Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKenneth Heafield <kpu@users.noreply.github.com>2020-01-21 14:48:42 +0300
committerGitHub <noreply@github.com>2020-01-21 14:48:42 +0300
commita5300207b996492d4507109da8d4e5323354c7ac (patch)
treee72c0ba1db3f7e1ca8dc5a837ea58ced7d14ac97
parent03a4a9dbe4e1955efdb6c6f671636d9378755f45 (diff)
parente2c008d075e55bfb1538a42cf2fce113f039e6a8 (diff)
Merge pull request #55 from kpu/debug_add127
More tests for add127
-rw-r--r--test/add127_test.cc301
-rw-r--r--test/test.cc15
-rw-r--r--test/test.h1
3 files changed, 264 insertions, 53 deletions
diff --git a/test/add127_test.cc b/test/add127_test.cc
index c803bad..86d630b 100644
--- a/test/add127_test.cc
+++ b/test/add127_test.cc
@@ -2,36 +2,6 @@
namespace intgemm {
-void newBias(const float * input, float * bias, float* output, float alpha, Index rows, Index width, Index cols) {
- AlignedVector<float> intermediate(rows*cols);
- AlignedVector<float> ones(rows*width);
- for (auto&& it : ones) {
- it = 1;
- }
- SlowRefFloat(ones.begin(), input, intermediate.begin(), rows, width, cols);
- for (auto&& it : intermediate) {
- it = it*alpha;
- }
-
-
- for (Index c = 0; c<cols; c++) {
- output[c] = bias[c] - intermediate.begin()[rows*c];
- }
-
-}
-
-void SlowSumB(const int8_t * input, float * bias, float* output, float alpha, Index rows, Index cols) {
- for (Index r = 0; r<rows; r++) {
- for (Index c = 0; c<cols; c++) {
- output[c] += input[r * cols + c];
- }
- }
-
- for (Index c = 0; c<cols; c++) {
- output[c] = bias[c] + output[c]*alpha;
- }
-}
-
void CompareAs(int8_t * output_old, uint8_t * output_new, Index rows, Index cols) {
for (Index r = 0; r<rows; r++) {
for (Index c = 0; c<cols; c++) {
@@ -91,19 +61,28 @@ template <class Routine> void TestPrepareBias(Index rows, Index cols) {
AlignedVector<float> goldBias(cols);
for (auto& it : goldBias) {
- it = 0;
+ it = dist(gen);
}
+ int i = 0;
for (auto& it : inputBias) {
- it = dist(gen);
+ it = goldBias[i];
+ i++;
}
float unquant_mult_forprep = (-1)*(alpha)*(alpha)/(127.0f);
- SlowSumB(B_quant.begin(), inputBias.begin(), goldBias.begin(), alpha, rows, cols);
+ Routine::PrepareBiasFor8(B_prep.begin(), rows, cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, inputBias.begin(), inputBias.begin()));
- Routine::PrepareBiasFor8(1, B_prep.begin(), 1, rows, cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, inputBias.begin(), inputBias.begin()));
-
- CompareBiases(goldBias.begin(), inputBias.begin(), cols);
+ int A_rows = 1;
+ AlignedVector<int8_t> A_prep2(A_rows*rows);
+ for (auto& it : A_prep2) {
+ it =1;
+ }
+ //Routine::Multiply(A_prep2.begin(), B_prep.begin(), A_rows, rows, cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, goldBias.begin(), goldBias.begin()));
+ //CompareBiases(goldBias.begin(), inputBias.begin(), cols);
+ AlignedVector<float> slowint_C(cols);
+ SlowRefInt(A_prep2.begin(), B_quant.begin(), slowint_C.begin(), unquant_mult_forprep, A_rows, rows, cols, goldBias.begin());
+ CompareBiases(slowint_C.begin(), inputBias.begin(), cols);
}
template <class Routine> void TestMultiplyBiasNew(Index A_rows, Index width, Index B_cols,
@@ -165,11 +144,133 @@ template <class Routine> void TestMultiplyBiasNew(Index A_rows, Index width, Ind
int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance);
}
-/*
+template <class Routine> void TestMultiplyShiftNonShift(Index A_rows, Index width, Index B_cols,
+ float int_tolerance=.1, float float_tolerance=1, float MSE_float_tolerance=0, float MSE_int_tolerance=0) {
+ std::ostringstream info;
+ info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n';
+
+ // Initialize A and B.
+ AlignedVector<float> A(A_rows * width);
+ AlignedVector<float> B(width * B_cols);
+ AlignedVector<float> bias(B_cols);
+ std::mt19937 gen;
+ std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
+ for (auto& it : A) {
+ it = dist(gen);
+ }
+ for (auto& it : B) {
+ it = dist(gen);
+ }
+ for (auto& it : bias) {
+ it = 0;
+ }
+
+ float alpha = 2.0f;
+ float quant_mult = 127/alpha;
+ float unquant_mult = 1.0/(quant_mult*quant_mult);
+
+ AlignedVector<uint8_t> A_prep(A.size());
+ AlignedVector<int8_t> A_prep_old(A.size());
+ AlignedVector<int8_t> B_prep(B.size());
+ Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width);
+ Routine::PrepareA(A.begin(), A_prep_old.begin(), quant_mult, A_rows, width); //Non shited version
+ Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
+
+ AlignedVector<float> test_C(A_rows * B_cols);
+
+ /*
+ * Reference non shift multiplication instead of slowint
+ */
+ AlignedVector<float> slowint_C(test_C.size());
+ Routine::Multiply(A_prep_old.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), slowint_C.begin()));
+
+ AlignedVector<float> float_C(test_C.size());
+ SlowRefFloat(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, bias.begin());
+ /*
+ * Multiply8 shift multiplication
+ */
+ float unquant_mult_forprep = (-1)*(alpha)*(alpha)/(127.0f); //Minus one to invert add_ps later on
+ Routine::PrepareBiasFor8(B_prep.begin(), width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, bias.begin(), bias.begin()));
+ Routine::Multiply8Shift(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin()));
+
+ Compare(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(),
+ int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance);
+}
+
+template <class Routine> void TestMultiplyShiftInt(Index A_rows, Index width, Index B_cols,
+ float int_tolerance=.1, float float_tolerance=1, float MSE_float_tolerance=0, float MSE_int_tolerance=0) {
+ std::ostringstream info;
+ info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n';
+
+ // Initialize A and B.
+ AlignedVector<float> A(A_rows * width);
+ AlignedVector<float> B(width * B_cols);
+ AlignedVector<float> bias(B_cols);
+ std::mt19937 gen;
+ std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
+ for (auto& it : A) {
+ it = dist(gen);
+ }
+ for (auto& it : B) {
+ it = dist(gen);
+ }
+ for (auto& it : bias) {
+ it = 0;
+ }
+
+ float alpha = 2.0f;
+ float quant_mult = 127/alpha;
+ float unquant_mult = 1.0/(quant_mult*quant_mult);
+
+ AlignedVector<uint8_t> A_prep(A.size());
+ AlignedVector<int8_t> A_prep_old(A.size());
+ AlignedVector<int8_t> B_prep(B.size());
+ Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width);
+ Routine::PrepareA(A.begin(), A_prep_old.begin(), quant_mult, A_rows, width); //Non shited version
+ Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
+
+ AlignedVector<float> test_C(A_rows * B_cols);
+
+ /*
+ * Reference float multiplication
+ */
+ AlignedVector<int8_t> B_quant(B.size());
+ Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, B.size());
+ AlignedVector<float> slowint_C(test_C.size());
+ // Taking the original A_preparation which means A would be int8_t
+ //SlowRefInt(A_prep.begin(), B_quant.begin(), slowint_C.begin(), unquant_mult, A_rows, width, B_cols, bias.begin());
+
+ AlignedVector<float> float_C(test_C.size());
+ SlowRefFloat(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, bias.begin());
+ /*
+ * Multiply8 shift multiplication
+ */
+ //First prepare SlowInteger Bias:
+ AlignedVector<int8_t> A_prep2(1*width);
+ for (auto& it : A_prep2) {
+ it = 1;
+ }
+ AlignedVector<float> ShiftedBias(B_cols);
+ float unquant_mult_forprep = (-1)*(alpha)*(alpha)/(127.0f); //Minus one to invert add_ps later on
+ SlowRefInt(A_prep2.begin(), B_quant.begin(), ShiftedBias.begin(), unquant_mult_forprep, 1, width, B_cols, bias.begin());
+
+
+ //Now prepare Fast integer Bias
+ Routine::PrepareBiasFor8(B_prep.begin(), width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, bias.begin(), bias.begin()));
+ Routine::Multiply8Shift(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin()));
+
+ // Reference INT VERSION HERE with ADD127
+ // Taking the original A_preparation which means A would be int8_t
+ SlowRefInt(A_prep.begin(), B_quant.begin(), slowint_C.begin(), unquant_mult, A_rows, width, B_cols, ShiftedBias.begin());
+
+ Compare(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(),
+ int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance);
+}
+
+
// Bias
TEST_CASE("PrepareBias SSSE3", "[Add127]") {
if (kCPU < CPUType::SSSE3) return;
- TestPrepareBias<SSSE3_8bit>(8,8);
TestPrepareBias<SSSE3_8bit>(256,256);
TestPrepareBias<SSSE3_8bit>(2048,256);
TestPrepareBias<SSSE3_8bit>(512,512);
@@ -177,7 +278,6 @@ TEST_CASE("PrepareBias SSSE3", "[Add127]") {
TEST_CASE("PrepareBias AVX2", "[Add127]") {
if (kCPU < CPUType::AVX2) return;
- TestPrepareBias<AVX2_8bit>(8,8);
TestPrepareBias<AVX2_8bit>(256,256);
TestPrepareBias<AVX2_8bit>(2048,256);
TestPrepareBias<AVX2_8bit>(512,512);
@@ -186,12 +286,11 @@ TEST_CASE("PrepareBias AVX2", "[Add127]") {
TEST_CASE("PrepareBias AVX512F", "[Add127]") {
if (kCPU < CPUType::AVX512BW) return;
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512
- TestPrepareBias<AVX512_8bit>(8,8);
TestPrepareBias<AVX512_8bit>(256,256);
TestPrepareBias<AVX512_8bit>(2048,256);
TestPrepareBias<AVX512_8bit>(512,512);
#endif
-}*/
+}
//A
TEST_CASE("PrepareA SSSE3", "[Add127]") {
@@ -225,23 +324,23 @@ TEST_CASE("PrepareA AVX512F", "[Add127]") {
TEST_CASE ("Multiply SSSE3 8bit Shift with bias", "[Add127]") {
if (kCPU < CPUType::SSSE3) return;
TestMultiplyBiasNew<SSSE3_8bit>(1, 64, 8, 0.11, 0.1, 0.06, 0.05);
- TestMultiplyBiasNew<SSSE3_8bit>(8, 256, 256, 0.45, 0.54, 0.17, 0.16); // 0.064, 0.026);
- TestMultiplyBiasNew<SSSE3_8bit>(8, 2048, 256, 1.7, 1.7, 0.46, 0.43); // 4.4, 4.4);
- TestMultiplyBiasNew<SSSE3_8bit>(320, 256, 256, 0.56, 0.64, 0.16, 0.15); // 0.1, 0.01);
- TestMultiplyBiasNew<SSSE3_8bit>(472, 256, 256, 0.46, 0.62, 0.17, 0.16); // 0.1, 0.011);
- TestMultiplyBiasNew<SSSE3_8bit>(248, 256, 256, 0.48, 0.64, 0.16, 0.15); // 0.1, 0.012);
- TestMultiplyBiasNew<SSSE3_8bit>(200, 256, 256, 0.55, 0.74, 0.17, 0.16); // 0.1, 0.011);
+ TestMultiplyBiasNew<SSSE3_8bit>(8, 256, 256, 0.45, 0.54, 0.17, 0.16);
+ TestMultiplyBiasNew<SSSE3_8bit>(8, 2048, 256, 1.7, 1.7, 0.46, 0.43);
+ TestMultiplyBiasNew<SSSE3_8bit>(320, 256, 256, 0.56, 0.64, 0.16, 0.15);
+ TestMultiplyBiasNew<SSSE3_8bit>(472, 256, 256, 0.46, 0.62, 0.17, 0.16);
+ TestMultiplyBiasNew<SSSE3_8bit>(248, 256, 256, 0.48, 0.64, 0.16, 0.15);
+ TestMultiplyBiasNew<SSSE3_8bit>(200, 256, 256, 0.55, 0.74, 0.17, 0.16);
}
TEST_CASE ("Multiply AVX2 8bit Shift with bias", "[Add127]") {
if (kCPU < CPUType::AVX2) return;
TestMultiplyBiasNew<AVX2_8bit>(1, 64, 8, 0.11, 0.11, 0.06, 0.05);
- TestMultiplyBiasNew<AVX2_8bit>(8, 256, 256, 0.49, 0.54, 0.17, 0.16); //0.1, 0);
- TestMultiplyBiasNew<AVX2_8bit>(8, 2048, 256, 1.57, 1.66, 0.46, 0.46); //1.8, 1.8);
- TestMultiplyBiasNew<AVX2_8bit>(320, 256, 256, 0.49, 0.64, 0.16, 0.15); //0.1, 0);
- TestMultiplyBiasNew<AVX2_8bit>(472, 256, 256, 0.46, 0.62, 0.17, 0.16); //0.1, 0);
- TestMultiplyBiasNew<AVX2_8bit>(248, 256, 256, 0.48, 0.64, 0.16, 0.15); //0.1, 0);
- TestMultiplyBiasNew<AVX2_8bit>(200, 256, 256, 0.55, 0.74, 0.17, 0.16); //0.1, 0);
+ TestMultiplyBiasNew<AVX2_8bit>(8, 256, 256, 0.49, 0.54, 0.17, 0.16);
+ TestMultiplyBiasNew<AVX2_8bit>(8, 2048, 256, 1.57, 1.66, 0.46, 0.46);
+ TestMultiplyBiasNew<AVX2_8bit>(320, 256, 256, 0.49, 0.64, 0.16, 0.15);
+ TestMultiplyBiasNew<AVX2_8bit>(472, 256, 256, 0.46, 0.62, 0.17, 0.16);
+ TestMultiplyBiasNew<AVX2_8bit>(248, 256, 256, 0.48, 0.64, 0.16, 0.15);
+ TestMultiplyBiasNew<AVX2_8bit>(200, 256, 256, 0.55, 0.74, 0.17, 0.16);
}
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512
TEST_CASE ("Multiply AVX512F 8bit Shift with bias", "[Add127]") {
@@ -269,4 +368,100 @@ TEST_CASE ("Multiply AVX512F 8bit Shift with bias", "[Add127]") {
}
#endif
+//Multiply old vs new
+TEST_CASE ("Multiply SSSE3 8bit Shift vs nonshift", "[Add127]") {
+ if (kCPU < CPUType::SSSE3) return;
+ TestMultiplyShiftNonShift<SSSE3_8bit>(1, 64, 8, 0.00001, 0.1, 0.06, 0.00001);
+ TestMultiplyShiftNonShift<SSSE3_8bit>(8, 256, 256, 0.00001, 0.54, 0.17, 0.00001);
+ TestMultiplyShiftNonShift<SSSE3_8bit>(8, 2048, 256, 17.9, 1.7, 0.46, 4.2); //Big difference here because the non-shift version is very bad
+ TestMultiplyShiftNonShift<SSSE3_8bit>(320, 256, 256, 1.2, 0.64, 0.16, 0.006);
+ TestMultiplyShiftNonShift<SSSE3_8bit>(472, 256, 256, 1.1, 0.62, 0.17, 0.006);
+ TestMultiplyShiftNonShift<SSSE3_8bit>(248, 256, 256, 0.9, 0.64, 0.16, 0.007);
+ TestMultiplyShiftNonShift<SSSE3_8bit>(200, 256, 256, 1, 0.74, 0.17, 0.006);
+}
+
+TEST_CASE ("Multiply AVX2 8bit Shift vs nonshift", "[Add127]") {
+ if (kCPU < CPUType::AVX2) return;
+ TestMultiplyShiftNonShift<AVX2_8bit>(1, 64, 8, 0.00001, 0.11, 0.06, 0.00001);
+ TestMultiplyShiftNonShift<AVX2_8bit>(8, 256, 256, 0.00001, 0.54, 0.17, 0.00001);
+ TestMultiplyShiftNonShift<AVX2_8bit>(8, 2048, 256, 9.4, 1.66, 0.46, 1.67); //Big difference here because the non-shift version is very bad
+ TestMultiplyShiftNonShift<AVX2_8bit>(320, 256, 256, 0.0001, 0.64, 0.16, 0.0001);
+ TestMultiplyShiftNonShift<AVX2_8bit>(472, 256, 256, 0.0001, 0.62, 0.17, 0.0001);
+ TestMultiplyShiftNonShift<AVX2_8bit>(248, 256, 256, 0.0001, 0.64, 0.16, 0.0001);
+ TestMultiplyShiftNonShift<AVX2_8bit>(200, 256, 256, 0.0001, 0.74, 0.17, 0.0001);
+}
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512
+TEST_CASE ("Multiply AVX512F 8bit Shift vs nonshift", "[Add127]") {
+ if (kCPU < CPUType::AVX512BW) return;
+ TestMultiplyShiftNonShift<AVX512_8bit>(1, 64, 8, 0.0001, 0.05, 0.03, 0.001);
+ TestMultiplyShiftNonShift<AVX512_8bit>(8, 256, 256, 0.0001, 0.22, 0.06, 0.001);
+ TestMultiplyShiftNonShift<AVX512_8bit>(8, 2048, 256, 3.51, 0.61, 0.17, 0.3);
+ TestMultiplyShiftNonShift<AVX512_8bit>(320, 256, 256, 0.0001, 0.27, 0.06, 0.001);
+ TestMultiplyShiftNonShift<AVX512_8bit>(472, 256, 256, 0.0001, 0.33, 0.06, 0.001);
+ TestMultiplyShiftNonShift<AVX512_8bit>(248, 256, 256, 0.0001, 0.27, 0.06, 0.001);
+ TestMultiplyShiftNonShift<AVX512_8bit>(200, 256, 256, 0.0001, 0.28, 0.06, 0.001);
+}
+#endif
+
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
+ TEST_CASE ("Multiply AVX512VNNI 8bit Shift vs nonshift", "[Add127]") {
+ if (kCPU < CPUType::AVX512VNNI) return;
+ TestMultiplyShiftNonShift<AVX512VNNI_8bit>(1, 64, 8, 0.00001, 0.05, 0.03, 0.00001);
+ TestMultiplyShiftNonShift<AVX512VNNI_8bit>(8, 256, 256, 0.00001, 0.22, 0.06, 0.00001);
+ TestMultiplyShiftNonShift<AVX512VNNI_8bit>(8, 2048, 256, 0.0001, 0.61, 0.17, 0.0001);
+ TestMultiplyShiftNonShift<AVX512VNNI_8bit>(320, 256, 256, 0.00001, 0.27, 0.06, 0.00001);
+ TestMultiplyShiftNonShift<AVX512VNNI_8bit>(472, 256, 256, 0.00001, 0.33, 0.06, 0.00001);
+ TestMultiplyShiftNonShift<AVX512VNNI_8bit>(248, 256, 256, 0.00001, 0.27, 0.06, 0.00001);
+ TestMultiplyShiftNonShift<AVX512VNNI_8bit>(200, 256, 256, 0.00001, 0.28, 0.06, 0.00001);
+ }
+#endif
+
+//Multiply Shift vs int shift implementation
+TEST_CASE ("Multiply SSSE3 8bit Shift vs Int", "[Add127]") {
+ if (kCPU < CPUType::SSSE3) return;
+ TestMultiplyShiftInt<SSSE3_8bit>(1, 64, 8, 0, 0.1, 0.06, 0);
+ TestMultiplyShiftInt<SSSE3_8bit>(8, 256, 256, 0, 0.54, 0.17, 0);
+ TestMultiplyShiftInt<SSSE3_8bit>(8, 2048, 256, 0, 1.7, 0.46, 0);
+ TestMultiplyShiftInt<SSSE3_8bit>(320, 256, 256, 0, 0.64, 0.16, 0);
+ TestMultiplyShiftInt<SSSE3_8bit>(472, 256, 256, 0, 0.62, 0.17, 0);
+ TestMultiplyShiftInt<SSSE3_8bit>(248, 256, 256, 0, 0.64, 0.16, 0);
+ TestMultiplyShiftInt<SSSE3_8bit>(200, 256, 256, 0, 0.74, 0.17, 0);
+}
+
+TEST_CASE ("Multiply AVX2 8bit Shift vs Int", "[Add127]") {
+ if (kCPU < CPUType::AVX2) return;
+ TestMultiplyShiftInt<AVX2_8bit>(1, 64, 8, 0, 0.11, 0.06, 0);
+ TestMultiplyShiftInt<AVX2_8bit>(8, 256, 256, 0, 0.54, 0.17, 0);
+ TestMultiplyShiftInt<AVX2_8bit>(8, 2048, 256, 0, 1.66, 0.46, 0);
+ TestMultiplyShiftInt<AVX2_8bit>(320, 256, 256, 0, 0.64, 0.16, 0);
+ TestMultiplyShiftInt<AVX2_8bit>(472, 256, 256, 0, 0.62, 0.17, 0);
+ TestMultiplyShiftInt<AVX2_8bit>(248, 256, 256, 0, 0.64, 0.16, 0);
+ TestMultiplyShiftInt<AVX2_8bit>(200, 256, 256, 0, 0.74, 0.17, 0);
+}
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512
+TEST_CASE ("Multiply AVX512F 8bit Shift vs Int", "[Add127]") {
+ if (kCPU < CPUType::AVX512BW) return;
+ TestMultiplyShiftInt<AVX512_8bit>(1, 64, 8, 0, 0.05, 0.03, 0);
+ TestMultiplyShiftInt<AVX512_8bit>(8, 256, 256, 0, 0.22, 0.06, 0);
+ TestMultiplyShiftInt<AVX512_8bit>(8, 2048, 256, 0, 0.61, 0.17, 0);
+ TestMultiplyShiftInt<AVX512_8bit>(320, 256, 256, 0, 0.27, 0.06, 0);
+ TestMultiplyShiftInt<AVX512_8bit>(472, 256, 256, 0, 0.33, 0.06, 0);
+ TestMultiplyShiftInt<AVX512_8bit>(248, 256, 256, 0, 0.27, 0.06, 0);
+ TestMultiplyShiftInt<AVX512_8bit>(200, 256, 256, 0, 0.28, 0.06, 0);
+}
+#endif
+
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
+ TEST_CASE ("Multiply AVX512VNNI 8bit Shift vs Int", "[Add127]") {
+ if (kCPU < CPUType::AVX512VNNI) return;
+ TestMultiplyShiftInt<AVX512VNNI_8bit>(1, 64, 8, 0, 0.05, 0.03, 0);
+ TestMultiplyShiftInt<AVX512VNNI_8bit>(8, 256, 256, 0, 0.22, 0.06, 0);
+ TestMultiplyShiftInt<AVX512VNNI_8bit>(8, 2048, 256, 0, 0.61, 0.17, 0);
+ TestMultiplyShiftInt<AVX512VNNI_8bit>(320, 256, 256, 0, 0.27, 0.06, 0);
+ TestMultiplyShiftInt<AVX512VNNI_8bit>(472, 256, 256, 0, 0.33, 0.06, 0);
+ TestMultiplyShiftInt<AVX512VNNI_8bit>(248, 256, 256, 0, 0.27, 0.06, 0);
+ TestMultiplyShiftInt<AVX512VNNI_8bit>(200, 256, 256, 0, 0.28, 0.06, 0);
+ }
+#endif
+
} //namespace intgemm
diff --git a/test/test.cc b/test/test.cc
index 88daaa2..2986d82 100644
--- a/test/test.cc
+++ b/test/test.cc
@@ -39,6 +39,21 @@ template <class Integer> void SlowRefInt(const Integer *A, const Integer *B, flo
}
}
}
+void SlowRefInt(const uint8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias) {
+ for (Index r = 0; r < A_rows; ++r) {
+ for (Index c = 0; c < B_cols; ++c) {
+ int32_t sum = 0;
+ for (Index k = 0; k < width; ++k) {
+ sum += static_cast<int16_t>(A[r * width + k]) * static_cast<int16_t>(B[k * B_cols + c]);
+ }
+ if (bias) {
+ C[r * B_cols + c] = sum * unquant_mult + bias[c];
+ } else {
+ C[r * B_cols + c] = sum * unquant_mult;
+ }
+ }
+ }
+}
template void SlowRefInt<int8_t>(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias);
template void SlowRefInt<int16_t>(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias);
diff --git a/test/test.h b/test/test.h
index fc47da5..291ff45 100644
--- a/test/test.h
+++ b/test/test.h
@@ -25,6 +25,7 @@ void SlowRefFloat(const float *A, const float *B, float *C, Index A_rows, Index
// Compute A*B slowly from integers.
template <class Integer> void SlowRefInt(const Integer *A, const Integer *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias=nullptr);
+void SlowRefInt(const uint8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias=nullptr);
void Compare(const float *float_ref, const float *int_ref, const float *int_test, std::size_t size, std::string test_info,
float int_tolerance, float float_tolerance, float MSE_float_tolerance, float MSE_int_tolerance);