Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaya Khudia <dskhudia@fb.com>2019-09-13 23:35:17 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-09-13 23:38:40 +0300
commitc8cac64995d8d8af871e461affbf505ac7fce4d8 (patch)
tree164c78e0b7f1b8a8148eb79ffb861a5c050f251c
parentea787e8278744ab4c7d6c4ee42a050bb1c76ef88 (diff)
add missing instantiation for float bias for gconv (#127)
Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/127 float bias was going through a slow path. Adding a missing specialization. Reviewed By: protonu, jianyuh Differential Revision: D17346881 fbshipit-source-id: dd6b40d80c3c429b438ea6b4e1520b935e582c4a
-rw-r--r--include/fbgemm/Fbgemm.h5
-rw-r--r--src/GroupwiseConvAcc32Avx2.cc7
-rw-r--r--src/QuantUtilsAvx2.cc117
3 files changed, 95 insertions, 34 deletions
diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h
index 2f73de4..4efd181 100644
--- a/include/fbgemm/Fbgemm.h
+++ b/include/fbgemm/Fbgemm.h
@@ -1388,7 +1388,8 @@ template <
typename outType,
bool FUSE_RELU,
QuantizationGranularity Q_GRAN,
- int SPATIAL_DIM = 2>
+ int SPATIAL_DIM = 2,
+ typename BIAS_TYPE = std::int32_t>
FBGEMM_API void fbgemmGroupwiseConv(
const conv_param_t<SPATIAL_DIM>& conv_param,
const std::uint8_t* activations,
@@ -1397,7 +1398,7 @@ FBGEMM_API void fbgemmGroupwiseConv(
packed_W& packed_weights,
outType* out,
std::int32_t* outBuffer,
- const ReQuantizeOutput<FUSE_RELU, Q_GRAN>& outProcess,
+ const ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE>& outProcess,
int thread_id,
int num_threads);
diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc
index 4ba3549..d1e0fdd 100644
--- a/src/GroupwiseConvAcc32Avx2.cc
+++ b/src/GroupwiseConvAcc32Avx2.cc
@@ -1769,7 +1769,8 @@ template <
typename outType,
bool FUSE_RELU,
QuantizationGranularity Q_GRAN,
- int SPATIAL_DIM>
+ int SPATIAL_DIM,
+ typename BIAS_TYPE>
void fbgemmGroupwiseConv(
const conv_param_t<SPATIAL_DIM>& conv_param,
const std::uint8_t* activations,
@@ -1778,10 +1779,10 @@ void fbgemmGroupwiseConv(
packed_W& packed_weights,
outType* out,
int32_t* outBuffer,
- const ReQuantizeOutput<FUSE_RELU, Q_GRAN>& outProcess,
+ const ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE>& outProcess,
int thread_id,
int num_threads) {
- typedef ReQuantizeOutput<FUSE_RELU, Q_GRAN> processOutputType;
+ typedef ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE> processOutputType;
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
diff --git a/src/QuantUtilsAvx2.cc b/src/QuantUtilsAvx2.cc
index c5ef6ba..c50f6d9 100644
--- a/src/QuantUtilsAvx2.cc
+++ b/src/QuantUtilsAvx2.cc
@@ -1222,41 +1222,100 @@ void requantizeOutputProcessingGConvAvx2(
__m256 xf_v, yf_v, zf_v, wf_v;
if (HAS_BIAS) {
if (is_same<BIAS_TYPE, float>::value) {
- __m256 x_bias_v, y_bias_v, z_bias_v, w_bias_v;
+ __m256 x_bias_v = _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 0 * VLEN));
+ __m256 y_bias_v = _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 1 * VLEN));
+ __m256 z_bias_v = _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 2 * VLEN));
+ __m256 w_bias_v = _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 3 * VLEN));
if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
x_bias_v = _mm256_div_ps(
- _mm256_loadu_ps(
- reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)),
- _mm256_loadu_ps(r.act_times_w_scale + j + 0 * VLEN));
+ x_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 0 * VLEN));
y_bias_v = _mm256_div_ps(
- _mm256_loadu_ps(
- reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)),
- _mm256_loadu_ps(r.act_times_w_scale + j + 1 * VLEN));
+ y_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 1 * VLEN));
z_bias_v = _mm256_div_ps(
- _mm256_loadu_ps(
- reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)),
- _mm256_loadu_ps(r.act_times_w_scale + j + 2 * VLEN));
+ z_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 2 * VLEN));
w_bias_v = _mm256_div_ps(
- _mm256_loadu_ps(
- reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)),
- _mm256_loadu_ps(r.act_times_w_scale + j + 3 * VLEN));
+ w_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 3 * VLEN));
+ } else if (Q_GRAN == QuantizationGranularity::GROUP) {
+ __m256 diviser_v;
+ if (C_PER_G == 4) {
+ diviser_v = _mm256_insertf128_ps(
+ _mm256_castps128_ps256(
+ _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 0])),
+ _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 1]),
+ 1);
+ x_bias_v = _mm256_div_ps(x_bias_v, diviser_v);
+
+ diviser_v = _mm256_insertf128_ps(
+ _mm256_castps128_ps256(
+ _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 2])),
+ _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 3]),
+ 1);
+ y_bias_v = _mm256_div_ps(y_bias_v, diviser_v);
+
+ diviser_v = _mm256_insertf128_ps(
+ _mm256_castps128_ps256(
+ _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 4])),
+ _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 5]),
+ 1);
+ z_bias_v = _mm256_div_ps(z_bias_v, diviser_v);
+
+ diviser_v = _mm256_insertf128_ps(
+ _mm256_castps128_ps256(
+ _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 6])),
+ _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 7]),
+ 1);
+ w_bias_v = _mm256_div_ps(w_bias_v, diviser_v);
+
+ } else if (C_PER_G == 8) {
+ diviser_v = _mm256_set1_ps(
+ r.act_times_w_scale
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 4 + 0]);
+ x_bias_v = _mm256_div_ps(x_bias_v, diviser_v);
+
+ diviser_v = _mm256_set1_ps(
+ r.act_times_w_scale
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 4 + 1]);
+ y_bias_v = _mm256_div_ps(y_bias_v, diviser_v);
+
+ diviser_v = _mm256_set1_ps(
+ r.act_times_w_scale
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 4 + 2]);
+ z_bias_v = _mm256_div_ps(z_bias_v, diviser_v);
+
+ diviser_v = _mm256_set1_ps(
+ r.act_times_w_scale
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 4 + 3]);
+ w_bias_v = _mm256_div_ps(w_bias_v, diviser_v);
+
+ } else {
+ assert(C_PER_G == 16);
+ diviser_v = _mm256_set1_ps(
+ r.act_times_w_scale
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 2 + 0]);
+ x_bias_v = _mm256_div_ps(x_bias_v, diviser_v);
+ y_bias_v = _mm256_div_ps(y_bias_v, diviser_v);
+
+ diviser_v = _mm256_set1_ps(
+ r.act_times_w_scale
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 2 + 1]);
+ z_bias_v = _mm256_div_ps(z_bias_v, diviser_v);
+ w_bias_v = _mm256_div_ps(w_bias_v, diviser_v);
+ }
} else {
- x_bias_v = _mm256_mul_ps(
- _mm256_loadu_ps(
- reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)),
- act_times_w_rcp_v);
- y_bias_v = _mm256_mul_ps(
- _mm256_loadu_ps(
- reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)),
- act_times_w_rcp_v);
- z_bias_v = _mm256_mul_ps(
- _mm256_loadu_ps(
- reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)),
- act_times_w_rcp_v);
- w_bias_v = _mm256_mul_ps(
- _mm256_loadu_ps(
- reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)),
- act_times_w_rcp_v);
+ x_bias_v = _mm256_mul_ps(x_bias_v, act_times_w_rcp_v);
+ y_bias_v = _mm256_mul_ps(y_bias_v, act_times_w_rcp_v);
+ z_bias_v = _mm256_mul_ps(z_bias_v, act_times_w_rcp_v);
+ w_bias_v = _mm256_mul_ps(w_bias_v, act_times_w_rcp_v);
}
xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), y_bias_v);