Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'test/test.h')
-rw-r--r--test/test.h34
1 files changed, 31 insertions, 3 deletions
diff --git a/test/test.h b/test/test.h
index 7c294f8..af6e17a 100644
--- a/test/test.h
+++ b/test/test.h
@@ -26,9 +26,37 @@
namespace intgemm {
-void Compare(const float *float_ref, const float *int_ref, const float *int_test,
- std::size_t size, std::string test_info, float int_tolerance,
- float float_tolerance, float MSE_float_tolerance, float MSE_int_tolerance);
+template <typename Type>
+void Compare(const Type* reference, const Type* actual, Index size) {
+ for (Index i = 0; i < size; ++i) {
+ INFO("Inaccurate at " << i << ' ' << reference[i] << ' ' << actual[i]);
+ CHECK(reference[i] == actual[i]);
+ }
+}
+
+template <typename Type>
+void CompareEps(const Type* reference, const Type* actual, Index size, Type epsilon) {
+ for (Index i = 0; i < size; ++i) {
+ INFO("Inaccurate at " << i << ' ' << reference[i] << ' ' << actual[i]);
+ CHECK(fabs(reference[i] - actual[i]) < epsilon);
+ }
+}
+
+void CompareMSE(const float *float_ref, const float *int_ref, const float *int_test,
+ std::size_t size, std::string test_info, float int_tolerance,
+ float float_tolerance, float MSE_float_tolerance, float MSE_int_tolerance);
+
+template <typename Type>
+std::string PrintMatrix(const Type *mem, Index rows, Index cols) {
+ std::ostringstream out;
+ for (Index r = 0; r < rows; ++r) {
+ for (Index c = 0; c < cols; ++c) {
+ out << std::setw(4) << (int64_t) mem[r * cols + c] << ' ';
+ }
+ out << '\n';
+ }
+ return out.str();
+}
/*
* References