Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKenneth Heafield <kheafiel@amazon.com>2020-03-23 15:57:28 +0300
committerKenneth Heafield <kheafiel@amazon.com>2020-03-23 17:27:40 +0300
commit21f122d7d0aede96665580488f4d0e3fedd0fa57 (patch)
tree5070ad350fa3295221f912622c0beb469eea2fdf
parent65176b06d3caea37bd0d9d5154686f073f37ad6b (diff)
OMP parallelization for Multiply
-rw-r--r--avx512_gemm.h5
-rw-r--r--avx512vnni_gemm.h15
-rw-r--r--intgemm.h11
-rw-r--r--multiply.h41
-rw-r--r--test/multiply_test.cc2
5 files changed, 52 insertions, 22 deletions
diff --git a/avx512_gemm.h b/avx512_gemm.h
index 6286ccc..e56d043 100644
--- a/avx512_gemm.h
+++ b/avx512_gemm.h
@@ -329,11 +329,12 @@ struct AVX512_8bit {
// There's 8 results for INTGEMM_AVX2 to handle.
auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback);
const int simd_width = width / sizeof(Register);
- const Register *B0_col = reinterpret_cast<const Register*>(B);
// Added for AVX512.
Register zeros = setzero_si<Register>();
// Go over 8 columns of B at a time.
- for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) {
+#pragma omp for
+ for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) {
+ const Register *B0_col = reinterpret_cast<const Register*>(B) + B0_colidx * simd_width;
// Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.
for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) {
// Iterate over shared (inner) dimension.
diff --git a/avx512vnni_gemm.h b/avx512vnni_gemm.h
index 59f6405..6eb3be4 100644
--- a/avx512vnni_gemm.h
+++ b/avx512vnni_gemm.h
@@ -18,10 +18,11 @@ struct AVX512VNNI_8bit : public AVX512_8bit {
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0);
auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback);
const int simd_width = width / sizeof(Register);
- const Register *B0_col = reinterpret_cast<const Register*>(B);
Register zeros = setzero_si<Register>();
// Go over 8 columns of B at a time.
- for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) {
+#pragma omp for
+ for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) {
+ const Register *B0_col = reinterpret_cast<const Register*>(B) + B0_colidx * simd_width;
// Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.
for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) {
// Iterate over shared (inner) dimension.
@@ -79,10 +80,11 @@ struct AVX512VNNI_8bit : public AVX512_8bit {
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0);
auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback);
const int simd_width = width / sizeof(Register);
- const Register *B0_col = reinterpret_cast<const Register*>(B);
Register zeros = setzero_si<Register>();
// Go over 8 columns of B at a time.
- for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) {
+#pragma omp for
+ for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) {
+ const Register *B0_col = reinterpret_cast<const Register*>(B) + B0_colidx * simd_width;
// Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.
for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) {
// Iterate over shared (inner) dimension.
@@ -119,11 +121,12 @@ struct AVX512VNNI_8bit : public AVX512_8bit {
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0);
auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback);
const int simd_width = width / sizeof(Register);
- const Register *B0_col = reinterpret_cast<const Register*>(B);
Register zeros = setzero_si<Register>();
const Register a = set1_epi8<Register>(1);
// Go over 8 columns of B at a time.
- for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) {
+#pragma omp for
+ for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) {
+ const Register *B0_col = reinterpret_cast<const Register*>(B) + B0_colidx * simd_width;
const Register *B_live = B0_col; //In order to make the code look as much as possible as the above function
const Register *B_end = B_live + simd_width*8;
diff --git a/intgemm.h b/intgemm.h
index 0c315fc..8fe7bb7 100644
--- a/intgemm.h
+++ b/intgemm.h
@@ -281,7 +281,7 @@ private:
};
template <typename Callback>
-void (*Int8::MultiplyImpl<Callback>::run)(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512VNNI_8bit::Multiply<Callback>, AVX512_8bit::Multiply<Callback>, AVX2_8bit::Multiply<Callback>, SSSE3_8bit::Multiply<Callback>, SSSE3_8bit::Multiply<Callback>, Unsupported_8bit::Multiply);
+void (*Int8::MultiplyImpl<Callback>::run)(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(OMPParallelWrap<Callback, AVX512VNNI_8bit>, OMPParallelWrap<Callback, AVX512_8bit>, OMPParallelWrap<Callback, AVX2_8bit>, OMPParallelWrap<Callback, SSSE3_8bit>, Unsupported_8bit::Multiply<Callback>, Unsupported_8bit::Multiply<Callback>);
/*
* 8-bit matrix multiplication with shifting A by 127
@@ -344,7 +344,12 @@ private:
};
template <class Callback>
-void (*Int8Shift::MultiplyImpl<Callback>::run)(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512VNNI_8bit::Multiply8Shift<Callback>, AVX512_8bit::Multiply8Shift<Callback>, AVX2_8bit::Multiply8Shift<Callback>, SSSE3_8bit::Multiply8Shift<Callback>, SSSE3_8bit::Multiply8Shift<Callback>, Unsupported_8bit::Multiply8Shift);
+void (*Int8Shift::MultiplyImpl<Callback>::run)(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(
+ OMPParallelWrap8Shift<Callback, AVX512VNNI_8bit>,
+ OMPParallelWrap8Shift<Callback, AVX512_8bit>,
+ OMPParallelWrap8Shift<Callback, AVX2_8bit>,
+ OMPParallelWrap8Shift<Callback, SSSE3_8bit>,
+ Unsupported_8bit::Multiply8Shift<Callback>, Unsupported_8bit::Multiply8Shift<Callback>);
template <class Callback>
void (*Int8Shift::PrepareBiasImpl<Callback>::run)(const int8_t *B, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512VNNI_8bit::PrepareBias<Callback>, AVX512_8bit::PrepareBias<Callback>, AVX2_8bit::PrepareBias<Callback>, SSSE3_8bit::PrepareBias<Callback>, SSSE3_8bit::PrepareBias<Callback>, Unsupported_8bit::PrepareBias);
@@ -403,7 +408,7 @@ private:
};
template <typename Callback>
-void (*Int16::MultiplyImpl<Callback>::run)(const int16_t *A, const int16_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512_16bit::Multiply<Callback> /*TODO VNNI 16-bit. */, AVX512_16bit::Multiply<Callback>, AVX2_16bit::Multiply<Callback>, SSE2_16bit::Multiply<Callback>, SSE2_16bit::Multiply<Callback>, Unsupported_16bit::Multiply);
+void (*Int16::MultiplyImpl<Callback>::run)(const int16_t *A, const int16_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(OMPParallelWrap<Callback, AVX512_16bit> /*TODO VNNI 16-bit. */, OMPParallelWrap<Callback, AVX512_16bit>, OMPParallelWrap<Callback, AVX2_16bit>, OMPParallelWrap<Callback, SSE2_16bit>, OMPParallelWrap<Callback, SSE2_16bit>, Unsupported_16bit::Multiply<Callback>);
extern const CPUType kCPU;
diff --git a/multiply.h b/multiply.h
index 84d6737..359db45 100644
--- a/multiply.h
+++ b/multiply.h
@@ -176,8 +176,9 @@ template <typename Callback> target static void Multiply(const int16_t *A, const
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); \
const int simd_width = width / (sizeof(Register) / sizeof(int16_t)); \
auto callback_impl = callbacks::CallbackImpl<cpu_type, Callback>(callback); \
- const Register *B0_col = reinterpret_cast<const Register *>(B); \
- for (Index B0_colidx = 0; B0_colidx < B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \
+ _Pragma("omp for") \
+ for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { \
+ const Register *B0_col = reinterpret_cast<const Register *>(B) + simd_width * B0_colidx; \
/* Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \
for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { \
const Register *A_row = reinterpret_cast<const Register*>(A + A_rowidx * width); \
@@ -232,9 +233,10 @@ template <typename Callback> target static void Multiply(const int16_t *A, const
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); \
const int simd_width = width / (sizeof(Register) / sizeof(int8_t)); \
auto callback_impl = callbacks::CallbackImpl<cpu_type, Callback>(callback); \
- const Register *B0_col = reinterpret_cast<const Register *>(B); \
const Register a = set1_epi8<Register>(1); \
- for (Index B0_colidx = 0; B0_colidx < B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \
+ _Pragma("omp for") \
+ for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { \
+ const Register *B0_col = reinterpret_cast<const Register *>(B) + simd_width * B0_colidx; \
/*const Register *A_row = reinterpret_cast<const Register*>(A + A_rowidx * width);*/ \
/* These will be packed 16-bit integers containing sums for each row of B multiplied by the row of A. \
Iterate over shared (inner) dimension.*/ \
@@ -306,8 +308,9 @@ template <typename Callback> target static void Multiply(const int16_t *A, const
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); \
const int simd_width = width / (sizeof(Register) / sizeof(int8_t)); \
auto callback_impl = callbacks::CallbackImpl<cpu_type, Callback>(callback); \
- const Register *B0_col = reinterpret_cast<const Register *>(B); \
- for (Index B0_colidx = 0; B0_colidx < B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \
+ _Pragma("omp for") \
+ for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { \
+ const Register *B0_col = reinterpret_cast<const Register *>(B) + simd_width * B0_colidx; \
/* Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \
for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { \
const Register *A_row = reinterpret_cast<const Register*>(A + A_rowidx * width); \
@@ -530,9 +533,9 @@ INTGEMM_SSSE3 inline static void InnerINTGEMM_SSSE3(
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); \
const int simd_width = width / sizeof(Register); \
auto callback_impl = callbacks::CallbackImpl<cpu_type, Callback>(callback); \
- const Register *B0_col = reinterpret_cast<const Register*>(B); \
- /*Go over 8 columns of B at a time.*/ \
- for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \
+ _Pragma("omp for") \
+ for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { \
+ const Register *B0_col = reinterpret_cast<const Register *>(B) + simd_width * B0_colidx; \
/*Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \
for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { \
/*Iterate over shared (inner) dimension.*/ \
@@ -588,7 +591,25 @@ INTGEMM_SSSE3 inline static void InnerINTGEMM_SSSE3(
RunCallback(callback_impl, total, A_rowidx, B0_colidx, A_rows, B_cols); \
} \
} \
-} \
+}
+
+/* Wrap a multiply call in OMP parallelism. Here it launches threads then
+ * inside the implementation there is a pragma omp for. In gcc >= 8 these
+ * could have been the same but older compilers don't imbue target attributes
+ * on the hidden function created by pragma omp parallel.
+ *
+ * Also, gcc 7 is unable to deduce the function pointer type (for ChooseCPU) if
+ * I use typename Backend::Integer directly in the arguments. As a workaround,
+ * have a default template argument Integer then use that so it's resolved.
+ */
+template <class Callback, class Backend, class Integer = typename Backend::Integer> static inline void OMPParallelWrap(const Integer *A, const Integer *B, Index A_rows, Index width, Index B_cols, Callback callback) {
+#pragma omp parallel
+ Backend::template Multiply<Callback>(A, B, A_rows, width, B_cols, callback);
+}
+template <class Callback, class Backend> static inline void OMPParallelWrap8Shift(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) {
+#pragma omp parallel
+ Backend::template Multiply8Shift<Callback>(A, B, A_rows, width, B_cols, callback);
+}
#define INTGEMM_MAXABSOLUTE(Register, target) \
target static inline float MaxAbsolute(const float *begin_float, const float *end_float) { \
diff --git a/test/multiply_test.cc b/test/multiply_test.cc
index 260dd76..27d8a07 100644
--- a/test/multiply_test.cc
+++ b/test/multiply_test.cc
@@ -278,7 +278,7 @@ template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_co
Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
AlignedVector<float> test_C(A_rows * B_cols);
- Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndWrite(unquant_mult, test_C.begin()));
+ OMPParallelWrap<callbacks::UnquantizeAndWrite, Routine>(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndWrite(unquant_mult, test_C.begin()));
// Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::Sequence(
// callbacks::Unquantize(unquant_mult),
// callbacks::Write<float>(test_C.begin())