diff options
author | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2020-02-20 18:48:00 +0300 |
---|---|---|
committer | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2020-02-21 18:16:51 +0300 |
commit | 0a8508842cf757cb4b1ed4df9f7907c86ae4df0f (patch) | |
tree | 28c1d2c56786e2ba05b7534f05549da216f1f37b /avx512vnni_gemm.h | |
parent | 38d2e0e1040f19118a397702535af187d507e464 (diff) |
Use static loops in AVX512_VNNI Multiply 8bit function
Diffstat (limited to 'avx512vnni_gemm.h')
-rw-r--r-- | avx512vnni_gemm.h | 180 |
1 files changed, 99 insertions, 81 deletions
diff --git a/avx512vnni_gemm.h b/avx512vnni_gemm.h index d47e2ca..d3d65e2 100644 --- a/avx512vnni_gemm.h +++ b/avx512vnni_gemm.h @@ -8,91 +8,109 @@ namespace intgemm { +// Rewrite that loads of struct to template labdas as soon as c++14 is used +struct AVX512VNNI_Multiply_InitALivesLoop { + template <typename Iterator, typename Type> + INTGEMM_AVX512VNNI static void body(const Type* A, Index A_rowidx, Index A_rows, Index width, const __m512i* A_lives[Iterator::total_iterations]) { + A_lives[Iterator::template I<0>()] = reinterpret_cast<const __m512i*>(A + (A_rowidx + Iterator::template I<0>()) * width); + } +}; + +struct AVX512VNNI_Multiply_InitSumsLoop { + template <typename Iterator> + INTGEMM_AVX512VNNI static void body(__m512i sums[Iterator::template N<0>()][Iterator::template N<1>()]) { + static constexpr auto Row = Iterator::template I<0>(); + static constexpr auto Column = Iterator::template I<1>(); + sums[Row][Column] = setzero_si<__m512i>(); + } +}; + +struct AVX512VNNI_Multiply_TileLoop { + template <typename Iterator> + INTGEMM_AVX512VNNI static void body(const __m512i* A_lives[Iterator::template N<0>()], + const __m512i* B_live, + __m512i sums[Iterator::template N<0>()][Iterator::template N<1>()]) { + static constexpr auto Row = Iterator::template I<0>(); + static constexpr auto Column = Iterator::template I<1>(); + auto neg_mask = _mm512_test_epi8_mask(*A_lives[Row], _mm512_set1_epi8(-128)); + sums[Row][Column] = _mm512_dpbusds_epi32(sums[Row][Column], _mm512_abs_epi8(*A_lives[Row]), _mm512_mask_sub_epi8(B_live[Column], neg_mask, setzero_si<__m512i>(), B_live[Column])); + } +}; + +struct AVX512VNNI_Multiply_IncreaseALivesLoop { + template <typename Iterator> + INTGEMM_AVX512VNNI static void body(const __m512i* A_lives[Iterator::total_iterations]) { + ++A_lives[Iterator::template I<0>()]; + } +}; + +struct AVX512VNNI_Multiply_MakeFinalOutputAndRunCallback { + template <typename Iterator, typename CallbackImpl> + INTGEMM_AVX512VNNI static void body(__m512i sums[Iterator::template N<0>()][8 * Iterator::template N<1>()], CallbackImpl callback_impl, Index A_rowidx, Index B_colidx, Index A_rows, Index B_cols) { + static constexpr auto Row = Iterator::template I<0>(); + static constexpr auto Column8 = Iterator::template I<1>(); + auto pack0123 = Pack0123(sums[Row][8 * Column8 + 0], sums[Row][8 * Column8 + 1], sums[Row][8 * Column8 + 2], sums[Row][8 * Column8 + 3]); + auto pack4567 = Pack0123(sums[Row][8 * Column8 + 4], sums[Row][8 * Column8 + 5], sums[Row][8 * Column8 + 6], sums[Row][8 * Column8 + 7]); + auto total = PermuteSummer(pack0123, pack4567); + RunCallback(callback_impl, total, A_rowidx + Iterator::template I<0>(), B_colidx + 8 * Column8, A_rows, B_cols); + } +}; + struct AVX512VNNI_8bit : public AVX512_8bit { - template <Index TileRows, Index TileColumnsMultiplier, typename Callback> \ + template <Index TileRows, Index TileColumnsMultiplier, typename Callback> INTGEMM_AVX512VNNI static void Multiply(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { - typedef __m512i Register; - assert(width % sizeof(Register) == 0); - assert(B_cols % 8 == 0); - assert(reinterpret_cast<uintptr_t>(A) % sizeof(Register) == 0); - assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); + static constexpr Index TileColumns = 8 * TileColumnsMultiplier; + assert(A_rows % TileRows == 0); + assert(width % sizeof(__m512i) == 0); + assert(B_cols % TileColumns == 0); + assert(reinterpret_cast<uintptr_t>(A) % sizeof(__m512i) == 0); + assert(reinterpret_cast<uintptr_t>(B) % sizeof(__m512i) == 0); + + const int simd_width = width / sizeof(__m512i); auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback); - const int simd_width = width / sizeof(Register); - const Register *B0_col = reinterpret_cast<const Register*>(B); - Register zeros = setzero_si<Register>(); - // Go over 8 columns of B at a time. - for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { - // Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once. - for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { - // Iterate over shared (inner) dimension. - const Register *A_live = reinterpret_cast<const Register *>(A + A_rowidx * width); - const Register *A_end = A_live + simd_width; - const Register *B_live = B0_col; - // TODO: separate first step. - Register sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros; - for (; A_live != A_end; ++A_live, B_live += 8) { - Register a = *A_live; - // Retrieve the conveniently consecutive values of B. - Register b0 = *B_live; - Register b1 = *(B_live + 1); - Register b2 = *(B_live + 2); - Register b3 = *(B_live + 3); - Register b4 = *(B_live + 4); - Register b5 = *(B_live + 5); - Register b6 = *(B_live + 6); - Register b7 = *(B_live + 7); - // Get a mask where a is negative. - __mmask64 neg_mask = _mm512_test_epi8_mask(a, _mm512_set1_epi8(-128)); - Register a_positive = _mm512_abs_epi8(a); - // Negate by subtracting from zero with a mask. - b0 = _mm512_mask_sub_epi8(b0, neg_mask, zeros, b0); - b1 = _mm512_mask_sub_epi8(b1, neg_mask, zeros, b1); - b2 = _mm512_mask_sub_epi8(b2, neg_mask, zeros, b2); - b3 = _mm512_mask_sub_epi8(b3, neg_mask, zeros, b3); - b4 = _mm512_mask_sub_epi8(b4, neg_mask, zeros, b4); - b5 = _mm512_mask_sub_epi8(b5, neg_mask, zeros, b5); - b6 = _mm512_mask_sub_epi8(b6, neg_mask, zeros, b6); - b7 = _mm512_mask_sub_epi8(b7, neg_mask, zeros, b7); - sum0 = _mm512_dpbusds_epi32(sum0, a_positive, b0); - sum1 = _mm512_dpbusds_epi32(sum1, a_positive, b1); - sum2 = _mm512_dpbusds_epi32(sum2, a_positive, b2); - sum3 = _mm512_dpbusds_epi32(sum3, a_positive, b3); - sum4 = _mm512_dpbusds_epi32(sum4, a_positive, b4); - sum5 = _mm512_dpbusds_epi32(sum5, a_positive, b5); - sum6 = _mm512_dpbusds_epi32(sum6, a_positive, b6); - sum7 = _mm512_dpbusds_epi32(sum7, a_positive, b7); + const __m512i *A_lives[TileRows]; + __m512i sums[TileRows][TileColumns]; + + /* Process with tile = (TileRows, TileColumns). */ + auto *B0_col = reinterpret_cast<const __m512i*>(B); + for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += TileColumns * simd_width, B0_colidx += TileColumns) { + for (Index A_rowidx = 0; A_rowidx < A_rows; A_rowidx += TileRows) { + StaticLoop<AVX512VNNI_Multiply_InitALivesLoop, MakeStaticLoopIterator<TileRows>>(A, A_rowidx, A_rows, width, A_lives); + StaticLoop<AVX512VNNI_Multiply_InitSumsLoop, MakeStaticLoopIterator<TileRows, TileColumns>>(sums); + /* Process a tile (use A as the loop variable so the add can be done where gcc likes it for branch prediction. */ + auto* B_live = B0_col; + for (Index i = 0; i < simd_width; ++i, B_live += TileColumns) { + StaticLoop<AVX512VNNI_Multiply_TileLoop, MakeStaticLoopIterator<TileRows, TileColumns>>(A_lives, B_live, sums); + StaticLoop<AVX512VNNI_Multiply_IncreaseALivesLoop, MakeStaticLoopIterator<TileRows>>(A_lives); } - Register pack0123 = Pack0123(sum0, sum1, sum2, sum3); - Register pack4567 = Pack0123(sum4, sum5, sum6, sum7); - auto total = PermuteSummer(pack0123, pack4567); - callback_impl(total, callbacks::OutputBufferInfo(A_rowidx, B0_colidx, A_rows, B_cols)); + StaticLoop<AVX512VNNI_Multiply_MakeFinalOutputAndRunCallback, MakeStaticLoopIterator<TileRows, TileColumnsMultiplier>>(sums, callback_impl, A_rowidx, B0_colidx, A_rows, B_cols); } } } - template <Index TileRows, Index TileColumnsMultiplier, typename Callback> \ + template <Index TileRows, Index TileColumnsMultiplier, typename Callback> INTGEMM_AVX512VNNI static void Multiply8Shift(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { - typedef __m512i Register; - assert(width % sizeof(Register) == 0); + typedef __m512i __m512i; + assert(width % sizeof(__m512i) == 0); assert(B_cols % 8 == 0); - assert(reinterpret_cast<uintptr_t>(A) % sizeof(Register) == 0); - assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); + assert(reinterpret_cast<uintptr_t>(A) % sizeof(__m512i) == 0); + assert(reinterpret_cast<uintptr_t>(B) % sizeof(__m512i) == 0); auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback); - const int simd_width = width / sizeof(Register); - const Register *B0_col = reinterpret_cast<const Register*>(B); - Register zeros = setzero_si<Register>(); + const int simd_width = width / sizeof(__m512i); + const __m512i *B0_col = reinterpret_cast<const __m512i*>(B); + __m512i zeros = setzero_si<__m512i>(); // Go over 8 columns of B at a time. for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { // Process one row of A at a time. Doesn't seem to be faster to do multiple rows of A at once. for (Index A_rowidx = 0; A_rowidx < A_rows; ++A_rowidx) { // Iterate over shared (inner) dimension. - const Register *A_live = reinterpret_cast<const Register *>(A + A_rowidx * width); - const Register *A_end = A_live + simd_width; - const Register *B_live = B0_col; + const __m512i *A_live = reinterpret_cast<const __m512i *>(A + A_rowidx * width); + const __m512i *A_end = A_live + simd_width; + const __m512i *B_live = B0_col; // TODO: separate first step. - Register sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros; + __m512i sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros; for (; A_live != A_end; ++A_live, B_live += 8) { - Register a = *A_live; + __m512i a = *A_live; //MultiplyAdd sum0 = _mm512_dpbusds_epi32(sum0, a, *B_live); sum1 = _mm512_dpbusds_epi32(sum1, a, *(B_live + 1)); @@ -103,8 +121,8 @@ struct AVX512VNNI_8bit : public AVX512_8bit { sum6 = _mm512_dpbusds_epi32(sum6, a, *(B_live + 6)); sum7 = _mm512_dpbusds_epi32(sum7, a, *(B_live + 7)); } - Register pack0123 = Pack0123(sum0, sum1, sum2, sum3); - Register pack4567 = Pack0123(sum4, sum5, sum6, sum7); + __m512i pack0123 = Pack0123(sum0, sum1, sum2, sum3); + __m512i pack4567 = Pack0123(sum4, sum5, sum6, sum7); auto total = PermuteSummer(pack0123, pack4567); callback_impl(total, callbacks::OutputBufferInfo(A_rowidx, B0_colidx, A_rows, B_cols)); } @@ -113,22 +131,22 @@ struct AVX512VNNI_8bit : public AVX512_8bit { template <typename Callback> INTGEMM_AVX512VNNI static void PrepareBias(const int8_t *B, Index width, Index B_cols, Callback callback) { - typedef __m512i Register; - assert(width % sizeof(Register) == 0); + typedef __m512i __m512i; + assert(width % sizeof(__m512i) == 0); assert(B_cols % 8 == 0); - assert(reinterpret_cast<uintptr_t>(B) % sizeof(Register) == 0); + assert(reinterpret_cast<uintptr_t>(B) % sizeof(__m512i) == 0); auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback); - const int simd_width = width / sizeof(Register); - const Register *B0_col = reinterpret_cast<const Register*>(B); - Register zeros = setzero_si<Register>(); - const Register a = set1_epi8<Register>(1); + const int simd_width = width / sizeof(__m512i); + const __m512i *B0_col = reinterpret_cast<const __m512i*>(B); + __m512i zeros = setzero_si<__m512i>(); + const __m512i a = set1_epi8<__m512i>(1); // Go over 8 columns of B at a time. for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { - const Register *B_live = B0_col; //In order to make the code look as much as possible as the above function - const Register *B_end = B_live + simd_width*8; + const __m512i *B_live = B0_col; //In order to make the code look as much as possible as the above function + const __m512i *B_end = B_live + simd_width*8; // TODO: separate first step. - Register sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros; + __m512i sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros; for (; B_live != B_end; B_live += 8) { // Retrieve the conveniently consecutive values of B. sum0 = _mm512_dpbusds_epi32(sum0, a, *B_live); @@ -140,8 +158,8 @@ struct AVX512VNNI_8bit : public AVX512_8bit { sum6 = _mm512_dpbusds_epi32(sum6, a, *(B_live + 6)); sum7 = _mm512_dpbusds_epi32(sum7, a, *(B_live + 7)); } - Register pack0123 = Pack0123(sum0, sum1, sum2, sum3); - Register pack4567 = Pack0123(sum4, sum5, sum6, sum7); + __m512i pack0123 = Pack0123(sum0, sum1, sum2, sum3); + __m512i pack4567 = Pack0123(sum4, sum5, sum6, sum7); auto total = PermuteSummer(pack0123, pack4567); callback_impl(total, callbacks::OutputBufferInfo(0, B0_colidx, 1, B_cols)); } |