diff options
author | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2020-04-24 19:51:43 +0300 |
---|---|---|
committer | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2020-04-25 01:39:24 +0300 |
commit | a44dd8607f81e652dc8310a80aa70f79bcdc97f2 (patch) | |
tree | 316f194990df16f41d60b4f1a191b6a0596f76de | |
parent | 6377ee4d9f051d7be0c9c290bb33ab66f27ea900 (diff) |
Add WriteCallback
-rw-r--r-- | tile/access.h | 84 | ||||
-rw-r--r-- | tile/callbacks.h | 95 | ||||
-rw-r--r-- | tile/multiply.h | 1 | ||||
-rw-r--r-- | tile/multiply.inl | 3 |
4 files changed, 102 insertions, 81 deletions
diff --git a/tile/access.h b/tile/access.h index e8e09e5..11f67ec 100644 --- a/tile/access.h +++ b/tile/access.h @@ -19,90 +19,12 @@ template <class T> class RowMajorAccess { return RowMajorAccess<Content>(data_ + row * cols_ + col, cols_); } + Index Cols() const { return cols_; } + const Content &Front() const { return *data_; } Content &Front() { return *data_; } - // TODO: SLOW. This is here for testing. - template <Index A_rows, Index B_cols> void Write(const __m128i *from) { SlowWrite<A_rows, B_cols>(reinterpret_cast<const T*>(from)); } - template <Index A_rows, Index B_cols> void Write(const __m256i *from) { SlowWrite<A_rows, B_cols>(reinterpret_cast<const T*>(from)); } - template <Index A_rows, Index B_cols> void Write(const __m512i *from) { - WriteImpl<A_rows, B_cols, B_cols>(from); - } - private: - // If there's a full register to write for a column, do that. - template <Index A_rows, Index B_cols, Index ColRemain> INTGEMM_AVX512BW - typename std::enable_if<A_rows && B_cols && (ColRemain >= 16)>::type - WriteImpl(const __m512i *from) { - _mm512_storeu_si512(data_, *from); - Add(0, 16).template WriteImpl<A_rows, B_cols, (ColRemain - 16)>(from + 1); - } - - // TODO: test this more, also save it somewhere! Make sure compiler isn't recreating this every time. - template <Index B_cols, Index Off> INTGEMM_AVX512BW inline __m512i Offsets() { - const __m512i coefficients = _mm512_set_epi32( - (Off + 15) / B_cols, (Off + 14) / B_cols, (Off + 13) / B_cols, (Off + 12) / B_cols, - (Off + 11) / B_cols, (Off + 10) / B_cols, (Off + 9) / B_cols, (Off + 8) / B_cols, - (Off + 7) / B_cols, (Off + 6) / B_cols, (Off + 5) / B_cols, (Off + 4) / B_cols, - (Off + 3) / B_cols, (Off + 2) / B_cols, (Off + 1) / B_cols, Off / B_cols); - const __m512i row_offsets = _mm512_set_epi32( - (Off + 15) % B_cols, (Off + 14) % B_cols, (Off + 13) % B_cols, (Off + 12) % B_cols, - (Off + 11) % B_cols, (Off + 10) % B_cols, (Off + 9) % B_cols, (Off + 8) % B_cols, - (Off + 7) % B_cols, (Off + 6) % B_cols, (Off + 5) % B_cols, (Off + 4) % B_cols, - (Off + 3) % B_cols, (Off + 2) % B_cols, (Off + 1) % B_cols, Off % B_cols); - - __m512i cols_reg = _mm512_set1_epi32(cols_); - // Multiply by the number of columns for the offsets. - const __m512i multiplied = _mm512_mullo_epi32(cols_reg, coefficients); - // These are the offsets to use if we're perfectly aligned at the beginning of a row. - return _mm512_add_epi32(row_offsets, multiplied); - } - - // There is a mix of rows in a register and we need a scatter. - template <Index A_rows, Index B_cols, Index ColRemain> INTGEMM_AVX512BW - typename std::enable_if<(A_rows > 1) && ColRemain && (ColRemain < 16)>::type - WriteImpl(const __m512i *from) { - __m512i offsets = Offsets<B_cols, B_cols - ColRemain>(); - // We might be at the end of the data, in which case a mask is needed. - constexpr Index remaining = (A_rows - 1) * B_cols + ColRemain; - // Compilers seem to complain a lot about shifting past the end :-( - constexpr __mmask16 mask = (remaining >= 16) ? 0xffff : (static_cast<__mmask16>(1 << remaining) - 1); - _mm512_mask_i32scatter_epi32(data_ - (B_cols - ColRemain), mask, offsets, *from, sizeof(int32_t)); - // We just wrote 16 values: ColRemain, the next row (all or partial), possibly the next etc. - // 16 - ColRemain of the next row and whatever followed. - constexpr Index Wrote = ((remaining < 16) ? remaining : 16); - constexpr Index Position = (B_cols - ColRemain) + Wrote; - // TODO: more testing on this. - Add(Position / B_cols, Position % B_cols - (B_cols - ColRemain)).template WriteImpl<A_rows - (Position / B_cols), B_cols, B_cols - (Position % B_cols)>(from + 1); - } - - // At clean end of column, move to next row. - template <Index A_rows, Index B_cols, Index ColRemain> INTGEMM_AVX512BW - typename std::enable_if<A_rows && B_cols && (ColRemain == 0)>::type - WriteImpl(const __m512i *from) { - Add(1, -B_cols).template WriteImpl<A_rows - 1, B_cols, B_cols>(from); - } - - // On the last row, finish the last write with a mask. - template <Index A_rows, Index B_cols, Index ColRemain> INTGEMM_AVX512BW - typename std::enable_if<(A_rows == 1) && B_cols && (ColRemain < 16 && ColRemain > 0)>::type - WriteImpl(const __m512i *from) { - _mm512_mask_storeu_epi32(data_, (1 << ColRemain) - 1, *from); - } - - // Nothing to write. - template <Index A_rows, Index B_cols, Index ColRemain> INTGEMM_AVX512BW - typename std::enable_if<!A_rows || !B_cols>::type - WriteImpl(const __m512i *) {} - - template <Index A_rows, Index B_cols> void SlowWrite(const T *from) { - for (Index i = 0; i < A_rows; ++i) { - for (Index j = 0; j < B_cols; ++j) { - data_[i * cols_ + j] = from[i * B_cols + j]; - } - } - } - Content *data_; Index cols_; }; @@ -118,6 +40,8 @@ template <class T> class ColMajorAccess { return ColMajorAccess<Content>(data_ + row + col * rows_, rows_); } + Index Rows() const { return rows_; } + const Content &Front() const { return *data_; } Content &Front() { return *data_; } diff --git a/tile/callbacks.h b/tile/callbacks.h new file mode 100644 index 0000000..3c7c00f --- /dev/null +++ b/tile/callbacks.h @@ -0,0 +1,95 @@ +#pragma once + +#include <type_traits> + +#include "../types.h" + +namespace intgemm { + +class WriteCallback { + public: + struct Config {}; + + // TODO: SLOW. This is here for testing. + template <Index A_rows, Index B_cols, typename Access> static void Run(Access access, const __m128i *from, const Config&) { Slow<A_rows, B_cols>(access, reinterpret_cast<const typename Access::Content*>(from)); } + template <Index A_rows, Index B_cols, typename Access> static void Run(Access access, const __m256i *from, const Config&) { Slow<A_rows, B_cols>(access, reinterpret_cast<const typename Access::Content*>(from)); } + template <Index A_rows, Index B_cols, typename Access> static void Run(Access access, const __m512i *from, const Config&) { + RunImpl<A_rows, B_cols, B_cols>(access, from); + } + + private: + // If there's a full register to write for a column, do that. + template <Index A_rows, Index B_cols, Index ColRemain, typename Access> INTGEMM_AVX512BW + static typename std::enable_if<A_rows && B_cols && (ColRemain >= 16)>::type + RunImpl(Access access, const __m512i *from) { + _mm512_storeu_si512(&access.Front(), *from); + RunImpl<A_rows, B_cols, (ColRemain - 16)>(access.Add(0, 16), from + 1); + } + + // TODO: test this more, also save it somewhere! Make sure compiler isn't recreating this every time. + template <Index B_cols, Index Off, typename Access> INTGEMM_AVX512BW static inline __m512i Offsets(Access access) { + const __m512i coefficients = _mm512_set_epi32( + (Off + 15) / B_cols, (Off + 14) / B_cols, (Off + 13) / B_cols, (Off + 12) / B_cols, + (Off + 11) / B_cols, (Off + 10) / B_cols, (Off + 9) / B_cols, (Off + 8) / B_cols, + (Off + 7) / B_cols, (Off + 6) / B_cols, (Off + 5) / B_cols, (Off + 4) / B_cols, + (Off + 3) / B_cols, (Off + 2) / B_cols, (Off + 1) / B_cols, Off / B_cols); + const __m512i row_offsets = _mm512_set_epi32( + (Off + 15) % B_cols, (Off + 14) % B_cols, (Off + 13) % B_cols, (Off + 12) % B_cols, + (Off + 11) % B_cols, (Off + 10) % B_cols, (Off + 9) % B_cols, (Off + 8) % B_cols, + (Off + 7) % B_cols, (Off + 6) % B_cols, (Off + 5) % B_cols, (Off + 4) % B_cols, + (Off + 3) % B_cols, (Off + 2) % B_cols, (Off + 1) % B_cols, Off % B_cols); + + __m512i cols_reg = _mm512_set1_epi32(access.Cols()); + // Multiply by the number of columns for the offsets. + const __m512i multiplied = _mm512_mullo_epi32(cols_reg, coefficients); + // These are the offsets to use if we're perfectly aligned at the beginning of a row. + return _mm512_add_epi32(row_offsets, multiplied); + } + + // There is a mix of rows in a register and we need a scatter. + template <Index A_rows, Index B_cols, Index ColRemain, typename Access> INTGEMM_AVX512BW + static typename std::enable_if<(A_rows > 1) && ColRemain && (ColRemain < 16)>::type + RunImpl(Access access, const __m512i *from) { + __m512i offsets = Offsets<B_cols, B_cols - ColRemain>(access); + // We might be at the end of the data, in which case a mask is needed. + constexpr Index remaining = (A_rows - 1) * B_cols + ColRemain; + // Compilers seem to complain a lot about shifting past the end :-( + constexpr __mmask16 mask = (remaining >= 16) ? 0xffff : (static_cast<__mmask16>(1 << remaining) - 1); + _mm512_mask_i32scatter_epi32(&access.Front() - (B_cols - ColRemain), mask, offsets, *from, sizeof(int32_t)); + // We just wrote 16 values: ColRemain, the next row (all or partial), possibly the next etc. + // 16 - ColRemain of the next row and whatever followed. + constexpr Index Wrote = ((remaining < 16) ? remaining : 16); + constexpr Index Position = (B_cols - ColRemain) + Wrote; + // TODO: more testing on this. + RunImpl<A_rows - (Position / B_cols), B_cols, B_cols - (Position % B_cols)>(access.Add(Position / B_cols, Position % B_cols - (B_cols - ColRemain)), from + 1); + } + + // At clean end of column, move to next row. + template <Index A_rows, Index B_cols, Index ColRemain, typename Access> INTGEMM_AVX512BW + static typename std::enable_if<A_rows && B_cols && (ColRemain == 0)>::type + RunImpl(Access access, const __m512i *from) { + RunImpl<A_rows - 1, B_cols, B_cols>(access.Add(1, -B_cols), from); + } + + // On the last row, finish the last write with a mask. + template <Index A_rows, Index B_cols, Index ColRemain, typename Access> INTGEMM_AVX512BW + static typename std::enable_if<(A_rows == 1) && B_cols && (ColRemain < 16 && ColRemain > 0)>::type + RunImpl(Access access, const __m512i *from) { + _mm512_mask_storeu_epi32(&access.Front(), (1 << ColRemain) - 1, *from); + } + + // Nothing to write. + template <Index A_rows, Index B_cols, Index ColRemain, typename Access> INTGEMM_AVX512BW + static typename std::enable_if<!A_rows || !B_cols>::type + RunImpl(Access, const __m512i *) {} + + template <Index A_rows, Index B_cols, typename Access> static void Slow(Access access, const typename Access::Content *from) { + for (Index i = 0; i < A_rows; ++i) { + for (Index j = 0; j < B_cols; ++j) { + (&access.Front())[i * access.Cols() + j] = from[i * B_cols + j]; + } + } + } +}; + +} // namespace intgemm diff --git a/tile/multiply.h b/tile/multiply.h index 294b3ed..2c0ca46 100644 --- a/tile/multiply.h +++ b/tile/multiply.h @@ -1,6 +1,7 @@ #pragma once #include "access.h" +#include "callbacks.h" #include "dot.h" #include "reduce.h" #include "../types.h" diff --git a/tile/multiply.inl b/tile/multiply.inl index a1a92cf..4d208d9 100644 --- a/tile/multiply.inl +++ b/tile/multiply.inl @@ -51,7 +51,8 @@ template <class AccessT, class Kernel> INTGEMM_TARGET __attribute__((flatten)) s Sum16To32(c_regs, typename Kernel::Packed::C(), make_index_sequence<Outputs>()); // Horizontally add 32-bit values. Reduce32<Outputs, Sum32Op>(c_regs); - col_row.CAccessor().template Write<Kernel::kTile.A_rows, Kernel::kTile.B_cols>(c_regs); + // col_row.CAccessor().template Write<Kernel::kTile.A_rows, Kernel::kTile.B_cols>(c_regs); + WriteCallback::template Run<Kernel::kTile.A_rows, Kernel::kTile.B_cols>(col_row.CAccessor(), c_regs); } } } |