Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Chudyk <mateuszchudyk@gmail.com>2020-04-25 01:51:59 +0300
committerMateusz Chudyk <mateuszchudyk@gmail.com>2020-04-25 04:00:56 +0300
commit9f30419eaafeb91fd2e2b17f0590cb505ee98d28 (patch)
treef78eb71ea57b9d711233890984fb25eda72bc33d
parent943a6baa291662d903a42070433a6d3d7f56428a (diff)
Add UnquantizeAndWriteCallback tests
-rw-r--r--test/test_matrices.h34
-rw-r--r--test/tile_test.inl100
2 files changed, 134 insertions, 0 deletions
diff --git a/test/test_matrices.h b/test/test_matrices.h
index c9b7ec6..2e3acf7 100644
--- a/test/test_matrices.h
+++ b/test/test_matrices.h
@@ -39,4 +39,38 @@ struct TestMatrices8 {
AlignedVector<int32_t> C;
};
+struct TestMatricesUnquantizeAndWriteRowMajorAccess {
+ typedef Access<RowMajorAccess<int8_t>, ColMajorAccess<int8_t>, UnquantizeAndWriteRowMajorAccess<float>> AccessT;
+
+ explicit TestMatricesUnquantizeAndWriteRowMajorAccess(Tile shape_in, float unquant_mult) :
+ shape(shape_in),
+ A(shape.A_rows * shape.inner),
+ B(shape.inner * shape.B_cols),
+ C(shape.A_rows * shape.B_cols),
+ unquant_mult_(unquant_mult) {
+
+ std::mt19937 gen;
+ std::uniform_int_distribution<int8_t> dist(-127,127);
+ for (int8_t &it : A) it = dist(gen);
+ for (int8_t &it : B) it = dist(gen);
+ // C is uninitialized.
+ }
+
+ AccessT Accessor() {
+ return AccessT(
+ RowMajorAccess<int8_t>(A.begin(), shape.inner),
+ ColMajorAccess<int8_t>(B.begin(), shape.inner),
+ UnquantizeAndWriteRowMajorAccess<float>(C.begin(), shape.B_cols, {unquant_mult_}));
+ }
+
+ Tile shape;
+ AlignedVector<int8_t> A;
+ AlignedVector<int8_t> B;
+ // Uninitialized; for using tests to write to.
+ AlignedVector<float> C;
+
+private:
+ float unquant_mult_;
+};
+
} // namespace intgemm
diff --git a/test/tile_test.inl b/test/tile_test.inl
index 1da3556..1c1d0d3 100644
--- a/test/tile_test.inl
+++ b/test/tile_test.inl
@@ -141,6 +141,46 @@ template <class Access> void Signed8ReferenceMult(Access access, Tile problem) {
}
}
+template <class Access> void Signed8ReferenceMult_UnquantizeAndWrite(Access access, Tile problem, float unquant_mult) {
+ assert(!problem.inner % 2);
+ for (Index a_row = 0; a_row < problem.A_rows; ++a_row) {
+ for (Index b_col = 0; b_col < problem.B_cols; ++b_col) {
+ Access acc = access.AAdd(a_row, 0).BAdd(0, b_col).CAdd(a_row, b_col);
+ // For VNNI, just do it accurately.
+#ifdef INTGEMM_THIS_IS_AVX512VNNI
+ acc.CFront() = 0;
+ for (Index inner = 0; inner < problem.inner; ++inner) {
+ Access innermost = acc.AAdd(0, inner).BAdd(inner, 0);
+ acc.CFront() += static_cast<int32_t>(innermost.AFront()) * static_cast<int32_t>(innermost.BFront());
+ }
+#else
+ // For non-VNNI, do the saturation stuff.
+ int16_t accumulators[sizeof(Register) / sizeof(int16_t)] = {0};
+ for (Index inner = 0; inner < problem.inner; inner += 2) {
+ Access innermost = acc.AAdd(0, inner).BAdd(inner, 0);
+ int32_t product = static_cast<int32_t>(innermost.AFront()) * static_cast<int32_t>(innermost.BFront());
+ innermost = innermost.AAdd(0, 1).BAdd(1, 0);
+ product += static_cast<int32_t>(innermost.AFront()) * static_cast<int32_t>(innermost.BFront());
+ // Saturate to 16-bit for maddubs.
+ if (product > 32767) product = 32767;
+ if (product < -32768) product = -32768;
+ int16_t &accum = accumulators[(inner / 2) % (sizeof(Register) / sizeof(int16_t))];
+ // Saturating accumlation.
+ product += static_cast<int32_t>(accum);
+ if (product > 32767) product = 32767;
+ if (product < -32768) product = -32768;
+ accum = static_cast<int16_t>(product);
+ }
+ acc.CFront() = 0;
+ for (Index i = 0; i < sizeof(Register) / sizeof(int16_t); ++i) {
+ acc.CFront() += static_cast<int32_t>(accumulators[i]);
+ }
+#endif
+ acc.CFront() *= unquant_mult;
+ }
+ }
+}
+
void DumpMatrix(int8_t *m, Index rows, Index cols) {
std::cerr << rows << 'x' << cols << '\n';
for (Index i = 0; i < rows; ++i) {
@@ -166,6 +206,22 @@ struct TestMatricesRef : TestMatrices8 {
AlignedVector<int32_t> C_reference;
};
+struct TestMatricesRef_UnquantizeAndWrite : TestMatricesUnquantizeAndWriteRowMajorAccess {
+ TestMatricesRef_UnquantizeAndWrite(Tile shape_in, float unquant_mult) :
+ TestMatricesUnquantizeAndWriteRowMajorAccess(shape_in, unquant_mult),
+ C_reference(shape.A_rows * shape.B_cols) {
+
+ AccessT ref_access(
+ RowMajorAccess<int8_t>(A.begin(), shape.inner),
+ ColMajorAccess<int8_t>(B.begin(), shape.inner),
+ UnquantizeAndWriteRowMajorAccess<float>(C_reference.begin(), shape.B_cols, {unquant_mult}));
+ Signed8ReferenceMult_UnquantizeAndWrite<AccessT>(ref_access, shape, unquant_mult);
+ }
+
+ AlignedVector<float> C_reference;
+};
+
+
#ifndef INTGEMM_THIS_IS_SSE2
template <class Kernel> void TestMultiplyNoOverhang(Tile shape) {
// These are sanity checks on the arguments, not the code.
@@ -182,6 +238,16 @@ template <class Kernel> void TestMultiplyNoOverhang(Tile shape) {
}*/
}
+template <class Kernel> void TestMultiplyNoOverhang_UnquantizeAndWrite(Tile shape, float unquant_mult) {
+ // These are sanity checks on the arguments, not the code.
+ CHECK(shape.A_rows % Kernel::kTile.A_rows == 0);
+ CHECK(shape.inner % Kernel::kTile.inner == 0);
+ CHECK(shape.B_cols % Kernel::kTile.B_cols == 0);
+ TestMatricesRef_UnquantizeAndWrite t(shape, unquant_mult);
+ MultiplyNoOverhang<TestMatricesRef_UnquantizeAndWrite::AccessT, Kernel>(t.Accessor(), shape);
+ CHECK(!memcmp(t.C_reference.begin(), t.C.begin(), shape.A_rows * shape.B_cols * sizeof(float)));
+}
+
template <class Kernel> void TestMultiplyNoOverhangShapes() {
Tile shape = Kernel::kTile;
// Minimum size.
@@ -201,6 +267,25 @@ template <class Kernel> void TestMultiplyNoOverhangShapes() {
}
}
+template <class Kernel> void TestMultiplyNoOverhangShapes_UnquantizeAndWrite(float unquant_mult) {
+ Tile shape = Kernel::kTile;
+ // Minimum size.
+ TestMultiplyNoOverhang_UnquantizeAndWrite<Kernel>(shape, unquant_mult);
+ // Multiples on each dimension.
+ TestMultiplyNoOverhang_UnquantizeAndWrite<Kernel>(Tile{shape.A_rows * 2, shape.inner, shape.B_cols}, unquant_mult);
+ TestMultiplyNoOverhang_UnquantizeAndWrite<Kernel>(Tile{shape.A_rows, shape.inner * 2, shape.B_cols}, unquant_mult);
+ TestMultiplyNoOverhang_UnquantizeAndWrite<Kernel>(Tile{shape.A_rows, shape.inner, shape.B_cols * 2}, unquant_mult);
+ TestMultiplyNoOverhang_UnquantizeAndWrite<Kernel>(Tile{shape.A_rows * 2, shape.inner * 2, shape.B_cols * 2}, unquant_mult);
+ // Try a bunch of shapes!
+ for (shape.A_rows = 0; shape.A_rows <= Kernel::kTile.A_rows * 9; shape.A_rows += Kernel::kTile.A_rows) {
+ for (shape.inner = 0; shape.inner <= Kernel::kTile.inner * 9; shape.inner += Kernel::kTile.inner) {
+ for (shape.B_cols = 0; shape.B_cols <= Kernel::kTile.B_cols * 9; shape.B_cols += Kernel::kTile.B_cols) {
+ TestMultiplyNoOverhang_UnquantizeAndWrite<Kernel>(shape, unquant_mult);
+ }
+ }
+ }
+}
+
TEST_CASE("MultiplyNoOverhang Signed8 " INTGEMM_TEST_NAME, "[tile]") {
if (kCPU < CPUType::INTGEMM_ARCH) return;
TestMultiplyNoOverhangShapes<Signed8>();
@@ -219,6 +304,21 @@ TEST_CASE("MultiplyNoOverhang inner unroll " INTGEMM_TEST_NAME, "[tile][multiply
TestMultiplyNoOverhang<Kernel>({1, sizeof(Register) * 4, 1});
TestMultiplyNoOverhangShapes<Kernel>();
}
+
+TEST_CASE("MultiplyNoOverhang Signed8 UnquantizeAndWrite " INTGEMM_TEST_NAME, "[tile]") {
+ if (kCPU < CPUType::INTGEMM_ARCH) return;
+ TestMultiplyNoOverhangShapes_UnquantizeAndWrite<Signed8>(1.7f);
+}
+
+TEST_CASE("MultiplyNoOverhang inner unroll UnquantizeAndWrite " INTGEMM_TEST_NAME, "[tile][multiply]") {
+ if (kCPU < CPUType::INTGEMM_ARCH) return;
+ float unquant_mult = 1.7f;
+ typedef UnrollKernel<1, 2, 1, Signed8> Kernel;
+ Tile shape = {1, sizeof(Register) * 2, 1};
+ TestMultiplyNoOverhang_UnquantizeAndWrite<Kernel>(shape, unquant_mult);
+ TestMultiplyNoOverhang_UnquantizeAndWrite<Kernel>({1, sizeof(Register) * 4, 1}, unquant_mult);
+ TestMultiplyNoOverhangShapes_UnquantizeAndWrite<Kernel>(unquant_mult);
+}
#endif // INTGEMM_THIS_IS_AVX512VNNI
// If the inner dimension is just twice, then there isn't any non-determinism in saturation order.