diff options
author | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2019-07-09 22:04:46 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-07-09 22:04:46 +0300 |
commit | 5466238858becaec459d154137dbd2d79baa0d3d (patch) | |
tree | 6ef5f76263af5337fabe9ef3c1325edd35699c65 | |
parent | 937e42a5e4b8ae145e0082d82e1b4074355600c2 (diff) | |
parent | d06b3e5f15876b3c0691494cf49c9dea78120378 (diff) |
Merge pull request #24 from kpu/fixes
Fixes
-rw-r--r-- | avx512_gemm.h | 216 | ||||
-rw-r--r-- | intgemm.cc | 7 |
2 files changed, 115 insertions, 108 deletions
diff --git a/avx512_gemm.h b/avx512_gemm.h index 043dfae..8326a82 100644 --- a/avx512_gemm.h +++ b/avx512_gemm.h @@ -219,117 +219,117 @@ struct AVX512_8bit { // allocate registers manually) and no sign instruction. template <typename PostprocessPipeline> INTGEMM_AVX512BW static void Multiply(const int8_t *A, const int8_t *B, float *C, PostprocessPipeline pipeline, Index A_rows, Index width, Index B_cols) { - typedef __m512i Integer; - //typedef __m256 Float; // For quantization we only do 8 at a time. - // This is copy-paste from Multiply8_SSE2OrAVX2. - assert(width % sizeof(Integer) == 0); - assert(B_cols % 8 == 0); - assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); - assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); - // There's 8 results for INTGEMM_AVX2 to handle. - auto inited_pipeline = InitPostprocessPipeline<CPUType::AVX2>(pipeline); - const int simd_width = width / sizeof(Integer); - const Integer *B0_col = reinterpret_cast<const Integer*>(B); - // Added for AVX512. - Integer zeros = setzero_si<Integer>(); - // Go over 8 columns of B at a time. - for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { - // Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once. - for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { - // Iterate over shared (inner) dimension. - const Integer *A_live = reinterpret_cast<const Integer *>(A + A_rowidx * width); - const Integer *A_end = A_live + simd_width; - const Integer *B_live = B0_col; - - // Do the first iteration to initialize the sums. - __m512i a = *A_live; - __mmask64 neg_mask = _mm512_test_epi8_mask(a, _mm512_set1_epi8(-128)); - __m512i a_positive = _mm512_abs_epi8(a); - // These will be packed 16-bit integers containing sums for each column of B multiplied by the row of A. - Integer sum0 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[0], neg_mask, zeros, B_live[0])); - Integer sum1 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[1], neg_mask, zeros, B_live[1])); - Integer sum2 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[2], neg_mask, zeros, B_live[2])); - Integer sum3 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[3], neg_mask, zeros, B_live[3])); - Integer sum4 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[4], neg_mask, zeros, B_live[4])); - Integer sum5 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[5], neg_mask, zeros, B_live[5])); - Integer sum6 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[6], neg_mask, zeros, B_live[6])); - Integer sum7 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[7], neg_mask, zeros, B_live[7])); - - ++A_live; - B_live += 8; - - // Use A as the loop variable so the add can be done where gcc likes it - // for branch prediction. - for (; A_live != A_end; ++A_live, B_live += 8) { - // Unique code here: can we do an inline function? - // Retrieve a. We will use this as the unsigned part. - a = *A_live; - // Retrieve the conveniently consecutive values of B. - __m512i b0 = *B_live; - __m512i b1 = *(B_live + 1); - __m512i b2 = *(B_live + 2); - __m512i b3 = *(B_live + 3); - __m512i b4 = *(B_live + 4); - __m512i b5 = *(B_live + 5); - __m512i b6 = *(B_live + 6); - __m512i b7 = *(B_live + 7); - - // Get a mask where a is negative. - // Didn't seem to make a difference definining sign bits here vs at top - neg_mask = _mm512_test_epi8_mask(a, _mm512_set1_epi8(-128)); - a_positive = _mm512_abs_epi8(a); - - // Negate by subtracting from zero with a mask. - b0 = _mm512_mask_sub_epi8(b0, neg_mask, zeros, b0); - b1 = _mm512_mask_sub_epi8(b1, neg_mask, zeros, b1); - b2 = _mm512_mask_sub_epi8(b2, neg_mask, zeros, b2); - b3 = _mm512_mask_sub_epi8(b3, neg_mask, zeros, b3); - b4 = _mm512_mask_sub_epi8(b4, neg_mask, zeros, b4); - b5 = _mm512_mask_sub_epi8(b5, neg_mask, zeros, b5); - b6 = _mm512_mask_sub_epi8(b6, neg_mask, zeros, b6); - b7 = _mm512_mask_sub_epi8(b7, neg_mask, zeros, b7); - // The magic 8-bit multiply then horizontal sum into 16-bit. - b0 = _mm512_maddubs_epi16(a_positive, b0); - b1 = _mm512_maddubs_epi16(a_positive, b1); - b2 = _mm512_maddubs_epi16(a_positive, b2); - b3 = _mm512_maddubs_epi16(a_positive, b3); - b4 = _mm512_maddubs_epi16(a_positive, b4); - b5 = _mm512_maddubs_epi16(a_positive, b5); - b6 = _mm512_maddubs_epi16(a_positive, b6); - b7 = _mm512_maddubs_epi16(a_positive, b7); - // Now we have 16-bit results that are the sum of two multiplies. - // Choosing to approximate and do adds. - // Perhaps every so often we could accumulate by upcasting. - sum0 = _mm512_adds_epi16(sum0, b0); - sum1 = _mm512_adds_epi16(sum1, b1); - sum2 = _mm512_adds_epi16(sum2, b2); - sum3 = _mm512_adds_epi16(sum3, b3); - sum4 = _mm512_adds_epi16(sum4, b4); - sum5 = _mm512_adds_epi16(sum5, b5); - sum6 = _mm512_adds_epi16(sum6, b6); - sum7 = _mm512_adds_epi16(sum7, b7); - // Unique code ends: can we do an inline function? + typedef __m512i Integer; + //typedef __m256 Float; // For quantization we only do 8 at a time. + // This is copy-paste from Multiply8_SSE2OrAVX2. + assert(width % sizeof(Integer) == 0); + assert(B_cols % 8 == 0); + assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); + assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); + // There's 8 results for INTGEMM_AVX2 to handle. + auto inited_pipeline = InitPostprocessPipeline<CPUType::AVX2>(pipeline); + const int simd_width = width / sizeof(Integer); + const Integer *B0_col = reinterpret_cast<const Integer*>(B); + // Added for AVX512. + Integer zeros = setzero_si<Integer>(); + // Go over 8 columns of B at a time. + for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { + // Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once. + for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { + // Iterate over shared (inner) dimension. + const Integer *A_live = reinterpret_cast<const Integer *>(A + A_rowidx * width); + const Integer *A_end = A_live + simd_width; + const Integer *B_live = B0_col; + + // Do the first iteration to initialize the sums. + __m512i a = *A_live; + __mmask64 neg_mask = _mm512_test_epi8_mask(a, _mm512_set1_epi8(-128)); + __m512i a_positive = _mm512_abs_epi8(a); + // These will be packed 16-bit integers containing sums for each column of B multiplied by the row of A. + Integer sum0 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[0], neg_mask, zeros, B_live[0])); + Integer sum1 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[1], neg_mask, zeros, B_live[1])); + Integer sum2 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[2], neg_mask, zeros, B_live[2])); + Integer sum3 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[3], neg_mask, zeros, B_live[3])); + Integer sum4 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[4], neg_mask, zeros, B_live[4])); + Integer sum5 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[5], neg_mask, zeros, B_live[5])); + Integer sum6 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[6], neg_mask, zeros, B_live[6])); + Integer sum7 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[7], neg_mask, zeros, B_live[7])); + + ++A_live; + B_live += 8; + + // Use A as the loop variable so the add can be done where gcc likes it + // for branch prediction. + for (; A_live != A_end; ++A_live, B_live += 8) { + // Unique code here: can we do an inline function? + // Retrieve a. We will use this as the unsigned part. + a = *A_live; + // Retrieve the conveniently consecutive values of B. + __m512i b0 = *B_live; + __m512i b1 = *(B_live + 1); + __m512i b2 = *(B_live + 2); + __m512i b3 = *(B_live + 3); + __m512i b4 = *(B_live + 4); + __m512i b5 = *(B_live + 5); + __m512i b6 = *(B_live + 6); + __m512i b7 = *(B_live + 7); + + // Get a mask where a is negative. + // Didn't seem to make a difference definining sign bits here vs at top + neg_mask = _mm512_test_epi8_mask(a, _mm512_set1_epi8(-128)); + a_positive = _mm512_abs_epi8(a); + + // Negate by subtracting from zero with a mask. + b0 = _mm512_mask_sub_epi8(b0, neg_mask, zeros, b0); + b1 = _mm512_mask_sub_epi8(b1, neg_mask, zeros, b1); + b2 = _mm512_mask_sub_epi8(b2, neg_mask, zeros, b2); + b3 = _mm512_mask_sub_epi8(b3, neg_mask, zeros, b3); + b4 = _mm512_mask_sub_epi8(b4, neg_mask, zeros, b4); + b5 = _mm512_mask_sub_epi8(b5, neg_mask, zeros, b5); + b6 = _mm512_mask_sub_epi8(b6, neg_mask, zeros, b6); + b7 = _mm512_mask_sub_epi8(b7, neg_mask, zeros, b7); + // The magic 8-bit multiply then horizontal sum into 16-bit. + b0 = _mm512_maddubs_epi16(a_positive, b0); + b1 = _mm512_maddubs_epi16(a_positive, b1); + b2 = _mm512_maddubs_epi16(a_positive, b2); + b3 = _mm512_maddubs_epi16(a_positive, b3); + b4 = _mm512_maddubs_epi16(a_positive, b4); + b5 = _mm512_maddubs_epi16(a_positive, b5); + b6 = _mm512_maddubs_epi16(a_positive, b6); + b7 = _mm512_maddubs_epi16(a_positive, b7); + // Now we have 16-bit results that are the sum of two multiplies. + // Choosing to approximate and do adds. + // Perhaps every so often we could accumulate by upcasting. + sum0 = _mm512_adds_epi16(sum0, b0); + sum1 = _mm512_adds_epi16(sum1, b1); + sum2 = _mm512_adds_epi16(sum2, b2); + sum3 = _mm512_adds_epi16(sum3, b3); + sum4 = _mm512_adds_epi16(sum4, b4); + sum5 = _mm512_adds_epi16(sum5, b5); + sum6 = _mm512_adds_epi16(sum6, b6); + sum7 = _mm512_adds_epi16(sum7, b7); + // Unique code ends: can we do an inline function? + } + // Upcast to 32-bit and horizontally add. + Integer ones = set1_epi16<Integer>(1); + sum0 = madd_epi16(sum0, ones); + sum1 = madd_epi16(sum1, ones); + sum2 = madd_epi16(sum2, ones); + sum3 = madd_epi16(sum3, ones); + sum4 = madd_epi16(sum4, ones); + sum5 = madd_epi16(sum5, ones); + sum6 = madd_epi16(sum6, ones); + sum7 = madd_epi16(sum7, ones); + Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3); + Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); + + auto total = PermuteSummer(pack0123, pack4567); + auto offset = A_rowidx * B_cols + B0_colidx; + auto result = inited_pipeline.run(total, offset); + writer(C, offset, result); } - // Upcast to 32-bit and horizontally add. - Integer ones = set1_epi16<Integer>(1); - sum0 = madd_epi16(sum0, ones); - sum1 = madd_epi16(sum1, ones); - sum2 = madd_epi16(sum2, ones); - sum3 = madd_epi16(sum3, ones); - sum4 = madd_epi16(sum4, ones); - sum5 = madd_epi16(sum5, ones); - sum6 = madd_epi16(sum6, ones); - sum7 = madd_epi16(sum7, ones); - Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3); - Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); - - auto total = PermuteSummer(pack0123, pack4567); - auto offset = A_rowidx * B_cols + B0_colidx; - auto result = inited_pipeline.run(total, offset); - writer(C, offset, result); } } -} constexpr static const char *const kName = "8-bit AVX512"; @@ -26,4 +26,11 @@ const CPUType kCPU = ChooseCPU(CPUType::AVX512BW, CPUType::AVX2, CPUType::SSSE3, float (*MaxAbsolute)(const float *begin, const float *end) = ChooseCPU(avx512f::MaxAbsolute, avx2::MaxAbsolute, sse2::MaxAbsolute, sse2::MaxAbsolute, Unsupported_MaxAbsolute); +constexpr const char *const SSE2_16bit::kName; +constexpr const char *const SSSE3_8bit::kName; +constexpr const char *const AVX2_8bit::kName; +constexpr const char *const AVX2_16bit::kName; +constexpr const char *const AVX512_8bit::kName; +constexpr const char *const AVX512_16bit::kName; + } |