diff options
author | Nikolay Bogoychev <nheart@gmail.com> | 2019-12-02 03:15:16 +0300 |
---|---|---|
committer | Nikolay Bogoychev <nheart@gmail.com> | 2019-12-02 03:15:16 +0300 |
commit | b7a99eef5cc23c9d1756724b42636cd44ab39f22 (patch) | |
tree | 1b0f20e10334a25e2e6bcfeee9223611b38e3c94 /avx512vnni_gemm.h | |
parent | c9ba9be09a42403513810bc6fe645f4219ecade2 (diff) |
VNNI based PrepareBiasFor8
Diffstat (limited to 'avx512vnni_gemm.h')
-rw-r--r-- | avx512vnni_gemm.h | 61 |
1 files changed, 43 insertions, 18 deletions
diff --git a/avx512vnni_gemm.h b/avx512vnni_gemm.h index f8d4a61..bffa8e9 100644 --- a/avx512vnni_gemm.h +++ b/avx512vnni_gemm.h @@ -93,24 +93,15 @@ struct AVX512VNNI_8bit : public AVX512_8bit { Integer sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros; for (; A_live != A_end; ++A_live, B_live += 8) { Integer a = *A_live; - // Retrieve the conveniently consecutive values of B. - Integer b0 = *B_live; - Integer b1 = *(B_live + 1); - Integer b2 = *(B_live + 2); - Integer b3 = *(B_live + 3); - Integer b4 = *(B_live + 4); - Integer b5 = *(B_live + 5); - Integer b6 = *(B_live + 6); - Integer b7 = *(B_live + 7); //MultiplyAdd - sum0 = _mm512_dpbusds_epi32(sum0, a, b0); - sum1 = _mm512_dpbusds_epi32(sum1, a, b1); - sum2 = _mm512_dpbusds_epi32(sum2, a, b2); - sum3 = _mm512_dpbusds_epi32(sum3, a, b3); - sum4 = _mm512_dpbusds_epi32(sum4, a, b4); - sum5 = _mm512_dpbusds_epi32(sum5, a, b5); - sum6 = _mm512_dpbusds_epi32(sum6, a, b6); - sum7 = _mm512_dpbusds_epi32(sum7, a, b7); + sum0 = _mm512_dpbusds_epi32(sum0, a, *B_live); + sum1 = _mm512_dpbusds_epi32(sum1, a, *(B_live + 1)); + sum2 = _mm512_dpbusds_epi32(sum2, a, *(B_live + 2)); + sum3 = _mm512_dpbusds_epi32(sum3, a, *(B_live + 3)); + sum4 = _mm512_dpbusds_epi32(sum4, a, *(B_live + 4)); + sum5 = _mm512_dpbusds_epi32(sum5, a, *(B_live + 5)); + sum6 = _mm512_dpbusds_epi32(sum6, a, *(B_live + 6)); + sum7 = _mm512_dpbusds_epi32(sum7, a, *(B_live + 7)); } Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3); Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); @@ -120,7 +111,41 @@ struct AVX512VNNI_8bit : public AVX512_8bit { } } - INTGEMM_PREPAREBIASFOR8(__m512i, INTGEMM_AVX512BW, CPUType::AVX2) + template <typename Callback> + INTGEMM_AVX512VNNI static void PrepareBiasFor8(const int8_t *B, Index width, Index B_cols, Callback callback) { + typedef __m512i Integer; + assert(width % sizeof(Integer) == 0); + assert(B_cols % 8 == 0); + assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); + auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback); + const int simd_width = width / sizeof(Integer); + const Integer *B0_col = reinterpret_cast<const Integer*>(B); + Integer zeros = setzero_si<Integer>(); + const Integer a = set1_epi8<Integer>(1); + // Go over 8 columns of B at a time. + for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) { + const Integer *B_live = B0_col; //In order to make the code look as much as possible as the above function + const Integer *B_end = B_live + simd_width*8; + + // TODO: separate first step. + Integer sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros; + for (; B_live != B_end; B_live += 8) { + // Retrieve the conveniently consecutive values of B. + sum0 = _mm512_dpbusds_epi32(sum0, a, *B_live); + sum1 = _mm512_dpbusds_epi32(sum1, a, *(B_live + 1)); + sum2 = _mm512_dpbusds_epi32(sum2, a, *(B_live + 2)); + sum3 = _mm512_dpbusds_epi32(sum3, a, *(B_live + 3)); + sum4 = _mm512_dpbusds_epi32(sum4, a, *(B_live + 4)); + sum5 = _mm512_dpbusds_epi32(sum5, a, *(B_live + 5)); + sum6 = _mm512_dpbusds_epi32(sum6, a, *(B_live + 6)); + sum7 = _mm512_dpbusds_epi32(sum7, a, *(B_live + 7)); + } + Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3); + Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7); + auto total = PermuteSummer(pack0123, pack4567); + callback_impl(total, callbacks::OutputBufferInfo(0, B0_colidx, 1, B_cols)); + } + } constexpr static const char *const kName = "8-bit AVX512VNNI"; |