diff options
author | Kenneth Heafield <github@kheafield.com> | 2020-04-19 14:22:40 +0300 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2020-04-19 14:22:40 +0300 |
commit | a876043147c9ee6b7992b0bc23acd24ed475f7ea (patch) | |
tree | 5da5a049a2532dc11d7fa461a45f367e481d8884 | |
parent | a3a6a9b845ed5dc51e81e7cd8b9b9ba84855edaa (diff) |
Switch reduce to taking RegisterPair
-rw-r--r-- | tile/reduce.h | 18 | ||||
-rw-r--r-- | tile/reduce.inl | 35 |
2 files changed, 24 insertions, 29 deletions
diff --git a/tile/reduce.h b/tile/reduce.h index 188c7e9..7cffc88 100644 --- a/tile/reduce.h +++ b/tile/reduce.h @@ -34,19 +34,25 @@ namespace intgemm { +namespace SSE2 { struct RegisterPair { Register hi; Register lo; }; } +namespace AVX2 { struct RegisterPair { Register hi; Register lo; }; } +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW +namespace AVX512BW { struct RegisterPair { Register hi; Register lo; }; } +#endif + // Op argument appropriate for summing 32-bit integers. struct Sum32Op { - INTGEMM_SSE2 static inline __m128i Run(__m128i first, __m128i second) { - return add_epi32(first, second); + INTGEMM_SSE2 static inline __m128i Run(SSE2::RegisterPair regs) { + return add_epi32(regs.hi, regs.lo); } - INTGEMM_AVX2 static inline __m256i Run(__m256i first, __m256i second) { - return add_epi32(first, second); + INTGEMM_AVX2 static inline __m256i Run(AVX2::RegisterPair regs) { + return add_epi32(regs.hi, regs.lo); } #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW - INTGEMM_AVX512BW static inline __m512i Run(__m512i first, __m512i second) { - return add_epi32(first, second); + INTGEMM_AVX512BW static inline __m512i Run(AVX512BW::RegisterPair regs) { + return add_epi32(regs.hi, regs.lo); } #endif }; diff --git a/tile/reduce.inl b/tile/reduce.inl index ef9afc9..9106c52 100644 --- a/tile/reduce.inl +++ b/tile/reduce.inl @@ -16,14 +16,11 @@ namespace intgemm { namespace INTGEMM_ARCH { -struct RegisterPair { Register hi; Register lo; }; - /* Static loop callback for folding an even number of registers. */ template <class Op, class Folder> struct ReduceEvens { template <class Iterator> INTGEMM_TARGET static inline void body(Register *regs) { const Index i = Iterator::template I<0>(); - RegisterPair ret = Folder::Even(regs[2 * i], regs[2 * i + 1]); - regs[i] = Op::Run(ret.hi, ret.lo); + regs[i] = Op::Run(Folder::Even(regs[2 * i], regs[2 * i + 1])); } }; /* Call a fold object to reduce one width. Does a static loop over pairs of @@ -31,8 +28,7 @@ template <class Op, class Folder> struct ReduceEvens { template <Index Valid, class Op, class Folder> INTGEMM_TARGET static inline void GenericReduce(Register *regs) { StaticLoop<ReduceEvens<Op, Folder>, MakeStaticLoopIterator<Valid / 2>>(regs); if (Valid & 1) { - auto values = Folder::Odd(regs[Valid - 1]); - regs[Valid / 2] = Folder::OddUpcast(Op::Run(values.lo, values.hi)); + regs[Valid / 2] = Folder::OddUpcast(Op::Run(Folder::Odd(regs[Valid - 1]))); } } @@ -114,14 +110,11 @@ template <class Op> struct ReduceFours { const Index i = Iterator::template I<0>(); const Register *in = regs + i * 4; // Do 256-bit interleaving first because it's slightly cheaper. - RegisterPair mix0pair = Reduce256Folder::Even(in[0], in[2]); - RegisterPair mix1pair = Reduce256Folder::Even(in[1], in[3]); // 0 0 2 2 - Register mix0 = Op::Run(mix0pair.hi, mix0pair.lo); + Register mix0 = Op::Run(Reduce256Folder::Even(in[0], in[2])); // 1 1 3 3 - Register mix1 = Op::Run(mix1pair.hi, mix1pair.lo); - mix0pair = Reduce128Folder::Even(mix0, mix1); - regs[i] = Op::Run(mix0pair.hi, mix0pair.lo); + Register mix1 = Op::Run(Reduce256Folder::Even(in[1], in[3])); + regs[i] = Op::Run(Reduce128Folder::Even(mix0, mix1)); } }; @@ -140,9 +133,8 @@ template <class Op> struct ReduceOverhang<0, Op> { // Overhang of 1 AVX512 register. Fold over itself going down to SSE2. template <class Op> struct ReduceOverhang<1, Op> { INTGEMM_TARGET static inline void Run(const Register *regs, Register &to) { - AVX2::Register folded = Op::Run(_mm512_castsi512_si256(regs[0]), _mm512_extracti64x4_epi64(regs[0], 1)); - SSE2::RegisterPair pair = AVX2::Reduce128Folder::Odd(folded); - SSE2::Register more = Op::Run(pair.hi, pair.lo); + AVX2::Register folded = Op::Run(AVX2::RegisterPair {_mm512_castsi512_si256(regs[0]), _mm512_extracti64x4_epi64(regs[0], 1)}); + SSE2::Register more = Op::Run(AVX2::Reduce128Folder::Odd(folded)); to = _mm512_castsi128_si512(more); } }; @@ -150,24 +142,21 @@ template <class Op> struct ReduceOverhang<1, Op> { template <class Op> struct ReduceOverhang<2, Op> { // Overhang of 2 registers: fold to AVX2. INTGEMM_TARGET static inline void Run(const Register *regs, Register &to) { - RegisterPair mixpair = Reduce128Folder::Even(regs[0], regs[1]); - Register mix = Op::Run(mixpair.hi, mixpair.lo); - AVX2::Register folded = Op::Run(_mm512_castsi512_si256(mix), _mm512_extracti64x4_epi64(mix, 1)); + Register mix = Op::Run(Reduce128Folder::Even(regs[0], regs[1])); + AVX2::Register folded = Op::Run(AVX2::RegisterPair{_mm512_castsi512_si256(mix), _mm512_extracti64x4_epi64(mix, 1)}); to = _mm512_castsi256_si512(folded); } }; // Overhang of 3 AVX512 registers. Fold two together and one overitself. template <class Op> struct ReduceOverhang<3, Op> { INTGEMM_TARGET static inline void Run(const Register *regs, Register &to) { - RegisterPair mix0pair = Reduce256Folder::Even(regs[0], regs[2]); - Register mix0022 = Op::Run(mix0pair.hi, mix0pair.lo); + Register mix0022 = Op::Run(Reduce256Folder::Even(regs[0], regs[2])); // mix0022 128-bit bit blocks: 0 0 2 2 - AVX2::Register fold11 = Op::Run(_mm512_castsi512_si256(regs[1]), _mm512_extracti64x4_epi64(regs[1], 1)); + AVX2::Register fold11 = Op::Run(AVX2::RegisterPair{_mm512_castsi512_si256(regs[1]), _mm512_extracti64x4_epi64(regs[1], 1)}); // fold11 128-bit blocks: 1 1 - RegisterPair finish = Reduce128Folder::Even(mix0022, _mm512_castsi256_si512(fold11)); - to = Op::Run(finish.hi, finish.lo); + to = Op::Run(Reduce128Folder::Even(mix0022, _mm512_castsi256_si512(fold11))); } }; |