Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Chudyk <mateuszchudyk@gmail.com>2020-04-30 21:02:54 +0300
committerMateusz Chudyk <mateuszchudyk@gmail.com>2020-05-05 21:43:54 +0300
commit987e40a494d9bb48e7ef94990c9bcf5cb2582ded (patch)
treed4bbf21ffa758498692480f07bb51ef1c29a98c9
parent1280a3715a9338caddc29f04176fe66daa79895e (diff)
Add basic support for callbacks to multiplication with tiles
-rw-r--r--benchmarks/benchmark_tile.cc4
-rw-r--r--callbacks.h9
-rw-r--r--callbacks/configs.h4
-rw-r--r--callbacks/implementations.inl42
-rw-r--r--test/tile_test.inl6
-rw-r--r--tile/access.h55
-rw-r--r--tile/multiply.h1
-rw-r--r--tile/multiply.inl15
-rw-r--r--vec_traits.h5
9 files changed, 97 insertions, 44 deletions
diff --git a/benchmarks/benchmark_tile.cc b/benchmarks/benchmark_tile.cc
index f6ff0b7..801427f 100644
--- a/benchmarks/benchmark_tile.cc
+++ b/benchmarks/benchmark_tile.cc
@@ -93,9 +93,9 @@ template <Index A_rows, Index B_cols> static inline double BenchmarkNoOverhang(A
typedef AVX512VNNI::UnrollKernel<A_rows, 1, B_cols, AVX512VNNI::Shifted8> Kernel;
// Burn in.
// TODO: different arches, guard against old compilers, etc.
- AVX512VNNI::MultiplyNoOverhang<Kernel>(access, shape);
+ AVX512VNNI::MultiplyNoOverhang<Kernel>(access, shape, callbacks::Identity<Accessor::CContent>());
for (std::size_t t = 0; t < kTries; ++t) {
- AVX512VNNI::MultiplyNoOverhang<Kernel>(access, shape);
+ AVX512VNNI::MultiplyNoOverhang<Kernel>(access, shape, callbacks::Identity<Accessor::CContent>());
}
auto end = std::chrono::steady_clock::now();
return std::chrono::duration<double>(end - start).count() / kTries;
diff --git a/callbacks.h b/callbacks.h
index 24f9009..439992e 100644
--- a/callbacks.h
+++ b/callbacks.h
@@ -14,6 +14,10 @@
#include "callbacks/implementations.inl"
#undef CALLBACKS_THIS_IS_SSE2
+#define CALLBACKS_THIS_IS_SSSE3
+#include "callbacks/implementations.inl"
+#undef CALLBACKS_THIS_IS_SSSE3
+
#define CALLBACKS_THIS_IS_AVX2
#include "callbacks/implementations.inl"
#undef CALLBACKS_THIS_IS_AVX2
@@ -24,3 +28,8 @@
#undef CALLBACKS_THIS_IS_AVX512BW
#endif
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
+#define CALLBACKS_THIS_IS_AVX512VNNI
+#include "callbacks/implementations.inl"
+#undef CALLBACKS_THIS_IS_AVX512VNNI
+#endif
diff --git a/callbacks/configs.h b/callbacks/configs.h
index 1222448..fc8343f 100644
--- a/callbacks/configs.h
+++ b/callbacks/configs.h
@@ -20,6 +20,10 @@ struct Dummy {
};
template <typename Type>
+struct Identity {
+};
+
+template <typename Type>
struct Write {
Type* output_addr;
diff --git a/callbacks/implementations.inl b/callbacks/implementations.inl
index 25f8aa3..86984fb 100644
--- a/callbacks/implementations.inl
+++ b/callbacks/implementations.inl
@@ -2,25 +2,35 @@
#if defined(CALLBACKS_THIS_IS_SSE2)
#define CPU_NAME SSE2
#define CPU_ATTR INTGEMM_SSE2
+#elif defined(CALLBACKS_THIS_IS_SSSE3)
+ #define CPU_NAME SSSE3
+ #define CPU_ATTR INTGEMM_SSSE3
#elif defined(CALLBACKS_THIS_IS_AVX2)
#define CPU_NAME AVX2
#define CPU_ATTR INTGEMM_AVX2
#elif defined(CALLBACKS_THIS_IS_AVX512BW)
#define CPU_NAME AVX512BW
#define CPU_ATTR INTGEMM_AVX512BW
+#elif defined(CALLBACKS_THIS_IS_AVX512VNNI)
+ #define CPU_NAME AVX512VNNI
+ #define CPU_ATTR INTGEMM_AVX512VNNI
#else
- #error "Only SSE2, AVX2 and AVX512BW are supported"
+ #error "Only SSE2, SSSE3, AVX2, AVX512BW and AVX512VNNI are supported"
#endif
-#if defined(CALLBACKS_THIS_IS_SSE2)
- #define vi vector_t<CPUType::SSE2, int>
- #define vf vector_t<CPUType::SSE2, float>
- #define vd vector_t<CPUType::SSE2, double>
-#else
- #define vi vector_t<CPUType::AVX2, int>
- #define vf vector_t<CPUType::AVX2, float>
- #define vd vector_t<CPUType::AVX2, double>
-#endif
+// #if defined(CALLBACKS_THIS_IS_SSE2)
+// #define vi vector_t<CPUType::SSE2, int>
+// #define vf vector_t<CPUType::SSE2, float>
+// #define vd vector_t<CPUType::SSE2, double>
+// #else
+// #define vi vector_t<CPUType::AVX2, int>
+// #define vf vector_t<CPUType::AVX2, float>
+// #define vd vector_t<CPUType::AVX2, double>
+// #endif
+
+#define vi vector_t<CPUType::CPU_NAME, int>
+#define vf vector_t<CPUType::CPU_NAME, float>
+#define vd vector_t<CPUType::CPU_NAME, double>
namespace intgemm {
namespace callbacks {
@@ -86,6 +96,18 @@ public:
};
/*
+ * Identity
+ */
+template <typename Type>
+class CallbackImpl<CPUType::CPU_NAME, Identity<Type>> {
+public:
+ CPU_ATTR CallbackImpl(const Identity<Type>&) {}
+ CPU_ATTR vector_t<CPUType::CPU_NAME, Type> operator()(vector_t<CPUType::CPU_NAME, Type> input, const OutputBufferInfo&) {
+ return input;
+ }
+};
+
+/*
* Write
*/
template <typename Type>
diff --git a/test/tile_test.inl b/test/tile_test.inl
index 865e922..688f3f9 100644
--- a/test/tile_test.inl
+++ b/test/tile_test.inl
@@ -173,7 +173,7 @@ template <class Kernel> void TestMultiplyNoOverhang(Tile shape) {
CHECK(shape.inner % Kernel::kTile.inner == 0);
CHECK(shape.B_cols % Kernel::kTile.B_cols == 0);
TestMatricesRef t(shape);
- MultiplyNoOverhang<Kernel>(t.Accessor(), shape);
+ MultiplyNoOverhang<Kernel>(t.Accessor(), shape, callbacks::Identity<int32_t>());
CHECK(!memcmp(t.C_reference.begin(), t.C.begin(), shape.A_rows * shape.B_cols * sizeof(int32_t)));
/* for (Index i = 0; i < shape.A_rows; ++i) {
for (Index j = 0; j < shape.B_cols; ++j) {
@@ -286,10 +286,10 @@ TEST_CASE("Multiply " INTGEMM_TEST_NAME, "[tile][multiply]") {
for (shape.A_rows = 1; shape.A_rows < 33; ++shape.A_rows) {
for (shape.B_cols = 1; shape.B_cols < 33; ++shape.B_cols) {
TestMatricesRef t(shape);
- Multiply<Signed8, 7, 3>(t.Accessor(), shape);
+ Multiply<Signed8, 7, 3>(t.Accessor(), shape, callbacks::Identity<int32_t>());
CHECK(!memcmp(t.C_reference.begin(), t.C.begin(), shape.A_rows * shape.B_cols * sizeof(int32_t)));
memset(t.C.begin(), 0, shape.A_rows * shape.B_cols * sizeof(int32_t));
- Multiply<Signed8, 4, 5>(t.Accessor(), shape);
+ Multiply<Signed8, 4, 5>(t.Accessor(), shape, callbacks::Identity<int32_t>());
CHECK(!memcmp(t.C_reference.begin(), t.C.begin(), shape.A_rows * shape.B_cols * sizeof(int32_t)));
}
}
diff --git a/tile/access.h b/tile/access.h
index aa8365e..5faf9a1 100644
--- a/tile/access.h
+++ b/tile/access.h
@@ -3,6 +3,7 @@
#include <type_traits>
#include "../types.h"
+#include "../callbacks.h"
namespace intgemm {
@@ -12,30 +13,36 @@ template <class T> class RowMajorAccess {
public:
typedef T Content;
- RowMajorAccess(Content *data, Index cols)
- : data_(data), cols_(cols) {}
+ RowMajorAccess(Content *data, Index cols) : RowMajorAccess(data, cols, 0, 0) {}
RowMajorAccess<Content> Add(Index row, Index col) const {
- return RowMajorAccess<Content>(data_ + row * cols_ + col, cols_);
+ return RowMajorAccess<Content>(data_ + row * cols_ + col, cols_, row_idx_ + row, col_idx_ + col);
}
const Content &Front() const { return *data_; }
Content &Front() { return *data_; }
// TODO: SLOW. This is here for testing.
- template <Index A_rows, Index B_cols> void Write(const __m128i *from) { SlowWrite<A_rows, B_cols>(reinterpret_cast<const T*>(from)); }
- template <Index A_rows, Index B_cols> void Write(const __m256i *from) { SlowWrite<A_rows, B_cols>(reinterpret_cast<const T*>(from)); }
- template <Index A_rows, Index B_cols> void Write(const __m512i *from) {
- WriteImpl<A_rows, B_cols, B_cols>(from);
+ template <Index A_rows, Index B_cols, typename CallbackImpl> void Write(const __m128i *from, CallbackImpl&) {
+ SlowWrite<A_rows, B_cols>(reinterpret_cast<const T*>(from));
+ }
+ template <Index A_rows, Index B_cols, typename CallbackImpl> void Write(const __m256i *from, CallbackImpl&) {
+ SlowWrite<A_rows, B_cols>(reinterpret_cast<const T*>(from));
+ }
+ template <Index A_rows, Index B_cols, typename CallbackImpl> void Write(const __m512i *from, CallbackImpl& callback_impl) {
+ WriteImpl<A_rows, B_cols, B_cols>(from, callback_impl);
}
private:
+ RowMajorAccess(Content *data, Index cols, Index row_idx, Index col_idx)
+ : data_(data), cols_(cols), row_idx_(row_idx), col_idx_(col_idx) {}
+
// If there's a full register to write for a column, do that.
- template <Index A_rows, Index B_cols, Index ColRemain> INTGEMM_AVX512BW
+ template <Index A_rows, Index B_cols, Index ColRemain, typename CallbackImpl> INTGEMM_AVX512BW
typename std::enable_if<A_rows && B_cols && (ColRemain >= 16)>::type
- WriteImpl(const __m512i *from) {
- _mm512_storeu_si512(data_, *from);
- Add(0, 16).template WriteImpl<A_rows, B_cols, (ColRemain - 16)>(from + 1);
+ WriteImpl(const __m512i *from, CallbackImpl& callback_impl) {
+ _mm512_storeu_si512(data_, callback_impl(*from, callbacks::OutputBufferInfo(row_idx_, col_idx_, 0, 0)));
+ Add(0, 16).template WriteImpl<A_rows, B_cols, (ColRemain - 16)>(from + 1, callback_impl);
}
// TODO: test this more, also save it somewhere! Make sure compiler isn't recreating this every time.
@@ -59,9 +66,9 @@ template <class T> class RowMajorAccess {
}
// There is a mix of rows in a register and we need a scatter.
- template <Index A_rows, Index B_cols, Index ColRemain> INTGEMM_AVX512BW
+ template <Index A_rows, Index B_cols, Index ColRemain, typename CallbackImpl> INTGEMM_AVX512BW
typename std::enable_if<(A_rows > 1) && ColRemain && (ColRemain < 16)>::type
- WriteImpl(const __m512i *from) {
+ WriteImpl(const __m512i *from, CallbackImpl& callback_impl) {
__m512i offsets = Offsets<B_cols, B_cols - ColRemain>();
// We might be at the end of the data, in which case a mask is needed.
constexpr Index remaining = (A_rows - 1) * B_cols + ColRemain;
@@ -70,33 +77,33 @@ template <class T> class RowMajorAccess {
// The offsets add B_cols - ColRemain so they can be correct modulo the number of columns.
// So we subtract that from the data pointer.
int32_t *go_back = data_ - (B_cols - ColRemain);
- _mm512_mask_i32scatter_epi32(go_back, mask, offsets, *from, sizeof(int32_t));
+ _mm512_mask_i32scatter_epi32(go_back, mask, offsets, callback_impl(*from, callbacks::OutputBufferInfo(row_idx_, col_idx_, 0, 0)), sizeof(int32_t));
// We just wrote 16 values: ColRemain, the next row (all or partial), possibly the next etc.
// 16 - ColRemain of the next row and whatever followed.
constexpr Index Wrote = ((remaining < 16) ? remaining : 16);
constexpr Index Position = (B_cols - ColRemain) + Wrote;
// TODO: more testing on this.
- Add(Position / B_cols, Position % B_cols - (B_cols - ColRemain)).template WriteImpl<A_rows - (Position / B_cols), B_cols, B_cols - (Position % B_cols)>(from + 1);
+ Add(Position / B_cols, Position % B_cols - (B_cols - ColRemain)).template WriteImpl<A_rows - (Position / B_cols), B_cols, B_cols - (Position % B_cols)>(from + 1, callback_impl);
}
// At clean end of column, move to next row.
- template <Index A_rows, Index B_cols, Index ColRemain> INTGEMM_AVX512BW
+ template <Index A_rows, Index B_cols, Index ColRemain, typename CallbackImpl> INTGEMM_AVX512BW
typename std::enable_if<A_rows && B_cols && (ColRemain == 0)>::type
- WriteImpl(const __m512i *from) {
- Add(1, -B_cols).template WriteImpl<A_rows - 1, B_cols, B_cols>(from);
+ WriteImpl(const __m512i *from, CallbackImpl& callback_impl) {
+ Add(1, -B_cols).template WriteImpl<A_rows - 1, B_cols, B_cols>(from, callback_impl);
}
// On the last row, finish the last write with a mask.
- template <Index A_rows, Index B_cols, Index ColRemain> INTGEMM_AVX512BW
+ template <Index A_rows, Index B_cols, Index ColRemain, typename CallbackImpl> INTGEMM_AVX512BW
typename std::enable_if<(A_rows == 1) && B_cols && (ColRemain < 16 && ColRemain > 0)>::type
- WriteImpl(const __m512i *from) {
- _mm512_mask_storeu_epi32(data_, (1 << ColRemain) - 1, *from);
+ WriteImpl(const __m512i *from, CallbackImpl& callback_impl) {
+ _mm512_mask_storeu_epi32(data_, (1 << ColRemain) - 1, callback_impl(*from, callbacks::OutputBufferInfo(row_idx_, col_idx_, 0, 0)));
}
// Nothing to write.
- template <Index A_rows, Index B_cols, Index ColRemain> INTGEMM_AVX512BW
+ template <Index A_rows, Index B_cols, Index ColRemain, typename CallbackImpl> INTGEMM_AVX512BW
typename std::enable_if<!A_rows || !B_cols>::type
- WriteImpl(const __m512i *) {}
+ WriteImpl(const __m512i *, CallbackImpl&) {}
template <Index A_rows, Index B_cols> void SlowWrite(const T *from) {
for (Index i = 0; i < A_rows; ++i) {
@@ -108,6 +115,8 @@ template <class T> class RowMajorAccess {
Content *data_;
Index cols_;
+ Index row_idx_;
+ Index col_idx_;
};
template <class T> class ColMajorAccess {
diff --git a/tile/multiply.h b/tile/multiply.h
index 294b3ed..0ef9a43 100644
--- a/tile/multiply.h
+++ b/tile/multiply.h
@@ -3,6 +3,7 @@
#include "access.h"
#include "dot.h"
#include "reduce.h"
+#include "../callbacks.h"
#include "../types.h"
#include <cassert>
diff --git a/tile/multiply.inl b/tile/multiply.inl
index be86255..9036750 100644
--- a/tile/multiply.inl
+++ b/tile/multiply.inl
@@ -25,10 +25,13 @@ template <std::size_t... i> INTGEMM_TARGET static inline void Sum16To32(Register
template <std::size_t... i> INTGEMM_TARGET static inline void Sum16To32(Register *, int32_t, index_sequence<i...>) {}
/* Multiply assuming the matrix sizes are a multiple of the kernel size. */
-template <class Kernel, class AccessT> INTGEMM_TARGET __attribute__((flatten)) static inline void MultiplyNoOverhang(AccessT access, const Tile shape) {
+template <class Kernel, class AccessT, class Callback> INTGEMM_TARGET __attribute__((flatten)) static inline void MultiplyNoOverhang(AccessT access, const Tile shape, Callback callback) {
assert(shape.A_rows % Kernel::kTile.A_rows == 0);
assert(shape.inner % Kernel::kTile.inner == 0);
assert(shape.B_cols % Kernel::kTile.B_cols == 0);
+
+ auto callback_impl = callbacks::CallbackImpl<CPUType::INTGEMM_ARCH, Callback>(callback);
+
for (Index B_col = 0; B_col < shape.B_cols; B_col += Kernel::kTile.B_cols) {
AccessT column_adjusted = access.BAdd(0, B_col).CAdd(0, B_col);
for (Index A_row = 0; A_row < shape.A_rows; A_row += Kernel::kTile.A_rows) {
@@ -51,7 +54,7 @@ template <class Kernel, class AccessT> INTGEMM_TARGET __attribute__((flatten)) s
Sum16To32(c_regs, typename Kernel::Packed::C(), make_index_sequence<Outputs>());
// Horizontally add 32-bit values.
Reduce32<Outputs, Sum32Op>(c_regs);
- col_row.CAccessor().template Write<Kernel::kTile.A_rows, Kernel::kTile.B_cols>(c_regs);
+ col_row.CAccessor().template Write<Kernel::kTile.A_rows, Kernel::kTile.B_cols>(c_regs, callback_impl);
}
}
}
@@ -64,7 +67,7 @@ template <class Kernel, class AccessT> INTGEMM_TARGET __attribute__((flatten)) s
* A_rows and B_cols specify the unrolled kernel size to use for most of the
* multiply; these impact speed but not output.
*/
-template <class Kernel, Index A_rows, Index B_cols, class AccessT> INTGEMM_TARGET static inline void Multiply(AccessT access, const Tile shape) {
+template <class Kernel, Index A_rows, Index B_cols, class Access, class Callback> INTGEMM_TARGET static inline void Multiply(Access access, const Tile shape, Callback callback) {
// Still has to be a multiple of the underlying Kernel, but usually that's just 1 x sizeof(Register) x 1.
assert(shape.A_rows % Kernel::kTile.A_rows == 0);
assert(shape.inner % Kernel::kTile.inner == 0);
@@ -82,15 +85,15 @@ template <class Kernel, Index A_rows, Index B_cols, class AccessT> INTGEMM_TARGE
shape.B_cols - overhang.B_cols
};
// Top left corner.
- MultiplyNoOverhang<Big>(access, big_shape);
+ MultiplyNoOverhang<Big>(access, big_shape, callback);
// Bottom currently including right side. TODO: unrolled kernel, rather than dumb loop.
MultiplyNoOverhang<Kernel>(
access.AAdd(big_shape.A_rows, 0).CAdd(big_shape.A_rows, 0),
- Tile {overhang.A_rows, shape.inner, shape.B_cols});
+ Tile {overhang.A_rows, shape.inner, shape.B_cols}, callback);
// Right side except bottom. TODO: unrolled kernel, rather than dumb loop.
MultiplyNoOverhang<Kernel>(
access.BAdd(0, big_shape.B_cols).CAdd(0, big_shape.B_cols),
- Tile {big_shape.A_rows, shape.inner, overhang.B_cols});
+ Tile {big_shape.A_rows, shape.inner, overhang.B_cols}, callback);
}
} // namespace INTGEMM_ARCH
diff --git a/vec_traits.h b/vec_traits.h
index 6c4ec28..059aa83 100644
--- a/vec_traits.h
+++ b/vec_traits.h
@@ -28,6 +28,11 @@ template <> struct vector_s<CPUType::AVX512BW, int16_t> { using type = __m512i;
template <> struct vector_s<CPUType::AVX512BW, int> { using type = __m512i; };
template <> struct vector_s<CPUType::AVX512BW, float> { using type = __m512; };
template <> struct vector_s<CPUType::AVX512BW, double> { using type = __m512d; };
+template <> struct vector_s<CPUType::AVX512VNNI, int8_t> { using type = __m512i; };
+template <> struct vector_s<CPUType::AVX512VNNI, int16_t> { using type = __m512i; };
+template <> struct vector_s<CPUType::AVX512VNNI, int> { using type = __m512i; };
+template <> struct vector_s<CPUType::AVX512VNNI, float> { using type = __m512; };
+template <> struct vector_s<CPUType::AVX512VNNI, double> { using type = __m512d; };
template <CPUType CPUType_, typename ElemType_>
using vector_t = typename vector_s<CPUType_, ElemType_>::type;