Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2020-04-18 01:35:39 +0300
committerKenneth Heafield <github@kheafield.com>2020-04-18 01:35:39 +0300
commit9801459266d44a38f6e2aade427491d3c3015218 (patch)
tree06c9dea044937c11daa007608cc40e994c144c4b
parentb444029e291f874859000ad5527ce38895213f47 (diff)
parentfb96b0851cf420ac49c13b361a503afffe386ada (diff)
Merge remote-tracking branch 'origin/master' into static
-rw-r--r--aligned.h9
-rw-r--r--avx512_gemm.h5
-rw-r--r--avx512vnni_gemm.h15
-rw-r--r--intgemm.h11
-rw-r--r--multiply.h41
-rw-r--r--test/multiply_test.cc2
6 files changed, 56 insertions, 27 deletions
diff --git a/aligned.h b/aligned.h
index 6e72afb..6af3c31 100644
--- a/aligned.h
+++ b/aligned.h
@@ -1,5 +1,6 @@
#pragma once
#include <cstdlib>
+#include <new>
#include <stdlib.h>
// 64-byte aligned simple vector.
@@ -10,11 +11,9 @@ template <class T> class AlignedVector {
public:
explicit AlignedVector(std::size_t size)
: size_(size) {
- #ifdef __APPLE__
- posix_memalign(reinterpret_cast<void **>(&mem_), 64, size * sizeof(T));
- #else
- mem_ = reinterpret_cast<T*>(aligned_alloc(64, (size * sizeof(T) + 63) & ~63)); // pedantic requirements for memory size on aligned_alloc in case it's not just a call to posix_memalign
- #endif
+ if (posix_memalign(reinterpret_cast<void **>(&mem_), 64, size * sizeof(T))) {
+ throw std::bad_alloc();
+ }
}
AlignedVector(const AlignedVector&) = delete;
diff --git a/avx512_gemm.h b/avx512_gemm.h
index 623e21a..c6a473e 100644
--- a/avx512_gemm.h
+++ b/avx512_gemm.h
@@ -331,11 +331,12 @@ struct AVX512_8bit {
// There's 8 results for INTGEMM_AVX2 to handle.
auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback);
const int simd_width = width / sizeof(Register);
- const Register *B0_col = reinterpret_cast<const Register*>(B);
// Added for AVX512.
Register zeros = setzero_si<Register>();
// Go over 8 columns of B at a time.
- for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) {
+#pragma omp for
+ for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) {
+ const Register *B0_col = reinterpret_cast<const Register*>(B) + B0_colidx * simd_width;
// Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.
for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) {
// Iterate over shared (inner) dimension.
diff --git a/avx512vnni_gemm.h b/avx512vnni_gemm.h
index 59f6405..6eb3be4 100644
--- a/avx512vnni_gemm.h
+++ b/avx512vnni_gemm.h
@@ -18,10 +18,11 @@ struct AVX512VNNI_8bit : public AVX512_8bit {
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0);
auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback);
const int simd_width = width / sizeof(Register);
- const Register *B0_col = reinterpret_cast<const Register*>(B);
Register zeros = setzero_si<Register>();
// Go over 8 columns of B at a time.
- for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) {
+#pragma omp for
+ for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) {
+ const Register *B0_col = reinterpret_cast<const Register*>(B) + B0_colidx * simd_width;
// Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.
for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) {
// Iterate over shared (inner) dimension.
@@ -79,10 +80,11 @@ struct AVX512VNNI_8bit : public AVX512_8bit {
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0);
auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback);
const int simd_width = width / sizeof(Register);
- const Register *B0_col = reinterpret_cast<const Register*>(B);
Register zeros = setzero_si<Register>();
// Go over 8 columns of B at a time.
- for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) {
+#pragma omp for
+ for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) {
+ const Register *B0_col = reinterpret_cast<const Register*>(B) + B0_colidx * simd_width;
// Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.
for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) {
// Iterate over shared (inner) dimension.
@@ -119,11 +121,12 @@ struct AVX512VNNI_8bit : public AVX512_8bit {
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0);
auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback);
const int simd_width = width / sizeof(Register);
- const Register *B0_col = reinterpret_cast<const Register*>(B);
Register zeros = setzero_si<Register>();
const Register a = set1_epi8<Register>(1);
// Go over 8 columns of B at a time.
- for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) {
+#pragma omp for
+ for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) {
+ const Register *B0_col = reinterpret_cast<const Register*>(B) + B0_colidx * simd_width;
const Register *B_live = B0_col; //In order to make the code look as much as possible as the above function
const Register *B_end = B_live + simd_width*8;
diff --git a/intgemm.h b/intgemm.h
index 8c5309b..95c2428 100644
--- a/intgemm.h
+++ b/intgemm.h
@@ -285,7 +285,7 @@ private:
};
template <typename Callback>
-void (*Int8::MultiplyImpl<Callback>::run)(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512VNNI_8bit::Multiply<Callback>, AVX512_8bit::Multiply<Callback>, AVX2_8bit::Multiply<Callback>, SSSE3_8bit::Multiply<Callback>, SSSE3_8bit::Multiply<Callback>, Unsupported_8bit::Multiply);
+void (*Int8::MultiplyImpl<Callback>::run)(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(OMPParallelWrap<Callback, AVX512VNNI_8bit>, OMPParallelWrap<Callback, AVX512_8bit>, OMPParallelWrap<Callback, AVX2_8bit>, OMPParallelWrap<Callback, SSSE3_8bit>, Unsupported_8bit::Multiply<Callback>, Unsupported_8bit::Multiply<Callback>);
/*
* 8-bit matrix multiplication with shifting A by 127
@@ -348,7 +348,12 @@ private:
};
template <class Callback>
-void (*Int8Shift::MultiplyImpl<Callback>::run)(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512VNNI_8bit::Multiply8Shift<Callback>, AVX512_8bit::Multiply8Shift<Callback>, AVX2_8bit::Multiply8Shift<Callback>, SSSE3_8bit::Multiply8Shift<Callback>, SSSE3_8bit::Multiply8Shift<Callback>, Unsupported_8bit::Multiply8Shift);
+void (*Int8Shift::MultiplyImpl<Callback>::run)(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(
+ OMPParallelWrap8Shift<Callback, AVX512VNNI_8bit>,
+ OMPParallelWrap8Shift<Callback, AVX512_8bit>,
+ OMPParallelWrap8Shift<Callback, AVX2_8bit>,
+ OMPParallelWrap8Shift<Callback, SSSE3_8bit>,
+ Unsupported_8bit::Multiply8Shift<Callback>, Unsupported_8bit::Multiply8Shift<Callback>);
template <class Callback>
void (*Int8Shift::PrepareBiasImpl<Callback>::run)(const int8_t *B, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512VNNI_8bit::PrepareBias<Callback>, AVX512_8bit::PrepareBias<Callback>, AVX2_8bit::PrepareBias<Callback>, SSSE3_8bit::PrepareBias<Callback>, SSSE3_8bit::PrepareBias<Callback>, Unsupported_8bit::PrepareBias);
@@ -407,7 +412,7 @@ private:
};
template <typename Callback>
-void (*Int16::MultiplyImpl<Callback>::run)(const int16_t *A, const int16_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512_16bit::Multiply<Callback> /*TODO VNNI 16-bit. */, AVX512_16bit::Multiply<Callback>, AVX2_16bit::Multiply<Callback>, SSE2_16bit::Multiply<Callback>, SSE2_16bit::Multiply<Callback>, Unsupported_16bit::Multiply);
+void (*Int16::MultiplyImpl<Callback>::run)(const int16_t *A, const int16_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(OMPParallelWrap<Callback, AVX512_16bit> /*TODO VNNI 16-bit. */, OMPParallelWrap<Callback, AVX512_16bit>, OMPParallelWrap<Callback, AVX2_16bit>, OMPParallelWrap<Callback, SSE2_16bit>, OMPParallelWrap<Callback, SSE2_16bit>, Unsupported_16bit::Multiply<Callback>);
extern const CPUType kCPU;
diff --git a/multiply.h b/multiply.h
index 9790353..adb71ee 100644
--- a/multiply.h
+++ b/multiply.h
@@ -205,8 +205,9 @@ template <typename Callback> target static void Multiply(const int16_t *A, const
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); \
const int simd_width = width / (sizeof(Register) / sizeof(int16_t)); \
auto callback_impl = callbacks::CallbackImpl<cpu_type, Callback>(callback); \
- const Register *B0_col = reinterpret_cast<const Register *>(B); \
- for (Index B0_colidx = 0; B0_colidx < B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \
+ _Pragma("omp for") \
+ for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { \
+ const Register *B0_col = reinterpret_cast<const Register *>(B) + simd_width * B0_colidx; \
/* Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \
for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { \
const Register *A_row = reinterpret_cast<const Register*>(A + A_rowidx * width); \
@@ -261,9 +262,10 @@ template <typename Callback> target static void Multiply(const int16_t *A, const
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); \
const int simd_width = width / (sizeof(Register) / sizeof(int8_t)); \
auto callback_impl = callbacks::CallbackImpl<cpu_type, Callback>(callback); \
- const Register *B0_col = reinterpret_cast<const Register *>(B); \
const Register a = set1_epi8<Register>(1); \
- for (Index B0_colidx = 0; B0_colidx < B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \
+ _Pragma("omp for") \
+ for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { \
+ const Register *B0_col = reinterpret_cast<const Register *>(B) + simd_width * B0_colidx; \
/*const Register *A_row = reinterpret_cast<const Register*>(A + A_rowidx * width);*/ \
/* These will be packed 16-bit integers containing sums for each row of B multiplied by the row of A. \
Iterate over shared (inner) dimension.*/ \
@@ -335,8 +337,9 @@ template <typename Callback> target static void Multiply(const int16_t *A, const
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); \
const int simd_width = width / (sizeof(Register) / sizeof(int8_t)); \
auto callback_impl = callbacks::CallbackImpl<cpu_type, Callback>(callback); \
- const Register *B0_col = reinterpret_cast<const Register *>(B); \
- for (Index B0_colidx = 0; B0_colidx < B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \
+ _Pragma("omp for") \
+ for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { \
+ const Register *B0_col = reinterpret_cast<const Register *>(B) + simd_width * B0_colidx; \
/* Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \
for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { \
const Register *A_row = reinterpret_cast<const Register*>(A + A_rowidx * width); \
@@ -559,9 +562,9 @@ INTGEMM_SSSE3 inline static void InnerINTGEMM_SSSE3(
assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); \
const int simd_width = width / sizeof(Register); \
auto callback_impl = callbacks::CallbackImpl<cpu_type, Callback>(callback); \
- const Register *B0_col = reinterpret_cast<const Register*>(B); \
- /*Go over 8 columns of B at a time.*/ \
- for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { \
+ _Pragma("omp for") \
+ for (Index B0_colidx = 0; B0_colidx < B_cols; B0_colidx += 8) { \
+ const Register *B0_col = reinterpret_cast<const Register *>(B) + simd_width * B0_colidx; \
/*Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.*/ \
for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { \
/*Iterate over shared (inner) dimension.*/ \
@@ -617,7 +620,25 @@ INTGEMM_SSSE3 inline static void InnerINTGEMM_SSSE3(
RunCallback(callback_impl, total, A_rowidx, B0_colidx, A_rows, B_cols); \
} \
} \
-} \
+}
+
+/* Wrap a multiply call in OMP parallelism. Here it launches threads then
+ * inside the implementation there is a pragma omp for. In gcc >= 8 these
+ * could have been the same but older compilers don't imbue target attributes
+ * on the hidden function created by pragma omp parallel.
+ *
+ * Also, gcc 7 is unable to deduce the function pointer type (for ChooseCPU) if
+ * I use typename Backend::Integer directly in the arguments. As a workaround,
+ * have a default template argument Integer then use that so it's resolved.
+ */
+template <class Callback, class Backend, class Integer = typename Backend::Integer> static inline void OMPParallelWrap(const Integer *A, const Integer *B, Index A_rows, Index width, Index B_cols, Callback callback) {
+#pragma omp parallel
+ Backend::template Multiply<Callback>(A, B, A_rows, width, B_cols, callback);
+}
+template <class Callback, class Backend> static inline void OMPParallelWrap8Shift(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) {
+#pragma omp parallel
+ Backend::template Multiply8Shift<Callback>(A, B, A_rows, width, B_cols, callback);
+}
#define INTGEMM_MAXABSOLUTE(Register, target) \
target static inline float MaxAbsolute(const float *begin_float, const float *end_float) { \
diff --git a/test/multiply_test.cc b/test/multiply_test.cc
index 0fd8231..a054753 100644
--- a/test/multiply_test.cc
+++ b/test/multiply_test.cc
@@ -278,7 +278,7 @@ template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_co
Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
AlignedVector<float> test_C(A_rows * B_cols);
- Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndWrite(unquant_mult, test_C.begin()));
+ OMPParallelWrap<callbacks::UnquantizeAndWrite, Routine>(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndWrite(unquant_mult, test_C.begin()));
// Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::Sequence(
// callbacks::Unquantize(unquant_mult),
// callbacks::Write<float>(test_C.begin())