diff options
author | Kenneth Heafield <github@kheafield.com> | 2020-04-18 19:11:41 +0300 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2020-04-18 19:11:41 +0300 |
commit | 27049f4f30347ae999311eb1445941e2ae88809e (patch) | |
tree | 08d591d85651bf67571eee905aa2cc63787238df | |
parent | 9801459266d44a38f6e2aade427491d3c3015218 (diff) |
Tiled multiply with basic testing work
-rw-r--r-- | test/tile_test.cc | 1 | ||||
-rw-r--r-- | test/tile_test.inl | 116 | ||||
-rw-r--r-- | tile/access.h | 18 | ||||
-rw-r--r-- | tile/multiply.h | 32 | ||||
-rw-r--r-- | tile/multiply.inl | 71 | ||||
-rw-r--r-- | tile/reduce.inl | 4 | ||||
-rw-r--r-- | types.h | 3 |
7 files changed, 242 insertions, 3 deletions
diff --git a/test/tile_test.cc b/test/tile_test.cc index aa1380b..0b7d94c 100644 --- a/test/tile_test.cc +++ b/test/tile_test.cc @@ -1,6 +1,7 @@ #include "../aligned.h" #include "../tile/access.h" #include "../tile/dot.h" +#include "../tile/multiply.h" #include "../tile/reduce.h" #include "test.h" diff --git a/test/tile_test.inl b/test/tile_test.inl index 3d13920..2beaff6 100644 --- a/test/tile_test.inl +++ b/test/tile_test.inl @@ -22,7 +22,6 @@ #error "Included without expected architecture" #endif - namespace intgemm { namespace INTGEMM_ARCH { @@ -101,6 +100,121 @@ TEST_CASE("Reduce " INTGEMM_TEST_NAME, "[tile]") { StaticLoop<Reduce32Test, MakeStaticLoopIterator<33>>(); } +// Replicate the saturation behavior of the Signed8 kernel with 16-bit accumulation. +template <class Access> void Signed8ReferenceMult(Access access, Tile problem) { + assert(!problem.inner % 2); + for (Index a_row = 0; a_row < problem.A_rows; ++a_row) { + for (Index b_col = 0; b_col < problem.B_cols; ++b_col) { + Access acc = access.AAdd(a_row, 0).BAdd(0, b_col).CAdd(a_row, b_col); + // For VNNI, just do it accurately. +#ifdef INTGEMM_THIS_IS_AVX512VNNI + acc.CFront() = 0; + for (Index inner = 0; inner < problem.inner; ++inner) { + Access innermost = acc.AAdd(0, inner).BAdd(inner, 0); + acc.CFront() += static_cast<int32_t>(innermost.AFront()) * static_cast<int32_t>(innermost.BFront()); + } +#else + // For non-VNNI, do the saturation stuff. + int16_t accumulators[sizeof(Register) / sizeof(int16_t)] = {0}; + for (Index inner = 0; inner < problem.inner; inner += 2) { + Access innermost = acc.AAdd(0, inner).BAdd(inner, 0); + int32_t product = static_cast<int32_t>(innermost.AFront()) * static_cast<int32_t>(innermost.BFront()); + innermost = innermost.AAdd(0, 1).BAdd(1, 0); + product += static_cast<int32_t>(innermost.AFront()) * static_cast<int32_t>(innermost.BFront()); + // Saturate to 16-bit for maddubs. + if (product > 32767) product = 32767; + if (product < -32768) product = -32768; + int16_t &accum = accumulators[(inner / 2) % (sizeof(Register) / sizeof(int16_t))]; + // Saturating accumlation. + product += static_cast<int32_t>(accum); + if (product > 32767) product = 32767; + if (product < -32768) product = -32768; + accum = static_cast<int16_t>(product); + } + acc.CFront() = 0; + for (Index i = 0; i < sizeof(Register) / sizeof(int16_t); ++i) { + acc.CFront() += static_cast<int32_t>(accumulators[i]); + } +#endif + } + } +} + +void DumpMatrix(int8_t *m, Index rows, Index cols) { + std::cerr << rows << 'x' << cols << '\n'; + for (Index i = 0; i < rows; ++i) { + for (Index j = 0; j < cols; ++j) { + std::cerr << (int16_t)m[i * cols + j] << ' '; + } + std::cerr << '\n'; + } +} + +#ifndef INTGEMM_THIS_IS_SSE2 +template <class Kernel> void TestMultiplyNoOverhang(Tile shape) { + // These are sanity checks on the arguments, not the code. + CHECK(shape.A_rows % Kernel::kTile.A_rows == 0); + CHECK(shape.inner % Kernel::kTile.inner == 0); + CHECK(shape.B_cols % Kernel::kTile.B_cols == 0); + + AlignedVector<int8_t> A(shape.A_rows * shape.inner); + AlignedVector<int8_t> B(shape.inner * shape.B_cols); + std::mt19937 gen; + std::uniform_int_distribution<int8_t> dist(-127,127); + for (int8_t &it : A) it = dist(gen); + for (int8_t &it : B) it = dist(gen); + + AlignedVector<int32_t> C_reference(shape.A_rows * shape.B_cols); + typedef Access<RowMajorAccess<int8_t>, ColMajorAccess<int8_t>, RowMajorAccess<int32_t> > AccessT; + AccessT ref_access( + RowMajorAccess<int8_t>(A.begin(), shape.inner), + ColMajorAccess<int8_t>(B.begin(), shape.inner), + RowMajorAccess<int32_t>(C_reference.begin(), shape.B_cols)); + Signed8ReferenceMult<AccessT>(ref_access, shape); + + AlignedVector<int32_t> C_test(shape.A_rows * shape.B_cols); + AccessT test_access( + RowMajorAccess<int8_t>(A.begin(), shape.inner), + ColMajorAccess<int8_t>(B.begin(), shape.inner), + RowMajorAccess<int32_t>(C_test.begin(), shape.B_cols)); + MultiplyNoOverhang<AccessT, Kernel>(test_access, shape); + bool failed = false; + for (Index i = 0; i < shape.A_rows; ++i) { + for (Index j = 0; j < shape.B_cols; ++j) { + CHECK(C_reference[i * shape.B_cols + j] == C_test[i * shape.B_cols + j]); + if (C_reference[i * shape.B_cols + j] != C_test[i * shape.B_cols + j]) + failed = true; + } + } + if (failed) { + std::cerr << "Failed A is "; + DumpMatrix(A.begin(), shape.A_rows, shape.inner); + std::cerr << "Failed B is "; + DumpMatrix(B.begin(), shape.inner, shape.B_cols); + } +} + +TEST_CASE("MultiplyNoOverhang Signed8 " INTGEMM_TEST_NAME, "[tile]") { + if (kCPU < CPUType::INTGEMM_ARCH) return; + // Test small multiples. + TestMultiplyNoOverhang<Signed8>(Tile{1,sizeof(Register),1}); + TestMultiplyNoOverhang<Signed8>(Tile{2, sizeof(Register), 1}); + TestMultiplyNoOverhang<Signed8>(Tile{1, 2 * sizeof(Register),1}); + TestMultiplyNoOverhang<Signed8>(Tile{1, sizeof(Register), 2}); + TestMultiplyNoOverhang<Signed8>(Tile{2, 2 * sizeof(Register), 2}); + // Try a bunch of shapes! + Tile shape; + for (shape.A_rows = 0; shape.A_rows <= 33; ++shape.A_rows) { + for (shape.inner = 0; shape.inner <= 9 * sizeof(Register); shape.inner += sizeof(Register)) { + for (shape.B_cols = 0; shape.B_cols <= 33; ++shape.B_cols) { + TestMultiplyNoOverhang<Signed8>(shape); + } + } + } +} + +#endif + } // namespace INTGEMM_ARCH } // namespace intgemm diff --git a/tile/access.h b/tile/access.h index e2fb1a8..d5e6119 100644 --- a/tile/access.h +++ b/tile/access.h @@ -20,6 +20,19 @@ template <class T> class RowMajorAccess { const Content &Front() const { return *data_; } Content &Front() { return *data_; } + // TODO: SLOW. This is here for testing. + template <Index A_rows, Index B_cols> void Write(const __m128i *from) { Write<A_rows, B_cols>(reinterpret_cast<const T*>(from)); } + template <Index A_rows, Index B_cols> void Write(const __m256i *from) { Write<A_rows, B_cols>(reinterpret_cast<const T*>(from)); } + template <Index A_rows, Index B_cols> void Write(const __m512i *from) { Write<A_rows, B_cols>(reinterpret_cast<const T*>(from)); } + + template <Index A_rows, Index B_cols> void Write(const T *from) { + for (Index i = 0; i < A_rows; ++i) { + for (Index j = 0; j < B_cols; ++j) { + data_[i * cols_ + j] = from[i * B_cols + j]; + } + } + } + private: Content *data_; Index cols_; @@ -68,6 +81,11 @@ template <class AT, class BT, class CT> class Access { return Access(a_, b_, c_.Add(row, col)); } + const A &AAccessor() const { return a_; } + const B &BAccessor() const { return b_; } + const C &CAccessor() const { return c_; } + C &CAccessor() { return c_; } + AContent &AFront() { return a_.Front(); } const AContent &AFront() const { return a_.Front(); } BContent &BFront() { return b_.Front(); } diff --git a/tile/multiply.h b/tile/multiply.h new file mode 100644 index 0000000..294b3ed --- /dev/null +++ b/tile/multiply.h @@ -0,0 +1,32 @@ +#pragma once + +#include "access.h" +#include "dot.h" +#include "reduce.h" +#include "../types.h" + +#include <cassert> + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI +#define INTGEMM_THIS_IS_AVX512VNNI +#include "multiply.inl" +#undef INTGEMM_THIS_IS_AVX512VNNI +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +#define INTGEMM_THIS_IS_AVX512BW +#include "multiply.inl" +#undef INTGEMM_THIS_IS_AVX512BW +#endif + +#define INTGEMM_THIS_IS_AVX2 +#include "multiply.inl" +#undef INTGEMM_THIS_IS_AVX2 + +#define INTGEMM_THIS_IS_SSSE3 +#include "multiply.inl" +#undef INTGEMM_THIS_IS_SSSE3 + +#define INTGEMM_THIS_IS_SSE2 +#include "multiply.inl" +#undef INTGEMM_THIS_IS_SSE2 diff --git a/tile/multiply.inl b/tile/multiply.inl new file mode 100644 index 0000000..f1a22cf --- /dev/null +++ b/tile/multiply.inl @@ -0,0 +1,71 @@ +#if defined(INTGEMM_THIS_IS_AVX512VNNI) +#define INTGEMM_ARCH AVX512VNNI +#define INTGEMM_TARGET INTGEMM_AVX512VNNI +#elif defined(INTGEMM_THIS_IS_AVX512BW) +#define INTGEMM_ARCH AVX512BW +#define INTGEMM_TARGET INTGEMM_AVX512BW +#elif defined(INTGEMM_THIS_IS_AVX2) +#define INTGEMM_ARCH AVX2 +#define INTGEMM_TARGET INTGEMM_AVX2 +#elif defined(INTGEMM_THIS_IS_SSSE3) +#define INTGEMM_ARCH SSSE3 +#define INTGEMM_TARGET INTGEMM_SSSE3 +#elif defined(INTGEMM_THIS_IS_SSE2) +#define INTGEMM_ARCH SSE2 +#define INTGEMM_TARGET INTGEMM_SSE2 +#endif + +namespace intgemm { +namespace INTGEMM_ARCH { + +// Upcast 16 to 32 if needed. +template <class T> INTGEMM_TARGET static inline void SumTo32(Register ®); +template <> INTGEMM_TARGET inline void SumTo32<int16_t>(Register ®) { + reg = madd_epi16(reg, set1_epi16<Register>(1)); +} +template <> INTGEMM_TARGET inline void SumTo32<int32_t>(Register &) {} + +template <class T> struct SumTo32Body { + template <class Iterator> INTGEMM_TARGET static inline void body(Register *regs) { + SumTo32<T>(regs[Iterator::template I<0>()]); + } +}; + +/* Multiply assuming the matrix sizes are a multiple of the kernel size. */ +template <class AccessT, class Kernel> INTGEMM_TARGET static inline void MultiplyNoOverhang(AccessT access, const Tile shape) { + assert(shape.A_rows % Kernel::kTile.A_rows == 0); + assert(shape.inner % Kernel::kTile.inner == 0); + assert(shape.B_cols % Kernel::kTile.B_cols == 0); + constexpr Index Outputs = Kernel::kTile.A_rows * Kernel::kTile.B_cols; + for (Index B_col = 0; B_col < shape.B_cols; B_col += Kernel::kTile.B_cols) { + AccessT column_adjusted = access.BAdd(0, B_col).CAdd(0, B_col); + for (Index A_row = 0; A_row < shape.A_rows; A_row += Kernel::kTile.A_rows) { + AccessT col_row = column_adjusted.AAdd(A_row, 0).CAdd(A_row, 0); + + // Accumulate values in temporary C registers. + Register c_regs[Outputs] = {setzero_si<Register>()}; + // If C is column major it would be better to have column-major registers + // since this determines the order used by Reduce32. + Access<typename AccessT::A, typename AccessT::B, RegisterRowMajorAccess> reg_access( + col_row.AAccessor(), + col_row.BAccessor(), + RegisterRowMajorAccess(c_regs, Kernel::kTile.B_cols)); + + for (Index inner = 0; inner < shape.inner; inner += Kernel::kTile.inner) { + Kernel::Run(reg_access.AAdd(0, inner).BAdd(inner, 0)); + } + + // If 16-bit, upcast to 32-bit while horizontally adding. + StaticLoop<SumTo32Body<typename Kernel::Packed::C>, MakeStaticLoopIterator<Outputs>>(c_regs); + // Horizontally add 32-bit values. + Reduce32<Outputs, Sum32Op>(c_regs); + col_row.CAccessor().template Write<Kernel::kTile.A_rows, Kernel::kTile.B_cols>(c_regs); + } + } +} + +} // namespace INTGEMM_ARCH +} // namespace intgemm + +#undef INTGEMM_ARCH +#undef INTGEMM_TARGET diff --git a/tile/reduce.inl b/tile/reduce.inl index f070723..ef9afc9 100644 --- a/tile/reduce.inl +++ b/tile/reduce.inl @@ -39,6 +39,7 @@ template <Index Valid, class Op, class Folder> INTGEMM_TARGET static inline void /* These Folder structs say how to interweave even pairs of regiers and * fold an odd register over itself. Folding an odd register over itself is * slightly faster than doing an even fold with garbage. */ +// TODO: _mm_hadd_epi32 for SSSE3 and _mm256_hadd_epi32 for AVX2 struct Reduce32Folder { INTGEMM_TARGET static inline RegisterPair Even(Register first, Register second) { return RegisterPair { unpackhi_epi32(first, second), unpacklo_epi32(first, second) }; @@ -172,8 +173,7 @@ template <class Op> struct ReduceOverhang<3, Op> { #endif -/* The only public function: horizontally reduce registers with 32-bit values. - */ +/* Public function: horizontally reduce registers with 32-bit values. */ template <Index Valid, class Op> INTGEMM_TARGET static inline void Reduce32(Register *regs) { GenericReduce<Valid, Op, Reduce32Folder>(regs); GenericReduce<(Valid + 1) / 2, Op, Reduce64Folder>(regs); @@ -49,6 +49,9 @@ enum class CPUType { extern const CPUType kCPU; struct Tile { +/* Tile() {} + Tile(Index in_A_rows, Index in_inner, Index in_B_cols) + : A_rows(in_A_rows), inner(in_inner), B_cols(in_B_cols) {} */ Index A_rows, inner, B_cols; }; |