Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Chudyk <mateuszchudyk@gmail.com>2020-05-08 18:38:18 +0300
committerMateusz Chudyk <mateuszchudyk@gmail.com>2020-05-08 19:11:57 +0300
commit54c46094baa3ad6af888ba1db6f164e5fa988612 (patch)
tree64165c31191e0e61281730c85c449f1a982eb879
parent180c3b3c318c31d561faa42bd9cc145a007328e2 (diff)
Add scalar versions on callbacks
-rw-r--r--callbacks/implementations.inl21
-rw-r--r--test/tile_test.inl30
-rw-r--r--tile/access.h12
3 files changed, 43 insertions, 20 deletions
diff --git a/callbacks/implementations.inl b/callbacks/implementations.inl
index b54ca9a..fc075c4 100644
--- a/callbacks/implementations.inl
+++ b/callbacks/implementations.inl
@@ -58,6 +58,10 @@ public:
run_callbacks(input, info, callbacks, make_sequence<sizeof...(Configs)>());
}
+ void operator()(int32_t input, const OutputBufferInfo& info) {
+ run_callbacks(input, info, callbacks, make_sequence<sizeof...(Configs)>());
+ }
+
private:
using CallbacksTupleType = std::tuple<CallbackImpl<CPUType::CPU_NAME, Configs>...>;
@@ -82,6 +86,7 @@ private:
RUN_CALLBACKS_PIPELINE_IMPL(vi)
RUN_CALLBACKS_PIPELINE_IMPL(vf)
RUN_CALLBACKS_PIPELINE_IMPL(vd)
+ RUN_CALLBACKS_PIPELINE_IMPL(int32_t)
#undef RUN_CALLBACKS_PIPELINE_IMPL
};
@@ -93,9 +98,14 @@ template <typename Type>
class CallbackImpl<CPUType::CPU_NAME, Identity<Type>> {
public:
CPU_ATTR CallbackImpl(const Identity<Type>&) {}
+
CPU_ATTR vector_t<CPUType::CPU_NAME, Type> operator()(vector_t<CPUType::CPU_NAME, Type> input, const OutputBufferInfo&) {
return input;
}
+
+ Type operator()(Type input, const OutputBufferInfo&) {
+ return input;
+ }
};
/*
@@ -111,6 +121,10 @@ public:
return kernels::unquantize(input, unquant_mult);
}
+ float operator()(int32_t input, const OutputBufferInfo&) {
+ return input * config.unquant_mult;
+ }
+
private:
Unquantize config;
vf unquant_mult;
@@ -130,6 +144,8 @@ public:
kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx);
}
+ // TODO: Implement scalar version of operator()
+
private:
UnquantizeAndWrite config;
vf unquant_mult;
@@ -147,6 +163,8 @@ public:
kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx);
}
+ // TODO: Implement scalar version of operator()
+
private:
AddBiasAndWrite config;
};
@@ -165,6 +183,9 @@ public:
result = kernels::add_bias(result, config.bias_addr, info.col_idx);
kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx);
}
+
+ // TODO: Implement scalar version of operator()
+
private:
UnquantizeAndAddBiasAndWrite config;
vf unquant_mult;
diff --git a/test/tile_test.inl b/test/tile_test.inl
index 814c674..70d0277 100644
--- a/test/tile_test.inl
+++ b/test/tile_test.inl
@@ -102,8 +102,11 @@ TEST_CASE("Reduce " INTGEMM_TEST_NAME, "[tile]") {
}
// Replicate the saturation behavior of the Signed8 kernel with 16-bit accumulation.
-template <class Access, typename ScalarCallback> void Signed8ReferenceMult(Access access, Tile problem, ScalarCallback callback) {
+template <class Access, typename Callback> void Signed8ReferenceMult(Access access, Tile problem, Callback callback) {
assert(!(problem.inner % 2));
+
+ auto callback_impl = callbacks::CallbackImpl<CPUType::INTGEMM_ARCH, Callback>(callback);
+
for (Index a_row = 0; a_row < problem.A_rows; ++a_row) {
for (Index b_col = 0; b_col < problem.B_cols; ++b_col) {
Access acc = access.AAdd(a_row, 0).BAdd(0, b_col).CAdd(a_row, b_col);
@@ -137,15 +140,11 @@ template <class Access, typename ScalarCallback> void Signed8ReferenceMult(Acces
acc.CFront() += static_cast<int32_t>(accumulators[i]);
}
#endif
- acc.CFront() = callback(acc.CFront());
+ acc.CFront() = callback_impl(acc.CFront(), callbacks::OutputBufferInfo(a_row, b_col, 0, 0));
}
}
}
-template <class Access> void Signed8ReferenceMult(Access access, Tile problem) {
- Signed8ReferenceMult(access, problem, [](typename Access::CContent sum) { return sum; });
-}
-
void DumpMatrix(int8_t *m, Index rows, Index cols) {
std::cerr << rows << 'x' << cols << '\n';
for (Index i = 0; i < rows; ++i) {
@@ -165,23 +164,28 @@ struct TestMatricesRef : TestMatrices8 {
RowMajorAccess<int8_t>(A.begin(), shape.inner),
ColMajorAccess<int8_t>(B.begin(), shape.inner),
RowMajorAccess<int32_t>(C_reference.begin(), shape.B_cols));
- Signed8ReferenceMult(ref_access, shape);
+ Signed8ReferenceMult(ref_access, shape, callbacks::Identity<int32_t>());
}
AlignedVector<int32_t> C_reference;
};
struct TestMatricesRef_Unquantize : TestMatrices<RowMajorAccess<int8_t>, ColMajorAccess<int8_t>, RowMajorAccess<float>> {
- TestMatricesRef_Unquantize(Tile shape_in, float unquant_mult) : TestMatrices(shape_in), C_reference(shape.A_rows * shape.B_cols) {
- AccessT ref_access({A.begin(), shape.inner}, {B.begin(), shape.inner}, {C_reference.begin(), shape.B_cols});
- Signed8ReferenceMult(ref_access, shape, [unquant_mult](typename AccessT::CContent value) {
- return value * unquant_mult;
- });
+ TestMatricesRef_Unquantize(Tile shape_in, float unquant_mult) :
+ TestMatrices(shape_in),
+ C_reference(shape.A_rows * shape.B_cols) {
+
+ AccessT ref_access(
+ RowMajorAccess<int8_t>(A.begin(), shape.inner),
+ ColMajorAccess<int8_t>(B.begin(), shape.inner),
+ RowMajorAccess<float>(C_reference.begin(), shape.B_cols));
+ Signed8ReferenceMult(ref_access, shape, callbacks::Unquantize(unquant_mult));
}
AlignedVector<float> C_reference;
};
+
#ifndef INTGEMM_THIS_IS_SSE2
template <class Kernel> void TestMultiplyNoOverhang(Tile shape) {
// These are sanity checks on the arguments, not the code.
@@ -251,12 +255,10 @@ TEST_CASE("MultiplyNoOverhang Signed8 " INTGEMM_TEST_NAME, "[tile]") {
TestMultiplyNoOverhangShapes<Signed8>();
}
-#if defined(INTGEMM_THIS_IS_AVX512BW) || defined(INTGEMM_THIS_IS_AVX512VNNI)
TEST_CASE("MultiplyNoOverhang Signed8 Unquantize " INTGEMM_TEST_NAME, "[tile]") {
if (kCPU < CPUType::INTGEMM_ARCH) return;
TestMultiplyNoOverhangShapes_Unquantize<Signed8>(2.0f);
}
-#endif
// Due to unordered_unfurl in dot.inl, the inner dimension can change order.
// That impacts saturation. Then the test doesn't mach reference on arches
diff --git a/tile/access.h b/tile/access.h
index f8ae709..6e440af 100644
--- a/tile/access.h
+++ b/tile/access.h
@@ -23,11 +23,11 @@ template <class T> class RowMajorAccess {
Content &Front() { return *data_; }
// TODO: SLOW. This is here for testing.
- template <Index A_rows, Index B_cols, typename CallbackImpl> void Write(const __m128i *from, CallbackImpl&) {
- SlowWrite<A_rows, B_cols>(reinterpret_cast<const T*>(from));
+ template <Index A_rows, Index B_cols, typename CallbackImpl> void Write(const __m128i *from, CallbackImpl& callback_impl) {
+ SlowWrite<A_rows, B_cols>(reinterpret_cast<const int32_t*>(from), callback_impl);
}
- template <Index A_rows, Index B_cols, typename CallbackImpl> void Write(const __m256i *from, CallbackImpl&) {
- SlowWrite<A_rows, B_cols>(reinterpret_cast<const T*>(from));
+ template <Index A_rows, Index B_cols, typename CallbackImpl> void Write(const __m256i *from, CallbackImpl& callback_impl) {
+ SlowWrite<A_rows, B_cols>(reinterpret_cast<const int32_t*>(from), callback_impl);
}
template <Index A_rows, Index B_cols, typename CallbackImpl> void Write(const __m512i *from, CallbackImpl& callback_impl) {
WriteImpl<A_rows, B_cols, B_cols>(from, callback_impl);
@@ -107,10 +107,10 @@ template <class T> class RowMajorAccess {
typename std::enable_if<!A_rows || !B_cols>::type
WriteImpl(const __m512i *, CallbackImpl&) {}
- template <Index A_rows, Index B_cols> void SlowWrite(const T *from) {
+ template <Index A_rows, Index B_cols, typename CallbackImpl> void SlowWrite(const int32_t *from, CallbackImpl& callback_impl) {
for (Index i = 0; i < A_rows; ++i) {
for (Index j = 0; j < B_cols; ++j) {
- data_[i * cols_ + j] = from[i * B_cols + j];
+ data_[i * cols_ + j] = callback_impl(from[i * B_cols + j], callbacks::OutputBufferInfo(i, j, 0, 0));
}
}
}