diff options
author | Daya Khudia <dskhudia@fb.com> | 2019-06-22 04:10:52 +0300 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-06-22 04:20:52 +0300 |
commit | 278c146b929caf751f8e4daf31a039effe2bfb0c (patch) | |
tree | 7b3da3706cafbf0ff52a0455b7642568a88a5ebf | |
parent | 5b64af1469cf629aa7beb934eb898fd1e0b02719 (diff) |
fix flaky test (#100)
Summary:
Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/100
The test fails in some cases depending on what random values got generated. See the comment in diff on why does it fail.
Reviewed By: jspark1105
Differential Revision: D15954045
fbshipit-source-id: d128ab7fa61f1b3210274120ac8f1e14c998f063
-rw-r--r-- | test/QuantUtilsTest.cc | 32 |
1 files changed, 29 insertions, 3 deletions
diff --git a/test/QuantUtilsTest.cc b/test/QuantUtilsTest.cc index 2bbd05e..ddb1f91 100644 --- a/test/QuantUtilsTest.cc +++ b/test/QuantUtilsTest.cc @@ -85,6 +85,32 @@ void runTests( } /** + * There can be off-by-one error in quantized values due to how the mid-point + * cases are rounded-off in vectorized vs scalar codes and due to adding of + * zero_point before rounding vs after rounding. We ignore such differences + * while comparing results. + */ +template <typename T> +::testing::AssertionResult isNear( + const vector<T>& res, + const vector<T>& res_ref) { + bool match = true; + if (res.size() == res_ref.size()) { + for (int i = 0; i < res.size(); ++i) { + if (!(res[i] == res_ref[i] || res[i] == res_ref[i] + 1 || + res[i] == res_ref[i] - 1)) { + match = false; + break; + } + } + } + if (match) + return ::testing::AssertionSuccess(); + else + return ::testing::AssertionFailure() << " Quantized results do not match"; +} + +/** * Test for QuantizeGroupwise */ TEST_P(QuantizeGroupwiseTest, quantizeTest) { @@ -151,7 +177,7 @@ TEST_P(QuantizeGroupwiseTest, quantizeTest) { inp, K, C, X, G, scales, zero_points_int32, dstint32, dstint32_ref); } - EXPECT_EQ(dstuint8, dstuint8_ref); - EXPECT_EQ(dstint8, dstint8_ref); - EXPECT_EQ(dstint32, dstint32_ref); + EXPECT_TRUE(isNear(dstuint8, dstuint8_ref)); + EXPECT_TRUE(isNear(dstint8, dstint8_ref)); + EXPECT_TRUE(isNear(dstint32, dstint32_ref)); } |