Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2020-04-19 15:30:26 +0300
committerKenneth Heafield <github@kheafield.com>2020-04-19 15:30:26 +0300
commita3a31149281df195e7a7a316f95845b9ef8e1b34 (patch)
tree94decef35cc5f7de2efa51bf6e87eb202bf52732
parent2e5f05e1d5a7b1e644a0a5a2a62afbd151d09b7b (diff)
Sum16To32 using variadic templates
-rw-r--r--tile/multiply.inl17
1 files changed, 5 insertions, 12 deletions
diff --git a/tile/multiply.inl b/tile/multiply.inl
index 3bec91f..f1344d7 100644
--- a/tile/multiply.inl
+++ b/tile/multiply.inl
@@ -19,16 +19,10 @@ namespace intgemm {
namespace INTGEMM_ARCH {
// Upcast 16 to 32 if needed.
-template <class T> struct SumTo32Body;
-template <> struct SumTo32Body<int16_t> {
- template <class Iterator> INTGEMM_TARGET static inline void body(Register *regs) {
- Register &reg = regs[Iterator::template I<0>()];
- reg = madd_epi16(reg, set1_epi16<Register>(1));
- }
-};
-template <> struct SumTo32Body<int32_t> {
- template <class Iterator> INTGEMM_TARGET static inline void body(Register *) {}
-};
+template <std::size_t... i> INTGEMM_TARGET static inline void Sum16To32(Register *regs, int16_t, index_sequence<i...>) {
+ unordered_unfurl((regs[i] = madd_epi16(regs[i], set1_epi16<Register>(1)))...);
+}
+template <std::size_t... i> INTGEMM_TARGET static inline void Sum16To32(Register *, int32_t, index_sequence<i...>) {}
/* Multiply assuming the matrix sizes are a multiple of the kernel size. */
template <class AccessT, class Kernel> INTGEMM_TARGET __attribute__((flatten)) static inline void MultiplyNoOverhang(AccessT access, const Tile shape) {
@@ -54,8 +48,7 @@ template <class AccessT, class Kernel> INTGEMM_TARGET __attribute__((flatten)) s
Kernel::Run(reg_access.AAdd(0, inner).BAdd(inner, 0));
}
- // If 16-bit, upcast to 32-bit while horizontally adding.
- StaticLoop<SumTo32Body<typename Kernel::Packed::C>, MakeStaticLoopIterator<Outputs>>(c_regs);
+ Sum16To32(c_regs, typename Kernel::Packed::C(), make_index_sequence<Outputs>());
// Horizontally add 32-bit values.
Reduce32<Outputs, Sum32Op>(c_regs);
col_row.CAccessor().template Write<Kernel::kTile.A_rows, Kernel::kTile.B_cols>(c_regs);