diff options
Diffstat (limited to 'avx512vnni_gemm.h')
-rw-r--r-- | avx512vnni_gemm.h | 98 |
1 files changed, 49 insertions, 49 deletions
diff --git a/avx512vnni_gemm.h b/avx512vnni_gemm.h index 3f616a6..59f6405 100644 --- a/avx512vnni_gemm.h +++ b/avx512vnni_gemm.h @@ -11,39 +11,39 @@ namespace intgemm { struct AVX512VNNI_8bit : public AVX512_8bit { template <typename Callback> INTGEMM_AVX512VNNI static void Multiply(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { - typedef __m512i Integer; - assert(width % sizeof(Integer) == 0); + typedef __m512i Register; + assert(width % sizeof(Register) == 0); assert(B_cols % 8 == 0); - assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); - assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); + assert(reinterpret_cast<uintptr_t>(A) % sizeof(Register) == 0); + assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback); - const int simd_width = width / sizeof(Integer); - const Integer *B0_col = reinterpret_cast<const Integer*>(B); - Integer zeros = setzero_si<Integer>(); + const int simd_width = width / sizeof(Register); + const Register *B0_col = reinterpret_cast<const Register*>(B); + Register zeros = setzero_si<Register>(); // Go over 8 columns of B at a time. for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { // Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once. for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { // Iterate over shared (inner) dimension. - const Integer *A_live = reinterpret_cast<const Integer *>(A + A_rowidx * width); - const Integer *A_end = A_live + simd_width; - const Integer *B_live = B0_col; + const Register *A_live = reinterpret_cast<const Register *>(A + A_rowidx * width); + const Register *A_end = A_live + simd_width; + const Register *B_live = B0_col; // TODO: separate first step. - Integer sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros; + Register sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros; for (; A_live != A_end; ++A_live, B_live += 8) { - Integer a = *A_live; + Register a = *A_live; // Retrieve the conveniently consecutive values of B. - Integer b0 = *B_live; - Integer b1 = *(B_live + 1); - Integer b2 = *(B_live + 2); - Integer b3 = *(B_live + 3); - Integer b4 = *(B_live + 4); - Integer b5 = *(B_live + 5); - Integer b6 = *(B_live + 6); - Integer b7 = *(B_live + 7); + Register b0 = *B_live; + Register b1 = *(B_live + 1); + Register b2 = *(B_live + 2); + Register b3 = *(B_live + 3); + Register b4 = *(B_live + 4); + Register b5 = *(B_live + 5); + Register b6 = *(B_live + 6); + Register b7 = *(B_live + 7); // Get a mask where a is negative. __mmask64 neg_mask = _mm512_test_epi8_mask(a, _mm512_set1_epi8(-128)); - Integer a_positive = _mm512_abs_epi8(a); + Register a_positive = _mm512_abs_epi8(a); // Negate by subtracting from zero with a mask. b0 = _mm512_mask_sub_epi8(b0, neg_mask, zeros, b0); b1 = _mm512_mask_sub_epi8(b1, neg_mask, zeros, b1); @@ -62,8 +62,8 @@ struct AVX512VNNI_8bit : public AVX512_8bit { sum6 = _mm512_dpbusds_epi32(sum6, a_positive, b6); sum7 = _mm512_dpbusds_epi32(sum7, a_positive, b7); } - Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3); - Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); + Register pack0123 = Pack0123(sum0, sum1, sum2, sum3); + Register pack4567 = Pack0123(sum4, sum5, sum6, sum7); auto total = PermuteSummer(pack0123, pack4567); callback_impl(total, callbacks::OutputBufferInfo(A_rowidx, B0_colidx, A_rows, B_cols)); } @@ -72,27 +72,27 @@ struct AVX512VNNI_8bit : public AVX512_8bit { template <typename Callback> INTGEMM_AVX512VNNI static void Multiply8Shift(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { - typedef __m512i Integer; - assert(width % sizeof(Integer) == 0); + typedef __m512i Register; + assert(width % sizeof(Register) == 0); assert(B_cols % 8 == 0); - assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); - assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); + assert(reinterpret_cast<uintptr_t>(A) % sizeof(Register) == 0); + assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback); - const int simd_width = width / sizeof(Integer); - const Integer *B0_col = reinterpret_cast<const Integer*>(B); - Integer zeros = setzero_si<Integer>(); + const int simd_width = width / sizeof(Register); + const Register *B0_col = reinterpret_cast<const Register*>(B); + Register zeros = setzero_si<Register>(); // Go over 8 columns of B at a time. for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { // Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once. for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { // Iterate over shared (inner) dimension. - const Integer *A_live = reinterpret_cast<const Integer *>(A + A_rowidx * width); - const Integer *A_end = A_live + simd_width; - const Integer *B_live = B0_col; + const Register *A_live = reinterpret_cast<const Register *>(A + A_rowidx * width); + const Register *A_end = A_live + simd_width; + const Register *B_live = B0_col; // TODO: separate first step. - Integer sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros; + Register sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros; for (; A_live != A_end; ++A_live, B_live += 8) { - Integer a = *A_live; + Register a = *A_live; //MultiplyAdd sum0 = _mm512_dpbusds_epi32(sum0, a, *B_live); sum1 = _mm512_dpbusds_epi32(sum1, a, *(B_live + 1)); @@ -103,8 +103,8 @@ struct AVX512VNNI_8bit : public AVX512_8bit { sum6 = _mm512_dpbusds_epi32(sum6, a, *(B_live + 6)); sum7 = _mm512_dpbusds_epi32(sum7, a, *(B_live + 7)); } - Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3); - Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); + Register pack0123 = Pack0123(sum0, sum1, sum2, sum3); + Register pack4567 = Pack0123(sum4, sum5, sum6, sum7); auto total = PermuteSummer(pack0123, pack4567); callback_impl(total, callbacks::OutputBufferInfo(A_rowidx, B0_colidx, A_rows, B_cols)); } @@ -113,22 +113,22 @@ struct AVX512VNNI_8bit : public AVX512_8bit { template <typename Callback> INTGEMM_AVX512VNNI static void PrepareBias(const int8_t *B, Index width, Index B_cols, Callback callback) { - typedef __m512i Integer; - assert(width % sizeof(Integer) == 0); + typedef __m512i Register; + assert(width % sizeof(Register) == 0); assert(B_cols % 8 == 0); - assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); + assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback); - const int simd_width = width / sizeof(Integer); - const Integer *B0_col = reinterpret_cast<const Integer*>(B); - Integer zeros = setzero_si<Integer>(); - const Integer a = set1_epi8<Integer>(1); + const int simd_width = width / sizeof(Register); + const Register *B0_col = reinterpret_cast<const Register*>(B); + Register zeros = setzero_si<Register>(); + const Register a = set1_epi8<Register>(1); // Go over 8 columns of B at a time. for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { - const Integer *B_live = B0_col; //In order to make the code look as much as possible as the above function - const Integer *B_end = B_live + simd_width*8; + const Register *B_live = B0_col; //In order to make the code look as much as possible as the above function + const Register *B_end = B_live + simd_width*8; // TODO: separate first step. - Integer sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros; + Register sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros; for (; B_live != B_end; B_live += 8) { // Retrieve the conveniently consecutive values of B. sum0 = _mm512_dpbusds_epi32(sum0, a, *B_live); @@ -140,8 +140,8 @@ struct AVX512VNNI_8bit : public AVX512_8bit { sum6 = _mm512_dpbusds_epi32(sum6, a, *(B_live + 6)); sum7 = _mm512_dpbusds_epi32(sum7, a, *(B_live + 7)); } - Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3); - Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); + Register pack0123 = Pack0123(sum0, sum1, sum2, sum3); + Register pack4567 = Pack0123(sum4, sum5, sum6, sum7); auto total = PermuteSummer(pack0123, pack4567); callback_impl(total, callbacks::OutputBufferInfo(0, B0_colidx, 1, B_cols)); } |