diff options
Diffstat (limited to 'avx512_gemm.h')
-rw-r--r-- | avx512_gemm.h | 56 |
1 files changed, 28 insertions, 28 deletions
diff --git a/avx512_gemm.h b/avx512_gemm.h index 91fdd8a..2a5fff1 100644 --- a/avx512_gemm.h +++ b/avx512_gemm.h @@ -70,12 +70,12 @@ INTGEMM_AVX512BW inline __m512i QuantizerGrabHalves(const float *input0, const f // being used for the quantizer. class QuantizeTile16 { public: - typedef __m512i Integer; + typedef __m512i Register; /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */ INTGEMM_AVX512BW explicit QuantizeTile16(float mult) : mult_reg_(_mm512_set1_ps(mult)) {} - INTGEMM_AVX512BW Integer ConsecutiveWithWrapping(const float *input, Index cols_left, Index cols, Index row_step) const { + INTGEMM_AVX512BW Register ConsecutiveWithWrapping(const float *input, Index cols_left, Index cols, Index row_step) const { auto input0 = input; auto input1 = input + 16 + (cols_left <= 16 ? cols * (row_step - 1) : 0); auto g0 = QuantizerGrabHalves(input0, input1, mult_reg_); @@ -98,24 +98,24 @@ class QuantizeTile16 { class QuantizeTile8 { public: - typedef __m512i Integer; + typedef __m512i Register; /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */ INTGEMM_AVX512BW explicit QuantizeTile8(float mult) : mult_reg_(_mm512_set1_ps(mult)) {} - INTGEMM_AVX512BW Integer ConsecutiveWithWrapping(const float *input, Index cols_left, Index cols, Index row_step) const { + INTGEMM_AVX512BW Register ConsecutiveWithWrapping(const float *input, Index cols_left, Index cols, Index row_step) const { static const __m512i neg127 = _mm512_set1_epi8(-127); static const __m512i shuffle_param = _mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0); const float* inputs[4]; for (int i = 0; i < sizeof(inputs) / sizeof(inputs[0]); ++i) { - while (cols_left < sizeof(Integer) / sizeof(float)) { + while (cols_left < sizeof(Register) / sizeof(float)) { input += cols * (row_step - 1); cols_left += cols; } inputs[i] = input; - input += sizeof(Integer) / sizeof(float); - cols_left -= sizeof(Integer) / sizeof(float); + input += sizeof(Register) / sizeof(float); + cols_left -= sizeof(Register) / sizeof(float); } auto g0 = QuantizerGrab(inputs[0], mult_reg_); @@ -292,41 +292,41 @@ struct AVX512_8bit { // allocate registers manually) and no sign instruction. template <typename Callback> INTGEMM_AVX512BW static void Multiply(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { - typedef __m512i Integer; + typedef __m512i Register; //typedef __m256 Float; // For quantization we only do 8 at a time. // This is copy-paste from Multiply8_SSE2OrAVX2. - assert(width % sizeof(Integer) == 0); + assert(width % sizeof(Register) == 0); assert(B_cols % 8 == 0); - assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); - assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); + assert(reinterpret_cast<uintptr_t>(A) % sizeof(Register) == 0); + assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); // There's 8 results for INTGEMM_AVX2 to handle. auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback); - const int simd_width = width / sizeof(Integer); - const Integer *B0_col = reinterpret_cast<const Integer*>(B); + const int simd_width = width / sizeof(Register); + const Register *B0_col = reinterpret_cast<const Register*>(B); // Added for AVX512. - Integer zeros = setzero_si<Integer>(); + Register zeros = setzero_si<Register>(); // Go over 8 columns of B at a time. for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { // Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once. for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { // Iterate over shared (inner) dimension. - const Integer *A_live = reinterpret_cast<const Integer *>(A + A_rowidx * width); - const Integer *A_end = A_live + simd_width; - const Integer *B_live = B0_col; + const Register *A_live = reinterpret_cast<const Register *>(A + A_rowidx * width); + const Register *A_end = A_live + simd_width; + const Register *B_live = B0_col; // Do the first iteration to initialize the sums. __m512i a = *A_live; __mmask64 neg_mask = _mm512_test_epi8_mask(a, _mm512_set1_epi8(-128)); __m512i a_positive = _mm512_abs_epi8(a); // These will be packed 16-bit integers containing sums for each column of B multiplied by the row of A. - Integer sum0 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[0], neg_mask, zeros, B_live[0])); - Integer sum1 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[1], neg_mask, zeros, B_live[1])); - Integer sum2 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[2], neg_mask, zeros, B_live[2])); - Integer sum3 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[3], neg_mask, zeros, B_live[3])); - Integer sum4 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[4], neg_mask, zeros, B_live[4])); - Integer sum5 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[5], neg_mask, zeros, B_live[5])); - Integer sum6 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[6], neg_mask, zeros, B_live[6])); - Integer sum7 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[7], neg_mask, zeros, B_live[7])); + Register sum0 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[0], neg_mask, zeros, B_live[0])); + Register sum1 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[1], neg_mask, zeros, B_live[1])); + Register sum2 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[2], neg_mask, zeros, B_live[2])); + Register sum3 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[3], neg_mask, zeros, B_live[3])); + Register sum4 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[4], neg_mask, zeros, B_live[4])); + Register sum5 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[5], neg_mask, zeros, B_live[5])); + Register sum6 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[6], neg_mask, zeros, B_live[6])); + Register sum7 = maddubs_epi16(a_positive, _mm512_mask_sub_epi8(B_live[7], neg_mask, zeros, B_live[7])); ++A_live; B_live += 8; @@ -384,7 +384,7 @@ struct AVX512_8bit { // Unique code ends: can we do an inline function? } // Upcast to 32-bit and horizontally add. - Integer ones = set1_epi16<Integer>(1); + Register ones = set1_epi16<Register>(1); sum0 = madd_epi16(sum0, ones); sum1 = madd_epi16(sum1, ones); sum2 = madd_epi16(sum2, ones); @@ -393,8 +393,8 @@ struct AVX512_8bit { sum5 = madd_epi16(sum5, ones); sum6 = madd_epi16(sum6, ones); sum7 = madd_epi16(sum7, ones); - Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3); - Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); + Register pack0123 = Pack0123(sum0, sum1, sum2, sum3); + Register pack4567 = Pack0123(sum4, sum5, sum6, sum7); auto total = PermuteSummer(pack0123, pack4567); callback_impl(total, callbacks::OutputBufferInfo(A_rowidx, B0_colidx, A_rows, B_cols)); |