diff options
author | Kenneth Heafield <github@kheafield.com> | 2020-04-19 19:38:53 +0300 |
---|---|---|
committer | Kenneth Heafield <github@kheafield.com> | 2020-04-19 19:38:53 +0300 |
commit | 92b0b0288c035d91554846b04d41b28e4fe7c906 (patch) | |
tree | 1c5966e740c92e9618f389aa2fc1ca59eac6ed76 | |
parent | 40dd33b765cf2b6b711850eac46c2df4b225da12 (diff) |
Work around gcc _mm512_dpbusds_epi32 spurious vmovdqa64 instructions
Use
asm ("vpdpbusds %2, %1, %0" : "+x"(c) : "x"(a), "mx"(b));
instead of
c = _mm512_dpbusds_epi32(c, a, b);
-rw-r--r-- | tile/dot.inl | 45 |
1 files changed, 45 insertions, 0 deletions
diff --git a/tile/dot.inl b/tile/dot.inl index 9af7b1b..257930e 100644 --- a/tile/dot.inl +++ b/tile/dot.inl @@ -18,6 +18,43 @@ #define INTGEMM_TARGET INTGEMM_SSE2 #endif +/* gcc _mm512_dpbusds_epi32 is slow because it inserts spurious vmovdqa64 instructions. + * Simple test program: + * #include <immintrin.h> + * + * __m512i Foo(const __m512i *a, const __m512i b0, const __m512i b1, std::size_t count) { + * register __m512i c0 = _mm512_setzero_epi32(); + * register __m512i c1 = _mm512_setzero_epi32(); + * for (std::size_t i = 0; i < count; ++i) { + * c0 = _mm512_dpbusds_epi32(c0, a[i], b0); + * c1 = _mm512_dpbusds_epi32(c1, a[i], b1); + * } + * // Do not optimize away + * return _mm512_sub_epi32(c0, c1); + * } + * Then with g++ (Gentoo 9.2.0-r2 p3) 9.2.0 run as + * g++ -mavx512vnni -O3 example.cc -S + * We get some inefficient asm: + * .L3: + * vmovdqa64 (%rdi), %zmm6 + * vmovdqa64 %zmm3, %zmm0 + * vmovdqa64 %zmm4, %zmm2 + * addq $64, %rdi + * vpdpbusds %zmm5, %zmm6, %zmm0 + * vpdpbusds %zmm1, %zmm6, %zmm2 + * vmovdqa64 %zmm0, %zmm3 + * vmovdqa64 %zmm2, %zmm4 + * cmpq %rdi, %rax + * jne .L3 + * + * Why does it copy from zmm3 to zmm0, then copy zmm0 to zmm3 each loop???? + * So for gcc instead of + * c = _mm512_dpbusds_epi32(c, a, b); + * I use: + * asm ("vpdpbusds %2, %1, %0" : "+x"(c) : "x"(a), "mx"(b)); + * and that works better in the test program. + */ + namespace intgemm { namespace INTGEMM_ARCH { @@ -51,7 +88,11 @@ struct Shifted8 { const Register &a = reinterpret_cast<const Register&>(access.AFront()); const Register &b = reinterpret_cast<const Register&>(access.BFront()); #ifdef INTGEMM_THIS_IS_AVX512VNNI +#ifdef __GNUC__ + asm ("vpdpbusds %2, %1, %0" : "+x"(access.CFront()) : "x"(a), "mx"(b)); +#else access.CFront() = _mm512_dpbusds_epi32(access.CFront(), a, b); +#endif #else const Register ones = set1_epi16<Register>(1); Register mult = maddubs_epi16(a, b); @@ -89,7 +130,11 @@ struct Signed8 { // c += |a| * b_signed #if defined(INTGEMM_THIS_IS_AVX512VNNI) +#ifdef __GNUC__ + asm ("vpdpbusds %2, %1, %0" : "+x"(access.CFront()) : "x"(a_positive), "mx"(b_signed)); +#else access.CFront() = _mm512_dpbusds_epi32(access.CFront(), a_positive, b_signed); +#endif #else Register mult = maddubs_epi16(a_positive, b_signed); access.CFront() = adds_epi16(access.CFront(), mult); |