Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNikolay Bogoychev <nheart@gmail.com>2019-06-13 11:35:00 +0300
committerNikolay Bogoychev <nheart@gmail.com>2019-06-13 11:35:00 +0300
commitfe8146452aecc39f4e348be23928e7ed9baaefc1 (patch)
tree6b70d7351024490ae9e03786e7d714c6581ae277
parent4593f46f4a447a98d4b5c09419f0d07f0bfccc77 (diff)
Fix avx512f prepareA and tests
-rw-r--r--avx512_gemm.h10
-rw-r--r--test/add127_test.cc14
2 files changed, 13 insertions, 11 deletions
diff --git a/avx512_gemm.h b/avx512_gemm.h
index 4033589..ef01c88 100644
--- a/avx512_gemm.h
+++ b/avx512_gemm.h
@@ -215,15 +215,17 @@ struct AVX512_8bit {
assert(size % 16 == 0);
assert(reinterpret_cast<uintptr_t>(input) % 64 == 0);
const __m512i neg127 = _mm512_set1_epi32(-127);
- const __m512i pos127 = _mm512_set1_epi32(127);
+ const __m128i pos127 = _mm_set1_epi8(127);
const __m512 quant_mult_reg = _mm512_set1_ps(quant_mult);
const float *end = input + size;
for (; input < end; input += 16, output += 16) {
__m512i asint = avx512f::QuantizerGrab(input, quant_mult_reg);
asint = _mm512_max_epi32(asint, neg127);
- asint = add_epi32(asint, pos127); //Maybe could do +128 and remove the above line
- // There doesn't seem to be an unmasked version.
- _mm512_mask_cvtsepi32_storeu_epi8(output, 0xffff, asint);
+
+ //First convert to 8 bit then add and finally store,
+ //because _mm512_mask_cvtsepi32_storeu_epi8 saturates to signed
+ __m128i as8bit = _mm512_cvtsepi32_epi8(asint);
+ *reinterpret_cast<__m128i*>(output) = _mm_add_epi8(as8bit, pos127);
}
}
diff --git a/test/add127_test.cc b/test/add127_test.cc
index 5b828cb..7d4d916 100644
--- a/test/add127_test.cc
+++ b/test/add127_test.cc
@@ -215,13 +215,13 @@ TEST_CASE ("Multiply AVX2 8bit with new bias", "[Add127]") {
TEST_CASE ("Multiply AVX512F 8bit with new bias", "[Add127]") {
if (kCPU < CPU_AVX512BW) return;
- TestMultiplyBiasNew<AVX512_8bit>(1, 64, 8, 1.2, 1.2, 0.064, 0.05);
- TestMultiplyBiasNew<AVX512_8bit>(8, 256, 256, 17, 17, 3.6, 3.6); //0.1, 0);
- TestMultiplyBiasNew<AVX512_8bit>(8, 2048, 256, 132, 132, 41.0, 41.0); //1.8, 1.8);
- TestMultiplyBiasNew<AVX512_8bit>(320, 256, 256, 18, 18, 3.7, 3.7); //0.1, 0);
- TestMultiplyBiasNew<AVX512_8bit>(472, 256, 256, 28, 28, 3.9, 3.9); //0.1, 0);
- TestMultiplyBiasNew<AVX512_8bit>(248, 256, 256, 25, 25, 3.9, 3.9); //0.1, 0);
- TestMultiplyBiasNew<AVX512_8bit>(200, 256, 256, 19, 19, 3.6, 3.6); //0.1, 0);
+ TestMultiplyBiasNew<AVX512_8bit>(1, 64, 8, 0.11, 0.11, 0.06, 0.05);
+ TestMultiplyBiasNew<AVX512_8bit>(8, 256, 256, 7.5, 7.5, 0.99, 0.99); //, 1.6, 1.6); //0.1, 0);
+ TestMultiplyBiasNew<AVX512_8bit>(8, 2048, 256, 109, 109, 31.0, 31.0); //1.8, 1.8);
+ TestMultiplyBiasNew<AVX512_8bit>(320, 256, 256, 9, 9, 1.1, 1.1); //0.1, 0);
+ TestMultiplyBiasNew<AVX512_8bit>(472, 256, 256, 10, 10, 1.2, 1.2); //0.1, 0);
+ TestMultiplyBiasNew<AVX512_8bit>(248, 256, 256, 8.2, 8.2, 1.1, 1.1); //0.1, 0);
+ TestMultiplyBiasNew<AVX512_8bit>(200, 256, 256, 8.3, 8.3, 1.2, 1.2); //0.1, 0);
}
} //namespace intgemm