Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Chudyk <mateuszchudyk@gmail.com>2019-07-09 22:04:46 +0300
committerGitHub <noreply@github.com>2019-07-09 22:04:46 +0300
commit5466238858becaec459d154137dbd2d79baa0d3d (patch)
tree6ef5f76263af5337fabe9ef3c1325edd35699c65
parent937e42a5e4b8ae145e0082d82e1b4074355600c2 (diff)
parentd06b3e5f15876b3c0691494cf49c9dea78120378 (diff)
Merge pull request #24 from kpu/fixes
Fixes
-rw-r--r--avx512_gemm.h216
-rw-r--r--intgemm.cc7
2 files changed, 115 insertions, 108 deletions
diff --git a/avx512_gemm.h b/avx512_gemm.h
index 043dfae..8326a82 100644
--- a/avx512_gemm.h
+++ b/avx512_gemm.h
@@ -219,117 +219,117 @@ struct AVX512_8bit {
// allocate registers manually) and no sign instruction.
template <typename PostprocessPipeline>
INTGEMM_AVX512BW static void Multiply(const int8_t *A, const int8_t *B, float *C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols) {
- typedef __m512i Integer;
- //typedef __m256 Float; // For quantization we only do 8 at a time.
- // This is copy-paste from Multiply8_SSE2OrAVX2.
- assert(width % sizeof(Integer) == 0);
- assert(B_cols % 8 == 0);
- assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0);
- assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0);
- // There's 8 results for INTGEMM_AVX2 to handle.
- auto inited_pipeline = InitPostprocessPipeline<CPUType::AVX2>(pipeline);
- const int simd_width = width / sizeof(Integer);
- const Integer *B0_col = reinterpret_cast<const Integer*>(B);
- // Added for AVX512.
- Integer zeros = setzero_si<Integer>();
- // Go over 8 columns of B at a time.
- for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) {
- // Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.
- for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) {
- // Iterate over shared (inner) dimension.
- const Integer *A_live = reinterpret_cast<const Integer *>(A + A_rowidx * width);
- const Integer *A_end = A_live + simd_width;
- const Integer *B_live = B0_col;
-
- // Do the first iteration to initialize the sums.
- __m512i a = *A_live;
- __mmask64 neg_mask = _mm512_test_epi8_mask(a, _mm512_set1_epi8(-128));
- __m512i a_positive = _mm512_abs_epi8(a);
- // These will be packed 16-bit integers containing sums for each column of B multiplied by the row of A.
- Integer sum0 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[0], neg_mask, zeros, B_live[0]));
- Integer sum1 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[1], neg_mask, zeros, B_live[1]));
- Integer sum2 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[2], neg_mask, zeros, B_live[2]));
- Integer sum3 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[3], neg_mask, zeros, B_live[3]));
- Integer sum4 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[4], neg_mask, zeros, B_live[4]));
- Integer sum5 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[5], neg_mask, zeros, B_live[5]));
- Integer sum6 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[6], neg_mask, zeros, B_live[6]));
- Integer sum7 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[7], neg_mask, zeros, B_live[7]));
-
- ++A_live;
- B_live += 8;
-
- // Use A as the loop variable so the add can be done where gcc likes it
- // for branch prediction.
- for (; A_live != A_end; ++A_live, B_live += 8) {
- // Unique code here: can we do an inline function?
- // Retrieve a. We will use this as the unsigned part.
- a = *A_live;
- // Retrieve the conveniently consecutive values of B.
- __m512i b0 = *B_live;
- __m512i b1 = *(B_live + 1);
- __m512i b2 = *(B_live + 2);
- __m512i b3 = *(B_live + 3);
- __m512i b4 = *(B_live + 4);
- __m512i b5 = *(B_live + 5);
- __m512i b6 = *(B_live + 6);
- __m512i b7 = *(B_live + 7);
-
- // Get a mask where a is negative.
- // Didn't seem to make a difference definining sign bits here vs at top
- neg_mask = _mm512_test_epi8_mask(a, _mm512_set1_epi8(-128));
- a_positive = _mm512_abs_epi8(a);
-
- // Negate by subtracting from zero with a mask.
- b0 = _mm512_mask_sub_epi8(b0, neg_mask, zeros, b0);
- b1 = _mm512_mask_sub_epi8(b1, neg_mask, zeros, b1);
- b2 = _mm512_mask_sub_epi8(b2, neg_mask, zeros, b2);
- b3 = _mm512_mask_sub_epi8(b3, neg_mask, zeros, b3);
- b4 = _mm512_mask_sub_epi8(b4, neg_mask, zeros, b4);
- b5 = _mm512_mask_sub_epi8(b5, neg_mask, zeros, b5);
- b6 = _mm512_mask_sub_epi8(b6, neg_mask, zeros, b6);
- b7 = _mm512_mask_sub_epi8(b7, neg_mask, zeros, b7);
- // The magic 8-bit multiply then horizontal sum into 16-bit.
- b0 = _mm512_maddubs_epi16(a_positive, b0);
- b1 = _mm512_maddubs_epi16(a_positive, b1);
- b2 = _mm512_maddubs_epi16(a_positive, b2);
- b3 = _mm512_maddubs_epi16(a_positive, b3);
- b4 = _mm512_maddubs_epi16(a_positive, b4);
- b5 = _mm512_maddubs_epi16(a_positive, b5);
- b6 = _mm512_maddubs_epi16(a_positive, b6);
- b7 = _mm512_maddubs_epi16(a_positive, b7);
- // Now we have 16-bit results that are the sum of two multiplies.
- // Choosing to approximate and do adds.
- // Perhaps every so often we could accumulate by upcasting.
- sum0 = _mm512_adds_epi16(sum0, b0);
- sum1 = _mm512_adds_epi16(sum1, b1);
- sum2 = _mm512_adds_epi16(sum2, b2);
- sum3 = _mm512_adds_epi16(sum3, b3);
- sum4 = _mm512_adds_epi16(sum4, b4);
- sum5 = _mm512_adds_epi16(sum5, b5);
- sum6 = _mm512_adds_epi16(sum6, b6);
- sum7 = _mm512_adds_epi16(sum7, b7);
- // Unique code ends: can we do an inline function?
+ typedef __m512i Integer;
+ //typedef __m256 Float; // For quantization we only do 8 at a time.
+ // This is copy-paste from Multiply8_SSE2OrAVX2.
+ assert(width % sizeof(Integer) == 0);
+ assert(B_cols % 8 == 0);
+ assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0);
+ assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0);
+ // There's 8 results for INTGEMM_AVX2 to handle.
+ auto inited_pipeline = InitPostprocessPipeline<CPUType::AVX2>(pipeline);
+ const int simd_width = width / sizeof(Integer);
+ const Integer *B0_col = reinterpret_cast<const Integer*>(B);
+ // Added for AVX512.
+ Integer zeros = setzero_si<Integer>();
+ // Go over 8 columns of B at a time.
+ for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) {
+ // Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once.
+ for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) {
+ // Iterate over shared (inner) dimension.
+ const Integer *A_live = reinterpret_cast<const Integer *>(A + A_rowidx * width);
+ const Integer *A_end = A_live + simd_width;
+ const Integer *B_live = B0_col;
+
+ // Do the first iteration to initialize the sums.
+ __m512i a = *A_live;
+ __mmask64 neg_mask = _mm512_test_epi8_mask(a, _mm512_set1_epi8(-128));
+ __m512i a_positive = _mm512_abs_epi8(a);
+ // These will be packed 16-bit integers containing sums for each column of B multiplied by the row of A.
+ Integer sum0 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[0], neg_mask, zeros, B_live[0]));
+ Integer sum1 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[1], neg_mask, zeros, B_live[1]));
+ Integer sum2 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[2], neg_mask, zeros, B_live[2]));
+ Integer sum3 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[3], neg_mask, zeros, B_live[3]));
+ Integer sum4 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[4], neg_mask, zeros, B_live[4]));
+ Integer sum5 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[5], neg_mask, zeros, B_live[5]));
+ Integer sum6 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[6], neg_mask, zeros, B_live[6]));
+ Integer sum7 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[7], neg_mask, zeros, B_live[7]));
+
+ ++A_live;
+ B_live += 8;
+
+ // Use A as the loop variable so the add can be done where gcc likes it
+ // for branch prediction.
+ for (; A_live != A_end; ++A_live, B_live += 8) {
+ // Unique code here: can we do an inline function?
+ // Retrieve a. We will use this as the unsigned part.
+ a = *A_live;
+ // Retrieve the conveniently consecutive values of B.
+ __m512i b0 = *B_live;
+ __m512i b1 = *(B_live + 1);
+ __m512i b2 = *(B_live + 2);
+ __m512i b3 = *(B_live + 3);
+ __m512i b4 = *(B_live + 4);
+ __m512i b5 = *(B_live + 5);
+ __m512i b6 = *(B_live + 6);
+ __m512i b7 = *(B_live + 7);
+
+ // Get a mask where a is negative.
+ // Didn't seem to make a difference definining sign bits here vs at top
+ neg_mask = _mm512_test_epi8_mask(a, _mm512_set1_epi8(-128));
+ a_positive = _mm512_abs_epi8(a);
+
+ // Negate by subtracting from zero with a mask.
+ b0 = _mm512_mask_sub_epi8(b0, neg_mask, zeros, b0);
+ b1 = _mm512_mask_sub_epi8(b1, neg_mask, zeros, b1);
+ b2 = _mm512_mask_sub_epi8(b2, neg_mask, zeros, b2);
+ b3 = _mm512_mask_sub_epi8(b3, neg_mask, zeros, b3);
+ b4 = _mm512_mask_sub_epi8(b4, neg_mask, zeros, b4);
+ b5 = _mm512_mask_sub_epi8(b5, neg_mask, zeros, b5);
+ b6 = _mm512_mask_sub_epi8(b6, neg_mask, zeros, b6);
+ b7 = _mm512_mask_sub_epi8(b7, neg_mask, zeros, b7);
+ // The magic 8-bit multiply then horizontal sum into 16-bit.
+ b0 = _mm512_maddubs_epi16(a_positive, b0);
+ b1 = _mm512_maddubs_epi16(a_positive, b1);
+ b2 = _mm512_maddubs_epi16(a_positive, b2);
+ b3 = _mm512_maddubs_epi16(a_positive, b3);
+ b4 = _mm512_maddubs_epi16(a_positive, b4);
+ b5 = _mm512_maddubs_epi16(a_positive, b5);
+ b6 = _mm512_maddubs_epi16(a_positive, b6);
+ b7 = _mm512_maddubs_epi16(a_positive, b7);
+ // Now we have 16-bit results that are the sum of two multiplies.
+ // Choosing to approximate and do adds.
+ // Perhaps every so often we could accumulate by upcasting.
+ sum0 = _mm512_adds_epi16(sum0, b0);
+ sum1 = _mm512_adds_epi16(sum1, b1);
+ sum2 = _mm512_adds_epi16(sum2, b2);
+ sum3 = _mm512_adds_epi16(sum3, b3);
+ sum4 = _mm512_adds_epi16(sum4, b4);
+ sum5 = _mm512_adds_epi16(sum5, b5);
+ sum6 = _mm512_adds_epi16(sum6, b6);
+ sum7 = _mm512_adds_epi16(sum7, b7);
+ // Unique code ends: can we do an inline function?
+ }
+ // Upcast to 32-bit and horizontally add.
+ Integer ones = set1_epi16<Integer>(1);
+ sum0 = madd_epi16(sum0, ones);
+ sum1 = madd_epi16(sum1, ones);
+ sum2 = madd_epi16(sum2, ones);
+ sum3 = madd_epi16(sum3, ones);
+ sum4 = madd_epi16(sum4, ones);
+ sum5 = madd_epi16(sum5, ones);
+ sum6 = madd_epi16(sum6, ones);
+ sum7 = madd_epi16(sum7, ones);
+ Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3);
+ Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7);
+
+ auto total = PermuteSummer(pack0123, pack4567);
+ auto offset = A_rowidx * B_cols + B0_colidx;
+ auto result = inited_pipeline.run(total, offset);
+ writer(C, offset, result);
}
- // Upcast to 32-bit and horizontally add.
- Integer ones = set1_epi16<Integer>(1);
- sum0 = madd_epi16(sum0, ones);
- sum1 = madd_epi16(sum1, ones);
- sum2 = madd_epi16(sum2, ones);
- sum3 = madd_epi16(sum3, ones);
- sum4 = madd_epi16(sum4, ones);
- sum5 = madd_epi16(sum5, ones);
- sum6 = madd_epi16(sum6, ones);
- sum7 = madd_epi16(sum7, ones);
- Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3);
- Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7);
-
- auto total = PermuteSummer(pack0123, pack4567);
- auto offset = A_rowidx * B_cols + B0_colidx;
- auto result = inited_pipeline.run(total, offset);
- writer(C, offset, result);
}
}
-}
constexpr static const char *const kName = "8-bit AVX512";
diff --git a/intgemm.cc b/intgemm.cc
index 39ba227..6928f0c 100644
--- a/intgemm.cc
+++ b/intgemm.cc
@@ -26,4 +26,11 @@ const CPUType kCPU = ChooseCPU(CPUType::AVX512BW, CPUType::AVX2, CPUType::SSSE3,
float (*MaxAbsolute)(const float *begin, const float *end) = ChooseCPU(avx512f::MaxAbsolute, avx2::MaxAbsolute, sse2::MaxAbsolute, sse2::MaxAbsolute, Unsupported_MaxAbsolute);
+constexpr const char *const SSE2_16bit::kName;
+constexpr const char *const SSSE3_8bit::kName;
+constexpr const char *const AVX2_8bit::kName;
+constexpr const char *const AVX2_16bit::kName;
+constexpr const char *const AVX512_8bit::kName;
+constexpr const char *const AVX512_16bit::kName;
+
}