diff options
Diffstat (limited to 'src/QuantUtilsAvx2.cc')
-rw-r--r--[-rwxr-xr-x] | src/QuantUtilsAvx2.cc | 760 |
1 files changed, 510 insertions, 250 deletions
diff --git a/src/QuantUtilsAvx2.cc b/src/QuantUtilsAvx2.cc index 821999e..66828ae 100755..100644 --- a/src/QuantUtilsAvx2.cc +++ b/src/QuantUtilsAvx2.cc @@ -18,16 +18,20 @@ using namespace std; //////////////////////////////////////////////////////////////////////////////// // Utility functions +template <typename T> void QuantizeAvx2( const float* src, - uint8_t* dst, + T* dst, int len, const TensorQuantizationParams& qparams) { -#if defined(__AVX2__) && defined(__FMA__) - constexpr int VLEN = 8; - std::size_t i = 0; - __m256 inverse_scale_v = _mm256_set1_ps(1.f / qparams.scale); - __m256i shuffle_mask_v = _mm256_set_epi8( + // original compile condition - #if defined(__AVX2__) && (defined(__FMA__) || defined(_MSC_VER)) + if (fbgemm::fbgemmHasAvx2Support()) { + constexpr int VLEN = 8; + constexpr float min_val = std::numeric_limits<T>::min(); + constexpr float max_val = std::numeric_limits<T>::max(); + std::size_t i = 0; + __m256 inverse_scale_v = _mm256_set1_ps(1.f / qparams.scale); + __m256i shuffle_mask_v = _mm256_set_epi8( 0xff, 0xff, 0xff, @@ -60,41 +64,56 @@ void QuantizeAvx2( 0x08, 0x04, 0x00); - __m256i permute_mask_v = + __m256i permute_mask_v = _mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00); - for (; i < len / VLEN * VLEN; i += VLEN) { - __m256 src_v = _mm256_loadu_ps(src + i); - __m256 transformed_v = _mm256_fmadd_ps( + for (; i < len / VLEN * VLEN; i += VLEN) { + __m256 src_v = _mm256_loadu_ps(src + i); + __m256 transformed_v = _mm256_fmadd_ps( src_v, inverse_scale_v, _mm256_set1_ps(qparams.zero_point)); - __m256 clipped_v = _mm256_min_ps( - _mm256_max_ps(transformed_v, _mm256_set1_ps(0.f)), - _mm256_set1_ps(255.f)); - __m256i rounded_v = _mm256_cvtps_epi32(clipped_v); - - // An instruction sequence to save 8 32-bit integers as 8 8-bit integers - rounded_v = _mm256_shuffle_epi8(rounded_v, shuffle_mask_v); - rounded_v = _mm256_permutevar8x32_epi32(rounded_v, permute_mask_v); - _mm_storel_epi64( + __m256 clipped_v = _mm256_min_ps( + _mm256_max_ps(transformed_v, _mm256_set1_ps(min_val)), + _mm256_set1_ps(max_val)); + __m256i rounded_v = _mm256_cvtps_epi32(clipped_v); + + // An instruction sequence to save 8 32-bit integers as 8 8-bit integers + rounded_v = _mm256_shuffle_epi8(rounded_v, shuffle_mask_v); + rounded_v = _mm256_permutevar8x32_epi32(rounded_v, permute_mask_v); + _mm_storel_epi64( reinterpret_cast<__m128i*>(dst + i), _mm256_castsi256_si128(rounded_v)); - } + } - for (; i < len; ++i) { - float transformed = qparams.zero_point + src[i] / qparams.scale; - float clipped = std::min(std::max(transformed, 0.f), 255.f); - // Not exactly the same behavior as the vectorized code. - // The vectorized code above always rounds to even in halfway cases - // (https://software.intel.com/en-us/node/523819), but std::nearbyint - // does the same only when the current rounding mode is FE_TONEAREST. - // However, in practice, this should not be a problem because most cases - // use the default rounding mode FE_TONEAREST. - // Note that we cannot implement the same behavior as the vectorized code - // using std::round because it does rounding away from zero in halfway - // cases. - dst[i] = nearbyint(clipped); + for (; i < len; ++i) { + float transformed = qparams.zero_point + src[i] / qparams.scale; + float clipped = std::min(std::max(transformed, min_val), max_val); + // Not exactly the same behavior as the vectorized code. + // The vectorized code above always rounds to even in halfway cases + // (https://software.intel.com/en-us/node/523819), but std::nearbyint + // does the same only when the current rounding mode is FE_TONEAREST. + // However, in practice, this should not be a problem because most cases + // use the default rounding mode FE_TONEAREST. + // Note that we cannot implement the same behavior as the vectorized code + // using std::round because it does rounding away from zero in halfway + // cases. + dst[i] = nearbyint(clipped); + } } -#endif } +// Instantiate QuantizeAvx2 for known datatypes +template +void QuantizeAvx2<uint8_t>( + const float* src, + uint8_t* dst, + int len, + const TensorQuantizationParams& qparams); +template +void QuantizeAvx2<int8_t>( + const float* src, + int8_t* dst, + int len, + const TensorQuantizationParams& qparams); + + void FindMinMax(const float* a, float* min, float* max, int len) { if (len <= 0) { *min = 0.0f; @@ -105,24 +124,24 @@ void FindMinMax(const float* a, float* min, float* max, int len) { float temp_min = *a, temp_max = *a; int i = 0; -#ifdef __AVX__ - __m256 min_v = _mm256_set1_ps(*a), max_v = _mm256_set1_ps(*a); - constexpr int VLEN = 8; - if (len >= VLEN) { - for (; i < len / VLEN * VLEN; i += VLEN) { - min_v = _mm256_min_ps(min_v, _mm256_loadu_ps(a + i)); - max_v = _mm256_max_ps(max_v, _mm256_loadu_ps(a + i)); - } + if (fbgemm::fbgemmHasAvx2Support()) { + __m256 min_v = _mm256_set1_ps(*a), max_v = _mm256_set1_ps(*a); + constexpr int VLEN = 8; + if (len >= VLEN) { + for (; i < len / VLEN * VLEN; i += VLEN) { + min_v = _mm256_min_ps(min_v, _mm256_loadu_ps(a + i)); + max_v = _mm256_max_ps(max_v, _mm256_loadu_ps(a + i)); + } - float min_buf[VLEN], max_buf[VLEN]; - _mm256_storeu_ps(min_buf, min_v); - _mm256_storeu_ps(max_buf, max_v); - for (int j = 0; j < VLEN; ++j) { - temp_min = std::min(temp_min, min_buf[j]); - temp_max = std::max(temp_max, max_buf[j]); + float min_buf[VLEN], max_buf[VLEN]; + _mm256_storeu_ps(min_buf, min_v); + _mm256_storeu_ps(max_buf, max_v); + for (int j = 0; j < VLEN; ++j) { + temp_min = std::min(temp_min, min_buf[j]); + temp_max = std::max(temp_max, max_buf[j]); + } } } -#endif for (; i < len; i++) { temp_min = std::min(temp_min, a[i]); @@ -135,15 +154,15 @@ void FindMinMax(const float* a, float* min, float* max, int len) { //////////////////////////////////////////////////////////////////////////////// // Requantization (with floats) -#ifdef __AVX2__ void RequantizeAvx2( const int32_t* src, uint8_t* dst, int len, const RequantizationParams& params) { - DoNothing<> doNothingObj{}; - int32_t Bq_zero_point[] = { 0 }; - ReQuantizeOutput<false /* FUSE_RELU */> requantizeObj( + if (fbgemm::fbgemmHasAvx2Support()) { + DoNothing<> doNothingObj{}; + int32_t Bq_zero_point[] = { 0 }; + ReQuantizeOutput<false /* FUSE_RELU */> requantizeObj( doNothingObj, ¶ms.real_multiplier, params.target_qparams.zero_point, @@ -153,7 +172,8 @@ void RequantizeAvx2( nullptr, // col_offsets nullptr, // bias len); // ncol - requantizeObj.f<inst_set_t::avx2>(dst, src, {0, 1, 0, len}, 0, 0); + requantizeObj.f<inst_set_t::avx2>(dst, src, { 0, 1, 0, len }, 0, 0); + } } void RequantizeFixedPointAvx2( @@ -161,24 +181,26 @@ void RequantizeFixedPointAvx2( uint8_t* dst, int len, const RequantizationParams& params) { - constexpr int VLEN = 8; + if (fbgemm::fbgemmHasAvx2Support()) + { + constexpr int VLEN = 8; - __m256i b = _mm256_set1_epi32(params.multiplier); + __m256i b = _mm256_set1_epi32(params.multiplier); - // AVX2 doesn't support arithmetic right shift. - // As a work around, we convert 64-bit multiplied results to uint64_t by - // adding 0x8000000000000000ULL, logical right shift, and subtract by - // (0x8000000000000000ULL >> right_shift). - __m256i pre_shift_nudge = _mm256_set1_epi64x( + // AVX2 doesn't support arithmetic right shift. + // As a work around, we convert 64-bit multiplied results to uint64_t by + // adding 0x8000000000000000ULL, logical right shift, and subtract by + // (0x8000000000000000ULL >> right_shift). + __m256i pre_shift_nudge = _mm256_set1_epi64x( (1ll << (params.right_shift - 1)) + 0x8000000000000000ULL); - __m256i post_shift_nudge = _mm256_set1_epi64x( + __m256i post_shift_nudge = _mm256_set1_epi64x( params.target_qparams.zero_point - (0x8000000000000000ULL >> params.right_shift)); - __m256i min_v = _mm256_set1_epi32(numeric_limits<uint8_t>::min()); - __m256i max_v = _mm256_set1_epi32(numeric_limits<uint8_t>::max()); + __m256i min_v = _mm256_set1_epi32(numeric_limits<uint8_t>::min()); + __m256i max_v = _mm256_set1_epi32(numeric_limits<uint8_t>::max()); - __m256i shuffle_mask_v = _mm256_set_epi8( + __m256i shuffle_mask_v = _mm256_set_epi8( 0xff, 0xff, 0xff, @@ -211,75 +233,68 @@ void RequantizeFixedPointAvx2( 0x08, 0x04, 0x00); - __m256i permute_mask_v = + __m256i permute_mask_v = _mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00); - int i = 0; - for (; i < len / VLEN * VLEN; i += VLEN) { - __m256i a_v = _mm256_loadu_si256((const __m256i*)(src + i)); + int i = 0; + for (; i < len / VLEN * VLEN; i += VLEN) { + __m256i a_v = _mm256_loadu_si256((const __m256i*)(src + i)); - // a = a0 | a1 | a2 | a3 | a4 | a5 | a6 | a7 - // b = b0 | b1 | b3 | b3 | b4 | b5 | b6 | b7 - __m256i a_even_v = a_v; - __m256i a_odd_v = _mm256_srli_si256(a_v, 4); + // a = a0 | a1 | a2 | a3 | a4 | a5 | a6 | a7 + // b = b0 | b1 | b3 | b3 | b4 | b5 | b6 | b7 + __m256i a_even_v = a_v; + __m256i a_odd_v = _mm256_srli_si256(a_v, 4); - __m256i ab_even_v = _mm256_mul_epi32(a_even_v, b); - __m256i ab_odd_v = _mm256_mul_epi32(a_odd_v, b); + __m256i ab_even_v = _mm256_mul_epi32(a_even_v, b); + __m256i ab_odd_v = _mm256_mul_epi32(a_odd_v, b); - __m256i even_rounded_v = _mm256_add_epi64(ab_even_v, pre_shift_nudge); - __m256i odd_rounded_v = _mm256_add_epi64(ab_odd_v, pre_shift_nudge); + __m256i even_rounded_v = _mm256_add_epi64(ab_even_v, pre_shift_nudge); + __m256i odd_rounded_v = _mm256_add_epi64(ab_odd_v, pre_shift_nudge); - __m256i even_result_v = _mm256_add_epi64( + __m256i even_result_v = _mm256_add_epi64( _mm256_srli_epi64(even_rounded_v, params.right_shift), post_shift_nudge); - __m256i odd_result_v = _mm256_add_epi64( + __m256i odd_result_v = _mm256_add_epi64( _mm256_srli_epi64(odd_rounded_v, params.right_shift), post_shift_nudge); - odd_result_v = _mm256_slli_si256(odd_result_v, 4); + odd_result_v = _mm256_slli_si256(odd_result_v, 4); - // even_result_v has numbers we want in its even 32-bit SIMD lanes, and - // odd_result_v has numbers we want in its odd 32-bit SIMD lanes. - // Use blend to combine them. - __m256i result_v = _mm256_blend_epi32(even_result_v, odd_result_v, 0xaa); - __m256i clipped_v = + // even_result_v has numbers we want in its even 32-bit SIMD lanes, and + // odd_result_v has numbers we want in its odd 32-bit SIMD lanes. + // Use blend to combine them. + __m256i result_v = _mm256_blend_epi32(even_result_v, odd_result_v, 0xaa); + __m256i clipped_v = _mm256_max_epi32(min_v, _mm256_min_epi32(max_v, result_v)); - clipped_v = _mm256_shuffle_epi8(clipped_v, shuffle_mask_v); - clipped_v = _mm256_permutevar8x32_epi32(clipped_v, permute_mask_v); - *(int64_t*)(dst + i) = _mm256_extract_epi64(clipped_v, 0); - } + clipped_v = _mm256_shuffle_epi8(clipped_v, shuffle_mask_v); + clipped_v = _mm256_permutevar8x32_epi32(clipped_v, permute_mask_v); + *(int64_t*)(dst + i) = _mm256_extract_epi64(clipped_v, 0); + } - for (; i < len; ++i) { - int64_t ab_64 = + for (; i < len; ++i) { + int64_t ab_64 = static_cast<int64_t>(src[i]) * static_cast<int64_t>(params.multiplier); - int64_t nudge = 1ll << std::max(0, params.right_shift - 1); - int64_t quantized_down = params.target_qparams.zero_point + + int64_t nudge = 1ll << std::max(0, params.right_shift - 1); + int64_t quantized_down = params.target_qparams.zero_point + ((ab_64 + nudge) >> params.right_shift); - dst[i] = std::min<int64_t>(std::max<int64_t>(quantized_down, 0l), 255l); + dst[i] = std::min<int64_t>(std::max<int64_t>(quantized_down, 0l), 255l); + } } } -#else -// dummy implementations to avoid link errors -void RequantizeAvx2(const int32_t* src, uint8_t* dst, int len, const RequantizationParams& params) { - assert(false && "RequantizeAvx2() was called unexpectedly in non-AVX2 build"); -} -void RequantizeFixedPointAvx2(const int32_t* src, uint8_t* dst, int len, const RequantizationParams& params) { - assert(false && "RequantizeFixedPointAvx2() was called unexpectedly in non-AVX2 build"); -} -#endif template < bool A_SYMMETRIC, bool B_SYMMETRIC, QuantizationGranularity Q_GRAN, bool HAS_BIAS, - bool FUSE_RELU> + bool FUSE_RELU, + typename BIAS_TYPE> void requantizeOutputProcessingAvx2( uint8_t* out, const int32_t* inp, const block_type_t& block, int ld_out, int ld_in, - const requantizationParams_t& r) { + const requantizationParams_t<BIAS_TYPE>& r) { // Adoption of implementation at QNNPACK/src/requantization/fp32-sse2.c // using AVX2 instructions int quant_param_idx = 0; @@ -290,6 +305,15 @@ void requantizeOutputProcessingAvx2( } __m256 multiplier_v = _mm256_set1_ps(r.C_multiplier[quant_param_idx]); + // Broadcasted reciprocal of act_times_w_scale + __m256 act_times_w_rcp_v; + if (!(Q_GRAN == QuantizationGranularity::OUT_CHANNEL)) { + if (is_same<BIAS_TYPE, float>::value) { + act_times_w_rcp_v = + _mm256_set1_ps(1.0 / r.act_times_w_scale[quant_param_idx]); + } + } + __m256i min_v = _mm256_set1_epi8(static_cast<uint8_t>(0)); __m256i max_v = _mm256_set1_epi8(static_cast<uint8_t>(255)); @@ -399,22 +423,76 @@ void requantizeOutputProcessingAvx2( } w_v = _mm256_sub_epi32(w_v, row_offset_v); } + __m256 xf_v, yf_v, zf_v, wf_v; if (HAS_BIAS) { - x_v = _mm256_add_epi32( - x_v, - _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j))); - y_v = _mm256_add_epi32( - y_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(r.bias + j + VLEN))); - z_v = _mm256_add_epi32( - z_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN))); - w_v = _mm256_add_epi32( - w_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN))); + if (is_same<BIAS_TYPE, float>::value) { + __m256 x_bias_v, y_bias_v, z_bias_v, w_bias_v; + if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { + x_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)), + _mm256_loadu_ps(r.act_times_w_scale + j + 0 * VLEN)); + y_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)), + _mm256_loadu_ps(r.act_times_w_scale + j + 1 * VLEN)); + z_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)), + _mm256_loadu_ps(r.act_times_w_scale + j + 2 * VLEN)); + w_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)), + _mm256_loadu_ps(r.act_times_w_scale + j + 3 * VLEN)); + } else { + x_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)), + act_times_w_rcp_v); + y_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)), + act_times_w_rcp_v); + z_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)), + act_times_w_rcp_v); + w_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)), + act_times_w_rcp_v); + } + xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v); + yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), y_bias_v); + zf_v = _mm256_add_ps(_mm256_cvtepi32_ps(z_v), z_bias_v); + wf_v = _mm256_add_ps(_mm256_cvtepi32_ps(w_v), w_bias_v); + } else { + x_v = _mm256_add_epi32( + x_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 0 * VLEN))); + y_v = _mm256_add_epi32( + y_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 1 * VLEN))); + z_v = _mm256_add_epi32( + z_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN))); + w_v = _mm256_add_epi32( + w_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN))); + xf_v = _mm256_cvtepi32_ps(x_v); + yf_v = _mm256_cvtepi32_ps(y_v); + zf_v = _mm256_cvtepi32_ps(z_v); + wf_v = _mm256_cvtepi32_ps(w_v); + } + } else { + xf_v = _mm256_cvtepi32_ps(x_v); + yf_v = _mm256_cvtepi32_ps(y_v); + zf_v = _mm256_cvtepi32_ps(z_v); + wf_v = _mm256_cvtepi32_ps(w_v); } /* @@ -431,22 +509,19 @@ void requantizeOutputProcessingAvx2( */ __m256 x_scaled_v, y_scaled_v, z_scaled_v, w_scaled_v; if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { - x_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(x_v), _mm256_loadu_ps(r.C_multiplier + j)); - y_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(y_v), - _mm256_loadu_ps(r.C_multiplier + j + VLEN)); - z_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(z_v), - _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN)); - w_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(w_v), - _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN)); + x_scaled_v = + _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j + 0 * VLEN)); + y_scaled_v = + _mm256_mul_ps(yf_v, _mm256_loadu_ps(r.C_multiplier + j + 1 * VLEN)); + z_scaled_v = + _mm256_mul_ps(zf_v, _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN)); + w_scaled_v = + _mm256_mul_ps(wf_v, _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN)); } else { - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); - y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); - z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); - w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); + y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v); + z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v); + w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v); } /* @@ -533,18 +608,35 @@ void requantizeOutputProcessingAvx2( } x_v = _mm256_sub_epi32(x_v, row_offset_v); } + __m256 xf_v; if (HAS_BIAS) { - x_v = _mm256_add_epi32( - x_v, - _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j))); + if (is_same<BIAS_TYPE, float>::value) { + __m256 x_bias_v; + if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { + x_bias_v = _mm256_div_ps( + _mm256_loadu_ps(reinterpret_cast<const float*>(r.bias + j)), + _mm256_loadu_ps(r.act_times_w_scale + j)); + } else { + x_bias_v = _mm256_mul_ps( + _mm256_loadu_ps(reinterpret_cast<const float*>(r.bias + j)), + act_times_w_rcp_v); + } + xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v); + } else { + x_v = _mm256_add_epi32( + x_v, + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j))); + xf_v = _mm256_cvtepi32_ps(x_v); + } + } else { + xf_v = _mm256_cvtepi32_ps(x_v); } __m256 x_scaled_v; if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { - x_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(x_v), _mm256_loadu_ps(r.C_multiplier + j)); + x_scaled_v = _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j)); } else { - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); } __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); @@ -582,6 +674,7 @@ void requantizeOutputProcessingAvx2( int remainder = block.col_start + block.col_size - j; if (remainder > 0) { + // clang-format off alignas(64) const int masks[8][8] = { // NOTE: clang-format wants to use a different formatting but the // current formatting should be easier to read. @@ -594,6 +687,7 @@ void requantizeOutputProcessingAvx2( { -1, -1, -1, -1, -1, -1, 0, 0, }, { -1, -1, -1, -1, -1, -1, -1, 0, }, }; + // clang-format on __m256i mask_v = _mm256_load_si256( reinterpret_cast<const __m256i*>(masks[remainder])); @@ -615,17 +709,40 @@ void requantizeOutputProcessingAvx2( } x_v = _mm256_sub_epi32(x_v, row_offset_v); } + + __m256 xf_v; if (HAS_BIAS) { - x_v = _mm256_add_epi32(x_v, _mm256_maskload_epi32(r.bias + j, mask_v)); + if (is_same<BIAS_TYPE, float>::value) { + __m256 x_bias_v; + if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { + x_bias_v = _mm256_div_ps( + _mm256_maskload_ps( + reinterpret_cast<const float*>(r.bias + j), mask_v), + _mm256_maskload_ps(r.act_times_w_scale + j, mask_v)); + } else { + x_bias_v = _mm256_mul_ps( + _mm256_maskload_ps( + reinterpret_cast<const float*>(r.bias + j), mask_v), + act_times_w_rcp_v); + } + xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v); + } else { + x_v = _mm256_add_epi32( + x_v, + _mm256_maskload_epi32( + reinterpret_cast<const int*>(r.bias + j), mask_v)); + xf_v = _mm256_cvtepi32_ps(x_v); + } + } else { + xf_v = _mm256_cvtepi32_ps(x_v); } __m256 x_scaled_v; if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { - x_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(x_v), - _mm256_maskload_ps(r.C_multiplier + j, mask_v)); + x_scaled_v = + _mm256_mul_ps(xf_v, _mm256_maskload_ps(r.C_multiplier + j, mask_v)); } else { - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); } __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); @@ -767,6 +884,7 @@ void requantizeForFloatAvx2( int remainder = block.col_start + block.col_size - j; if (remainder > 0) { + // clang-format off alignas(64) const int masks[8][8] = { // NOTE: clang-format wants to use a different formatting but the // current formatting should be easier to read. @@ -779,6 +897,7 @@ void requantizeForFloatAvx2( { -1, -1, -1, -1, -1, -1, 0, 0, }, { -1, -1, -1, -1, -1, -1, -1, 0, }, }; + // clang-format on __m256i mask_v = _mm256_load_si256( reinterpret_cast<const __m256i*>(masks[remainder])); @@ -831,14 +950,15 @@ template < QuantizationGranularity Q_GRAN, bool HAS_BIAS, bool FUSE_RELU, - int C_PER_G> + int C_PER_G, + typename BIAS_TYPE> void requantizeOutputProcessingGConvAvx2( uint8_t* out, const int32_t* inp, const block_type_t& block, int ld_out, int ld_in, - const requantizationParams_t& r) { + const requantizationParams_t<BIAS_TYPE>& r) { // Adoption of implementation at QNNPACK/src/requantization/fp32-sse2.c // using AVX2 instructions int quant_param_idx = 0; @@ -849,6 +969,14 @@ void requantizeOutputProcessingGConvAvx2( } __m256 multiplier_v = _mm256_set1_ps(r.C_multiplier[quant_param_idx]); + // Broadcasted reciprocal of act_times_w_scale + __m256 act_times_w_rcp_v; + if (!(Q_GRAN == QuantizationGranularity::OUT_CHANNEL)) { + if (is_same<BIAS_TYPE, float>::value) { + act_times_w_rcp_v = + _mm256_set1_ps(1.0 / r.act_times_w_scale[quant_param_idx]); + } + } __m256i min_v = _mm256_set1_epi8(static_cast<uint8_t>(0)); __m256i max_v = _mm256_set1_epi8(static_cast<uint8_t>(255)); @@ -1095,22 +1223,135 @@ void requantizeOutputProcessingGConvAvx2( } w_v = _mm256_sub_epi32(w_v, row_offset_v); } + __m256 xf_v, yf_v, zf_v, wf_v; if (HAS_BIAS) { - x_v = _mm256_add_epi32( - x_v, - _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j))); - y_v = _mm256_add_epi32( - y_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(r.bias + j + VLEN))); - z_v = _mm256_add_epi32( - z_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN))); - w_v = _mm256_add_epi32( - w_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN))); + if (is_same<BIAS_TYPE, float>::value) { + __m256 x_bias_v = _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)); + __m256 y_bias_v = _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)); + __m256 z_bias_v = _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)); + __m256 w_bias_v = _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)); + if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { + x_bias_v = _mm256_div_ps( + x_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 0 * VLEN)); + y_bias_v = _mm256_div_ps( + y_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 1 * VLEN)); + z_bias_v = _mm256_div_ps( + z_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 2 * VLEN)); + w_bias_v = _mm256_div_ps( + w_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 3 * VLEN)); + } else if (Q_GRAN == QuantizationGranularity::GROUP) { + __m256 diviser_v; + if (C_PER_G == 4) { + diviser_v = _mm256_insertf128_ps( + _mm256_castps128_ps256( + _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 0])), + _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 1]), + 1); + x_bias_v = _mm256_div_ps(x_bias_v, diviser_v); + + diviser_v = _mm256_insertf128_ps( + _mm256_castps128_ps256( + _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 2])), + _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 3]), + 1); + y_bias_v = _mm256_div_ps(y_bias_v, diviser_v); + + diviser_v = _mm256_insertf128_ps( + _mm256_castps128_ps256( + _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 4])), + _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 5]), + 1); + z_bias_v = _mm256_div_ps(z_bias_v, diviser_v); + + diviser_v = _mm256_insertf128_ps( + _mm256_castps128_ps256( + _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 6])), + _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 7]), + 1); + w_bias_v = _mm256_div_ps(w_bias_v, diviser_v); + + } else if (C_PER_G == 8) { + diviser_v = _mm256_set1_ps( + r.act_times_w_scale + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 4 + 0]); + x_bias_v = _mm256_div_ps(x_bias_v, diviser_v); + + diviser_v = _mm256_set1_ps( + r.act_times_w_scale + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 4 + 1]); + y_bias_v = _mm256_div_ps(y_bias_v, diviser_v); + + diviser_v = _mm256_set1_ps( + r.act_times_w_scale + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 4 + 2]); + z_bias_v = _mm256_div_ps(z_bias_v, diviser_v); + + diviser_v = _mm256_set1_ps( + r.act_times_w_scale + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 4 + 3]); + w_bias_v = _mm256_div_ps(w_bias_v, diviser_v); + + } else { + assert(C_PER_G == 16); + diviser_v = _mm256_set1_ps( + r.act_times_w_scale + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 2 + 0]); + x_bias_v = _mm256_div_ps(x_bias_v, diviser_v); + y_bias_v = _mm256_div_ps(y_bias_v, diviser_v); + + diviser_v = _mm256_set1_ps( + r.act_times_w_scale + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 2 + 1]); + z_bias_v = _mm256_div_ps(z_bias_v, diviser_v); + w_bias_v = _mm256_div_ps(w_bias_v, diviser_v); + } + } else { + x_bias_v = _mm256_mul_ps(x_bias_v, act_times_w_rcp_v); + y_bias_v = _mm256_mul_ps(y_bias_v, act_times_w_rcp_v); + z_bias_v = _mm256_mul_ps(z_bias_v, act_times_w_rcp_v); + w_bias_v = _mm256_mul_ps(w_bias_v, act_times_w_rcp_v); + } + xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v); + yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), y_bias_v); + zf_v = _mm256_add_ps(_mm256_cvtepi32_ps(z_v), z_bias_v); + wf_v = _mm256_add_ps(_mm256_cvtepi32_ps(w_v), w_bias_v); + } else { + x_v = _mm256_add_epi32( + x_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 0 * VLEN))); + y_v = _mm256_add_epi32( + y_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 1 * VLEN))); + z_v = _mm256_add_epi32( + z_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN))); + w_v = _mm256_add_epi32( + w_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN))); + xf_v = _mm256_cvtepi32_ps(x_v); + yf_v = _mm256_cvtepi32_ps(y_v); + zf_v = _mm256_cvtepi32_ps(z_v); + wf_v = _mm256_cvtepi32_ps(w_v); + } + } else { + xf_v = _mm256_cvtepi32_ps(x_v); + yf_v = _mm256_cvtepi32_ps(y_v); + zf_v = _mm256_cvtepi32_ps(z_v); + wf_v = _mm256_cvtepi32_ps(w_v); } /* @@ -1127,17 +1368,13 @@ void requantizeOutputProcessingGConvAvx2( */ __m256 x_scaled_v, y_scaled_v, z_scaled_v, w_scaled_v; if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { - x_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(x_v), _mm256_loadu_ps(r.C_multiplier + j)); - y_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(y_v), - _mm256_loadu_ps(r.C_multiplier + j + VLEN)); - z_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(z_v), - _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN)); - w_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(w_v), - _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN)); + x_scaled_v = _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j)); + y_scaled_v = + _mm256_mul_ps(yf_v, _mm256_loadu_ps(r.C_multiplier + j + VLEN)); + z_scaled_v = + _mm256_mul_ps(zf_v, _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN)); + w_scaled_v = + _mm256_mul_ps(wf_v, _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN)); } else if (Q_GRAN == QuantizationGranularity::GROUP) { if (C_PER_G == 4) { multiplier_v = _mm256_insertf128_ps( @@ -1145,70 +1382,70 @@ void requantizeOutputProcessingGConvAvx2( _mm_set1_ps(r.C_multiplier[quant_param_idx])), _mm_set1_ps(r.C_multiplier[quant_param_idx + 1]), 1); - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); multiplier_v = _mm256_insertf128_ps( _mm256_castps128_ps256( _mm_set1_ps(r.C_multiplier[quant_param_idx + 2])), _mm_set1_ps(r.C_multiplier[quant_param_idx + 3]), 1); - y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); + y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v); multiplier_v = _mm256_insertf128_ps( _mm256_castps128_ps256( _mm_set1_ps(r.C_multiplier[quant_param_idx + 4])), _mm_set1_ps(r.C_multiplier[quant_param_idx + 5]), 1); - z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); + z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v); multiplier_v = _mm256_insertf128_ps( _mm256_castps128_ps256( _mm_set1_ps(r.C_multiplier[quant_param_idx + 6])), _mm_set1_ps(r.C_multiplier[quant_param_idx + 7]), 1); - w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); + w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v); } else if (C_PER_G == 8) { multiplier_v = _mm256_set1_ps( r.C_multiplier [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4]); - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); - - multiplier_v = _mm256_set1_ps( - r.C_multiplier - [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 + - 1]); - y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); - - multiplier_v = _mm256_set1_ps( - r.C_multiplier - [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 + - 2]); - z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); - - multiplier_v = _mm256_set1_ps( - r.C_multiplier - [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 + - 3]); - w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); + + multiplier_v = + _mm256_set1_ps(r.C_multiplier + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 4 + 1]); + y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v); + + multiplier_v = + _mm256_set1_ps(r.C_multiplier + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 4 + 2]); + z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v); + + multiplier_v = + _mm256_set1_ps(r.C_multiplier + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 4 + 3]); + w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v); } else { multiplier_v = _mm256_set1_ps( r.C_multiplier [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 2]); - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); - y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); - - multiplier_v = _mm256_set1_ps( - r.C_multiplier - [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 2 + - 1]); - z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); - w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); + y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v); + + multiplier_v = + _mm256_set1_ps(r.C_multiplier + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 2 + 1]); + z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v); + w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v); } } else { - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); - y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); - z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); - w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); + y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v); + z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v); + w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v); } /* @@ -1279,46 +1516,69 @@ void requantizeOutputProcessingGConvAvx2( } // i loop } -#define INSTANTIATE_REQUANTIZE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU) \ - template void \ - requantizeOutputProcessingAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \ - uint8_t * out, \ - const int32_t* inp, \ - const block_type_t& block, \ - int ld_out, \ - int ld_in, \ - const requantizationParams_t& r); \ - template void requantizeForFloatAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \ - float* out, \ - const int32_t* inp, \ - const block_type_t& block, \ - int ld_out, \ - int ld_in, \ - const requantizationForFloatParams_t& r); \ - template void \ - requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 4>( \ - uint8_t * out, \ - const int32_t* inp, \ - const block_type_t& block, \ - int ld_out, \ - int ld_in, \ - const requantizationParams_t& r); \ - template void \ - requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 8>( \ - uint8_t * out, \ - const int32_t* inp, \ - const block_type_t& block, \ - int ld_out, \ - int ld_in, \ - const requantizationParams_t& r); \ - template void \ - requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 16>( \ - uint8_t * out, \ - const int32_t* inp, \ - const block_type_t& block, \ - int ld_out, \ - int ld_in, \ - const requantizationParams_t& r); +#define INSTANTIATE_REQUANTIZE_BIAS_TYPE( \ + A_SYM, B_SYM, Q_GRAN, BIAS, RELU, BIAS_TYPE) \ + template void \ + requantizeOutputProcessingAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, BIAS_TYPE>( \ + uint8_t * out, \ + const int32_t* inp, \ + const block_type_t& block, \ + int ld_out, \ + int ld_in, \ + const requantizationParams_t<BIAS_TYPE>& r); \ + template void requantizeOutputProcessingGConvAvx2< \ + A_SYM, \ + B_SYM, \ + Q_GRAN, \ + BIAS, \ + RELU, \ + 4, \ + BIAS_TYPE>( \ + uint8_t * out, \ + const int32_t* inp, \ + const block_type_t& block, \ + int ld_out, \ + int ld_in, \ + const requantizationParams_t<BIAS_TYPE>& r); \ + template void requantizeOutputProcessingGConvAvx2< \ + A_SYM, \ + B_SYM, \ + Q_GRAN, \ + BIAS, \ + RELU, \ + 8, \ + BIAS_TYPE>( \ + uint8_t * out, \ + const int32_t* inp, \ + const block_type_t& block, \ + int ld_out, \ + int ld_in, \ + const requantizationParams_t<BIAS_TYPE>& r); \ + template void requantizeOutputProcessingGConvAvx2< \ + A_SYM, \ + B_SYM, \ + Q_GRAN, \ + BIAS, \ + RELU, \ + 16, \ + BIAS_TYPE>( \ + uint8_t * out, \ + const int32_t* inp, \ + const block_type_t& block, \ + int ld_out, \ + int ld_in, \ + const requantizationParams_t<BIAS_TYPE>& r); + +#define INSTANTIATE_REQUANTIZE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU) \ + INSTANTIATE_REQUANTIZE_BIAS_TYPE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU, float) \ + INSTANTIATE_REQUANTIZE_BIAS_TYPE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU, int32_t) \ + template void requantizeForFloatAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \ + float* out, \ + const int32_t* inp, \ + const block_type_t& block, \ + int ld_out, \ + int ld_in, \ + const requantizationForFloatParams_t& r); #define INSTANTIATE_A_SYM(B_SYM, Q_GRAN, BIAS, RELU) \ INSTANTIATE_REQUANTIZE(true, B_SYM, Q_GRAN, BIAS, RELU) \ |