diff options
Diffstat (limited to 'test/QuantUtilsTest.cc')
-rw-r--r-- | test/QuantUtilsTest.cc | 32 |
1 files changed, 29 insertions, 3 deletions
diff --git a/test/QuantUtilsTest.cc b/test/QuantUtilsTest.cc index 2bbd05e..ddb1f91 100644 --- a/test/QuantUtilsTest.cc +++ b/test/QuantUtilsTest.cc @@ -85,6 +85,32 @@ void runTests( } /** + * There can be off-by-one error in quantized values due to how the mid-point + * cases are rounded-off in vectorized vs scalar codes and due to adding of + * zero_point before rounding vs after rounding. We ignore such differences + * while comparing results. + */ +template <typename T> +::testing::AssertionResult isNear( + const vector<T>& res, + const vector<T>& res_ref) { + bool match = true; + if (res.size() == res_ref.size()) { + for (int i = 0; i < res.size(); ++i) { + if (!(res[i] == res_ref[i] || res[i] == res_ref[i] + 1 || + res[i] == res_ref[i] - 1)) { + match = false; + break; + } + } + } + if (match) + return ::testing::AssertionSuccess(); + else + return ::testing::AssertionFailure() << " Quantized results do not match"; +} + +/** * Test for QuantizeGroupwise */ TEST_P(QuantizeGroupwiseTest, quantizeTest) { @@ -151,7 +177,7 @@ TEST_P(QuantizeGroupwiseTest, quantizeTest) { inp, K, C, X, G, scales, zero_points_int32, dstint32, dstint32_ref); } - EXPECT_EQ(dstuint8, dstuint8_ref); - EXPECT_EQ(dstint8, dstint8_ref); - EXPECT_EQ(dstint32, dstint32_ref); + EXPECT_TRUE(isNear(dstuint8, dstuint8_ref)); + EXPECT_TRUE(isNear(dstint8, dstint8_ref)); + EXPECT_TRUE(isNear(dstint32, dstint32_ref)); } |