diff options
author | Kenneth Heafield <github@kheafield.com> | 2020-04-23 15:51:26 +0300 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2020-04-23 15:51:26 +0300 |
commit | a5be193995e2782e0f14122a35df0ac9c140ee3d (patch) | |
tree | 403d6e5a4a9eecebc4323ee5b6248edd2298c61e | |
parent | e12a54041950569742bac71167f8d8718e5932c1 (diff) |
General write working on AVX512, at least for tested cases
-rw-r--r-- | test/tile_test.inl | 15 | ||||
-rw-r--r-- | tile/access.h | 76 |
2 files changed, 53 insertions, 38 deletions
diff --git a/test/tile_test.inl b/test/tile_test.inl index 2a6c059..2e00040 100644 --- a/test/tile_test.inl +++ b/test/tile_test.inl @@ -245,6 +245,12 @@ TEST_CASE("MultiplyNoOverhang simple inner unroll " INTGEMM_TEST_NAME, "[tile][m TestMultiplyNoOverhang<Kernel>({5, sizeof(Register) * 2, 7}); } +TEST_CASE("MultiplyNoOverhang Simple 17 rows " INTGEMM_TEST_NAME, "[tile][multiply]") { + if (kCPU < CPUType::INTGEMM_ARCH) return; + typedef UnrollKernel<17, 1, 1, Signed8> Kernel; + TestMultiplyNoOverhang<Kernel>({17, sizeof(Register), 1}); +} + // Annoyingly, catch's cross-product stuff requires the first argument be a type, which is pretty useless for a cross-product of integers. TEMPLATE_TEST_CASE("MultiplyNoOverhang Unrolled Signed8 " INTGEMM_TEST_NAME, "[tile][multiply]", (UnrollKernel<1, 1, 1, Signed8>), @@ -270,13 +276,18 @@ TEMPLATE_TEST_CASE("MultiplyNoOverhang Unrolled Signed8 " INTGEMM_TEST_NAME, "[t (UnrollKernel<1, 1, 32, Signed8>), (UnrollKernel<2, 1, 1, Signed8>), (UnrollKernel<2, 1, 2, Signed8>), + (UnrollKernel<2, 1, 3, Signed8>), (UnrollKernel<3, 1, 1, Signed8>), (UnrollKernel<3, 1, 3, Signed8>), (UnrollKernel<4, 1, 1, Signed8>), (UnrollKernel<5, 1, 1, Signed8>), + (UnrollKernel<6, 1, 4, Signed8>), + (UnrollKernel<7, 1, 3, Signed8>), + (UnrollKernel<7, 1, 4, Signed8>), (UnrollKernel<15, 1, 1, Signed8>), - (UnrollKernel<16, 1, 1, Signed8>) -// (UnrollKernel<17, 1, 1, Signed8>) + (UnrollKernel<15, 1, 2, Signed8>), + (UnrollKernel<16, 1, 1, Signed8>), + (UnrollKernel<17, 1, 1, Signed8>) ) { if (kCPU < CPUType::INTGEMM_ARCH) return; TestMultiplyNoOverhangShapes<TestType>(); diff --git a/tile/access.h b/tile/access.h index 0279392..535a873 100644 --- a/tile/access.h +++ b/tile/access.h @@ -31,61 +31,65 @@ template <class T> class RowMajorAccess { private: // If there's a full register to write for a column, do that. - template <Index A_rows, Index B_cols, Index CurColumn> INTGEMM_AVX512BW - typename std::enable_if<A_rows && B_cols && (CurColumn >= 16)>::type + template <Index A_rows, Index B_cols, Index ColRemain> INTGEMM_AVX512BW + typename std::enable_if<A_rows && B_cols && (ColRemain >= 16)>::type WriteImpl(const __m512i *from) { _mm512_storeu_si512(data_, *from); - Add(0, 16).template WriteImpl<A_rows, B_cols, (CurColumn - 16)>(from + 1); + Add(0, 16).template WriteImpl<A_rows, B_cols, (ColRemain - 16)>(from + 1); } - // There is a mix of rows in a register and we need a scatter. Also, we are lucky enough to be at the beginning of a row. - // TODO case where we are not at the beginning of a register with _mm512_alignr_epi32. - template <Index A_rows, Index B_cols, Index CurColumn> INTGEMM_AVX512BW - typename std::enable_if<(A_rows > 1) && CurColumn && (CurColumn < 16) && (CurColumn == B_cols) /* beginning of row */>::type - WriteImpl(const __m512i *from) { - // TODO: test this more, also save it somewhere! Make sure compiler isn't recreating this every time. - const __m512i coefficients_lo = _mm512_set_epi32( - 15 / B_cols, 14 / B_cols, 13 / B_cols, 12 / B_cols, - 11 / B_cols, 10 / B_cols, 9 / B_cols, 8 / B_cols, - 7 / B_cols, 6 / B_cols, 5 / B_cols, 4 / B_cols, - 3 / B_cols, 2 / B_cols, 1 / B_cols, 0); + // TODO: test this more, also save it somewhere! Make sure compiler isn't recreating this every time. + template <Index B_cols, Index Off> INTGEMM_AVX512BW inline __m512i Offsets() { + const __m512i coefficients = _mm512_set_epi32( + (Off + 15) / B_cols, (Off + 14) / B_cols, (Off + 13) / B_cols, (Off + 12) / B_cols, + (Off + 11) / B_cols, (Off + 10) / B_cols, (Off + 9) / B_cols, (Off + 8) / B_cols, + (Off + 7) / B_cols, (Off + 6) / B_cols, (Off + 5) / B_cols, (Off + 4) / B_cols, + (Off + 3) / B_cols, (Off + 2) / B_cols, (Off + 1) / B_cols, Off / B_cols); + const __m512i row_offsets = _mm512_set_epi32( + (Off + 15) % B_cols, (Off + 14) % B_cols, (Off + 13) % B_cols, (Off + 12) % B_cols, + (Off + 11) % B_cols, (Off + 10) % B_cols, (Off + 9) % B_cols, (Off + 8) % B_cols, + (Off + 7) % B_cols, (Off + 6) % B_cols, (Off + 5) % B_cols, (Off + 4) % B_cols, + (Off + 3) % B_cols, (Off + 2) % B_cols, (Off + 1) % B_cols, Off % B_cols); + __m512i cols_reg = _mm512_set1_epi32(cols_); // Multiply by the number of columns for the offsets. - const __m512i multiplied = _mm512_mullo_epi32(cols_reg, coefficients_lo); - // Add row offsets. - const __m512i row_offsets = _mm512_set_epi32( - 15 % B_cols, 14 % B_cols, 13 % B_cols, 12 % B_cols, - 11 % B_cols, 10 % B_cols, 9 % B_cols, 8 % B_cols, - 7 % B_cols, 6 % B_cols, 5 % B_cols, 4 % B_cols, - 3 % B_cols, 2 % B_cols, 1 % B_cols, 0); - // These are the offsets to use if we're perfectly aligned: B_cols is a divisor or multiple of 16. - __m512i offsets = _mm512_add_epi32(row_offsets, multiplied); + const __m512i multiplied = _mm512_mullo_epi32(cols_reg, coefficients); + // These are the offsets to use if we're perfectly aligned at the beginning of a row. + return _mm512_add_epi32(row_offsets, multiplied); + } + + // There is a mix of rows in a register and we need a scatter. + template <Index A_rows, Index B_cols, Index ColRemain> INTGEMM_AVX512BW + typename std::enable_if<(A_rows > 1) && ColRemain && (ColRemain < 16)>::type + WriteImpl(const __m512i *from) { + __m512i offsets = Offsets<B_cols, B_cols - ColRemain>(); // We might be at the end of the data, in which case a mask is needed. - constexpr Index remaining = (A_rows - 1) * B_cols + CurColumn; - _mm512_mask_i32scatter_epi32(data_, (1 << remaining) - 1, offsets, *from, sizeof(int32_t)); - // We just wrote 16 values: CurColumn, the next row (all or partial), possibly the next etc. - // 16 - CurColumn of the next row and whatever followed. - constexpr Index WroteMore = ((remaining < 16) ? remaining : 16) - CurColumn; - // TODO: testing on this. - Add(WroteMore / B_cols, WroteMore % B_cols - CurColumn).template WriteImpl<A_rows - WroteMore / B_cols, B_cols, WroteMore % B_cols>(from); + constexpr Index remaining = (A_rows - 1) * B_cols + ColRemain; + _mm512_mask_i32scatter_epi32(data_ - (B_cols - ColRemain), static_cast<__mmask16>(1 << remaining) - 1, offsets, *from, sizeof(int32_t)); + // We just wrote 16 values: ColRemain, the next row (all or partial), possibly the next etc. + // 16 - ColRemain of the next row and whatever followed. + constexpr Index Wrote = ((remaining < 16) ? remaining : 16); + constexpr Index Position = (B_cols - ColRemain) + Wrote; + // TODO: more testing on this. + Add(Position / B_cols, Position % B_cols - (B_cols - ColRemain)).template WriteImpl<A_rows - (Position / B_cols), B_cols, B_cols - (Position % B_cols)>(from + 1); } // At clean end of column, move to next row. - template <Index A_rows, Index B_cols, Index CurColumn> INTGEMM_AVX512BW - typename std::enable_if<A_rows && B_cols && (CurColumn == 0)>::type + template <Index A_rows, Index B_cols, Index ColRemain> INTGEMM_AVX512BW + typename std::enable_if<A_rows && B_cols && (ColRemain == 0)>::type WriteImpl(const __m512i *from) { Add(1, -B_cols).template WriteImpl<A_rows - 1, B_cols, B_cols>(from); } // On the last row, finish the last write with a mask. - template <Index A_rows, Index B_cols, Index CurColumn> INTGEMM_AVX512BW - typename std::enable_if<(A_rows == 1) && B_cols && (CurColumn < 16 && CurColumn > 0)>::type + template <Index A_rows, Index B_cols, Index ColRemain> INTGEMM_AVX512BW + typename std::enable_if<(A_rows == 1) && B_cols && (ColRemain < 16 && ColRemain > 0)>::type WriteImpl(const __m512i *from) { - _mm512_mask_storeu_epi32(data_, (1 << CurColumn) - 1, *from); + _mm512_mask_storeu_epi32(data_, (1 << ColRemain) - 1, *from); } // Nothing to write. - template <Index A_rows, Index B_cols, Index CurColumn> INTGEMM_AVX512BW + template <Index A_rows, Index B_cols, Index ColRemain> INTGEMM_AVX512BW typename std::enable_if<!A_rows || !B_cols>::type WriteImpl(const __m512i *) {} |