Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/QuantUtilsAvx2.cc')
-rw-r--r--src/QuantUtilsAvx2.cc477
1 files changed, 330 insertions, 147 deletions
diff --git a/src/QuantUtilsAvx2.cc b/src/QuantUtilsAvx2.cc
index 0643ed6..c5ef6ba 100644
--- a/src/QuantUtilsAvx2.cc
+++ b/src/QuantUtilsAvx2.cc
@@ -282,14 +282,15 @@ template <
bool B_SYMMETRIC,
QuantizationGranularity Q_GRAN,
bool HAS_BIAS,
- bool FUSE_RELU>
+ bool FUSE_RELU,
+ typename BIAS_TYPE>
void requantizeOutputProcessingAvx2(
uint8_t* out,
const int32_t* inp,
const block_type_t& block,
int ld_out,
int ld_in,
- const requantizationParams_t& r) {
+ const requantizationParams_t<BIAS_TYPE>& r) {
// Adoption of implementation at QNNPACK/src/requantization/fp32-sse2.c
// using AVX2 instructions
int quant_param_idx = 0;
@@ -300,6 +301,15 @@ void requantizeOutputProcessingAvx2(
}
__m256 multiplier_v = _mm256_set1_ps(r.C_multiplier[quant_param_idx]);
+ // Broadcasted reciprocal of act_times_w_scale
+ __m256 act_times_w_rcp_v;
+ if (!(Q_GRAN == QuantizationGranularity::OUT_CHANNEL)) {
+ if (is_same<BIAS_TYPE, float>::value) {
+ act_times_w_rcp_v =
+ _mm256_set1_ps(1.0 / r.act_times_w_scale[quant_param_idx]);
+ }
+ }
+
__m256i min_v = _mm256_set1_epi8(static_cast<uint8_t>(0));
__m256i max_v = _mm256_set1_epi8(static_cast<uint8_t>(255));
@@ -409,22 +419,76 @@ void requantizeOutputProcessingAvx2(
}
w_v = _mm256_sub_epi32(w_v, row_offset_v);
}
+ __m256 xf_v, yf_v, zf_v, wf_v;
if (HAS_BIAS) {
- x_v = _mm256_add_epi32(
- x_v,
- _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j)));
- y_v = _mm256_add_epi32(
- y_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(r.bias + j + VLEN)));
- z_v = _mm256_add_epi32(
- z_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN)));
- w_v = _mm256_add_epi32(
- w_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN)));
+ if (is_same<BIAS_TYPE, float>::value) {
+ __m256 x_bias_v, y_bias_v, z_bias_v, w_bias_v;
+ if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
+ x_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)),
+ _mm256_loadu_ps(r.act_times_w_scale + j + 0 * VLEN));
+ y_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)),
+ _mm256_loadu_ps(r.act_times_w_scale + j + 1 * VLEN));
+ z_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)),
+ _mm256_loadu_ps(r.act_times_w_scale + j + 2 * VLEN));
+ w_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)),
+ _mm256_loadu_ps(r.act_times_w_scale + j + 3 * VLEN));
+ } else {
+ x_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)),
+ act_times_w_rcp_v);
+ y_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)),
+ act_times_w_rcp_v);
+ z_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)),
+ act_times_w_rcp_v);
+ w_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)),
+ act_times_w_rcp_v);
+ }
+ xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
+ yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), y_bias_v);
+ zf_v = _mm256_add_ps(_mm256_cvtepi32_ps(z_v), z_bias_v);
+ wf_v = _mm256_add_ps(_mm256_cvtepi32_ps(w_v), w_bias_v);
+ } else {
+ x_v = _mm256_add_epi32(
+ x_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(r.bias + j + 0 * VLEN)));
+ y_v = _mm256_add_epi32(
+ y_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(r.bias + j + 1 * VLEN)));
+ z_v = _mm256_add_epi32(
+ z_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN)));
+ w_v = _mm256_add_epi32(
+ w_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN)));
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ yf_v = _mm256_cvtepi32_ps(y_v);
+ zf_v = _mm256_cvtepi32_ps(z_v);
+ wf_v = _mm256_cvtepi32_ps(w_v);
+ }
+ } else {
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ yf_v = _mm256_cvtepi32_ps(y_v);
+ zf_v = _mm256_cvtepi32_ps(z_v);
+ wf_v = _mm256_cvtepi32_ps(w_v);
}
/*
@@ -441,22 +505,19 @@ void requantizeOutputProcessingAvx2(
*/
__m256 x_scaled_v, y_scaled_v, z_scaled_v, w_scaled_v;
if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
- x_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(x_v), _mm256_loadu_ps(r.C_multiplier + j));
- y_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(y_v),
- _mm256_loadu_ps(r.C_multiplier + j + VLEN));
- z_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(z_v),
- _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN));
- w_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(w_v),
- _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN));
+ x_scaled_v =
+ _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j + 0 * VLEN));
+ y_scaled_v =
+ _mm256_mul_ps(yf_v, _mm256_loadu_ps(r.C_multiplier + j + 1 * VLEN));
+ z_scaled_v =
+ _mm256_mul_ps(zf_v, _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN));
+ w_scaled_v =
+ _mm256_mul_ps(wf_v, _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN));
} else {
- x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
- y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v);
- z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v);
- w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v);
+ x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
+ y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v);
+ z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v);
+ w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v);
}
/*
@@ -543,18 +604,35 @@ void requantizeOutputProcessingAvx2(
}
x_v = _mm256_sub_epi32(x_v, row_offset_v);
}
+ __m256 xf_v;
if (HAS_BIAS) {
- x_v = _mm256_add_epi32(
- x_v,
- _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j)));
+ if (is_same<BIAS_TYPE, float>::value) {
+ __m256 x_bias_v;
+ if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
+ x_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(reinterpret_cast<const float*>(r.bias + j)),
+ _mm256_loadu_ps(r.act_times_w_scale + j));
+ } else {
+ x_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(reinterpret_cast<const float*>(r.bias + j)),
+ act_times_w_rcp_v);
+ }
+ xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
+ } else {
+ x_v = _mm256_add_epi32(
+ x_v,
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j)));
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ }
+ } else {
+ xf_v = _mm256_cvtepi32_ps(x_v);
}
__m256 x_scaled_v;
if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
- x_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(x_v), _mm256_loadu_ps(r.C_multiplier + j));
+ x_scaled_v = _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j));
} else {
- x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
+ x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
}
__m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
@@ -627,17 +705,40 @@ void requantizeOutputProcessingAvx2(
}
x_v = _mm256_sub_epi32(x_v, row_offset_v);
}
+
+ __m256 xf_v;
if (HAS_BIAS) {
- x_v = _mm256_add_epi32(x_v, _mm256_maskload_epi32(r.bias + j, mask_v));
+ if (is_same<BIAS_TYPE, float>::value) {
+ __m256 x_bias_v;
+ if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
+ x_bias_v = _mm256_div_ps(
+ _mm256_maskload_ps(
+ reinterpret_cast<const float*>(r.bias + j), mask_v),
+ _mm256_maskload_ps(r.act_times_w_scale + j, mask_v));
+ } else {
+ x_bias_v = _mm256_mul_ps(
+ _mm256_maskload_ps(
+ reinterpret_cast<const float*>(r.bias + j), mask_v),
+ act_times_w_rcp_v);
+ }
+ xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
+ } else {
+ x_v = _mm256_add_epi32(
+ x_v,
+ _mm256_maskload_epi32(
+ reinterpret_cast<const int*>(r.bias + j), mask_v));
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ }
+ } else {
+ xf_v = _mm256_cvtepi32_ps(x_v);
}
__m256 x_scaled_v;
if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
- x_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(x_v),
- _mm256_maskload_ps(r.C_multiplier + j, mask_v));
+ x_scaled_v =
+ _mm256_mul_ps(xf_v, _mm256_maskload_ps(r.C_multiplier + j, mask_v));
} else {
- x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
+ x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
}
__m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
@@ -845,14 +946,15 @@ template <
QuantizationGranularity Q_GRAN,
bool HAS_BIAS,
bool FUSE_RELU,
- int C_PER_G>
+ int C_PER_G,
+ typename BIAS_TYPE>
void requantizeOutputProcessingGConvAvx2(
uint8_t* out,
const int32_t* inp,
const block_type_t& block,
int ld_out,
int ld_in,
- const requantizationParams_t& r) {
+ const requantizationParams_t<BIAS_TYPE>& r) {
// Adoption of implementation at QNNPACK/src/requantization/fp32-sse2.c
// using AVX2 instructions
int quant_param_idx = 0;
@@ -863,6 +965,14 @@ void requantizeOutputProcessingGConvAvx2(
}
__m256 multiplier_v = _mm256_set1_ps(r.C_multiplier[quant_param_idx]);
+ // Broadcasted reciprocal of act_times_w_scale
+ __m256 act_times_w_rcp_v;
+ if (!(Q_GRAN == QuantizationGranularity::OUT_CHANNEL)) {
+ if (is_same<BIAS_TYPE, float>::value) {
+ act_times_w_rcp_v =
+ _mm256_set1_ps(1.0 / r.act_times_w_scale[quant_param_idx]);
+ }
+ }
__m256i min_v = _mm256_set1_epi8(static_cast<uint8_t>(0));
__m256i max_v = _mm256_set1_epi8(static_cast<uint8_t>(255));
@@ -1109,22 +1219,76 @@ void requantizeOutputProcessingGConvAvx2(
}
w_v = _mm256_sub_epi32(w_v, row_offset_v);
}
+ __m256 xf_v, yf_v, zf_v, wf_v;
if (HAS_BIAS) {
- x_v = _mm256_add_epi32(
- x_v,
- _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j)));
- y_v = _mm256_add_epi32(
- y_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(r.bias + j + VLEN)));
- z_v = _mm256_add_epi32(
- z_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN)));
- w_v = _mm256_add_epi32(
- w_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN)));
+ if (is_same<BIAS_TYPE, float>::value) {
+ __m256 x_bias_v, y_bias_v, z_bias_v, w_bias_v;
+ if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
+ x_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)),
+ _mm256_loadu_ps(r.act_times_w_scale + j + 0 * VLEN));
+ y_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)),
+ _mm256_loadu_ps(r.act_times_w_scale + j + 1 * VLEN));
+ z_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)),
+ _mm256_loadu_ps(r.act_times_w_scale + j + 2 * VLEN));
+ w_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)),
+ _mm256_loadu_ps(r.act_times_w_scale + j + 3 * VLEN));
+ } else {
+ x_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)),
+ act_times_w_rcp_v);
+ y_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)),
+ act_times_w_rcp_v);
+ z_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)),
+ act_times_w_rcp_v);
+ w_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)),
+ act_times_w_rcp_v);
+ }
+ xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
+ yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), y_bias_v);
+ zf_v = _mm256_add_ps(_mm256_cvtepi32_ps(z_v), z_bias_v);
+ wf_v = _mm256_add_ps(_mm256_cvtepi32_ps(w_v), w_bias_v);
+ } else {
+ x_v = _mm256_add_epi32(
+ x_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(r.bias + j + 0 * VLEN)));
+ y_v = _mm256_add_epi32(
+ y_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(r.bias + j + 1 * VLEN)));
+ z_v = _mm256_add_epi32(
+ z_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN)));
+ w_v = _mm256_add_epi32(
+ w_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN)));
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ yf_v = _mm256_cvtepi32_ps(y_v);
+ zf_v = _mm256_cvtepi32_ps(z_v);
+ wf_v = _mm256_cvtepi32_ps(w_v);
+ }
+ } else {
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ yf_v = _mm256_cvtepi32_ps(y_v);
+ zf_v = _mm256_cvtepi32_ps(z_v);
+ wf_v = _mm256_cvtepi32_ps(w_v);
}
/*
@@ -1141,17 +1305,13 @@ void requantizeOutputProcessingGConvAvx2(
*/
__m256 x_scaled_v, y_scaled_v, z_scaled_v, w_scaled_v;
if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
- x_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(x_v), _mm256_loadu_ps(r.C_multiplier + j));
- y_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(y_v),
- _mm256_loadu_ps(r.C_multiplier + j + VLEN));
- z_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(z_v),
- _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN));
- w_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(w_v),
- _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN));
+ x_scaled_v = _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j));
+ y_scaled_v =
+ _mm256_mul_ps(yf_v, _mm256_loadu_ps(r.C_multiplier + j + VLEN));
+ z_scaled_v =
+ _mm256_mul_ps(zf_v, _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN));
+ w_scaled_v =
+ _mm256_mul_ps(wf_v, _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN));
} else if (Q_GRAN == QuantizationGranularity::GROUP) {
if (C_PER_G == 4) {
multiplier_v = _mm256_insertf128_ps(
@@ -1159,70 +1319,70 @@ void requantizeOutputProcessingGConvAvx2(
_mm_set1_ps(r.C_multiplier[quant_param_idx])),
_mm_set1_ps(r.C_multiplier[quant_param_idx + 1]),
1);
- x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
+ x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
multiplier_v = _mm256_insertf128_ps(
_mm256_castps128_ps256(
_mm_set1_ps(r.C_multiplier[quant_param_idx + 2])),
_mm_set1_ps(r.C_multiplier[quant_param_idx + 3]),
1);
- y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v);
+ y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v);
multiplier_v = _mm256_insertf128_ps(
_mm256_castps128_ps256(
_mm_set1_ps(r.C_multiplier[quant_param_idx + 4])),
_mm_set1_ps(r.C_multiplier[quant_param_idx + 5]),
1);
- z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v);
+ z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v);
multiplier_v = _mm256_insertf128_ps(
_mm256_castps128_ps256(
_mm_set1_ps(r.C_multiplier[quant_param_idx + 6])),
_mm_set1_ps(r.C_multiplier[quant_param_idx + 7]),
1);
- w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v);
+ w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v);
} else if (C_PER_G == 8) {
multiplier_v = _mm256_set1_ps(
r.C_multiplier
[quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4]);
- x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
-
- multiplier_v = _mm256_set1_ps(
- r.C_multiplier
- [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 +
- 1]);
- y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v);
-
- multiplier_v = _mm256_set1_ps(
- r.C_multiplier
- [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 +
- 2]);
- z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v);
-
- multiplier_v = _mm256_set1_ps(
- r.C_multiplier
- [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 +
- 3]);
- w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v);
+ x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
+
+ multiplier_v =
+ _mm256_set1_ps(r.C_multiplier
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 4 + 1]);
+ y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v);
+
+ multiplier_v =
+ _mm256_set1_ps(r.C_multiplier
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 4 + 2]);
+ z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v);
+
+ multiplier_v =
+ _mm256_set1_ps(r.C_multiplier
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 4 + 3]);
+ w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v);
} else {
multiplier_v = _mm256_set1_ps(
r.C_multiplier
[quant_param_idx + (j - block.col_start) / (VLEN * 4) * 2]);
- x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
- y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v);
-
- multiplier_v = _mm256_set1_ps(
- r.C_multiplier
- [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 2 +
- 1]);
- z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v);
- w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v);
+ x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
+ y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v);
+
+ multiplier_v =
+ _mm256_set1_ps(r.C_multiplier
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 2 + 1]);
+ z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v);
+ w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v);
}
} else {
- x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
- y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v);
- z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v);
- w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v);
+ x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
+ y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v);
+ z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v);
+ w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v);
}
/*
@@ -1293,46 +1453,69 @@ void requantizeOutputProcessingGConvAvx2(
} // i loop
}
-#define INSTANTIATE_REQUANTIZE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU) \
- template void \
- requantizeOutputProcessingAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \
- uint8_t * out, \
- const int32_t* inp, \
- const block_type_t& block, \
- int ld_out, \
- int ld_in, \
- const requantizationParams_t& r); \
- template void requantizeForFloatAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \
- float* out, \
- const int32_t* inp, \
- const block_type_t& block, \
- int ld_out, \
- int ld_in, \
- const requantizationForFloatParams_t& r); \
- template void \
- requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 4>( \
- uint8_t * out, \
- const int32_t* inp, \
- const block_type_t& block, \
- int ld_out, \
- int ld_in, \
- const requantizationParams_t& r); \
- template void \
- requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 8>( \
- uint8_t * out, \
- const int32_t* inp, \
- const block_type_t& block, \
- int ld_out, \
- int ld_in, \
- const requantizationParams_t& r); \
- template void \
- requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 16>( \
- uint8_t * out, \
- const int32_t* inp, \
- const block_type_t& block, \
- int ld_out, \
- int ld_in, \
- const requantizationParams_t& r);
+#define INSTANTIATE_REQUANTIZE_BIAS_TYPE( \
+ A_SYM, B_SYM, Q_GRAN, BIAS, RELU, BIAS_TYPE) \
+ template void \
+ requantizeOutputProcessingAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, BIAS_TYPE>( \
+ uint8_t * out, \
+ const int32_t* inp, \
+ const block_type_t& block, \
+ int ld_out, \
+ int ld_in, \
+ const requantizationParams_t<BIAS_TYPE>& r); \
+ template void requantizeOutputProcessingGConvAvx2< \
+ A_SYM, \
+ B_SYM, \
+ Q_GRAN, \
+ BIAS, \
+ RELU, \
+ 4, \
+ BIAS_TYPE>( \
+ uint8_t * out, \
+ const int32_t* inp, \
+ const block_type_t& block, \
+ int ld_out, \
+ int ld_in, \
+ const requantizationParams_t<BIAS_TYPE>& r); \
+ template void requantizeOutputProcessingGConvAvx2< \
+ A_SYM, \
+ B_SYM, \
+ Q_GRAN, \
+ BIAS, \
+ RELU, \
+ 8, \
+ BIAS_TYPE>( \
+ uint8_t * out, \
+ const int32_t* inp, \
+ const block_type_t& block, \
+ int ld_out, \
+ int ld_in, \
+ const requantizationParams_t<BIAS_TYPE>& r); \
+ template void requantizeOutputProcessingGConvAvx2< \
+ A_SYM, \
+ B_SYM, \
+ Q_GRAN, \
+ BIAS, \
+ RELU, \
+ 16, \
+ BIAS_TYPE>( \
+ uint8_t * out, \
+ const int32_t* inp, \
+ const block_type_t& block, \
+ int ld_out, \
+ int ld_in, \
+ const requantizationParams_t<BIAS_TYPE>& r);
+
+#define INSTANTIATE_REQUANTIZE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU) \
+ INSTANTIATE_REQUANTIZE_BIAS_TYPE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU, float) \
+ INSTANTIATE_REQUANTIZE_BIAS_TYPE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU, int32_t) \
+ template void requantizeForFloatAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \
+ float* out, \
+ const int32_t* inp, \
+ const block_type_t& block, \
+ int ld_out, \
+ int ld_in, \
+ const requantizationForFloatParams_t& r);
#define INSTANTIATE_A_SYM(B_SYM, Q_GRAN, BIAS, RELU) \
INSTANTIATE_REQUANTIZE(true, B_SYM, Q_GRAN, BIAS, RELU) \