Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Chudyk <mateuszchudyk@gmail.com>2020-04-20 17:40:38 +0300
committerMateusz Chudyk <mateuszchudyk@gmail.com>2020-04-20 17:40:38 +0300
commite108ab87ab1f383228f14a532163ebb549e118c2 (patch)
tree5e03dbf671f7b6666f045ec27c32ba7e5c4243f8
parent2e40c5a978d152ed3c8e6b6bec07016aed37d9b6 (diff)
parentec396d1b8d6f29e3a70924df4225cfd4050a1c2b (diff)
Merge remote-tracking branch 'origin/master' into multiply-tiling-8x
-rw-r--r--avx512vnni_gemm.h29
-rw-r--r--multiply.h3
2 files changed, 21 insertions, 11 deletions
diff --git a/avx512vnni_gemm.h b/avx512vnni_gemm.h
index 2383ba1..b9db526 100644
--- a/avx512vnni_gemm.h
+++ b/avx512vnni_gemm.h
@@ -8,6 +8,15 @@
namespace intgemm {
+// Workaround extra vmovdqa64 https://gcc.gnu.org/bugzilla/show_bug.cgi?id=94663
+INTGEMM_AVX512VNNI static inline void VNNI8(__m512i &c, __m512i a, __m512i b) {
+#if defined(__GNUC__) && !defined(__clang__) && !defined(__INTEL_COMPILER)
+ asm ("vpdpbusds %2, %1, %0" : "+x"(c) : "x"(a), "mx"(b));
+#else
+ c = _mm512_dpbusds_epi32(c, a, b);
+#endif
+}
+
// Rewrite that loads of struct to template labdas as soon as c++14 is used
struct AVX512VNNI_Multiply_InitALivesLoop {
template <typename Iterator, typename Type>
@@ -33,7 +42,7 @@ struct AVX512VNNI_Multiply_TileLoop {
static constexpr auto Row = Iterator::template I<0>();
static constexpr auto Column = Iterator::template I<1>();
auto neg_mask = _mm512_test_epi8_mask(*A_lives[Row], _mm512_set1_epi8(-128));
- sums[Row][Column] = _mm512_dpbusds_epi32(sums[Row][Column], _mm512_abs_epi8(*A_lives[Row]), _mm512_mask_sub_epi8(B_live[Column], neg_mask, setzero_si<__m512i>(), B_live[Column]));
+ VNNI8(sums[Row][Column], _mm512_abs_epi8(*A_lives[Row]), _mm512_mask_sub_epi8(B_live[Column], neg_mask, setzero_si<__m512i>(), B_live[Column]));
}
};
@@ -63,7 +72,7 @@ struct AVX512VNNI_Multiply8Shift_TileLoop {
__m512i sums[Iterator::template N<0>()][Iterator::template N<1>()]) {
static constexpr auto Row = Iterator::template I<0>();
static constexpr auto Column = Iterator::template I<1>();
- sums[Row][Column] = _mm512_dpbusds_epi32(sums[Row][Column], *A_lives[Row], B_live[Column]);
+ VNNI8(sums[Row][Column], *A_lives[Row], B_live[Column]);
}
};
@@ -152,14 +161,14 @@ struct AVX512VNNI_8bit : public AVX512_8bit {
__m512i sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros;
for (; B_live != B_end; B_live += 8) {
// Retrieve the conveniently consecutive values of B.
- sum0 = _mm512_dpbusds_epi32(sum0, a, *B_live);
- sum1 = _mm512_dpbusds_epi32(sum1, a, *(B_live + 1));
- sum2 = _mm512_dpbusds_epi32(sum2, a, *(B_live + 2));
- sum3 = _mm512_dpbusds_epi32(sum3, a, *(B_live + 3));
- sum4 = _mm512_dpbusds_epi32(sum4, a, *(B_live + 4));
- sum5 = _mm512_dpbusds_epi32(sum5, a, *(B_live + 5));
- sum6 = _mm512_dpbusds_epi32(sum6, a, *(B_live + 6));
- sum7 = _mm512_dpbusds_epi32(sum7, a, *(B_live + 7));
+ VNNI8(sum0, a, *B_live);
+ VNNI8(sum1, a, *(B_live + 1));
+ VNNI8(sum2, a, *(B_live + 2));
+ VNNI8(sum3, a, *(B_live + 3));
+ VNNI8(sum4, a, *(B_live + 4));
+ VNNI8(sum5, a, *(B_live + 5));
+ VNNI8(sum6, a, *(B_live + 6));
+ VNNI8(sum7, a, *(B_live + 7));
}
__m512i pack0123 = Pack0123(sum0, sum1, sum2, sum3);
__m512i pack4567 = Pack0123(sum4, sum5, sum6, sum7);
diff --git a/multiply.h b/multiply.h
index 17ec65c..5c01677 100644
--- a/multiply.h
+++ b/multiply.h
@@ -521,7 +521,8 @@ template <Index TileRows, Index TileColumnsMultiplier, class Backend, class Call
#pragma omp parallel
Backend::template Multiply<TileRows, TileColumnsMultiplier, Callback>(A, B, A_rows, width, B_cols, callback);
}
-template <Index TileRows, Index TileColumnsMultiplier, class Backend, class Callback> static inline void OMPParallelWrap8Shift(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) {
+
+template <Index TileRows, Index TileColumnsMultiplier, class Backend, class Callback> static inline void OMPParallelWrap8Shift(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) {
#pragma omp parallel
Backend::template Multiply8Shift<TileRows, TileColumnsMultiplier, Callback>(A, B, A_rows, width, B_cols, callback);
}