diff options
author | Kenneth Heafield <github@kheafield.com> | 2020-04-19 15:30:26 +0300 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2020-04-19 15:30:26 +0300 |
commit | a3a31149281df195e7a7a316f95845b9ef8e1b34 (patch) | |
tree | 94decef35cc5f7de2efa51bf6e87eb202bf52732 | |
parent | 2e5f05e1d5a7b1e644a0a5a2a62afbd151d09b7b (diff) |
Sum16To32 using variadic templates
-rw-r--r-- | tile/multiply.inl | 17 |
1 files changed, 5 insertions, 12 deletions
diff --git a/tile/multiply.inl b/tile/multiply.inl index 3bec91f..f1344d7 100644 --- a/tile/multiply.inl +++ b/tile/multiply.inl @@ -19,16 +19,10 @@ namespace intgemm { namespace INTGEMM_ARCH { // Upcast 16 to 32 if needed. -template <class T> struct SumTo32Body; -template <> struct SumTo32Body<int16_t> { - template <class Iterator> INTGEMM_TARGET static inline void body(Register *regs) { - Register ® = regs[Iterator::template I<0>()]; - reg = madd_epi16(reg, set1_epi16<Register>(1)); - } -}; -template <> struct SumTo32Body<int32_t> { - template <class Iterator> INTGEMM_TARGET static inline void body(Register *) {} -}; +template <std::size_t... i> INTGEMM_TARGET static inline void Sum16To32(Register *regs, int16_t, index_sequence<i...>) { + unordered_unfurl((regs[i] = madd_epi16(regs[i], set1_epi16<Register>(1)))...); +} +template <std::size_t... i> INTGEMM_TARGET static inline void Sum16To32(Register *, int32_t, index_sequence<i...>) {} /* Multiply assuming the matrix sizes are a multiple of the kernel size. */ template <class AccessT, class Kernel> INTGEMM_TARGET __attribute__((flatten)) static inline void MultiplyNoOverhang(AccessT access, const Tile shape) { @@ -54,8 +48,7 @@ template <class AccessT, class Kernel> INTGEMM_TARGET __attribute__((flatten)) s Kernel::Run(reg_access.AAdd(0, inner).BAdd(inner, 0)); } - // If 16-bit, upcast to 32-bit while horizontally adding. - StaticLoop<SumTo32Body<typename Kernel::Packed::C>, MakeStaticLoopIterator<Outputs>>(c_regs); + Sum16To32(c_regs, typename Kernel::Packed::C(), make_index_sequence<Outputs>()); // Horizontally add 32-bit values. Reduce32<Outputs, Sum32Op>(c_regs); col_row.CAccessor().template Write<Kernel::kTile.A_rows, Kernel::kTile.B_cols>(c_regs); |