diff options
author | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2020-04-25 01:51:59 +0300 |
---|---|---|
committer | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2020-04-25 04:00:56 +0300 |
commit | 9f30419eaafeb91fd2e2b17f0590cb505ee98d28 (patch) | |
tree | f78eb71ea57b9d711233890984fb25eda72bc33d | |
parent | 943a6baa291662d903a42070433a6d3d7f56428a (diff) |
Add UnquantizeAndWriteCallback tests
-rw-r--r-- | test/test_matrices.h | 34 | ||||
-rw-r--r-- | test/tile_test.inl | 100 |
2 files changed, 134 insertions, 0 deletions
diff --git a/test/test_matrices.h b/test/test_matrices.h index c9b7ec6..2e3acf7 100644 --- a/test/test_matrices.h +++ b/test/test_matrices.h @@ -39,4 +39,38 @@ struct TestMatrices8 { AlignedVector<int32_t> C; }; +struct TestMatricesUnquantizeAndWriteRowMajorAccess { + typedef Access<RowMajorAccess<int8_t>, ColMajorAccess<int8_t>, UnquantizeAndWriteRowMajorAccess<float>> AccessT; + + explicit TestMatricesUnquantizeAndWriteRowMajorAccess(Tile shape_in, float unquant_mult) : + shape(shape_in), + A(shape.A_rows * shape.inner), + B(shape.inner * shape.B_cols), + C(shape.A_rows * shape.B_cols), + unquant_mult_(unquant_mult) { + + std::mt19937 gen; + std::uniform_int_distribution<int8_t> dist(-127,127); + for (int8_t &it : A) it = dist(gen); + for (int8_t &it : B) it = dist(gen); + // C is uninitialized. + } + + AccessT Accessor() { + return AccessT( + RowMajorAccess<int8_t>(A.begin(), shape.inner), + ColMajorAccess<int8_t>(B.begin(), shape.inner), + UnquantizeAndWriteRowMajorAccess<float>(C.begin(), shape.B_cols, {unquant_mult_})); + } + + Tile shape; + AlignedVector<int8_t> A; + AlignedVector<int8_t> B; + // Uninitialized; for using tests to write to. + AlignedVector<float> C; + +private: + float unquant_mult_; +}; + } // namespace intgemm diff --git a/test/tile_test.inl b/test/tile_test.inl index 1da3556..1c1d0d3 100644 --- a/test/tile_test.inl +++ b/test/tile_test.inl @@ -141,6 +141,46 @@ template <class Access> void Signed8ReferenceMult(Access access, Tile problem) { } } +template <class Access> void Signed8ReferenceMult_UnquantizeAndWrite(Access access, Tile problem, float unquant_mult) { + assert(!problem.inner % 2); + for (Index a_row = 0; a_row < problem.A_rows; ++a_row) { + for (Index b_col = 0; b_col < problem.B_cols; ++b_col) { + Access acc = access.AAdd(a_row, 0).BAdd(0, b_col).CAdd(a_row, b_col); + // For VNNI, just do it accurately. +#ifdef INTGEMM_THIS_IS_AVX512VNNI + acc.CFront() = 0; + for (Index inner = 0; inner < problem.inner; ++inner) { + Access innermost = acc.AAdd(0, inner).BAdd(inner, 0); + acc.CFront() += static_cast<int32_t>(innermost.AFront()) * static_cast<int32_t>(innermost.BFront()); + } +#else + // For non-VNNI, do the saturation stuff. + int16_t accumulators[sizeof(Register) / sizeof(int16_t)] = {0}; + for (Index inner = 0; inner < problem.inner; inner += 2) { + Access innermost = acc.AAdd(0, inner).BAdd(inner, 0); + int32_t product = static_cast<int32_t>(innermost.AFront()) * static_cast<int32_t>(innermost.BFront()); + innermost = innermost.AAdd(0, 1).BAdd(1, 0); + product += static_cast<int32_t>(innermost.AFront()) * static_cast<int32_t>(innermost.BFront()); + // Saturate to 16-bit for maddubs. + if (product > 32767) product = 32767; + if (product < -32768) product = -32768; + int16_t &accum = accumulators[(inner / 2) % (sizeof(Register) / sizeof(int16_t))]; + // Saturating accumlation. + product += static_cast<int32_t>(accum); + if (product > 32767) product = 32767; + if (product < -32768) product = -32768; + accum = static_cast<int16_t>(product); + } + acc.CFront() = 0; + for (Index i = 0; i < sizeof(Register) / sizeof(int16_t); ++i) { + acc.CFront() += static_cast<int32_t>(accumulators[i]); + } +#endif + acc.CFront() *= unquant_mult; + } + } +} + void DumpMatrix(int8_t *m, Index rows, Index cols) { std::cerr << rows << 'x' << cols << '\n'; for (Index i = 0; i < rows; ++i) { @@ -166,6 +206,22 @@ struct TestMatricesRef : TestMatrices8 { AlignedVector<int32_t> C_reference; }; +struct TestMatricesRef_UnquantizeAndWrite : TestMatricesUnquantizeAndWriteRowMajorAccess { + TestMatricesRef_UnquantizeAndWrite(Tile shape_in, float unquant_mult) : + TestMatricesUnquantizeAndWriteRowMajorAccess(shape_in, unquant_mult), + C_reference(shape.A_rows * shape.B_cols) { + + AccessT ref_access( + RowMajorAccess<int8_t>(A.begin(), shape.inner), + ColMajorAccess<int8_t>(B.begin(), shape.inner), + UnquantizeAndWriteRowMajorAccess<float>(C_reference.begin(), shape.B_cols, {unquant_mult})); + Signed8ReferenceMult_UnquantizeAndWrite<AccessT>(ref_access, shape, unquant_mult); + } + + AlignedVector<float> C_reference; +}; + + #ifndef INTGEMM_THIS_IS_SSE2 template <class Kernel> void TestMultiplyNoOverhang(Tile shape) { // These are sanity checks on the arguments, not the code. @@ -182,6 +238,16 @@ template <class Kernel> void TestMultiplyNoOverhang(Tile shape) { }*/ } +template <class Kernel> void TestMultiplyNoOverhang_UnquantizeAndWrite(Tile shape, float unquant_mult) { + // These are sanity checks on the arguments, not the code. + CHECK(shape.A_rows % Kernel::kTile.A_rows == 0); + CHECK(shape.inner % Kernel::kTile.inner == 0); + CHECK(shape.B_cols % Kernel::kTile.B_cols == 0); + TestMatricesRef_UnquantizeAndWrite t(shape, unquant_mult); + MultiplyNoOverhang<TestMatricesRef_UnquantizeAndWrite::AccessT, Kernel>(t.Accessor(), shape); + CHECK(!memcmp(t.C_reference.begin(), t.C.begin(), shape.A_rows * shape.B_cols * sizeof(float))); +} + template <class Kernel> void TestMultiplyNoOverhangShapes() { Tile shape = Kernel::kTile; // Minimum size. @@ -201,6 +267,25 @@ template <class Kernel> void TestMultiplyNoOverhangShapes() { } } +template <class Kernel> void TestMultiplyNoOverhangShapes_UnquantizeAndWrite(float unquant_mult) { + Tile shape = Kernel::kTile; + // Minimum size. + TestMultiplyNoOverhang_UnquantizeAndWrite<Kernel>(shape, unquant_mult); + // Multiples on each dimension. + TestMultiplyNoOverhang_UnquantizeAndWrite<Kernel>(Tile{shape.A_rows * 2, shape.inner, shape.B_cols}, unquant_mult); + TestMultiplyNoOverhang_UnquantizeAndWrite<Kernel>(Tile{shape.A_rows, shape.inner * 2, shape.B_cols}, unquant_mult); + TestMultiplyNoOverhang_UnquantizeAndWrite<Kernel>(Tile{shape.A_rows, shape.inner, shape.B_cols * 2}, unquant_mult); + TestMultiplyNoOverhang_UnquantizeAndWrite<Kernel>(Tile{shape.A_rows * 2, shape.inner * 2, shape.B_cols * 2}, unquant_mult); + // Try a bunch of shapes! + for (shape.A_rows = 0; shape.A_rows <= Kernel::kTile.A_rows * 9; shape.A_rows += Kernel::kTile.A_rows) { + for (shape.inner = 0; shape.inner <= Kernel::kTile.inner * 9; shape.inner += Kernel::kTile.inner) { + for (shape.B_cols = 0; shape.B_cols <= Kernel::kTile.B_cols * 9; shape.B_cols += Kernel::kTile.B_cols) { + TestMultiplyNoOverhang_UnquantizeAndWrite<Kernel>(shape, unquant_mult); + } + } + } +} + TEST_CASE("MultiplyNoOverhang Signed8 " INTGEMM_TEST_NAME, "[tile]") { if (kCPU < CPUType::INTGEMM_ARCH) return; TestMultiplyNoOverhangShapes<Signed8>(); @@ -219,6 +304,21 @@ TEST_CASE("MultiplyNoOverhang inner unroll " INTGEMM_TEST_NAME, "[tile][multiply TestMultiplyNoOverhang<Kernel>({1, sizeof(Register) * 4, 1}); TestMultiplyNoOverhangShapes<Kernel>(); } + +TEST_CASE("MultiplyNoOverhang Signed8 UnquantizeAndWrite " INTGEMM_TEST_NAME, "[tile]") { + if (kCPU < CPUType::INTGEMM_ARCH) return; + TestMultiplyNoOverhangShapes_UnquantizeAndWrite<Signed8>(1.7f); +} + +TEST_CASE("MultiplyNoOverhang inner unroll UnquantizeAndWrite " INTGEMM_TEST_NAME, "[tile][multiply]") { + if (kCPU < CPUType::INTGEMM_ARCH) return; + float unquant_mult = 1.7f; + typedef UnrollKernel<1, 2, 1, Signed8> Kernel; + Tile shape = {1, sizeof(Register) * 2, 1}; + TestMultiplyNoOverhang_UnquantizeAndWrite<Kernel>(shape, unquant_mult); + TestMultiplyNoOverhang_UnquantizeAndWrite<Kernel>({1, sizeof(Register) * 4, 1}, unquant_mult); + TestMultiplyNoOverhangShapes_UnquantizeAndWrite<Kernel>(unquant_mult); +} #endif // INTGEMM_THIS_IS_AVX512VNNI // If the inner dimension is just twice, then there isn't any non-determinism in saturation order. |