Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'tile/multiply.inl')
-rw-r--r--tile/multiply.inl31
1 files changed, 30 insertions, 1 deletions
diff --git a/tile/multiply.inl b/tile/multiply.inl
index f1344d7..78dca55 100644
--- a/tile/multiply.inl
+++ b/tile/multiply.inl
@@ -29,11 +29,11 @@ template <class AccessT, class Kernel> INTGEMM_TARGET __attribute__((flatten)) s
assert(shape.A_rows % Kernel::kTile.A_rows == 0);
assert(shape.inner % Kernel::kTile.inner == 0);
assert(shape.B_cols % Kernel::kTile.B_cols == 0);
- constexpr Index Outputs = Kernel::kTile.A_rows * Kernel::kTile.B_cols;
for (Index B_col = 0; B_col < shape.B_cols; B_col += Kernel::kTile.B_cols) {
AccessT column_adjusted = access.BAdd(0, B_col).CAdd(0, B_col);
for (Index A_row = 0; A_row < shape.A_rows; A_row += Kernel::kTile.A_rows) {
AccessT col_row = column_adjusted.AAdd(A_row, 0).CAdd(A_row, 0);
+ constexpr Index Outputs = Kernel::kTile.A_rows * Kernel::kTile.B_cols;
// Accumulate values in temporary C registers.
Register c_regs[Outputs] = {setzero_si<Register>()};
@@ -56,6 +56,35 @@ template <class AccessT, class Kernel> INTGEMM_TARGET __attribute__((flatten)) s
}
}
+template <class Access, class Kernel, Index A_rows, Index B_cols> INTGEMM_TARGET static inline void Multiply(Access access, const Tile shape) {
+ // Still has to be a multiple of the underlying Kernel, but usually that's just 1 x sizeof(Register) x 1.
+ assert(shape.A_rows % Kernel::kTile.A_rows == 0);
+ assert(shape.inner % Kernel::kTile.inner == 0);
+ assert(shape.B_cols % Kernel::kTile.B_cols == 0);
+
+ typedef UnrollKernel<A_rows, 1, B_cols, Kernel> Big;
+ Tile overhang = {
+ shape.A_rows % Big::kTile.A_rows,
+ shape.inner % Big::kTile.inner,
+ shape.B_cols % Big::kTile.B_cols
+ };
+ Tile big_shape = {
+ shape.A_rows - overhang.A_rows,
+ shape.inner - overhang.inner,
+ shape.B_cols - overhang.B_cols
+ };
+ // Top left corner.
+ MultiplyNoOverhang<Access, Big>(access, big_shape);
+ // Bottom currently including right side. TODO: unrolled kernel, rather than dumb loop.
+ MultiplyNoOverhang<Access, Kernel>(
+ access.AAdd(big_shape.A_rows, 0).CAdd(big_shape.A_rows, 0),
+ Tile {overhang.A_rows, shape.inner, shape.B_cols});
+ // Right side except bottom. TODO: unrolled kernel, rather than dumb loop.
+ MultiplyNoOverhang<Access, Kernel>(
+ access.BAdd(0, big_shape.B_cols).CAdd(0, big_shape.B_cols),
+ Tile {big_shape.A_rows, shape.inner, overhang.B_cols});
+}
+
} // namespace INTGEMM_ARCH
} // namespace intgemm