Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/QuantUtilsAvx2.cc')
-rw-r--r--src/QuantUtilsAvx2.cc117
1 files changed, 88 insertions, 29 deletions
diff --git a/src/QuantUtilsAvx2.cc b/src/QuantUtilsAvx2.cc
index c5ef6ba..c50f6d9 100644
--- a/src/QuantUtilsAvx2.cc
+++ b/src/QuantUtilsAvx2.cc
@@ -1222,41 +1222,100 @@ void requantizeOutputProcessingGConvAvx2(
__m256 xf_v, yf_v, zf_v, wf_v;
if (HAS_BIAS) {
if (is_same<BIAS_TYPE, float>::value) {
- __m256 x_bias_v, y_bias_v, z_bias_v, w_bias_v;
+ __m256 x_bias_v = _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 0 * VLEN));
+ __m256 y_bias_v = _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 1 * VLEN));
+ __m256 z_bias_v = _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 2 * VLEN));
+ __m256 w_bias_v = _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 3 * VLEN));
if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
x_bias_v = _mm256_div_ps(
- _mm256_loadu_ps(
- reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)),
- _mm256_loadu_ps(r.act_times_w_scale + j + 0 * VLEN));
+ x_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 0 * VLEN));
y_bias_v = _mm256_div_ps(
- _mm256_loadu_ps(
- reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)),
- _mm256_loadu_ps(r.act_times_w_scale + j + 1 * VLEN));
+ y_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 1 * VLEN));
z_bias_v = _mm256_div_ps(
- _mm256_loadu_ps(
- reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)),
- _mm256_loadu_ps(r.act_times_w_scale + j + 2 * VLEN));
+ z_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 2 * VLEN));
w_bias_v = _mm256_div_ps(
- _mm256_loadu_ps(
- reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)),
- _mm256_loadu_ps(r.act_times_w_scale + j + 3 * VLEN));
+ w_bias_v, _mm256_loadu_ps(r.act_times_w_scale + j + 3 * VLEN));
+ } else if (Q_GRAN == QuantizationGranularity::GROUP) {
+ __m256 diviser_v;
+ if (C_PER_G == 4) {
+ diviser_v = _mm256_insertf128_ps(
+ _mm256_castps128_ps256(
+ _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 0])),
+ _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 1]),
+ 1);
+ x_bias_v = _mm256_div_ps(x_bias_v, diviser_v);
+
+ diviser_v = _mm256_insertf128_ps(
+ _mm256_castps128_ps256(
+ _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 2])),
+ _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 3]),
+ 1);
+ y_bias_v = _mm256_div_ps(y_bias_v, diviser_v);
+
+ diviser_v = _mm256_insertf128_ps(
+ _mm256_castps128_ps256(
+ _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 4])),
+ _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 5]),
+ 1);
+ z_bias_v = _mm256_div_ps(z_bias_v, diviser_v);
+
+ diviser_v = _mm256_insertf128_ps(
+ _mm256_castps128_ps256(
+ _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 6])),
+ _mm_set1_ps(r.act_times_w_scale[quant_param_idx + 7]),
+ 1);
+ w_bias_v = _mm256_div_ps(w_bias_v, diviser_v);
+
+ } else if (C_PER_G == 8) {
+ diviser_v = _mm256_set1_ps(
+ r.act_times_w_scale
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 4 + 0]);
+ x_bias_v = _mm256_div_ps(x_bias_v, diviser_v);
+
+ diviser_v = _mm256_set1_ps(
+ r.act_times_w_scale
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 4 + 1]);
+ y_bias_v = _mm256_div_ps(y_bias_v, diviser_v);
+
+ diviser_v = _mm256_set1_ps(
+ r.act_times_w_scale
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 4 + 2]);
+ z_bias_v = _mm256_div_ps(z_bias_v, diviser_v);
+
+ diviser_v = _mm256_set1_ps(
+ r.act_times_w_scale
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 4 + 3]);
+ w_bias_v = _mm256_div_ps(w_bias_v, diviser_v);
+
+ } else {
+ assert(C_PER_G == 16);
+ diviser_v = _mm256_set1_ps(
+ r.act_times_w_scale
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 2 + 0]);
+ x_bias_v = _mm256_div_ps(x_bias_v, diviser_v);
+ y_bias_v = _mm256_div_ps(y_bias_v, diviser_v);
+
+ diviser_v = _mm256_set1_ps(
+ r.act_times_w_scale
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 2 + 1]);
+ z_bias_v = _mm256_div_ps(z_bias_v, diviser_v);
+ w_bias_v = _mm256_div_ps(w_bias_v, diviser_v);
+ }
} else {
- x_bias_v = _mm256_mul_ps(
- _mm256_loadu_ps(
- reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)),
- act_times_w_rcp_v);
- y_bias_v = _mm256_mul_ps(
- _mm256_loadu_ps(
- reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)),
- act_times_w_rcp_v);
- z_bias_v = _mm256_mul_ps(
- _mm256_loadu_ps(
- reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)),
- act_times_w_rcp_v);
- w_bias_v = _mm256_mul_ps(
- _mm256_loadu_ps(
- reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)),
- act_times_w_rcp_v);
+ x_bias_v = _mm256_mul_ps(x_bias_v, act_times_w_rcp_v);
+ y_bias_v = _mm256_mul_ps(y_bias_v, act_times_w_rcp_v);
+ z_bias_v = _mm256_mul_ps(z_bias_v, act_times_w_rcp_v);
+ w_bias_v = _mm256_mul_ps(w_bias_v, act_times_w_rcp_v);
}
xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), y_bias_v);