diff options
author | Daya Khudia <dskhudia@fb.com> | 2020-04-23 21:29:09 +0300 |
---|---|---|
committer | Facebook GitHub Bot <facebook-github-bot@users.noreply.github.com> | 2020-04-23 21:42:46 +0300 |
commit | 91063b0c3e0327a4f4a42f768abe11e7f0780c7f (patch) | |
tree | 8c94ade96f733ebbf562c739f21c2b3b745947f0 | |
parent | be54f479bccbddec37e1d339c324b621c1001de0 (diff) |
Zero point addition after rounding in quantization routines (#362)
Summary:
Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/362
We want to add zero point after rounding to unify numerics for PyTorch quantization. However, we also maintain the original version to maintain backward compatibility for C2.
Reviewed By: jspark1105
Differential Revision: D21188721
fbshipit-source-id: daaefabd7eafb39ca99eb2d9d90a4db7a8c26c32
-rw-r--r-- | include/fbgemm/QuantUtils.h | 26 | ||||
-rw-r--r-- | include/fbgemm/QuantUtilsAvx2.h | 2 | ||||
-rw-r--r-- | src/QuantUtils.cc | 64 | ||||
-rw-r--r-- | src/QuantUtilsAvx2.cc | 136 |
4 files changed, 128 insertions, 100 deletions
diff --git a/include/fbgemm/QuantUtils.h b/include/fbgemm/QuantUtils.h index 5cd249f..634c608 100644 --- a/include/fbgemm/QuantUtils.h +++ b/include/fbgemm/QuantUtils.h @@ -48,7 +48,7 @@ T2 clamp(T1 src, int precision, bool is_signed = false) { /// Quantize src using zero_point and scale, clamp to the specified precision, /// and convert it to type T -template <typename T> +template <typename T, bool LEGACY = true> T Quantize( float src, std::int32_t zero_point, @@ -65,20 +65,32 @@ T Quantize( // transformed_val is 127.499992 for src / scale. // Eventually 127.5 gets rounded to 128 while 127.499992 gets rounded to 127. float inv_scale = 1.0f / scale; - const float transformed_val = zero_point + src * inv_scale; + + float transformed_val = src * inv_scale; + // nearbyint here performs round-to-nearest-ties-to-even with + // default rounding mode. + // For example, nearbyint(1.4) is 1.0, nearbyint(1.5) is 2.0 + // and nearbyint(2.5) is 2.0 + // Adding zero_point before or after rounding can make a difference + // in exactly halfway cases. + if (LEGACY) { + transformed_val = std::nearbyint(zero_point + transformed_val); + } else { + transformed_val = zero_point + std::nearbyint(transformed_val); + } // Please note the use of double. Unlike float, a double can represent // all int32 values exactly. Using a float results in a float value > // INT32_MAX conversion to int32 in clamp function and hence an UBSAN error. - return clamp<double, T>( - std::nearbyint(transformed_val), result_precision, result_is_signed); + return clamp<double, T>(transformed_val, result_precision, result_is_signed); } -template <typename T> +template <typename T, bool LEGACY = true> T Quantize(float src, const TensorQuantizationParams& qparams) { - return Quantize<T>(src, qparams.zero_point, qparams.scale, qparams.precision); + return Quantize<T, LEGACY>( + src, qparams.zero_point, qparams.scale, qparams.precision); } -template <typename T> +template <typename T, bool LEGACY = true> FBGEMM_API void Quantize( const float* src, T* dst, diff --git a/include/fbgemm/QuantUtilsAvx2.h b/include/fbgemm/QuantUtilsAvx2.h index a07a02e..24c6ae4 100644 --- a/include/fbgemm/QuantUtilsAvx2.h +++ b/include/fbgemm/QuantUtilsAvx2.h @@ -40,7 +40,7 @@ struct FBGEMM_API RequantizationParams { //////////////////////////////////////////////////////////////////////////////// // Utility functions -template <typename T = std::uint8_t> +template <typename T = std::uint8_t, bool LEGACY = true> void QuantizeAvx2( const float* src, T* dst, diff --git a/src/QuantUtils.cc b/src/QuantUtils.cc index 9f96fb4..b982ddb 100644 --- a/src/QuantUtils.cc +++ b/src/QuantUtils.cc @@ -161,9 +161,9 @@ void ChooseRequantizationMultiplier( //////////////////////////////////////////////////////////////////////////////// // Utility functions -#define FBGEMM_SPECIALIZED_QUANTIZE(T) \ +#define FBGEMM_SPECIALIZED_QUANTIZE(T, LEGACY) \ template <> \ - FBGEMM_API void Quantize<T>( \ + FBGEMM_API void Quantize<T, LEGACY>( \ const float* src, \ T* dst, \ const int len, \ @@ -173,39 +173,45 @@ void ChooseRequantizationMultiplier( int i_begin, i_end; \ fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); \ for (int i = i_begin; i < i_end; ++i) { \ - dst[i] = Quantize<T>(src[i], qparams); \ + dst[i] = Quantize<T, LEGACY>(src[i], qparams); \ } \ } -FBGEMM_SPECIALIZED_QUANTIZE(uint16_t) -FBGEMM_SPECIALIZED_QUANTIZE(int16_t) -FBGEMM_SPECIALIZED_QUANTIZE(int32_t) +FBGEMM_SPECIALIZED_QUANTIZE(uint16_t, true) +FBGEMM_SPECIALIZED_QUANTIZE(int16_t, true) +FBGEMM_SPECIALIZED_QUANTIZE(int32_t, true) +FBGEMM_SPECIALIZED_QUANTIZE(uint16_t, false) +FBGEMM_SPECIALIZED_QUANTIZE(int16_t, false) +FBGEMM_SPECIALIZED_QUANTIZE(int32_t, false) #undef FBGEMM_SPECIALIZED_QUANTIZE -#define FBGEMM_SPECIALIZED_QUANTIZE_AVX2(T) \ - template <> \ - FBGEMM_API void Quantize<T>( \ - const float* src, \ - T* dst, \ - int len, \ - const TensorQuantizationParams& qparams, \ - int thread_id, \ - int num_threads) { \ - bool avx2_support = cpuinfo_initialize() && fbgemmHasAvx2Support(); \ - bool fma_support = cpuinfo_has_x86_fma3(); \ - int i_begin, i_end; \ - fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); \ - if (avx2_support && fma_support && qparams.precision == 8) { \ - /* fast path */ \ - QuantizeAvx2<T>(&src[i_begin], &dst[i_begin], i_end - i_begin, qparams); \ - } else { \ - for (std::size_t i = i_begin; i < i_end; ++i) { \ - dst[i] = Quantize<T>(src[i], qparams); \ - } \ - } \ +#define FBGEMM_SPECIALIZED_QUANTIZE_AVX2(T, LEGACY) \ + template <> \ + FBGEMM_API void Quantize<T, LEGACY>( \ + const float* src, \ + T* dst, \ + int len, \ + const TensorQuantizationParams& qparams, \ + int thread_id, \ + int num_threads) { \ + bool avx2_support = cpuinfo_initialize() && fbgemmHasAvx2Support(); \ + bool fma_support = cpuinfo_has_x86_fma3(); \ + int i_begin, i_end; \ + fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end); \ + if (avx2_support && fma_support && qparams.precision == 8) { \ + /* fast path */ \ + QuantizeAvx2<T, LEGACY>( \ + &src[i_begin], &dst[i_begin], i_end - i_begin, qparams); \ + } else { \ + for (std::size_t i = i_begin; i < i_end; ++i) { \ + dst[i] = Quantize<T, LEGACY>(src[i], qparams); \ + } \ + } \ } -FBGEMM_SPECIALIZED_QUANTIZE_AVX2(int8_t) -FBGEMM_SPECIALIZED_QUANTIZE_AVX2(uint8_t) +FBGEMM_SPECIALIZED_QUANTIZE_AVX2(int8_t, true) +FBGEMM_SPECIALIZED_QUANTIZE_AVX2(uint8_t, true) +FBGEMM_SPECIALIZED_QUANTIZE_AVX2(int8_t, false) +FBGEMM_SPECIALIZED_QUANTIZE_AVX2(uint8_t, false) #undef FBGEMM_SPECIALIZED_QUANTIZE_AVX2 #define FBGEMM_SPECIALIZED_QUANTIZEGROUPWISEKCX(T) \ diff --git a/src/QuantUtilsAvx2.cc b/src/QuantUtilsAvx2.cc index 2e6a1c5..eb4e7fa 100644 --- a/src/QuantUtilsAvx2.cc +++ b/src/QuantUtilsAvx2.cc @@ -20,7 +20,7 @@ using namespace std; // Utility functions // ASAN seems to have a false-positive for _mm_maskmoveu_si128 -template <typename T> +template <typename T, bool LEGACY> void NO_SANITIZE("address") QuantizeAvx2( const float* src, T* dst, @@ -28,60 +28,57 @@ void NO_SANITIZE("address") QuantizeAvx2( const TensorQuantizationParams& qparams) { #if defined(__AVX2__) && (defined(__FMA__) || defined(_MSC_VER)) constexpr int VLEN = 8; - constexpr float min_val = std::numeric_limits<T>::min(); - constexpr float max_val = std::numeric_limits<T>::max(); + constexpr int32_t min_val = std::numeric_limits<T>::min(); + constexpr int32_t max_val = std::numeric_limits<T>::max(); + // This is the largest int32 value less than int32_max + // that is exactly representable in float + constexpr int32_t int32_float_max_val = + std::numeric_limits<int32_t>::max() - 127; std::size_t i = 0; float inverse_scale = 1.f / qparams.scale; __m256 inverse_scale_v = _mm256_set1_ps(inverse_scale); + // clang-format off __m256i shuffle_mask_v = _mm256_set_epi8( - 0xff, - 0xff, - 0xff, - 0xff, - 0xff, - 0xff, - 0xff, - 0xff, - 0xff, - 0xff, - 0xff, - 0xff, - 0x0c, - 0x08, - 0x04, - 0x00, - 0xff, - 0xff, - 0xff, - 0xff, - 0xff, - 0xff, - 0xff, - 0xff, - 0xff, - 0xff, - 0xff, - 0xff, - 0x0c, - 0x08, - 0x04, - 0x00); + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0x0c, 0x08, 0x04, 0x00, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, + 0x0c, 0x08, 0x04, 0x00); + // clang-format on __m256i permute_mask_v = _mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00); for (; i < len / VLEN * VLEN; i += VLEN) { __m256 src_v = _mm256_loadu_ps(src + i); - __m256 transformed_v = _mm256_fmadd_ps( - src_v, inverse_scale_v, _mm256_set1_ps(qparams.zero_point)); - __m256 clipped_v = _mm256_min_ps( - _mm256_max_ps(transformed_v, _mm256_set1_ps(min_val)), - _mm256_set1_ps(max_val)); - __m256i rounded_v = _mm256_cvtps_epi32(clipped_v); + __m256 transformed_v; + if (LEGACY) { // static if + transformed_v = _mm256_fmadd_ps( + src_v, inverse_scale_v, _mm256_set1_ps(qparams.zero_point)); + } else { + transformed_v = _mm256_mul_ps(src_v, inverse_scale_v); + } + // If the floating point value is greater than int32_max, + // _mm256_cvtps_epi32 converts them to negative. Clip at int32_float_max_val + // to avoid this. + transformed_v = + _mm256_min_ps(transformed_v, _mm256_set1_ps(int32_float_max_val)); + + __m256i rounded_v = _mm256_cvtps_epi32(transformed_v); + if (!LEGACY) { + rounded_v = + _mm256_add_epi32(rounded_v, _mm256_set1_epi32(qparams.zero_point)); + } + __m256i clipped_v = _mm256_min_epi32( + _mm256_max_epi32(rounded_v, _mm256_set1_epi32(min_val)), + _mm256_set1_epi32(max_val)); // An instruction sequence to save 8 32-bit integers as 8 8-bit integers - rounded_v = _mm256_shuffle_epi8(rounded_v, shuffle_mask_v); - rounded_v = _mm256_permutevar8x32_epi32(rounded_v, permute_mask_v); + clipped_v = _mm256_shuffle_epi8(clipped_v, shuffle_mask_v); + clipped_v = _mm256_permutevar8x32_epi32(clipped_v, permute_mask_v); _mm_storel_epi64( - reinterpret_cast<__m128i*>(dst + i), _mm256_castsi256_si128(rounded_v)); + reinterpret_cast<__m128i*>(dst + i), _mm256_castsi256_si128(clipped_v)); } // Handle remainder using mask instructions so that @@ -93,19 +90,31 @@ void NO_SANITIZE("address") QuantizeAvx2( __m128i store_mask_v = _mm_load_si128( reinterpret_cast<const __m128i*>(internal::sse_epi8_masks[rem])); __m256 src_v = _mm256_maskload_ps(src + i, mask_v); - __m256 transformed_v = _mm256_fmadd_ps( - src_v, inverse_scale_v, _mm256_set1_ps(qparams.zero_point)); - __m256 clipped_v = _mm256_min_ps( - _mm256_max_ps(transformed_v, _mm256_set1_ps(min_val)), - _mm256_set1_ps(max_val)); - __m256i rounded_v = _mm256_cvtps_epi32(clipped_v); + __m256 transformed_v; + if (LEGACY) { + transformed_v = _mm256_fmadd_ps( + src_v, inverse_scale_v, _mm256_set1_ps(qparams.zero_point)); + } else { + transformed_v = _mm256_mul_ps(src_v, inverse_scale_v); + } + transformed_v = + _mm256_min_ps(transformed_v, _mm256_set1_ps(int32_float_max_val)); + + __m256i rounded_v = _mm256_cvtps_epi32(transformed_v); + if (!LEGACY) { + rounded_v = + _mm256_add_epi32(rounded_v, _mm256_set1_epi32(qparams.zero_point)); + } + __m256i clipped_v = _mm256_min_epi32( + _mm256_max_epi32(rounded_v, _mm256_set1_epi32(min_val)), + _mm256_set1_epi32(max_val)); // An instruction sequence to save "rem" number of 32-bit integers // as "rem" number of 8-bit integers - rounded_v = _mm256_shuffle_epi8(rounded_v, shuffle_mask_v); - rounded_v = _mm256_permutevar8x32_epi32(rounded_v, permute_mask_v); + clipped_v = _mm256_shuffle_epi8(clipped_v, shuffle_mask_v); + clipped_v = _mm256_permutevar8x32_epi32(clipped_v, permute_mask_v); _mm_maskmoveu_si128( - _mm256_castsi256_si128(rounded_v), + _mm256_castsi256_si128(clipped_v), store_mask_v, reinterpret_cast<char*>(dst + i)); } @@ -113,16 +122,17 @@ void NO_SANITIZE("address") QuantizeAvx2( } // Instantiate QuantizeAvx2 for known datatypes -template void QuantizeAvx2<uint8_t>( - const float* src, - uint8_t* dst, - int len, - const TensorQuantizationParams& qparams); -template void QuantizeAvx2<int8_t>( - const float* src, - int8_t* dst, - int len, - const TensorQuantizationParams& qparams); +#define SPECIALIZE_QUANTIZEAVX2(T, LEGACY) \ + template void QuantizeAvx2<T, LEGACY>( \ + const float* src, \ + T* dst, \ + int len, \ + const TensorQuantizationParams& qparams); +SPECIALIZE_QUANTIZEAVX2(uint8_t, true) +SPECIALIZE_QUANTIZEAVX2(int8_t, true) +SPECIALIZE_QUANTIZEAVX2(uint8_t, false) +SPECIALIZE_QUANTIZEAVX2(int8_t, false) +#undef SPECIALIZE_QUANTIZEAVX2 void FindMinMax(const float* a, float* min, float* max, int len) { if (len <= 0) { |