diff options
author | Kenneth Heafield <kheafiel@amazon.com> | 2020-03-23 15:57:28 +0300 |
---|---|---|
committer | Kenneth Heafield <kheafiel@amazon.com> | 2020-03-23 17:27:40 +0300 |
commit | 21f122d7d0aede96665580488f4d0e3fedd0fa57 (patch) | |
tree | 5070ad350fa3295221f912622c0beb469eea2fdf | |
parent | 65176b06d3caea37bd0d9d5154686f073f37ad6b (diff) |
OMP parallelization for Multiply
-rw-r--r-- | avx512_gemm.h | 5 | ||||
-rw-r--r-- | avx512vnni_gemm.h | 15 | ||||
-rw-r--r-- | intgemm.h | 11 | ||||
-rw-r--r-- | multiply.h | 41 | ||||
-rw-r--r-- | test/multiply_test.cc | 2 |
5 files changed, 52 insertions, 22 deletions
diff --git a/avx512_gemm.h b/avx512_gemm.h index 6286ccc..e56d043 100644 --- a/avx512_gemm.h +++ b/avx512_gemm.h @@ -329,11 +329,12 @@ struct AVX512_8bit { // There's 8 results for INTGEMM_AVX2 to handle. auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback); const int simd_width = width / sizeof(Register); - const Register *B0_col = reinterpret_cast<const Register*>(B); // Added for AVX512. Register zeros = setzero_si<Register>(); // Go over 8 columns of B at a time. - for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { +#pragma omp for + for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { + const Register *B0_col = reinterpret_cast<const Register*>(B) + B0_colidx * simd_width; // Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once. for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { // Iterate over shared (inner) dimension. diff --git a/avx512vnni_gemm.h b/avx512vnni_gemm.h index 59f6405..6eb3be4 100644 --- a/avx512vnni_gemm.h +++ b/avx512vnni_gemm.h @@ -18,10 +18,11 @@ struct AVX512VNNI_8bit : public AVX512_8bit { assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback); const int simd_width = width / sizeof(Register); - const Register *B0_col = reinterpret_cast<const Register*>(B); Register zeros = setzero_si<Register>(); // Go over 8 columns of B at a time. - for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { +#pragma omp for + for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { + const Register *B0_col = reinterpret_cast<const Register*>(B) + B0_colidx * simd_width; // Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once. for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { // Iterate over shared (inner) dimension. @@ -79,10 +80,11 @@ struct AVX512VNNI_8bit : public AVX512_8bit { assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback); const int simd_width = width / sizeof(Register); - const Register *B0_col = reinterpret_cast<const Register*>(B); Register zeros = setzero_si<Register>(); // Go over 8 columns of B at a time. - for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { +#pragma omp for + for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { + const Register *B0_col = reinterpret_cast<const Register*>(B) + B0_colidx * simd_width; // Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once. for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { // Iterate over shared (inner) dimension. @@ -119,11 +121,12 @@ struct AVX512VNNI_8bit : public AVX512_8bit { assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback); const int simd_width = width / sizeof(Register); - const Register *B0_col = reinterpret_cast<const Register*>(B); Register zeros = setzero_si<Register>(); const Register a = set1_epi8<Register>(1); // Go over 8 columns of B at a time. - for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { +#pragma omp for + for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { + const Register *B0_col = reinterpret_cast<const Register*>(B) + B0_colidx * simd_width; const Register *B_live = B0_col; //In order to make the code look as much as possible as the above function const Register *B_end = B_live + simd_width*8; @@ -281,7 +281,7 @@ private: }; template <typename Callback> -void (*Int8::MultiplyImpl<Callback>::run)(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512VNNI_8bit::Multiply<Callback>, AVX512_8bit::Multiply<Callback>, AVX2_8bit::Multiply<Callback>, SSSE3_8bit::Multiply<Callback>, SSSE3_8bit::Multiply<Callback>, Unsupported_8bit::Multiply); +void (*Int8::MultiplyImpl<Callback>::run)(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(OMPParallelWrap<Callback, AVX512VNNI_8bit>, OMPParallelWrap<Callback, AVX512_8bit>, OMPParallelWrap<Callback, AVX2_8bit>, OMPParallelWrap<Callback, SSSE3_8bit>, Unsupported_8bit::Multiply<Callback>, Unsupported_8bit::Multiply<Callback>); /* * 8-bit matrix multiplication with shifting A by 127 @@ -344,7 +344,12 @@ private: }; template <class Callback> -void (*Int8Shift::MultiplyImpl<Callback>::run)(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512VNNI_8bit::Multiply8Shift<Callback>, AVX512_8bit::Multiply8Shift<Callback>, AVX2_8bit::Multiply8Shift<Callback>, SSSE3_8bit::Multiply8Shift<Callback>, SSSE3_8bit::Multiply8Shift<Callback>, Unsupported_8bit::Multiply8Shift); +void (*Int8Shift::MultiplyImpl<Callback>::run)(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU( + OMPParallelWrap8Shift<Callback, AVX512VNNI_8bit>, + OMPParallelWrap8Shift<Callback, AVX512_8bit>, + OMPParallelWrap8Shift<Callback, AVX2_8bit>, + OMPParallelWrap8Shift<Callback, SSSE3_8bit>, + Unsupported_8bit::Multiply8Shift<Callback>, Unsupported_8bit::Multiply8Shift<Callback>); template <class Callback> void (*Int8Shift::PrepareBiasImpl<Callback>::run)(const int8_t *B, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512VNNI_8bit::PrepareBias<Callback>, AVX512_8bit::PrepareBias<Callback>, AVX2_8bit::PrepareBias<Callback>, SSSE3_8bit::PrepareBias<Callback>, SSSE3_8bit::PrepareBias<Callback>, Unsupported_8bit::PrepareBias); @@ -403,7 +408,7 @@ private: }; template <typename Callback> -void (*Int16::MultiplyImpl<Callback>::run)(const int16_t *A, const int16_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512_16bit::Multiply<Callback> /*TODO VNNI 16-bit. */, AVX512_16bit::Multiply<Callback>, AVX2_16bit::Multiply<Callback>, SSE2_16bit::Multiply<Callback>, SSE2_16bit::Multiply<Callback>, Unsupported_16bit::Multiply); +void (*Int16::MultiplyImpl<Callback>::run)(const int16_t *A, const int16_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(OMPParallelWrap<Callback, AVX512_16bit> /*TODO VNNI 16-bit. */, OMPParallelWrap<Callback, AVX512_16bit>, OMPParallelWrap<Callback, AVX2_16bit>, OMPParallelWrap<Callback, SSE2_16bit>, OMPParallelWrap<Callback, SSE2_16bit>, Unsupported_16bit::Multiply<Callback>); extern const CPUType kCPU; @@ -176,8 +176,9 @@ template <typename Callback> target static void Multiply(const int16_t *A, const assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); \ const int simd_width = width / (sizeof(Register) / sizeof(int16_t)); \ auto callback_impl = callbacks::CallbackImpl<cpu_type, Callback>(callback); \ - const Register *B0_col = reinterpret_cast<const Register *>(B); \ - for (Index B0_colidx = 0; B0_colidx < B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \ + _Pragma("omp for") \ + for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { \ + const Register *B0_col = reinterpret_cast<const Register *>(B) + simd_width * B0_colidx; \ /* Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \ for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { \ const Register *A_row = reinterpret_cast<const Register*>(A + A_rowidx * width); \ @@ -232,9 +233,10 @@ template <typename Callback> target static void Multiply(const int16_t *A, const assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); \ const int simd_width = width / (sizeof(Register) / sizeof(int8_t)); \ auto callback_impl = callbacks::CallbackImpl<cpu_type, Callback>(callback); \ - const Register *B0_col = reinterpret_cast<const Register *>(B); \ const Register a = set1_epi8<Register>(1); \ - for (Index B0_colidx = 0; B0_colidx < B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \ + _Pragma("omp for") \ + for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { \ + const Register *B0_col = reinterpret_cast<const Register *>(B) + simd_width * B0_colidx; \ /*const Register *A_row = reinterpret_cast<const Register*>(A + A_rowidx * width);*/ \ /* These will be packed 16-bit integers containing sums for each row of B multiplied by the row of A. \ Iterate over shared (inner) dimension.*/ \ @@ -306,8 +308,9 @@ template <typename Callback> target static void Multiply(const int16_t *A, const assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); \ const int simd_width = width / (sizeof(Register) / sizeof(int8_t)); \ auto callback_impl = callbacks::CallbackImpl<cpu_type, Callback>(callback); \ - const Register *B0_col = reinterpret_cast<const Register *>(B); \ - for (Index B0_colidx = 0; B0_colidx < B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \ + _Pragma("omp for") \ + for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { \ + const Register *B0_col = reinterpret_cast<const Register *>(B) + simd_width * B0_colidx; \ /* Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \ for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { \ const Register *A_row = reinterpret_cast<const Register*>(A + A_rowidx * width); \ @@ -530,9 +533,9 @@ INTGEMM_SSSE3 inline static void InnerINTGEMM_SSSE3( assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); \ const int simd_width = width / sizeof(Register); \ auto callback_impl = callbacks::CallbackImpl<cpu_type, Callback>(callback); \ - const Register *B0_col = reinterpret_cast<const Register*>(B); \ - /*Go over 8 columns of B at a time.*/ \ - for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \ + _Pragma("omp for") \ + for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { \ + const Register *B0_col = reinterpret_cast<const Register *>(B) + simd_width * B0_colidx; \ /*Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \ for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { \ /*Iterate over shared (inner) dimension.*/ \ @@ -588,7 +591,25 @@ INTGEMM_SSSE3 inline static void InnerINTGEMM_SSSE3( RunCallback(callback_impl, total, A_rowidx, B0_colidx, A_rows, B_cols); \ } \ } \ -} \ +} + +/* Wrap a multiply call in OMP parallelism. Here it launches threads then + * inside the implementation there is a pragma omp for. In gcc >= 8 these + * could have been the same but older compilers don't imbue target attributes + * on the hidden function created by pragma omp parallel. + * + * Also, gcc 7 is unable to deduce the function pointer type (for ChooseCPU) if + * I use typename Backend::Integer directly in the arguments. As a workaround, + * have a default template argument Integer then use that so it's resolved. + */ +template <class Callback, class Backend, class Integer = typename Backend::Integer> static inline void OMPParallelWrap(const Integer *A, const Integer *B, Index A_rows, Index width, Index B_cols, Callback callback) { +#pragma omp parallel + Backend::template Multiply<Callback>(A, B, A_rows, width, B_cols, callback); +} +template <class Callback, class Backend> static inline void OMPParallelWrap8Shift(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { +#pragma omp parallel + Backend::template Multiply8Shift<Callback>(A, B, A_rows, width, B_cols, callback); +} #define INTGEMM_MAXABSOLUTE(Register, target) \ target static inline float MaxAbsolute(const float *begin_float, const float *end_float) { \ diff --git a/test/multiply_test.cc b/test/multiply_test.cc index 260dd76..27d8a07 100644 --- a/test/multiply_test.cc +++ b/test/multiply_test.cc @@ -278,7 +278,7 @@ template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_co Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols); AlignedVector<float> test_C(A_rows * B_cols); - Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndWrite(unquant_mult, test_C.begin())); + OMPParallelWrap<callbacks::UnquantizeAndWrite, Routine>(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndWrite(unquant_mult, test_C.begin())); // Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::Sequence( // callbacks::Unquantize(unquant_mult), // callbacks::Write<float>(test_C.begin()) |