diff options
Diffstat (limited to 'avx512_gemm.h')
-rw-r--r-- | avx512_gemm.h | 19 |
1 files changed, 13 insertions, 6 deletions
diff --git a/avx512_gemm.h b/avx512_gemm.h index 2a5fff1..eba0322 100644 --- a/avx512_gemm.h +++ b/avx512_gemm.h @@ -229,17 +229,24 @@ struct AVX512_8bit { // Convert to 8-bit signed integers. /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */ INTGEMM_AVX512BW static void Quantize(const float *input, int8_t *output, float quant_mult, Index size) { - assert(size % 16 == 0); - assert(reinterpret_cast<uintptr_t>(input) % 64 == 0); + assert(reinterpret_cast<uintptr_t>(input) % sizeof(__m512i) == 0); const __m512i neg127 = _mm512_set1_epi32(-127); const __m512 quant_mult_reg = _mm512_set1_ps(quant_mult); - const float *end = input + size; - for (; input < end; input += 16, output += 16) { - __m512i asint = avx512f::QuantizerGrab(input, quant_mult_reg); + const std::size_t kBatch = sizeof(__m512i) / sizeof(float); + const float *fast_input_end = input + (size & ~(kBatch - 1)); + int8_t *fast_output_end = output + (size & ~(kBatch - 1)); +#pragma omp parallel for + for (const float *input_it = input; input_it < fast_input_end; input_it += kBatch) { + __m512i asint = avx512f::QuantizerGrab(input_it, quant_mult_reg); asint = _mm512_max_epi32(asint, neg127); // There doesn't seem to be an unmasked version. - _mm512_mask_cvtsepi32_storeu_epi8(output, 0xffff, asint); + _mm512_mask_cvtsepi32_storeu_epi8(output + (input_it - input), 0xffff, asint); } + std::size_t overhang = size & (kBatch - 1); + if (!overhang) return; // We needed a branch anyway for the empty case. + __m512i asint = avx512f::QuantizerGrab(fast_input_end, quant_mult_reg); + asint = _mm512_max_epi32(asint, neg127); + _mm512_mask_cvtsepi32_storeu_epi8(fast_output_end, (1 << overhang) - 1, asint); } // Preparing A for the signed/unsigned multiplication. Using add 127 |