Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2020-04-18 19:11:41 +0300
committerKenneth Heafield <github@kheafield.com>2020-04-18 19:11:41 +0300
commit27049f4f30347ae999311eb1445941e2ae88809e (patch)
tree08d591d85651bf67571eee905aa2cc63787238df
parent9801459266d44a38f6e2aade427491d3c3015218 (diff)
Tiled multiply with basic testing work
-rw-r--r--test/tile_test.cc1
-rw-r--r--test/tile_test.inl116
-rw-r--r--tile/access.h18
-rw-r--r--tile/multiply.h32
-rw-r--r--tile/multiply.inl71
-rw-r--r--tile/reduce.inl4
-rw-r--r--types.h3
7 files changed, 242 insertions, 3 deletions
diff --git a/test/tile_test.cc b/test/tile_test.cc
index aa1380b..0b7d94c 100644
--- a/test/tile_test.cc
+++ b/test/tile_test.cc
@@ -1,6 +1,7 @@
#include "../aligned.h"
#include "../tile/access.h"
#include "../tile/dot.h"
+#include "../tile/multiply.h"
#include "../tile/reduce.h"
#include "test.h"
diff --git a/test/tile_test.inl b/test/tile_test.inl
index 3d13920..2beaff6 100644
--- a/test/tile_test.inl
+++ b/test/tile_test.inl
@@ -22,7 +22,6 @@
#error "Included without expected architecture"
#endif
-
namespace intgemm {
namespace INTGEMM_ARCH {
@@ -101,6 +100,121 @@ TEST_CASE("Reduce " INTGEMM_TEST_NAME, "[tile]") {
StaticLoop<Reduce32Test, MakeStaticLoopIterator<33>>();
}
+// Replicate the saturation behavior of the Signed8 kernel with 16-bit accumulation.
+template <class Access> void Signed8ReferenceMult(Access access, Tile problem) {
+ assert(!problem.inner % 2);
+ for (Index a_row = 0; a_row < problem.A_rows; ++a_row) {
+ for (Index b_col = 0; b_col < problem.B_cols; ++b_col) {
+ Access acc = access.AAdd(a_row, 0).BAdd(0, b_col).CAdd(a_row, b_col);
+ // For VNNI, just do it accurately.
+#ifdef INTGEMM_THIS_IS_AVX512VNNI
+ acc.CFront() = 0;
+ for (Index inner = 0; inner < problem.inner; ++inner) {
+ Access innermost = acc.AAdd(0, inner).BAdd(inner, 0);
+ acc.CFront() += static_cast<int32_t>(innermost.AFront()) * static_cast<int32_t>(innermost.BFront());
+ }
+#else
+ // For non-VNNI, do the saturation stuff.
+ int16_t accumulators[sizeof(Register) / sizeof(int16_t)] = {0};
+ for (Index inner = 0; inner < problem.inner; inner += 2) {
+ Access innermost = acc.AAdd(0, inner).BAdd(inner, 0);
+ int32_t product = static_cast<int32_t>(innermost.AFront()) * static_cast<int32_t>(innermost.BFront());
+ innermost = innermost.AAdd(0, 1).BAdd(1, 0);
+ product += static_cast<int32_t>(innermost.AFront()) * static_cast<int32_t>(innermost.BFront());
+ // Saturate to 16-bit for maddubs.
+ if (product > 32767) product = 32767;
+ if (product < -32768) product = -32768;
+ int16_t &accum = accumulators[(inner / 2) % (sizeof(Register) / sizeof(int16_t))];
+ // Saturating accumlation.
+ product += static_cast<int32_t>(accum);
+ if (product > 32767) product = 32767;
+ if (product < -32768) product = -32768;
+ accum = static_cast<int16_t>(product);
+ }
+ acc.CFront() = 0;
+ for (Index i = 0; i < sizeof(Register) / sizeof(int16_t); ++i) {
+ acc.CFront() += static_cast<int32_t>(accumulators[i]);
+ }
+#endif
+ }
+ }
+}
+
+void DumpMatrix(int8_t *m, Index rows, Index cols) {
+ std::cerr << rows << 'x' << cols << '\n';
+ for (Index i = 0; i < rows; ++i) {
+ for (Index j = 0; j < cols; ++j) {
+ std::cerr << (int16_t)m[i * cols + j] << ' ';
+ }
+ std::cerr << '\n';
+ }
+}
+
+#ifndef INTGEMM_THIS_IS_SSE2
+template <class Kernel> void TestMultiplyNoOverhang(Tile shape) {
+ // These are sanity checks on the arguments, not the code.
+ CHECK(shape.A_rows % Kernel::kTile.A_rows == 0);
+ CHECK(shape.inner % Kernel::kTile.inner == 0);
+ CHECK(shape.B_cols % Kernel::kTile.B_cols == 0);
+
+ AlignedVector<int8_t> A(shape.A_rows * shape.inner);
+ AlignedVector<int8_t> B(shape.inner * shape.B_cols);
+ std::mt19937 gen;
+ std::uniform_int_distribution<int8_t> dist(-127,127);
+ for (int8_t &it : A) it = dist(gen);
+ for (int8_t &it : B) it = dist(gen);
+
+ AlignedVector<int32_t> C_reference(shape.A_rows * shape.B_cols);
+ typedef Access<RowMajorAccess<int8_t>, ColMajorAccess<int8_t>, RowMajorAccess<int32_t> > AccessT;
+ AccessT ref_access(
+ RowMajorAccess<int8_t>(A.begin(), shape.inner),
+ ColMajorAccess<int8_t>(B.begin(), shape.inner),
+ RowMajorAccess<int32_t>(C_reference.begin(), shape.B_cols));
+ Signed8ReferenceMult<AccessT>(ref_access, shape);
+
+ AlignedVector<int32_t> C_test(shape.A_rows * shape.B_cols);
+ AccessT test_access(
+ RowMajorAccess<int8_t>(A.begin(), shape.inner),
+ ColMajorAccess<int8_t>(B.begin(), shape.inner),
+ RowMajorAccess<int32_t>(C_test.begin(), shape.B_cols));
+ MultiplyNoOverhang<AccessT, Kernel>(test_access, shape);
+ bool failed = false;
+ for (Index i = 0; i < shape.A_rows; ++i) {
+ for (Index j = 0; j < shape.B_cols; ++j) {
+ CHECK(C_reference[i * shape.B_cols + j] == C_test[i * shape.B_cols + j]);
+ if (C_reference[i * shape.B_cols + j] != C_test[i * shape.B_cols + j])
+ failed = true;
+ }
+ }
+ if (failed) {
+ std::cerr << "Failed A is ";
+ DumpMatrix(A.begin(), shape.A_rows, shape.inner);
+ std::cerr << "Failed B is ";
+ DumpMatrix(B.begin(), shape.inner, shape.B_cols);
+ }
+}
+
+TEST_CASE("MultiplyNoOverhang Signed8 " INTGEMM_TEST_NAME, "[tile]") {
+ if (kCPU < CPUType::INTGEMM_ARCH) return;
+ // Test small multiples.
+ TestMultiplyNoOverhang<Signed8>(Tile{1,sizeof(Register),1});
+ TestMultiplyNoOverhang<Signed8>(Tile{2, sizeof(Register), 1});
+ TestMultiplyNoOverhang<Signed8>(Tile{1, 2 * sizeof(Register),1});
+ TestMultiplyNoOverhang<Signed8>(Tile{1, sizeof(Register), 2});
+ TestMultiplyNoOverhang<Signed8>(Tile{2, 2 * sizeof(Register), 2});
+ // Try a bunch of shapes!
+ Tile shape;
+ for (shape.A_rows = 0; shape.A_rows <= 33; ++shape.A_rows) {
+ for (shape.inner = 0; shape.inner <= 9 * sizeof(Register); shape.inner += sizeof(Register)) {
+ for (shape.B_cols = 0; shape.B_cols <= 33; ++shape.B_cols) {
+ TestMultiplyNoOverhang<Signed8>(shape);
+ }
+ }
+ }
+}
+
+#endif
+
} // namespace INTGEMM_ARCH
} // namespace intgemm
diff --git a/tile/access.h b/tile/access.h
index e2fb1a8..d5e6119 100644
--- a/tile/access.h
+++ b/tile/access.h
@@ -20,6 +20,19 @@ template <class T> class RowMajorAccess {
const Content &Front() const { return *data_; }
Content &Front() { return *data_; }
+ // TODO: SLOW. This is here for testing.
+ template <Index A_rows, Index B_cols> void Write(const __m128i *from) { Write<A_rows, B_cols>(reinterpret_cast<const T*>(from)); }
+ template <Index A_rows, Index B_cols> void Write(const __m256i *from) { Write<A_rows, B_cols>(reinterpret_cast<const T*>(from)); }
+ template <Index A_rows, Index B_cols> void Write(const __m512i *from) { Write<A_rows, B_cols>(reinterpret_cast<const T*>(from)); }
+
+ template <Index A_rows, Index B_cols> void Write(const T *from) {
+ for (Index i = 0; i < A_rows; ++i) {
+ for (Index j = 0; j < B_cols; ++j) {
+ data_[i * cols_ + j] = from[i * B_cols + j];
+ }
+ }
+ }
+
private:
Content *data_;
Index cols_;
@@ -68,6 +81,11 @@ template <class AT, class BT, class CT> class Access {
return Access(a_, b_, c_.Add(row, col));
}
+ const A &AAccessor() const { return a_; }
+ const B &BAccessor() const { return b_; }
+ const C &CAccessor() const { return c_; }
+ C &CAccessor() { return c_; }
+
AContent &AFront() { return a_.Front(); }
const AContent &AFront() const { return a_.Front(); }
BContent &BFront() { return b_.Front(); }
diff --git a/tile/multiply.h b/tile/multiply.h
new file mode 100644
index 0000000..294b3ed
--- /dev/null
+++ b/tile/multiply.h
@@ -0,0 +1,32 @@
+#pragma once
+
+#include "access.h"
+#include "dot.h"
+#include "reduce.h"
+#include "../types.h"
+
+#include <cassert>
+
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
+#define INTGEMM_THIS_IS_AVX512VNNI
+#include "multiply.inl"
+#undef INTGEMM_THIS_IS_AVX512VNNI
+#endif
+
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
+#define INTGEMM_THIS_IS_AVX512BW
+#include "multiply.inl"
+#undef INTGEMM_THIS_IS_AVX512BW
+#endif
+
+#define INTGEMM_THIS_IS_AVX2
+#include "multiply.inl"
+#undef INTGEMM_THIS_IS_AVX2
+
+#define INTGEMM_THIS_IS_SSSE3
+#include "multiply.inl"
+#undef INTGEMM_THIS_IS_SSSE3
+
+#define INTGEMM_THIS_IS_SSE2
+#include "multiply.inl"
+#undef INTGEMM_THIS_IS_SSE2
diff --git a/tile/multiply.inl b/tile/multiply.inl
new file mode 100644
index 0000000..f1a22cf
--- /dev/null
+++ b/tile/multiply.inl
@@ -0,0 +1,71 @@
+#if defined(INTGEMM_THIS_IS_AVX512VNNI)
+#define INTGEMM_ARCH AVX512VNNI
+#define INTGEMM_TARGET INTGEMM_AVX512VNNI
+#elif defined(INTGEMM_THIS_IS_AVX512BW)
+#define INTGEMM_ARCH AVX512BW
+#define INTGEMM_TARGET INTGEMM_AVX512BW
+#elif defined(INTGEMM_THIS_IS_AVX2)
+#define INTGEMM_ARCH AVX2
+#define INTGEMM_TARGET INTGEMM_AVX2
+#elif defined(INTGEMM_THIS_IS_SSSE3)
+#define INTGEMM_ARCH SSSE3
+#define INTGEMM_TARGET INTGEMM_SSSE3
+#elif defined(INTGEMM_THIS_IS_SSE2)
+#define INTGEMM_ARCH SSE2
+#define INTGEMM_TARGET INTGEMM_SSE2
+#endif
+
+namespace intgemm {
+namespace INTGEMM_ARCH {
+
+// Upcast 16 to 32 if needed.
+template <class T> INTGEMM_TARGET static inline void SumTo32(Register &reg);
+template <> INTGEMM_TARGET inline void SumTo32<int16_t>(Register &reg) {
+ reg = madd_epi16(reg, set1_epi16<Register>(1));
+}
+template <> INTGEMM_TARGET inline void SumTo32<int32_t>(Register &) {}
+
+template <class T> struct SumTo32Body {
+ template <class Iterator> INTGEMM_TARGET static inline void body(Register *regs) {
+ SumTo32<T>(regs[Iterator::template I<0>()]);
+ }
+};
+
+/* Multiply assuming the matrix sizes are a multiple of the kernel size. */
+template <class AccessT, class Kernel> INTGEMM_TARGET static inline void MultiplyNoOverhang(AccessT access, const Tile shape) {
+ assert(shape.A_rows % Kernel::kTile.A_rows == 0);
+ assert(shape.inner % Kernel::kTile.inner == 0);
+ assert(shape.B_cols % Kernel::kTile.B_cols == 0);
+ constexpr Index Outputs = Kernel::kTile.A_rows * Kernel::kTile.B_cols;
+ for (Index B_col = 0; B_col < shape.B_cols; B_col += Kernel::kTile.B_cols) {
+ AccessT column_adjusted = access.BAdd(0, B_col).CAdd(0, B_col);
+ for (Index A_row = 0; A_row < shape.A_rows; A_row += Kernel::kTile.A_rows) {
+ AccessT col_row = column_adjusted.AAdd(A_row, 0).CAdd(A_row, 0);
+
+ // Accumulate values in temporary C registers.
+ Register c_regs[Outputs] = {setzero_si<Register>()};
+ // If C is column major it would be better to have column-major registers
+ // since this determines the order used by Reduce32.
+ Access<typename AccessT::A, typename AccessT::B, RegisterRowMajorAccess> reg_access(
+ col_row.AAccessor(),
+ col_row.BAccessor(),
+ RegisterRowMajorAccess(c_regs, Kernel::kTile.B_cols));
+
+ for (Index inner = 0; inner < shape.inner; inner += Kernel::kTile.inner) {
+ Kernel::Run(reg_access.AAdd(0, inner).BAdd(inner, 0));
+ }
+
+ // If 16-bit, upcast to 32-bit while horizontally adding.
+ StaticLoop<SumTo32Body<typename Kernel::Packed::C>, MakeStaticLoopIterator<Outputs>>(c_regs);
+ // Horizontally add 32-bit values.
+ Reduce32<Outputs, Sum32Op>(c_regs);
+ col_row.CAccessor().template Write<Kernel::kTile.A_rows, Kernel::kTile.B_cols>(c_regs);
+ }
+ }
+}
+
+} // namespace INTGEMM_ARCH
+} // namespace intgemm
+
+#undef INTGEMM_ARCH
+#undef INTGEMM_TARGET
diff --git a/tile/reduce.inl b/tile/reduce.inl
index f070723..ef9afc9 100644
--- a/tile/reduce.inl
+++ b/tile/reduce.inl
@@ -39,6 +39,7 @@ template <Index Valid, class Op, class Folder> INTGEMM_TARGET static inline void
/* These Folder structs say how to interweave even pairs of regiers and
* fold an odd register over itself. Folding an odd register over itself is
* slightly faster than doing an even fold with garbage. */
+// TODO: _mm_hadd_epi32 for SSSE3 and _mm256_hadd_epi32 for AVX2
struct Reduce32Folder {
INTGEMM_TARGET static inline RegisterPair Even(Register first, Register second) {
return RegisterPair { unpackhi_epi32(first, second), unpacklo_epi32(first, second) };
@@ -172,8 +173,7 @@ template <class Op> struct ReduceOverhang<3, Op> {
#endif
-/* The only public function: horizontally reduce registers with 32-bit values.
- */
+/* Public function: horizontally reduce registers with 32-bit values. */
template <Index Valid, class Op> INTGEMM_TARGET static inline void Reduce32(Register *regs) {
GenericReduce<Valid, Op, Reduce32Folder>(regs);
GenericReduce<(Valid + 1) / 2, Op, Reduce64Folder>(regs);
diff --git a/types.h b/types.h
index 9bc8111..b1655d3 100644
--- a/types.h
+++ b/types.h
@@ -49,6 +49,9 @@ enum class CPUType {
extern const CPUType kCPU;
struct Tile {
+/* Tile() {}
+ Tile(Index in_A_rows, Index in_inner, Index in_B_cols)
+ : A_rows(in_A_rows), inner(in_inner), B_cols(in_B_cols) {} */
Index A_rows, inner, B_cols;
};