diff options
author | Kenneth Heafield <github@kheafield.com> | 2020-04-04 19:43:27 +0300 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2020-04-04 19:43:27 +0300 |
commit | c35e3d2281ae50af675c8a323726620f9e23332b (patch) | |
tree | ee000c585ea0e385049fc276cf25446874a39549 | |
parent | 57fe315ab631fb44c909f90868584536d05ec0ce (diff) |
Does AVX512 reduce work?
-rw-r--r-- | tile/reduce.inl | 54 |
1 files changed, 44 insertions, 10 deletions
diff --git a/tile/reduce.inl b/tile/reduce.inl index c3f7470..ab91b32 100644 --- a/tile/reduce.inl +++ b/tile/reduce.inl @@ -82,10 +82,6 @@ struct Pack128Folder { _mm512_mask_permutex_epi64(second, 0x33, first, 2 | (3 << 2)) }; } - INTGEMM_TARGET static inline AVX2::RegisterPair Odd(Register reg) { - return AVX2::RegisterPair { _mm512_castsi512_si256(reg), _mm512_extracti64x4_epi64(reg, 1) }; - } - INTGEMM_TARGET static inline Register OddUpcast(AVX2::Register reg) { return _mm512_castsi256_si512(reg); } }; struct Pack256Folder { @@ -97,10 +93,6 @@ struct Pack256Folder { _mm512_mask_blend_epi64(0xf0, first, second) }; } - INTGEMM_TARGET static inline AVX2::RegisterPair Odd(Register reg) { - return AVX2::RegisterPair { _mm512_castsi512_si256(reg), _mm512_extracti64x4_epi64(reg, 1) }; - } - INTGEMM_TARGET static inline Register OddUpcast(AVX2::Register reg) { return _mm512_castsi256_si512(reg); } }; template <class Op> struct PackFours { @@ -111,13 +103,52 @@ template <class Op> struct PackFours { // Do 256-bit interleaving first because it's slightly cheaper. RegisterPair mix0pair = Pack256Folder::Even(in[0], in[2]); RegisterPair mix1pair = Pack256Folder::Even(in[1], in[3]); + // 0 0 2 2 Register mix0 = Op::Run(mix0pair.hi, mix0pair.lo); + // 1 1 3 3 Register mix1 = Op::Run(mix1pair.hi, mix1pair.lo); mix0pair = Pack128Folder::Even(mix0, mix1); regs[i] = Op::Run(mix0pair.hi, mix0pair.lo); } }; +// non-type partial specialization ‘PackOverhang<0, Op>’ is not allowed +template <Index Valid, class Op> struct PackOverhang; + +template <class Op> struct PackOverhang<0, Op> { + INTGEMM_TARGET static inline void Run(const Register *, Register &) {} +}; +template <class Op> struct PackOverhang<1, Op> { + // Overhang of 1 register: fold it overself to SSE2. + INTGEMM_TARGET static inline void Run(const Register *regs, Register &to) { + AVX2::Register folded = Op::Run(_mm512_castsi512_si256(regs[0]), _mm512_extracti64x4_epi64(regs[0], 1)); + SSE2::RegisterPair pair = AVX2::Pack128Folder::Odd(folded); + SSE2::Register more = Op::Run(pair.hi, pair.lo); + to = _mm512_castsi128_si512(more); + } +}; +template <class Op> struct PackOverhang<2, Op> { + // Overhang of 2 registers: fold to AVX2. + INTGEMM_TARGET static inline void Run(const Register *regs, Register &to) { + RegisterPair mixpair = Pack128Folder::Even(regs[0], regs[1]); + Register mix = Op::Run(mixpair.hi, mixpair.lo); + AVX2::Register folded = Op::Run(_mm512_castsi512_si256(mix), _mm512_extracti64x4_epi64(mix, 1)); + to = _mm512_castsi256_si512(folded); + } +}; +template <class Op> struct PackOverhang<3, Op> { + INTGEMM_TARGET static inline void Run(const Register *regs, Register &to) { + RegisterPair mix0pair = Pack256Folder::Even(regs[0], regs[2]); + Register mix0022 = Op::Run(mix0pair.hi, mix0pair.lo); + // mix0022 128-bit bit blocks: 0 0 2 2 + + AVX2::Register fold11 = Op::Run(_mm512_castsi512_si256(regs[1]), _mm512_extracti64x4_epi64(regs[1], 1)); + // fold11 128-bit blocks: 1 1 + + RegisterPair finish = Pack128Folder::Even(mix0022, _mm512_castsi256_si512(fold11)); + to = Op::Run(finish.hi, finish.lo); + } +}; #endif @@ -128,8 +159,11 @@ template <Index Valid, class Op> INTGEMM_TARGET static inline void Pack32(Regist #if defined(INTGEMM_THIS_IS_AVX2) GenericPack<(Valid + 3) / 4, Op, Pack128Folder>(regs); #elif defined(INTGEMM_THIS_IS_AVX512BW) - StaticLoop<PackFours<Op>, MakeStaticLoopIterator<(Valid / 4)>>(regs); - // TODO: non-multiples of 4 registers. + // Special handling for AVX512BW because we need to fold twice and it can actually go all the way down to SSE2. + constexpr Index remaining = (Valid + 3) / 4; + // Handle registers a multiple of 4. + StaticLoop<PackFours<Op>, MakeStaticLoopIterator<(remaining / 4)>>(regs); + PackOverhang<remaining & 3, Op>::Run(regs + (remaining & ~3), *(regs + remaining / 4)); #endif } |