Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'test/add127_test.cc')
-rw-r--r--test/add127_test.cc33
1 files changed, 25 insertions, 8 deletions
diff --git a/test/add127_test.cc b/test/add127_test.cc
index d1b850d..ae5c08a 100644
--- a/test/add127_test.cc
+++ b/test/add127_test.cc
@@ -81,7 +81,9 @@ template <class Routine> void TestPrepareBias(Index rows, Index cols) {
//Routine::Multiply(A_prep2.begin(), B_prep.begin(), A_rows, rows, cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, goldBias.begin(), goldBias.begin()));
//CompareBiases(goldBias.begin(), inputBias.begin(), cols);
AlignedVector<float> slowint_C(cols);
- SlowRefInt(A_prep2.begin(), B_quant.begin(), slowint_C.begin(), unquant_mult_forprep, A_rows, rows, cols, goldBias.begin());
+ references::Multiply(A_prep2.begin(), B_quant.begin(), slowint_C.begin(), A_rows, rows, cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) {
+ return sum * unquant_mult_forprep + goldBias[info.col_idx];
+ });
CompareBiases(slowint_C.begin(), inputBias.begin(), cols);
}
@@ -127,10 +129,14 @@ template <class Routine> void TestMultiplyBiasNew(Index A_rows, Index width, Ind
// Taking the original A_preparation which means A would be int8_t
AlignedVector<int8_t> A_prep2(A.size());
Routine::PrepareA(A.begin(), A_prep2.begin(), quant_mult, A_rows, width);
- SlowRefInt(A_prep2.begin(), B_quant.begin(), slowint_C.begin(), unquant_mult, A_rows, width, B_cols, bias.begin());
+ references::Multiply(A_prep2.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) {
+ return sum * unquant_mult + bias[info.col_idx];
+ });
AlignedVector<float> float_C(test_C.size());
- SlowRefFloat(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, bias.begin());
+ references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) {
+ return sum + bias[info.col_idx];
+ });
/*ACTUAL MULTIPLICATION
*
@@ -185,7 +191,10 @@ template <class Routine> void TestMultiplyShiftNonShift(Index A_rows, Index widt
Routine::Multiply(A_prep_old.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), slowint_C.begin()));
AlignedVector<float> float_C(test_C.size());
- SlowRefFloat(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, bias.begin());
+ references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) {
+ return sum + bias[info.col_idx];
+ });
+
/*
* Multiply8 shift multiplication
*/
@@ -238,10 +247,14 @@ template <class Routine> void TestMultiplyShiftInt(Index A_rows, Index width, In
Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, B.size());
AlignedVector<float> slowint_C(test_C.size());
// Taking the original A_preparation which means A would be int8_t
- //SlowRefInt(A_prep.begin(), B_quant.begin(), slowint_C.begin(), unquant_mult, A_rows, width, B_cols, bias.begin());
+ // references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) {
+ // return sum * unquant_mult + bias[info.col_idx];
+ // });
AlignedVector<float> float_C(test_C.size());
- SlowRefFloat(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, bias.begin());
+ references::MultiplyFF(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](float sum, const callbacks::OutputBufferInfo& info) {
+ return sum + bias[info.col_idx];
+ });
/*
* Multiply8 shift multiplication
*/
@@ -252,7 +265,9 @@ template <class Routine> void TestMultiplyShiftInt(Index A_rows, Index width, In
}
AlignedVector<float> ShiftedBias(B_cols);
float unquant_mult_forprep = (-1)*(alpha)*(alpha)/(127.0f); //Minus one to invert add_ps later on
- SlowRefInt(A_prep2.begin(), B_quant.begin(), ShiftedBias.begin(), unquant_mult_forprep, 1, width, B_cols, bias.begin());
+ references::Multiply(A_prep2.begin(), B_quant.begin(), ShiftedBias.begin(), 1, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) {
+ return sum * unquant_mult_forprep + bias[info.col_idx];
+ });
//Now prepare Fast integer Bias
@@ -261,7 +276,9 @@ template <class Routine> void TestMultiplyShiftInt(Index A_rows, Index width, In
// Reference INT VERSION HERE with ADD127
// Taking the original A_preparation which means A would be int8_t
- SlowRefInt(A_prep.begin(), B_quant.begin(), slowint_C.begin(), unquant_mult, A_rows, width, B_cols, ShiftedBias.begin());
+ references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) {
+ return sum * unquant_mult + ShiftedBias[info.col_idx];
+ });
Compare(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(),
int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance);