diff options
author | Kenneth Heafield <github@kheafield.com> | 2020-04-19 15:20:07 +0300 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2020-04-19 15:20:07 +0300 |
commit | 2e5f05e1d5a7b1e644a0a5a2a62afbd151d09b7b (patch) | |
tree | bd87b1d309895188665131d1b2b17d6c1d643af3 | |
parent | d8a3c5020b5c240e1547d6e1a9fffae69eca429c (diff) |
Replace StaticLoop with variadic template
-rw-r--r-- | tile/reduce.inl | 34 |
1 files changed, 19 insertions, 15 deletions
diff --git a/tile/reduce.inl b/tile/reduce.inl index 9106c52..b17cfca 100644 --- a/tile/reduce.inl +++ b/tile/reduce.inl @@ -16,17 +16,19 @@ namespace intgemm { namespace INTGEMM_ARCH { -/* Static loop callback for folding an even number of registers. */ +/* Static loop for folding an even number of registers. */ template <class Op, class Folder> struct ReduceEvens { - template <class Iterator> INTGEMM_TARGET static inline void body(Register *regs) { - const Index i = Iterator::template I<0>(); - regs[i] = Op::Run(Folder::Even(regs[2 * i], regs[2 * i + 1])); + template <std::size_t... i> INTGEMM_TARGET static inline void Run(Register *regs, index_sequence<i...>) { + using static_loop = int[]; + (void)static_loop {0, + (regs[i] = Op::Run(Folder::Even(regs[2 * i], regs[2 * i + 1])), 0)... + }; } }; /* Call a fold object to reduce one width. Does a static loop over pairs of * registers then handles odd numbers at the end */ template <Index Valid, class Op, class Folder> INTGEMM_TARGET static inline void GenericReduce(Register *regs) { - StaticLoop<ReduceEvens<Op, Folder>, MakeStaticLoopIterator<Valid / 2>>(regs); + ReduceEvens<Op, Folder>::Run(regs, make_index_sequence<Valid / 2>()); if (Valid & 1) { regs[Valid / 2] = Folder::OddUpcast(Op::Run(Folder::Odd(regs[Valid - 1]))); } @@ -106,15 +108,17 @@ struct Reduce256Folder { * body of a static loop. */ template <class Op> struct ReduceFours { // Collapse 4 AVX512 registers at once, interleaving 128-bit fields. - template <class Iterator> INTGEMM_TARGET static inline void body(Register *regs) { - const Index i = Iterator::template I<0>(); - const Register *in = regs + i * 4; - // Do 256-bit interleaving first because it's slightly cheaper. - // 0 0 2 2 - Register mix0 = Op::Run(Reduce256Folder::Even(in[0], in[2])); - // 1 1 3 3 - Register mix1 = Op::Run(Reduce256Folder::Even(in[1], in[3])); - regs[i] = Op::Run(Reduce128Folder::Even(mix0, mix1)); + template <std::size_t... i> INTGEMM_TARGET static inline void Run(Register *regs, index_sequence<i...>) { + using static_loop = int[]; + (void)static_loop {0, + // Do 256-bit interleaving first because it's slightly cheaper, then 128-bit. + (regs[i] = Op::Run(Reduce128Folder::Even( + // 0 0 2 2 + Op::Run(Reduce256Folder::Even(regs[i * 4], regs[i * 4 + 2])), + // 1 1 3 3 + Op::Run(Reduce256Folder::Even(regs[i * 4 + 1], regs[i * 4 + 3])) + )), 0)... + }; } }; @@ -173,7 +177,7 @@ template <Index Valid, class Op> INTGEMM_TARGET static inline void Reduce32(Regi // Special handling for AVX512BW because we need to fold twice and it can actually go all the way down to SSE2. constexpr Index remaining = (Valid + 3) / 4; // Handle registers a multiple of 4. - StaticLoop<ReduceFours<Op>, MakeStaticLoopIterator<(remaining / 4)>>(regs); + ReduceFours<Op>::Run(regs, make_index_sequence<remaining / 4>()); ReduceOverhang<remaining & 3, Op>::Run(regs + (remaining & ~3), *(regs + remaining / 4)); #endif } |