Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2020-04-04 19:43:27 +0300
committerKenneth Heafield <github@kheafield.com>2020-04-04 19:43:27 +0300
commitc35e3d2281ae50af675c8a323726620f9e23332b (patch)
treeee000c585ea0e385049fc276cf25446874a39549
parent57fe315ab631fb44c909f90868584536d05ec0ce (diff)
Does AVX512 reduce work?
-rw-r--r--tile/reduce.inl54
1 files changed, 44 insertions, 10 deletions
diff --git a/tile/reduce.inl b/tile/reduce.inl
index c3f7470..ab91b32 100644
--- a/tile/reduce.inl
+++ b/tile/reduce.inl
@@ -82,10 +82,6 @@ struct Pack128Folder {
_mm512_mask_permutex_epi64(second, 0x33, first, 2 | (3 << 2))
};
}
- INTGEMM_TARGET static inline AVX2::RegisterPair Odd(Register reg) {
- return AVX2::RegisterPair { _mm512_castsi512_si256(reg), _mm512_extracti64x4_epi64(reg, 1) };
- }
- INTGEMM_TARGET static inline Register OddUpcast(AVX2::Register reg) { return _mm512_castsi256_si512(reg); }
};
struct Pack256Folder {
@@ -97,10 +93,6 @@ struct Pack256Folder {
_mm512_mask_blend_epi64(0xf0, first, second)
};
}
- INTGEMM_TARGET static inline AVX2::RegisterPair Odd(Register reg) {
- return AVX2::RegisterPair { _mm512_castsi512_si256(reg), _mm512_extracti64x4_epi64(reg, 1) };
- }
- INTGEMM_TARGET static inline Register OddUpcast(AVX2::Register reg) { return _mm512_castsi256_si512(reg); }
};
template <class Op> struct PackFours {
@@ -111,13 +103,52 @@ template <class Op> struct PackFours {
// Do 256-bit interleaving first because it's slightly cheaper.
RegisterPair mix0pair = Pack256Folder::Even(in[0], in[2]);
RegisterPair mix1pair = Pack256Folder::Even(in[1], in[3]);
+ // 0 0 2 2
Register mix0 = Op::Run(mix0pair.hi, mix0pair.lo);
+ // 1 1 3 3
Register mix1 = Op::Run(mix1pair.hi, mix1pair.lo);
mix0pair = Pack128Folder::Even(mix0, mix1);
regs[i] = Op::Run(mix0pair.hi, mix0pair.lo);
}
};
+// non-type partial specialization ‘PackOverhang<0, Op>’ is not allowed
+template <Index Valid, class Op> struct PackOverhang;
+
+template <class Op> struct PackOverhang<0, Op> {
+ INTGEMM_TARGET static inline void Run(const Register *, Register &) {}
+};
+template <class Op> struct PackOverhang<1, Op> {
+ // Overhang of 1 register: fold it overself to SSE2.
+ INTGEMM_TARGET static inline void Run(const Register *regs, Register &to) {
+ AVX2::Register folded = Op::Run(_mm512_castsi512_si256(regs[0]), _mm512_extracti64x4_epi64(regs[0], 1));
+ SSE2::RegisterPair pair = AVX2::Pack128Folder::Odd(folded);
+ SSE2::Register more = Op::Run(pair.hi, pair.lo);
+ to = _mm512_castsi128_si512(more);
+ }
+};
+template <class Op> struct PackOverhang<2, Op> {
+ // Overhang of 2 registers: fold to AVX2.
+ INTGEMM_TARGET static inline void Run(const Register *regs, Register &to) {
+ RegisterPair mixpair = Pack128Folder::Even(regs[0], regs[1]);
+ Register mix = Op::Run(mixpair.hi, mixpair.lo);
+ AVX2::Register folded = Op::Run(_mm512_castsi512_si256(mix), _mm512_extracti64x4_epi64(mix, 1));
+ to = _mm512_castsi256_si512(folded);
+ }
+};
+template <class Op> struct PackOverhang<3, Op> {
+ INTGEMM_TARGET static inline void Run(const Register *regs, Register &to) {
+ RegisterPair mix0pair = Pack256Folder::Even(regs[0], regs[2]);
+ Register mix0022 = Op::Run(mix0pair.hi, mix0pair.lo);
+ // mix0022 128-bit bit blocks: 0 0 2 2
+
+ AVX2::Register fold11 = Op::Run(_mm512_castsi512_si256(regs[1]), _mm512_extracti64x4_epi64(regs[1], 1));
+ // fold11 128-bit blocks: 1 1
+
+ RegisterPair finish = Pack128Folder::Even(mix0022, _mm512_castsi256_si512(fold11));
+ to = Op::Run(finish.hi, finish.lo);
+ }
+};
#endif
@@ -128,8 +159,11 @@ template <Index Valid, class Op> INTGEMM_TARGET static inline void Pack32(Regist
#if defined(INTGEMM_THIS_IS_AVX2)
GenericPack<(Valid + 3) / 4, Op, Pack128Folder>(regs);
#elif defined(INTGEMM_THIS_IS_AVX512BW)
- StaticLoop<PackFours<Op>, MakeStaticLoopIterator<(Valid / 4)>>(regs);
- // TODO: non-multiples of 4 registers.
+ // Special handling for AVX512BW because we need to fold twice and it can actually go all the way down to SSE2.
+ constexpr Index remaining = (Valid + 3) / 4;
+ // Handle registers a multiple of 4.
+ StaticLoop<PackFours<Op>, MakeStaticLoopIterator<(remaining / 4)>>(regs);
+ PackOverhang<remaining & 3, Op>::Run(regs + (remaining & ~3), *(regs + remaining / 4));
#endif
}