Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'test/QuantizationHelpers.cc')
-rw-r--r--test/QuantizationHelpers.cc57
1 files changed, 57 insertions, 0 deletions
diff --git a/test/QuantizationHelpers.cc b/test/QuantizationHelpers.cc
new file mode 100644
index 0000000..354519b
--- /dev/null
+++ b/test/QuantizationHelpers.cc
@@ -0,0 +1,57 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "QuantizationHelpers.h"
+#include <algorithm>
+#include <cassert>
+#include <cmath>
+#include <limits>
+
+using namespace std;
+
+namespace fbgemm2 {
+/*
+ * @brief Make sure we won't have overflows from vpmaddubsw instruction.
+ */
+template <typename T>
+void avoidOverflow(int m, int n, int k, const uint8_t* Aint8, T* B) {
+ for (int i = 0; i < m; ++i) {
+ for (int j = 0; j < n; ++j) {
+ for (int kk = 0; kk < k / 2 * 2; kk += 2) {
+ int a0 = Aint8[i * k + kk], a1 = Aint8[i * k + kk + 1];
+ int b0 = B[kk * n + j], b1 = B[(kk + 1) * n + j];
+ int sum_pair = a0 * b0 + a1 * b1;
+ if (sum_pair < numeric_limits<int16_t>::lowest()) {
+ int b1_adjusted =
+ ceil((numeric_limits<int16_t>::lowest() - a0 * b0) / a1);
+ b1_adjusted = std::min(std::max(b1_adjusted, -128), 127);
+
+ int new_sum_pair = a0 * b0 + a1 * b1_adjusted;
+ assert(
+ new_sum_pair >= numeric_limits<int16_t>::lowest() &&
+ new_sum_pair <= numeric_limits<int16_t>::max());
+ B[(kk + 1) * n + j] = b1_adjusted;
+ } else if (sum_pair > numeric_limits<int16_t>::max()) {
+ int b1_adjusted =
+ floor((numeric_limits<int16_t>::max() - a0 * b0) / a1);
+ b1_adjusted = std::min(std::max(b1_adjusted, -128), 127);
+
+ int new_sum_pair = a0 * b0 + a1 * b1_adjusted;
+ assert(
+ new_sum_pair >= numeric_limits<int16_t>::lowest() &&
+ new_sum_pair <= numeric_limits<int16_t>::max());
+ B[(kk + 1) * n + j] = b1_adjusted;
+ }
+ }
+ } // for each j
+ } // for each i
+}
+
+template void
+avoidOverflow(int m, int n, int k, const uint8_t* Aint8, int8_t* B);
+template void
+avoidOverflow(int m, int n, int k, const uint8_t* Aint8, float* B);
+} // namespace fbgemm2