diff options
author | James Reed <jamesreed@fb.com> | 2019-08-29 21:11:31 +0300 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-08-29 21:26:05 +0300 |
commit | d4bfa96cdaab5c69dfee4d4c9daa9d565435cb2d (patch) | |
tree | a438ba751cf7495bea71711a374f90ef965f0561 | |
parent | 280fa17349b763eb474c423a6d1172f81df29103 (diff) |
int8 specialization for AVX2 Quantize routine (#120)
Summary:
This adds a specialization for `int8` to the AVX2 `Quantize` routine.
I tried also adding a specialization for `int32` (the final datatype we support in PyTorch quantization), but it seemed to introduce numerical issues stemming from the difference in implementations:
https://github.com/pytorch/FBGEMM/blob/master/include/fbgemm/QuantUtils.h#L63
vs
https://github.com/pytorch/FBGEMM/blob/master/src/QuantUtilsAvx2.cc#L82
Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/120
Reviewed By: driazati
Differential Revision: D17115198
Pulled By: jamesr66a
fbshipit-source-id: 119145bb99235a7545389afa61483060200cc2b7
-rw-r--r-- | include/fbgemm/QuantUtilsAvx2.h | 3 | ||||
-rw-r--r-- | src/QuantUtils.cc | 38 | ||||
-rw-r--r-- | src/QuantUtilsAvx2.cc | 26 |
3 files changed, 45 insertions, 22 deletions
diff --git a/include/fbgemm/QuantUtilsAvx2.h b/include/fbgemm/QuantUtilsAvx2.h index 47f33a8..a001004 100644 --- a/include/fbgemm/QuantUtilsAvx2.h +++ b/include/fbgemm/QuantUtilsAvx2.h @@ -40,9 +40,10 @@ struct FBGEMM_API RequantizationParams { //////////////////////////////////////////////////////////////////////////////// // Utility functions +template <typename T=std::uint8_t> void QuantizeAvx2( const float* src, - std::uint8_t* dst, + T* dst, int len, const TensorQuantizationParams& qparams); diff --git a/src/QuantUtils.cc b/src/QuantUtils.cc index 5dde90b..a209efc 100644 --- a/src/QuantUtils.cc +++ b/src/QuantUtils.cc @@ -164,30 +164,34 @@ void ChooseRequantizationMultiplier( dst[i] = Quantize<T>(src[i], qparams); \ } \ } -FBGEMM_SPECIALIZED_QUANTIZE(int8_t) FBGEMM_SPECIALIZED_QUANTIZE(uint16_t) FBGEMM_SPECIALIZED_QUANTIZE(int16_t) FBGEMM_SPECIALIZED_QUANTIZE(int32_t) #undef FBGEMM_SPECIALIZED_QUANTIZE -template <> -void Quantize<uint8_t>( - const float* src, - uint8_t* dst, - int len, - const TensorQuantizationParams& qparams) { - bool avx2_support = cpuinfo_initialize() && fbgemmHasAvx2Support(); - bool fma_support = cpuinfo_has_x86_fma3(); - if (avx2_support && fma_support && qparams.precision == 8) { - // fast path - QuantizeAvx2(src, dst, len, qparams); - } else { - for (std::size_t i = 0; i < len; ++i) { - dst[i] = Quantize<uint8_t>(src[i], qparams); - } - } +#define FBGEMM_SPECIALIZED_QUANTIZE_AVX2(T) \ +template <> \ +void Quantize<T>( \ + const float* src, \ + T* dst, \ + int len, \ + const TensorQuantizationParams& qparams) { \ + bool avx2_support = cpuinfo_initialize() && fbgemmHasAvx2Support(); \ + bool fma_support = cpuinfo_has_x86_fma3(); \ + if (avx2_support && fma_support && qparams.precision == 8) { \ + /* fast path */ \ + QuantizeAvx2<T>(src, dst, len, qparams); \ + } else { \ + for (std::size_t i = 0; i < len; ++i) { \ + dst[i] = Quantize<T>(src[i], qparams); \ + } \ + } \ } +FBGEMM_SPECIALIZED_QUANTIZE_AVX2(int8_t) +FBGEMM_SPECIALIZED_QUANTIZE_AVX2(uint8_t) +#undef FBGEMM_SPECIALIZED_QUANTIZE_AVX2 + #define FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKCX(T) \ template <> \ void QuantizeGroupwise<T, layout_t::KCX>( \ diff --git a/src/QuantUtilsAvx2.cc b/src/QuantUtilsAvx2.cc index 7f43ced..4a5f458 100644 --- a/src/QuantUtilsAvx2.cc +++ b/src/QuantUtilsAvx2.cc @@ -18,13 +18,16 @@ using namespace std; //////////////////////////////////////////////////////////////////////////////// // Utility functions +template <typename T> void QuantizeAvx2( const float* src, - uint8_t* dst, + T* dst, int len, const TensorQuantizationParams& qparams) { #if defined(__AVX2__) && defined(__FMA__) constexpr int VLEN = 8; + constexpr float min_val = std::numeric_limits<T>::min(); + constexpr float max_val = std::numeric_limits<T>::max(); std::size_t i = 0; __m256 inverse_scale_v = _mm256_set1_ps(1.f / qparams.scale); __m256i shuffle_mask_v = _mm256_set_epi8( @@ -67,8 +70,8 @@ void QuantizeAvx2( __m256 transformed_v = _mm256_fmadd_ps( src_v, inverse_scale_v, _mm256_set1_ps(qparams.zero_point)); __m256 clipped_v = _mm256_min_ps( - _mm256_max_ps(transformed_v, _mm256_set1_ps(0.f)), - _mm256_set1_ps(255.f)); + _mm256_max_ps(transformed_v, _mm256_set1_ps(min_val)), + _mm256_set1_ps(max_val)); __m256i rounded_v = _mm256_cvtps_epi32(clipped_v); // An instruction sequence to save 8 32-bit integers as 8 8-bit integers @@ -80,7 +83,7 @@ void QuantizeAvx2( for (; i < len; ++i) { float transformed = qparams.zero_point + src[i] / qparams.scale; - float clipped = std::min(std::max(transformed, 0.f), 255.f); + float clipped = std::min(std::max(transformed, min_val), max_val); // Not exactly the same behavior as the vectorized code. // The vectorized code above always rounds to even in halfway cases // (https://software.intel.com/en-us/node/523819), but std::nearbyint @@ -95,6 +98,21 @@ void QuantizeAvx2( #endif } +// Instantiate QuantizeAvx2 for known datatypes +template +void QuantizeAvx2<uint8_t>( + const float* src, + uint8_t* dst, + int len, + const TensorQuantizationParams& qparams); +template +void QuantizeAvx2<int8_t>( + const float* src, + int8_t* dst, + int len, + const TensorQuantizationParams& qparams); + + void FindMinMax(const float* a, float* min, float* max, int len) { if (len <= 0) { *min = 0.0f; |