diff options
Diffstat (limited to 'src/QuantUtilsAvx2.cc')
-rw-r--r-- | src/QuantUtilsAvx2.cc | 26 |
1 files changed, 22 insertions, 4 deletions
diff --git a/src/QuantUtilsAvx2.cc b/src/QuantUtilsAvx2.cc index 7f43ced..4a5f458 100644 --- a/src/QuantUtilsAvx2.cc +++ b/src/QuantUtilsAvx2.cc @@ -18,13 +18,16 @@ using namespace std; //////////////////////////////////////////////////////////////////////////////// // Utility functions +template <typename T> void QuantizeAvx2( const float* src, - uint8_t* dst, + T* dst, int len, const TensorQuantizationParams& qparams) { #if defined(__AVX2__) && defined(__FMA__) constexpr int VLEN = 8; + constexpr float min_val = std::numeric_limits<T>::min(); + constexpr float max_val = std::numeric_limits<T>::max(); std::size_t i = 0; __m256 inverse_scale_v = _mm256_set1_ps(1.f / qparams.scale); __m256i shuffle_mask_v = _mm256_set_epi8( @@ -67,8 +70,8 @@ void QuantizeAvx2( __m256 transformed_v = _mm256_fmadd_ps( src_v, inverse_scale_v, _mm256_set1_ps(qparams.zero_point)); __m256 clipped_v = _mm256_min_ps( - _mm256_max_ps(transformed_v, _mm256_set1_ps(0.f)), - _mm256_set1_ps(255.f)); + _mm256_max_ps(transformed_v, _mm256_set1_ps(min_val)), + _mm256_set1_ps(max_val)); __m256i rounded_v = _mm256_cvtps_epi32(clipped_v); // An instruction sequence to save 8 32-bit integers as 8 8-bit integers @@ -80,7 +83,7 @@ void QuantizeAvx2( for (; i < len; ++i) { float transformed = qparams.zero_point + src[i] / qparams.scale; - float clipped = std::min(std::max(transformed, 0.f), 255.f); + float clipped = std::min(std::max(transformed, min_val), max_val); // Not exactly the same behavior as the vectorized code. // The vectorized code above always rounds to even in halfway cases // (https://software.intel.com/en-us/node/523819), but std::nearbyint @@ -95,6 +98,21 @@ void QuantizeAvx2( #endif } +// Instantiate QuantizeAvx2 for known datatypes +template +void QuantizeAvx2<uint8_t>( + const float* src, + uint8_t* dst, + int len, + const TensorQuantizationParams& qparams); +template +void QuantizeAvx2<int8_t>( + const float* src, + int8_t* dst, + int len, + const TensorQuantizationParams& qparams); + + void FindMinMax(const float* a, float* min, float* max, int len) { if (len <= 0) { *min = 0.0f; |