diff options
author | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2020-04-20 17:40:38 +0300 |
---|---|---|
committer | Mateusz Chudyk <mateuszchudyk@gmail.com> | 2020-04-20 17:40:38 +0300 |
commit | e108ab87ab1f383228f14a532163ebb549e118c2 (patch) | |
tree | 5e03dbf671f7b6666f045ec27c32ba7e5c4243f8 | |
parent | 2e40c5a978d152ed3c8e6b6bec07016aed37d9b6 (diff) | |
parent | ec396d1b8d6f29e3a70924df4225cfd4050a1c2b (diff) |
Merge remote-tracking branch 'origin/master' into multiply-tiling-8x
-rw-r--r-- | avx512vnni_gemm.h | 29 | ||||
-rw-r--r-- | multiply.h | 3 |
2 files changed, 21 insertions, 11 deletions
diff --git a/avx512vnni_gemm.h b/avx512vnni_gemm.h index 2383ba1..b9db526 100644 --- a/avx512vnni_gemm.h +++ b/avx512vnni_gemm.h @@ -8,6 +8,15 @@ namespace intgemm { +// Workaround extra vmovdqa64 https://gcc.gnu.org/bugzilla/show_bug.cgi?id=94663 +INTGEMM_AVX512VNNI static inline void VNNI8(__m512i &c, __m512i a, __m512i b) { +#if defined(__GNUC__) && !defined(__clang__) && !defined(__INTEL_COMPILER) + asm ("vpdpbusds %2, %1, %0" : "+x"(c) : "x"(a), "mx"(b)); +#else + c = _mm512_dpbusds_epi32(c, a, b); +#endif +} + // Rewrite that loads of struct to template labdas as soon as c++14 is used struct AVX512VNNI_Multiply_InitALivesLoop { template <typename Iterator, typename Type> @@ -33,7 +42,7 @@ struct AVX512VNNI_Multiply_TileLoop { static constexpr auto Row = Iterator::template I<0>(); static constexpr auto Column = Iterator::template I<1>(); auto neg_mask = _mm512_test_epi8_mask(*A_lives[Row], _mm512_set1_epi8(-128)); - sums[Row][Column] = _mm512_dpbusds_epi32(sums[Row][Column], _mm512_abs_epi8(*A_lives[Row]), _mm512_mask_sub_epi8(B_live[Column], neg_mask, setzero_si<__m512i>(), B_live[Column])); + VNNI8(sums[Row][Column], _mm512_abs_epi8(*A_lives[Row]), _mm512_mask_sub_epi8(B_live[Column], neg_mask, setzero_si<__m512i>(), B_live[Column])); } }; @@ -63,7 +72,7 @@ struct AVX512VNNI_Multiply8Shift_TileLoop { __m512i sums[Iterator::template N<0>()][Iterator::template N<1>()]) { static constexpr auto Row = Iterator::template I<0>(); static constexpr auto Column = Iterator::template I<1>(); - sums[Row][Column] = _mm512_dpbusds_epi32(sums[Row][Column], *A_lives[Row], B_live[Column]); + VNNI8(sums[Row][Column], *A_lives[Row], B_live[Column]); } }; @@ -152,14 +161,14 @@ struct AVX512VNNI_8bit : public AVX512_8bit { __m512i sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros; for (; B_live != B_end; B_live += 8) { // Retrieve the conveniently consecutive values of B. - sum0 = _mm512_dpbusds_epi32(sum0, a, *B_live); - sum1 = _mm512_dpbusds_epi32(sum1, a, *(B_live + 1)); - sum2 = _mm512_dpbusds_epi32(sum2, a, *(B_live + 2)); - sum3 = _mm512_dpbusds_epi32(sum3, a, *(B_live + 3)); - sum4 = _mm512_dpbusds_epi32(sum4, a, *(B_live + 4)); - sum5 = _mm512_dpbusds_epi32(sum5, a, *(B_live + 5)); - sum6 = _mm512_dpbusds_epi32(sum6, a, *(B_live + 6)); - sum7 = _mm512_dpbusds_epi32(sum7, a, *(B_live + 7)); + VNNI8(sum0, a, *B_live); + VNNI8(sum1, a, *(B_live + 1)); + VNNI8(sum2, a, *(B_live + 2)); + VNNI8(sum3, a, *(B_live + 3)); + VNNI8(sum4, a, *(B_live + 4)); + VNNI8(sum5, a, *(B_live + 5)); + VNNI8(sum6, a, *(B_live + 6)); + VNNI8(sum7, a, *(B_live + 7)); } __m512i pack0123 = Pack0123(sum0, sum1, sum2, sum3); __m512i pack4567 = Pack0123(sum4, sum5, sum6, sum7); @@ -521,7 +521,8 @@ template <Index TileRows, Index TileColumnsMultiplier, class Backend, class Call #pragma omp parallel Backend::template Multiply<TileRows, TileColumnsMultiplier, Callback>(A, B, A_rows, width, B_cols, callback); } -template <Index TileRows, Index TileColumnsMultiplier, class Backend, class Callback> static inline void OMPParallelWrap8Shift(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { + +template <Index TileRows, Index TileColumnsMultiplier, class Backend, class Callback> static inline void OMPParallelWrap8Shift(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) { #pragma omp parallel Backend::template Multiply8Shift<TileRows, TileColumnsMultiplier, Callback>(A, B, A_rows, width, B_cols, callback); } |