diff options
author | Kenneth Heafield <github@kheafield.com> | 2020-04-24 01:47:09 +0300 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2020-04-24 01:47:09 +0300 |
commit | d73a47adc6e0ac88b10e127df7e083bf3bb6c46b (patch) | |
tree | 97f88d201bbf26736a7154a7cb245adf26f7951c | |
parent | b95575c0c039ee05910098590f6a46685988b248 (diff) |
Extract randomly generated matrix class
-rw-r--r-- | test/test_matrices.h | 42 | ||||
-rw-r--r-- | test/tile_test.cc | 2 | ||||
-rw-r--r-- | test/tile_test.inl | 45 |
3 files changed, 56 insertions, 33 deletions
diff --git a/test/test_matrices.h b/test/test_matrices.h new file mode 100644 index 0000000..c9b7ec6 --- /dev/null +++ b/test/test_matrices.h @@ -0,0 +1,42 @@ +#pragma once + +#include "../aligned.h" +#include "../tile/access.h" +#include <random> +// Yes both due to debacle. +#include <stdint.h> +#include <cstdint> + +namespace intgemm { + +struct TestMatrices8 { + typedef Access<RowMajorAccess<int8_t>, ColMajorAccess<int8_t>, RowMajorAccess<int32_t> > AccessT; + + explicit TestMatrices8(Tile shape_in) : + shape(shape_in), + A(shape.A_rows * shape.inner), + B(shape.inner * shape.B_cols), + C(shape.A_rows * shape.B_cols) { + + std::mt19937 gen; + std::uniform_int_distribution<int8_t> dist(-127,127); + for (int8_t &it : A) it = dist(gen); + for (int8_t &it : B) it = dist(gen); + // C is uninitialized. + } + + AccessT Accessor() { + return AccessT( + RowMajorAccess<int8_t>(A.begin(), shape.inner), + ColMajorAccess<int8_t>(B.begin(), shape.inner), + RowMajorAccess<int32_t>(C.begin(), shape.B_cols)); + } + + Tile shape; + AlignedVector<int8_t> A; + AlignedVector<int8_t> B; + // Uninitialized; for using tests to write to. + AlignedVector<int32_t> C; +}; + +} // namespace intgemm diff --git a/test/tile_test.cc b/test/tile_test.cc index 0b7d94c..2385a81 100644 --- a/test/tile_test.cc +++ b/test/tile_test.cc @@ -5,6 +5,8 @@ #include "../tile/reduce.h" #include "test.h" +#include "test_matrices.h" + #include <random> #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI diff --git a/test/tile_test.inl b/test/tile_test.inl index 24decca..1da3556 100644 --- a/test/tile_test.inl +++ b/test/tile_test.inl @@ -151,20 +151,11 @@ void DumpMatrix(int8_t *m, Index rows, Index cols) { } } -struct TestMatrices { - typedef Access<RowMajorAccess<int8_t>, ColMajorAccess<int8_t>, RowMajorAccess<int32_t> > AccessT; - - TestMatrices(Tile shape_in) : - shape(shape_in), - A(shape.A_rows * shape.inner), - B(shape.inner * shape.B_cols), +struct TestMatricesRef : TestMatrices8 { + TestMatricesRef(Tile shape_in) : + TestMatrices8(shape_in), C_reference(shape.A_rows * shape.B_cols) { - std::mt19937 gen; - std::uniform_int_distribution<int8_t> dist(-127,127); - for (int8_t &it : A) it = dist(gen); - for (int8_t &it : B) it = dist(gen); - AccessT ref_access( RowMajorAccess<int8_t>(A.begin(), shape.inner), ColMajorAccess<int8_t>(B.begin(), shape.inner), @@ -172,16 +163,6 @@ struct TestMatrices { Signed8ReferenceMult<AccessT>(ref_access, shape); } - AccessT AccessTest(AlignedVector<int32_t> &C_test) { - return AccessT( - RowMajorAccess<int8_t>(A.begin(), shape.inner), - ColMajorAccess<int8_t>(B.begin(), shape.inner), - RowMajorAccess<int32_t>(C_test.begin(), shape.B_cols)); - } - - Tile shape; - AlignedVector<int8_t> A; - AlignedVector<int8_t> B; AlignedVector<int32_t> C_reference; }; @@ -191,10 +172,9 @@ template <class Kernel> void TestMultiplyNoOverhang(Tile shape) { CHECK(shape.A_rows % Kernel::kTile.A_rows == 0); CHECK(shape.inner % Kernel::kTile.inner == 0); CHECK(shape.B_cols % Kernel::kTile.B_cols == 0); - TestMatrices t(shape); - AlignedVector<int32_t> C_test(shape.A_rows * shape.B_cols); - MultiplyNoOverhang<TestMatrices::AccessT, Kernel>(t.AccessTest(C_test), shape); - CHECK(!memcmp(t.C_reference.begin(), C_test.begin(), shape.A_rows * shape.B_cols * sizeof(int32_t))); + TestMatricesRef t(shape); + MultiplyNoOverhang<TestMatricesRef::AccessT, Kernel>(t.Accessor(), shape); + CHECK(!memcmp(t.C_reference.begin(), t.C.begin(), shape.A_rows * shape.B_cols * sizeof(int32_t))); /* for (Index i = 0; i < shape.A_rows; ++i) { for (Index j = 0; j < shape.B_cols; ++j) { CHECK(t.C_reference[i * shape.B_cols + j] == C_test[i * shape.B_cols + j]); @@ -305,13 +285,12 @@ TEST_CASE("Multiply " INTGEMM_TEST_NAME, "[tile][multiply]") { Tile shape{1, sizeof(Register), 1}; for (shape.A_rows = 1; shape.A_rows < 33; ++shape.A_rows) { for (shape.B_cols = 1; shape.B_cols < 33; ++shape.B_cols) { - TestMatrices t(shape); - AlignedVector<int32_t> C_test(shape.A_rows * shape.B_cols); - Multiply<TestMatrices::AccessT, Signed8, 7, 3>(t.AccessTest(C_test), shape); - CHECK(!memcmp(t.C_reference.begin(), C_test.begin(), shape.A_rows * shape.B_cols * sizeof(int32_t))); - memset(C_test.begin(), 0, shape.A_rows * shape.B_cols); - Multiply<TestMatrices::AccessT, Signed8, 4, 5>(t.AccessTest(C_test), shape); - CHECK(!memcmp(t.C_reference.begin(), C_test.begin(), shape.A_rows * shape.B_cols * sizeof(int32_t))); + TestMatricesRef t(shape); + Multiply<TestMatricesRef::AccessT, Signed8, 7, 3>(t.Accessor(), shape); + CHECK(!memcmp(t.C_reference.begin(), t.C.begin(), shape.A_rows * shape.B_cols * sizeof(int32_t))); + memset(t.C.begin(), 0, shape.A_rows * shape.B_cols * sizeof(int32_t)); + Multiply<TestMatricesRef::AccessT, Signed8, 4, 5>(t.Accessor(), shape); + CHECK(!memcmp(t.C_reference.begin(), t.C.begin(), shape.A_rows * shape.B_cols * sizeof(int32_t))); } } } |