Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNikolay Bogoychev <nheart@gmail.com>2019-12-02 03:15:16 +0300
committerNikolay Bogoychev <nheart@gmail.com>2019-12-02 03:15:16 +0300
commitb7a99eef5cc23c9d1756724b42636cd44ab39f22 (patch)
tree1b0f20e10334a25e2e6bcfeee9223611b38e3c94 /avx512vnni_gemm.h
parentc9ba9be09a42403513810bc6fe645f4219ecade2 (diff)
VNNI based PrepareBiasFor8
Diffstat (limited to 'avx512vnni_gemm.h')
-rw-r--r--avx512vnni_gemm.h61
1 files changed, 43 insertions, 18 deletions
diff --git a/avx512vnni_gemm.h b/avx512vnni_gemm.h
index f8d4a61..bffa8e9 100644
--- a/avx512vnni_gemm.h
+++ b/avx512vnni_gemm.h
@@ -93,24 +93,15 @@ struct AVX512VNNI_8bit : public AVX512_8bit {
Integer sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros;
for (; A_live != A_end; ++A_live, B_live += 8) {
Integer a = *A_live;
- // Retrieve the conveniently consecutive values of B.
- Integer b0 = *B_live;
- Integer b1 = *(B_live + 1);
- Integer b2 = *(B_live + 2);
- Integer b3 = *(B_live + 3);
- Integer b4 = *(B_live + 4);
- Integer b5 = *(B_live + 5);
- Integer b6 = *(B_live + 6);
- Integer b7 = *(B_live + 7);
//MultiplyAdd
- sum0 = _mm512_dpbusds_epi32(sum0, a, b0);
- sum1 = _mm512_dpbusds_epi32(sum1, a, b1);
- sum2 = _mm512_dpbusds_epi32(sum2, a, b2);
- sum3 = _mm512_dpbusds_epi32(sum3, a, b3);
- sum4 = _mm512_dpbusds_epi32(sum4, a, b4);
- sum5 = _mm512_dpbusds_epi32(sum5, a, b5);
- sum6 = _mm512_dpbusds_epi32(sum6, a, b6);
- sum7 = _mm512_dpbusds_epi32(sum7, a, b7);
+ sum0 = _mm512_dpbusds_epi32(sum0, a, *B_live);
+ sum1 = _mm512_dpbusds_epi32(sum1, a, *(B_live + 1));
+ sum2 = _mm512_dpbusds_epi32(sum2, a, *(B_live + 2));
+ sum3 = _mm512_dpbusds_epi32(sum3, a, *(B_live + 3));
+ sum4 = _mm512_dpbusds_epi32(sum4, a, *(B_live + 4));
+ sum5 = _mm512_dpbusds_epi32(sum5, a, *(B_live + 5));
+ sum6 = _mm512_dpbusds_epi32(sum6, a, *(B_live + 6));
+ sum7 = _mm512_dpbusds_epi32(sum7, a, *(B_live + 7));
}
Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3);
Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7);
@@ -120,7 +111,41 @@ struct AVX512VNNI_8bit : public AVX512_8bit {
}
}
- INTGEMM_PREPAREBIASFOR8(__m512i, INTGEMM_AVX512BW, CPUType::AVX2)
+ template <typename Callback>
+ INTGEMM_AVX512VNNI static void PrepareBiasFor8(const int8_t *B, Index width, Index B_cols, Callback callback) {
+ typedef __m512i Integer;
+ assert(width % sizeof(Integer) == 0);
+ assert(B_cols % 8 == 0);
+ assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0);
+ auto callback_impl = callbacks::CallbackImpl<CPUType::AVX2, Callback>(callback);
+ const int simd_width = width / sizeof(Integer);
+ const Integer *B0_col = reinterpret_cast<const Integer*>(B);
+ Integer zeros = setzero_si<Integer>();
+ const Integer a = set1_epi8<Integer>(1);
+ // Go over 8 columns of B at a time.
+ for (Index B0_colidx = 0; B0_colidx != B_cols; B0_col += 8 * simd_width, B0_colidx += 8) {
+ const Integer *B_live = B0_col; //In order to make the code look as much as possible as the above function
+ const Integer *B_end = B_live + simd_width*8;
+
+ // TODO: separate first step.
+ Integer sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros;
+ for (; B_live != B_end; B_live += 8) {
+ // Retrieve the conveniently consecutive values of B.
+ sum0 = _mm512_dpbusds_epi32(sum0, a, *B_live);
+ sum1 = _mm512_dpbusds_epi32(sum1, a, *(B_live + 1));
+ sum2 = _mm512_dpbusds_epi32(sum2, a, *(B_live + 2));
+ sum3 = _mm512_dpbusds_epi32(sum3, a, *(B_live + 3));
+ sum4 = _mm512_dpbusds_epi32(sum4, a, *(B_live + 4));
+ sum5 = _mm512_dpbusds_epi32(sum5, a, *(B_live + 5));
+ sum6 = _mm512_dpbusds_epi32(sum6, a, *(B_live + 6));
+ sum7 = _mm512_dpbusds_epi32(sum7, a, *(B_live + 7));
+ }
+ Integer pack0123 = Pack0123(sum0, sum1, sum2, sum3);
+ Integer pack4567 = Pack0123(sum4, sum5, sum6, sum7);
+ auto total = PermuteSummer(pack0123, pack4567);
+ callback_impl(total, callbacks::OutputBufferInfo(0, B0_colidx, 1, B_cols));
+ }
+ }
constexpr static const char *const kName = "8-bit AVX512VNNI";