Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaya Khudia <dskhudia@fb.com>2019-09-11 21:47:58 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-09-11 21:52:07 +0300
commit637288bff9972c02e72341d6a60fdf9bab1dce7e (patch)
tree3844552832c8527c5bfcb04d87b9b4132fc5bc8e
parent415035019ccbca2b11b62f1503fdd61e8bc59b10 (diff)
ReQuantization with FP32 bias
Summary: There is an issue in eager mode if we quantize bias using input_scale*weight_scale. See the following doc. https://fb.quip.com/ru2eAqzsjwXc Reviewed By: jianyuh Differential Revision: D16948098 fbshipit-source-id: ff2c2bc560c2c14da1941d65a15c96e18f407569
-rw-r--r--include/fbgemm/Fbgemm.h20
-rw-r--r--include/fbgemm/OutputProcessing-inl.h36
-rw-r--r--include/fbgemm/QuantUtilsAvx2.h10
-rw-r--r--include/fbgemm/UtilsAvx2.h5
-rw-r--r--src/GroupwiseConvAcc32Avx2.cc20
-rw-r--r--src/QuantUtilsAvx2.cc477
6 files changed, 390 insertions, 178 deletions
diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h
index 0b7bf1f..2f73de4 100644
--- a/include/fbgemm/Fbgemm.h
+++ b/include/fbgemm/Fbgemm.h
@@ -1128,6 +1128,7 @@ class FBGEMM_API DoSConvOnInpBuffer {
template <
bool FUSE_RELU,
QuantizationGranularity Q_GRAN = QuantizationGranularity::TENSOR,
+ typename BIAS_TYPE = std::int32_t,
typename outT = std::uint8_t,
typename inT = std::int32_t,
typename nextOPType = DoNothing<outT, outT>>
@@ -1135,6 +1136,7 @@ class FBGEMM_API ReQuantizeOutput {
public:
static constexpr int RELU_FUSED = FUSE_RELU;
static constexpr QuantizationGranularity QGRANType = Q_GRAN;
+ using BIAS_T = BIAS_TYPE;
using outType = outT;
using inpType = inT;
/**
@@ -1155,6 +1157,8 @@ class FBGEMM_API ReQuantizeOutput {
* See PackedRequantizeTest.cc for an example.
* TODO: if Aq_zero_point == 0, allow passing nullptr.
* @params bias can be nullptr otherwise the length should be nCol
+ * @params act_times_w_scale activation_scale * weight_scale. This is only
+ * used if bias is unquantized (i.e., float).
*/
ReQuantizeOutput(
nextOPType& nextop,
@@ -1164,9 +1168,10 @@ class FBGEMM_API ReQuantizeOutput {
const std::int32_t* Bq_zero_point,
const std::int32_t* row_offsets,
const std::int32_t* col_offsets,
- const std::int32_t* bias,
+ const BIAS_T* bias,
std::uint32_t nCol,
- int groups = 1)
+ int groups = 1,
+ const float* act_times_w_scale = nullptr)
: nextop_(nextop),
C_multiplier_(C_multiplier),
C_zero_point_(C_zero_point),
@@ -1176,7 +1181,8 @@ class FBGEMM_API ReQuantizeOutput {
q_col_offsets_(col_offsets),
bias_(bias),
ncols_(nCol),
- groups_(groups) {}
+ groups_(groups),
+ act_times_w_scale_(act_times_w_scale) {}
template <inst_set_t instSet>
inline int f(
@@ -1204,12 +1210,15 @@ class FBGEMM_API ReQuantizeOutput {
const std::int32_t* getColOffsets() const {
return q_col_offsets_;
}
- const std::int32_t* getBias() const {
+ const BIAS_T* getBias() const {
return bias_;
}
std::uint32_t getNCols() const {
return ncols_;
}
+ const float* getActWScale() const {
+ return act_times_w_scale_;
+ }
void setRowOffsets(const std::int32_t* row_offsets) {
q_row_offsets_ = row_offsets;
@@ -1223,9 +1232,10 @@ class FBGEMM_API ReQuantizeOutput {
const std::int32_t* Bq_zero_point_;
const std::int32_t* q_row_offsets_;
const std::int32_t* q_col_offsets_;
- const std::int32_t* bias_;
+ const BIAS_T* bias_;
std::uint32_t ncols_;
int groups_;
+ const float* act_times_w_scale_;
};
/**
diff --git a/include/fbgemm/OutputProcessing-inl.h b/include/fbgemm/OutputProcessing-inl.h
index d984c60..04ae100 100644
--- a/include/fbgemm/OutputProcessing-inl.h
+++ b/include/fbgemm/OutputProcessing-inl.h
@@ -59,11 +59,13 @@ inline int DoSConvOnInpBuffer<outT, inT, nextOPType>::f(
template <
bool FUSE_RELU,
QuantizationGranularity Q_GRAN,
+ typename BIAS_TYPE,
typename outT,
typename inT,
typename nextOPType>
template <inst_set_t instSet>
-inline int ReQuantizeOutput<FUSE_RELU, Q_GRAN, outT, inT, nextOPType>::f(
+inline int
+ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE, outT, inT, nextOPType>::f(
outT* out,
const inT* inp,
const block_type_t& block,
@@ -98,11 +100,20 @@ inline int ReQuantizeOutput<FUSE_RELU, Q_GRAN, outT, inT, nextOPType>::f(
raw -= q_row_offsets_[i - block.row_start] *
Bq_zero_point_[Bq_zero_point_idx];
}
+ float raw_f;
if (bias_) {
- raw += bias_[j];
+ if (std::is_same<BIAS_TYPE, float>::value) {
+ raw_f = raw;
+ raw_f += bias_[j] / act_times_w_scale_[Bq_zero_point_idx];
+ } else {
+ raw += bias_[j];
+ raw_f = raw;
+ }
+ } else {
+ raw_f = raw;
}
- float ab = raw * C_multiplier_[Bq_zero_point_idx];
+ float ab = raw_f * C_multiplier_[Bq_zero_point_idx];
long rounded = std::lrintf(ab) + C_zero_point_;
out[i * ld_out + j] = std::max(
@@ -115,15 +126,16 @@ inline int ReQuantizeOutput<FUSE_RELU, Q_GRAN, outT, inT, nextOPType>::f(
Bq_zero_point_[0] == 0) ||
q_row_offsets_ == nullptr;
- requantizationParams_t r = {Aq_zero_point_,
- Bq_zero_point_,
- C_zero_point_,
- C_multiplier_,
- q_row_offsets_,
- q_col_offsets_,
- bias_,
- ncols_,
- groups_};
+ requantizationParams_t<BIAS_TYPE> r = {Aq_zero_point_,
+ Bq_zero_point_,
+ C_zero_point_,
+ C_multiplier_,
+ q_row_offsets_,
+ q_col_offsets_,
+ bias_,
+ ncols_,
+ groups_,
+ act_times_w_scale_};
if (Aq_zero_point_ == 0) {
if (b_symmetric) {
diff --git a/include/fbgemm/QuantUtilsAvx2.h b/include/fbgemm/QuantUtilsAvx2.h
index a001004..c7f3f35 100644
--- a/include/fbgemm/QuantUtilsAvx2.h
+++ b/include/fbgemm/QuantUtilsAvx2.h
@@ -72,14 +72,15 @@ template <
bool B_SYMMETRIC,
QuantizationGranularity Q_GRAN,
bool HAS_BIAS,
- bool FUSE_RELU>
+ bool FUSE_RELU,
+ typename BIAS_TYPE = std::int32_t>
FBGEMM_API void requantizeOutputProcessingAvx2(
std::uint8_t* out,
const std::int32_t* inp,
const block_type_t& block,
int ld_out,
int ld_in,
- const requantizationParams_t& r);
+ const requantizationParams_t<BIAS_TYPE>& r);
template <
bool A_SYMMETRIC,
@@ -87,14 +88,15 @@ template <
QuantizationGranularity Q_GRAN,
bool HAS_BIAS,
bool FUSE_RELU,
- int C_PER_G>
+ int C_PER_G,
+ typename BIAS_TYPE = std::int32_t>
FBGEMM_API void requantizeOutputProcessingGConvAvx2(
std::uint8_t* out,
const std::int32_t* inp,
const block_type_t& block,
int ld_out,
int ld_in,
- const requantizationParams_t& r);
+ const requantizationParams_t<BIAS_TYPE>& r);
template <
bool A_SYMMETRIC,
diff --git a/include/fbgemm/UtilsAvx2.h b/include/fbgemm/UtilsAvx2.h
index 082edc1..3bac909 100644
--- a/include/fbgemm/UtilsAvx2.h
+++ b/include/fbgemm/UtilsAvx2.h
@@ -44,16 +44,19 @@ struct block_type_t {
* QuantUtilsAvx2.h as it combines all the parameters needed for various
* quantization granularities
*/
+template<typename BIAS_TYPE = std::int32_t>
struct requantizationParams_t {
+ using BIAS_T = BIAS_TYPE;
std::int32_t A_zero_point;
const std::int32_t* B_zero_point;
std::int32_t C_zero_point;
const float* C_multiplier;
const std::int32_t* row_offsets;
const std::int32_t* col_offsets;
- const std::int32_t* bias;
+ const BIAS_T* bias;
std::uint32_t ncols;
int groups;
+ const float* act_times_w_scale;
};
/**
diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc
index ef4ba7b..40f3fba 100644
--- a/src/GroupwiseConvAcc32Avx2.cc
+++ b/src/GroupwiseConvAcc32Avx2.cc
@@ -1872,15 +1872,17 @@ void fbgemmGroupwiseConv(
outProcess.getBZeroPoint()[0] == 0) ||
rowOffsetBuf == nullptr;
- requantizationParams_t r = {a_zero_point,
- outProcess.getBZeroPoint(),
- outProcess.getCZeroPoint(),
- outProcess.getCMultiplier(),
- rowOffsetBuf,
- outProcess.getColOffsets(),
- outProcess.getBias(),
- outProcess.getNCols(),
- G};
+ requantizationParams_t<typename processOutputType::BIAS_T> r = {
+ a_zero_point,
+ outProcess.getBZeroPoint(),
+ outProcess.getCZeroPoint(),
+ outProcess.getCMultiplier(),
+ rowOffsetBuf,
+ outProcess.getColOffsets(),
+ outProcess.getBias(),
+ outProcess.getNCols(),
+ G,
+ outProcess.getActWScale()};
const std::int32_t* inp = outBuffer;
block_type_t block{i * oh_ow, oh_ow, gOuter * K_per_G, 8 * K_per_G};
diff --git a/src/QuantUtilsAvx2.cc b/src/QuantUtilsAvx2.cc
index 0643ed6..c5ef6ba 100644
--- a/src/QuantUtilsAvx2.cc
+++ b/src/QuantUtilsAvx2.cc
@@ -282,14 +282,15 @@ template <
bool B_SYMMETRIC,
QuantizationGranularity Q_GRAN,
bool HAS_BIAS,
- bool FUSE_RELU>
+ bool FUSE_RELU,
+ typename BIAS_TYPE>
void requantizeOutputProcessingAvx2(
uint8_t* out,
const int32_t* inp,
const block_type_t& block,
int ld_out,
int ld_in,
- const requantizationParams_t& r) {
+ const requantizationParams_t<BIAS_TYPE>& r) {
// Adoption of implementation at QNNPACK/src/requantization/fp32-sse2.c
// using AVX2 instructions
int quant_param_idx = 0;
@@ -300,6 +301,15 @@ void requantizeOutputProcessingAvx2(
}
__m256 multiplier_v = _mm256_set1_ps(r.C_multiplier[quant_param_idx]);
+ // Broadcasted reciprocal of act_times_w_scale
+ __m256 act_times_w_rcp_v;
+ if (!(Q_GRAN == QuantizationGranularity::OUT_CHANNEL)) {
+ if (is_same<BIAS_TYPE, float>::value) {
+ act_times_w_rcp_v =
+ _mm256_set1_ps(1.0 / r.act_times_w_scale[quant_param_idx]);
+ }
+ }
+
__m256i min_v = _mm256_set1_epi8(static_cast<uint8_t>(0));
__m256i max_v = _mm256_set1_epi8(static_cast<uint8_t>(255));
@@ -409,22 +419,76 @@ void requantizeOutputProcessingAvx2(
}
w_v = _mm256_sub_epi32(w_v, row_offset_v);
}
+ __m256 xf_v, yf_v, zf_v, wf_v;
if (HAS_BIAS) {
- x_v = _mm256_add_epi32(
- x_v,
- _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j)));
- y_v = _mm256_add_epi32(
- y_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(r.bias + j + VLEN)));
- z_v = _mm256_add_epi32(
- z_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN)));
- w_v = _mm256_add_epi32(
- w_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN)));
+ if (is_same<BIAS_TYPE, float>::value) {
+ __m256 x_bias_v, y_bias_v, z_bias_v, w_bias_v;
+ if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
+ x_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)),
+ _mm256_loadu_ps(r.act_times_w_scale + j + 0 * VLEN));
+ y_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)),
+ _mm256_loadu_ps(r.act_times_w_scale + j + 1 * VLEN));
+ z_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)),
+ _mm256_loadu_ps(r.act_times_w_scale + j + 2 * VLEN));
+ w_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)),
+ _mm256_loadu_ps(r.act_times_w_scale + j + 3 * VLEN));
+ } else {
+ x_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)),
+ act_times_w_rcp_v);
+ y_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)),
+ act_times_w_rcp_v);
+ z_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)),
+ act_times_w_rcp_v);
+ w_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)),
+ act_times_w_rcp_v);
+ }
+ xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
+ yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), y_bias_v);
+ zf_v = _mm256_add_ps(_mm256_cvtepi32_ps(z_v), z_bias_v);
+ wf_v = _mm256_add_ps(_mm256_cvtepi32_ps(w_v), w_bias_v);
+ } else {
+ x_v = _mm256_add_epi32(
+ x_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(r.bias + j + 0 * VLEN)));
+ y_v = _mm256_add_epi32(
+ y_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(r.bias + j + 1 * VLEN)));
+ z_v = _mm256_add_epi32(
+ z_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN)));
+ w_v = _mm256_add_epi32(
+ w_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN)));
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ yf_v = _mm256_cvtepi32_ps(y_v);
+ zf_v = _mm256_cvtepi32_ps(z_v);
+ wf_v = _mm256_cvtepi32_ps(w_v);
+ }
+ } else {
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ yf_v = _mm256_cvtepi32_ps(y_v);
+ zf_v = _mm256_cvtepi32_ps(z_v);
+ wf_v = _mm256_cvtepi32_ps(w_v);
}
/*
@@ -441,22 +505,19 @@ void requantizeOutputProcessingAvx2(
*/
__m256 x_scaled_v, y_scaled_v, z_scaled_v, w_scaled_v;
if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
- x_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(x_v), _mm256_loadu_ps(r.C_multiplier + j));
- y_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(y_v),
- _mm256_loadu_ps(r.C_multiplier + j + VLEN));
- z_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(z_v),
- _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN));
- w_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(w_v),
- _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN));
+ x_scaled_v =
+ _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j + 0 * VLEN));
+ y_scaled_v =
+ _mm256_mul_ps(yf_v, _mm256_loadu_ps(r.C_multiplier + j + 1 * VLEN));
+ z_scaled_v =
+ _mm256_mul_ps(zf_v, _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN));
+ w_scaled_v =
+ _mm256_mul_ps(wf_v, _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN));
} else {
- x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
- y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v);
- z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v);
- w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v);
+ x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
+ y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v);
+ z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v);
+ w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v);
}
/*
@@ -543,18 +604,35 @@ void requantizeOutputProcessingAvx2(
}
x_v = _mm256_sub_epi32(x_v, row_offset_v);
}
+ __m256 xf_v;
if (HAS_BIAS) {
- x_v = _mm256_add_epi32(
- x_v,
- _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j)));
+ if (is_same<BIAS_TYPE, float>::value) {
+ __m256 x_bias_v;
+ if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
+ x_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(reinterpret_cast<const float*>(r.bias + j)),
+ _mm256_loadu_ps(r.act_times_w_scale + j));
+ } else {
+ x_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(reinterpret_cast<const float*>(r.bias + j)),
+ act_times_w_rcp_v);
+ }
+ xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
+ } else {
+ x_v = _mm256_add_epi32(
+ x_v,
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j)));
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ }
+ } else {
+ xf_v = _mm256_cvtepi32_ps(x_v);
}
__m256 x_scaled_v;
if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
- x_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(x_v), _mm256_loadu_ps(r.C_multiplier + j));
+ x_scaled_v = _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j));
} else {
- x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
+ x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
}
__m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
@@ -627,17 +705,40 @@ void requantizeOutputProcessingAvx2(
}
x_v = _mm256_sub_epi32(x_v, row_offset_v);
}
+
+ __m256 xf_v;
if (HAS_BIAS) {
- x_v = _mm256_add_epi32(x_v, _mm256_maskload_epi32(r.bias + j, mask_v));
+ if (is_same<BIAS_TYPE, float>::value) {
+ __m256 x_bias_v;
+ if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
+ x_bias_v = _mm256_div_ps(
+ _mm256_maskload_ps(
+ reinterpret_cast<const float*>(r.bias + j), mask_v),
+ _mm256_maskload_ps(r.act_times_w_scale + j, mask_v));
+ } else {
+ x_bias_v = _mm256_mul_ps(
+ _mm256_maskload_ps(
+ reinterpret_cast<const float*>(r.bias + j), mask_v),
+ act_times_w_rcp_v);
+ }
+ xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
+ } else {
+ x_v = _mm256_add_epi32(
+ x_v,
+ _mm256_maskload_epi32(
+ reinterpret_cast<const int*>(r.bias + j), mask_v));
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ }
+ } else {
+ xf_v = _mm256_cvtepi32_ps(x_v);
}
__m256 x_scaled_v;
if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
- x_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(x_v),
- _mm256_maskload_ps(r.C_multiplier + j, mask_v));
+ x_scaled_v =
+ _mm256_mul_ps(xf_v, _mm256_maskload_ps(r.C_multiplier + j, mask_v));
} else {
- x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
+ x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
}
__m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
@@ -845,14 +946,15 @@ template <
QuantizationGranularity Q_GRAN,
bool HAS_BIAS,
bool FUSE_RELU,
- int C_PER_G>
+ int C_PER_G,
+ typename BIAS_TYPE>
void requantizeOutputProcessingGConvAvx2(
uint8_t* out,
const int32_t* inp,
const block_type_t& block,
int ld_out,
int ld_in,
- const requantizationParams_t& r) {
+ const requantizationParams_t<BIAS_TYPE>& r) {
// Adoption of implementation at QNNPACK/src/requantization/fp32-sse2.c
// using AVX2 instructions
int quant_param_idx = 0;
@@ -863,6 +965,14 @@ void requantizeOutputProcessingGConvAvx2(
}
__m256 multiplier_v = _mm256_set1_ps(r.C_multiplier[quant_param_idx]);
+ // Broadcasted reciprocal of act_times_w_scale
+ __m256 act_times_w_rcp_v;
+ if (!(Q_GRAN == QuantizationGranularity::OUT_CHANNEL)) {
+ if (is_same<BIAS_TYPE, float>::value) {
+ act_times_w_rcp_v =
+ _mm256_set1_ps(1.0 / r.act_times_w_scale[quant_param_idx]);
+ }
+ }
__m256i min_v = _mm256_set1_epi8(static_cast<uint8_t>(0));
__m256i max_v = _mm256_set1_epi8(static_cast<uint8_t>(255));
@@ -1109,22 +1219,76 @@ void requantizeOutputProcessingGConvAvx2(
}
w_v = _mm256_sub_epi32(w_v, row_offset_v);
}
+ __m256 xf_v, yf_v, zf_v, wf_v;
if (HAS_BIAS) {
- x_v = _mm256_add_epi32(
- x_v,
- _mm256_loadu_si256(reinterpret_cast<const __m256i*>(r.bias + j)));
- y_v = _mm256_add_epi32(
- y_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(r.bias + j + VLEN)));
- z_v = _mm256_add_epi32(
- z_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN)));
- w_v = _mm256_add_epi32(
- w_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN)));
+ if (is_same<BIAS_TYPE, float>::value) {
+ __m256 x_bias_v, y_bias_v, z_bias_v, w_bias_v;
+ if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
+ x_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)),
+ _mm256_loadu_ps(r.act_times_w_scale + j + 0 * VLEN));
+ y_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)),
+ _mm256_loadu_ps(r.act_times_w_scale + j + 1 * VLEN));
+ z_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)),
+ _mm256_loadu_ps(r.act_times_w_scale + j + 2 * VLEN));
+ w_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)),
+ _mm256_loadu_ps(r.act_times_w_scale + j + 3 * VLEN));
+ } else {
+ x_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 0 * VLEN)),
+ act_times_w_rcp_v);
+ y_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 1 * VLEN)),
+ act_times_w_rcp_v);
+ z_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 2 * VLEN)),
+ act_times_w_rcp_v);
+ w_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(r.bias + j + 3 * VLEN)),
+ act_times_w_rcp_v);
+ }
+ xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
+ yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), y_bias_v);
+ zf_v = _mm256_add_ps(_mm256_cvtepi32_ps(z_v), z_bias_v);
+ wf_v = _mm256_add_ps(_mm256_cvtepi32_ps(w_v), w_bias_v);
+ } else {
+ x_v = _mm256_add_epi32(
+ x_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(r.bias + j + 0 * VLEN)));
+ y_v = _mm256_add_epi32(
+ y_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(r.bias + j + 1 * VLEN)));
+ z_v = _mm256_add_epi32(
+ z_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(r.bias + j + 2 * VLEN)));
+ w_v = _mm256_add_epi32(
+ w_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(r.bias + j + 3 * VLEN)));
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ yf_v = _mm256_cvtepi32_ps(y_v);
+ zf_v = _mm256_cvtepi32_ps(z_v);
+ wf_v = _mm256_cvtepi32_ps(w_v);
+ }
+ } else {
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ yf_v = _mm256_cvtepi32_ps(y_v);
+ zf_v = _mm256_cvtepi32_ps(z_v);
+ wf_v = _mm256_cvtepi32_ps(w_v);
}
/*
@@ -1141,17 +1305,13 @@ void requantizeOutputProcessingGConvAvx2(
*/
__m256 x_scaled_v, y_scaled_v, z_scaled_v, w_scaled_v;
if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
- x_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(x_v), _mm256_loadu_ps(r.C_multiplier + j));
- y_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(y_v),
- _mm256_loadu_ps(r.C_multiplier + j + VLEN));
- z_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(z_v),
- _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN));
- w_scaled_v = _mm256_mul_ps(
- _mm256_cvtepi32_ps(w_v),
- _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN));
+ x_scaled_v = _mm256_mul_ps(xf_v, _mm256_loadu_ps(r.C_multiplier + j));
+ y_scaled_v =
+ _mm256_mul_ps(yf_v, _mm256_loadu_ps(r.C_multiplier + j + VLEN));
+ z_scaled_v =
+ _mm256_mul_ps(zf_v, _mm256_loadu_ps(r.C_multiplier + j + 2 * VLEN));
+ w_scaled_v =
+ _mm256_mul_ps(wf_v, _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN));
} else if (Q_GRAN == QuantizationGranularity::GROUP) {
if (C_PER_G == 4) {
multiplier_v = _mm256_insertf128_ps(
@@ -1159,70 +1319,70 @@ void requantizeOutputProcessingGConvAvx2(
_mm_set1_ps(r.C_multiplier[quant_param_idx])),
_mm_set1_ps(r.C_multiplier[quant_param_idx + 1]),
1);
- x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
+ x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
multiplier_v = _mm256_insertf128_ps(
_mm256_castps128_ps256(
_mm_set1_ps(r.C_multiplier[quant_param_idx + 2])),
_mm_set1_ps(r.C_multiplier[quant_param_idx + 3]),
1);
- y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v);
+ y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v);
multiplier_v = _mm256_insertf128_ps(
_mm256_castps128_ps256(
_mm_set1_ps(r.C_multiplier[quant_param_idx + 4])),
_mm_set1_ps(r.C_multiplier[quant_param_idx + 5]),
1);
- z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v);
+ z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v);
multiplier_v = _mm256_insertf128_ps(
_mm256_castps128_ps256(
_mm_set1_ps(r.C_multiplier[quant_param_idx + 6])),
_mm_set1_ps(r.C_multiplier[quant_param_idx + 7]),
1);
- w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v);
+ w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v);
} else if (C_PER_G == 8) {
multiplier_v = _mm256_set1_ps(
r.C_multiplier
[quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4]);
- x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
-
- multiplier_v = _mm256_set1_ps(
- r.C_multiplier
- [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 +
- 1]);
- y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v);
-
- multiplier_v = _mm256_set1_ps(
- r.C_multiplier
- [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 +
- 2]);
- z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v);
-
- multiplier_v = _mm256_set1_ps(
- r.C_multiplier
- [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 +
- 3]);
- w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v);
+ x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
+
+ multiplier_v =
+ _mm256_set1_ps(r.C_multiplier
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 4 + 1]);
+ y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v);
+
+ multiplier_v =
+ _mm256_set1_ps(r.C_multiplier
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 4 + 2]);
+ z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v);
+
+ multiplier_v =
+ _mm256_set1_ps(r.C_multiplier
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 4 + 3]);
+ w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v);
} else {
multiplier_v = _mm256_set1_ps(
r.C_multiplier
[quant_param_idx + (j - block.col_start) / (VLEN * 4) * 2]);
- x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
- y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v);
-
- multiplier_v = _mm256_set1_ps(
- r.C_multiplier
- [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 2 +
- 1]);
- z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v);
- w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v);
+ x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
+ y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v);
+
+ multiplier_v =
+ _mm256_set1_ps(r.C_multiplier
+ [quant_param_idx +
+ (j - block.col_start) / (VLEN * 4) * 2 + 1]);
+ z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v);
+ w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v);
}
} else {
- x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
- y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v);
- z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v);
- w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v);
+ x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
+ y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v);
+ z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v);
+ w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v);
}
/*
@@ -1293,46 +1453,69 @@ void requantizeOutputProcessingGConvAvx2(
} // i loop
}
-#define INSTANTIATE_REQUANTIZE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU) \
- template void \
- requantizeOutputProcessingAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \
- uint8_t * out, \
- const int32_t* inp, \
- const block_type_t& block, \
- int ld_out, \
- int ld_in, \
- const requantizationParams_t& r); \
- template void requantizeForFloatAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \
- float* out, \
- const int32_t* inp, \
- const block_type_t& block, \
- int ld_out, \
- int ld_in, \
- const requantizationForFloatParams_t& r); \
- template void \
- requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 4>( \
- uint8_t * out, \
- const int32_t* inp, \
- const block_type_t& block, \
- int ld_out, \
- int ld_in, \
- const requantizationParams_t& r); \
- template void \
- requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 8>( \
- uint8_t * out, \
- const int32_t* inp, \
- const block_type_t& block, \
- int ld_out, \
- int ld_in, \
- const requantizationParams_t& r); \
- template void \
- requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 16>( \
- uint8_t * out, \
- const int32_t* inp, \
- const block_type_t& block, \
- int ld_out, \
- int ld_in, \
- const requantizationParams_t& r);
+#define INSTANTIATE_REQUANTIZE_BIAS_TYPE( \
+ A_SYM, B_SYM, Q_GRAN, BIAS, RELU, BIAS_TYPE) \
+ template void \
+ requantizeOutputProcessingAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, BIAS_TYPE>( \
+ uint8_t * out, \
+ const int32_t* inp, \
+ const block_type_t& block, \
+ int ld_out, \
+ int ld_in, \
+ const requantizationParams_t<BIAS_TYPE>& r); \
+ template void requantizeOutputProcessingGConvAvx2< \
+ A_SYM, \
+ B_SYM, \
+ Q_GRAN, \
+ BIAS, \
+ RELU, \
+ 4, \
+ BIAS_TYPE>( \
+ uint8_t * out, \
+ const int32_t* inp, \
+ const block_type_t& block, \
+ int ld_out, \
+ int ld_in, \
+ const requantizationParams_t<BIAS_TYPE>& r); \
+ template void requantizeOutputProcessingGConvAvx2< \
+ A_SYM, \
+ B_SYM, \
+ Q_GRAN, \
+ BIAS, \
+ RELU, \
+ 8, \
+ BIAS_TYPE>( \
+ uint8_t * out, \
+ const int32_t* inp, \
+ const block_type_t& block, \
+ int ld_out, \
+ int ld_in, \
+ const requantizationParams_t<BIAS_TYPE>& r); \
+ template void requantizeOutputProcessingGConvAvx2< \
+ A_SYM, \
+ B_SYM, \
+ Q_GRAN, \
+ BIAS, \
+ RELU, \
+ 16, \
+ BIAS_TYPE>( \
+ uint8_t * out, \
+ const int32_t* inp, \
+ const block_type_t& block, \
+ int ld_out, \
+ int ld_in, \
+ const requantizationParams_t<BIAS_TYPE>& r);
+
+#define INSTANTIATE_REQUANTIZE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU) \
+ INSTANTIATE_REQUANTIZE_BIAS_TYPE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU, float) \
+ INSTANTIATE_REQUANTIZE_BIAS_TYPE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU, int32_t) \
+ template void requantizeForFloatAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \
+ float* out, \
+ const int32_t* inp, \
+ const block_type_t& block, \
+ int ld_out, \
+ int ld_in, \
+ const requantizationForFloatParams_t& r);
#define INSTANTIATE_A_SYM(B_SYM, Q_GRAN, BIAS, RELU) \
INSTANTIATE_REQUANTIZE(true, B_SYM, Q_GRAN, BIAS, RELU) \