Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2020-04-19 14:22:40 +0300
committerKenneth Heafield <github@kheafield.com>2020-04-19 14:22:40 +0300
commita876043147c9ee6b7992b0bc23acd24ed475f7ea (patch)
tree5da5a049a2532dc11d7fa461a45f367e481d8884
parenta3a6a9b845ed5dc51e81e7cd8b9b9ba84855edaa (diff)
Switch reduce to taking RegisterPair
-rw-r--r--tile/reduce.h18
-rw-r--r--tile/reduce.inl35
2 files changed, 24 insertions, 29 deletions
diff --git a/tile/reduce.h b/tile/reduce.h
index 188c7e9..7cffc88 100644
--- a/tile/reduce.h
+++ b/tile/reduce.h
@@ -34,19 +34,25 @@
namespace intgemm {
+namespace SSE2 { struct RegisterPair { Register hi; Register lo; }; }
+namespace AVX2 { struct RegisterPair { Register hi; Register lo; }; }
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
+namespace AVX512BW { struct RegisterPair { Register hi; Register lo; }; }
+#endif
+
// Op argument appropriate for summing 32-bit integers.
struct Sum32Op {
- INTGEMM_SSE2 static inline __m128i Run(__m128i first, __m128i second) {
- return add_epi32(first, second);
+ INTGEMM_SSE2 static inline __m128i Run(SSE2::RegisterPair regs) {
+ return add_epi32(regs.hi, regs.lo);
}
- INTGEMM_AVX2 static inline __m256i Run(__m256i first, __m256i second) {
- return add_epi32(first, second);
+ INTGEMM_AVX2 static inline __m256i Run(AVX2::RegisterPair regs) {
+ return add_epi32(regs.hi, regs.lo);
}
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
- INTGEMM_AVX512BW static inline __m512i Run(__m512i first, __m512i second) {
- return add_epi32(first, second);
+ INTGEMM_AVX512BW static inline __m512i Run(AVX512BW::RegisterPair regs) {
+ return add_epi32(regs.hi, regs.lo);
}
#endif
};
diff --git a/tile/reduce.inl b/tile/reduce.inl
index ef9afc9..9106c52 100644
--- a/tile/reduce.inl
+++ b/tile/reduce.inl
@@ -16,14 +16,11 @@
namespace intgemm {
namespace INTGEMM_ARCH {
-struct RegisterPair { Register hi; Register lo; };
-
/* Static loop callback for folding an even number of registers. */
template <class Op, class Folder> struct ReduceEvens {
template <class Iterator> INTGEMM_TARGET static inline void body(Register *regs) {
const Index i = Iterator::template I<0>();
- RegisterPair ret = Folder::Even(regs[2 * i], regs[2 * i + 1]);
- regs[i] = Op::Run(ret.hi, ret.lo);
+ regs[i] = Op::Run(Folder::Even(regs[2 * i], regs[2 * i + 1]));
}
};
/* Call a fold object to reduce one width. Does a static loop over pairs of
@@ -31,8 +28,7 @@ template <class Op, class Folder> struct ReduceEvens {
template <Index Valid, class Op, class Folder> INTGEMM_TARGET static inline void GenericReduce(Register *regs) {
StaticLoop<ReduceEvens<Op, Folder>, MakeStaticLoopIterator<Valid / 2>>(regs);
if (Valid & 1) {
- auto values = Folder::Odd(regs[Valid - 1]);
- regs[Valid / 2] = Folder::OddUpcast(Op::Run(values.lo, values.hi));
+ regs[Valid / 2] = Folder::OddUpcast(Op::Run(Folder::Odd(regs[Valid - 1])));
}
}
@@ -114,14 +110,11 @@ template <class Op> struct ReduceFours {
const Index i = Iterator::template I<0>();
const Register *in = regs + i * 4;
// Do 256-bit interleaving first because it's slightly cheaper.
- RegisterPair mix0pair = Reduce256Folder::Even(in[0], in[2]);
- RegisterPair mix1pair = Reduce256Folder::Even(in[1], in[3]);
// 0 0 2 2
- Register mix0 = Op::Run(mix0pair.hi, mix0pair.lo);
+ Register mix0 = Op::Run(Reduce256Folder::Even(in[0], in[2]));
// 1 1 3 3
- Register mix1 = Op::Run(mix1pair.hi, mix1pair.lo);
- mix0pair = Reduce128Folder::Even(mix0, mix1);
- regs[i] = Op::Run(mix0pair.hi, mix0pair.lo);
+ Register mix1 = Op::Run(Reduce256Folder::Even(in[1], in[3]));
+ regs[i] = Op::Run(Reduce128Folder::Even(mix0, mix1));
}
};
@@ -140,9 +133,8 @@ template <class Op> struct ReduceOverhang<0, Op> {
// Overhang of 1 AVX512 register. Fold over itself going down to SSE2.
template <class Op> struct ReduceOverhang<1, Op> {
INTGEMM_TARGET static inline void Run(const Register *regs, Register &to) {
- AVX2::Register folded = Op::Run(_mm512_castsi512_si256(regs[0]), _mm512_extracti64x4_epi64(regs[0], 1));
- SSE2::RegisterPair pair = AVX2::Reduce128Folder::Odd(folded);
- SSE2::Register more = Op::Run(pair.hi, pair.lo);
+ AVX2::Register folded = Op::Run(AVX2::RegisterPair {_mm512_castsi512_si256(regs[0]), _mm512_extracti64x4_epi64(regs[0], 1)});
+ SSE2::Register more = Op::Run(AVX2::Reduce128Folder::Odd(folded));
to = _mm512_castsi128_si512(more);
}
};
@@ -150,24 +142,21 @@ template <class Op> struct ReduceOverhang<1, Op> {
template <class Op> struct ReduceOverhang<2, Op> {
// Overhang of 2 registers: fold to AVX2.
INTGEMM_TARGET static inline void Run(const Register *regs, Register &to) {
- RegisterPair mixpair = Reduce128Folder::Even(regs[0], regs[1]);
- Register mix = Op::Run(mixpair.hi, mixpair.lo);
- AVX2::Register folded = Op::Run(_mm512_castsi512_si256(mix), _mm512_extracti64x4_epi64(mix, 1));
+ Register mix = Op::Run(Reduce128Folder::Even(regs[0], regs[1]));
+ AVX2::Register folded = Op::Run(AVX2::RegisterPair{_mm512_castsi512_si256(mix), _mm512_extracti64x4_epi64(mix, 1)});
to = _mm512_castsi256_si512(folded);
}
};
// Overhang of 3 AVX512 registers. Fold two together and one overitself.
template <class Op> struct ReduceOverhang<3, Op> {
INTGEMM_TARGET static inline void Run(const Register *regs, Register &to) {
- RegisterPair mix0pair = Reduce256Folder::Even(regs[0], regs[2]);
- Register mix0022 = Op::Run(mix0pair.hi, mix0pair.lo);
+ Register mix0022 = Op::Run(Reduce256Folder::Even(regs[0], regs[2]));
// mix0022 128-bit bit blocks: 0 0 2 2
- AVX2::Register fold11 = Op::Run(_mm512_castsi512_si256(regs[1]), _mm512_extracti64x4_epi64(regs[1], 1));
+ AVX2::Register fold11 = Op::Run(AVX2::RegisterPair{_mm512_castsi512_si256(regs[1]), _mm512_extracti64x4_epi64(regs[1], 1)});
// fold11 128-bit blocks: 1 1
- RegisterPair finish = Reduce128Folder::Even(mix0022, _mm512_castsi256_si512(fold11));
- to = Op::Run(finish.hi, finish.lo);
+ to = Op::Run(Reduce128Folder::Even(mix0022, _mm512_castsi256_si512(fold11)));
}
};