Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2020-04-20 14:54:57 +0300
committerKenneth Heafield <github@kheafield.com>2020-04-20 14:54:57 +0300
commitb872fd6a00d7f232d84427807a666806727e7b88 (patch)
treed61d0533ac5f22f89fdf23f6a11d87c49e971569
parentfb96b0851cf420ac49c13b361a503afffe386ada (diff)
Workaround gcc bug producing extra move instructions
https://gcc.gnu.org/bugzilla/show_bug.cgi?id=94663 Improvement ranges from 3% (1x64x8) to 35% (8x2048x256) and is often 21-25%. Benchmark program output: BEFORE AFTER Multiply 1 64 8 Samples=75 8-bit AVX512VNNI 64 65.4933 0.875698 8-bit AVX512VNNI 62 64.8533 1.36256 Multiply 8 256 256 Samples=75 8-bit AVX512VNNI 13296 13385.3 36.0012 8-bit AVX512VNNI 10754 10873.9 31.3479 Multiply 8 2048 256 Samples=75 8-bit AVX512VNNI 86800 86974.3 59.9597 8-bit AVX512VNNI 64222 65428.6 222.893 Multiply 8 256 2048 Samples=75 8-bit AVX512VNNI 106780 107392 232.955 8-bit AVX512VNNI 86176 88366.1 402.335 Multiply 320 256 256 Samples=75 8-bit AVX512VNNI 531720 533687 1419.3 8-bit AVX512VNNI 436536 437186 352.487 Multiply 472 256 256 Samples=75 8-bit AVX512VNNI 785026 787784 2068.05 8-bit AVX512VNNI 646240 647382 416.252 Multiply 248 256 256 Samples=75 8-bit AVX512VNNI 412282 413484 971.843 8-bit AVX512VNNI 338368 338656 141.354 Multiply 200 256 256 Samples=75 8-bit AVX512VNNI 332578 333463 742.297 8-bit AVX512VNNI 272890 273103 77.2789 Multiply 256 256 256 Samples=75 8-bit AVX512VNNI 425654 427240 1095.53 8-bit AVX512VNNI 349418 349580 80.8586 Multiply 512 512 512 Samples=75 8-bit AVX512VNNI 3122382 3.13179e+06 4215.88 8-bit AVX512VNNI 2493984 2.51602e+06 6052.1 Multiply 1024 1024 1024 Samples=3 8-bit AVX512VNNI 24927622 2.49795e+07 44940.9 8-bit AVX512VNNI 19210646 1.9229e+07 17037 Multiply 4096 4096 128 Samples=3 8-bit AVX512VNNI 49870840 4.99655e+07 133057 8-bit AVX512VNNI 46146812 4.62847e+07 205448
-rw-r--r--avx512vnni_gemm.h57
1 files changed, 33 insertions, 24 deletions
diff --git a/avx512vnni_gemm.h b/avx512vnni_gemm.h
index 6eb3be4..22c5c4e 100644
--- a/avx512vnni_gemm.h
+++ b/avx512vnni_gemm.h
@@ -8,6 +8,15 @@
namespace intgemm {
+// Workaround extra vmovdqa64 https://gcc.gnu.org/bugzilla/show_bug.cgi?id=94663
+INTGEMM_AVX512VNNI static inline void VNNI8(__m512i &c, __m512i a, __m512i b) {
+#if defined(__GNUC__) && !defined(__clang__) && !defined(__INTEL_COMPILER)
+ asm ("vpdpbusds %2, %1, %0" : "+x"(c) : "x"(a), "mx"(b));
+#else
+ c = _mm512_dpbusds_epi32(c, a, b);
+#endif
+}
+
struct AVX512VNNI_8bit : public AVX512_8bit {
template <typename Callback>
INTGEMM_AVX512VNNI static void Multiply(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) {
@@ -54,14 +63,14 @@ struct AVX512VNNI_8bit : public AVX512_8bit {
b5 = _mm512_mask_sub_epi8(b5, neg_mask, zeros, b5);
b6 = _mm512_mask_sub_epi8(b6, neg_mask, zeros, b6);
b7 = _mm512_mask_sub_epi8(b7, neg_mask, zeros, b7);
- sum0 = _mm512_dpbusds_epi32(sum0, a_positive, b0);
- sum1 = _mm512_dpbusds_epi32(sum1, a_positive, b1);
- sum2 = _mm512_dpbusds_epi32(sum2, a_positive, b2);
- sum3 = _mm512_dpbusds_epi32(sum3, a_positive, b3);
- sum4 = _mm512_dpbusds_epi32(sum4, a_positive, b4);
- sum5 = _mm512_dpbusds_epi32(sum5, a_positive, b5);
- sum6 = _mm512_dpbusds_epi32(sum6, a_positive, b6);
- sum7 = _mm512_dpbusds_epi32(sum7, a_positive, b7);
+ VNNI8(sum0, a_positive, b0);
+ VNNI8(sum1, a_positive, b1);
+ VNNI8(sum2, a_positive, b2);
+ VNNI8(sum3, a_positive, b3);
+ VNNI8(sum4, a_positive, b4);
+ VNNI8(sum5, a_positive, b5);
+ VNNI8(sum6, a_positive, b6);
+ VNNI8(sum7, a_positive, b7);
}
Register pack0123 = Pack0123(sum0, sum1, sum2, sum3);
Register pack4567 = Pack0123(sum4, sum5, sum6, sum7);
@@ -96,14 +105,14 @@ struct AVX512VNNI_8bit : public AVX512_8bit {
for (; A_live != A_end; ++A_live, B_live += 8) {
Register a = *A_live;
//MultiplyAdd
- sum0 = _mm512_dpbusds_epi32(sum0, a, *B_live);
- sum1 = _mm512_dpbusds_epi32(sum1, a, *(B_live + 1));
- sum2 = _mm512_dpbusds_epi32(sum2, a, *(B_live + 2));
- sum3 = _mm512_dpbusds_epi32(sum3, a, *(B_live + 3));
- sum4 = _mm512_dpbusds_epi32(sum4, a, *(B_live + 4));
- sum5 = _mm512_dpbusds_epi32(sum5, a, *(B_live + 5));
- sum6 = _mm512_dpbusds_epi32(sum6, a, *(B_live + 6));
- sum7 = _mm512_dpbusds_epi32(sum7, a, *(B_live + 7));
+ VNNI8(sum0, a, *B_live);
+ VNNI8(sum1, a, *(B_live + 1));
+ VNNI8(sum2, a, *(B_live + 2));
+ VNNI8(sum3, a, *(B_live + 3));
+ VNNI8(sum4, a, *(B_live + 4));
+ VNNI8(sum5, a, *(B_live + 5));
+ VNNI8(sum6, a, *(B_live + 6));
+ VNNI8(sum7, a, *(B_live + 7));
}
Register pack0123 = Pack0123(sum0, sum1, sum2, sum3);
Register pack4567 = Pack0123(sum4, sum5, sum6, sum7);
@@ -134,14 +143,14 @@ struct AVX512VNNI_8bit : public AVX512_8bit {
Register sum0 = zeros, sum1 = zeros, sum2 = zeros, sum3 = zeros, sum4 = zeros, sum5 = zeros, sum6 = zeros, sum7 = zeros;
for (; B_live != B_end; B_live += 8) {
// Retrieve the conveniently consecutive values of B.
- sum0 = _mm512_dpbusds_epi32(sum0, a, *B_live);
- sum1 = _mm512_dpbusds_epi32(sum1, a, *(B_live + 1));
- sum2 = _mm512_dpbusds_epi32(sum2, a, *(B_live + 2));
- sum3 = _mm512_dpbusds_epi32(sum3, a, *(B_live + 3));
- sum4 = _mm512_dpbusds_epi32(sum4, a, *(B_live + 4));
- sum5 = _mm512_dpbusds_epi32(sum5, a, *(B_live + 5));
- sum6 = _mm512_dpbusds_epi32(sum6, a, *(B_live + 6));
- sum7 = _mm512_dpbusds_epi32(sum7, a, *(B_live + 7));
+ VNNI8(sum0, a, *B_live);
+ VNNI8(sum1, a, *(B_live + 1));
+ VNNI8(sum2, a, *(B_live + 2));
+ VNNI8(sum3, a, *(B_live + 3));
+ VNNI8(sum4, a, *(B_live + 4));
+ VNNI8(sum5, a, *(B_live + 5));
+ VNNI8(sum6, a, *(B_live + 6));
+ VNNI8(sum7, a, *(B_live + 7));
}
Register pack0123 = Pack0123(sum0, sum1, sum2, sum3);
Register pack4567 = Pack0123(sum4, sum5, sum6, sum7);