Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'avx512_gemm.h')
-rw-r--r--avx512_gemm.h19
1 files changed, 13 insertions, 6 deletions
diff --git a/avx512_gemm.h b/avx512_gemm.h
index 2a5fff1..eba0322 100644
--- a/avx512_gemm.h
+++ b/avx512_gemm.h
@@ -229,17 +229,24 @@ struct AVX512_8bit {
// Convert to 8-bit signed integers.
/* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
INTGEMM_AVX512BW static void Quantize(const float *input, int8_t *output, float quant_mult, Index size) {
- assert(size % 16 == 0);
- assert(reinterpret_cast<uintptr_t>(input) % 64 == 0);
+ assert(reinterpret_cast<uintptr_t>(input) % sizeof(__m512i) == 0);
const __m512i neg127 = _mm512_set1_epi32(-127);
const __m512 quant_mult_reg = _mm512_set1_ps(quant_mult);
- const float *end = input + size;
- for (; input < end; input += 16, output += 16) {
- __m512i asint = avx512f::QuantizerGrab(input, quant_mult_reg);
+ const std::size_t kBatch = sizeof(__m512i) / sizeof(float);
+ const float *fast_input_end = input + (size & ~(kBatch - 1));
+ int8_t *fast_output_end = output + (size & ~(kBatch - 1));
+#pragma omp parallel for
+ for (const float *input_it = input; input_it < fast_input_end; input_it += kBatch) {
+ __m512i asint = avx512f::QuantizerGrab(input_it, quant_mult_reg);
asint = _mm512_max_epi32(asint, neg127);
// There doesn't seem to be an unmasked version.
- _mm512_mask_cvtsepi32_storeu_epi8(output, 0xffff, asint);
+ _mm512_mask_cvtsepi32_storeu_epi8(output + (input_it - input), 0xffff, asint);
}
+ std::size_t overhang = size & (kBatch - 1);
+ if (!overhang) return; // We needed a branch anyway for the empty case.
+ __m512i asint = avx512f::QuantizerGrab(fast_input_end, quant_mult_reg);
+ asint = _mm512_max_epi32(asint, neg127);
+ _mm512_mask_cvtsepi32_storeu_epi8(fast_output_end, (1 << overhang) - 1, asint);
}
// Preparing A for the signed/unsigned multiplication. Using add 127