Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2020-04-23 15:51:26 +0300
committerKenneth Heafield <github@kheafield.com>2020-04-23 15:51:26 +0300
commita5be193995e2782e0f14122a35df0ac9c140ee3d (patch)
tree403d6e5a4a9eecebc4323ee5b6248edd2298c61e
parente12a54041950569742bac71167f8d8718e5932c1 (diff)
General write working on AVX512, at least for tested cases
-rw-r--r--test/tile_test.inl15
-rw-r--r--tile/access.h76
2 files changed, 53 insertions, 38 deletions
diff --git a/test/tile_test.inl b/test/tile_test.inl
index 2a6c059..2e00040 100644
--- a/test/tile_test.inl
+++ b/test/tile_test.inl
@@ -245,6 +245,12 @@ TEST_CASE("MultiplyNoOverhang simple inner unroll " INTGEMM_TEST_NAME, "[tile][m
TestMultiplyNoOverhang<Kernel>({5, sizeof(Register) * 2, 7});
}
+TEST_CASE("MultiplyNoOverhang Simple 17 rows " INTGEMM_TEST_NAME, "[tile][multiply]") {
+ if (kCPU < CPUType::INTGEMM_ARCH) return;
+ typedef UnrollKernel<17, 1, 1, Signed8> Kernel;
+ TestMultiplyNoOverhang<Kernel>({17, sizeof(Register), 1});
+}
+
// Annoyingly, catch's cross-product stuff requires the first argument be a type, which is pretty useless for a cross-product of integers.
TEMPLATE_TEST_CASE("MultiplyNoOverhang Unrolled Signed8 " INTGEMM_TEST_NAME, "[tile][multiply]",
(UnrollKernel<1, 1, 1, Signed8>),
@@ -270,13 +276,18 @@ TEMPLATE_TEST_CASE("MultiplyNoOverhang Unrolled Signed8 " INTGEMM_TEST_NAME, "[t
(UnrollKernel<1, 1, 32, Signed8>),
(UnrollKernel<2, 1, 1, Signed8>),
(UnrollKernel<2, 1, 2, Signed8>),
+ (UnrollKernel<2, 1, 3, Signed8>),
(UnrollKernel<3, 1, 1, Signed8>),
(UnrollKernel<3, 1, 3, Signed8>),
(UnrollKernel<4, 1, 1, Signed8>),
(UnrollKernel<5, 1, 1, Signed8>),
+ (UnrollKernel<6, 1, 4, Signed8>),
+ (UnrollKernel<7, 1, 3, Signed8>),
+ (UnrollKernel<7, 1, 4, Signed8>),
(UnrollKernel<15, 1, 1, Signed8>),
- (UnrollKernel<16, 1, 1, Signed8>)
-// (UnrollKernel<17, 1, 1, Signed8>)
+ (UnrollKernel<15, 1, 2, Signed8>),
+ (UnrollKernel<16, 1, 1, Signed8>),
+ (UnrollKernel<17, 1, 1, Signed8>)
) {
if (kCPU < CPUType::INTGEMM_ARCH) return;
TestMultiplyNoOverhangShapes<TestType>();
diff --git a/tile/access.h b/tile/access.h
index 0279392..535a873 100644
--- a/tile/access.h
+++ b/tile/access.h
@@ -31,61 +31,65 @@ template <class T> class RowMajorAccess {
private:
// If there's a full register to write for a column, do that.
- template <Index A_rows, Index B_cols, Index CurColumn> INTGEMM_AVX512BW
- typename std::enable_if<A_rows && B_cols && (CurColumn >= 16)>::type
+ template <Index A_rows, Index B_cols, Index ColRemain> INTGEMM_AVX512BW
+ typename std::enable_if<A_rows && B_cols && (ColRemain >= 16)>::type
WriteImpl(const __m512i *from) {
_mm512_storeu_si512(data_, *from);
- Add(0, 16).template WriteImpl<A_rows, B_cols, (CurColumn - 16)>(from + 1);
+ Add(0, 16).template WriteImpl<A_rows, B_cols, (ColRemain - 16)>(from + 1);
}
- // There is a mix of rows in a register and we need a scatter. Also, we are lucky enough to be at the beginning of a row.
- // TODO case where we are not at the beginning of a register with _mm512_alignr_epi32.
- template <Index A_rows, Index B_cols, Index CurColumn> INTGEMM_AVX512BW
- typename std::enable_if<(A_rows > 1) && CurColumn && (CurColumn < 16) && (CurColumn == B_cols) /* beginning of row */>::type
- WriteImpl(const __m512i *from) {
- // TODO: test this more, also save it somewhere! Make sure compiler isn't recreating this every time.
- const __m512i coefficients_lo = _mm512_set_epi32(
- 15 / B_cols, 14 / B_cols, 13 / B_cols, 12 / B_cols,
- 11 / B_cols, 10 / B_cols, 9 / B_cols, 8 / B_cols,
- 7 / B_cols, 6 / B_cols, 5 / B_cols, 4 / B_cols,
- 3 / B_cols, 2 / B_cols, 1 / B_cols, 0);
+ // TODO: test this more, also save it somewhere! Make sure compiler isn't recreating this every time.
+ template <Index B_cols, Index Off> INTGEMM_AVX512BW inline __m512i Offsets() {
+ const __m512i coefficients = _mm512_set_epi32(
+ (Off + 15) / B_cols, (Off + 14) / B_cols, (Off + 13) / B_cols, (Off + 12) / B_cols,
+ (Off + 11) / B_cols, (Off + 10) / B_cols, (Off + 9) / B_cols, (Off + 8) / B_cols,
+ (Off + 7) / B_cols, (Off + 6) / B_cols, (Off + 5) / B_cols, (Off + 4) / B_cols,
+ (Off + 3) / B_cols, (Off + 2) / B_cols, (Off + 1) / B_cols, Off / B_cols);
+ const __m512i row_offsets = _mm512_set_epi32(
+ (Off + 15) % B_cols, (Off + 14) % B_cols, (Off + 13) % B_cols, (Off + 12) % B_cols,
+ (Off + 11) % B_cols, (Off + 10) % B_cols, (Off + 9) % B_cols, (Off + 8) % B_cols,
+ (Off + 7) % B_cols, (Off + 6) % B_cols, (Off + 5) % B_cols, (Off + 4) % B_cols,
+ (Off + 3) % B_cols, (Off + 2) % B_cols, (Off + 1) % B_cols, Off % B_cols);
+
__m512i cols_reg = _mm512_set1_epi32(cols_);
// Multiply by the number of columns for the offsets.
- const __m512i multiplied = _mm512_mullo_epi32(cols_reg, coefficients_lo);
- // Add row offsets.
- const __m512i row_offsets = _mm512_set_epi32(
- 15 % B_cols, 14 % B_cols, 13 % B_cols, 12 % B_cols,
- 11 % B_cols, 10 % B_cols, 9 % B_cols, 8 % B_cols,
- 7 % B_cols, 6 % B_cols, 5 % B_cols, 4 % B_cols,
- 3 % B_cols, 2 % B_cols, 1 % B_cols, 0);
- // These are the offsets to use if we're perfectly aligned: B_cols is a divisor or multiple of 16.
- __m512i offsets = _mm512_add_epi32(row_offsets, multiplied);
+ const __m512i multiplied = _mm512_mullo_epi32(cols_reg, coefficients);
+ // These are the offsets to use if we're perfectly aligned at the beginning of a row.
+ return _mm512_add_epi32(row_offsets, multiplied);
+ }
+
+ // There is a mix of rows in a register and we need a scatter.
+ template <Index A_rows, Index B_cols, Index ColRemain> INTGEMM_AVX512BW
+ typename std::enable_if<(A_rows > 1) && ColRemain && (ColRemain < 16)>::type
+ WriteImpl(const __m512i *from) {
+ __m512i offsets = Offsets<B_cols, B_cols - ColRemain>();
// We might be at the end of the data, in which case a mask is needed.
- constexpr Index remaining = (A_rows - 1) * B_cols + CurColumn;
- _mm512_mask_i32scatter_epi32(data_, (1 << remaining) - 1, offsets, *from, sizeof(int32_t));
- // We just wrote 16 values: CurColumn, the next row (all or partial), possibly the next etc.
- // 16 - CurColumn of the next row and whatever followed.
- constexpr Index WroteMore = ((remaining < 16) ? remaining : 16) - CurColumn;
- // TODO: testing on this.
- Add(WroteMore / B_cols, WroteMore % B_cols - CurColumn).template WriteImpl<A_rows - WroteMore / B_cols, B_cols, WroteMore % B_cols>(from);
+ constexpr Index remaining = (A_rows - 1) * B_cols + ColRemain;
+ _mm512_mask_i32scatter_epi32(data_ - (B_cols - ColRemain), static_cast<__mmask16>(1 << remaining) - 1, offsets, *from, sizeof(int32_t));
+ // We just wrote 16 values: ColRemain, the next row (all or partial), possibly the next etc.
+ // 16 - ColRemain of the next row and whatever followed.
+ constexpr Index Wrote = ((remaining < 16) ? remaining : 16);
+ constexpr Index Position = (B_cols - ColRemain) + Wrote;
+ // TODO: more testing on this.
+ Add(Position / B_cols, Position % B_cols - (B_cols - ColRemain)).template WriteImpl<A_rows - (Position / B_cols), B_cols, B_cols - (Position % B_cols)>(from + 1);
}
// At clean end of column, move to next row.
- template <Index A_rows, Index B_cols, Index CurColumn> INTGEMM_AVX512BW
- typename std::enable_if<A_rows && B_cols && (CurColumn == 0)>::type
+ template <Index A_rows, Index B_cols, Index ColRemain> INTGEMM_AVX512BW
+ typename std::enable_if<A_rows && B_cols && (ColRemain == 0)>::type
WriteImpl(const __m512i *from) {
Add(1, -B_cols).template WriteImpl<A_rows - 1, B_cols, B_cols>(from);
}
// On the last row, finish the last write with a mask.
- template <Index A_rows, Index B_cols, Index CurColumn> INTGEMM_AVX512BW
- typename std::enable_if<(A_rows == 1) && B_cols && (CurColumn < 16 && CurColumn > 0)>::type
+ template <Index A_rows, Index B_cols, Index ColRemain> INTGEMM_AVX512BW
+ typename std::enable_if<(A_rows == 1) && B_cols && (ColRemain < 16 && ColRemain > 0)>::type
WriteImpl(const __m512i *from) {
- _mm512_mask_storeu_epi32(data_, (1 << CurColumn) - 1, *from);
+ _mm512_mask_storeu_epi32(data_, (1 << ColRemain) - 1, *from);
}
// Nothing to write.
- template <Index A_rows, Index B_cols, Index CurColumn> INTGEMM_AVX512BW
+ template <Index A_rows, Index B_cols, Index ColRemain> INTGEMM_AVX512BW
typename std::enable_if<!A_rows || !B_cols>::type
WriteImpl(const __m512i *) {}