diff options
author | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2020-04-30 21:02:54 +0300 |
---|---|---|
committer | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2020-05-05 21:43:54 +0300 |
commit | 987e40a494d9bb48e7ef94990c9bcf5cb2582ded (patch) | |
tree | d4bbf21ffa758498692480f07bb51ef1c29a98c9 | |
parent | 1280a3715a9338caddc29f04176fe66daa79895e (diff) |
Add basic support for callbacks to multiplication with tiles
-rw-r--r-- | benchmarks/benchmark_tile.cc | 4 | ||||
-rw-r--r-- | callbacks.h | 9 | ||||
-rw-r--r-- | callbacks/configs.h | 4 | ||||
-rw-r--r-- | callbacks/implementations.inl | 42 | ||||
-rw-r--r-- | test/tile_test.inl | 6 | ||||
-rw-r--r-- | tile/access.h | 55 | ||||
-rw-r--r-- | tile/multiply.h | 1 | ||||
-rw-r--r-- | tile/multiply.inl | 15 | ||||
-rw-r--r-- | vec_traits.h | 5 |
9 files changed, 97 insertions, 44 deletions
diff --git a/benchmarks/benchmark_tile.cc b/benchmarks/benchmark_tile.cc index f6ff0b7..801427f 100644 --- a/benchmarks/benchmark_tile.cc +++ b/benchmarks/benchmark_tile.cc @@ -93,9 +93,9 @@ template <Index A_rows, Index B_cols> static inline double BenchmarkNoOverhang(A typedef AVX512VNNI::UnrollKernel<A_rows, 1, B_cols, AVX512VNNI::Shifted8> Kernel; // Burn in. // TODO: different arches, guard against old compilers, etc. - AVX512VNNI::MultiplyNoOverhang<Kernel>(access, shape); + AVX512VNNI::MultiplyNoOverhang<Kernel>(access, shape, callbacks::Identity<Accessor::CContent>()); for (std::size_t t = 0; t < kTries; ++t) { - AVX512VNNI::MultiplyNoOverhang<Kernel>(access, shape); + AVX512VNNI::MultiplyNoOverhang<Kernel>(access, shape, callbacks::Identity<Accessor::CContent>()); } auto end = std::chrono::steady_clock::now(); return std::chrono::duration<double>(end - start).count() / kTries; diff --git a/callbacks.h b/callbacks.h index 24f9009..439992e 100644 --- a/callbacks.h +++ b/callbacks.h @@ -14,6 +14,10 @@ #include "callbacks/implementations.inl" #undef CALLBACKS_THIS_IS_SSE2 +#define CALLBACKS_THIS_IS_SSSE3 +#include "callbacks/implementations.inl" +#undef CALLBACKS_THIS_IS_SSSE3 + #define CALLBACKS_THIS_IS_AVX2 #include "callbacks/implementations.inl" #undef CALLBACKS_THIS_IS_AVX2 @@ -24,3 +28,8 @@ #undef CALLBACKS_THIS_IS_AVX512BW #endif +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI +#define CALLBACKS_THIS_IS_AVX512VNNI +#include "callbacks/implementations.inl" +#undef CALLBACKS_THIS_IS_AVX512VNNI +#endif diff --git a/callbacks/configs.h b/callbacks/configs.h index 1222448..fc8343f 100644 --- a/callbacks/configs.h +++ b/callbacks/configs.h @@ -20,6 +20,10 @@ struct Dummy { }; template <typename Type> +struct Identity { +}; + +template <typename Type> struct Write { Type* output_addr; diff --git a/callbacks/implementations.inl b/callbacks/implementations.inl index 25f8aa3..86984fb 100644 --- a/callbacks/implementations.inl +++ b/callbacks/implementations.inl @@ -2,25 +2,35 @@ #if defined(CALLBACKS_THIS_IS_SSE2) #define CPU_NAME SSE2 #define CPU_ATTR INTGEMM_SSE2 +#elif defined(CALLBACKS_THIS_IS_SSSE3) + #define CPU_NAME SSSE3 + #define CPU_ATTR INTGEMM_SSSE3 #elif defined(CALLBACKS_THIS_IS_AVX2) #define CPU_NAME AVX2 #define CPU_ATTR INTGEMM_AVX2 #elif defined(CALLBACKS_THIS_IS_AVX512BW) #define CPU_NAME AVX512BW #define CPU_ATTR INTGEMM_AVX512BW +#elif defined(CALLBACKS_THIS_IS_AVX512VNNI) + #define CPU_NAME AVX512VNNI + #define CPU_ATTR INTGEMM_AVX512VNNI #else - #error "Only SSE2, AVX2 and AVX512BW are supported" + #error "Only SSE2, SSSE3, AVX2, AVX512BW and AVX512VNNI are supported" #endif -#if defined(CALLBACKS_THIS_IS_SSE2) - #define vi vector_t<CPUType::SSE2, int> - #define vf vector_t<CPUType::SSE2, float> - #define vd vector_t<CPUType::SSE2, double> -#else - #define vi vector_t<CPUType::AVX2, int> - #define vf vector_t<CPUType::AVX2, float> - #define vd vector_t<CPUType::AVX2, double> -#endif +// #if defined(CALLBACKS_THIS_IS_SSE2) +// #define vi vector_t<CPUType::SSE2, int> +// #define vf vector_t<CPUType::SSE2, float> +// #define vd vector_t<CPUType::SSE2, double> +// #else +// #define vi vector_t<CPUType::AVX2, int> +// #define vf vector_t<CPUType::AVX2, float> +// #define vd vector_t<CPUType::AVX2, double> +// #endif + +#define vi vector_t<CPUType::CPU_NAME, int> +#define vf vector_t<CPUType::CPU_NAME, float> +#define vd vector_t<CPUType::CPU_NAME, double> namespace intgemm { namespace callbacks { @@ -86,6 +96,18 @@ public: }; /* + * Identity + */ +template <typename Type> +class CallbackImpl<CPUType::CPU_NAME, Identity<Type>> { +public: + CPU_ATTR CallbackImpl(const Identity<Type>&) {} + CPU_ATTR vector_t<CPUType::CPU_NAME, Type> operator()(vector_t<CPUType::CPU_NAME, Type> input, const OutputBufferInfo&) { + return input; + } +}; + +/* * Write */ template <typename Type> diff --git a/test/tile_test.inl b/test/tile_test.inl index 865e922..688f3f9 100644 --- a/test/tile_test.inl +++ b/test/tile_test.inl @@ -173,7 +173,7 @@ template <class Kernel> void TestMultiplyNoOverhang(Tile shape) { CHECK(shape.inner % Kernel::kTile.inner == 0); CHECK(shape.B_cols % Kernel::kTile.B_cols == 0); TestMatricesRef t(shape); - MultiplyNoOverhang<Kernel>(t.Accessor(), shape); + MultiplyNoOverhang<Kernel>(t.Accessor(), shape, callbacks::Identity<int32_t>()); CHECK(!memcmp(t.C_reference.begin(), t.C.begin(), shape.A_rows * shape.B_cols * sizeof(int32_t))); /* for (Index i = 0; i < shape.A_rows; ++i) { for (Index j = 0; j < shape.B_cols; ++j) { @@ -286,10 +286,10 @@ TEST_CASE("Multiply " INTGEMM_TEST_NAME, "[tile][multiply]") { for (shape.A_rows = 1; shape.A_rows < 33; ++shape.A_rows) { for (shape.B_cols = 1; shape.B_cols < 33; ++shape.B_cols) { TestMatricesRef t(shape); - Multiply<Signed8, 7, 3>(t.Accessor(), shape); + Multiply<Signed8, 7, 3>(t.Accessor(), shape, callbacks::Identity<int32_t>()); CHECK(!memcmp(t.C_reference.begin(), t.C.begin(), shape.A_rows * shape.B_cols * sizeof(int32_t))); memset(t.C.begin(), 0, shape.A_rows * shape.B_cols * sizeof(int32_t)); - Multiply<Signed8, 4, 5>(t.Accessor(), shape); + Multiply<Signed8, 4, 5>(t.Accessor(), shape, callbacks::Identity<int32_t>()); CHECK(!memcmp(t.C_reference.begin(), t.C.begin(), shape.A_rows * shape.B_cols * sizeof(int32_t))); } } diff --git a/tile/access.h b/tile/access.h index aa8365e..5faf9a1 100644 --- a/tile/access.h +++ b/tile/access.h @@ -3,6 +3,7 @@ #include <type_traits> #include "../types.h" +#include "../callbacks.h" namespace intgemm { @@ -12,30 +13,36 @@ template <class T> class RowMajorAccess { public: typedef T Content; - RowMajorAccess(Content *data, Index cols) - : data_(data), cols_(cols) {} + RowMajorAccess(Content *data, Index cols) : RowMajorAccess(data, cols, 0, 0) {} RowMajorAccess<Content> Add(Index row, Index col) const { - return RowMajorAccess<Content>(data_ + row * cols_ + col, cols_); + return RowMajorAccess<Content>(data_ + row * cols_ + col, cols_, row_idx_ + row, col_idx_ + col); } const Content &Front() const { return *data_; } Content &Front() { return *data_; } // TODO: SLOW. This is here for testing. - template <Index A_rows, Index B_cols> void Write(const __m128i *from) { SlowWrite<A_rows, B_cols>(reinterpret_cast<const T*>(from)); } - template <Index A_rows, Index B_cols> void Write(const __m256i *from) { SlowWrite<A_rows, B_cols>(reinterpret_cast<const T*>(from)); } - template <Index A_rows, Index B_cols> void Write(const __m512i *from) { - WriteImpl<A_rows, B_cols, B_cols>(from); + template <Index A_rows, Index B_cols, typename CallbackImpl> void Write(const __m128i *from, CallbackImpl&) { + SlowWrite<A_rows, B_cols>(reinterpret_cast<const T*>(from)); + } + template <Index A_rows, Index B_cols, typename CallbackImpl> void Write(const __m256i *from, CallbackImpl&) { + SlowWrite<A_rows, B_cols>(reinterpret_cast<const T*>(from)); + } + template <Index A_rows, Index B_cols, typename CallbackImpl> void Write(const __m512i *from, CallbackImpl& callback_impl) { + WriteImpl<A_rows, B_cols, B_cols>(from, callback_impl); } private: + RowMajorAccess(Content *data, Index cols, Index row_idx, Index col_idx) + : data_(data), cols_(cols), row_idx_(row_idx), col_idx_(col_idx) {} + // If there's a full register to write for a column, do that. - template <Index A_rows, Index B_cols, Index ColRemain> INTGEMM_AVX512BW + template <Index A_rows, Index B_cols, Index ColRemain, typename CallbackImpl> INTGEMM_AVX512BW typename std::enable_if<A_rows && B_cols && (ColRemain >= 16)>::type - WriteImpl(const __m512i *from) { - _mm512_storeu_si512(data_, *from); - Add(0, 16).template WriteImpl<A_rows, B_cols, (ColRemain - 16)>(from + 1); + WriteImpl(const __m512i *from, CallbackImpl& callback_impl) { + _mm512_storeu_si512(data_, callback_impl(*from, callbacks::OutputBufferInfo(row_idx_, col_idx_, 0, 0))); + Add(0, 16).template WriteImpl<A_rows, B_cols, (ColRemain - 16)>(from + 1, callback_impl); } // TODO: test this more, also save it somewhere! Make sure compiler isn't recreating this every time. @@ -59,9 +66,9 @@ template <class T> class RowMajorAccess { } // There is a mix of rows in a register and we need a scatter. - template <Index A_rows, Index B_cols, Index ColRemain> INTGEMM_AVX512BW + template <Index A_rows, Index B_cols, Index ColRemain, typename CallbackImpl> INTGEMM_AVX512BW typename std::enable_if<(A_rows > 1) && ColRemain && (ColRemain < 16)>::type - WriteImpl(const __m512i *from) { + WriteImpl(const __m512i *from, CallbackImpl& callback_impl) { __m512i offsets = Offsets<B_cols, B_cols - ColRemain>(); // We might be at the end of the data, in which case a mask is needed. constexpr Index remaining = (A_rows - 1) * B_cols + ColRemain; @@ -70,33 +77,33 @@ template <class T> class RowMajorAccess { // The offsets add B_cols - ColRemain so they can be correct modulo the number of columns. // So we subtract that from the data pointer. int32_t *go_back = data_ - (B_cols - ColRemain); - _mm512_mask_i32scatter_epi32(go_back, mask, offsets, *from, sizeof(int32_t)); + _mm512_mask_i32scatter_epi32(go_back, mask, offsets, callback_impl(*from, callbacks::OutputBufferInfo(row_idx_, col_idx_, 0, 0)), sizeof(int32_t)); // We just wrote 16 values: ColRemain, the next row (all or partial), possibly the next etc. // 16 - ColRemain of the next row and whatever followed. constexpr Index Wrote = ((remaining < 16) ? remaining : 16); constexpr Index Position = (B_cols - ColRemain) + Wrote; // TODO: more testing on this. - Add(Position / B_cols, Position % B_cols - (B_cols - ColRemain)).template WriteImpl<A_rows - (Position / B_cols), B_cols, B_cols - (Position % B_cols)>(from + 1); + Add(Position / B_cols, Position % B_cols - (B_cols - ColRemain)).template WriteImpl<A_rows - (Position / B_cols), B_cols, B_cols - (Position % B_cols)>(from + 1, callback_impl); } // At clean end of column, move to next row. - template <Index A_rows, Index B_cols, Index ColRemain> INTGEMM_AVX512BW + template <Index A_rows, Index B_cols, Index ColRemain, typename CallbackImpl> INTGEMM_AVX512BW typename std::enable_if<A_rows && B_cols && (ColRemain == 0)>::type - WriteImpl(const __m512i *from) { - Add(1, -B_cols).template WriteImpl<A_rows - 1, B_cols, B_cols>(from); + WriteImpl(const __m512i *from, CallbackImpl& callback_impl) { + Add(1, -B_cols).template WriteImpl<A_rows - 1, B_cols, B_cols>(from, callback_impl); } // On the last row, finish the last write with a mask. - template <Index A_rows, Index B_cols, Index ColRemain> INTGEMM_AVX512BW + template <Index A_rows, Index B_cols, Index ColRemain, typename CallbackImpl> INTGEMM_AVX512BW typename std::enable_if<(A_rows == 1) && B_cols && (ColRemain < 16 && ColRemain > 0)>::type - WriteImpl(const __m512i *from) { - _mm512_mask_storeu_epi32(data_, (1 << ColRemain) - 1, *from); + WriteImpl(const __m512i *from, CallbackImpl& callback_impl) { + _mm512_mask_storeu_epi32(data_, (1 << ColRemain) - 1, callback_impl(*from, callbacks::OutputBufferInfo(row_idx_, col_idx_, 0, 0))); } // Nothing to write. - template <Index A_rows, Index B_cols, Index ColRemain> INTGEMM_AVX512BW + template <Index A_rows, Index B_cols, Index ColRemain, typename CallbackImpl> INTGEMM_AVX512BW typename std::enable_if<!A_rows || !B_cols>::type - WriteImpl(const __m512i *) {} + WriteImpl(const __m512i *, CallbackImpl&) {} template <Index A_rows, Index B_cols> void SlowWrite(const T *from) { for (Index i = 0; i < A_rows; ++i) { @@ -108,6 +115,8 @@ template <class T> class RowMajorAccess { Content *data_; Index cols_; + Index row_idx_; + Index col_idx_; }; template <class T> class ColMajorAccess { diff --git a/tile/multiply.h b/tile/multiply.h index 294b3ed..0ef9a43 100644 --- a/tile/multiply.h +++ b/tile/multiply.h @@ -3,6 +3,7 @@ #include "access.h" #include "dot.h" #include "reduce.h" +#include "../callbacks.h" #include "../types.h" #include <cassert> diff --git a/tile/multiply.inl b/tile/multiply.inl index be86255..9036750 100644 --- a/tile/multiply.inl +++ b/tile/multiply.inl @@ -25,10 +25,13 @@ template <std::size_t... i> INTGEMM_TARGET static inline void Sum16To32(Register template <std::size_t... i> INTGEMM_TARGET static inline void Sum16To32(Register *, int32_t, index_sequence<i...>) {} /* Multiply assuming the matrix sizes are a multiple of the kernel size. */ -template <class Kernel, class AccessT> INTGEMM_TARGET __attribute__((flatten)) static inline void MultiplyNoOverhang(AccessT access, const Tile shape) { +template <class Kernel, class AccessT, class Callback> INTGEMM_TARGET __attribute__((flatten)) static inline void MultiplyNoOverhang(AccessT access, const Tile shape, Callback callback) { assert(shape.A_rows % Kernel::kTile.A_rows == 0); assert(shape.inner % Kernel::kTile.inner == 0); assert(shape.B_cols % Kernel::kTile.B_cols == 0); + + auto callback_impl = callbacks::CallbackImpl<CPUType::INTGEMM_ARCH, Callback>(callback); + for (Index B_col = 0; B_col < shape.B_cols; B_col += Kernel::kTile.B_cols) { AccessT column_adjusted = access.BAdd(0, B_col).CAdd(0, B_col); for (Index A_row = 0; A_row < shape.A_rows; A_row += Kernel::kTile.A_rows) { @@ -51,7 +54,7 @@ template <class Kernel, class AccessT> INTGEMM_TARGET __attribute__((flatten)) s Sum16To32(c_regs, typename Kernel::Packed::C(), make_index_sequence<Outputs>()); // Horizontally add 32-bit values. Reduce32<Outputs, Sum32Op>(c_regs); - col_row.CAccessor().template Write<Kernel::kTile.A_rows, Kernel::kTile.B_cols>(c_regs); + col_row.CAccessor().template Write<Kernel::kTile.A_rows, Kernel::kTile.B_cols>(c_regs, callback_impl); } } } @@ -64,7 +67,7 @@ template <class Kernel, class AccessT> INTGEMM_TARGET __attribute__((flatten)) s * A_rows and B_cols specify the unrolled kernel size to use for most of the * multiply; these impact speed but not output. */ -template <class Kernel, Index A_rows, Index B_cols, class AccessT> INTGEMM_TARGET static inline void Multiply(AccessT access, const Tile shape) { +template <class Kernel, Index A_rows, Index B_cols, class Access, class Callback> INTGEMM_TARGET static inline void Multiply(Access access, const Tile shape, Callback callback) { // Still has to be a multiple of the underlying Kernel, but usually that's just 1 x sizeof(Register) x 1. assert(shape.A_rows % Kernel::kTile.A_rows == 0); assert(shape.inner % Kernel::kTile.inner == 0); @@ -82,15 +85,15 @@ template <class Kernel, Index A_rows, Index B_cols, class AccessT> INTGEMM_TARGE shape.B_cols - overhang.B_cols }; // Top left corner. - MultiplyNoOverhang<Big>(access, big_shape); + MultiplyNoOverhang<Big>(access, big_shape, callback); // Bottom currently including right side. TODO: unrolled kernel, rather than dumb loop. MultiplyNoOverhang<Kernel>( access.AAdd(big_shape.A_rows, 0).CAdd(big_shape.A_rows, 0), - Tile {overhang.A_rows, shape.inner, shape.B_cols}); + Tile {overhang.A_rows, shape.inner, shape.B_cols}, callback); // Right side except bottom. TODO: unrolled kernel, rather than dumb loop. MultiplyNoOverhang<Kernel>( access.BAdd(0, big_shape.B_cols).CAdd(0, big_shape.B_cols), - Tile {big_shape.A_rows, shape.inner, overhang.B_cols}); + Tile {big_shape.A_rows, shape.inner, overhang.B_cols}, callback); } } // namespace INTGEMM_ARCH diff --git a/vec_traits.h b/vec_traits.h index 6c4ec28..059aa83 100644 --- a/vec_traits.h +++ b/vec_traits.h @@ -28,6 +28,11 @@ template <> struct vector_s<CPUType::AVX512BW, int16_t> { using type = __m512i; template <> struct vector_s<CPUType::AVX512BW, int> { using type = __m512i; }; template <> struct vector_s<CPUType::AVX512BW, float> { using type = __m512; }; template <> struct vector_s<CPUType::AVX512BW, double> { using type = __m512d; }; +template <> struct vector_s<CPUType::AVX512VNNI, int8_t> { using type = __m512i; }; +template <> struct vector_s<CPUType::AVX512VNNI, int16_t> { using type = __m512i; }; +template <> struct vector_s<CPUType::AVX512VNNI, int> { using type = __m512i; }; +template <> struct vector_s<CPUType::AVX512VNNI, float> { using type = __m512; }; +template <> struct vector_s<CPUType::AVX512VNNI, double> { using type = __m512d; }; template <CPUType CPUType_, typename ElemType_> using vector_t = typename vector_s<CPUType_, ElemType_>::type; |