Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/OpenNMT/CTranslate2.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorGuillaume Klein <guillaumekln@users.noreply.github.com>2022-02-01 12:11:04 +0300
committerGitHub <noreply@github.com>2022-02-01 12:11:04 +0300
commit8a16008d0e4400ae4b51a3ee4cbe9c4d3f86277f (patch)
tree4ff313a6fd2f16938c3ee19791abcb6057b12140 /tests
parent76964fffa24d68219081a8055bf8caceb24f469e (diff)
Round value before cast in quantization (#704)
* Round value before cast in quantization * Update quantization formula in documentation * Add missing variables in lambda capture * Save binary version in named variable * Defer implicit cast to value assignment * Fix Python test * Remove non needed changes
Diffstat (limited to 'tests')
-rw-r--r--tests/ops_test.cc40
1 files changed, 32 insertions, 8 deletions
diff --git a/tests/ops_test.cc b/tests/ops_test.cc
index 73772455..9370394b 100644
--- a/tests/ops_test.cc
+++ b/tests/ops_test.cc
@@ -672,10 +672,22 @@ TEST_P(OpDeviceTest, QuantizeINT8) {
StorageView scale(DataType::FLOAT, device);
StorageView qa(DataType::INT8, device);
StorageView expected_scale({2}, std::vector<float>{12.7, 6.047619}, device);
- StorageView expected_qa(a.shape(), std::vector<int8_t>{-127, -38, 63, 25, 30, 127, -18, 0});
- ops::Quantize()(a, qa, scale);
- expect_storage_eq(scale, expected_scale);
- expect_storage_eq(qa, expected_qa);
+
+ // With rounding before cast.
+ {
+ StorageView expected_qa(a.shape(), std::vector<int8_t>{-127, -38, 64, 25, 30, 127, -18, 0});
+ ops::Quantize(ops::Quantize::ScaleType::GLOBAL, false, true)(a, qa, scale);
+ expect_storage_eq(scale, expected_scale);
+ expect_storage_eq(qa, expected_qa);
+ }
+
+ // Without rounding before cast (legacy behavior).
+ {
+ StorageView expected_qa(a.shape(), std::vector<int8_t>{-127, -38, 63, 25, 30, 127, -18, 0});
+ ops::Quantize(ops::Quantize::ScaleType::GLOBAL, false, false)(a, qa, scale);
+ expect_storage_eq(scale, expected_scale);
+ expect_storage_eq(qa, expected_qa);
+ }
}
TEST_P(OpDeviceTest, QuantizeINT8ZeroRow) {
@@ -684,10 +696,22 @@ TEST_P(OpDeviceTest, QuantizeINT8ZeroRow) {
StorageView scale(DataType::FLOAT, device);
StorageView qa(DataType::INT8, device);
StorageView expected_scale({2}, std::vector<float>{12.7, 1}, device);
- StorageView expected_qa(a.shape(), std::vector<int8_t>{-127, -38, 63, 25, 0, 0, 0, 0});
- ops::Quantize()(a, qa, scale);
- expect_storage_eq(scale, expected_scale);
- expect_storage_eq(qa, expected_qa);
+
+ // With rounding before cast.
+ {
+ StorageView expected_qa(a.shape(), std::vector<int8_t>{-127, -38, 64, 25, 0, 0, 0, 0});
+ ops::Quantize(ops::Quantize::ScaleType::GLOBAL, false, true)(a, qa, scale);
+ expect_storage_eq(scale, expected_scale);
+ expect_storage_eq(qa, expected_qa);
+ }
+
+ // Without rounding before cast (legacy behavior).
+ {
+ StorageView expected_qa(a.shape(), std::vector<int8_t>{-127, -38, 63, 25, 0, 0, 0, 0});
+ ops::Quantize(ops::Quantize::ScaleType::GLOBAL, false, false)(a, qa, scale);
+ expect_storage_eq(scale, expected_scale);
+ expect_storage_eq(qa, expected_qa);
+ }
}
TEST_P(OpDeviceFPTest, Multinomial) {