diff options
author | Kenneth Heafield <github@kheafield.com> | 2020-04-05 00:22:49 +0300 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2020-04-05 00:22:49 +0300 |
commit | b444029e291f874859000ad5527ce38895213f47 (patch) | |
tree | 163d38b3d061b9f98c799d41e46142361ba35d0a | |
parent | b65be9edd266d3446d1475efc2d32cd3241874b7 (diff) |
Comments
-rw-r--r-- | tile/reduce.h | 32 | ||||
-rw-r--r-- | tile/reduce.inl | 28 |
2 files changed, 56 insertions, 4 deletions
diff --git a/tile/reduce.h b/tile/reduce.h index 641c403..188c7e9 100644 --- a/tile/reduce.h +++ b/tile/reduce.h @@ -1,11 +1,40 @@ +/* reduce.h: Horizontally reduce an arbitrary number of registers + * simultaneously. Given an array of registers, they will be horizontally + * reduced (i.e. summed if Sum32Op is used) with the results placed back into + * the array. + * + * This is the function: + * template <Index Valid, class Op> INTGEMM_TARGET static inline void Reduce32(Register *regs); + * + * Valid is the length of the array of Registers in the input. + * + * Op defines the reduction operation. It should support three architectures: + * INTGEMM_SSE2 static inline __m128i Run(__m128i first, __m128i second); + * INTGEMM_AVX2 static inline __m256i Run(__m256i first, __m256i second); + * INTGEMM_AVX512BW static inline __m512i Run(__m512i first, __m512i second); + * See Sum32Op for an example. + * + * regs is memory to use. + * Input: an array Register[Valid]. + * Output: an array int32_t[Valid] of reduced values in the same order. This + * can be interpreted as registers with reduced values packed into them. + * Anything at index Valid or later is undefined in the output. + * + * The function is defined in each architecture's namespace, so: + * intgemm::SSE2:Reduce32 + * intgemm::SSSE3:Reduce32 + * intgemm::AVX2:Reduce32 + * intgemm::AVX512BW:Reduce32 + * intgemm::AVX512VNNI:Reduce32 + */ #pragma once - #include "../intrinsics.h" #include "../utils.h" #include "../types.h" namespace intgemm { +// Op argument appropriate for summing 32-bit integers. struct Sum32Op { INTGEMM_SSE2 static inline __m128i Run(__m128i first, __m128i second) { return add_epi32(first, second); @@ -24,6 +53,7 @@ struct Sum32Op { } // namespace intgemm +// One implementation per width; the rest just import below. #define INTGEMM_THIS_IS_SSE2 #include "reduce.inl" #undef INTGEMM_THIS_IS_SSE2 diff --git a/tile/reduce.inl b/tile/reduce.inl index 17769fb..f070723 100644 --- a/tile/reduce.inl +++ b/tile/reduce.inl @@ -1,3 +1,5 @@ +/* This file is included multiple times from reduce.h, once for each of the + * below architectures. */ #if defined(INTGEMM_THIS_IS_AVX512BW) #define INTGEMM_ARCH AVX512BW #define INTGEMM_TARGET INTGEMM_AVX512BW @@ -16,6 +18,7 @@ namespace INTGEMM_ARCH { struct RegisterPair { Register hi; Register lo; }; +/* Static loop callback for folding an even number of registers. */ template <class Op, class Folder> struct ReduceEvens { template <class Iterator> INTGEMM_TARGET static inline void body(Register *regs) { const Index i = Iterator::template I<0>(); @@ -23,7 +26,8 @@ template <class Op, class Folder> struct ReduceEvens { regs[i] = Op::Run(ret.hi, ret.lo); } }; - +/* Call a fold object to reduce one width. Does a static loop over pairs of + * registers then handles odd numbers at the end */ template <Index Valid, class Op, class Folder> INTGEMM_TARGET static inline void GenericReduce(Register *regs) { StaticLoop<ReduceEvens<Op, Folder>, MakeStaticLoopIterator<Valid / 2>>(regs); if (Valid & 1) { @@ -32,6 +36,9 @@ template <Index Valid, class Op, class Folder> INTGEMM_TARGET static inline void } } +/* These Folder structs say how to interweave even pairs of regiers and + * fold an odd register over itself. Folding an odd register over itself is + * slightly faster than doing an even fold with garbage. */ struct Reduce32Folder { INTGEMM_TARGET static inline RegisterPair Even(Register first, Register second) { return RegisterPair { unpackhi_epi32(first, second), unpacklo_epi32(first, second) }; @@ -72,6 +79,9 @@ struct Reduce128Folder { #endif #ifdef INTGEMM_THIS_IS_AVX512BW +/* AVX512 is a special case due to multiple register widths for odd cases and + * its length. We have to fold two more times over 128-bit lanes to reduce + * completely. */ struct Reduce128Folder { INTGEMM_TARGET static inline RegisterPair Even(Register first, Register second) { // TODO can this be optimized with a blend and a shuffle instruction? @@ -95,6 +105,8 @@ struct Reduce256Folder { } }; +/* The common case for AVX512 where there are 4 registers to fold. This is the + * body of a static loop. */ template <class Op> struct ReduceFours { // Collapse 4 AVX512 registers at once, interleaving 128-bit fields. template <class Iterator> INTGEMM_TARGET static inline void body(Register *regs) { @@ -112,14 +124,20 @@ template <class Op> struct ReduceFours { } }; -// non-type partial specialization ‘ReduceOverhang<0, Op>’ is not allowed +/* Handle overhang when the number of AVX512 registers is not a multiple of 4. + * The numeric argument is how many are left over. + * I use an output argument (instead of return value) to avoid writing when + * nothing is left over. + * + * Partial specialization of functions isn't allowed, so use a class wrapper. + */ template <Index Valid, class Op> struct ReduceOverhang; template <class Op> struct ReduceOverhang<0, Op> { INTGEMM_TARGET static inline void Run(const Register *, Register &) {} }; +// Overhang of 1 AVX512 register. Fold over itself going down to SSE2. template <class Op> struct ReduceOverhang<1, Op> { - // Overhang of 1 register: fold it overself to SSE2. INTGEMM_TARGET static inline void Run(const Register *regs, Register &to) { AVX2::Register folded = Op::Run(_mm512_castsi512_si256(regs[0]), _mm512_extracti64x4_epi64(regs[0], 1)); SSE2::RegisterPair pair = AVX2::Reduce128Folder::Odd(folded); @@ -127,6 +145,7 @@ template <class Op> struct ReduceOverhang<1, Op> { to = _mm512_castsi128_si512(more); } }; +// Overhang of 2 AVX512 registers. template <class Op> struct ReduceOverhang<2, Op> { // Overhang of 2 registers: fold to AVX2. INTGEMM_TARGET static inline void Run(const Register *regs, Register &to) { @@ -136,6 +155,7 @@ template <class Op> struct ReduceOverhang<2, Op> { to = _mm512_castsi256_si512(folded); } }; +// Overhang of 3 AVX512 registers. Fold two together and one overitself. template <class Op> struct ReduceOverhang<3, Op> { INTGEMM_TARGET static inline void Run(const Register *regs, Register &to) { RegisterPair mix0pair = Reduce256Folder::Even(regs[0], regs[2]); @@ -152,6 +172,8 @@ template <class Op> struct ReduceOverhang<3, Op> { #endif +/* The only public function: horizontally reduce registers with 32-bit values. + */ template <Index Valid, class Op> INTGEMM_TARGET static inline void Reduce32(Register *regs) { GenericReduce<Valid, Op, Reduce32Folder>(regs); GenericReduce<(Valid + 1) / 2, Op, Reduce64Folder>(regs); |