Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJongsoo Park <jongsoo@fb.com>2019-03-13 06:14:32 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-03-13 06:17:49 +0300
commit6011ce3b0c1fccee549e85b37e475c7a734ad742 (patch)
tree7089177b6c7da36c2582da1cf9b42eca9dfb2ea7 /include/fbgemm
parent50b43162fd1742122d01f2704945c78f13e0d73e (diff)
optimize requantize for float out processing (#85)
Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/85 Optimizing performance of output processing when output is dequantized right away. Reviewed By: protonu Differential Revision: D14433141 fbshipit-source-id: f99a8d82000c43e554461acf036462a4e8f7e300
Diffstat (limited to 'include/fbgemm')
-rw-r--r--include/fbgemm/OutputProcessing-inl.h274
-rw-r--r--include/fbgemm/QuantUtilsAvx2.h14
-rw-r--r--include/fbgemm/UtilsAvx2.h15
3 files changed, 207 insertions, 96 deletions
diff --git a/include/fbgemm/OutputProcessing-inl.h b/include/fbgemm/OutputProcessing-inl.h
index 9485b18..d984c60 100644
--- a/include/fbgemm/OutputProcessing-inl.h
+++ b/include/fbgemm/OutputProcessing-inl.h
@@ -77,7 +77,7 @@ inline int ReQuantizeOutput<FUSE_RELU, Q_GRAN, outT, inT, nextOPType>::f(
block.col_size <= ncol_per_group &&
"ReQuantizeOutput should be called at most 1 group at a time.");
int g = block.col_start / ncol_per_group;
- if (instSet == inst_set_t::anyarch) {
+ if (instSet == inst_set_t::anyarch || !std::is_same<outT, uint8_t>::value) {
for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
for (int j = block.col_start; j < block.col_start + block.col_size; ++j) {
inT raw = inp[(i - block.row_start) * ld_in + (j - block.col_start)];
@@ -111,88 +111,84 @@ inline int ReQuantizeOutput<FUSE_RELU, Q_GRAN, outT, inT, nextOPType>::f(
}
}
} else if (instSet == inst_set_t::avx2 || instSet == inst_set_t::avx512) {
- if (std::is_same<outT, uint8_t>::value) {
- bool b_symmetric = (Q_GRAN == QuantizationGranularity::TENSOR &&
- Bq_zero_point_[0] == 0) ||
- q_row_offsets_ == nullptr;
+ bool b_symmetric = (Q_GRAN == QuantizationGranularity::TENSOR &&
+ Bq_zero_point_[0] == 0) ||
+ q_row_offsets_ == nullptr;
- requantizationParams_t r = {Aq_zero_point_,
- Bq_zero_point_,
- C_zero_point_,
- C_multiplier_,
- q_row_offsets_,
- q_col_offsets_,
- bias_,
- ncols_,
- groups_};
+ requantizationParams_t r = {Aq_zero_point_,
+ Bq_zero_point_,
+ C_zero_point_,
+ C_multiplier_,
+ q_row_offsets_,
+ q_col_offsets_,
+ bias_,
+ ncols_,
+ groups_};
- if (Aq_zero_point_ == 0) {
- if (b_symmetric) {
- if (bias_ == nullptr) {
- requantizeOutputProcessingAvx2<
- true,
- true,
- Q_GRAN,
- false,
- FUSE_RELU>(out, inp, block, ld_out, ld_in, r);
- } else {
- requantizeOutputProcessingAvx2<true, true, Q_GRAN, true, FUSE_RELU>(
- out, inp, block, ld_out, ld_in, r);
- }
+ if (Aq_zero_point_ == 0) {
+ if (b_symmetric) {
+ if (bias_ == nullptr) {
+ requantizeOutputProcessingAvx2<
+ true,
+ true,
+ Q_GRAN,
+ false,
+ FUSE_RELU>(out, inp, block, ld_out, ld_in, r);
} else {
- if (bias_ == nullptr) {
- requantizeOutputProcessingAvx2<
- true,
- false,
- Q_GRAN,
- false,
- FUSE_RELU>(out, inp, block, ld_out, ld_in, r);
- } else {
- requantizeOutputProcessingAvx2<
- true,
- false,
- Q_GRAN,
- true,
- FUSE_RELU>(out, inp, block, ld_out, ld_in, r);
- }
+ requantizeOutputProcessingAvx2<true, true, Q_GRAN, true, FUSE_RELU>(
+ out, inp, block, ld_out, ld_in, r);
}
} else {
- if (b_symmetric) {
- if (bias_ == nullptr) {
- requantizeOutputProcessingAvx2<
- false,
- true,
- Q_GRAN,
- false,
- FUSE_RELU>(out, inp, block, ld_out, ld_in, r);
- } else {
- requantizeOutputProcessingAvx2<
- false,
- true,
- Q_GRAN,
- true,
- FUSE_RELU>(out, inp, block, ld_out, ld_in, r);
- }
+ if (bias_ == nullptr) {
+ requantizeOutputProcessingAvx2<
+ true,
+ false,
+ Q_GRAN,
+ false,
+ FUSE_RELU>(out, inp, block, ld_out, ld_in, r);
} else {
- if (bias_ == nullptr) {
- requantizeOutputProcessingAvx2<
- false,
- false,
- Q_GRAN,
- false,
- FUSE_RELU>(out, inp, block, ld_out, ld_in, r);
- } else {
- requantizeOutputProcessingAvx2<
- false,
- false,
- Q_GRAN,
- true,
- FUSE_RELU>(out, inp, block, ld_out, ld_in, r);
- }
+ requantizeOutputProcessingAvx2<
+ true,
+ false,
+ Q_GRAN,
+ true,
+ FUSE_RELU>(out, inp, block, ld_out, ld_in, r);
}
}
} else {
- assert(0 && "Not supported yet");
+ if (b_symmetric) {
+ if (bias_ == nullptr) {
+ requantizeOutputProcessingAvx2<
+ false,
+ true,
+ Q_GRAN,
+ false,
+ FUSE_RELU>(out, inp, block, ld_out, ld_in, r);
+ } else {
+ requantizeOutputProcessingAvx2<
+ false,
+ true,
+ Q_GRAN,
+ true,
+ FUSE_RELU>(out, inp, block, ld_out, ld_in, r);
+ }
+ } else {
+ if (bias_ == nullptr) {
+ requantizeOutputProcessingAvx2<
+ false,
+ false,
+ Q_GRAN,
+ false,
+ FUSE_RELU>(out, inp, block, ld_out, ld_in, r);
+ } else {
+ requantizeOutputProcessingAvx2<
+ false,
+ false,
+ Q_GRAN,
+ true,
+ FUSE_RELU>(out, inp, block, ld_out, ld_in, r);
+ }
+ }
}
} else {
assert(0 && "Not supported yet");
@@ -224,33 +220,119 @@ inline int ReQuantizeForFloat<FUSE_RELU, Q_GRAN, outT, inT, nextOPType>::f(
block.col_size <= ncol_per_group &&
"ReQuantizeOutput should be called at most 1 group at a time.");
int g = block.col_start / ncol_per_group;
- for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
- for (int j = block.col_start; j < block.col_start + block.col_size; ++j) {
- inT raw = inp[(i - block.row_start) * ld_in + j - block.col_start];
- if (Aq_zero_point_) {
- raw -= Aq_zero_point_ * q_col_offsets_[j];
+ if (instSet == inst_set_t::anyarch || !std::is_same<outT, float>::value) {
+ for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
+ for (int j = block.col_start; j < block.col_start + block.col_size; ++j) {
+ inT raw = inp[(i - block.row_start) * ld_in + j - block.col_start];
+ if (Aq_zero_point_) {
+ raw -= Aq_zero_point_ * q_col_offsets_[j];
+ }
+ int Bq_zero_point_idx;
+ if (Q_GRAN == QuantizationGranularity::TENSOR) {
+ Bq_zero_point_idx = 0;
+ } else if (Q_GRAN == QuantizationGranularity::GROUP) {
+ Bq_zero_point_idx = g;
+ } else if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
+ Bq_zero_point_idx = j;
+ } else {
+ assert(false && "unknown quantization granularity");
+ }
+ if (q_row_offsets_) {
+ raw -= q_row_offsets_[i - block.row_start] *
+ Bq_zero_point_[Bq_zero_point_idx];
+ }
+ float res = raw * Aq_scale_ * Bq_scale_[Bq_zero_point_idx];
+ if (bias_) {
+ res += bias_[j];
+ }
+ out[i * ld_out + j] = res;
+ if (FUSE_RELU) {
+ out[i * ld_out + j] = std::max<outT>(0.0f, out[i * ld_out + j]);
+ }
}
- int Bq_zero_point_idx;
- if (Q_GRAN == QuantizationGranularity::TENSOR) {
- Bq_zero_point_idx = 0;
- } else if (Q_GRAN == QuantizationGranularity::GROUP) {
- Bq_zero_point_idx = g;
- } else if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
- Bq_zero_point_idx = j;
+ }
+ } else if (instSet == inst_set_t::avx2 || instSet == inst_set_t::avx512) {
+ bool b_symmetric = (Q_GRAN == QuantizationGranularity::TENSOR &&
+ Bq_zero_point_[0] == 0) ||
+ q_row_offsets_ == nullptr;
+
+ requantizationForFloatParams_t r = {Aq_zero_point_,
+ Bq_zero_point_,
+ Aq_scale_,
+ Bq_scale_,
+ q_row_offsets_,
+ q_col_offsets_,
+ bias_,
+ ncols_,
+ groups_};
+
+ if (Aq_zero_point_ == 0) {
+ if (b_symmetric) {
+ if (bias_ == nullptr) {
+ requantizeForFloatAvx2<
+ true,
+ true,
+ Q_GRAN,
+ false,
+ FUSE_RELU>(out, inp, block, ld_out, ld_in, r);
+ } else {
+ requantizeForFloatAvx2<true, true, Q_GRAN, true, FUSE_RELU>(
+ out, inp, block, ld_out, ld_in, r);
+ }
} else {
- assert(false && "unknown quantization granularity");
- }
- raw -= q_row_offsets_[i - block.row_start] *
- Bq_zero_point_[Bq_zero_point_idx];
- float res = raw * Aq_scale_ * Bq_scale_[Bq_zero_point_idx];
- if (bias_) {
- res += bias_[j];
+ if (bias_ == nullptr) {
+ requantizeForFloatAvx2<
+ true,
+ false,
+ Q_GRAN,
+ false,
+ FUSE_RELU>(out, inp, block, ld_out, ld_in, r);
+ } else {
+ requantizeForFloatAvx2<
+ true,
+ false,
+ Q_GRAN,
+ true,
+ FUSE_RELU>(out, inp, block, ld_out, ld_in, r);
+ }
}
- out[i * ld_out + j] = res;
- if (FUSE_RELU) {
- out[i * ld_out + j] = std::max<outT>(0.0f, out[i * ld_out + j]);
+ } else {
+ if (b_symmetric) {
+ if (bias_ == nullptr) {
+ requantizeForFloatAvx2<
+ false,
+ true,
+ Q_GRAN,
+ false,
+ FUSE_RELU>(out, inp, block, ld_out, ld_in, r);
+ } else {
+ requantizeForFloatAvx2<
+ false,
+ true,
+ Q_GRAN,
+ true,
+ FUSE_RELU>(out, inp, block, ld_out, ld_in, r);
+ }
+ } else {
+ if (bias_ == nullptr) {
+ requantizeForFloatAvx2<
+ false,
+ false,
+ Q_GRAN,
+ false,
+ FUSE_RELU>(out, inp, block, ld_out, ld_in, r);
+ } else {
+ requantizeForFloatAvx2<
+ false,
+ false,
+ Q_GRAN,
+ true,
+ FUSE_RELU>(out, inp, block, ld_out, ld_in, r);
+ }
}
}
+ } else {
+ assert(0 && "Not supported yet");
}
return nextop_.template f<instSet>(out, out, block, ld_out, ld_out);
diff --git a/include/fbgemm/QuantUtilsAvx2.h b/include/fbgemm/QuantUtilsAvx2.h
index 04aeba1..47f33a8 100644
--- a/include/fbgemm/QuantUtilsAvx2.h
+++ b/include/fbgemm/QuantUtilsAvx2.h
@@ -95,4 +95,18 @@ FBGEMM_API void requantizeOutputProcessingGConvAvx2(
int ld_in,
const requantizationParams_t& r);
+template <
+ bool A_SYMMETRIC,
+ bool B_SYMMETRIC,
+ QuantizationGranularity Q_GRAN,
+ bool HAS_BIAS,
+ bool FUSE_RELU>
+FBGEMM_API void requantizeForFloatAvx2(
+ float* out,
+ const std::int32_t* inp,
+ const block_type_t& block,
+ int ld_out,
+ int ld_in,
+ const requantizationForFloatParams_t& r);
+
} // namespace fbgemm
diff --git a/include/fbgemm/UtilsAvx2.h b/include/fbgemm/UtilsAvx2.h
index 53fb39d..082edc1 100644
--- a/include/fbgemm/UtilsAvx2.h
+++ b/include/fbgemm/UtilsAvx2.h
@@ -56,4 +56,19 @@ struct requantizationParams_t {
int groups;
};
+/**
+ * @brief A struct to represent all the parameters for requantizing for floats.
+ */
+struct requantizationForFloatParams_t {
+ std::int32_t A_zero_point;
+ const std::int32_t* B_zero_point;
+ float A_scale;
+ const float* B_scale;
+ const std::int32_t* row_offsets;
+ const std::int32_t* col_offsets;
+ const float* bias;
+ std::uint32_t ncols;
+ int groups;
+};
+
} // namespace fbgemm