diff options
author | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2020-05-08 18:38:18 +0300 |
---|---|---|
committer | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2020-05-08 19:11:57 +0300 |
commit | 54c46094baa3ad6af888ba1db6f164e5fa988612 (patch) | |
tree | 64165c31191e0e61281730c85c449f1a982eb879 | |
parent | 180c3b3c318c31d561faa42bd9cc145a007328e2 (diff) |
Add scalar versions on callbacks
-rw-r--r-- | callbacks/implementations.inl | 21 | ||||
-rw-r--r-- | test/tile_test.inl | 30 | ||||
-rw-r--r-- | tile/access.h | 12 |
3 files changed, 43 insertions, 20 deletions
diff --git a/callbacks/implementations.inl b/callbacks/implementations.inl index b54ca9a..fc075c4 100644 --- a/callbacks/implementations.inl +++ b/callbacks/implementations.inl @@ -58,6 +58,10 @@ public: run_callbacks(input, info, callbacks, make_sequence<sizeof...(Configs)>()); } + void operator()(int32_t input, const OutputBufferInfo& info) { + run_callbacks(input, info, callbacks, make_sequence<sizeof...(Configs)>()); + } + private: using CallbacksTupleType = std::tuple<CallbackImpl<CPUType::CPU_NAME, Configs>...>; @@ -82,6 +86,7 @@ private: RUN_CALLBACKS_PIPELINE_IMPL(vi) RUN_CALLBACKS_PIPELINE_IMPL(vf) RUN_CALLBACKS_PIPELINE_IMPL(vd) + RUN_CALLBACKS_PIPELINE_IMPL(int32_t) #undef RUN_CALLBACKS_PIPELINE_IMPL }; @@ -93,9 +98,14 @@ template <typename Type> class CallbackImpl<CPUType::CPU_NAME, Identity<Type>> { public: CPU_ATTR CallbackImpl(const Identity<Type>&) {} + CPU_ATTR vector_t<CPUType::CPU_NAME, Type> operator()(vector_t<CPUType::CPU_NAME, Type> input, const OutputBufferInfo&) { return input; } + + Type operator()(Type input, const OutputBufferInfo&) { + return input; + } }; /* @@ -111,6 +121,10 @@ public: return kernels::unquantize(input, unquant_mult); } + float operator()(int32_t input, const OutputBufferInfo&) { + return input * config.unquant_mult; + } + private: Unquantize config; vf unquant_mult; @@ -130,6 +144,8 @@ public: kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx); } + // TODO: Implement scalar version of operator() + private: UnquantizeAndWrite config; vf unquant_mult; @@ -147,6 +163,8 @@ public: kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx); } + // TODO: Implement scalar version of operator() + private: AddBiasAndWrite config; }; @@ -165,6 +183,9 @@ public: result = kernels::add_bias(result, config.bias_addr, info.col_idx); kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx); } + + // TODO: Implement scalar version of operator() + private: UnquantizeAndAddBiasAndWrite config; vf unquant_mult; diff --git a/test/tile_test.inl b/test/tile_test.inl index 814c674..70d0277 100644 --- a/test/tile_test.inl +++ b/test/tile_test.inl @@ -102,8 +102,11 @@ TEST_CASE("Reduce " INTGEMM_TEST_NAME, "[tile]") { } // Replicate the saturation behavior of the Signed8 kernel with 16-bit accumulation. -template <class Access, typename ScalarCallback> void Signed8ReferenceMult(Access access, Tile problem, ScalarCallback callback) { +template <class Access, typename Callback> void Signed8ReferenceMult(Access access, Tile problem, Callback callback) { assert(!(problem.inner % 2)); + + auto callback_impl = callbacks::CallbackImpl<CPUType::INTGEMM_ARCH, Callback>(callback); + for (Index a_row = 0; a_row < problem.A_rows; ++a_row) { for (Index b_col = 0; b_col < problem.B_cols; ++b_col) { Access acc = access.AAdd(a_row, 0).BAdd(0, b_col).CAdd(a_row, b_col); @@ -137,15 +140,11 @@ template <class Access, typename ScalarCallback> void Signed8ReferenceMult(Acces acc.CFront() += static_cast<int32_t>(accumulators[i]); } #endif - acc.CFront() = callback(acc.CFront()); + acc.CFront() = callback_impl(acc.CFront(), callbacks::OutputBufferInfo(a_row, b_col, 0, 0)); } } } -template <class Access> void Signed8ReferenceMult(Access access, Tile problem) { - Signed8ReferenceMult(access, problem, [](typename Access::CContent sum) { return sum; }); -} - void DumpMatrix(int8_t *m, Index rows, Index cols) { std::cerr << rows << 'x' << cols << '\n'; for (Index i = 0; i < rows; ++i) { @@ -165,23 +164,28 @@ struct TestMatricesRef : TestMatrices8 { RowMajorAccess<int8_t>(A.begin(), shape.inner), ColMajorAccess<int8_t>(B.begin(), shape.inner), RowMajorAccess<int32_t>(C_reference.begin(), shape.B_cols)); - Signed8ReferenceMult(ref_access, shape); + Signed8ReferenceMult(ref_access, shape, callbacks::Identity<int32_t>()); } AlignedVector<int32_t> C_reference; }; struct TestMatricesRef_Unquantize : TestMatrices<RowMajorAccess<int8_t>, ColMajorAccess<int8_t>, RowMajorAccess<float>> { - TestMatricesRef_Unquantize(Tile shape_in, float unquant_mult) : TestMatrices(shape_in), C_reference(shape.A_rows * shape.B_cols) { - AccessT ref_access({A.begin(), shape.inner}, {B.begin(), shape.inner}, {C_reference.begin(), shape.B_cols}); - Signed8ReferenceMult(ref_access, shape, [unquant_mult](typename AccessT::CContent value) { - return value * unquant_mult; - }); + TestMatricesRef_Unquantize(Tile shape_in, float unquant_mult) : + TestMatrices(shape_in), + C_reference(shape.A_rows * shape.B_cols) { + + AccessT ref_access( + RowMajorAccess<int8_t>(A.begin(), shape.inner), + ColMajorAccess<int8_t>(B.begin(), shape.inner), + RowMajorAccess<float>(C_reference.begin(), shape.B_cols)); + Signed8ReferenceMult(ref_access, shape, callbacks::Unquantize(unquant_mult)); } AlignedVector<float> C_reference; }; + #ifndef INTGEMM_THIS_IS_SSE2 template <class Kernel> void TestMultiplyNoOverhang(Tile shape) { // These are sanity checks on the arguments, not the code. @@ -251,12 +255,10 @@ TEST_CASE("MultiplyNoOverhang Signed8 " INTGEMM_TEST_NAME, "[tile]") { TestMultiplyNoOverhangShapes<Signed8>(); } -#if defined(INTGEMM_THIS_IS_AVX512BW) || defined(INTGEMM_THIS_IS_AVX512VNNI) TEST_CASE("MultiplyNoOverhang Signed8 Unquantize " INTGEMM_TEST_NAME, "[tile]") { if (kCPU < CPUType::INTGEMM_ARCH) return; TestMultiplyNoOverhangShapes_Unquantize<Signed8>(2.0f); } -#endif // Due to unordered_unfurl in dot.inl, the inner dimension can change order. // That impacts saturation. Then the test doesn't mach reference on arches diff --git a/tile/access.h b/tile/access.h index f8ae709..6e440af 100644 --- a/tile/access.h +++ b/tile/access.h @@ -23,11 +23,11 @@ template <class T> class RowMajorAccess { Content &Front() { return *data_; } // TODO: SLOW. This is here for testing. - template <Index A_rows, Index B_cols, typename CallbackImpl> void Write(const __m128i *from, CallbackImpl&) { - SlowWrite<A_rows, B_cols>(reinterpret_cast<const T*>(from)); + template <Index A_rows, Index B_cols, typename CallbackImpl> void Write(const __m128i *from, CallbackImpl& callback_impl) { + SlowWrite<A_rows, B_cols>(reinterpret_cast<const int32_t*>(from), callback_impl); } - template <Index A_rows, Index B_cols, typename CallbackImpl> void Write(const __m256i *from, CallbackImpl&) { - SlowWrite<A_rows, B_cols>(reinterpret_cast<const T*>(from)); + template <Index A_rows, Index B_cols, typename CallbackImpl> void Write(const __m256i *from, CallbackImpl& callback_impl) { + SlowWrite<A_rows, B_cols>(reinterpret_cast<const int32_t*>(from), callback_impl); } template <Index A_rows, Index B_cols, typename CallbackImpl> void Write(const __m512i *from, CallbackImpl& callback_impl) { WriteImpl<A_rows, B_cols, B_cols>(from, callback_impl); @@ -107,10 +107,10 @@ template <class T> class RowMajorAccess { typename std::enable_if<!A_rows || !B_cols>::type WriteImpl(const __m512i *, CallbackImpl&) {} - template <Index A_rows, Index B_cols> void SlowWrite(const T *from) { + template <Index A_rows, Index B_cols, typename CallbackImpl> void SlowWrite(const int32_t *from, CallbackImpl& callback_impl) { for (Index i = 0; i < A_rows; ++i) { for (Index j = 0; j < B_cols; ++j) { - data_[i * cols_ + j] = from[i * B_cols + j]; + data_[i * cols_ + j] = callback_impl(from[i * B_cols + j], callbacks::OutputBufferInfo(i, j, 0, 0)); } } } |