diff options
-rw-r--r-- | test/tile_test.inl | 71 |
1 files changed, 71 insertions, 0 deletions
diff --git a/test/tile_test.inl b/test/tile_test.inl index 70d0277..3f9e3d5 100644 --- a/test/tile_test.inl +++ b/test/tile_test.inl @@ -185,6 +185,29 @@ struct TestMatricesRef_Unquantize : TestMatrices<RowMajorAccess<int8_t>, ColMajo AlignedVector<float> C_reference; }; +struct TestMatricesRef_UnquantizeAndAddBias : TestMatrices<RowMajorAccess<int8_t>, ColMajorAccess<int8_t>, RowMajorAccess<float>> { + TestMatricesRef_UnquantizeAndAddBias(Tile shape_in, float unquant_mult) : + TestMatrices(shape_in), + bias(shape.B_cols), + C_reference(shape.A_rows * shape.B_cols) { + + std::mt19937 gen; + std::uniform_int_distribution<int> dist(-10, 10); + for (auto &it : bias) it = dist(gen); + + AccessT ref_access( + RowMajorAccess<int8_t>(A.begin(), shape.inner), + ColMajorAccess<int8_t>(B.begin(), shape.inner), + RowMajorAccess<float>(C_reference.begin(), shape.B_cols)); + Signed8ReferenceMult(ref_access, shape, callbacks::Sequence( + callbacks::Unquantize(unquant_mult), + callbacks::AddBias(bias.begin()) + )); + } + + AlignedVector<float> bias; + AlignedVector<float> C_reference; +}; #ifndef INTGEMM_THIS_IS_SSE2 template <class Kernel> void TestMultiplyNoOverhang(Tile shape) { @@ -212,6 +235,30 @@ template <class Kernel> void TestMultiplyNoOverhang_Unquantize(Tile shape, float CHECK(!memcmp(t.C_reference.begin(), t.C.begin(), shape.A_rows * shape.B_cols * sizeof(int32_t))); } +template <class Kernel> void TestMultiplyNoOverhang_UnquantizeAndAddBias(Tile shape, float unquant_mult) { + // These are sanity checks on the arguments, not the code. + CHECK(shape.A_rows % Kernel::kTile.A_rows == 0); + CHECK(shape.inner % Kernel::kTile.inner == 0); + CHECK(shape.B_cols % Kernel::kTile.B_cols == 0); + TestMatricesRef_UnquantizeAndAddBias t(shape, unquant_mult); + MultiplyNoOverhang<Kernel>(t.Accessor(), shape, callbacks::Sequence( + callbacks::Unquantize(unquant_mult), + callbacks::AddBias(t.bias.begin()) + )); + // CHECK(!memcmp(t.C_reference.begin(), t.C.begin(), shape.A_rows * shape.B_cols * sizeof(int32_t))); + for (Index i = 0; i < shape.A_rows; ++i) { + for (Index j = 0; j < shape.B_cols; ++j) { + if (fabs(t.C[i * shape.B_cols + j] - t.C_reference[i * shape.B_cols + j]) < 0.0001f) { + SUCCEED(); + } + else { + printf("row x col = %lux%lu\n", i, j); + CHECK(t.C[i * shape.B_cols + j] == t.C_reference[i * shape.B_cols + j]); + } + } + } +} + template <class Kernel> void TestMultiplyNoOverhangShapes() { Tile shape = Kernel::kTile; // Minimum size. @@ -250,6 +297,25 @@ template <class Kernel> void TestMultiplyNoOverhangShapes_Unquantize(float unqua } } +template <class Kernel> void TestMultiplyNoOverhangShapes_UnquantizeAndAddBias(float unquant_mult) { + Tile shape = Kernel::kTile; + // Minimum size. + TestMultiplyNoOverhang_UnquantizeAndAddBias<Kernel>(shape, unquant_mult); + // Multiples on each dimension. + TestMultiplyNoOverhang_UnquantizeAndAddBias<Kernel>(Tile{shape.A_rows * 2, shape.inner, shape.B_cols}, unquant_mult); + TestMultiplyNoOverhang_UnquantizeAndAddBias<Kernel>(Tile{shape.A_rows, shape.inner * 2, shape.B_cols}, unquant_mult); + TestMultiplyNoOverhang_UnquantizeAndAddBias<Kernel>(Tile{shape.A_rows, shape.inner, shape.B_cols * 2}, unquant_mult); + TestMultiplyNoOverhang_UnquantizeAndAddBias<Kernel>(Tile{shape.A_rows * 2, shape.inner * 2, shape.B_cols * 2}, unquant_mult); + // Try a bunch of shapes! + for (shape.A_rows = 0; shape.A_rows <= Kernel::kTile.A_rows * 9; shape.A_rows += Kernel::kTile.A_rows) { + for (shape.inner = 0; shape.inner <= Kernel::kTile.inner * 9; shape.inner += Kernel::kTile.inner) { + for (shape.B_cols = 0; shape.B_cols <= Kernel::kTile.B_cols * 9; shape.B_cols += Kernel::kTile.B_cols) { + TestMultiplyNoOverhang_UnquantizeAndAddBias<Kernel>(shape, unquant_mult); + } + } + } +} + TEST_CASE("MultiplyNoOverhang Signed8 " INTGEMM_TEST_NAME, "[tile]") { if (kCPU < CPUType::INTGEMM_ARCH) return; TestMultiplyNoOverhangShapes<Signed8>(); @@ -260,6 +326,11 @@ TEST_CASE("MultiplyNoOverhang Signed8 Unquantize " INTGEMM_TEST_NAME, "[tile]") TestMultiplyNoOverhangShapes_Unquantize<Signed8>(2.0f); } +TEST_CASE("MultiplyNoOverhang Signed8 UnquantizeAndAddBias " INTGEMM_TEST_NAME, "[tile]") { + if (kCPU < CPUType::INTGEMM_ARCH) return; + TestMultiplyNoOverhangShapes_UnquantizeAndAddBias<Signed8>(2.0f); +} + // Due to unordered_unfurl in dot.inl, the inner dimension can change order. // That impacts saturation. Then the test doesn't mach reference on arches // that use 16-bit saturating accumlation. So we only test inner unrolling on |