Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/ExecuteKernelU8S8.cc')
-rw-r--r--src/ExecuteKernelU8S8.cc94
1 files changed, 70 insertions, 24 deletions
diff --git a/src/ExecuteKernelU8S8.cc b/src/ExecuteKernelU8S8.cc
index f7292fd..4ae1b50 100644
--- a/src/ExecuteKernelU8S8.cc
+++ b/src/ExecuteKernelU8S8.cc
@@ -49,7 +49,8 @@ ExecuteKernel<
throw std::runtime_error("Failed to initialize cpuinfo!");
}
if (params) {
- if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) {
+ if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support() ||
+ fbgemmHasAvx2Support()) {
mbSize_ = params->MCB;
nbSize_ = params->NCB;
nrMinSize_ = params->NR_MIN;
@@ -59,7 +60,20 @@ ExecuteKernel<
assert(0 && "unsupported architecure");
}
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ mbSize_ = PackingTraits<
+ int8_t,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512_vnni>::MCB;
+ nbSize_ = PackingTraits<
+ int8_t,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512_vnni>::NCB;
+ nrMinSize_ = PackingTraits<
+ int8_t,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512_vnni>::NR_MIN;
+ } else if (fbgemmHasAvx512Support()) {
mbSize_ = PackingTraits<
int8_t,
typename packingAMatrix::accType,
@@ -118,7 +132,25 @@ void ExecuteKernel<
typename BaseType::jit_micro_kernel_fp fn;
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ if (std::is_same<typename packingAMatrix::accType, std::int16_t>::value) {
+ // For AVX512VNNI, we redirect int16_t to int32_t accumulation.
+ CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
+ fn = codeObj.getOrCreate<inst_set_t::avx512_vnni>(
+ accum,
+ packed_rows_A,
+ packedB_.blockColSize(),
+ packedA_.numPackedCols(),
+ nbSize_);
+ } else {
+ fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni>(
+ accum,
+ packed_rows_A,
+ packedB_.blockColSize(),
+ packedA_.numPackedCols(),
+ nbSize_);
+ }
+ } else if (fbgemmHasAvx512Support()) {
fn = BaseType::template getOrCreate<inst_set_t::avx512>(
accum,
packed_rows_A,
@@ -148,7 +180,10 @@ void ExecuteKernel<
if (jb == bColBlocks - 1) {
int nc = ((packedB_.lastBcol() - 1) / nrMinSize_ + 1) * nrMinSize_;
if (nc != nbSize_) {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni>(
+ accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_);
+ } else if (fbgemmHasAvx512Support()) {
fn = BaseType::template getOrCreate<inst_set_t::avx512>(
accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_);
} else if (fbgemmHasAvx2Support()) {
@@ -213,7 +248,7 @@ void ExecuteKernel<
int32_t nSize =
C_buffer_start == C_tile_ ? jb * nbSize_ : packedB_.numCols();
if (nSize) {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support()) {
// TODO: avx512 path
// Currently use avx2 code
outputProcess_.template f<inst_set_t::avx2>(
@@ -238,7 +273,7 @@ void ExecuteKernel<
if (C_buffer_start == C_tile_) {
// When C_tile_ scratchpad was used to avoid accessing memory past
// C_buffer_ .
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support()) {
// TODO: avx512 path
// Currently use avx2 code
outputProcess_.template f<inst_set_t::avx2>(
@@ -280,19 +315,23 @@ void ExecuteKernel<
////////////////////////////////////////////////////////////////////////////////
// ReQuantizeOutput
-#define INSTANTIATE_REQUANT_BASE(PACK_A, ACC_T, RELU, Q_GRAN) \
- template class ExecuteKernel< \
- PACK_A<uint8_t, ACC_T>, \
- PackBMatrix<int8_t, ACC_T>, \
- uint8_t, \
- ReQuantizeOutput<RELU, Q_GRAN>>;
+#define INSTANTIATE_REQUANT_BASE(PACK_A, ACC_T, RELU, Q_GRAN, BIAS_TYPE) \
+ template class ExecuteKernel< \
+ PACK_A<uint8_t, ACC_T>, \
+ PackBMatrix<int8_t, ACC_T>, \
+ uint8_t, \
+ ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>>;
+
+#define INSTANTIATE_REQUANT_BIAS_T(PACK_A, ACC_T, RELU, Q_GRAN) \
+ INSTANTIATE_REQUANT_BASE(PACK_A, ACC_T, RELU, Q_GRAN, float); \
+ INSTANTIATE_REQUANT_BASE(PACK_A, ACC_T, RELU, Q_GRAN, int32_t);
#define INSTANTIATE_REQUANT_Q_GRANS(PACK_A, ACC_T, RELU) \
- INSTANTIATE_REQUANT_BASE( \
+ INSTANTIATE_REQUANT_BIAS_T( \
PACK_A, ACC_T, RELU, QuantizationGranularity::TENSOR); \
- INSTANTIATE_REQUANT_BASE( \
+ INSTANTIATE_REQUANT_BIAS_T( \
PACK_A, ACC_T, RELU, QuantizationGranularity::GROUP); \
- INSTANTIATE_REQUANT_BASE( \
+ INSTANTIATE_REQUANT_BIAS_T( \
PACK_A, ACC_T, RELU, QuantizationGranularity::OUT_CHANNEL);
#define INSTANTIATE_REQUANT_RELU(PACK_A, ACC_T) \
@@ -309,21 +348,27 @@ INSTANTIATE_REQUANT_ACC_T(PackAWithRowOffset);
#undef INSTANTIATE_REQUANT_ACC_T
#undef INSTANTIATE_REQUANT_RELU
#undef INSTANTIATE_REQUANT_Q_GRANS
+#undef INSTANTIATE_REQUANT_BIAS_T
#undef INSTANTIATE_REQUANT_BASE
-#define INSTANTIATE_IM2COL_REQUANT_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN) \
- template class ExecuteKernel< \
- PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>, \
- PackBMatrix<int8_t, ACC_T>, \
- uint8_t, \
- ReQuantizeOutput<RELU, Q_GRAN>>;
+#define INSTANTIATE_IM2COL_REQUANT_BASE( \
+ ACC_T, RELU, SPATIAL_DIM, Q_GRAN, BIAS_TYPE) \
+ template class ExecuteKernel< \
+ PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>, \
+ PackBMatrix<int8_t, ACC_T>, \
+ uint8_t, \
+ ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>>;
+
+#define INSTANTIATE_IM2COL_REQUANT_BIAS_T(ACC_T, RELU, SPATIAL_DIM, Q_GRAN) \
+ INSTANTIATE_IM2COL_REQUANT_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN, float); \
+ INSTANTIATE_IM2COL_REQUANT_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN, int32_t);
#define INSTANTIATE_IM2COL_REQUANT_Q_GRANS(ACC_T, RELU, SPATIAL_DIM) \
- INSTANTIATE_IM2COL_REQUANT_BASE( \
+ INSTANTIATE_IM2COL_REQUANT_BIAS_T( \
ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::TENSOR); \
- INSTANTIATE_IM2COL_REQUANT_BASE( \
+ INSTANTIATE_IM2COL_REQUANT_BIAS_T( \
ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::GROUP); \
- INSTANTIATE_IM2COL_REQUANT_BASE( \
+ INSTANTIATE_IM2COL_REQUANT_BIAS_T( \
ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::OUT_CHANNEL);
#define INSTANTIATE_IM2COL_REQUANT_SPATIAL_DIM(ACC_T, RELU) \
@@ -340,6 +385,7 @@ INSTANTIATE_IM2COL_REQUANT_RELU(int16_t);
#undef INSTANTIATE_IM2COL_REQUANT_RELU
#undef INSTANTIATE_IM2COL_REQUANT_SPATIAL_DIM
#undef INSTANTIATE_IM2COL_REQUANT_Q_GRANS
+#undef INSTANTIATE_IM2COL_REQUANT_BIAS_T
#undef INSTANTIATE_IM2COL_REQUANT_BASE
////////////////////////////////////////////////////////////////////////////////