From b75a56d9b930acfe6c65d00beb09d145bd9f51dc Mon Sep 17 00:00:00 2001 From: Mateusz Chudyk Date: Thu, 6 Feb 2020 18:41:40 +0000 Subject: Straighten helper functions used in tests --- test/add127_test.cc | 17 +++++------------ test/multiply_test.cc | 15 ++------------- test/test.cc | 2 +- test/test.h | 34 +++++++++++++++++++++++++++++++--- 4 files changed, 39 insertions(+), 29 deletions(-) diff --git a/test/add127_test.cc b/test/add127_test.cc index ae5c08a..18afaa5 100644 --- a/test/add127_test.cc +++ b/test/add127_test.cc @@ -14,13 +14,6 @@ void CompareAs(int8_t * output_old, uint8_t * output_new, Index rows, Index cols } } -void CompareBiases(const float *bias_ref, const float *bias, Index cols) { - for (std::size_t i = 0; i < cols; ++i) { - INFO("Inaccurate at " << i << ' ' << bias_ref[i] << ' ' << bias[i]); - CHECK(fabs(bias_ref[i] - bias[i]) < 0.0001); - } -} - template void TestPrepareA(Index rows, Index cols) { std::mt19937 gen; // Go somewhat out of range too. @@ -79,12 +72,12 @@ template void TestPrepareBias(Index rows, Index cols) { it =1; } //Routine::Multiply(A_prep2.begin(), B_prep.begin(), A_rows, rows, cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, goldBias.begin(), goldBias.begin())); - //CompareBiases(goldBias.begin(), inputBias.begin(), cols); + //CompareEps(goldBias.begin(), inputBias.begin(), cols, 0.0001f); AlignedVector slowint_C(cols); references::Multiply(A_prep2.begin(), B_quant.begin(), slowint_C.begin(), A_rows, rows, cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) { return sum * unquant_mult_forprep + goldBias[info.col_idx]; }); - CompareBiases(slowint_C.begin(), inputBias.begin(), cols); + CompareEps(slowint_C.begin(), inputBias.begin(), cols, 0.0001f); } template void TestMultiplyBiasNew(Index A_rows, Index width, Index B_cols, @@ -146,7 +139,7 @@ template void TestMultiplyBiasNew(Index A_rows, Index width, Ind //Routine::PrepareBias(B.begin(), bias.begin(), alpha, width, B_cols); Routine::Multiply8Shift(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin())); - Compare(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), + CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance); } @@ -202,7 +195,7 @@ template void TestMultiplyShiftNonShift(Index A_rows, Index widt Routine::PrepareBias(B_prep.begin(), width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, bias.begin(), bias.begin())); Routine::Multiply8Shift(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin())); - Compare(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), + CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance); } @@ -280,7 +273,7 @@ template void TestMultiplyShiftInt(Index A_rows, Index width, In return sum * unquant_mult + ShiftedBias[info.col_idx]; }); - Compare(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), + CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance); } diff --git a/test/multiply_test.cc b/test/multiply_test.cc index 725fbca..59c62a9 100644 --- a/test/multiply_test.cc +++ b/test/multiply_test.cc @@ -54,17 +54,6 @@ INTGEMM_SSSE3 TEST_CASE("Transpose 8", "[transpose]") { } } -template std::string PrintMatrix(const T *mem, Index rows, Index cols) { - std::ostringstream out; - for (Index r = 0; r < rows; ++r) { - for (Index c = 0; c < cols; ++c) { - out << std::setw(4) << (int64_t) mem[r * cols + c] << ' '; - } - out << '\n'; - } - return out.str(); -} - template void TestPrepare(Index rows = 32, Index cols = 16) { std::mt19937 gen; // Go somewhat out of range too. @@ -306,7 +295,7 @@ template void TestMultiply(Index A_rows, Index width, Index B_co return sum; }); - Compare(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), + CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance); } @@ -359,7 +348,7 @@ template void TestMultiplyBias(Index A_rows, Index width, Index return sum + bias[info.col_idx]; }); - Compare(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), + CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(), int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance); } diff --git a/test/test.cc b/test/test.cc index 62137a1..3559738 100644 --- a/test/test.cc +++ b/test/test.cc @@ -7,7 +7,7 @@ int main(int argc, char ** argv) { namespace intgemm { -void Compare(const float *float_ref, const float *int_ref, const float *int_test, std::size_t size, std::string test_info, +void CompareMSE(const float *float_ref, const float *int_ref, const float *int_test, std::size_t size, std::string test_info, float int_tolerance, float float_tolerance, float MSE_float_tolerance, float MSE_int_tolerance) { float int_sum = 0.0, float_sum = 0.0; for (std::size_t i = 0; i < size; ++i) { diff --git a/test/test.h b/test/test.h index 7c294f8..af6e17a 100644 --- a/test/test.h +++ b/test/test.h @@ -26,9 +26,37 @@ namespace intgemm { -void Compare(const float *float_ref, const float *int_ref, const float *int_test, - std::size_t size, std::string test_info, float int_tolerance, - float float_tolerance, float MSE_float_tolerance, float MSE_int_tolerance); +template +void Compare(const Type* reference, const Type* actual, Index size) { + for (Index i = 0; i < size; ++i) { + INFO("Inaccurate at " << i << ' ' << reference[i] << ' ' << actual[i]); + CHECK(reference[i] == actual[i]); + } +} + +template +void CompareEps(const Type* reference, const Type* actual, Index size, Type epsilon) { + for (Index i = 0; i < size; ++i) { + INFO("Inaccurate at " << i << ' ' << reference[i] << ' ' << actual[i]); + CHECK(fabs(reference[i] - actual[i]) < epsilon); + } +} + +void CompareMSE(const float *float_ref, const float *int_ref, const float *int_test, + std::size_t size, std::string test_info, float int_tolerance, + float float_tolerance, float MSE_float_tolerance, float MSE_int_tolerance); + +template +std::string PrintMatrix(const Type *mem, Index rows, Index cols) { + std::ostringstream out; + for (Index r = 0; r < rows; ++r) { + for (Index c = 0; c < cols; ++c) { + out << std::setw(4) << (int64_t) mem[r * cols + c] << ' '; + } + out << '\n'; + } + return out.str(); +} /* * References -- cgit v1.2.3