diff options
author | Kenneth Heafield <github@kheafield.com> | 2020-04-24 00:18:53 +0300 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2020-04-24 00:19:29 +0300 |
commit | ed5ba5e5eb56d6cd44913d8efb29f3fed248b039 (patch) | |
tree | 3e3a518ba71a80f8b57e854eefc837c859b725e3 | |
parent | 88471a26388f38eb4007cfbd4e52ac1a60c88a79 (diff) |
Basic general sized multiply, not optimized yet
-rw-r--r-- | test/tile_test.inl | 86 | ||||
-rw-r--r-- | tile/multiply.inl | 31 |
2 files changed, 83 insertions, 34 deletions
diff --git a/test/tile_test.inl b/test/tile_test.inl index 69fce76..bd5c891 100644 --- a/test/tile_test.inl +++ b/test/tile_test.inl @@ -151,48 +151,55 @@ void DumpMatrix(int8_t *m, Index rows, Index cols) { } } +struct TestMatrices { + typedef Access<RowMajorAccess<int8_t>, ColMajorAccess<int8_t>, RowMajorAccess<int32_t> > AccessT; + + TestMatrices(Tile shape_in) : + shape(shape_in), + A(shape.A_rows * shape.inner), + B(shape.inner * shape.B_cols), + C_reference(shape.A_rows * shape.B_cols) { + + std::mt19937 gen; + std::uniform_int_distribution<int8_t> dist(-127,127); + for (int8_t &it : A) it = dist(gen); + for (int8_t &it : B) it = dist(gen); + + AccessT ref_access( + RowMajorAccess<int8_t>(A.begin(), shape.inner), + ColMajorAccess<int8_t>(B.begin(), shape.inner), + RowMajorAccess<int32_t>(C_reference.begin(), shape.B_cols)); + Signed8ReferenceMult<AccessT>(ref_access, shape); + } + + AccessT AccessTest(AlignedVector<int32_t> &C_test) { + return AccessT( + RowMajorAccess<int8_t>(A.begin(), shape.inner), + ColMajorAccess<int8_t>(B.begin(), shape.inner), + RowMajorAccess<int32_t>(C_test.begin(), shape.B_cols)); + } + + Tile shape; + AlignedVector<int8_t> A; + AlignedVector<int8_t> B; + AlignedVector<int32_t> C_reference; +}; + #ifndef INTGEMM_THIS_IS_SSE2 template <class Kernel> void TestMultiplyNoOverhang(Tile shape) { // These are sanity checks on the arguments, not the code. CHECK(shape.A_rows % Kernel::kTile.A_rows == 0); CHECK(shape.inner % Kernel::kTile.inner == 0); CHECK(shape.B_cols % Kernel::kTile.B_cols == 0); - - AlignedVector<int8_t> A(shape.A_rows * shape.inner); - AlignedVector<int8_t> B(shape.inner * shape.B_cols); - std::mt19937 gen; - std::uniform_int_distribution<int8_t> dist(-127,127); - for (int8_t &it : A) it = dist(gen); - for (int8_t &it : B) it = dist(gen); - - AlignedVector<int32_t> C_reference(shape.A_rows * shape.B_cols); - typedef Access<RowMajorAccess<int8_t>, ColMajorAccess<int8_t>, RowMajorAccess<int32_t> > AccessT; - AccessT ref_access( - RowMajorAccess<int8_t>(A.begin(), shape.inner), - ColMajorAccess<int8_t>(B.begin(), shape.inner), - RowMajorAccess<int32_t>(C_reference.begin(), shape.B_cols)); - Signed8ReferenceMult<AccessT>(ref_access, shape); - + TestMatrices t(shape); AlignedVector<int32_t> C_test(shape.A_rows * shape.B_cols); - AccessT test_access( - RowMajorAccess<int8_t>(A.begin(), shape.inner), - ColMajorAccess<int8_t>(B.begin(), shape.inner), - RowMajorAccess<int32_t>(C_test.begin(), shape.B_cols)); - MultiplyNoOverhang<AccessT, Kernel>(test_access, shape); - bool failed = false; - for (Index i = 0; i < shape.A_rows; ++i) { + MultiplyNoOverhang<TestMatrices::AccessT, Kernel>(t.AccessTest(C_test), shape); + CHECK(!memcmp(t.C_reference.begin(), C_test.begin(), shape.A_rows * shape.B_cols)); +/* for (Index i = 0; i < shape.A_rows; ++i) { for (Index j = 0; j < shape.B_cols; ++j) { - CHECK(C_reference[i * shape.B_cols + j] == C_test[i * shape.B_cols + j]); - if (C_reference[i * shape.B_cols + j] != C_test[i * shape.B_cols + j]) - failed = true; + CHECK(t.C_reference[i * shape.B_cols + j] == C_test[i * shape.B_cols + j]); } - } - if (failed) { - std::cerr << "Failed A is "; - DumpMatrix(A.begin(), shape.A_rows, shape.inner); - std::cerr << "Failed B is "; - DumpMatrix(B.begin(), shape.inner, shape.B_cols); - } + }*/ } template <class Kernel> void TestMultiplyNoOverhangShapes() { @@ -293,6 +300,19 @@ TEMPLATE_TEST_CASE("MultiplyNoOverhang Unrolled Signed8 " INTGEMM_TEST_NAME, "[t TestMultiplyNoOverhangShapes<TestType>(); } +TEST_CASE("Multiply " INTGEMM_TEST_NAME, "[tile][multiply]") { + if (kCPU < CPUType::INTGEMM_ARCH) return; + Tile shape{1, sizeof(Register), 1}; + for (shape.A_rows = 1; shape.A_rows < 33; ++shape.A_rows) { + for (shape.B_cols = 1; shape.B_cols < 33; ++shape.B_cols) { + TestMatrices t(shape); + AlignedVector<int32_t> C_test(shape.A_rows * shape.B_cols); + Multiply<TestMatrices::AccessT, Signed8, 7, 3>(t.AccessTest(C_test), shape); + CHECK(!memcmp(t.C_reference.begin(), C_test.begin(), shape.A_rows * shape.B_cols)); + } + } +} + #endif // no INTGEMM_THIS_IS_SSE2 } // namespace INTGEMM_ARCH diff --git a/tile/multiply.inl b/tile/multiply.inl index f1344d7..78dca55 100644 --- a/tile/multiply.inl +++ b/tile/multiply.inl @@ -29,11 +29,11 @@ template <class AccessT, class Kernel> INTGEMM_TARGET __attribute__((flatten)) s assert(shape.A_rows % Kernel::kTile.A_rows == 0); assert(shape.inner % Kernel::kTile.inner == 0); assert(shape.B_cols % Kernel::kTile.B_cols == 0); - constexpr Index Outputs = Kernel::kTile.A_rows * Kernel::kTile.B_cols; for (Index B_col = 0; B_col < shape.B_cols; B_col += Kernel::kTile.B_cols) { AccessT column_adjusted = access.BAdd(0, B_col).CAdd(0, B_col); for (Index A_row = 0; A_row < shape.A_rows; A_row += Kernel::kTile.A_rows) { AccessT col_row = column_adjusted.AAdd(A_row, 0).CAdd(A_row, 0); + constexpr Index Outputs = Kernel::kTile.A_rows * Kernel::kTile.B_cols; // Accumulate values in temporary C registers. Register c_regs[Outputs] = {setzero_si<Register>()}; @@ -56,6 +56,35 @@ template <class AccessT, class Kernel> INTGEMM_TARGET __attribute__((flatten)) s } } +template <class Access, class Kernel, Index A_rows, Index B_cols> INTGEMM_TARGET static inline void Multiply(Access access, const Tile shape) { + // Still has to be a multiple of the underlying Kernel, but usually that's just 1 x sizeof(Register) x 1. + assert(shape.A_rows % Kernel::kTile.A_rows == 0); + assert(shape.inner % Kernel::kTile.inner == 0); + assert(shape.B_cols % Kernel::kTile.B_cols == 0); + + typedef UnrollKernel<A_rows, 1, B_cols, Kernel> Big; + Tile overhang = { + shape.A_rows % Big::kTile.A_rows, + shape.inner % Big::kTile.inner, + shape.B_cols % Big::kTile.B_cols + }; + Tile big_shape = { + shape.A_rows - overhang.A_rows, + shape.inner - overhang.inner, + shape.B_cols - overhang.B_cols + }; + // Top left corner. + MultiplyNoOverhang<Access, Big>(access, big_shape); + // Bottom currently including right side. TODO: unrolled kernel, rather than dumb loop. + MultiplyNoOverhang<Access, Kernel>( + access.AAdd(big_shape.A_rows, 0).CAdd(big_shape.A_rows, 0), + Tile {overhang.A_rows, shape.inner, shape.B_cols}); + // Right side except bottom. TODO: unrolled kernel, rather than dumb loop. + MultiplyNoOverhang<Access, Kernel>( + access.BAdd(0, big_shape.B_cols).CAdd(0, big_shape.B_cols), + Tile {big_shape.A_rows, shape.inner, overhang.B_cols}); +} + } // namespace INTGEMM_ARCH } // namespace intgemm |