diff options
author | Daya Khudia <dskhudia@fb.com> | 2019-09-11 21:47:58 +0300 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-09-11 21:52:07 +0300 |
commit | 637288bff9972c02e72341d6a60fdf9bab1dce7e (patch) | |
tree | 3844552832c8527c5bfcb04d87b9b4132fc5bc8e | |
parent | 415035019ccbca2b11b62f1503fdd61e8bc59b10 (diff) |
ReQuantization with FP32 bias
Summary:
There is an issue in eager mode if we quantize bias using input_scale*weight_scale. See the following doc.
https://fb.quip.com/ru2eAqzsjwXc
Reviewed By: jianyuh
Differential Revision: D16948098
fbshipit-source-id: ff2c2bc560c2c14da1941d65a15c96e18f407569
-rw-r--r-- | include/fbgemm/Fbgemm.h | 20 | ||||
-rw-r--r-- | include/fbgemm/OutputProcessing-inl.h | 36 | ||||
-rw-r--r-- | include/fbgemm/QuantUtilsAvx2.h | 10 | ||||
-rw-r--r-- | include/fbgemm/UtilsAvx2.h | 5 | ||||
-rw-r--r-- | src/GroupwiseConvAcc32Avx2.cc | 20 | ||||
-rw-r--r-- | src/QuantUtilsAvx2.cc | 477 |
6 files changed, 390 insertions, 178 deletions
diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h index 0b7bf1f..2f73de4 100644 --- a/include/fbgemm/Fbgemm.h +++ b/include/fbgemm/Fbgemm.h @@ -1128,6 +1128,7 @@ class FBGEMM_API DoSConvOnInpBuffer { template < bool FUSE_RELU, QuantizationGranularity Q_GRAN = QuantizationGranularity::TENSOR, + typename BIAS_TYPE = std::int32_t, typename outT = std::uint8_t, typename inT = std::int32_t, typename nextOPType = DoNothing<outT, outT>> @@ -1135,6 +1136,7 @@ class FBGEMM_API ReQuantizeOutput { public: static constexpr int RELU_FUSED = FUSE_RELU; static constexpr QuantizationGranularity QGRANType = Q_GRAN; + using BIAS_T = BIAS_TYPE; using outType = outT; using inpType = inT; /** @@ -1155,6 +1157,8 @@ class FBGEMM_API ReQuantizeOutput { * See PackedRequantizeTest.cc for an example. * TODO: if Aq_zero_point == 0, allow passing nullptr. * @params bias can be nullptr otherwise the length should be nCol + * @params act_times_w_scale activation_scale * weight_scale. This is only + * used if bias is unquantized (i.e., float). */ ReQuantizeOutput( nextOPType& nextop, @@ -1164,9 +1168,10 @@ class FBGEMM_API ReQuantizeOutput { const std::int32_t* Bq_zero_point, const std::int32_t* row_offsets, const std::int32_t* col_offsets, - const std::int32_t* bias, + const BIAS_T* bias, std::uint32_t nCol, - int groups = 1) + int groups = 1, + const float* act_times_w_scale = nullptr) : nextop_(nextop), C_multiplier_(C_multiplier), C_zero_point_(C_zero_point), @@ -1176,7 +1181,8 @@ class FBGEMM_API ReQuantizeOutput { q_col_offsets_(col_offsets), bias_(bias), ncols_(nCol), - groups_(groups) {} + groups_(groups), + act_times_w_scale_(act_times_w_scale) {} template <inst_set_t instSet> inline int f( @@ -1204,12 +1210,15 @@ class FBGEMM_API ReQuantizeOutput { const std::int32_t* getColOffsets() const { return q_col_offsets_; } - const std::int32_t* getBias() const { + const BIAS_T* getBias() const { return bias_; } std::uint32_t getNCols() const { return ncols_; } + const float* getActWScale() const { + return act_times_w_scale_; + } void setRowOffsets(const std::int32_t* row_offsets) { q_row_offsets_ = row_offsets; @@ -1223,9 +1232,10 @@ class FBGEMM_API ReQuantizeOutput { const std::int32_t* Bq_zero_point_; const std::int32_t* q_row_offsets_; const std::int32_t* q_col_offsets_; - const std::int32_t* bias_; + const BIAS_T* bias_; std::uint32_t ncols_; int groups_; + const float* act_times_w_scale_; }; /** diff --git a/include/fbgemm/OutputProcessing-inl.h b/include/fbgemm/OutputProcessing-inl.h index d984c60..04ae100 100644 --- a/include/fbgemm/OutputProcessing-inl.h +++ b/include/fbgemm/OutputProcessing-inl.h @@ -59,11 +59,13 @@ inline int DoSConvOnInpBuffer<outT, inT, nextOPType>::f( template < bool FUSE_RELU, QuantizationGranularity Q_GRAN, + typename BIAS_TYPE, typename outT, typename inT, typename nextOPType> template <inst_set_t instSet> -inline int ReQuantizeOutput<FUSE_RELU, Q_GRAN, outT, inT, nextOPType>::f( +inline int +ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE, outT, inT, nextOPType>::f( outT* out, const inT* inp, const block_type_t& block, @@ -98,11 +100,20 @@ inline int ReQuantizeOutput<FUSE_RELU, Q_GRAN, outT, inT, nextOPType>::f( raw -= q_row_offsets_[i - block.row_start] * Bq_zero_point_[Bq_zero_point_idx]; } + float raw_f; if (bias_) { - raw += bias_[j]; + if (std::is_same<BIAS_TYPE, float>::value) { + raw_f = raw; + raw_f += bias_[j] / act_times_w_scale_[Bq_zero_point_idx]; + } else { + raw += bias_[j]; + raw_f = raw; + } + } else { + raw_f = raw; } - float ab = raw * C_multiplier_[Bq_zero_point_idx]; + float ab = raw_f * C_multiplier_[Bq_zero_point_idx]; long rounded = std::lrintf(ab) + C_zero_point_; out[i * ld_out + j] = std::max( @@ -115,15 +126,16 @@ inline int ReQuantizeOutput<FUSE_RELU, Q_GRAN, outT, inT, nextOPType>::f( Bq_zero_point_[0] == 0) || q_row_offsets_ == nullptr; - requantizationParams_t r = {Aq_zero_point_, - Bq_zero_point_, - C_zero_point_, - C_multiplier_, - q_row_offsets_, - q_col_offsets_, - bias_, - ncols_, - groups_}; + requantizationParams_t<BIAS_TYPE> r = {Aq_zero_point_, + Bq_zero_point_, + C_zero_point_, + C_multiplier_, + q_row_offsets_, + q_col_offsets_, + bias_, + ncols_, + groups_, + act_times_w_scale_}; if (Aq_zero_point_ == 0) { if (b_symmetric) { diff --git a/include/fbgemm/QuantUtilsAvx2.h b/include/fbgemm/QuantUtilsAvx2.h index a001004..c7f3f35 100644 --- a/include/fbgemm/QuantUtilsAvx2.h +++ b/include/fbgemm/QuantUtilsAvx2.h @@ -72,14 +72,15 @@ template < bool B_SYMMETRIC, QuantizationGranularity Q_GRAN, bool HAS_BIAS, - bool FUSE_RELU> + bool FUSE_RELU, + typename BIAS_TYPE = std::int32_t> FBGEMM_API void requantizeOutputProcessingAvx2( std::uint8_t* out, const std::int32_t* inp, const block_type_t& block, int ld_out, int ld_in, - const requantizationParams_t& r); + const requantizationParams_t<BIAS_TYPE>& r); template < bool A_SYMMETRIC, @@ -87,14 +88,15 @@ template < QuantizationGranularity Q_GRAN, bool HAS_BIAS, bool FUSE_RELU, - int C_PER_G> + int C_PER_G, + typename BIAS_TYPE = std::int32_t> FBGEMM_API void requantizeOutputProcessingGConvAvx2( std::uint8_t* out, const std::int32_t* inp, const block_type_t& block, int ld_out, int ld_in, - const requantizationParams_t& r); + const requantizationParams_t<BIAS_TYPE>& r); template < bool A_SYMMETRIC, diff --git a/include/fbgemm/UtilsAvx2.h b/include/fbgemm/UtilsAvx2.h index 082edc1..3bac909 100644 --- a/include/fbgemm/UtilsAvx2.h +++ b/include/fbgemm/UtilsAvx2.h @@ -44,16 +44,19 @@ struct block_type_t { * QuantUtilsAvx2.h as it combines all the parameters needed for various * quantization granularities */ +template<typename BIAS_TYPE = std::int32_t> struct requantizationParams_t { + using BIAS_T = BIAS_TYPE; std::int32_t A_zero_point; const std::int32_t* B_zero_point; std::int32_t C_zero_point; const float* C_multiplier; const std::int32_t* row_offsets; const std::int32_t* col_offsets; - const std::int32_t* bias; + const BIAS_T* bias; std::uint32_t ncols; int groups; + const float* act_times_w_scale; }; /** diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc index ef4ba7b..40f3fba 100644 --- a/src/GroupwiseConvAcc32Avx2.cc +++ b/src/GroupwiseConvAcc32Avx2.cc @@ -1872,15 +1872,17 @@ void fbgemmGroupwiseConv( outProcess.getBZeroPoint()[0] == 0) || rowOffsetBuf == nullptr; - requantizationParams_t r = {a_zero_point, - outProcess.getBZeroPoint(), - outProcess.getCZeroPoint(), - outProcess.getCMultiplier(), - rowOffsetBuf, - outProcess.getColOffsets(), - outProcess.getBias(), - outProcess.getNCols(), - G}; + requantizationParams_t<typename processOutputType::BIAS_T> r = { + a_zero_point, + outProcess.getBZeroPoint(), + outProcess.getCZeroPoint(), + outProcess.getCMultiplier(), + rowOffsetBuf, + outProcess.getColOffsets(), + outProcess.getBias(), + outProcess.getNCols(), + G, + outProcess.getActWScale()}; const std::int32_t* inp = outBuffer; block_type_t block{i * oh_ow, oh_ow, gOuter * K_per_G, 8 * K_per_G}; diff --git a/src/QuantUtilsAvx2.cc b/src/QuantUtilsAvx2.cc index 0643ed6..c5ef6ba 100644 --- a/src/QuantUtilsAvx2.cc +++ b/src/QuantUtilsAvx2.cc @@ -282,14 +282,15 @@ template < bool B_SYMMETRIC, QuantizationGranularity Q_GRAN, bool HAS_BIAS, - bool FUSE_RELU> + bool FUSE_RELU, + typename BIAS_TYPE> void requantizeOutputProcessingAvx2( uint8_t* out, const int32_t* inp, const block_type_t& block, int ld_out, int ld_in, - const requantizationParams_t& r) { + const requantizationParams_t<BIAS_TYPE>& r) { // Adoption of implementation at QNNPACK/src/requantization/fp32-sse2.c // using AVX2 instructions int quant_param_idx = 0; @@ -300,6 +301,15 @@ void requantizeOutputProcessingAvx2( } __m256 multiplier_v = _mm256_set1_ps(r.C_multiplier[quant_param_idx]); + // Broadcasted reciprocal of act_times_w_scale + __m256 act_times_w_rcp_v; + if (!(Q_GRAN == QuantizationGranularity::OUT_CHANNEL)) { + if (is_same<BIAS_TYPE, float>::value) { + act_times_w_rcp_v = + _mm256_set1_ps(1.0 / r.act_times_w_scale[quant_param_idx]); + } + } + __m256i min_v = _mm256_set1_epi8(static_cast<uint8_t>(0)); __m256i max_v = _mm256_set1_epi8(static_cast<uint8_t>(255)); @@ -409,22 +419,76 @@ void requantizeOutputProcessingAvx2( } w_v = _mm256_sub_epi32(w_v, row_offset_v); } + __m256 xf_v, yf_v, zf_v, wf_v; if (HAS_BIAS) { - x_v = _mm256_add_epi32( - x_v, - _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j))); - y_v = _mm256_add_epi32( - y_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(r.bias + j + VLEN))); - z_v = _mm256_add_epi32( - z_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN))); - w_v = _mm256_add_epi32( - w_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN))); + if (is_same<BIAS_TYPE, float>::value) { + __m256 x_bias_v, y_bias_v, z_bias_v, w_bias_v; + if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { + x_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)), + _mm256_loadu_ps(r.act_times_w_scale + j + 0 * VLEN)); + y_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)), + _mm256_loadu_ps(r.act_times_w_scale + j + 1 * VLEN)); + z_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)), + _mm256_loadu_ps(r.act_times_w_scale + j + 2 * VLEN)); + w_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)), + _mm256_loadu_ps(r.act_times_w_scale + j + 3 * VLEN)); + } else { + x_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)), + act_times_w_rcp_v); + y_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)), + act_times_w_rcp_v); + z_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)), + act_times_w_rcp_v); + w_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)), + act_times_w_rcp_v); + } + xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v); + yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), y_bias_v); + zf_v = _mm256_add_ps(_mm256_cvtepi32_ps(z_v), z_bias_v); + wf_v = _mm256_add_ps(_mm256_cvtepi32_ps(w_v), w_bias_v); + } else { + x_v = _mm256_add_epi32( + x_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 0 * VLEN))); + y_v = _mm256_add_epi32( + y_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 1 * VLEN))); + z_v = _mm256_add_epi32( + z_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN))); + w_v = _mm256_add_epi32( + w_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN))); + xf_v = _mm256_cvtepi32_ps(x_v); + yf_v = _mm256_cvtepi32_ps(y_v); + zf_v = _mm256_cvtepi32_ps(z_v); + wf_v = _mm256_cvtepi32_ps(w_v); + } + } else { + xf_v = _mm256_cvtepi32_ps(x_v); + yf_v = _mm256_cvtepi32_ps(y_v); + zf_v = _mm256_cvtepi32_ps(z_v); + wf_v = _mm256_cvtepi32_ps(w_v); } /* @@ -441,22 +505,19 @@ void requantizeOutputProcessingAvx2( */ __m256 x_scaled_v, y_scaled_v, z_scaled_v, w_scaled_v; if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { - x_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(x_v), _mm256_loadu_ps(r.C_multiplier + j)); - y_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(y_v), - _mm256_loadu_ps(r.C_multiplier + j + VLEN)); - z_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(z_v), - _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN)); - w_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(w_v), - _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN)); + x_scaled_v = + _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j + 0 * VLEN)); + y_scaled_v = + _mm256_mul_ps(yf_v, _mm256_loadu_ps(r.C_multiplier + j + 1 * VLEN)); + z_scaled_v = + _mm256_mul_ps(zf_v, _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN)); + w_scaled_v = + _mm256_mul_ps(wf_v, _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN)); } else { - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); - y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); - z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); - w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); + y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v); + z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v); + w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v); } /* @@ -543,18 +604,35 @@ void requantizeOutputProcessingAvx2( } x_v = _mm256_sub_epi32(x_v, row_offset_v); } + __m256 xf_v; if (HAS_BIAS) { - x_v = _mm256_add_epi32( - x_v, - _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j))); + if (is_same<BIAS_TYPE, float>::value) { + __m256 x_bias_v; + if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { + x_bias_v = _mm256_div_ps( + _mm256_loadu_ps(reinterpret_cast<const float*>(r.bias + j)), + _mm256_loadu_ps(r.act_times_w_scale + j)); + } else { + x_bias_v = _mm256_mul_ps( + _mm256_loadu_ps(reinterpret_cast<const float*>(r.bias + j)), + act_times_w_rcp_v); + } + xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v); + } else { + x_v = _mm256_add_epi32( + x_v, + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j))); + xf_v = _mm256_cvtepi32_ps(x_v); + } + } else { + xf_v = _mm256_cvtepi32_ps(x_v); } __m256 x_scaled_v; if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { - x_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(x_v), _mm256_loadu_ps(r.C_multiplier + j)); + x_scaled_v = _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j)); } else { - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); } __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); @@ -627,17 +705,40 @@ void requantizeOutputProcessingAvx2( } x_v = _mm256_sub_epi32(x_v, row_offset_v); } + + __m256 xf_v; if (HAS_BIAS) { - x_v = _mm256_add_epi32(x_v, _mm256_maskload_epi32(r.bias + j, mask_v)); + if (is_same<BIAS_TYPE, float>::value) { + __m256 x_bias_v; + if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { + x_bias_v = _mm256_div_ps( + _mm256_maskload_ps( + reinterpret_cast<const float*>(r.bias + j), mask_v), + _mm256_maskload_ps(r.act_times_w_scale + j, mask_v)); + } else { + x_bias_v = _mm256_mul_ps( + _mm256_maskload_ps( + reinterpret_cast<const float*>(r.bias + j), mask_v), + act_times_w_rcp_v); + } + xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v); + } else { + x_v = _mm256_add_epi32( + x_v, + _mm256_maskload_epi32( + reinterpret_cast<const int*>(r.bias + j), mask_v)); + xf_v = _mm256_cvtepi32_ps(x_v); + } + } else { + xf_v = _mm256_cvtepi32_ps(x_v); } __m256 x_scaled_v; if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { - x_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(x_v), - _mm256_maskload_ps(r.C_multiplier + j, mask_v)); + x_scaled_v = + _mm256_mul_ps(xf_v, _mm256_maskload_ps(r.C_multiplier + j, mask_v)); } else { - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); } __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); @@ -845,14 +946,15 @@ template < QuantizationGranularity Q_GRAN, bool HAS_BIAS, bool FUSE_RELU, - int C_PER_G> + int C_PER_G, + typename BIAS_TYPE> void requantizeOutputProcessingGConvAvx2( uint8_t* out, const int32_t* inp, const block_type_t& block, int ld_out, int ld_in, - const requantizationParams_t& r) { + const requantizationParams_t<BIAS_TYPE>& r) { // Adoption of implementation at QNNPACK/src/requantization/fp32-sse2.c // using AVX2 instructions int quant_param_idx = 0; @@ -863,6 +965,14 @@ void requantizeOutputProcessingGConvAvx2( } __m256 multiplier_v = _mm256_set1_ps(r.C_multiplier[quant_param_idx]); + // Broadcasted reciprocal of act_times_w_scale + __m256 act_times_w_rcp_v; + if (!(Q_GRAN == QuantizationGranularity::OUT_CHANNEL)) { + if (is_same<BIAS_TYPE, float>::value) { + act_times_w_rcp_v = + _mm256_set1_ps(1.0 / r.act_times_w_scale[quant_param_idx]); + } + } __m256i min_v = _mm256_set1_epi8(static_cast<uint8_t>(0)); __m256i max_v = _mm256_set1_epi8(static_cast<uint8_t>(255)); @@ -1109,22 +1219,76 @@ void requantizeOutputProcessingGConvAvx2( } w_v = _mm256_sub_epi32(w_v, row_offset_v); } + __m256 xf_v, yf_v, zf_v, wf_v; if (HAS_BIAS) { - x_v = _mm256_add_epi32( - x_v, - _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j))); - y_v = _mm256_add_epi32( - y_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(r.bias + j + VLEN))); - z_v = _mm256_add_epi32( - z_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN))); - w_v = _mm256_add_epi32( - w_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN))); + if (is_same<BIAS_TYPE, float>::value) { + __m256 x_bias_v, y_bias_v, z_bias_v, w_bias_v; + if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { + x_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)), + _mm256_loadu_ps(r.act_times_w_scale + j + 0 * VLEN)); + y_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)), + _mm256_loadu_ps(r.act_times_w_scale + j + 1 * VLEN)); + z_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)), + _mm256_loadu_ps(r.act_times_w_scale + j + 2 * VLEN)); + w_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)), + _mm256_loadu_ps(r.act_times_w_scale + j + 3 * VLEN)); + } else { + x_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)), + act_times_w_rcp_v); + y_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)), + act_times_w_rcp_v); + z_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)), + act_times_w_rcp_v); + w_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)), + act_times_w_rcp_v); + } + xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v); + yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), y_bias_v); + zf_v = _mm256_add_ps(_mm256_cvtepi32_ps(z_v), z_bias_v); + wf_v = _mm256_add_ps(_mm256_cvtepi32_ps(w_v), w_bias_v); + } else { + x_v = _mm256_add_epi32( + x_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 0 * VLEN))); + y_v = _mm256_add_epi32( + y_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 1 * VLEN))); + z_v = _mm256_add_epi32( + z_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN))); + w_v = _mm256_add_epi32( + w_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN))); + xf_v = _mm256_cvtepi32_ps(x_v); + yf_v = _mm256_cvtepi32_ps(y_v); + zf_v = _mm256_cvtepi32_ps(z_v); + wf_v = _mm256_cvtepi32_ps(w_v); + } + } else { + xf_v = _mm256_cvtepi32_ps(x_v); + yf_v = _mm256_cvtepi32_ps(y_v); + zf_v = _mm256_cvtepi32_ps(z_v); + wf_v = _mm256_cvtepi32_ps(w_v); } /* @@ -1141,17 +1305,13 @@ void requantizeOutputProcessingGConvAvx2( */ __m256 x_scaled_v, y_scaled_v, z_scaled_v, w_scaled_v; if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { - x_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(x_v), _mm256_loadu_ps(r.C_multiplier + j)); - y_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(y_v), - _mm256_loadu_ps(r.C_multiplier + j + VLEN)); - z_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(z_v), - _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN)); - w_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(w_v), - _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN)); + x_scaled_v = _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j)); + y_scaled_v = + _mm256_mul_ps(yf_v, _mm256_loadu_ps(r.C_multiplier + j + VLEN)); + z_scaled_v = + _mm256_mul_ps(zf_v, _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN)); + w_scaled_v = + _mm256_mul_ps(wf_v, _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN)); } else if (Q_GRAN == QuantizationGranularity::GROUP) { if (C_PER_G == 4) { multiplier_v = _mm256_insertf128_ps( @@ -1159,70 +1319,70 @@ void requantizeOutputProcessingGConvAvx2( _mm_set1_ps(r.C_multiplier[quant_param_idx])), _mm_set1_ps(r.C_multiplier[quant_param_idx + 1]), 1); - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); multiplier_v = _mm256_insertf128_ps( _mm256_castps128_ps256( _mm_set1_ps(r.C_multiplier[quant_param_idx + 2])), _mm_set1_ps(r.C_multiplier[quant_param_idx + 3]), 1); - y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); + y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v); multiplier_v = _mm256_insertf128_ps( _mm256_castps128_ps256( _mm_set1_ps(r.C_multiplier[quant_param_idx + 4])), _mm_set1_ps(r.C_multiplier[quant_param_idx + 5]), 1); - z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); + z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v); multiplier_v = _mm256_insertf128_ps( _mm256_castps128_ps256( _mm_set1_ps(r.C_multiplier[quant_param_idx + 6])), _mm_set1_ps(r.C_multiplier[quant_param_idx + 7]), 1); - w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); + w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v); } else if (C_PER_G == 8) { multiplier_v = _mm256_set1_ps( r.C_multiplier [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4]); - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); - - multiplier_v = _mm256_set1_ps( - r.C_multiplier - [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 + - 1]); - y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); - - multiplier_v = _mm256_set1_ps( - r.C_multiplier - [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 + - 2]); - z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); - - multiplier_v = _mm256_set1_ps( - r.C_multiplier - [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 + - 3]); - w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); + + multiplier_v = + _mm256_set1_ps(r.C_multiplier + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 4 + 1]); + y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v); + + multiplier_v = + _mm256_set1_ps(r.C_multiplier + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 4 + 2]); + z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v); + + multiplier_v = + _mm256_set1_ps(r.C_multiplier + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 4 + 3]); + w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v); } else { multiplier_v = _mm256_set1_ps( r.C_multiplier [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 2]); - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); - y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); - - multiplier_v = _mm256_set1_ps( - r.C_multiplier - [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 2 + - 1]); - z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); - w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); + y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v); + + multiplier_v = + _mm256_set1_ps(r.C_multiplier + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 2 + 1]); + z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v); + w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v); } } else { - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); - y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); - z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); - w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); + y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v); + z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v); + w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v); } /* @@ -1293,46 +1453,69 @@ void requantizeOutputProcessingGConvAvx2( } // i loop } -#define INSTANTIATE_REQUANTIZE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU) \ - template void \ - requantizeOutputProcessingAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \ - uint8_t * out, \ - const int32_t* inp, \ - const block_type_t& block, \ - int ld_out, \ - int ld_in, \ - const requantizationParams_t& r); \ - template void requantizeForFloatAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \ - float* out, \ - const int32_t* inp, \ - const block_type_t& block, \ - int ld_out, \ - int ld_in, \ - const requantizationForFloatParams_t& r); \ - template void \ - requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 4>( \ - uint8_t * out, \ - const int32_t* inp, \ - const block_type_t& block, \ - int ld_out, \ - int ld_in, \ - const requantizationParams_t& r); \ - template void \ - requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 8>( \ - uint8_t * out, \ - const int32_t* inp, \ - const block_type_t& block, \ - int ld_out, \ - int ld_in, \ - const requantizationParams_t& r); \ - template void \ - requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 16>( \ - uint8_t * out, \ - const int32_t* inp, \ - const block_type_t& block, \ - int ld_out, \ - int ld_in, \ - const requantizationParams_t& r); +#define INSTANTIATE_REQUANTIZE_BIAS_TYPE( \ + A_SYM, B_SYM, Q_GRAN, BIAS, RELU, BIAS_TYPE) \ + template void \ + requantizeOutputProcessingAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, BIAS_TYPE>( \ + uint8_t * out, \ + const int32_t* inp, \ + const block_type_t& block, \ + int ld_out, \ + int ld_in, \ + const requantizationParams_t<BIAS_TYPE>& r); \ + template void requantizeOutputProcessingGConvAvx2< \ + A_SYM, \ + B_SYM, \ + Q_GRAN, \ + BIAS, \ + RELU, \ + 4, \ + BIAS_TYPE>( \ + uint8_t * out, \ + const int32_t* inp, \ + const block_type_t& block, \ + int ld_out, \ + int ld_in, \ + const requantizationParams_t<BIAS_TYPE>& r); \ + template void requantizeOutputProcessingGConvAvx2< \ + A_SYM, \ + B_SYM, \ + Q_GRAN, \ + BIAS, \ + RELU, \ + 8, \ + BIAS_TYPE>( \ + uint8_t * out, \ + const int32_t* inp, \ + const block_type_t& block, \ + int ld_out, \ + int ld_in, \ + const requantizationParams_t<BIAS_TYPE>& r); \ + template void requantizeOutputProcessingGConvAvx2< \ + A_SYM, \ + B_SYM, \ + Q_GRAN, \ + BIAS, \ + RELU, \ + 16, \ + BIAS_TYPE>( \ + uint8_t * out, \ + const int32_t* inp, \ + const block_type_t& block, \ + int ld_out, \ + int ld_in, \ + const requantizationParams_t<BIAS_TYPE>& r); + +#define INSTANTIATE_REQUANTIZE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU) \ + INSTANTIATE_REQUANTIZE_BIAS_TYPE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU, float) \ + INSTANTIATE_REQUANTIZE_BIAS_TYPE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU, int32_t) \ + template void requantizeForFloatAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \ + float* out, \ + const int32_t* inp, \ + const block_type_t& block, \ + int ld_out, \ + int ld_in, \ + const requantizationForFloatParams_t& r); #define INSTANTIATE_A_SYM(B_SYM, Q_GRAN, BIAS, RELU) \ INSTANTIATE_REQUANTIZE(true, B_SYM, Q_GRAN, BIAS, RELU) \ |