diff options
Diffstat (limited to 'src/ExecuteKernelU8S8.cc')
-rw-r--r-- | src/ExecuteKernelU8S8.cc | 94 |
1 files changed, 70 insertions, 24 deletions
diff --git a/src/ExecuteKernelU8S8.cc b/src/ExecuteKernelU8S8.cc index f7292fd..4ae1b50 100644 --- a/src/ExecuteKernelU8S8.cc +++ b/src/ExecuteKernelU8S8.cc @@ -49,7 +49,8 @@ ExecuteKernel< throw std::runtime_error("Failed to initialize cpuinfo!"); } if (params) { - if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) { + if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support() || + fbgemmHasAvx2Support()) { mbSize_ = params->MCB; nbSize_ = params->NCB; nrMinSize_ = params->NR_MIN; @@ -59,7 +60,20 @@ ExecuteKernel< assert(0 && "unsupported architecure"); } } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + mbSize_ = PackingTraits< + int8_t, + typename packingAMatrix::accType, + inst_set_t::avx512_vnni>::MCB; + nbSize_ = PackingTraits< + int8_t, + typename packingAMatrix::accType, + inst_set_t::avx512_vnni>::NCB; + nrMinSize_ = PackingTraits< + int8_t, + typename packingAMatrix::accType, + inst_set_t::avx512_vnni>::NR_MIN; + } else if (fbgemmHasAvx512Support()) { mbSize_ = PackingTraits< int8_t, typename packingAMatrix::accType, @@ -118,7 +132,25 @@ void ExecuteKernel< typename BaseType::jit_micro_kernel_fp fn; - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + if (std::is_same<typename packingAMatrix::accType, std::int16_t>::value) { + // For AVX512VNNI, we redirect int16_t to int32_t accumulation. + CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj; + fn = codeObj.getOrCreate<inst_set_t::avx512_vnni>( + accum, + packed_rows_A, + packedB_.blockColSize(), + packedA_.numPackedCols(), + nbSize_); + } else { + fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni>( + accum, + packed_rows_A, + packedB_.blockColSize(), + packedA_.numPackedCols(), + nbSize_); + } + } else if (fbgemmHasAvx512Support()) { fn = BaseType::template getOrCreate<inst_set_t::avx512>( accum, packed_rows_A, @@ -148,7 +180,10 @@ void ExecuteKernel< if (jb == bColBlocks - 1) { int nc = ((packedB_.lastBcol() - 1) / nrMinSize_ + 1) * nrMinSize_; if (nc != nbSize_) { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni>( + accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_); + } else if (fbgemmHasAvx512Support()) { fn = BaseType::template getOrCreate<inst_set_t::avx512>( accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_); } else if (fbgemmHasAvx2Support()) { @@ -213,7 +248,7 @@ void ExecuteKernel< int32_t nSize = C_buffer_start == C_tile_ ? jb * nbSize_ : packedB_.numCols(); if (nSize) { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support()) { // TODO: avx512 path // Currently use avx2 code outputProcess_.template f<inst_set_t::avx2>( @@ -238,7 +273,7 @@ void ExecuteKernel< if (C_buffer_start == C_tile_) { // When C_tile_ scratchpad was used to avoid accessing memory past // C_buffer_ . - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support()) { // TODO: avx512 path // Currently use avx2 code outputProcess_.template f<inst_set_t::avx2>( @@ -280,19 +315,23 @@ void ExecuteKernel< //////////////////////////////////////////////////////////////////////////////// // ReQuantizeOutput -#define INSTANTIATE_REQUANT_BASE(PACK_A, ACC_T, RELU, Q_GRAN) \ - template class ExecuteKernel< \ - PACK_A<uint8_t, ACC_T>, \ - PackBMatrix<int8_t, ACC_T>, \ - uint8_t, \ - ReQuantizeOutput<RELU, Q_GRAN>>; +#define INSTANTIATE_REQUANT_BASE(PACK_A, ACC_T, RELU, Q_GRAN, BIAS_TYPE) \ + template class ExecuteKernel< \ + PACK_A<uint8_t, ACC_T>, \ + PackBMatrix<int8_t, ACC_T>, \ + uint8_t, \ + ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>>; + +#define INSTANTIATE_REQUANT_BIAS_T(PACK_A, ACC_T, RELU, Q_GRAN) \ + INSTANTIATE_REQUANT_BASE(PACK_A, ACC_T, RELU, Q_GRAN, float); \ + INSTANTIATE_REQUANT_BASE(PACK_A, ACC_T, RELU, Q_GRAN, int32_t); #define INSTANTIATE_REQUANT_Q_GRANS(PACK_A, ACC_T, RELU) \ - INSTANTIATE_REQUANT_BASE( \ + INSTANTIATE_REQUANT_BIAS_T( \ PACK_A, ACC_T, RELU, QuantizationGranularity::TENSOR); \ - INSTANTIATE_REQUANT_BASE( \ + INSTANTIATE_REQUANT_BIAS_T( \ PACK_A, ACC_T, RELU, QuantizationGranularity::GROUP); \ - INSTANTIATE_REQUANT_BASE( \ + INSTANTIATE_REQUANT_BIAS_T( \ PACK_A, ACC_T, RELU, QuantizationGranularity::OUT_CHANNEL); #define INSTANTIATE_REQUANT_RELU(PACK_A, ACC_T) \ @@ -309,21 +348,27 @@ INSTANTIATE_REQUANT_ACC_T(PackAWithRowOffset); #undef INSTANTIATE_REQUANT_ACC_T #undef INSTANTIATE_REQUANT_RELU #undef INSTANTIATE_REQUANT_Q_GRANS +#undef INSTANTIATE_REQUANT_BIAS_T #undef INSTANTIATE_REQUANT_BASE -#define INSTANTIATE_IM2COL_REQUANT_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN) \ - template class ExecuteKernel< \ - PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>, \ - PackBMatrix<int8_t, ACC_T>, \ - uint8_t, \ - ReQuantizeOutput<RELU, Q_GRAN>>; +#define INSTANTIATE_IM2COL_REQUANT_BASE( \ + ACC_T, RELU, SPATIAL_DIM, Q_GRAN, BIAS_TYPE) \ + template class ExecuteKernel< \ + PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>, \ + PackBMatrix<int8_t, ACC_T>, \ + uint8_t, \ + ReQuantizeOutput<RELU, Q_GRAN, BIAS_TYPE>>; + +#define INSTANTIATE_IM2COL_REQUANT_BIAS_T(ACC_T, RELU, SPATIAL_DIM, Q_GRAN) \ + INSTANTIATE_IM2COL_REQUANT_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN, float); \ + INSTANTIATE_IM2COL_REQUANT_BASE(ACC_T, RELU, SPATIAL_DIM, Q_GRAN, int32_t); #define INSTANTIATE_IM2COL_REQUANT_Q_GRANS(ACC_T, RELU, SPATIAL_DIM) \ - INSTANTIATE_IM2COL_REQUANT_BASE( \ + INSTANTIATE_IM2COL_REQUANT_BIAS_T( \ ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::TENSOR); \ - INSTANTIATE_IM2COL_REQUANT_BASE( \ + INSTANTIATE_IM2COL_REQUANT_BIAS_T( \ ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::GROUP); \ - INSTANTIATE_IM2COL_REQUANT_BASE( \ + INSTANTIATE_IM2COL_REQUANT_BIAS_T( \ ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::OUT_CHANNEL); #define INSTANTIATE_IM2COL_REQUANT_SPATIAL_DIM(ACC_T, RELU) \ @@ -340,6 +385,7 @@ INSTANTIATE_IM2COL_REQUANT_RELU(int16_t); #undef INSTANTIATE_IM2COL_REQUANT_RELU #undef INSTANTIATE_IM2COL_REQUANT_SPATIAL_DIM #undef INSTANTIATE_IM2COL_REQUANT_Q_GRANS +#undef INSTANTIATE_IM2COL_REQUANT_BIAS_T #undef INSTANTIATE_IM2COL_REQUANT_BASE //////////////////////////////////////////////////////////////////////////////// |