diff options
Diffstat (limited to 'test/add127_test.cc')
-rw-r--r-- | test/add127_test.cc | 50 |
1 files changed, 30 insertions, 20 deletions
diff --git a/test/add127_test.cc b/test/add127_test.cc index d1b850d..18afaa5 100644 --- a/test/add127_test.cc +++ b/test/add127_test.cc @@ -14,13 +14,6 @@ void CompareAs(int8_t * output_old, uint8_t * output_new, Index rows, Index cols } } -void CompareBiases(const float *bias_ref, const float *bias, Index cols) { - for (std::size_t i = 0; i < cols; ++i) { - INFO("Inaccurate at " << i << ' ' << bias_ref[i] << ' ' << bias[i]); - CHECK(fabs(bias_ref[i] - bias[i]) < 0.0001); - } -} - template <class Routine> void TestPrepareA(Index rows, Index cols) { std::mt19937 gen; // Go somewhat out of range too. @@ -79,10 +72,12 @@ template <class Routine> void TestPrepareBias(Index rows, Index cols) { it =1; } //Routine::Multiply(A_prep2.begin(), B_prep.begin(), A_rows, rows, cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, goldBias.begin(), goldBias.begin())); - //CompareBiases(goldBias.begin(), inputBias.begin(), cols); + //CompareEps(goldBias.begin(), inputBias.begin(), cols, 0.0001f); AlignedVector<float> slowint_C(cols); - SlowRefInt(A_prep2.begin(), B_quant.begin(), slowint_C.begin(), unquant_mult_forprep, A_rows, rows, cols, goldBias.begin()); - CompareBiases(slowint_C.begin(), inputBias.begin(), cols); + references::Multiply(A_prep2.begin(), B_quant.begin(), slowint_C.begin(), A_rows, rows, cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) { + return sum * unquant_mult_forprep + goldBias[info.col_idx]; + }); + CompareEps(slowint_C.begin(), inputBias.begin(), cols, 0.0001f); } template <class Routine> void TestMultiplyBiasNew(Index A_rows, Index width, Index B_cols, @@ -127,10 +122,14 @@ template <class Routine> void TestMultiplyBiasNew(Index A_rows, Index width, Ind // Taking the original A_preparation which means A would be int8_t AlignedVector<int8_t> A_prep2(A.size()); Routine::PrepareA(A.begin(), A_prep2.begin(), quant_mult, A_rows, width); - SlowRefInt(A_prep2.begin(), B_quant.begin(), slowint_C.begin(), unquant_mult, A_rows, width, B_cols, bias.begin()); + references::Multiply(A_prep2.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) { + return sum * unquant_mult + bias[info.col_idx]; + }); AlignedVector<float> float_C(test_C.size()); - SlowRefFloat(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, bias.begin()); + references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) { + return sum + bias[info.col_idx]; + }); /*ACTUAL MULTIPLICATION * @@ -140,7 +139,7 @@ template <class Routine> void TestMultiplyBiasNew(Index A_rows, Index width, Ind //Routine::PrepareBias(B.begin(), bias.begin(), alpha, width, B_cols); Routine::Multiply8Shift(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin())); - Compare(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), + CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance); } @@ -185,7 +184,10 @@ template <class Routine> void TestMultiplyShiftNonShift(Index A_rows, Index widt Routine::Multiply(A_prep_old.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), slowint_C.begin())); AlignedVector<float> float_C(test_C.size()); - SlowRefFloat(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, bias.begin()); + references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) { + return sum + bias[info.col_idx]; + }); + /* * Multiply8 shift multiplication */ @@ -193,7 +195,7 @@ template <class Routine> void TestMultiplyShiftNonShift(Index A_rows, Index widt Routine::PrepareBias(B_prep.begin(), width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, bias.begin(), bias.begin())); Routine::Multiply8Shift(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin())); - Compare(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), + CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance); } @@ -238,10 +240,14 @@ template <class Routine> void TestMultiplyShiftInt(Index A_rows, Index width, In Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, B.size()); AlignedVector<float> slowint_C(test_C.size()); // Taking the original A_preparation which means A would be int8_t - //SlowRefInt(A_prep.begin(), B_quant.begin(), slowint_C.begin(), unquant_mult, A_rows, width, B_cols, bias.begin()); + // references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) { + // return sum * unquant_mult + bias[info.col_idx]; + // }); AlignedVector<float> float_C(test_C.size()); - SlowRefFloat(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, bias.begin()); + references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) { + return sum + bias[info.col_idx]; + }); /* * Multiply8 shift multiplication */ @@ -252,7 +258,9 @@ template <class Routine> void TestMultiplyShiftInt(Index A_rows, Index width, In } AlignedVector<float> ShiftedBias(B_cols); float unquant_mult_forprep = (-1)*(alpha)*(alpha)/(127.0f); //Minus one to invert add_ps later on - SlowRefInt(A_prep2.begin(), B_quant.begin(), ShiftedBias.begin(), unquant_mult_forprep, 1, width, B_cols, bias.begin()); + references::Multiply(A_prep2.begin(), B_quant.begin(), ShiftedBias.begin(), 1, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) { + return sum * unquant_mult_forprep + bias[info.col_idx]; + }); //Now prepare Fast integer Bias @@ -261,9 +269,11 @@ template <class Routine> void TestMultiplyShiftInt(Index A_rows, Index width, In // Reference INT VERSION HERE with ADD127 // Taking the original A_preparation which means A would be int8_t - SlowRefInt(A_prep.begin(), B_quant.begin(), slowint_C.begin(), unquant_mult, A_rows, width, B_cols, ShiftedBias.begin()); + references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) { + return sum * unquant_mult + ShiftedBias[info.col_idx]; + }); - Compare(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), + CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance); } |