Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/QuantUtilsAvx2.cc')
-rw-r--r--src/QuantUtilsAvx2.cc26
1 files changed, 22 insertions, 4 deletions
diff --git a/src/QuantUtilsAvx2.cc b/src/QuantUtilsAvx2.cc
index 7f43ced..4a5f458 100644
--- a/src/QuantUtilsAvx2.cc
+++ b/src/QuantUtilsAvx2.cc
@@ -18,13 +18,16 @@ using namespace std;
////////////////////////////////////////////////////////////////////////////////
// Utility functions
+template <typename T>
void QuantizeAvx2(
const float* src,
- uint8_t* dst,
+ T* dst,
int len,
const TensorQuantizationParams& qparams) {
#if defined(__AVX2__) && defined(__FMA__)
constexpr int VLEN = 8;
+ constexpr float min_val = std::numeric_limits<T>::min();
+ constexpr float max_val = std::numeric_limits<T>::max();
std::size_t i = 0;
__m256 inverse_scale_v = _mm256_set1_ps(1.f / qparams.scale);
__m256i shuffle_mask_v = _mm256_set_epi8(
@@ -67,8 +70,8 @@ void QuantizeAvx2(
__m256 transformed_v = _mm256_fmadd_ps(
src_v, inverse_scale_v, _mm256_set1_ps(qparams.zero_point));
__m256 clipped_v = _mm256_min_ps(
- _mm256_max_ps(transformed_v, _mm256_set1_ps(0.f)),
- _mm256_set1_ps(255.f));
+ _mm256_max_ps(transformed_v, _mm256_set1_ps(min_val)),
+ _mm256_set1_ps(max_val));
__m256i rounded_v = _mm256_cvtps_epi32(clipped_v);
// An instruction sequence to save 8 32-bit integers as 8 8-bit integers
@@ -80,7 +83,7 @@ void QuantizeAvx2(
for (; i < len; ++i) {
float transformed = qparams.zero_point + src[i] / qparams.scale;
- float clipped = std::min(std::max(transformed, 0.f), 255.f);
+ float clipped = std::min(std::max(transformed, min_val), max_val);
// Not exactly the same behavior as the vectorized code.
// The vectorized code above always rounds to even in halfway cases
// (https://software.intel.com/en-us/node/523819), but std::nearbyint
@@ -95,6 +98,21 @@ void QuantizeAvx2(
#endif
}
+// Instantiate QuantizeAvx2 for known datatypes
+template
+void QuantizeAvx2<uint8_t>(
+ const float* src,
+ uint8_t* dst,
+ int len,
+ const TensorQuantizationParams& qparams);
+template
+void QuantizeAvx2<int8_t>(
+ const float* src,
+ int8_t* dst,
+ int len,
+ const TensorQuantizationParams& qparams);
+
+
void FindMinMax(const float* a, float* min, float* max, int len) {
if (len <= 0) {
*min = 0.0f;