diff options
Diffstat (limited to 'test/add127_test.cc')
-rw-r--r-- | test/add127_test.cc | 33 |
1 files changed, 25 insertions, 8 deletions
diff --git a/test/add127_test.cc b/test/add127_test.cc index d1b850d..ae5c08a 100644 --- a/test/add127_test.cc +++ b/test/add127_test.cc @@ -81,7 +81,9 @@ template <class Routine> void TestPrepareBias(Index rows, Index cols) { //Routine::Multiply(A_prep2.begin(), B_prep.begin(), A_rows, rows, cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, goldBias.begin(), goldBias.begin())); //CompareBiases(goldBias.begin(), inputBias.begin(), cols); AlignedVector<float> slowint_C(cols); - SlowRefInt(A_prep2.begin(), B_quant.begin(), slowint_C.begin(), unquant_mult_forprep, A_rows, rows, cols, goldBias.begin()); + references::Multiply(A_prep2.begin(), B_quant.begin(), slowint_C.begin(), A_rows, rows, cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) { + return sum * unquant_mult_forprep + goldBias[info.col_idx]; + }); CompareBiases(slowint_C.begin(), inputBias.begin(), cols); } @@ -127,10 +129,14 @@ template <class Routine> void TestMultiplyBiasNew(Index A_rows, Index width, Ind // Taking the original A_preparation which means A would be int8_t AlignedVector<int8_t> A_prep2(A.size()); Routine::PrepareA(A.begin(), A_prep2.begin(), quant_mult, A_rows, width); - SlowRefInt(A_prep2.begin(), B_quant.begin(), slowint_C.begin(), unquant_mult, A_rows, width, B_cols, bias.begin()); + references::Multiply(A_prep2.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) { + return sum * unquant_mult + bias[info.col_idx]; + }); AlignedVector<float> float_C(test_C.size()); - SlowRefFloat(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, bias.begin()); + references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) { + return sum + bias[info.col_idx]; + }); /*ACTUAL MULTIPLICATION * @@ -185,7 +191,10 @@ template <class Routine> void TestMultiplyShiftNonShift(Index A_rows, Index widt Routine::Multiply(A_prep_old.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), slowint_C.begin())); AlignedVector<float> float_C(test_C.size()); - SlowRefFloat(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, bias.begin()); + references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) { + return sum + bias[info.col_idx]; + }); + /* * Multiply8 shift multiplication */ @@ -238,10 +247,14 @@ template <class Routine> void TestMultiplyShiftInt(Index A_rows, Index width, In Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, B.size()); AlignedVector<float> slowint_C(test_C.size()); // Taking the original A_preparation which means A would be int8_t - //SlowRefInt(A_prep.begin(), B_quant.begin(), slowint_C.begin(), unquant_mult, A_rows, width, B_cols, bias.begin()); + // references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) { + // return sum * unquant_mult + bias[info.col_idx]; + // }); AlignedVector<float> float_C(test_C.size()); - SlowRefFloat(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, bias.begin()); + references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) { + return sum + bias[info.col_idx]; + }); /* * Multiply8 shift multiplication */ @@ -252,7 +265,9 @@ template <class Routine> void TestMultiplyShiftInt(Index A_rows, Index width, In } AlignedVector<float> ShiftedBias(B_cols); float unquant_mult_forprep = (-1)*(alpha)*(alpha)/(127.0f); //Minus one to invert add_ps later on - SlowRefInt(A_prep2.begin(), B_quant.begin(), ShiftedBias.begin(), unquant_mult_forprep, 1, width, B_cols, bias.begin()); + references::Multiply(A_prep2.begin(), B_quant.begin(), ShiftedBias.begin(), 1, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) { + return sum * unquant_mult_forprep + bias[info.col_idx]; + }); //Now prepare Fast integer Bias @@ -261,7 +276,9 @@ template <class Routine> void TestMultiplyShiftInt(Index A_rows, Index width, In // Reference INT VERSION HERE with ADD127 // Taking the original A_preparation which means A would be int8_t - SlowRefInt(A_prep.begin(), B_quant.begin(), slowint_C.begin(), unquant_mult, A_rows, width, B_cols, ShiftedBias.begin()); + references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) { + return sum * unquant_mult + ShiftedBias[info.col_idx]; + }); Compare(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance); |