diff options
Diffstat (limited to 'src/QuantUtilsAvx2.cc')
-rw-r--r-- | src/QuantUtilsAvx2.cc | 477 |
1 files changed, 330 insertions, 147 deletions
diff --git a/src/QuantUtilsAvx2.cc b/src/QuantUtilsAvx2.cc index 0643ed6..c5ef6ba 100644 --- a/src/QuantUtilsAvx2.cc +++ b/src/QuantUtilsAvx2.cc @@ -282,14 +282,15 @@ template < bool B_SYMMETRIC, QuantizationGranularity Q_GRAN, bool HAS_BIAS, - bool FUSE_RELU> + bool FUSE_RELU, + typename BIAS_TYPE> void requantizeOutputProcessingAvx2( uint8_t* out, const int32_t* inp, const block_type_t& block, int ld_out, int ld_in, - const requantizationParams_t& r) { + const requantizationParams_t<BIAS_TYPE>& r) { // Adoption of implementation at QNNPACK/src/requantization/fp32-sse2.c // using AVX2 instructions int quant_param_idx = 0; @@ -300,6 +301,15 @@ void requantizeOutputProcessingAvx2( } __m256 multiplier_v = _mm256_set1_ps(r.C_multiplier[quant_param_idx]); + // Broadcasted reciprocal of act_times_w_scale + __m256 act_times_w_rcp_v; + if (!(Q_GRAN == QuantizationGranularity::OUT_CHANNEL)) { + if (is_same<BIAS_TYPE, float>::value) { + act_times_w_rcp_v = + _mm256_set1_ps(1.0 / r.act_times_w_scale[quant_param_idx]); + } + } + __m256i min_v = _mm256_set1_epi8(static_cast<uint8_t>(0)); __m256i max_v = _mm256_set1_epi8(static_cast<uint8_t>(255)); @@ -409,22 +419,76 @@ void requantizeOutputProcessingAvx2( } w_v = _mm256_sub_epi32(w_v, row_offset_v); } + __m256 xf_v, yf_v, zf_v, wf_v; if (HAS_BIAS) { - x_v = _mm256_add_epi32( - x_v, - _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j))); - y_v = _mm256_add_epi32( - y_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(r.bias + j + VLEN))); - z_v = _mm256_add_epi32( - z_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN))); - w_v = _mm256_add_epi32( - w_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN))); + if (is_same<BIAS_TYPE, float>::value) { + __m256 x_bias_v, y_bias_v, z_bias_v, w_bias_v; + if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { + x_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)), + _mm256_loadu_ps(r.act_times_w_scale + j + 0 * VLEN)); + y_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)), + _mm256_loadu_ps(r.act_times_w_scale + j + 1 * VLEN)); + z_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)), + _mm256_loadu_ps(r.act_times_w_scale + j + 2 * VLEN)); + w_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)), + _mm256_loadu_ps(r.act_times_w_scale + j + 3 * VLEN)); + } else { + x_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)), + act_times_w_rcp_v); + y_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)), + act_times_w_rcp_v); + z_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)), + act_times_w_rcp_v); + w_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)), + act_times_w_rcp_v); + } + xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v); + yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), y_bias_v); + zf_v = _mm256_add_ps(_mm256_cvtepi32_ps(z_v), z_bias_v); + wf_v = _mm256_add_ps(_mm256_cvtepi32_ps(w_v), w_bias_v); + } else { + x_v = _mm256_add_epi32( + x_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 0 * VLEN))); + y_v = _mm256_add_epi32( + y_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 1 * VLEN))); + z_v = _mm256_add_epi32( + z_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN))); + w_v = _mm256_add_epi32( + w_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN))); + xf_v = _mm256_cvtepi32_ps(x_v); + yf_v = _mm256_cvtepi32_ps(y_v); + zf_v = _mm256_cvtepi32_ps(z_v); + wf_v = _mm256_cvtepi32_ps(w_v); + } + } else { + xf_v = _mm256_cvtepi32_ps(x_v); + yf_v = _mm256_cvtepi32_ps(y_v); + zf_v = _mm256_cvtepi32_ps(z_v); + wf_v = _mm256_cvtepi32_ps(w_v); } /* @@ -441,22 +505,19 @@ void requantizeOutputProcessingAvx2( */ __m256 x_scaled_v, y_scaled_v, z_scaled_v, w_scaled_v; if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { - x_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(x_v), _mm256_loadu_ps(r.C_multiplier + j)); - y_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(y_v), - _mm256_loadu_ps(r.C_multiplier + j + VLEN)); - z_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(z_v), - _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN)); - w_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(w_v), - _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN)); + x_scaled_v = + _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j + 0 * VLEN)); + y_scaled_v = + _mm256_mul_ps(yf_v, _mm256_loadu_ps(r.C_multiplier + j + 1 * VLEN)); + z_scaled_v = + _mm256_mul_ps(zf_v, _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN)); + w_scaled_v = + _mm256_mul_ps(wf_v, _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN)); } else { - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); - y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); - z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); - w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); + y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v); + z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v); + w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v); } /* @@ -543,18 +604,35 @@ void requantizeOutputProcessingAvx2( } x_v = _mm256_sub_epi32(x_v, row_offset_v); } + __m256 xf_v; if (HAS_BIAS) { - x_v = _mm256_add_epi32( - x_v, - _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j))); + if (is_same<BIAS_TYPE, float>::value) { + __m256 x_bias_v; + if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { + x_bias_v = _mm256_div_ps( + _mm256_loadu_ps(reinterpret_cast<const float*>(r.bias + j)), + _mm256_loadu_ps(r.act_times_w_scale + j)); + } else { + x_bias_v = _mm256_mul_ps( + _mm256_loadu_ps(reinterpret_cast<const float*>(r.bias + j)), + act_times_w_rcp_v); + } + xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v); + } else { + x_v = _mm256_add_epi32( + x_v, + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j))); + xf_v = _mm256_cvtepi32_ps(x_v); + } + } else { + xf_v = _mm256_cvtepi32_ps(x_v); } __m256 x_scaled_v; if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { - x_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(x_v), _mm256_loadu_ps(r.C_multiplier + j)); + x_scaled_v = _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j)); } else { - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); } __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); @@ -627,17 +705,40 @@ void requantizeOutputProcessingAvx2( } x_v = _mm256_sub_epi32(x_v, row_offset_v); } + + __m256 xf_v; if (HAS_BIAS) { - x_v = _mm256_add_epi32(x_v, _mm256_maskload_epi32(r.bias + j, mask_v)); + if (is_same<BIAS_TYPE, float>::value) { + __m256 x_bias_v; + if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { + x_bias_v = _mm256_div_ps( + _mm256_maskload_ps( + reinterpret_cast<const float*>(r.bias + j), mask_v), + _mm256_maskload_ps(r.act_times_w_scale + j, mask_v)); + } else { + x_bias_v = _mm256_mul_ps( + _mm256_maskload_ps( + reinterpret_cast<const float*>(r.bias + j), mask_v), + act_times_w_rcp_v); + } + xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v); + } else { + x_v = _mm256_add_epi32( + x_v, + _mm256_maskload_epi32( + reinterpret_cast<const int*>(r.bias + j), mask_v)); + xf_v = _mm256_cvtepi32_ps(x_v); + } + } else { + xf_v = _mm256_cvtepi32_ps(x_v); } __m256 x_scaled_v; if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { - x_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(x_v), - _mm256_maskload_ps(r.C_multiplier + j, mask_v)); + x_scaled_v = + _mm256_mul_ps(xf_v, _mm256_maskload_ps(r.C_multiplier + j, mask_v)); } else { - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); } __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v); @@ -845,14 +946,15 @@ template < QuantizationGranularity Q_GRAN, bool HAS_BIAS, bool FUSE_RELU, - int C_PER_G> + int C_PER_G, + typename BIAS_TYPE> void requantizeOutputProcessingGConvAvx2( uint8_t* out, const int32_t* inp, const block_type_t& block, int ld_out, int ld_in, - const requantizationParams_t& r) { + const requantizationParams_t<BIAS_TYPE>& r) { // Adoption of implementation at QNNPACK/src/requantization/fp32-sse2.c // using AVX2 instructions int quant_param_idx = 0; @@ -863,6 +965,14 @@ void requantizeOutputProcessingGConvAvx2( } __m256 multiplier_v = _mm256_set1_ps(r.C_multiplier[quant_param_idx]); + // Broadcasted reciprocal of act_times_w_scale + __m256 act_times_w_rcp_v; + if (!(Q_GRAN == QuantizationGranularity::OUT_CHANNEL)) { + if (is_same<BIAS_TYPE, float>::value) { + act_times_w_rcp_v = + _mm256_set1_ps(1.0 / r.act_times_w_scale[quant_param_idx]); + } + } __m256i min_v = _mm256_set1_epi8(static_cast<uint8_t>(0)); __m256i max_v = _mm256_set1_epi8(static_cast<uint8_t>(255)); @@ -1109,22 +1219,76 @@ void requantizeOutputProcessingGConvAvx2( } w_v = _mm256_sub_epi32(w_v, row_offset_v); } + __m256 xf_v, yf_v, zf_v, wf_v; if (HAS_BIAS) { - x_v = _mm256_add_epi32( - x_v, - _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j))); - y_v = _mm256_add_epi32( - y_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(r.bias + j + VLEN))); - z_v = _mm256_add_epi32( - z_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN))); - w_v = _mm256_add_epi32( - w_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN))); + if (is_same<BIAS_TYPE, float>::value) { + __m256 x_bias_v, y_bias_v, z_bias_v, w_bias_v; + if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { + x_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)), + _mm256_loadu_ps(r.act_times_w_scale + j + 0 * VLEN)); + y_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)), + _mm256_loadu_ps(r.act_times_w_scale + j + 1 * VLEN)); + z_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)), + _mm256_loadu_ps(r.act_times_w_scale + j + 2 * VLEN)); + w_bias_v = _mm256_div_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)), + _mm256_loadu_ps(r.act_times_w_scale + j + 3 * VLEN)); + } else { + x_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)), + act_times_w_rcp_v); + y_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)), + act_times_w_rcp_v); + z_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)), + act_times_w_rcp_v); + w_bias_v = _mm256_mul_ps( + _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)), + act_times_w_rcp_v); + } + xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v); + yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), y_bias_v); + zf_v = _mm256_add_ps(_mm256_cvtepi32_ps(z_v), z_bias_v); + wf_v = _mm256_add_ps(_mm256_cvtepi32_ps(w_v), w_bias_v); + } else { + x_v = _mm256_add_epi32( + x_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 0 * VLEN))); + y_v = _mm256_add_epi32( + y_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 1 * VLEN))); + z_v = _mm256_add_epi32( + z_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN))); + w_v = _mm256_add_epi32( + w_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN))); + xf_v = _mm256_cvtepi32_ps(x_v); + yf_v = _mm256_cvtepi32_ps(y_v); + zf_v = _mm256_cvtepi32_ps(z_v); + wf_v = _mm256_cvtepi32_ps(w_v); + } + } else { + xf_v = _mm256_cvtepi32_ps(x_v); + yf_v = _mm256_cvtepi32_ps(y_v); + zf_v = _mm256_cvtepi32_ps(z_v); + wf_v = _mm256_cvtepi32_ps(w_v); } /* @@ -1141,17 +1305,13 @@ void requantizeOutputProcessingGConvAvx2( */ __m256 x_scaled_v, y_scaled_v, z_scaled_v, w_scaled_v; if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { - x_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(x_v), _mm256_loadu_ps(r.C_multiplier + j)); - y_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(y_v), - _mm256_loadu_ps(r.C_multiplier + j + VLEN)); - z_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(z_v), - _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN)); - w_scaled_v = _mm256_mul_ps( - _mm256_cvtepi32_ps(w_v), - _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN)); + x_scaled_v = _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j)); + y_scaled_v = + _mm256_mul_ps(yf_v, _mm256_loadu_ps(r.C_multiplier + j + VLEN)); + z_scaled_v = + _mm256_mul_ps(zf_v, _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN)); + w_scaled_v = + _mm256_mul_ps(wf_v, _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN)); } else if (Q_GRAN == QuantizationGranularity::GROUP) { if (C_PER_G == 4) { multiplier_v = _mm256_insertf128_ps( @@ -1159,70 +1319,70 @@ void requantizeOutputProcessingGConvAvx2( _mm_set1_ps(r.C_multiplier[quant_param_idx])), _mm_set1_ps(r.C_multiplier[quant_param_idx + 1]), 1); - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); multiplier_v = _mm256_insertf128_ps( _mm256_castps128_ps256( _mm_set1_ps(r.C_multiplier[quant_param_idx + 2])), _mm_set1_ps(r.C_multiplier[quant_param_idx + 3]), 1); - y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); + y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v); multiplier_v = _mm256_insertf128_ps( _mm256_castps128_ps256( _mm_set1_ps(r.C_multiplier[quant_param_idx + 4])), _mm_set1_ps(r.C_multiplier[quant_param_idx + 5]), 1); - z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); + z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v); multiplier_v = _mm256_insertf128_ps( _mm256_castps128_ps256( _mm_set1_ps(r.C_multiplier[quant_param_idx + 6])), _mm_set1_ps(r.C_multiplier[quant_param_idx + 7]), 1); - w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); + w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v); } else if (C_PER_G == 8) { multiplier_v = _mm256_set1_ps( r.C_multiplier [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4]); - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); - - multiplier_v = _mm256_set1_ps( - r.C_multiplier - [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 + - 1]); - y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); - - multiplier_v = _mm256_set1_ps( - r.C_multiplier - [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 + - 2]); - z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); - - multiplier_v = _mm256_set1_ps( - r.C_multiplier - [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 + - 3]); - w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); + + multiplier_v = + _mm256_set1_ps(r.C_multiplier + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 4 + 1]); + y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v); + + multiplier_v = + _mm256_set1_ps(r.C_multiplier + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 4 + 2]); + z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v); + + multiplier_v = + _mm256_set1_ps(r.C_multiplier + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 4 + 3]); + w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v); } else { multiplier_v = _mm256_set1_ps( r.C_multiplier [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 2]); - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); - y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); - - multiplier_v = _mm256_set1_ps( - r.C_multiplier - [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 2 + - 1]); - z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); - w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); + y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v); + + multiplier_v = + _mm256_set1_ps(r.C_multiplier + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 2 + 1]); + z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v); + w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v); } } else { - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); - y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); - z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); - w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); + x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v); + y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v); + z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v); + w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v); } /* @@ -1293,46 +1453,69 @@ void requantizeOutputProcessingGConvAvx2( } // i loop } -#define INSTANTIATE_REQUANTIZE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU) \ - template void \ - requantizeOutputProcessingAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \ - uint8_t * out, \ - const int32_t* inp, \ - const block_type_t& block, \ - int ld_out, \ - int ld_in, \ - const requantizationParams_t& r); \ - template void requantizeForFloatAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \ - float* out, \ - const int32_t* inp, \ - const block_type_t& block, \ - int ld_out, \ - int ld_in, \ - const requantizationForFloatParams_t& r); \ - template void \ - requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 4>( \ - uint8_t * out, \ - const int32_t* inp, \ - const block_type_t& block, \ - int ld_out, \ - int ld_in, \ - const requantizationParams_t& r); \ - template void \ - requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 8>( \ - uint8_t * out, \ - const int32_t* inp, \ - const block_type_t& block, \ - int ld_out, \ - int ld_in, \ - const requantizationParams_t& r); \ - template void \ - requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 16>( \ - uint8_t * out, \ - const int32_t* inp, \ - const block_type_t& block, \ - int ld_out, \ - int ld_in, \ - const requantizationParams_t& r); +#define INSTANTIATE_REQUANTIZE_BIAS_TYPE( \ + A_SYM, B_SYM, Q_GRAN, BIAS, RELU, BIAS_TYPE) \ + template void \ + requantizeOutputProcessingAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, BIAS_TYPE>( \ + uint8_t * out, \ + const int32_t* inp, \ + const block_type_t& block, \ + int ld_out, \ + int ld_in, \ + const requantizationParams_t<BIAS_TYPE>& r); \ + template void requantizeOutputProcessingGConvAvx2< \ + A_SYM, \ + B_SYM, \ + Q_GRAN, \ + BIAS, \ + RELU, \ + 4, \ + BIAS_TYPE>( \ + uint8_t * out, \ + const int32_t* inp, \ + const block_type_t& block, \ + int ld_out, \ + int ld_in, \ + const requantizationParams_t<BIAS_TYPE>& r); \ + template void requantizeOutputProcessingGConvAvx2< \ + A_SYM, \ + B_SYM, \ + Q_GRAN, \ + BIAS, \ + RELU, \ + 8, \ + BIAS_TYPE>( \ + uint8_t * out, \ + const int32_t* inp, \ + const block_type_t& block, \ + int ld_out, \ + int ld_in, \ + const requantizationParams_t<BIAS_TYPE>& r); \ + template void requantizeOutputProcessingGConvAvx2< \ + A_SYM, \ + B_SYM, \ + Q_GRAN, \ + BIAS, \ + RELU, \ + 16, \ + BIAS_TYPE>( \ + uint8_t * out, \ + const int32_t* inp, \ + const block_type_t& block, \ + int ld_out, \ + int ld_in, \ + const requantizationParams_t<BIAS_TYPE>& r); + +#define INSTANTIATE_REQUANTIZE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU) \ + INSTANTIATE_REQUANTIZE_BIAS_TYPE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU, float) \ + INSTANTIATE_REQUANTIZE_BIAS_TYPE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU, int32_t) \ + template void requantizeForFloatAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \ + float* out, \ + const int32_t* inp, \ + const block_type_t& block, \ + int ld_out, \ + int ld_in, \ + const requantizationForFloatParams_t& r); #define INSTANTIATE_A_SYM(B_SYM, Q_GRAN, BIAS, RELU) \ INSTANTIATE_REQUANTIZE(true, B_SYM, Q_GRAN, BIAS, RELU) \ |