Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2020-04-05 00:22:49 +0300
committerKenneth Heafield <github@kheafield.com>2020-04-05 00:22:49 +0300
commitb444029e291f874859000ad5527ce38895213f47 (patch)
tree163d38b3d061b9f98c799d41e46142361ba35d0a
parentb65be9edd266d3446d1475efc2d32cd3241874b7 (diff)
Comments
-rw-r--r--tile/reduce.h32
-rw-r--r--tile/reduce.inl28
2 files changed, 56 insertions, 4 deletions
diff --git a/tile/reduce.h b/tile/reduce.h
index 641c403..188c7e9 100644
--- a/tile/reduce.h
+++ b/tile/reduce.h
@@ -1,11 +1,40 @@
+/* reduce.h: Horizontally reduce an arbitrary number of registers
+ * simultaneously. Given an array of registers, they will be horizontally
+ * reduced (i.e. summed if Sum32Op is used) with the results placed back into
+ * the array.
+ *
+ * This is the function:
+ * template <Index Valid, class Op> INTGEMM_TARGET static inline void Reduce32(Register *regs);
+ *
+ * Valid is the length of the array of Registers in the input.
+ *
+ * Op defines the reduction operation. It should support three architectures:
+ * INTGEMM_SSE2 static inline __m128i Run(__m128i first, __m128i second);
+ * INTGEMM_AVX2 static inline __m256i Run(__m256i first, __m256i second);
+ * INTGEMM_AVX512BW static inline __m512i Run(__m512i first, __m512i second);
+ * See Sum32Op for an example.
+ *
+ * regs is memory to use.
+ * Input: an array Register[Valid].
+ * Output: an array int32_t[Valid] of reduced values in the same order. This
+ * can be interpreted as registers with reduced values packed into them.
+ * Anything at index Valid or later is undefined in the output.
+ *
+ * The function is defined in each architecture's namespace, so:
+ * intgemm::SSE2:Reduce32
+ * intgemm::SSSE3:Reduce32
+ * intgemm::AVX2:Reduce32
+ * intgemm::AVX512BW:Reduce32
+ * intgemm::AVX512VNNI:Reduce32
+ */
#pragma once
-
#include "../intrinsics.h"
#include "../utils.h"
#include "../types.h"
namespace intgemm {
+// Op argument appropriate for summing 32-bit integers.
struct Sum32Op {
INTGEMM_SSE2 static inline __m128i Run(__m128i first, __m128i second) {
return add_epi32(first, second);
@@ -24,6 +53,7 @@ struct Sum32Op {
} // namespace intgemm
+// One implementation per width; the rest just import below.
#define INTGEMM_THIS_IS_SSE2
#include "reduce.inl"
#undef INTGEMM_THIS_IS_SSE2
diff --git a/tile/reduce.inl b/tile/reduce.inl
index 17769fb..f070723 100644
--- a/tile/reduce.inl
+++ b/tile/reduce.inl
@@ -1,3 +1,5 @@
+/* This file is included multiple times from reduce.h, once for each of the
+ * below architectures. */
#if defined(INTGEMM_THIS_IS_AVX512BW)
#define INTGEMM_ARCH AVX512BW
#define INTGEMM_TARGET INTGEMM_AVX512BW
@@ -16,6 +18,7 @@ namespace INTGEMM_ARCH {
struct RegisterPair { Register hi; Register lo; };
+/* Static loop callback for folding an even number of registers. */
template <class Op, class Folder> struct ReduceEvens {
template <class Iterator> INTGEMM_TARGET static inline void body(Register *regs) {
const Index i = Iterator::template I<0>();
@@ -23,7 +26,8 @@ template <class Op, class Folder> struct ReduceEvens {
regs[i] = Op::Run(ret.hi, ret.lo);
}
};
-
+/* Call a fold object to reduce one width. Does a static loop over pairs of
+ * registers then handles odd numbers at the end */
template <Index Valid, class Op, class Folder> INTGEMM_TARGET static inline void GenericReduce(Register *regs) {
StaticLoop<ReduceEvens<Op, Folder>, MakeStaticLoopIterator<Valid / 2>>(regs);
if (Valid & 1) {
@@ -32,6 +36,9 @@ template <Index Valid, class Op, class Folder> INTGEMM_TARGET static inline void
}
}
+/* These Folder structs say how to interweave even pairs of regiers and
+ * fold an odd register over itself. Folding an odd register over itself is
+ * slightly faster than doing an even fold with garbage. */
struct Reduce32Folder {
INTGEMM_TARGET static inline RegisterPair Even(Register first, Register second) {
return RegisterPair { unpackhi_epi32(first, second), unpacklo_epi32(first, second) };
@@ -72,6 +79,9 @@ struct Reduce128Folder {
#endif
#ifdef INTGEMM_THIS_IS_AVX512BW
+/* AVX512 is a special case due to multiple register widths for odd cases and
+ * its length. We have to fold two more times over 128-bit lanes to reduce
+ * completely. */
struct Reduce128Folder {
INTGEMM_TARGET static inline RegisterPair Even(Register first, Register second) {
// TODO can this be optimized with a blend and a shuffle instruction?
@@ -95,6 +105,8 @@ struct Reduce256Folder {
}
};
+/* The common case for AVX512 where there are 4 registers to fold. This is the
+ * body of a static loop. */
template <class Op> struct ReduceFours {
// Collapse 4 AVX512 registers at once, interleaving 128-bit fields.
template <class Iterator> INTGEMM_TARGET static inline void body(Register *regs) {
@@ -112,14 +124,20 @@ template <class Op> struct ReduceFours {
}
};
-// non-type partial specialization ‘ReduceOverhang<0, Op>’ is not allowed
+/* Handle overhang when the number of AVX512 registers is not a multiple of 4.
+ * The numeric argument is how many are left over.
+ * I use an output argument (instead of return value) to avoid writing when
+ * nothing is left over.
+ *
+ * Partial specialization of functions isn't allowed, so use a class wrapper.
+ */
template <Index Valid, class Op> struct ReduceOverhang;
template <class Op> struct ReduceOverhang<0, Op> {
INTGEMM_TARGET static inline void Run(const Register *, Register &) {}
};
+// Overhang of 1 AVX512 register. Fold over itself going down to SSE2.
template <class Op> struct ReduceOverhang<1, Op> {
- // Overhang of 1 register: fold it overself to SSE2.
INTGEMM_TARGET static inline void Run(const Register *regs, Register &to) {
AVX2::Register folded = Op::Run(_mm512_castsi512_si256(regs[0]), _mm512_extracti64x4_epi64(regs[0], 1));
SSE2::RegisterPair pair = AVX2::Reduce128Folder::Odd(folded);
@@ -127,6 +145,7 @@ template <class Op> struct ReduceOverhang<1, Op> {
to = _mm512_castsi128_si512(more);
}
};
+// Overhang of 2 AVX512 registers.
template <class Op> struct ReduceOverhang<2, Op> {
// Overhang of 2 registers: fold to AVX2.
INTGEMM_TARGET static inline void Run(const Register *regs, Register &to) {
@@ -136,6 +155,7 @@ template <class Op> struct ReduceOverhang<2, Op> {
to = _mm512_castsi256_si512(folded);
}
};
+// Overhang of 3 AVX512 registers. Fold two together and one overitself.
template <class Op> struct ReduceOverhang<3, Op> {
INTGEMM_TARGET static inline void Run(const Register *regs, Register &to) {
RegisterPair mix0pair = Reduce256Folder::Even(regs[0], regs[2]);
@@ -152,6 +172,8 @@ template <class Op> struct ReduceOverhang<3, Op> {
#endif
+/* The only public function: horizontally reduce registers with 32-bit values.
+ */
template <Index Valid, class Op> INTGEMM_TARGET static inline void Reduce32(Register *regs) {
GenericReduce<Valid, Op, Reduce32Folder>(regs);
GenericReduce<(Valid + 1) / 2, Op, Reduce64Folder>(regs);