diff options
author | Daya Khudia <dskhudia@fb.com> | 2019-09-13 23:35:17 +0300 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-09-13 23:38:40 +0300 |
commit | c8cac64995d8d8af871e461affbf505ac7fce4d8 (patch) | |
tree | 164c78e0b7f1b8a8148eb79ffb861a5c050f251c | |
parent | ea787e8278744ab4c7d6c4ee42a050bb1c76ef88 (diff) |
add missing instantiation for float bias for gconv (#127)
Summary:
Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/127
float bias was going through a slow path. Adding a missing specialization.
Reviewed By: protonu, jianyuh
Differential Revision: D17346881
fbshipit-source-id: dd6b40d80c3c429b438ea6b4e1520b935e582c4a
-rw-r--r-- | include/fbgemm/Fbgemm.h | 5 | ||||
-rw-r--r-- | src/GroupwiseConvAcc32Avx2.cc | 7 | ||||
-rw-r--r-- | src/QuantUtilsAvx2.cc | 117 |
3 files changed, 95 insertions, 34 deletions
diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h index 2f73de4..4efd181 100644 --- a/include/fbgemm/Fbgemm.h +++ b/include/fbgemm/Fbgemm.h @@ -1388,7 +1388,8 @@ template < typename outType, bool FUSE_RELU, QuantizationGranularity Q_GRAN, - int SPATIAL_DIM = 2> + int SPATIAL_DIM = 2, + typename BIAS_TYPE = std::int32_t> FBGEMM_API void fbgemmGroupwiseConv( const conv_param_t<SPATIAL_DIM>& conv_param, const std::uint8_t* activations, @@ -1397,7 +1398,7 @@ FBGEMM_API void fbgemmGroupwiseConv( packed_W& packed_weights, outType* out, std::int32_t* outBuffer, - const ReQuantizeOutput<FUSE_RELU, Q_GRAN>& outProcess, + const ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE>& outProcess, int thread_id, int num_threads); diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc index 4ba3549..d1e0fdd 100644 --- a/src/GroupwiseConvAcc32Avx2.cc +++ b/src/GroupwiseConvAcc32Avx2.cc @@ -1769,7 +1769,8 @@ template < typename outType, bool FUSE_RELU, QuantizationGranularity Q_GRAN, - int SPATIAL_DIM> + int SPATIAL_DIM, + typename BIAS_TYPE> void fbgemmGroupwiseConv( const conv_param_t<SPATIAL_DIM>& conv_param, const std::uint8_t* activations, @@ -1778,10 +1779,10 @@ void fbgemmGroupwiseConv( packed_W& packed_weights, outType* out, int32_t* outBuffer, - const ReQuantizeOutput<FUSE_RELU, Q_GRAN>& outProcess, + const ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE>& outProcess, int thread_id, int num_threads) { - typedef ReQuantizeOutput<FUSE_RELU, Q_GRAN> processOutputType; + typedef ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE> processOutputType; if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); diff --git a/src/QuantUtilsAvx2.cc b/src/QuantUtilsAvx2.cc index c5ef6ba..c50f6d9 100644 --- a/src/QuantUtilsAvx2.cc +++ b/src/QuantUtilsAvx2.cc @@ -1222,41 +1222,100 @@ void requantizeOutputProcessingGConvAvx2( __m256 xf_v, yf_v, zf_v, wf_v; if (HAS_BIAS) { if (is_same<BIAS_TYPE, float>::value) { - __m256 x_bias_v, y_bias_v, z_bias_v, w_bias_v; + __m256 x_bias_v = _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)); + __m256 y_bias_v = _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)); + __m256 z_bias_v = _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)); + __m256 w_bias_v = _mm256_loadu_ps( + reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)); if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { x_bias_v = _mm256_div_ps( - _mm256_loadu_ps( - reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)), - _mm256_loadu_ps(r.act_times_w_scale + j + 0 * VLEN)); + x_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 0 * VLEN)); y_bias_v = _mm256_div_ps( - _mm256_loadu_ps( - reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)), - _mm256_loadu_ps(r.act_times_w_scale + j + 1 * VLEN)); + y_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 1 * VLEN)); z_bias_v = _mm256_div_ps( - _mm256_loadu_ps( - reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)), - _mm256_loadu_ps(r.act_times_w_scale + j + 2 * VLEN)); + z_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 2 * VLEN)); w_bias_v = _mm256_div_ps( - _mm256_loadu_ps( - reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)), - _mm256_loadu_ps(r.act_times_w_scale + j + 3 * VLEN)); + w_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 3 * VLEN)); + } else if (Q_GRAN == QuantizationGranularity::GROUP) { + __m256 diviser_v; + if (C_PER_G == 4) { + diviser_v = _mm256_insertf128_ps( + _mm256_castps128_ps256( + _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 0])), + _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 1]), + 1); + x_bias_v = _mm256_div_ps(x_bias_v, diviser_v); + + diviser_v = _mm256_insertf128_ps( + _mm256_castps128_ps256( + _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 2])), + _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 3]), + 1); + y_bias_v = _mm256_div_ps(y_bias_v, diviser_v); + + diviser_v = _mm256_insertf128_ps( + _mm256_castps128_ps256( + _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 4])), + _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 5]), + 1); + z_bias_v = _mm256_div_ps(z_bias_v, diviser_v); + + diviser_v = _mm256_insertf128_ps( + _mm256_castps128_ps256( + _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 6])), + _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 7]), + 1); + w_bias_v = _mm256_div_ps(w_bias_v, diviser_v); + + } else if (C_PER_G == 8) { + diviser_v = _mm256_set1_ps( + r.act_times_w_scale + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 4 + 0]); + x_bias_v = _mm256_div_ps(x_bias_v, diviser_v); + + diviser_v = _mm256_set1_ps( + r.act_times_w_scale + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 4 + 1]); + y_bias_v = _mm256_div_ps(y_bias_v, diviser_v); + + diviser_v = _mm256_set1_ps( + r.act_times_w_scale + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 4 + 2]); + z_bias_v = _mm256_div_ps(z_bias_v, diviser_v); + + diviser_v = _mm256_set1_ps( + r.act_times_w_scale + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 4 + 3]); + w_bias_v = _mm256_div_ps(w_bias_v, diviser_v); + + } else { + assert(C_PER_G == 16); + diviser_v = _mm256_set1_ps( + r.act_times_w_scale + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 2 + 0]); + x_bias_v = _mm256_div_ps(x_bias_v, diviser_v); + y_bias_v = _mm256_div_ps(y_bias_v, diviser_v); + + diviser_v = _mm256_set1_ps( + r.act_times_w_scale + [quant_param_idx + + (j - block.col_start) / (VLEN * 4) * 2 + 1]); + z_bias_v = _mm256_div_ps(z_bias_v, diviser_v); + w_bias_v = _mm256_div_ps(w_bias_v, diviser_v); + } } else { - x_bias_v = _mm256_mul_ps( - _mm256_loadu_ps( - reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)), - act_times_w_rcp_v); - y_bias_v = _mm256_mul_ps( - _mm256_loadu_ps( - reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)), - act_times_w_rcp_v); - z_bias_v = _mm256_mul_ps( - _mm256_loadu_ps( - reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)), - act_times_w_rcp_v); - w_bias_v = _mm256_mul_ps( - _mm256_loadu_ps( - reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)), - act_times_w_rcp_v); + x_bias_v = _mm256_mul_ps(x_bias_v, act_times_w_rcp_v); + y_bias_v = _mm256_mul_ps(y_bias_v, act_times_w_rcp_v); + z_bias_v = _mm256_mul_ps(z_bias_v, act_times_w_rcp_v); + w_bias_v = _mm256_mul_ps(w_bias_v, act_times_w_rcp_v); } xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v); yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), y_bias_v); |