diff options
author | Kenneth Heafield <github@kheafield.com> | 2020-04-04 01:01:15 +0300 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2020-04-04 01:01:15 +0300 |
commit | 57fe315ab631fb44c909f90868584536d05ec0ce (patch) | |
tree | 329693c7ee39c8f40bf5c680d70ef10b38d0512f | |
parent | 9af1ca24ea9a427771f29dd2ab4f7e9a889bf934 (diff) |
Reduce working for SSE2 and AVX2, working on AVX512
-rw-r--r-- | test/tile_test.inl | 61 | ||||
-rw-r--r-- | tile/reduce.h | 12 | ||||
-rw-r--r-- | tile/reduce.inl | 125 |
3 files changed, 137 insertions, 61 deletions
diff --git a/test/tile_test.inl b/test/tile_test.inl index 98a1300..c612ace 100644 --- a/test/tile_test.inl +++ b/test/tile_test.inl @@ -57,43 +57,38 @@ TEST_CASE("Basic Tile " INTGEMM_TEST_NAME, "[tile]") { } } -INTGEMM_TARGET void DumpRegister(Register reg) { - int32_t values[sizeof(Register) / sizeof(int32_t)]; - memcpy(values, ®, sizeof(Register)); - for (std::size_t i = 0; i < sizeof(Register) / sizeof(int32_t); ++i) { - std::cout.width(11); - std::cout << values[i] << ' '; - } -} - -INTGEMM_TARGET void Pack32Test() { - const std::size_t kPack = sizeof(Register) / sizeof(int32_t); - Register regs[kPack]; - std::mt19937 gen; - //std::uniform_int_distribution<int32_t> dist(std::numeric_limits<int32_t>::min(), std::numeric_limits<int32_t>::max()); - std::uniform_int_distribution<int32_t> dist(0, 100); - std::vector<int32_t> reference(kPack, 0); - for (std::size_t i = 0; i < kPack; ++i) { - int32_t temp[kPack]; - for (std::size_t j = 0; j < kPack; ++j) { - temp[j] = dist(gen); - reference[j] += temp[j]; +struct Pack32Test { + template <typename Iterator> INTGEMM_TARGET static void body() { + constexpr Index Valid = Iterator::template I<0>(); + // A zero-length array is a compiler error, so force it to be longer. + constexpr Index ArrayLen = Valid ? Valid : 1; + const std::size_t kPack = sizeof(Register) / sizeof(int32_t); + Register regs[ArrayLen]; + std::mt19937 gen; + std::uniform_int_distribution<int32_t> dist(std::numeric_limits<int32_t>::min(), std::numeric_limits<int32_t>::max()); + int32_t reference[ArrayLen]; + memset(reference, 0, sizeof(reference)); + for (Index i = 0; i < Valid; ++i) { + int32_t temp[kPack]; + for (std::size_t j = 0; j < kPack; ++j) { + temp[j] = dist(gen); + reference[i] += temp[j]; + } + memcpy(®s[i], temp, sizeof(Register)); + } + // Decay type for template. + Register *indirect = regs; + Pack32<Valid, Sum32Op>(indirect); + const int32_t *test = reinterpret_cast<const int32_t*>(regs); + for (Index i = 0; i < Valid; ++i) { + CHECK(test[i] == reference[i]); } - memcpy(®s[i], temp, sizeof(Register)); - } - Register *indirect = regs; - for (std::size_t i = 0; i < 4; ++i) { - DumpRegister(indirect[i]); - std::cout << '\n'; } - Pack32<3, Sum32Op>(indirect); - DumpRegister(indirect[0]); - std::cout << '\n'; -} +}; TEST_CASE("Reduce " INTGEMM_TEST_NAME, "[tile]") { - if (kCPU >= CPUType::INTGEMM_ARCH) - Pack32Test(); + if (kCPU < CPUType::INTGEMM_ARCH) return; + StaticLoop<Pack32Test, MakeStaticLoopIterator<33>>(); } } // namespace INTGEMM_ARCH diff --git a/tile/reduce.h b/tile/reduce.h index 0ab4e0d..51cf719 100644 --- a/tile/reduce.h +++ b/tile/reduce.h @@ -24,19 +24,19 @@ struct Sum32Op { } // namespace intgemm -#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW -#define INTGEMM_THIS_IS_AVX512BW +#define INTGEMM_THIS_IS_SSE2 #include "reduce.inl" -#undef INTGEMM_THIS_IS_AVX512BW -#endif +#undef INTGEMM_THIS_IS_SSE2 #define INTGEMM_THIS_IS_AVX2 #include "reduce.inl" #undef INTGEMM_THIS_IS_AVX2 -#define INTGEMM_THIS_IS_SSE2 +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +#define INTGEMM_THIS_IS_AVX512BW #include "reduce.inl" -#undef INTGEMM_THIS_IS_SSE2 +#undef INTGEMM_THIS_IS_AVX512BW +#endif namespace intgemm { diff --git a/tile/reduce.inl b/tile/reduce.inl index 3bfa899..c3f7470 100644 --- a/tile/reduce.inl +++ b/tile/reduce.inl @@ -14,42 +14,123 @@ namespace intgemm { namespace INTGEMM_ARCH { -template <class Op> struct Pack64Even { +struct RegisterPair { Register hi; Register lo; }; + +template <class Op, class Folder> struct PackEvens { template <class Iterator> INTGEMM_TARGET static inline void body(Register *regs) { const Index i = Iterator::template I<0>(); - Register hi = unpackhi_epi64(regs[2 * i], regs[2 * i + 1]); - Register lo = unpacklo_epi64(regs[2 * i], regs[2 * i + 1]); - regs[i] = Op::Run(hi, lo); + RegisterPair ret = Folder::Even(regs[2 * i], regs[2 * i + 1]); + regs[i] = Op::Run(ret.hi, ret.lo); } }; -template <Index Valid, class Op> INTGEMM_TARGET static inline void Pack64(Register *regs) { - StaticLoop<Pack64Even<Op>, MakeStaticLoopIterator<Valid / 2>>(regs); + +template <Index Valid, class Op, class Folder> INTGEMM_TARGET static inline void GenericPack(Register *regs) { + StaticLoop<PackEvens<Op, Folder>, MakeStaticLoopIterator<Valid / 2>>(regs); if (Valid & 1) { - // For the odd case, shuffle to form 0 g where g is garbage and 0 is accumlated. - Register shuffled = shuffle_epi32(regs[Valid - 1], 0xB0 /* CDAA */); - regs[Valid / 2] = Op::Run(shuffled, regs[Valid - 1]); + auto values = Folder::Odd(regs[Valid - 1]); + regs[Valid / 2] = Folder::OddUpcast(Op::Run(values.lo, values.hi)); } - // Now [0, (Valid + 1) / 2) contains registers to pack with 128-bit interleaving. } -template <class Op> struct Pack32Even { +struct Pack32Folder { + INTGEMM_TARGET static inline RegisterPair Even(Register first, Register second) { + return RegisterPair { unpackhi_epi32(first, second), unpacklo_epi32(first, second) }; + } + INTGEMM_TARGET static inline RegisterPair Odd(Register reg) { + // For the odd case, shuffle to form 0 g 0 g where g is garbage and 0 is accumlated. + return RegisterPair { reg, shuffle_epi32(reg, 0x31) }; + } + INTGEMM_TARGET static inline Register OddUpcast(Register reg) { return reg; } +}; + +struct Pack64Folder { + INTGEMM_TARGET static inline RegisterPair Even(Register first, Register second) { + return RegisterPair { unpackhi_epi64(first, second), unpacklo_epi64(first, second) }; + } + INTGEMM_TARGET static inline RegisterPair Odd(Register reg) { + // For the odd case, shuffle to form 0 g where g is garbage and 0 is accumlated. + return RegisterPair { reg, shuffle_epi32(reg, 3 * 4 + 2) }; + } + INTGEMM_TARGET static inline Register OddUpcast(Register reg) { return reg; } +}; + +#ifdef INTGEMM_THIS_IS_AVX2 +struct Pack128Folder { + INTGEMM_TARGET static inline RegisterPair Even(Register first, Register second) { + return RegisterPair { + // This instruction generates 0s 1s 2s 3s 4f 5f 6f 7f + _mm256_permute2f128_si256(first, second, 0x21), + // This instruction generates 0f 1f 2f 3f 4s 5s 6s 7s + _mm256_blend_epi32(first, second, 0xf0) + }; + } + INTGEMM_TARGET static inline SSE2::RegisterPair Odd(Register reg) { + return SSE2::RegisterPair { _mm256_extracti128_si256(reg, 1), _mm256_castsi256_si128(reg) }; + } + INTGEMM_TARGET static inline Register OddUpcast(SSE2::Register reg) { return _mm256_castsi128_si256(reg); } +}; +#endif + +#ifdef INTGEMM_THIS_IS_AVX512BW +struct Pack128Folder { + INTGEMM_TARGET static inline RegisterPair Even(Register first, Register second) { + // TODO can this be optimized with a blend and a shuffle instruction? + return RegisterPair { + // Form [0th 128-bit of first, 0th 128-bit second, 2nd 128-bit of first, 2nd 128-bit of second] + _mm512_mask_permutex_epi64(first, 0xcc, second, (0 << 4) | (1 << 6)), + // Form [1st 128-bit of first, 1st 128-bit of second, 3rd 128-bit of first, 3rd 128-bit of second] + _mm512_mask_permutex_epi64(second, 0x33, first, 2 | (3 << 2)) + }; + } + INTGEMM_TARGET static inline AVX2::RegisterPair Odd(Register reg) { + return AVX2::RegisterPair { _mm512_castsi512_si256(reg), _mm512_extracti64x4_epi64(reg, 1) }; + } + INTGEMM_TARGET static inline Register OddUpcast(AVX2::Register reg) { return _mm512_castsi256_si512(reg); } +}; + +struct Pack256Folder { + INTGEMM_TARGET static inline RegisterPair Even(Register first, Register second) { + return RegisterPair { + // This instruction generates first[2] first[3] second[0] second[1] + _mm512_shuffle_i64x2(first, second, 2 | (3 << 2) | (0 << 4) | (1 << 6)), + // This instruction generates first[0] first[1] second[2] second[3] + _mm512_mask_blend_epi64(0xf0, first, second) + }; + } + INTGEMM_TARGET static inline AVX2::RegisterPair Odd(Register reg) { + return AVX2::RegisterPair { _mm512_castsi512_si256(reg), _mm512_extracti64x4_epi64(reg, 1) }; + } + INTGEMM_TARGET static inline Register OddUpcast(AVX2::Register reg) { return _mm512_castsi256_si512(reg); } +}; + +template <class Op> struct PackFours { + // Collapse 4 AVX512 registers at once, interleaving 128-bit fields. template <class Iterator> INTGEMM_TARGET static inline void body(Register *regs) { const Index i = Iterator::template I<0>(); - Register hi = unpackhi_epi32(regs[2 * i], regs[2 * i + 1]); - Register lo = unpacklo_epi32(regs[2 * i], regs[2 * i + 1]); - regs[i] = Op::Run(hi, lo); + const Register *in = regs + i * 4; + // Do 256-bit interleaving first because it's slightly cheaper. + RegisterPair mix0pair = Pack256Folder::Even(in[0], in[2]); + RegisterPair mix1pair = Pack256Folder::Even(in[1], in[3]); + Register mix0 = Op::Run(mix0pair.hi, mix0pair.lo); + Register mix1 = Op::Run(mix1pair.hi, mix1pair.lo); + mix0pair = Pack128Folder::Even(mix0, mix1); + regs[i] = Op::Run(mix0pair.hi, mix0pair.lo); } }; + +#endif + template <Index Valid, class Op> INTGEMM_TARGET static inline void Pack32(Register *regs) { - StaticLoop<Pack32Even<Op>, MakeStaticLoopIterator<Valid / 2>>(regs); - if (Valid & 1) { - // For the odd case, shuffle to form 0 g 0 g where g is garbage and 0 is accumlated. - Register shuffled = shuffle_epi32(regs[Valid - 1], 0x4C /* BADA */); - regs[Valid / 2] = Op::Run(shuffled, regs[Valid - 1]); - } - // Now [0, (Valid + 1) / 2) contains registers to pack with 64-bit interleaving. - Pack64<(Valid + 1) / 2, Op>(regs); + GenericPack<Valid, Op, Pack32Folder>(regs); + GenericPack<(Valid + 1) / 2, Op, Pack64Folder>(regs); + // SSE2 is done. +#if defined(INTGEMM_THIS_IS_AVX2) + GenericPack<(Valid + 3) / 4, Op, Pack128Folder>(regs); +#elif defined(INTGEMM_THIS_IS_AVX512BW) + StaticLoop<PackFours<Op>, MakeStaticLoopIterator<(Valid / 4)>>(regs); + // TODO: non-multiples of 4 registers. +#endif } } // namespace INTGEMM_ARCH |