diff options
Diffstat (limited to 'avx512_gemm.h')
-rw-r--r-- | avx512_gemm.h | 9 |
1 files changed, 3 insertions, 6 deletions
diff --git a/avx512_gemm.h b/avx512_gemm.h index b3499af..b8c4de1 100644 --- a/avx512_gemm.h +++ b/avx512_gemm.h @@ -264,17 +264,14 @@ struct AVX512_8bit { assert(size % 16 == 0); assert(reinterpret_cast<uintptr_t>(input) % 64 == 0); const __m512i neg127 = _mm512_set1_epi32(-127); - const __m128i pos127 = _mm_set1_epi8(127); + const __m512i pos127 = _mm512_set1_epi32(127); const __m512 quant_mult_reg = _mm512_set1_ps(quant_mult); const float *end = input + size; for (; input < end; input += 16, output += 16) { __m512i asint = avx512f::QuantizerGrab(input, quant_mult_reg); asint = _mm512_max_epi32(asint, neg127); - - //First convert to 8 bit then add and finally store, - //because _mm512_mask_cvtsepi32_storeu_epi8 saturates to signed - __m128i as8bit = _mm512_cvtsepi32_epi8(asint); - *reinterpret_cast<__m128i*>(output) = _mm_add_epi8(as8bit, pos127); + asint = _mm512_add_epi32(asint, pos127); + _mm512_mask_cvtusepi32_storeu_epi8(output, 0xffff, asint); } } |