diff options
Diffstat (limited to 'test/test.h')
-rw-r--r-- | test/test.h | 34 |
1 files changed, 31 insertions, 3 deletions
diff --git a/test/test.h b/test/test.h index 7c294f8..af6e17a 100644 --- a/test/test.h +++ b/test/test.h @@ -26,9 +26,37 @@ namespace intgemm { -void Compare(const float *float_ref, const float *int_ref, const float *int_test, - std::size_t size, std::string test_info, float int_tolerance, - float float_tolerance, float MSE_float_tolerance, float MSE_int_tolerance); +template <typename Type> +void Compare(const Type* reference, const Type* actual, Index size) { + for (Index i = 0; i < size; ++i) { + INFO("Inaccurate at " << i << ' ' << reference[i] << ' ' << actual[i]); + CHECK(reference[i] == actual[i]); + } +} + +template <typename Type> +void CompareEps(const Type* reference, const Type* actual, Index size, Type epsilon) { + for (Index i = 0; i < size; ++i) { + INFO("Inaccurate at " << i << ' ' << reference[i] << ' ' << actual[i]); + CHECK(fabs(reference[i] - actual[i]) < epsilon); + } +} + +void CompareMSE(const float *float_ref, const float *int_ref, const float *int_test, + std::size_t size, std::string test_info, float int_tolerance, + float float_tolerance, float MSE_float_tolerance, float MSE_int_tolerance); + +template <typename Type> +std::string PrintMatrix(const Type *mem, Index rows, Index cols) { + std::ostringstream out; + for (Index r = 0; r < rows; ++r) { + for (Index c = 0; c < cols; ++c) { + out << std::setw(4) << (int64_t) mem[r * cols + c] << ' '; + } + out << '\n'; + } + return out.str(); +} /* * References |