diff options
Diffstat (limited to 'tile/multiply.inl')
-rw-r--r-- | tile/multiply.inl | 31 |
1 files changed, 30 insertions, 1 deletions
diff --git a/tile/multiply.inl b/tile/multiply.inl index f1344d7..78dca55 100644 --- a/tile/multiply.inl +++ b/tile/multiply.inl @@ -29,11 +29,11 @@ template <class AccessT, class Kernel> INTGEMM_TARGET __attribute__((flatten)) s assert(shape.A_rows % Kernel::kTile.A_rows == 0); assert(shape.inner % Kernel::kTile.inner == 0); assert(shape.B_cols % Kernel::kTile.B_cols == 0); - constexpr Index Outputs = Kernel::kTile.A_rows * Kernel::kTile.B_cols; for (Index B_col = 0; B_col < shape.B_cols; B_col += Kernel::kTile.B_cols) { AccessT column_adjusted = access.BAdd(0, B_col).CAdd(0, B_col); for (Index A_row = 0; A_row < shape.A_rows; A_row += Kernel::kTile.A_rows) { AccessT col_row = column_adjusted.AAdd(A_row, 0).CAdd(A_row, 0); + constexpr Index Outputs = Kernel::kTile.A_rows * Kernel::kTile.B_cols; // Accumulate values in temporary C registers. Register c_regs[Outputs] = {setzero_si<Register>()}; @@ -56,6 +56,35 @@ template <class AccessT, class Kernel> INTGEMM_TARGET __attribute__((flatten)) s } } +template <class Access, class Kernel, Index A_rows, Index B_cols> INTGEMM_TARGET static inline void Multiply(Access access, const Tile shape) { + // Still has to be a multiple of the underlying Kernel, but usually that's just 1 x sizeof(Register) x 1. + assert(shape.A_rows % Kernel::kTile.A_rows == 0); + assert(shape.inner % Kernel::kTile.inner == 0); + assert(shape.B_cols % Kernel::kTile.B_cols == 0); + + typedef UnrollKernel<A_rows, 1, B_cols, Kernel> Big; + Tile overhang = { + shape.A_rows % Big::kTile.A_rows, + shape.inner % Big::kTile.inner, + shape.B_cols % Big::kTile.B_cols + }; + Tile big_shape = { + shape.A_rows - overhang.A_rows, + shape.inner - overhang.inner, + shape.B_cols - overhang.B_cols + }; + // Top left corner. + MultiplyNoOverhang<Access, Big>(access, big_shape); + // Bottom currently including right side. TODO: unrolled kernel, rather than dumb loop. + MultiplyNoOverhang<Access, Kernel>( + access.AAdd(big_shape.A_rows, 0).CAdd(big_shape.A_rows, 0), + Tile {overhang.A_rows, shape.inner, shape.B_cols}); + // Right side except bottom. TODO: unrolled kernel, rather than dumb loop. + MultiplyNoOverhang<Access, Kernel>( + access.BAdd(0, big_shape.B_cols).CAdd(0, big_shape.B_cols), + Tile {big_shape.A_rows, shape.inner, overhang.B_cols}); +} + } // namespace INTGEMM_ARCH } // namespace intgemm |