Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2020-04-24 00:18:53 +0300
committerKenneth Heafield <github@kheafield.com>2020-04-24 00:19:29 +0300
commited5ba5e5eb56d6cd44913d8efb29f3fed248b039 (patch)
tree3e3a518ba71a80f8b57e854eefc837c859b725e3
parent88471a26388f38eb4007cfbd4e52ac1a60c88a79 (diff)
Basic general sized multiply, not optimized yet
-rw-r--r--test/tile_test.inl86
-rw-r--r--tile/multiply.inl31
2 files changed, 83 insertions, 34 deletions
diff --git a/test/tile_test.inl b/test/tile_test.inl
index 69fce76..bd5c891 100644
--- a/test/tile_test.inl
+++ b/test/tile_test.inl
@@ -151,48 +151,55 @@ void DumpMatrix(int8_t *m, Index rows, Index cols) {
}
}
+struct TestMatrices {
+ typedef Access<RowMajorAccess<int8_t>, ColMajorAccess<int8_t>, RowMajorAccess<int32_t> > AccessT;
+
+ TestMatrices(Tile shape_in) :
+ shape(shape_in),
+ A(shape.A_rows * shape.inner),
+ B(shape.inner * shape.B_cols),
+ C_reference(shape.A_rows * shape.B_cols) {
+
+ std::mt19937 gen;
+ std::uniform_int_distribution<int8_t> dist(-127,127);
+ for (int8_t &it : A) it = dist(gen);
+ for (int8_t &it : B) it = dist(gen);
+
+ AccessT ref_access(
+ RowMajorAccess<int8_t>(A.begin(), shape.inner),
+ ColMajorAccess<int8_t>(B.begin(), shape.inner),
+ RowMajorAccess<int32_t>(C_reference.begin(), shape.B_cols));
+ Signed8ReferenceMult<AccessT>(ref_access, shape);
+ }
+
+ AccessT AccessTest(AlignedVector<int32_t> &C_test) {
+ return AccessT(
+ RowMajorAccess<int8_t>(A.begin(), shape.inner),
+ ColMajorAccess<int8_t>(B.begin(), shape.inner),
+ RowMajorAccess<int32_t>(C_test.begin(), shape.B_cols));
+ }
+
+ Tile shape;
+ AlignedVector<int8_t> A;
+ AlignedVector<int8_t> B;
+ AlignedVector<int32_t> C_reference;
+};
+
#ifndef INTGEMM_THIS_IS_SSE2
template <class Kernel> void TestMultiplyNoOverhang(Tile shape) {
// These are sanity checks on the arguments, not the code.
CHECK(shape.A_rows % Kernel::kTile.A_rows == 0);
CHECK(shape.inner % Kernel::kTile.inner == 0);
CHECK(shape.B_cols % Kernel::kTile.B_cols == 0);
-
- AlignedVector<int8_t> A(shape.A_rows * shape.inner);
- AlignedVector<int8_t> B(shape.inner * shape.B_cols);
- std::mt19937 gen;
- std::uniform_int_distribution<int8_t> dist(-127,127);
- for (int8_t &it : A) it = dist(gen);
- for (int8_t &it : B) it = dist(gen);
-
- AlignedVector<int32_t> C_reference(shape.A_rows * shape.B_cols);
- typedef Access<RowMajorAccess<int8_t>, ColMajorAccess<int8_t>, RowMajorAccess<int32_t> > AccessT;
- AccessT ref_access(
- RowMajorAccess<int8_t>(A.begin(), shape.inner),
- ColMajorAccess<int8_t>(B.begin(), shape.inner),
- RowMajorAccess<int32_t>(C_reference.begin(), shape.B_cols));
- Signed8ReferenceMult<AccessT>(ref_access, shape);
-
+ TestMatrices t(shape);
AlignedVector<int32_t> C_test(shape.A_rows * shape.B_cols);
- AccessT test_access(
- RowMajorAccess<int8_t>(A.begin(), shape.inner),
- ColMajorAccess<int8_t>(B.begin(), shape.inner),
- RowMajorAccess<int32_t>(C_test.begin(), shape.B_cols));
- MultiplyNoOverhang<AccessT, Kernel>(test_access, shape);
- bool failed = false;
- for (Index i = 0; i < shape.A_rows; ++i) {
+ MultiplyNoOverhang<TestMatrices::AccessT, Kernel>(t.AccessTest(C_test), shape);
+ CHECK(!memcmp(t.C_reference.begin(), C_test.begin(), shape.A_rows * shape.B_cols));
+/* for (Index i = 0; i < shape.A_rows; ++i) {
for (Index j = 0; j < shape.B_cols; ++j) {
- CHECK(C_reference[i * shape.B_cols + j] == C_test[i * shape.B_cols + j]);
- if (C_reference[i * shape.B_cols + j] != C_test[i * shape.B_cols + j])
- failed = true;
+ CHECK(t.C_reference[i * shape.B_cols + j] == C_test[i * shape.B_cols + j]);
}
- }
- if (failed) {
- std::cerr << "Failed A is ";
- DumpMatrix(A.begin(), shape.A_rows, shape.inner);
- std::cerr << "Failed B is ";
- DumpMatrix(B.begin(), shape.inner, shape.B_cols);
- }
+ }*/
}
template <class Kernel> void TestMultiplyNoOverhangShapes() {
@@ -293,6 +300,19 @@ TEMPLATE_TEST_CASE("MultiplyNoOverhang Unrolled Signed8 " INTGEMM_TEST_NAME, "[t
TestMultiplyNoOverhangShapes<TestType>();
}
+TEST_CASE("Multiply " INTGEMM_TEST_NAME, "[tile][multiply]") {
+ if (kCPU < CPUType::INTGEMM_ARCH) return;
+ Tile shape{1, sizeof(Register), 1};
+ for (shape.A_rows = 1; shape.A_rows < 33; ++shape.A_rows) {
+ for (shape.B_cols = 1; shape.B_cols < 33; ++shape.B_cols) {
+ TestMatrices t(shape);
+ AlignedVector<int32_t> C_test(shape.A_rows * shape.B_cols);
+ Multiply<TestMatrices::AccessT, Signed8, 7, 3>(t.AccessTest(C_test), shape);
+ CHECK(!memcmp(t.C_reference.begin(), C_test.begin(), shape.A_rows * shape.B_cols));
+ }
+ }
+}
+
#endif // no INTGEMM_THIS_IS_SSE2
} // namespace INTGEMM_ARCH
diff --git a/tile/multiply.inl b/tile/multiply.inl
index f1344d7..78dca55 100644
--- a/tile/multiply.inl
+++ b/tile/multiply.inl
@@ -29,11 +29,11 @@ template <class AccessT, class Kernel> INTGEMM_TARGET __attribute__((flatten)) s
assert(shape.A_rows % Kernel::kTile.A_rows == 0);
assert(shape.inner % Kernel::kTile.inner == 0);
assert(shape.B_cols % Kernel::kTile.B_cols == 0);
- constexpr Index Outputs = Kernel::kTile.A_rows * Kernel::kTile.B_cols;
for (Index B_col = 0; B_col < shape.B_cols; B_col += Kernel::kTile.B_cols) {
AccessT column_adjusted = access.BAdd(0, B_col).CAdd(0, B_col);
for (Index A_row = 0; A_row < shape.A_rows; A_row += Kernel::kTile.A_rows) {
AccessT col_row = column_adjusted.AAdd(A_row, 0).CAdd(A_row, 0);
+ constexpr Index Outputs = Kernel::kTile.A_rows * Kernel::kTile.B_cols;
// Accumulate values in temporary C registers.
Register c_regs[Outputs] = {setzero_si<Register>()};
@@ -56,6 +56,35 @@ template <class AccessT, class Kernel> INTGEMM_TARGET __attribute__((flatten)) s
}
}
+template <class Access, class Kernel, Index A_rows, Index B_cols> INTGEMM_TARGET static inline void Multiply(Access access, const Tile shape) {
+ // Still has to be a multiple of the underlying Kernel, but usually that's just 1 x sizeof(Register) x 1.
+ assert(shape.A_rows % Kernel::kTile.A_rows == 0);
+ assert(shape.inner % Kernel::kTile.inner == 0);
+ assert(shape.B_cols % Kernel::kTile.B_cols == 0);
+
+ typedef UnrollKernel<A_rows, 1, B_cols, Kernel> Big;
+ Tile overhang = {
+ shape.A_rows % Big::kTile.A_rows,
+ shape.inner % Big::kTile.inner,
+ shape.B_cols % Big::kTile.B_cols
+ };
+ Tile big_shape = {
+ shape.A_rows - overhang.A_rows,
+ shape.inner - overhang.inner,
+ shape.B_cols - overhang.B_cols
+ };
+ // Top left corner.
+ MultiplyNoOverhang<Access, Big>(access, big_shape);
+ // Bottom currently including right side. TODO: unrolled kernel, rather than dumb loop.
+ MultiplyNoOverhang<Access, Kernel>(
+ access.AAdd(big_shape.A_rows, 0).CAdd(big_shape.A_rows, 0),
+ Tile {overhang.A_rows, shape.inner, shape.B_cols});
+ // Right side except bottom. TODO: unrolled kernel, rather than dumb loop.
+ MultiplyNoOverhang<Access, Kernel>(
+ access.BAdd(0, big_shape.B_cols).CAdd(0, big_shape.B_cols),
+ Tile {big_shape.A_rows, shape.inner, overhang.B_cols});
+}
+
} // namespace INTGEMM_ARCH
} // namespace intgemm