diff options
author | Kenneth Heafield <kpu@users.noreply.github.com> | 2020-01-21 14:48:42 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-01-21 14:48:42 +0300 |
commit | a5300207b996492d4507109da8d4e5323354c7ac (patch) | |
tree | e72c0ba1db3f7e1ca8dc5a837ea58ced7d14ac97 | |
parent | 03a4a9dbe4e1955efdb6c6f671636d9378755f45 (diff) | |
parent | e2c008d075e55bfb1538a42cf2fce113f039e6a8 (diff) |
Merge pull request #55 from kpu/debug_add127
More tests for add127
-rw-r--r-- | test/add127_test.cc | 301 | ||||
-rw-r--r-- | test/test.cc | 15 | ||||
-rw-r--r-- | test/test.h | 1 |
3 files changed, 264 insertions, 53 deletions
diff --git a/test/add127_test.cc b/test/add127_test.cc index c803bad..86d630b 100644 --- a/test/add127_test.cc +++ b/test/add127_test.cc @@ -2,36 +2,6 @@ namespace intgemm { -void newBias(const float * input, float * bias, float* output, float alpha, Index rows, Index width, Index cols) { - AlignedVector<float> intermediate(rows*cols); - AlignedVector<float> ones(rows*width); - for (auto&& it : ones) { - it = 1; - } - SlowRefFloat(ones.begin(), input, intermediate.begin(), rows, width, cols); - for (auto&& it : intermediate) { - it = it*alpha; - } - - - for (Index c = 0; c<cols; c++) { - output[c] = bias[c] - intermediate.begin()[rows*c]; - } - -} - -void SlowSumB(const int8_t * input, float * bias, float* output, float alpha, Index rows, Index cols) { - for (Index r = 0; r<rows; r++) { - for (Index c = 0; c<cols; c++) { - output[c] += input[r * cols + c]; - } - } - - for (Index c = 0; c<cols; c++) { - output[c] = bias[c] + output[c]*alpha; - } -} - void CompareAs(int8_t * output_old, uint8_t * output_new, Index rows, Index cols) { for (Index r = 0; r<rows; r++) { for (Index c = 0; c<cols; c++) { @@ -91,19 +61,28 @@ template <class Routine> void TestPrepareBias(Index rows, Index cols) { AlignedVector<float> goldBias(cols); for (auto& it : goldBias) { - it = 0; + it = dist(gen); } + int i = 0; for (auto& it : inputBias) { - it = dist(gen); + it = goldBias[i]; + i++; } float unquant_mult_forprep = (-1)*(alpha)*(alpha)/(127.0f); - SlowSumB(B_quant.begin(), inputBias.begin(), goldBias.begin(), alpha, rows, cols); + Routine::PrepareBiasFor8(B_prep.begin(), rows, cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, inputBias.begin(), inputBias.begin())); - Routine::PrepareBiasFor8(1, B_prep.begin(), 1, rows, cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, inputBias.begin(), inputBias.begin())); - - CompareBiases(goldBias.begin(), inputBias.begin(), cols); + int A_rows = 1; + AlignedVector<int8_t> A_prep2(A_rows*rows); + for (auto& it : A_prep2) { + it =1; + } + //Routine::Multiply(A_prep2.begin(), B_prep.begin(), A_rows, rows, cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, goldBias.begin(), goldBias.begin())); + //CompareBiases(goldBias.begin(), inputBias.begin(), cols); + AlignedVector<float> slowint_C(cols); + SlowRefInt(A_prep2.begin(), B_quant.begin(), slowint_C.begin(), unquant_mult_forprep, A_rows, rows, cols, goldBias.begin()); + CompareBiases(slowint_C.begin(), inputBias.begin(), cols); } template <class Routine> void TestMultiplyBiasNew(Index A_rows, Index width, Index B_cols, @@ -165,11 +144,133 @@ template <class Routine> void TestMultiplyBiasNew(Index A_rows, Index width, Ind int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance); } -/* +template <class Routine> void TestMultiplyShiftNonShift(Index A_rows, Index width, Index B_cols, + float int_tolerance=.1, float float_tolerance=1, float MSE_float_tolerance=0, float MSE_int_tolerance=0) { + std::ostringstream info; + info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n'; + + // Initialize A and B. + AlignedVector<float> A(A_rows * width); + AlignedVector<float> B(width * B_cols); + AlignedVector<float> bias(B_cols); + std::mt19937 gen; + std::uniform_real_distribution<float> dist(-1.0f, 1.0f); + for (auto& it : A) { + it = dist(gen); + } + for (auto& it : B) { + it = dist(gen); + } + for (auto& it : bias) { + it = 0; + } + + float alpha = 2.0f; + float quant_mult = 127/alpha; + float unquant_mult = 1.0/(quant_mult*quant_mult); + + AlignedVector<uint8_t> A_prep(A.size()); + AlignedVector<int8_t> A_prep_old(A.size()); + AlignedVector<int8_t> B_prep(B.size()); + Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); + Routine::PrepareA(A.begin(), A_prep_old.begin(), quant_mult, A_rows, width); //Non shited version + Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols); + + AlignedVector<float> test_C(A_rows * B_cols); + + /* + * Reference non shift multiplication instead of slowint + */ + AlignedVector<float> slowint_C(test_C.size()); + Routine::Multiply(A_prep_old.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), slowint_C.begin())); + + AlignedVector<float> float_C(test_C.size()); + SlowRefFloat(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, bias.begin()); + /* + * Multiply8 shift multiplication + */ + float unquant_mult_forprep = (-1)*(alpha)*(alpha)/(127.0f); //Minus one to invert add_ps later on + Routine::PrepareBiasFor8(B_prep.begin(), width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, bias.begin(), bias.begin())); + Routine::Multiply8Shift(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin())); + + Compare(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), + int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance); +} + +template <class Routine> void TestMultiplyShiftInt(Index A_rows, Index width, Index B_cols, + float int_tolerance=.1, float float_tolerance=1, float MSE_float_tolerance=0, float MSE_int_tolerance=0) { + std::ostringstream info; + info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n'; + + // Initialize A and B. + AlignedVector<float> A(A_rows * width); + AlignedVector<float> B(width * B_cols); + AlignedVector<float> bias(B_cols); + std::mt19937 gen; + std::uniform_real_distribution<float> dist(-1.0f, 1.0f); + for (auto& it : A) { + it = dist(gen); + } + for (auto& it : B) { + it = dist(gen); + } + for (auto& it : bias) { + it = 0; + } + + float alpha = 2.0f; + float quant_mult = 127/alpha; + float unquant_mult = 1.0/(quant_mult*quant_mult); + + AlignedVector<uint8_t> A_prep(A.size()); + AlignedVector<int8_t> A_prep_old(A.size()); + AlignedVector<int8_t> B_prep(B.size()); + Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width); + Routine::PrepareA(A.begin(), A_prep_old.begin(), quant_mult, A_rows, width); //Non shited version + Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols); + + AlignedVector<float> test_C(A_rows * B_cols); + + /* + * Reference float multiplication + */ + AlignedVector<int8_t> B_quant(B.size()); + Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, B.size()); + AlignedVector<float> slowint_C(test_C.size()); + // Taking the original A_preparation which means A would be int8_t + //SlowRefInt(A_prep.begin(), B_quant.begin(), slowint_C.begin(), unquant_mult, A_rows, width, B_cols, bias.begin()); + + AlignedVector<float> float_C(test_C.size()); + SlowRefFloat(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, bias.begin()); + /* + * Multiply8 shift multiplication + */ + //First prepare SlowInteger Bias: + AlignedVector<int8_t> A_prep2(1*width); + for (auto& it : A_prep2) { + it = 1; + } + AlignedVector<float> ShiftedBias(B_cols); + float unquant_mult_forprep = (-1)*(alpha)*(alpha)/(127.0f); //Minus one to invert add_ps later on + SlowRefInt(A_prep2.begin(), B_quant.begin(), ShiftedBias.begin(), unquant_mult_forprep, 1, width, B_cols, bias.begin()); + + + //Now prepare Fast integer Bias + Routine::PrepareBiasFor8(B_prep.begin(), width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, bias.begin(), bias.begin())); + Routine::Multiply8Shift(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin())); + + // Reference INT VERSION HERE with ADD127 + // Taking the original A_preparation which means A would be int8_t + SlowRefInt(A_prep.begin(), B_quant.begin(), slowint_C.begin(), unquant_mult, A_rows, width, B_cols, ShiftedBias.begin()); + + Compare(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), + int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance); +} + + // Bias TEST_CASE("PrepareBias SSSE3", "[Add127]") { if (kCPU < CPUType::SSSE3) return; - TestPrepareBias<SSSE3_8bit>(8,8); TestPrepareBias<SSSE3_8bit>(256,256); TestPrepareBias<SSSE3_8bit>(2048,256); TestPrepareBias<SSSE3_8bit>(512,512); @@ -177,7 +278,6 @@ TEST_CASE("PrepareBias SSSE3", "[Add127]") { TEST_CASE("PrepareBias AVX2", "[Add127]") { if (kCPU < CPUType::AVX2) return; - TestPrepareBias<AVX2_8bit>(8,8); TestPrepareBias<AVX2_8bit>(256,256); TestPrepareBias<AVX2_8bit>(2048,256); TestPrepareBias<AVX2_8bit>(512,512); @@ -186,12 +286,11 @@ TEST_CASE("PrepareBias AVX2", "[Add127]") { TEST_CASE("PrepareBias AVX512F", "[Add127]") { if (kCPU < CPUType::AVX512BW) return; #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512 - TestPrepareBias<AVX512_8bit>(8,8); TestPrepareBias<AVX512_8bit>(256,256); TestPrepareBias<AVX512_8bit>(2048,256); TestPrepareBias<AVX512_8bit>(512,512); #endif -}*/ +} //A TEST_CASE("PrepareA SSSE3", "[Add127]") { @@ -225,23 +324,23 @@ TEST_CASE("PrepareA AVX512F", "[Add127]") { TEST_CASE ("Multiply SSSE3 8bit Shift with bias", "[Add127]") { if (kCPU < CPUType::SSSE3) return; TestMultiplyBiasNew<SSSE3_8bit>(1, 64, 8, 0.11, 0.1, 0.06, 0.05); - TestMultiplyBiasNew<SSSE3_8bit>(8, 256, 256, 0.45, 0.54, 0.17, 0.16); // 0.064, 0.026); - TestMultiplyBiasNew<SSSE3_8bit>(8, 2048, 256, 1.7, 1.7, 0.46, 0.43); // 4.4, 4.4); - TestMultiplyBiasNew<SSSE3_8bit>(320, 256, 256, 0.56, 0.64, 0.16, 0.15); // 0.1, 0.01); - TestMultiplyBiasNew<SSSE3_8bit>(472, 256, 256, 0.46, 0.62, 0.17, 0.16); // 0.1, 0.011); - TestMultiplyBiasNew<SSSE3_8bit>(248, 256, 256, 0.48, 0.64, 0.16, 0.15); // 0.1, 0.012); - TestMultiplyBiasNew<SSSE3_8bit>(200, 256, 256, 0.55, 0.74, 0.17, 0.16); // 0.1, 0.011); + TestMultiplyBiasNew<SSSE3_8bit>(8, 256, 256, 0.45, 0.54, 0.17, 0.16); + TestMultiplyBiasNew<SSSE3_8bit>(8, 2048, 256, 1.7, 1.7, 0.46, 0.43); + TestMultiplyBiasNew<SSSE3_8bit>(320, 256, 256, 0.56, 0.64, 0.16, 0.15); + TestMultiplyBiasNew<SSSE3_8bit>(472, 256, 256, 0.46, 0.62, 0.17, 0.16); + TestMultiplyBiasNew<SSSE3_8bit>(248, 256, 256, 0.48, 0.64, 0.16, 0.15); + TestMultiplyBiasNew<SSSE3_8bit>(200, 256, 256, 0.55, 0.74, 0.17, 0.16); } TEST_CASE ("Multiply AVX2 8bit Shift with bias", "[Add127]") { if (kCPU < CPUType::AVX2) return; TestMultiplyBiasNew<AVX2_8bit>(1, 64, 8, 0.11, 0.11, 0.06, 0.05); - TestMultiplyBiasNew<AVX2_8bit>(8, 256, 256, 0.49, 0.54, 0.17, 0.16); //0.1, 0); - TestMultiplyBiasNew<AVX2_8bit>(8, 2048, 256, 1.57, 1.66, 0.46, 0.46); //1.8, 1.8); - TestMultiplyBiasNew<AVX2_8bit>(320, 256, 256, 0.49, 0.64, 0.16, 0.15); //0.1, 0); - TestMultiplyBiasNew<AVX2_8bit>(472, 256, 256, 0.46, 0.62, 0.17, 0.16); //0.1, 0); - TestMultiplyBiasNew<AVX2_8bit>(248, 256, 256, 0.48, 0.64, 0.16, 0.15); //0.1, 0); - TestMultiplyBiasNew<AVX2_8bit>(200, 256, 256, 0.55, 0.74, 0.17, 0.16); //0.1, 0); + TestMultiplyBiasNew<AVX2_8bit>(8, 256, 256, 0.49, 0.54, 0.17, 0.16); + TestMultiplyBiasNew<AVX2_8bit>(8, 2048, 256, 1.57, 1.66, 0.46, 0.46); + TestMultiplyBiasNew<AVX2_8bit>(320, 256, 256, 0.49, 0.64, 0.16, 0.15); + TestMultiplyBiasNew<AVX2_8bit>(472, 256, 256, 0.46, 0.62, 0.17, 0.16); + TestMultiplyBiasNew<AVX2_8bit>(248, 256, 256, 0.48, 0.64, 0.16, 0.15); + TestMultiplyBiasNew<AVX2_8bit>(200, 256, 256, 0.55, 0.74, 0.17, 0.16); } #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512 TEST_CASE ("Multiply AVX512F 8bit Shift with bias", "[Add127]") { @@ -269,4 +368,100 @@ TEST_CASE ("Multiply AVX512F 8bit Shift with bias", "[Add127]") { } #endif +//Multiply old vs new +TEST_CASE ("Multiply SSSE3 8bit Shift vs nonshift", "[Add127]") { + if (kCPU < CPUType::SSSE3) return; + TestMultiplyShiftNonShift<SSSE3_8bit>(1, 64, 8, 0.00001, 0.1, 0.06, 0.00001); + TestMultiplyShiftNonShift<SSSE3_8bit>(8, 256, 256, 0.00001, 0.54, 0.17, 0.00001); + TestMultiplyShiftNonShift<SSSE3_8bit>(8, 2048, 256, 17.9, 1.7, 0.46, 4.2); //Big difference here because the non-shift version is very bad + TestMultiplyShiftNonShift<SSSE3_8bit>(320, 256, 256, 1.2, 0.64, 0.16, 0.006); + TestMultiplyShiftNonShift<SSSE3_8bit>(472, 256, 256, 1.1, 0.62, 0.17, 0.006); + TestMultiplyShiftNonShift<SSSE3_8bit>(248, 256, 256, 0.9, 0.64, 0.16, 0.007); + TestMultiplyShiftNonShift<SSSE3_8bit>(200, 256, 256, 1, 0.74, 0.17, 0.006); +} + +TEST_CASE ("Multiply AVX2 8bit Shift vs nonshift", "[Add127]") { + if (kCPU < CPUType::AVX2) return; + TestMultiplyShiftNonShift<AVX2_8bit>(1, 64, 8, 0.00001, 0.11, 0.06, 0.00001); + TestMultiplyShiftNonShift<AVX2_8bit>(8, 256, 256, 0.00001, 0.54, 0.17, 0.00001); + TestMultiplyShiftNonShift<AVX2_8bit>(8, 2048, 256, 9.4, 1.66, 0.46, 1.67); //Big difference here because the non-shift version is very bad + TestMultiplyShiftNonShift<AVX2_8bit>(320, 256, 256, 0.0001, 0.64, 0.16, 0.0001); + TestMultiplyShiftNonShift<AVX2_8bit>(472, 256, 256, 0.0001, 0.62, 0.17, 0.0001); + TestMultiplyShiftNonShift<AVX2_8bit>(248, 256, 256, 0.0001, 0.64, 0.16, 0.0001); + TestMultiplyShiftNonShift<AVX2_8bit>(200, 256, 256, 0.0001, 0.74, 0.17, 0.0001); +} +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512 +TEST_CASE ("Multiply AVX512F 8bit Shift vs nonshift", "[Add127]") { + if (kCPU < CPUType::AVX512BW) return; + TestMultiplyShiftNonShift<AVX512_8bit>(1, 64, 8, 0.0001, 0.05, 0.03, 0.001); + TestMultiplyShiftNonShift<AVX512_8bit>(8, 256, 256, 0.0001, 0.22, 0.06, 0.001); + TestMultiplyShiftNonShift<AVX512_8bit>(8, 2048, 256, 3.51, 0.61, 0.17, 0.3); + TestMultiplyShiftNonShift<AVX512_8bit>(320, 256, 256, 0.0001, 0.27, 0.06, 0.001); + TestMultiplyShiftNonShift<AVX512_8bit>(472, 256, 256, 0.0001, 0.33, 0.06, 0.001); + TestMultiplyShiftNonShift<AVX512_8bit>(248, 256, 256, 0.0001, 0.27, 0.06, 0.001); + TestMultiplyShiftNonShift<AVX512_8bit>(200, 256, 256, 0.0001, 0.28, 0.06, 0.001); +} +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI + TEST_CASE ("Multiply AVX512VNNI 8bit Shift vs nonshift", "[Add127]") { + if (kCPU < CPUType::AVX512VNNI) return; + TestMultiplyShiftNonShift<AVX512VNNI_8bit>(1, 64, 8, 0.00001, 0.05, 0.03, 0.00001); + TestMultiplyShiftNonShift<AVX512VNNI_8bit>(8, 256, 256, 0.00001, 0.22, 0.06, 0.00001); + TestMultiplyShiftNonShift<AVX512VNNI_8bit>(8, 2048, 256, 0.0001, 0.61, 0.17, 0.0001); + TestMultiplyShiftNonShift<AVX512VNNI_8bit>(320, 256, 256, 0.00001, 0.27, 0.06, 0.00001); + TestMultiplyShiftNonShift<AVX512VNNI_8bit>(472, 256, 256, 0.00001, 0.33, 0.06, 0.00001); + TestMultiplyShiftNonShift<AVX512VNNI_8bit>(248, 256, 256, 0.00001, 0.27, 0.06, 0.00001); + TestMultiplyShiftNonShift<AVX512VNNI_8bit>(200, 256, 256, 0.00001, 0.28, 0.06, 0.00001); + } +#endif + +//Multiply Shift vs int shift implementation +TEST_CASE ("Multiply SSSE3 8bit Shift vs Int", "[Add127]") { + if (kCPU < CPUType::SSSE3) return; + TestMultiplyShiftInt<SSSE3_8bit>(1, 64, 8, 0, 0.1, 0.06, 0); + TestMultiplyShiftInt<SSSE3_8bit>(8, 256, 256, 0, 0.54, 0.17, 0); + TestMultiplyShiftInt<SSSE3_8bit>(8, 2048, 256, 0, 1.7, 0.46, 0); + TestMultiplyShiftInt<SSSE3_8bit>(320, 256, 256, 0, 0.64, 0.16, 0); + TestMultiplyShiftInt<SSSE3_8bit>(472, 256, 256, 0, 0.62, 0.17, 0); + TestMultiplyShiftInt<SSSE3_8bit>(248, 256, 256, 0, 0.64, 0.16, 0); + TestMultiplyShiftInt<SSSE3_8bit>(200, 256, 256, 0, 0.74, 0.17, 0); +} + +TEST_CASE ("Multiply AVX2 8bit Shift vs Int", "[Add127]") { + if (kCPU < CPUType::AVX2) return; + TestMultiplyShiftInt<AVX2_8bit>(1, 64, 8, 0, 0.11, 0.06, 0); + TestMultiplyShiftInt<AVX2_8bit>(8, 256, 256, 0, 0.54, 0.17, 0); + TestMultiplyShiftInt<AVX2_8bit>(8, 2048, 256, 0, 1.66, 0.46, 0); + TestMultiplyShiftInt<AVX2_8bit>(320, 256, 256, 0, 0.64, 0.16, 0); + TestMultiplyShiftInt<AVX2_8bit>(472, 256, 256, 0, 0.62, 0.17, 0); + TestMultiplyShiftInt<AVX2_8bit>(248, 256, 256, 0, 0.64, 0.16, 0); + TestMultiplyShiftInt<AVX2_8bit>(200, 256, 256, 0, 0.74, 0.17, 0); +} +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512 +TEST_CASE ("Multiply AVX512F 8bit Shift vs Int", "[Add127]") { + if (kCPU < CPUType::AVX512BW) return; + TestMultiplyShiftInt<AVX512_8bit>(1, 64, 8, 0, 0.05, 0.03, 0); + TestMultiplyShiftInt<AVX512_8bit>(8, 256, 256, 0, 0.22, 0.06, 0); + TestMultiplyShiftInt<AVX512_8bit>(8, 2048, 256, 0, 0.61, 0.17, 0); + TestMultiplyShiftInt<AVX512_8bit>(320, 256, 256, 0, 0.27, 0.06, 0); + TestMultiplyShiftInt<AVX512_8bit>(472, 256, 256, 0, 0.33, 0.06, 0); + TestMultiplyShiftInt<AVX512_8bit>(248, 256, 256, 0, 0.27, 0.06, 0); + TestMultiplyShiftInt<AVX512_8bit>(200, 256, 256, 0, 0.28, 0.06, 0); +} +#endif + +#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI + TEST_CASE ("Multiply AVX512VNNI 8bit Shift vs Int", "[Add127]") { + if (kCPU < CPUType::AVX512VNNI) return; + TestMultiplyShiftInt<AVX512VNNI_8bit>(1, 64, 8, 0, 0.05, 0.03, 0); + TestMultiplyShiftInt<AVX512VNNI_8bit>(8, 256, 256, 0, 0.22, 0.06, 0); + TestMultiplyShiftInt<AVX512VNNI_8bit>(8, 2048, 256, 0, 0.61, 0.17, 0); + TestMultiplyShiftInt<AVX512VNNI_8bit>(320, 256, 256, 0, 0.27, 0.06, 0); + TestMultiplyShiftInt<AVX512VNNI_8bit>(472, 256, 256, 0, 0.33, 0.06, 0); + TestMultiplyShiftInt<AVX512VNNI_8bit>(248, 256, 256, 0, 0.27, 0.06, 0); + TestMultiplyShiftInt<AVX512VNNI_8bit>(200, 256, 256, 0, 0.28, 0.06, 0); + } +#endif + } //namespace intgemm diff --git a/test/test.cc b/test/test.cc index 88daaa2..2986d82 100644 --- a/test/test.cc +++ b/test/test.cc @@ -39,6 +39,21 @@ template <class Integer> void SlowRefInt(const Integer *A, const Integer *B, flo } } } +void SlowRefInt(const uint8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias) { + for (Index r = 0; r < A_rows; ++r) { + for (Index c = 0; c < B_cols; ++c) { + int32_t sum = 0; + for (Index k = 0; k < width; ++k) { + sum += static_cast<int16_t>(A[r * width + k]) * static_cast<int16_t>(B[k * B_cols + c]); + } + if (bias) { + C[r * B_cols + c] = sum * unquant_mult + bias[c]; + } else { + C[r * B_cols + c] = sum * unquant_mult; + } + } + } +} template void SlowRefInt<int8_t>(const int8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias); template void SlowRefInt<int16_t>(const int16_t *A, const int16_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias); diff --git a/test/test.h b/test/test.h index fc47da5..291ff45 100644 --- a/test/test.h +++ b/test/test.h @@ -25,6 +25,7 @@ void SlowRefFloat(const float *A, const float *B, float *C, Index A_rows, Index // Compute A*B slowly from integers. template <class Integer> void SlowRefInt(const Integer *A, const Integer *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias=nullptr); +void SlowRefInt(const uint8_t *A, const int8_t *B, float *C, float unquant_mult, Index A_rows, Index width, Index B_cols, const float *bias=nullptr); void Compare(const float *float_ref, const float *int_ref, const float *int_test, std::size_t size, std::string test_info, float int_tolerance, float float_tolerance, float MSE_float_tolerance, float MSE_int_tolerance); |