diff options
author | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2020-04-25 14:07:25 +0300 |
---|---|---|
committer | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2020-04-25 14:22:08 +0300 |
commit | e78b9af9ba7ba419567a074deeb104943db6cd14 (patch) | |
tree | a261a99cc70a2c5e7e43e987802d491740b36385 | |
parent | 6377ee4d9f051d7be0c9c290bb33ab66f27ea900 (diff) |
Add multiply function for kernels 1x1x16static-multiply1x16
-rw-r--r-- | tile/multiply.inl | 46 |
1 files changed, 46 insertions, 0 deletions
diff --git a/tile/multiply.inl b/tile/multiply.inl index a1a92cf..42d2faf 100644 --- a/tile/multiply.inl +++ b/tile/multiply.inl @@ -56,6 +56,52 @@ template <class AccessT, class Kernel> INTGEMM_TARGET __attribute__((flatten)) s } } +template <class Access, class Kernel, Index A_rows, Index B_cols> INTGEMM_TARGET static inline void Multiply_Force1x16(Access access, const Tile shape) { + // Still has to be a multiple of the underlying Kernel, but usually that's just 1 x sizeof(Register) x 1. + assert(shape.B_cols % 64 == 0); + + // Left part + typedef UnrollKernel<1, 1, 16, Kernel> Left; + Tile overhang = { + shape.A_rows % Left::kTile.A_rows, // = 0 + shape.inner % Left::kTile.inner, + shape.B_cols % Left::kTile.B_cols + }; + Tile left_shape = { + shape.A_rows - overhang.A_rows, // = shape.A_rows + shape.inner - overhang.inner, + shape.B_cols - overhang.B_cols + }; + MultiplyNoOverhang<Access, Left>(access, left_shape); + + // Right part +#define INTGEMM_UGLY_WAY_TO_IMPL_RIGHT_PART(overhang_value) \ + if ((overhang_value) == overhang.B_cols) { \ + typedef UnrollKernel<1, 1, overhang_value, Kernel> Right; \ + MultiplyNoOverhang<Access, Right>( \ + access.BAdd(0, left_shape.B_cols).CAdd(0, left_shape.B_cols), \ + Tile{left_shape.A_rows, shape.inner, overhang.B_cols}); \ + } + + INTGEMM_UGLY_WAY_TO_IMPL_RIGHT_PART(1) + INTGEMM_UGLY_WAY_TO_IMPL_RIGHT_PART(2) + INTGEMM_UGLY_WAY_TO_IMPL_RIGHT_PART(3) + INTGEMM_UGLY_WAY_TO_IMPL_RIGHT_PART(4) + INTGEMM_UGLY_WAY_TO_IMPL_RIGHT_PART(5) + INTGEMM_UGLY_WAY_TO_IMPL_RIGHT_PART(6) + INTGEMM_UGLY_WAY_TO_IMPL_RIGHT_PART(7) + INTGEMM_UGLY_WAY_TO_IMPL_RIGHT_PART(8) + INTGEMM_UGLY_WAY_TO_IMPL_RIGHT_PART(9) + INTGEMM_UGLY_WAY_TO_IMPL_RIGHT_PART(10) + INTGEMM_UGLY_WAY_TO_IMPL_RIGHT_PART(11) + INTGEMM_UGLY_WAY_TO_IMPL_RIGHT_PART(12) + INTGEMM_UGLY_WAY_TO_IMPL_RIGHT_PART(13) + INTGEMM_UGLY_WAY_TO_IMPL_RIGHT_PART(14) + INTGEMM_UGLY_WAY_TO_IMPL_RIGHT_PART(15) + +#undef INTGEMM_UGLY_WAY_TO_IMPL_RIGHT_PART +} + /* Multiply matrices without being a multiple of an unrolled kernel size. The * inner dimension still needs to be a multiple of sizeof(Register) for int8_t * or sizeof(Register) / 2 for int16_t. |