Welcome to mirror list, hosted at ThFree Co, Russian Federation.

multiply.inl « tile - github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: f1a22cf154ede123e688eba044c00bba16cde0ae (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#if defined(INTGEMM_THIS_IS_AVX512VNNI)
#define INTGEMM_ARCH AVX512VNNI
#define INTGEMM_TARGET INTGEMM_AVX512VNNI
#elif defined(INTGEMM_THIS_IS_AVX512BW)
#define INTGEMM_ARCH AVX512BW
#define INTGEMM_TARGET INTGEMM_AVX512BW
#elif defined(INTGEMM_THIS_IS_AVX2)
#define INTGEMM_ARCH AVX2
#define INTGEMM_TARGET INTGEMM_AVX2
#elif defined(INTGEMM_THIS_IS_SSSE3)
#define INTGEMM_ARCH SSSE3
#define INTGEMM_TARGET INTGEMM_SSSE3
#elif defined(INTGEMM_THIS_IS_SSE2)
#define INTGEMM_ARCH SSE2
#define INTGEMM_TARGET INTGEMM_SSE2
#endif

namespace intgemm {
namespace INTGEMM_ARCH {

// Upcast 16 to 32 if needed.
template <class T> INTGEMM_TARGET static inline void SumTo32(Register &reg);
template <> INTGEMM_TARGET inline void SumTo32<int16_t>(Register &reg) {
  reg = madd_epi16(reg, set1_epi16<Register>(1));
}
template <> INTGEMM_TARGET inline void SumTo32<int32_t>(Register &) {}

template <class T> struct SumTo32Body {
  template <class Iterator> INTGEMM_TARGET static inline void body(Register *regs) {
    SumTo32<T>(regs[Iterator::template I<0>()]);
  }
};

/* Multiply assuming the matrix sizes are a multiple of the kernel size. */
template <class AccessT, class Kernel> INTGEMM_TARGET static inline void MultiplyNoOverhang(AccessT access, const Tile shape) {
  assert(shape.A_rows % Kernel::kTile.A_rows == 0);
  assert(shape.inner % Kernel::kTile.inner == 0);
  assert(shape.B_cols % Kernel::kTile.B_cols == 0);
  constexpr Index Outputs = Kernel::kTile.A_rows * Kernel::kTile.B_cols;
  for (Index B_col = 0; B_col < shape.B_cols; B_col += Kernel::kTile.B_cols) {
    AccessT column_adjusted = access.BAdd(0, B_col).CAdd(0, B_col);
    for (Index A_row = 0; A_row < shape.A_rows; A_row += Kernel::kTile.A_rows) {
      AccessT col_row = column_adjusted.AAdd(A_row, 0).CAdd(A_row, 0);

      // Accumulate values in temporary C registers.
      Register c_regs[Outputs] = {setzero_si<Register>()};
      // If C is column major it would be better to have column-major registers
      // since this determines the order used by Reduce32.
      Access<typename AccessT::A, typename AccessT::B, RegisterRowMajorAccess> reg_access(
          col_row.AAccessor(),
          col_row.BAccessor(),
          RegisterRowMajorAccess(c_regs, Kernel::kTile.B_cols));

      for (Index inner = 0; inner < shape.inner; inner += Kernel::kTile.inner) {
        Kernel::Run(reg_access.AAdd(0, inner).BAdd(inner, 0));
      }

      // If 16-bit, upcast to 32-bit while horizontally adding.
      StaticLoop<SumTo32Body<typename Kernel::Packed::C>, MakeStaticLoopIterator<Outputs>>(c_regs);
      // Horizontally add 32-bit values.
      Reduce32<Outputs, Sum32Op>(c_regs);
      col_row.CAccessor().template Write<Kernel::kTile.A_rows, Kernel::kTile.B_cols>(c_regs);
    }
  }
}

} // namespace INTGEMM_ARCH
} // namespace intgemm

#undef INTGEMM_ARCH
#undef INTGEMM_TARGET