Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2020-04-24 01:47:09 +0300
committerKenneth Heafield <github@kheafield.com>2020-04-24 01:47:09 +0300
commitd73a47adc6e0ac88b10e127df7e083bf3bb6c46b (patch)
tree97f88d201bbf26736a7154a7cb245adf26f7951c
parentb95575c0c039ee05910098590f6a46685988b248 (diff)
Extract randomly generated matrix class
-rw-r--r--test/test_matrices.h42
-rw-r--r--test/tile_test.cc2
-rw-r--r--test/tile_test.inl45
3 files changed, 56 insertions, 33 deletions
diff --git a/test/test_matrices.h b/test/test_matrices.h
new file mode 100644
index 0000000..c9b7ec6
--- /dev/null
+++ b/test/test_matrices.h
@@ -0,0 +1,42 @@
+#pragma once
+
+#include "../aligned.h"
+#include "../tile/access.h"
+#include <random>
+// Yes both due to debacle.
+#include <stdint.h>
+#include <cstdint>
+
+namespace intgemm {
+
+struct TestMatrices8 {
+ typedef Access<RowMajorAccess<int8_t>, ColMajorAccess<int8_t>, RowMajorAccess<int32_t> > AccessT;
+
+ explicit TestMatrices8(Tile shape_in) :
+ shape(shape_in),
+ A(shape.A_rows * shape.inner),
+ B(shape.inner * shape.B_cols),
+ C(shape.A_rows * shape.B_cols) {
+
+ std::mt19937 gen;
+ std::uniform_int_distribution<int8_t> dist(-127,127);
+ for (int8_t &it : A) it = dist(gen);
+ for (int8_t &it : B) it = dist(gen);
+ // C is uninitialized.
+ }
+
+ AccessT Accessor() {
+ return AccessT(
+ RowMajorAccess<int8_t>(A.begin(), shape.inner),
+ ColMajorAccess<int8_t>(B.begin(), shape.inner),
+ RowMajorAccess<int32_t>(C.begin(), shape.B_cols));
+ }
+
+ Tile shape;
+ AlignedVector<int8_t> A;
+ AlignedVector<int8_t> B;
+ // Uninitialized; for using tests to write to.
+ AlignedVector<int32_t> C;
+};
+
+} // namespace intgemm
diff --git a/test/tile_test.cc b/test/tile_test.cc
index 0b7d94c..2385a81 100644
--- a/test/tile_test.cc
+++ b/test/tile_test.cc
@@ -5,6 +5,8 @@
#include "../tile/reduce.h"
#include "test.h"
+#include "test_matrices.h"
+
#include <random>
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
diff --git a/test/tile_test.inl b/test/tile_test.inl
index 24decca..1da3556 100644
--- a/test/tile_test.inl
+++ b/test/tile_test.inl
@@ -151,20 +151,11 @@ void DumpMatrix(int8_t *m, Index rows, Index cols) {
}
}
-struct TestMatrices {
- typedef Access<RowMajorAccess<int8_t>, ColMajorAccess<int8_t>, RowMajorAccess<int32_t> > AccessT;
-
- TestMatrices(Tile shape_in) :
- shape(shape_in),
- A(shape.A_rows * shape.inner),
- B(shape.inner * shape.B_cols),
+struct TestMatricesRef : TestMatrices8 {
+ TestMatricesRef(Tile shape_in) :
+ TestMatrices8(shape_in),
C_reference(shape.A_rows * shape.B_cols) {
- std::mt19937 gen;
- std::uniform_int_distribution<int8_t> dist(-127,127);
- for (int8_t &it : A) it = dist(gen);
- for (int8_t &it : B) it = dist(gen);
-
AccessT ref_access(
RowMajorAccess<int8_t>(A.begin(), shape.inner),
ColMajorAccess<int8_t>(B.begin(), shape.inner),
@@ -172,16 +163,6 @@ struct TestMatrices {
Signed8ReferenceMult<AccessT>(ref_access, shape);
}
- AccessT AccessTest(AlignedVector<int32_t> &C_test) {
- return AccessT(
- RowMajorAccess<int8_t>(A.begin(), shape.inner),
- ColMajorAccess<int8_t>(B.begin(), shape.inner),
- RowMajorAccess<int32_t>(C_test.begin(), shape.B_cols));
- }
-
- Tile shape;
- AlignedVector<int8_t> A;
- AlignedVector<int8_t> B;
AlignedVector<int32_t> C_reference;
};
@@ -191,10 +172,9 @@ template <class Kernel> void TestMultiplyNoOverhang(Tile shape) {
CHECK(shape.A_rows % Kernel::kTile.A_rows == 0);
CHECK(shape.inner % Kernel::kTile.inner == 0);
CHECK(shape.B_cols % Kernel::kTile.B_cols == 0);
- TestMatrices t(shape);
- AlignedVector<int32_t> C_test(shape.A_rows * shape.B_cols);
- MultiplyNoOverhang<TestMatrices::AccessT, Kernel>(t.AccessTest(C_test), shape);
- CHECK(!memcmp(t.C_reference.begin(), C_test.begin(), shape.A_rows * shape.B_cols * sizeof(int32_t)));
+ TestMatricesRef t(shape);
+ MultiplyNoOverhang<TestMatricesRef::AccessT, Kernel>(t.Accessor(), shape);
+ CHECK(!memcmp(t.C_reference.begin(), t.C.begin(), shape.A_rows * shape.B_cols * sizeof(int32_t)));
/* for (Index i = 0; i < shape.A_rows; ++i) {
for (Index j = 0; j < shape.B_cols; ++j) {
CHECK(t.C_reference[i * shape.B_cols + j] == C_test[i * shape.B_cols + j]);
@@ -305,13 +285,12 @@ TEST_CASE("Multiply " INTGEMM_TEST_NAME, "[tile][multiply]") {
Tile shape{1, sizeof(Register), 1};
for (shape.A_rows = 1; shape.A_rows < 33; ++shape.A_rows) {
for (shape.B_cols = 1; shape.B_cols < 33; ++shape.B_cols) {
- TestMatrices t(shape);
- AlignedVector<int32_t> C_test(shape.A_rows * shape.B_cols);
- Multiply<TestMatrices::AccessT, Signed8, 7, 3>(t.AccessTest(C_test), shape);
- CHECK(!memcmp(t.C_reference.begin(), C_test.begin(), shape.A_rows * shape.B_cols * sizeof(int32_t)));
- memset(C_test.begin(), 0, shape.A_rows * shape.B_cols);
- Multiply<TestMatrices::AccessT, Signed8, 4, 5>(t.AccessTest(C_test), shape);
- CHECK(!memcmp(t.C_reference.begin(), C_test.begin(), shape.A_rows * shape.B_cols * sizeof(int32_t)));
+ TestMatricesRef t(shape);
+ Multiply<TestMatricesRef::AccessT, Signed8, 7, 3>(t.Accessor(), shape);
+ CHECK(!memcmp(t.C_reference.begin(), t.C.begin(), shape.A_rows * shape.B_cols * sizeof(int32_t)));
+ memset(t.C.begin(), 0, shape.A_rows * shape.B_cols * sizeof(int32_t));
+ Multiply<TestMatricesRef::AccessT, Signed8, 4, 5>(t.Accessor(), shape);
+ CHECK(!memcmp(t.C_reference.begin(), t.C.begin(), shape.A_rows * shape.B_cols * sizeof(int32_t)));
}
}
}