Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMateusz Chudyk <mateuszchudyk@gmail.com>2020-02-06 21:41:40 +0300
committerMateusz Chudyk <mateuszchudyk@gmail.com>2020-02-06 22:39:22 +0300
commitb75a56d9b930acfe6c65d00beb09d145bd9f51dc (patch)
tree88f86414552df42ac10b1177d09a6e22af7c2282
parenta72b13b72d04f0863decd46c5b9cdca24d962de3 (diff)
Straighten helper functions used in tests
-rw-r--r--test/add127_test.cc17
-rw-r--r--test/multiply_test.cc15
-rw-r--r--test/test.cc2
-rw-r--r--test/test.h34
4 files changed, 39 insertions, 29 deletions
diff --git a/test/add127_test.cc b/test/add127_test.cc
index ae5c08a..18afaa5 100644
--- a/test/add127_test.cc
+++ b/test/add127_test.cc
@@ -14,13 +14,6 @@ void CompareAs(int8_t * output_old, uint8_t * output_new, Index rows, Index cols
}
}
-void CompareBiases(const float *bias_ref, const float *bias, Index cols) {
- for (std::size_t i = 0; i < cols; ++i) {
- INFO("Inaccurate at " << i << ' ' << bias_ref[i] << ' ' << bias[i]);
- CHECK(fabs(bias_ref[i] - bias[i]) < 0.0001);
- }
-}
-
template <class Routine> void TestPrepareA(Index rows, Index cols) {
std::mt19937 gen;
// Go somewhat out of range too.
@@ -79,12 +72,12 @@ template <class Routine> void TestPrepareBias(Index rows, Index cols) {
it =1;
}
//Routine::Multiply(A_prep2.begin(), B_prep.begin(), A_rows, rows, cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, goldBias.begin(), goldBias.begin()));
- //CompareBiases(goldBias.begin(), inputBias.begin(), cols);
+ //CompareEps(goldBias.begin(), inputBias.begin(), cols, 0.0001f);
AlignedVector<float> slowint_C(cols);
references::Multiply(A_prep2.begin(), B_quant.begin(), slowint_C.begin(), A_rows, rows, cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) {
return sum * unquant_mult_forprep + goldBias[info.col_idx];
});
- CompareBiases(slowint_C.begin(), inputBias.begin(), cols);
+ CompareEps(slowint_C.begin(), inputBias.begin(), cols, 0.0001f);
}
template <class Routine> void TestMultiplyBiasNew(Index A_rows, Index width, Index B_cols,
@@ -146,7 +139,7 @@ template <class Routine> void TestMultiplyBiasNew(Index A_rows, Index width, Ind
//Routine::PrepareBias(B.begin(), bias.begin(), alpha, width, B_cols);
Routine::Multiply8Shift(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin()));
- Compare(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(),
+ CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(),
int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance);
}
@@ -202,7 +195,7 @@ template <class Routine> void TestMultiplyShiftNonShift(Index A_rows, Index widt
Routine::PrepareBias(B_prep.begin(), width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult_forprep, bias.begin(), bias.begin()));
Routine::Multiply8Shift(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWrite(unquant_mult, bias.begin(), test_C.begin()));
- Compare(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(),
+ CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(),
int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance);
}
@@ -280,7 +273,7 @@ template <class Routine> void TestMultiplyShiftInt(Index A_rows, Index width, In
return sum * unquant_mult + ShiftedBias[info.col_idx];
});
- Compare(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(),
+ CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(),
int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance);
}
diff --git a/test/multiply_test.cc b/test/multiply_test.cc
index 725fbca..59c62a9 100644
--- a/test/multiply_test.cc
+++ b/test/multiply_test.cc
@@ -54,17 +54,6 @@ INTGEMM_SSSE3 TEST_CASE("Transpose 8", "[transpose]") {
}
}
-template <class T> std::string PrintMatrix(const T *mem, Index rows, Index cols) {
- std::ostringstream out;
- for (Index r = 0; r < rows; ++r) {
- for (Index c = 0; c < cols; ++c) {
- out << std::setw(4) << (int64_t) mem[r * cols + c] << ' ';
- }
- out << '\n';
- }
- return out.str();
-}
-
template <class Routine> void TestPrepare(Index rows = 32, Index cols = 16) {
std::mt19937 gen;
// Go somewhat out of range too.
@@ -306,7 +295,7 @@ template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_co
return sum;
});
- Compare(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(),
+ CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(),
int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance);
}
@@ -359,7 +348,7 @@ template <class Routine> void TestMultiplyBias(Index A_rows, Index width, Index
return sum + bias[info.col_idx];
});
- Compare(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(),
+ CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(),
int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance);
}
diff --git a/test/test.cc b/test/test.cc
index 62137a1..3559738 100644
--- a/test/test.cc
+++ b/test/test.cc
@@ -7,7 +7,7 @@ int main(int argc, char ** argv) {
namespace intgemm {
-void Compare(const float *float_ref, const float *int_ref, const float *int_test, std::size_t size, std::string test_info,
+void CompareMSE(const float *float_ref, const float *int_ref, const float *int_test, std::size_t size, std::string test_info,
float int_tolerance, float float_tolerance, float MSE_float_tolerance, float MSE_int_tolerance) {
float int_sum = 0.0, float_sum = 0.0;
for (std::size_t i = 0; i < size; ++i) {
diff --git a/test/test.h b/test/test.h
index 7c294f8..af6e17a 100644
--- a/test/test.h
+++ b/test/test.h
@@ -26,9 +26,37 @@
namespace intgemm {
-void Compare(const float *float_ref, const float *int_ref, const float *int_test,
- std::size_t size, std::string test_info, float int_tolerance,
- float float_tolerance, float MSE_float_tolerance, float MSE_int_tolerance);
+template <typename Type>
+void Compare(const Type* reference, const Type* actual, Index size) {
+ for (Index i = 0; i < size; ++i) {
+ INFO("Inaccurate at " << i << ' ' << reference[i] << ' ' << actual[i]);
+ CHECK(reference[i] == actual[i]);
+ }
+}
+
+template <typename Type>
+void CompareEps(const Type* reference, const Type* actual, Index size, Type epsilon) {
+ for (Index i = 0; i < size; ++i) {
+ INFO("Inaccurate at " << i << ' ' << reference[i] << ' ' << actual[i]);
+ CHECK(fabs(reference[i] - actual[i]) < epsilon);
+ }
+}
+
+void CompareMSE(const float *float_ref, const float *int_ref, const float *int_test,
+ std::size_t size, std::string test_info, float int_tolerance,
+ float float_tolerance, float MSE_float_tolerance, float MSE_int_tolerance);
+
+template <typename Type>
+std::string PrintMatrix(const Type *mem, Index rows, Index cols) {
+ std::ostringstream out;
+ for (Index r = 0; r < rows; ++r) {
+ for (Index c = 0; c < cols; ++c) {
+ out << std::setw(4) << (int64_t) mem[r * cols + c] << ' ';
+ }
+ out << '\n';
+ }
+ return out.str();
+}
/*
* References