diff options
author | Jianyu Huang <jianyuhuang@fb.com> | 2019-08-06 19:35:42 +0300 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-08-06 19:50:51 +0300 |
commit | d8b3323668fdd15dc70e9cb43ab16e96f4846eeb (patch) | |
tree | d48a6818c14575d92e68bf1ffb621d646a6c893e | |
parent | 0d5d057ca941ebb511bdc6178fc26c23e6c4a953 (diff) |
Integrate VNNI into FBGEMM master branch (#113)
Summary:
Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/113
Adding the VNNI support in FBGEMM.
Reviewed By: dskhudia
Differential Revision: D16276574
fbshipit-source-id: 832ccdb27339489ebc138f3b2678e53d107c1b79
-rw-r--r-- | CMakeLists.txt | 2 | ||||
-rw-r--r-- | include/fbgemm/PackingTraits-inl.h | 50 | ||||
-rw-r--r-- | include/fbgemm/Utils.h | 7 | ||||
-rw-r--r-- | src/ExecuteKernelU8S8.cc | 47 | ||||
-rw-r--r-- | src/Fbgemm.cc | 18 | ||||
-rw-r--r-- | src/GenerateKernel.h | 30 | ||||
-rw-r--r-- | src/GenerateKernelU8S8S32ACC16.cc | 91 | ||||
-rw-r--r-- | src/GenerateKernelU8S8S32ACC16Avx512.cc | 94 | ||||
-rw-r--r-- | src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc | 102 | ||||
-rw-r--r-- | src/GenerateKernelU8S8S32ACC32.cc | 87 | ||||
-rw-r--r-- | src/GenerateKernelU8S8S32ACC32Avx512.cc | 95 | ||||
-rw-r--r-- | src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc | 431 | ||||
-rw-r--r-- | src/GroupwiseConv.h | 100 | ||||
-rw-r--r-- | src/GroupwiseConvAcc32Avx2.cc | 177 | ||||
-rw-r--r-- | src/PackAMatrix.cc | 10 | ||||
-rw-r--r-- | src/PackAWithIm2Col.cc | 14 | ||||
-rw-r--r-- | src/PackAWithQuantRowOffset.cc | 14 | ||||
-rw-r--r-- | src/PackAWithRowOffset.cc | 14 | ||||
-rw-r--r-- | src/PackBMatrix.cc | 25 | ||||
-rw-r--r-- | src/PackMatrix.cc | 9 | ||||
-rw-r--r-- | src/PackWeightMatrixForGConv.cc | 8 | ||||
-rw-r--r-- | src/Utils.cc | 3 | ||||
-rw-r--r-- | test/GConvTest.cc | 4 | ||||
m--------- | third_party/asmjit | 0 |
24 files changed, 1054 insertions, 378 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index b575e17..817f699 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -33,8 +33,10 @@ set(FBGEMM_GENERIC_SRCS src/ExecuteKernel.cc src/FbgemmI8Spmdm.cc src/GenerateKernelU8S8S32ACC16.cc src/GenerateKernelU8S8S32ACC16Avx512.cc + src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc src/GenerateKernelU8S8S32ACC32.cc src/GenerateKernelU8S8S32ACC32Avx512.cc + src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc src/GroupwiseConvAcc32Avx2.cc src/PackAMatrix.cc src/PackAWithIm2Col.cc diff --git a/include/fbgemm/PackingTraits-inl.h b/include/fbgemm/PackingTraits-inl.h index 76eb425..baccfad 100644 --- a/include/fbgemm/PackingTraits-inl.h +++ b/include/fbgemm/PackingTraits-inl.h @@ -222,3 +222,53 @@ struct PackingTraits< 128}; ///< Cache block for N dimension (multiple of NR). static constexpr int KCB{256}; ///< Cache block for K dimension. }; + +/** + * @brief Helper struct to type specialize for int16_t and int32_t together. + */ +template <typename T> +struct is_16or32bit { + static constexpr bool value = + std::is_same<T, int16_t>::value || std::is_same<T, int32_t>::value; +}; + +/** + * @brief Packing parameter specialization for accumulation into 32-bit/16-bit + * integers. + * + * Since there is no int16_t accumulation for AVX512 VNNI, we redirect int16_t + * to int32_t accumulation and use the same blocking parameters as int32_t. + * + * This is picked when T is of int8 type (signed or unsigned) and instruction + * set is avx512_vnni. + */ +template <typename T, typename accT> +struct PackingTraits< + T, + accT, + inst_set_t::avx512_vnni, + typename std::enable_if< + is_8bit<T>::value && is_16or32bit<accT>::value>::type> { + static constexpr int MR{8}; ///< Register block for M dimension. + static constexpr int NR_MIN{ + 16}; ///< Minimum register block for N dimension. + ///< 16 because 16*ROW_INTERLEAVE int8 elements + ///< completely fill a 512-bit wide vector. + static constexpr int NR{ + 32}; ///< Register block for N dimension. + ///< Must be a multiple of 16 because 16*ROW_INTERLEAVE int8 elements + ///< completely fill a 512-bit wide vector. Total registers used for + ///< N dimension: NR*ROW_INTERLEAVE*8/VLEN. We use MR x + ///< NR*ROW_INTERLEAVE*8/VLEN zmm registers + ///< for C accumulations. + + static constexpr int ROW_INTERLEAVE{ + 4}; ///< 4 rows are interleaved to use vpmaddubsw instruction for packing + ///< B matrix. + + static constexpr int MCB{ + 128}; ///< Cache block for M dimension (multiple of MR). + static constexpr int NCB{ + 32}; ///< Cache block for N dimension (multiple of NR). + static constexpr int KCB{256}; ///< Cache block for K dimension. +}; diff --git a/include/fbgemm/Utils.h b/include/fbgemm/Utils.h index 107cf07..3f8522b 100644 --- a/include/fbgemm/Utils.h +++ b/include/fbgemm/Utils.h @@ -29,7 +29,7 @@ enum class matrix_op_t { NoTranspose, Transpose }; /** * @brief Typed enum for supported instruction sets. */ -enum class inst_set_t { anyarch, avx2, avx512 }; +enum class inst_set_t { anyarch, avx2, avx512, avx512_vnni }; /** * @brief Typed enum for optimized paths for convolutions @@ -100,6 +100,11 @@ FBGEMM_API bool fbgemmHasAvx512Support(); FBGEMM_API bool fbgemmHasAvx2Support(); /** + * @brief Are we running on a AVX512_VNNI supported cpu? + */ +FBGEMM_API bool fbgemmHasAvx512VnniSupport(); + +/** * @brief Helper struct to enable autotuning of FBGEMM packing and kernels. * * This structure is optional. If not used, the default values for these diff --git a/src/ExecuteKernelU8S8.cc b/src/ExecuteKernelU8S8.cc index f7292fd..0a4ff55 100644 --- a/src/ExecuteKernelU8S8.cc +++ b/src/ExecuteKernelU8S8.cc @@ -49,7 +49,8 @@ ExecuteKernel< throw std::runtime_error("Failed to initialize cpuinfo!"); } if (params) { - if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) { + if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support() || + fbgemmHasAvx2Support()) { mbSize_ = params->MCB; nbSize_ = params->NCB; nrMinSize_ = params->NR_MIN; @@ -59,7 +60,20 @@ ExecuteKernel< assert(0 && "unsupported architecure"); } } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + mbSize_ = PackingTraits< + int8_t, + typename packingAMatrix::accType, + inst_set_t::avx512_vnni>::MCB; + nbSize_ = PackingTraits< + int8_t, + typename packingAMatrix::accType, + inst_set_t::avx512_vnni>::NCB; + nrMinSize_ = PackingTraits< + int8_t, + typename packingAMatrix::accType, + inst_set_t::avx512_vnni>::NR_MIN; + } else if (fbgemmHasAvx512Support()) { mbSize_ = PackingTraits< int8_t, typename packingAMatrix::accType, @@ -118,7 +132,25 @@ void ExecuteKernel< typename BaseType::jit_micro_kernel_fp fn; - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + if (std::is_same<typename packingAMatrix::accType, std::int16_t>::value) { + // For AVX512VNNI, we redirect int16_t to int32_t accumulation. + CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj; + fn = codeObj.getOrCreate<inst_set_t::avx512_vnni>( + accum, + packed_rows_A, + packedB_.blockColSize(), + packedA_.numPackedCols(), + nbSize_); + } else { + fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni>( + accum, + packed_rows_A, + packedB_.blockColSize(), + packedA_.numPackedCols(), + nbSize_); + } + } else if (fbgemmHasAvx512Support()) { fn = BaseType::template getOrCreate<inst_set_t::avx512>( accum, packed_rows_A, @@ -148,7 +180,10 @@ void ExecuteKernel< if (jb == bColBlocks - 1) { int nc = ((packedB_.lastBcol() - 1) / nrMinSize_ + 1) * nrMinSize_; if (nc != nbSize_) { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni>( + accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_); + } else if (fbgemmHasAvx512Support()) { fn = BaseType::template getOrCreate<inst_set_t::avx512>( accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_); } else if (fbgemmHasAvx2Support()) { @@ -213,7 +248,7 @@ void ExecuteKernel< int32_t nSize = C_buffer_start == C_tile_ ? jb * nbSize_ : packedB_.numCols(); if (nSize) { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support()) { // TODO: avx512 path // Currently use avx2 code outputProcess_.template f<inst_set_t::avx2>( @@ -238,7 +273,7 @@ void ExecuteKernel< if (C_buffer_start == C_tile_) { // When C_tile_ scratchpad was used to avoid accessing memory past // C_buffer_ . - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support()) { // TODO: avx512 path // Currently use avx2 code outputProcess_.template f<inst_set_t::avx2>( diff --git a/src/Fbgemm.cc b/src/Fbgemm.cc index 0f2f6fb..4f7026f 100644 --- a/src/Fbgemm.cc +++ b/src/Fbgemm.cc @@ -48,7 +48,8 @@ void fbgemmPacked( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) { + if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() && + !fbgemmHasAvx2Support())) { assert(0 && "unknown architecure"); } @@ -62,7 +63,20 @@ void fbgemmPacked( MR = blocking_params->MR; } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + MCB = PackingTraits< + typename packingAMatrix::inpType, + typename packingAMatrix::accType, + inst_set_t::avx512_vnni>::MCB; + KCB = PackingTraits< + typename packingAMatrix::inpType, + typename packingAMatrix::accType, + inst_set_t::avx512_vnni>::KCB; + MR = PackingTraits< + typename packingAMatrix::inpType, + typename packingAMatrix::accType, + inst_set_t::avx512_vnni>::MR; + } else if (fbgemmHasAvx512Support()) { MCB = PackingTraits< typename packingAMatrix::inpType, typename packingAMatrix::accType, diff --git a/src/GenerateKernel.h b/src/GenerateKernel.h index dccdfc5..e52097e 100644 --- a/src/GenerateKernel.h +++ b/src/GenerateKernel.h @@ -18,7 +18,7 @@ namespace fbgemm { namespace x86 = asmjit::x86; /** - * @brief AVX2/AVX512 JIT assembly code generator. + * @brief AVX2/AVX512/AVX512VNNI JIT assembly code generator. * @tparam TA Type of matrix A. * @tparam TB Type of matrix B. * @tparam TC Type of matrix C. @@ -104,7 +104,7 @@ class CodeGenBase { */ template <inst_set_t instSet> void initCRegs( - asmjit::X86Emitter* a, + x86::Emitter* a, int rowRegs, int colRegs, int leadingDimCRegAssign = 4); @@ -114,10 +114,10 @@ class CodeGenBase { */ template <inst_set_t instSet> void genComputeBlock( - asmjit::X86Emitter* a, - asmjit::X86Gp buffer_A, - asmjit::X86Gp buffer_B, - asmjit::X86Gp B_pf, + x86::Emitter* a, + x86::Gp buffer_A, + x86::Gp buffer_B, + x86::Gp B_pf, int rowRegs, int colRegs, int lda, @@ -129,11 +129,11 @@ class CodeGenBase { */ template <inst_set_t instSet> void storeCRegs( - asmjit::X86Emitter* a, + x86::Emitter* a, int rowRegs, int colRegs, - asmjit::X86Gp C_Offset, - asmjit::X86Gp ldcReg, + x86::Gp C_Offset, + x86::Gp ldcReg, bool accum, int leadingDimCRegAssign = 4); @@ -168,7 +168,9 @@ class CodeGenBase { fileName += "_MR-" + std::to_string(MR); fileName += "_NR-" + std::to_string(NR); fileName += "_NR_MIN-" + std::to_string(NR_MIN); - if (instSet == inst_set_t::avx512) { + if (instSet == inst_set_t::avx512_vnni) { + fileName += "_avx512vnni"; + } else if (instSet == inst_set_t::avx512) { fileName += "_avx512"; } else if (instSet == inst_set_t::avx2) { fileName += "_avx2"; @@ -178,12 +180,10 @@ class CodeGenBase { } private: - asmjit::X86Ymm - CRegs_avx2_[12]; ///< AVX2 ymm registers for C in the micro-kernel. - asmjit::X86Zmm + x86::Ymm CRegs_avx2_[12]; ///< AVX2 ymm registers for C in the micro-kernel. + x86::Zmm CRegs_avx512_[28]; ///< AVX512 zmm registers for C in the micro-kernel. - asmjit::X86Zmm - AllRegs_avx512_[32]; ///< all AVX512 zmm registers. + x86::Zmm AllRegs_avx512_[32]; ///< all AVX512 zmm registers. int vectorWidth_; ///< Vector width in bits. int VLEN_; ///< Vector width in elements. diff --git a/src/GenerateKernelU8S8S32ACC16.cc b/src/GenerateKernelU8S8S32ACC16.cc index 718b883..1e7e081 100644 --- a/src/GenerateKernelU8S8S32ACC16.cc +++ b/src/GenerateKernelU8S8S32ACC16.cc @@ -31,7 +31,7 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs< inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int rowRegs, int colRegs, int leadingDimCReg) { @@ -53,18 +53,18 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock< inst_set_t::avx2>( - asmjit::X86Emitter* a, - asmjit::X86Gp buffer_A, - asmjit::X86Gp buffer_B, - asmjit::X86Gp /* unused (reserved for prefetching)*/, + x86::Emitter* a, + x86::Gp buffer_A, + x86::Gp buffer_B, + x86::Gp /* unused (reserved for prefetching)*/, int rowRegs, int colRegs, int lda, int leadingDimCReg) { // used for matrix A - asmjit::X86Ymm AReg = x86::ymm12; + x86::Ymm AReg = x86::ymm12; - asmjit::X86Ymm tmpReg = x86::ymm14; + x86::Ymm tmpReg = x86::ymm14; for (int i = 0; i < rowRegs; ++i) { // broadcast A @@ -95,15 +95,15 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs< inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int rowRegs, int colRegs, - asmjit::X86Gp C_Offset, - asmjit::X86Gp ldcReg, + x86::Gp C_Offset, + x86::Gp ldcReg, bool accum, int leadingDimCReg) { - asmjit::X86Xmm extractDest128 = x86::xmm15; - asmjit::X86Ymm extractDest256 = x86::ymm15; + x86::Xmm extractDest128 = x86::xmm15; + x86::Ymm extractDest256 = x86::ymm15; for (int i = 0; i < rowRegs; ++i) { a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(i * sizeof(int32_t))); @@ -112,7 +112,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs< a->vextracti128( extractDest128, CRegs_avx2_[i * leadingDimCReg + j], idx); a->vpmovsxwd(extractDest256, extractDest128); - asmjit::X86Mem destAddr = x86::dword_ptr( + x86::Mem destAddr = x86::dword_ptr( a->zcx(), C_Offset, 0, (j * 2 + idx) * 8 * sizeof(int32_t)); if (accum) { a->vpaddd(extractDest256, extractDest256, destAddr); @@ -176,9 +176,9 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>( return codeCache_[kernelSig]; } code_.reset(false); - code_.init(rt_.getCodeInfo()); - asmjit::X86Assembler assembler(&code_); - asmjit::X86Emitter* a = assembler.asEmitter(); + code_.init(rt_.codeInfo()); + x86::Assembler assembler(&code_); + x86::Emitter* a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) // generated code logging @@ -207,46 +207,45 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>( //"nc must be equal to the number of register blocks"); // arguments to the function created - asmjit::X86Gp buffer_A = a->zdi(); - asmjit::X86Gp buffer_B = a->zsi(); - asmjit::X86Gp B_pf = a->zdx(); - asmjit::X86Gp CBase = a->zcx(); - asmjit::X86Gp kSize = a->gpzRef(8); - asmjit::X86Gp ldcReg = a->gpzRef(9); + x86::Gp buffer_A = a->zdi(); + x86::Gp buffer_B = a->zsi(); + x86::Gp B_pf = a->zdx(); + x86::Gp CBase = a->zcx(); + x86::Gp kSize = a->gpz(8); + x86::Gp ldcReg = a->gpz(9); asmjit::FuncDetail func; func.init( asmjit:: - FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( + FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( asmjit::CallConv::kIdHost)); - asmjit::FuncFrameInfo ffi; - ffi.setDirtyRegs( - asmjit::X86Reg::kKindVec, - asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); - ffi.setDirtyRegs( - asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14)); + asmjit::FuncFrame frame; + frame.init(func); + frame.setDirtyRegs( + x86::Reg::kGroupVec, + asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.setDirtyRegs( + x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14)); - asmjit::FuncArgsMapper args(&func); + asmjit::FuncArgsAssignment args(&func); args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); - args.updateFrameInfo(ffi); + args.updateFuncFrame(frame); + frame.finalize(); - asmjit::FuncFrameLayout layout; - layout.init(func, ffi); - - asmjit::FuncUtils::emitProlog(a, layout); - asmjit::FuncUtils::allocArgs(a, layout, args); + a->emitProlog(frame); + a->emitArgsAssignment(frame, args); asmjit::Label Loopk = a->newLabel(); asmjit::Label LoopMBlocks = a->newLabel(); - asmjit::X86Gp buffer_B_saved = a->gpzRef(10); - asmjit::X86Gp C_Offset = a->gpzRef(11); - // asmjit::X86Gp B_pf_saved = a->gpzRef(12); - asmjit::X86Gp iIdx = a->gpzRef(13); - asmjit::X86Gp kIdx = a->gpzRef(14); + x86::Gp buffer_B_saved = a->gpz(10); + x86::Gp C_Offset = a->gpz(11); + // x86::Gp B_pf_saved = a->gpz(12); + x86::Gp iIdx = a->gpz(13); + x86::Gp kIdx = a->gpz(14); int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; if (mRegBlocks > 0) { @@ -289,8 +288,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>( a->jl(Loopk); // store C matrix - storeCRegs<inst_set_t::avx2>( - a, rowRegs, colRegs, C_Offset, ldcReg, accum); + storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum); // increment A for next block a->sub(buffer_A, kSize); @@ -340,11 +338,10 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>( a->jl(LoopkRem); // store C matrix - storeCRegs<inst_set_t::avx2>( - a, rowRegs, colRegs, C_Offset, ldcReg, accum); + storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum); } - asmjit::FuncUtils::emitEpilog(a, layout); + a->emitEpilog(frame); jit_micro_kernel_fp fn; asmjit::Error err = rt_.add(&fn, &code_); diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc index c95757b..a49e440 100644 --- a/src/GenerateKernelU8S8S32ACC16Avx512.cc +++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc @@ -19,7 +19,7 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs< inst_set_t::avx512>( - asmjit::X86Emitter* a, + x86::Emitter* a, int rowRegs, int colRegs, int leadingDimCReg) { @@ -41,18 +41,18 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock< inst_set_t::avx512>( - asmjit::X86Emitter* a, - asmjit::X86Gp buffer_A, - asmjit::X86Gp buffer_B, - asmjit::X86Gp /* unused (reserved for prefetching)*/, + x86::Emitter* a, + x86::Gp buffer_A, + x86::Gp buffer_B, + x86::Gp /* unused (reserved for prefetching)*/, int rowRegs, int colRegs, int lda, int leadingDimCReg) { // used for matrix A - asmjit::X86Zmm AReg = x86::zmm29; + x86::Zmm AReg = x86::zmm29; - asmjit::X86Zmm tmpReg = x86::zmm30; + x86::Zmm tmpReg = x86::zmm30; // We start allocating BRegs from zmm27 and then allocate zmm26 and so on. for (int j = 0; j < colRegs; ++j) { @@ -66,8 +66,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock< a->vpbroadcastw( AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t))); for (int j = 0; j < colRegs; ++j) { - a->vpmaddubsw( - tmpReg, AReg, AllRegs_avx512_[27-j]); + a->vpmaddubsw(tmpReg, AReg, AllRegs_avx512_[27 - j]); a->vpaddsw( CRegs_avx512_[i * leadingDimCReg + j], tmpReg, @@ -90,15 +89,16 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs< inst_set_t::avx512>( - asmjit::X86Emitter* a, + x86::Emitter* a, int rowRegs, int colRegs, - asmjit::X86Gp C_Offset, - asmjit::X86Gp ldcReg, + x86::Gp C_Offset, + x86::Gp ldcReg, + bool accum, int leadingDimCReg) { - asmjit::X86Ymm extractDest256 = x86::ymm31; - asmjit::X86Zmm extractDest512 = x86::zmm31; + x86::Ymm extractDest256 = x86::ymm31; + x86::Zmm extractDest512 = x86::zmm31; for (int i = 0; i < rowRegs; ++i) { a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(i * sizeof(int32_t))); @@ -107,7 +107,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs< a->vextracti32x8( extractDest256, CRegs_avx512_[i * leadingDimCReg + j], idx); a->vpmovsxwd(extractDest512, extractDest256); - asmjit::X86Mem destAddr = x86::dword_ptr( + x86::Mem destAddr = x86::dword_ptr( a->zcx(), C_Offset, 0, (j * 2 + idx) * 16 * sizeof(int32_t)); if (accum) { a->vpaddd(extractDest512, extractDest512, destAddr); @@ -172,9 +172,9 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( } code_.reset(false); - code_.init(rt_.getCodeInfo()); - asmjit::X86Assembler assembler(&code_); - asmjit::X86Emitter* a = assembler.asEmitter(); + code_.init(rt_.codeInfo()); + x86::Assembler assembler(&code_); + x86::Emitter* a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) // generated code logging @@ -209,49 +209,49 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( int mRegBlocksRem = mc % mRegBlockSize; // arguments to the function created - asmjit::X86Gp buffer_A = a->zdi(); - asmjit::X86Gp buffer_B = a->zsi(); - asmjit::X86Gp B_pf = a->zdx(); - asmjit::X86Gp CBase = a->zcx(); - asmjit::X86Gp kSize = a->gpzRef(8); - asmjit::X86Gp ldcReg = a->gpzRef(9); + x86::Gp buffer_A = a->zdi(); + x86::Gp buffer_B = a->zsi(); + x86::Gp B_pf = a->zdx(); + x86::Gp CBase = a->zcx(); + x86::Gp kSize = a->gpz(8); + x86::Gp ldcReg = a->gpz(9); asmjit::FuncDetail func; func.init( asmjit:: - FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( + FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( asmjit::CallConv::kIdHost)); - asmjit::FuncFrameInfo ffi; - ffi.setDirtyRegs( - asmjit::X86Reg::kKindVec, - asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); - ffi.setDirtyRegs( - asmjit::X86Reg::kKindGp, - asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); + asmjit::FuncFrame frame; + frame.init(func); - asmjit::FuncArgsMapper args(&func); - args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); + frame.setDirtyRegs( + x86::Reg::kGroupVec, + asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.setDirtyRegs( + x86::Reg::kGroupGp, + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - args.updateFrameInfo(ffi); + asmjit::FuncArgsAssignment args(&func); + args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); - asmjit::FuncFrameLayout layout; - layout.init(func, ffi); + args.updateFuncFrame(frame); + frame.finalize(); - asmjit::FuncUtils::emitProlog(a, layout); - asmjit::FuncUtils::allocArgs(a, layout, args); + a->emitProlog(frame); + a->emitArgsAssignment(frame, args); asmjit::Label LoopMBlocks = a->newLabel(); asmjit::Label LoopNBlocks = a->newLabel(); asmjit::Label Loopk = a->newLabel(); - asmjit::X86Gp buffer_B_saved = a->gpzRef(10); - asmjit::X86Gp C_Offset = a->gpzRef(11); - // asmjit::X86Gp B_pf_saved = a->gpzRef(12); - asmjit::X86Gp iIdx = a->gpzRef(13); - asmjit::X86Gp jIdx = a->gpzRef(14); - asmjit::X86Gp kIdx = a->gpzRef(15); + x86::Gp buffer_B_saved = a->gpz(10); + x86::Gp C_Offset = a->gpz(11); + // x86::Gp B_pf_saved = a->gpz(12); + x86::Gp iIdx = a->gpz(13); + x86::Gp jIdx = a->gpz(14); + x86::Gp kIdx = a->gpz(15); // save B_buffer address a->mov(buffer_B_saved, buffer_B); @@ -407,7 +407,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( a->jl(LoopNRem); } - asmjit::FuncUtils::emitEpilog(a, layout); + a->emitEpilog(frame); jit_micro_kernel_fp fn; asmjit::Error err = rt_.add(&fn, &code_); diff --git a/src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc b/src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc new file mode 100644 index 0000000..f559aba --- /dev/null +++ b/src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc @@ -0,0 +1,102 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include <iostream> +#include "GenerateKernel.h" + +namespace fbgemm { + +namespace x86 = asmjit::x86; + +/** + * Generate AVX512 instructions for initializing the C registers to 0 in 16-bit + * Accumulation kernel. + */ +template <> +template <> +void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs< + inst_set_t::avx512_vnni>( + x86::Emitter* a, + int rowRegs, + int colRegs, + int leadingDimCReg) { + assert(0 && "Accumulation to int16_t is not available for VNNI!"); + + // For AVX512VNNI, redirect to int32_t accumulation. + CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj; + codeObj.initCRegs<inst_set_t::avx512_vnni>( + a, rowRegs, colRegs, leadingDimCReg); +} + +/** + * Generate AVX512 instructions for computing block in the rank-k update of + * 16-bit Accmulation kernel. + */ +template <> +template <> +void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock< + inst_set_t::avx512_vnni>( + x86::Emitter* a, + x86::Gp buffer_A, + x86::Gp buffer_B, + x86::Gp /* unused (reserved for prefetching)*/, + int rowRegs, + int colRegs, + int lda, + int leadingDimCReg) { + assert(0 && "Accumulation to int16_t is not available for VNNI!"); + + // For AVX512VNNI, redirect to int32_t accumulation. + CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj; + codeObj.genComputeBlock<inst_set_t::avx512_vnni>( + a, buffer_A, buffer_B, buffer_B, rowRegs, colRegs, lda, leadingDimCReg); +} + +/** + * Generate AVX512 instructions for storing the C registers back to the memory + * in 16-bit Accumulation kernel. + */ +template <> +template <> +void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs< + inst_set_t::avx512_vnni>( + x86::Emitter* a, + int rowRegs, + int colRegs, + x86::Gp C_Offset, + x86::Gp ldcReg, + bool accum, + int leadingDimCReg) { + assert(0 && "Accumulation to int16_t is not available for VNNI!"); + + // For AVX512VNNI, redirect to int32_t accumulation. + CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj; + codeObj.storeCRegs<inst_set_t::avx512_vnni>( + a, rowRegs, colRegs, C_Offset, ldcReg, accum, leadingDimCReg); +} + +/** + * Get or Create the AVX512 instructions for 16-bit Accumulation macro-kernel. + * + */ +template <> +template <> +CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::jit_micro_kernel_fp +CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate< + inst_set_t::avx512_vnni>( + bool accum, + int32_t mc, + int32_t nc, + int32_t kc, + int32_t /* unused */) { + assert(0 && "Accumulation to int16_t is not available for VNNI!"); + + // For AVX512VNNI, redirect to int32_t accumulation. + CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj; + return codeObj.getOrCreate<inst_set_t::avx512_vnni>(accum, mc, nc, kc, kc); +} + +} // namespace fbgemm diff --git a/src/GenerateKernelU8S8S32ACC32.cc b/src/GenerateKernelU8S8S32ACC32.cc index 58643ad..6b54743 100644 --- a/src/GenerateKernelU8S8S32ACC32.cc +++ b/src/GenerateKernelU8S8S32ACC32.cc @@ -31,7 +31,7 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs< inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int rowRegs, int colRegs, int leadingDimCReg) { @@ -53,25 +53,25 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock< inst_set_t::avx2>( - asmjit::X86Emitter* a, - asmjit::X86Gp buffer_A, - asmjit::X86Gp buffer_B, - asmjit::X86Gp B_pf, + x86::Emitter* a, + x86::Gp buffer_A, + x86::Gp buffer_B, + x86::Gp B_pf, int rowRegs, int colRegs, int lda, int leadingDimCReg) { // used for matrix A - asmjit::X86Ymm AReg = x86::ymm12; + x86::Ymm AReg = x86::ymm12; // used for matrix B - asmjit::X86Ymm BReg = x86::ymm13; + x86::Ymm BReg = x86::ymm13; // Contains 16-bit 1s - asmjit::X86Ymm oneReg = x86::ymm15; + x86::Ymm oneReg = x86::ymm15; // temporary register - asmjit::X86Ymm res1 = x86::ymm14; + x86::Ymm res1 = x86::ymm14; for (int j = 0; j < colRegs; ++j) { // load B @@ -99,11 +99,11 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs< inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int rowRegs, int colRegs, - asmjit::X86Gp C_Offset, - asmjit::X86Gp ldcReg, + x86::Gp C_Offset, + x86::Gp ldcReg, bool accum, int leadingDimCReg) { for (int i = 0; i < rowRegs; ++i) { @@ -177,9 +177,9 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>( return codeCache_[kernelSig]; } code_.reset(false); - code_.init(rt_.getCodeInfo()); - asmjit::X86Assembler assembler(&code_); - asmjit::X86Emitter* a = assembler.asEmitter(); + code_.init(rt_.codeInfo()); + x86::Assembler assembler(&code_); + x86::Emitter* a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) // generated code logging FILE* codeLogfile = fopen( @@ -205,49 +205,48 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>( int mRegBlocksRem = mc % mRegBlockSize; // arguments to the function created - asmjit::X86Gp buffer_A = a->zdi(); - asmjit::X86Gp buffer_B = a->zsi(); - asmjit::X86Gp B_pf = a->zdx(); - asmjit::X86Gp CBase = a->zcx(); - asmjit::X86Gp kSize = a->gpzRef(8); - asmjit::X86Gp ldcReg = a->gpzRef(9); + x86::Gp buffer_A = a->zdi(); + x86::Gp buffer_B = a->zsi(); + x86::Gp B_pf = a->zdx(); + x86::Gp CBase = a->zcx(); + x86::Gp kSize = a->gpz(8); + x86::Gp ldcReg = a->gpz(9); asmjit::FuncDetail func; func.init( asmjit:: - FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( + FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( asmjit::CallConv::kIdHost)); - asmjit::FuncFrameInfo ffi; - ffi.setDirtyRegs( - asmjit::X86Reg::kKindVec, - asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); - ffi.setDirtyRegs( - asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14)); + asmjit::FuncFrame frame; + frame.init(func); + frame.setDirtyRegs( + x86::Reg::kGroupVec, + asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.setDirtyRegs( + x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14)); - asmjit::FuncArgsMapper args(&func); + asmjit::FuncArgsAssignment args(&func); args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); - args.updateFrameInfo(ffi); + args.updateFuncFrame(frame); + frame.finalize(); - asmjit::FuncFrameLayout layout; - layout.init(func, ffi); - - asmjit::FuncUtils::emitProlog(a, layout); - asmjit::FuncUtils::allocArgs(a, layout, args); + a->emitProlog(frame); + a->emitArgsAssignment(frame, args); asmjit::Label Loopk = a->newLabel(); asmjit::Label LoopMBlocks = a->newLabel(); - asmjit::X86Gp buffer_B_saved = a->gpzRef(10); - asmjit::X86Gp C_Offset = a->gpzRef(11); - asmjit::X86Gp B_pf_saved = a->gpzRef(12); - asmjit::X86Gp iIdx = a->gpzRef(13); - asmjit::X86Gp kIdx = a->gpzRef(14); - // asmjit::X86Gp B_pf = a->gpzRef(8); + x86::Gp buffer_B_saved = a->gpz(10); + x86::Gp C_Offset = a->gpz(11); + x86::Gp B_pf_saved = a->gpz(12); + x86::Gp iIdx = a->gpz(13); + x86::Gp kIdx = a->gpz(14); + // x86::Gp B_pf = a->gpz(8); - asmjit::X86Ymm oneReg = x86::ymm15; + x86::Ymm oneReg = x86::ymm15; // create 16-bit 1s // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001 // and so on @@ -358,7 +357,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>( a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); } - asmjit::FuncUtils::emitEpilog(a, layout); + a->emitEpilog(frame); jit_micro_kernel_fp fn; asmjit::Error err = rt_.add(&fn, &code_); diff --git a/src/GenerateKernelU8S8S32ACC32Avx512.cc b/src/GenerateKernelU8S8S32ACC32Avx512.cc index 12243ee..fe35627 100644 --- a/src/GenerateKernelU8S8S32ACC32Avx512.cc +++ b/src/GenerateKernelU8S8S32ACC32Avx512.cc @@ -19,7 +19,7 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs< inst_set_t::avx512>( - asmjit::X86Emitter* a, + x86::Emitter* a, int rowRegs, int colRegs, int leadingDimCReg) { @@ -41,25 +41,25 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock< inst_set_t::avx512>( - asmjit::X86Emitter* a, - asmjit::X86Gp buffer_A, - asmjit::X86Gp buffer_B, - asmjit::X86Gp B_pf, + x86::Emitter* a, + x86::Gp buffer_A, + x86::Gp buffer_B, + x86::Gp B_pf, int rowRegs, int colRegs, int lda, int leadingDimCReg) { // used for matrix A - asmjit::X86Zmm AReg = x86::zmm31; + x86::Zmm AReg = x86::zmm31; // used for matrix B - asmjit::X86Zmm BReg = x86::zmm30; + x86::Zmm BReg = x86::zmm30; // Contains 16-bit 1s - asmjit::X86Zmm oneReg = x86::zmm29; + x86::Zmm oneReg = x86::zmm29; // temporary register - asmjit::X86Zmm res1 = x86::zmm28; + x86::Zmm res1 = x86::zmm28; for (int j = 0; j < colRegs; ++j) { // load B @@ -87,18 +87,17 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs< inst_set_t::avx512>( - asmjit::X86Emitter* a, + x86::Emitter* a, int rowRegs, int colRegs, - asmjit::X86Gp C_Offset, - asmjit::X86Gp ldcReg, + x86::Gp C_Offset, + x86::Gp ldcReg, bool accum, int leadingDimCReg) { for (int i = 0; i < rowRegs; ++i) { if (i != 0) { a->add(C_Offset, ldcReg); - } - else { + } else { a->mov(C_Offset, static_cast<asmjit::Imm>(0)); } for (int j = 0; j < colRegs; ++j) { @@ -168,9 +167,9 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>( return codeCache_[kernelSig]; } code_.reset(false); - code_.init(rt_.getCodeInfo()); - asmjit::X86Assembler assembler(&code_); - asmjit::X86Emitter* a = assembler.asEmitter(); + code_.init(rt_.codeInfo()); + x86::Assembler assembler(&code_); + x86::Emitter* a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) // generated code logging @@ -205,52 +204,52 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>( int mRegBlocksRem = mc % mRegBlockSize; // arguments to the function created - asmjit::X86Gp buffer_A = a->zdi(); - asmjit::X86Gp buffer_B = a->zsi(); - asmjit::X86Gp B_pf = a->zdx(); - asmjit::X86Gp CBase = a->zcx(); - asmjit::X86Gp kSize = a->gpzRef(8); - asmjit::X86Gp ldcReg = a->gpzRef(9); + x86::Gp buffer_A = a->zdi(); + x86::Gp buffer_B = a->zsi(); + x86::Gp B_pf = a->zdx(); + x86::Gp CBase = a->zcx(); + x86::Gp kSize = a->gpz(8); + x86::Gp ldcReg = a->gpz(9); asmjit::FuncDetail func; func.init( asmjit:: - FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( + FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( asmjit::CallConv::kIdHost)); - asmjit::FuncFrameInfo ffi; - ffi.setDirtyRegs( - asmjit::X86Reg::kKindVec, - asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); - ffi.setDirtyRegs( - asmjit::X86Reg::kKindGp, - asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); + asmjit::FuncFrame frame; + frame.init(func); - asmjit::FuncArgsMapper args(&func); - args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); + frame.setDirtyRegs( + x86::Reg::kGroupVec, + asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.setDirtyRegs( + x86::Reg::kGroupGp, + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - args.updateFrameInfo(ffi); + asmjit::FuncArgsAssignment args(&func); + args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); - asmjit::FuncFrameLayout layout; - layout.init(func, ffi); + args.updateFuncFrame(frame); + frame.finalize(); - asmjit::FuncUtils::emitProlog(a, layout); - asmjit::FuncUtils::allocArgs(a, layout, args); + a->emitProlog(frame); + a->emitArgsAssignment(frame, args); asmjit::Label LoopMBlocks = a->newLabel(); asmjit::Label LoopNBlocks = a->newLabel(); asmjit::Label Loopk = a->newLabel(); - asmjit::X86Gp buffer_B_saved = a->gpzRef(10); - asmjit::X86Gp C_Offset = a->gpzRef(11); - asmjit::X86Gp B_pf_saved = a->gpzRef(12); - asmjit::X86Gp iIdx = a->gpzRef(13); - asmjit::X86Gp jIdx = a->gpzRef(14); - asmjit::X86Gp kIdx = a->gpzRef(15); - // asmjit::X86Gp B_pf = a->gpzRef(8); + x86::Gp buffer_B_saved = a->gpz(10); + x86::Gp C_Offset = a->gpz(11); + x86::Gp B_pf_saved = a->gpz(12); + x86::Gp iIdx = a->gpz(13); + x86::Gp jIdx = a->gpz(14); + x86::Gp kIdx = a->gpz(15); + // x86::Gp B_pf = a->gpz(8); - asmjit::X86Zmm oneReg = x86::zmm29; + x86::Zmm oneReg = x86::zmm29; // create 16-bit 1s // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001 // and so on @@ -420,7 +419,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>( a->jl(LoopNRem); } - asmjit::FuncUtils::emitEpilog(a, layout); + a->emitEpilog(frame); jit_micro_kernel_fp fn; asmjit::Error err = rt_.add(&fn, &code_); diff --git a/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc b/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc new file mode 100644 index 0000000..8ae0745 --- /dev/null +++ b/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc @@ -0,0 +1,431 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include <iostream> +#include "GenerateKernel.h" + +namespace fbgemm { + +namespace x86 = asmjit::x86; + +/** + * Generate AVX512 instructions for initializing the C registers to 0 in 32-bit + * Accumulation kernel. + */ +template <> +template <> +void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs< + inst_set_t::avx512_vnni>( + x86::Emitter* a, + int rowRegs, + int colRegs, + int leadingDimCReg) { + for (int i = 0; i < rowRegs; ++i) { + for (int j = 0; j < colRegs; ++j) { + a->vxorps( + CRegs_avx512_[i * leadingDimCReg + j], + CRegs_avx512_[i * leadingDimCReg + j], + CRegs_avx512_[i * leadingDimCReg + j]); + } + } +} + +/** + * Generate AVX512 instructions for computing block in the rank-k update of + * 32-bit Accmulation kernel. + */ +template <> +template <> +void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock< + inst_set_t::avx512_vnni>( + x86::Emitter* a, + x86::Gp buffer_A, + x86::Gp buffer_B, + x86::Gp B_pf, + int rowRegs, + int colRegs, + int lda, + int leadingDimCReg) { + // used for matrix A + x86::Zmm AReg = x86::zmm31; + + // used for matrix B + x86::Zmm BReg = x86::zmm30; + + for (int j = 0; j < colRegs; ++j) { + // load B + a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t))); + // load A, broadcast and fmas + for (int i = 0; i < rowRegs; ++i) { + a->vpbroadcastd( + AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t))); + a->vpdpbusd(CRegs_avx512_[i * leadingDimCReg + j], AReg, BReg); + } + a->prefetcht0(x86::dword_ptr(B_pf, j * VLEN_ * sizeof(int8_t))); + } +} + +/** + * Generate AVX512 instructions for storing the C registers back to the memory + * in 32-bit Accumulation kernel. + */ +template <> +template <> +void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs< + inst_set_t::avx512_vnni>( + x86::Emitter* a, + int rowRegs, + int colRegs, + x86::Gp C_Offset, + x86::Gp ldcReg, + bool accum, + int leadingDimCReg) { + for (int i = 0; i < rowRegs; ++i) { + if (i != 0) { + a->add(C_Offset, ldcReg); + } else { + a->mov(C_Offset, static_cast<asmjit::Imm>(0)); + } + for (int j = 0; j < colRegs; ++j) { + if (accum) { + a->vpaddd( + CRegs_avx512_[i * leadingDimCReg + j], + CRegs_avx512_[i * leadingDimCReg + j], + x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t))); + } + a->vmovups( + x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)), + CRegs_avx512_[i * leadingDimCReg + j]); + } + } +} + +/** + * Get or Create the AVX512 instructions for 32-bit Accumulation macro-kernel. + * + */ +template <> +template <> +CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::jit_micro_kernel_fp +CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate< + inst_set_t::avx512_vnni>( + bool accum, + int32_t mc, + int32_t nc, + int32_t kc, + int32_t /* unused */) { + std::tuple<bool, int, int, int, int, int, int, int> kernelSig; + int kBlock; + int nBlock; + int mRegBlockSize; + int nRegBlockSize; + int nRegBlockSizeMin; + int row_interleave; + + if (blocking_params) { + kBlock = blocking_params->KCB; + nBlock = blocking_params->NCB; + mRegBlockSize = blocking_params->MR; + nRegBlockSize = blocking_params->NR; + nRegBlockSizeMin = blocking_params->NR_MIN; + row_interleave = blocking_params->ROW_INTERLEAVE; + } else { + kBlock = PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::KCB; + nBlock = PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::NCB; + mRegBlockSize = + PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::MR; + nRegBlockSize = + PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::NR; + nRegBlockSizeMin = + PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::NR_MIN; + row_interleave = PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>:: + ROW_INTERLEAVE; + } + + kernelSig = std::make_tuple( + accum, + mc, + nc, + nBlock, + kBlock, + mRegBlockSize, + nRegBlockSize, + nRegBlockSizeMin); + + if (codeCache_.find(kernelSig) != codeCache_.end()) { + return codeCache_[kernelSig]; + } + code_.reset(false); + code_.init(rt_.codeInfo()); + x86::Assembler assembler(&code_); + x86::Emitter* a = assembler.as<x86::Emitter>(); + +#if defined(FBGEMM_LOG_CODE) + // generated code logging + FILE* codeLogfile = fopen( + getCodeLoggingFile<inst_set_t::avx512_vnni>( + accum, + mc, + nc, + nBlock, + kBlock, + mRegBlockSize, + nRegBlockSize, + nRegBlockSizeMin) + .c_str(), + "w"); + asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); + if (codeLogger) { + code_.setLogger(codeLogger); + } +#endif + + assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); + assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN"); + int maxMRegs = mRegBlockSize; + int maxNRegs = nRegBlockSize * row_interleave / VLEN_; + assert( + maxMRegs * maxNRegs <= 28 && + "MR*(NR*ROW_INTERLEAVE*8/512) \ + must be <= 28(available registers constraint)"); + + int mRegBlocks = mc / mRegBlockSize; + int mRegBlocksRem = mc % mRegBlockSize; + + // arguments to the function created + x86::Gp buffer_A = a->zdi(); + x86::Gp buffer_B = a->zsi(); + x86::Gp B_pf = a->zdx(); + x86::Gp CBase = a->zcx(); + x86::Gp kSize = a->gpz(8); + x86::Gp ldcReg = a->gpz(9); + + asmjit::FuncDetail func; + func.init( + asmjit:: + FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( + asmjit::CallConv::kIdHost)); + + asmjit::FuncFrame frame; + frame.init(func); + + frame.setDirtyRegs( + x86::Reg::kGroupVec, + asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.setDirtyRegs( + x86::Reg::kGroupGp, + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + + asmjit::FuncArgsAssignment args(&func); + args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); + + args.updateFuncFrame(frame); + frame.finalize(); + + a->emitProlog(frame); + a->emitArgsAssignment(frame, args); + + asmjit::Label LoopMBlocks = a->newLabel(); + asmjit::Label LoopNBlocks = a->newLabel(); + asmjit::Label Loopk = a->newLabel(); + + x86::Gp buffer_B_saved = a->gpz(10); + x86::Gp C_Offset = a->gpz(11); + x86::Gp B_pf_saved = a->gpz(12); + x86::Gp iIdx = a->gpz(13); + x86::Gp jIdx = a->gpz(14); + x86::Gp kIdx = a->gpz(15); + // x86::Gp B_pf = a->gpz(8); + + x86::Zmm oneReg = x86::zmm29; + // create 16-bit 1s + // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001 + // and so on + // a->vpcmpeqw(oneReg, oneReg, oneReg); + a->vpternlogd(oneReg, oneReg, oneReg, 0xff); + a->vpsrlw(oneReg, oneReg, 15); + a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t))); + + // save B_buffer address + a->mov(buffer_B_saved, buffer_B); + a->mov(B_pf_saved, B_pf); + + int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; + int colRegs = std::min(currColRegs, maxNRegs); + if (mRegBlocks > 0) { + // move 0 to iteration variables + a->mov(iIdx, 0); + + a->bind(LoopMBlocks); + a->inc(iIdx); + a->mov(jIdx, 0); + + a->bind(LoopNBlocks); + a->inc(jIdx); + + int rowRegs = mRegBlockSize; + + // init C registers + initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(Loopk); + + // k is incremented by row_interleave + a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); + + genComputeBlock<inst_set_t::avx512_vnni>( + a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); + + // update buffer_A address for next k iteration + a->add( + buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); + + // update buffer_B address for next k iteration + a->add( + buffer_B, + static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); + a->add( + B_pf, + static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); + + // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float))); + + a->cmp(kIdx, kSize); + a->jl(Loopk); + + // store C matrix + storeCRegs<inst_set_t::avx512_vnni>( + a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); + + // reset A + a->sub(buffer_A, kSize); + + // B for next block + a->mov(buffer_B, buffer_B_saved); + // using C_Offset as temp reg + a->imul( + C_Offset, + jIdx, + static_cast<asmjit::Imm>( + nRegBlockSize * row_interleave * sizeof(int8_t))); + a->add(buffer_B, C_Offset); + a->mov(B_pf, B_pf_saved); + a->add(B_pf, C_Offset); + + // increment C for next B block + a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); + + int jLoopTrips = currColRegs / maxNRegs; + // jLoopTrips should be at least 1 + jLoopTrips = jLoopTrips ? jLoopTrips : 1; + a->cmp(jIdx, jLoopTrips); + a->jl(LoopNBlocks); + + // increment A for next block + a->add( + buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t))); + + // increment C for next A block + a->sub( + CBase, + static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * sizeof(int32_t))); + a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs)); + a->add(CBase, C_Offset); + + // reset B + a->mov(buffer_B, buffer_B_saved); + a->mov(B_pf, B_pf_saved); + a->cmp(iIdx, mRegBlocks); + a->jl(LoopMBlocks); + } + // generate code for remainder + if (mRegBlocksRem > 0) { + asmjit::Label LoopNRem = a->newLabel(); + asmjit::Label LoopkRem = a->newLabel(); + int rowRegs = mRegBlocksRem; + + a->mov(jIdx, 0); + a->bind(LoopNRem); + a->inc(jIdx); + + // init C registers + initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(LoopkRem); + + // k is incremented by row_interleave + a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); + + genComputeBlock<inst_set_t::avx512_vnni>( + a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); + + // update buffer_A address for next k iteration + a->add( + buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); + + // update buffer_B address for next k iteration + a->add( + buffer_B, + static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); + a->add( + B_pf, + static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); + + a->cmp(kIdx, kSize); + a->jl(LoopkRem); + + // reset A + a->sub(buffer_A, kSize); + // B for next block + // using C_Offset as temp reg + a->imul( + C_Offset, + jIdx, + static_cast<asmjit::Imm>( + nRegBlockSize * row_interleave * sizeof(int8_t))); + a->mov(buffer_B, buffer_B_saved); + a->add(buffer_B, C_Offset); + a->mov(B_pf, B_pf_saved); + a->add(B_pf, C_Offset); + + // store C matrix + storeCRegs<inst_set_t::avx512_vnni>( + a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); + + // increment C for next B block + a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); + + int jLoopTrips = currColRegs / maxNRegs; + // jLoopTrips should be at least 1 + jLoopTrips = jLoopTrips ? jLoopTrips : 1; + a->cmp(jIdx, jLoopTrips); + a->jl(LoopNRem); + } + + a->emitEpilog(frame); + + jit_micro_kernel_fp fn; + asmjit::Error err = rt_.add(&fn, &code_); + if (err) { + std::cout << "Error: in fn add" << std::endl; + return nullptr; + } + codeCache_[kernelSig] = fn; + +#if defined(FBGEMM_LOG_CODE) + fclose(codeLogfile); + delete codeLogger; +#endif + + return fn; +} + +} // namespace fbgemm diff --git a/src/GroupwiseConv.h b/src/GroupwiseConv.h index 1e6324e..4c5eea5 100644 --- a/src/GroupwiseConv.h +++ b/src/GroupwiseConv.h @@ -128,60 +128,58 @@ class GenConvKernel { const conv_param_t<SPATIAL_DIM>& conv_param); template <inst_set_t instSet> - void createVector16BitOne(asmjit::X86Emitter* a); + void createVector16BitOne(x86::Emitter* a); template <inst_set_t instSet> - void createVector8BitOne(asmjit::X86Emitter* a); + void createVector8BitOne(x86::Emitter* a); template <inst_set_t instSet> - void setToZeroPt(asmjit::X86Emitter* a, asmjit::X86Ymm destReg); + void setToZeroPt(x86::Emitter* a, x86::Ymm destReg); template <inst_set_t instSet> - void - gen8bitFMA(asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm wReg); + void gen8bitFMA(x86::Emitter* a, x86::Ymm aReg, x86::Ymm wReg); template <inst_set_t instSet> - void genForLoadingWeights(asmjit::X86Emitter* a, int c_offset); + void genForLoadingWeights(x86::Emitter* a, int c_offset); template <inst_set_t instSet> - void genConstForPermutations(asmjit::X86Emitter* a); + void genConstForPermutations(x86::Emitter* a); template <inst_set_t instSet> - void genForTopEdge(asmjit::X86Emitter* a, int c_offset); + void genForTopEdge(x86::Emitter* a, int c_offset); template <inst_set_t instSet> - void genForLeftEdge(asmjit::X86Emitter* a, int c_offset); + void genForLeftEdge(x86::Emitter* a, int c_offset); template <inst_set_t instSet> - void genForRightEdge(asmjit::X86Emitter* a, int c_offset); + void genForRightEdge(x86::Emitter* a, int c_offset); template <inst_set_t instSet> - void genForBottomEdge(asmjit::X86Emitter* a, int c_offset); + void genForBottomEdge(x86::Emitter* a, int c_offset); template <inst_set_t instSet> - void genCoreInsts(asmjit::X86Emitter* a, int c_offset); + void genCoreInsts(x86::Emitter* a, int c_offset); template <inst_set_t instSet> - void storeResult(asmjit::X86Emitter* a); + void storeResult(x86::Emitter* a); // for Rowoffset kernel // Add 4 consecutive numbers of 32 uint8 and emit 8 32-bit template <inst_set_t instSet> - void gen8BitSumX4(asmjit::X86Emitter* a, asmjit::X86Ymm aReg); + void gen8BitSumX4(x86::Emitter* a, x86::Ymm aReg); // Add 8 consecutive numbers of 64 uint8 and emit 8 32-bit template <inst_set_t instSet> - void - gen8BitSumX8(asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm bReg); + void gen8BitSumX8(x86::Emitter* a, x86::Ymm aReg, x86::Ymm bReg); // Add 16 consecutive numbers of 128 uint8 and emit 8 32-bit template <inst_set_t instSet> void gen8BitSumX16( - asmjit::X86Emitter* a, - asmjit::X86Ymm aReg, - asmjit::X86Ymm bReg, - asmjit::X86Ymm cReg, - asmjit::X86Ymm dReg); + x86::Emitter* a, + x86::Ymm aReg, + x86::Ymm bReg, + x86::Ymm cReg, + x86::Ymm dReg); // Generate instruction sequence that loads 8-bit values and sum them up. // Depending on C_per_G_, this function dispatches to gen8BitSumX4/8/16 @@ -191,35 +189,33 @@ class GenConvKernel { // Internally, actRegAvx2_, stPermRegAvx2_, WRegs_avx2_[0, 1], tmpReg1Avx2_, // and resultRegAvx2_ are used. template <inst_set_t instSet> - void gen8BitSum( - asmjit::X86Emitter* a, - int act_offset, - bool use_scratch_reg1 = true); + void + gen8BitSum(x86::Emitter* a, int act_offset, bool use_scratch_reg1 = true); // Use scratchReg1_ and tmpReg1Avx2_ internally template <inst_set_t instSet> - void genZeroPtSum(asmjit::X86Emitter* a, int multiplier); + void genZeroPtSum(x86::Emitter* a, int multiplier); template <inst_set_t instSet> - void genForTopEdgeRowoffset(asmjit::X86Emitter* a); + void genForTopEdgeRowoffset(x86::Emitter* a); template <inst_set_t instSet> - void genForLeftEdgeRowoffset(asmjit::X86Emitter* a); + void genForLeftEdgeRowoffset(x86::Emitter* a); template <inst_set_t instSet> - void genForRightEdgeRowoffset(asmjit::X86Emitter* a); + void genForRightEdgeRowoffset(x86::Emitter* a); template <inst_set_t instSet> - void genForBottomEdgeRowoffset(asmjit::X86Emitter* a); + void genForBottomEdgeRowoffset(x86::Emitter* a); template <inst_set_t instSet> - void genRowoffsetCorners(asmjit::X86Emitter* a); + void genRowoffsetCorners(x86::Emitter* a); template <inst_set_t instSet> - void genRowoffsetCore(asmjit::X86Emitter* a); + void genRowoffsetCore(x86::Emitter* a); template <inst_set_t instSet> - void storeResultRowoffset(asmjit::X86Emitter* a, int offset = 0); + void storeResultRowoffset(x86::Emitter* a, int offset = 0); static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit. static thread_local asmjit::CodeHolder code_; ///< JIT Code Holder for asmjit. @@ -234,30 +230,30 @@ class GenConvKernel { int vectorWidth_; ///< Vector width in bits. int VLEN_; ///< Vector width in elements. // avx2 specific - asmjit::X86Ymm + x86::Ymm WRegs_avx2_[9]; ///< AVX2 ymm registers for weights in the micro-kernel. - asmjit::X86Ymm zeroPTRegAvx2_; - asmjit::X86Ymm tmpReg1Avx2_; - asmjit::X86Ymm stPermRegAvx2_; - asmjit::X86Ymm actRegAvx2_; - asmjit::X86Ymm resultRegAvx2_; - asmjit::X86Ymm oneReg8BitAvx2_; - asmjit::X86Ymm oneReg16BitAvx2_; + x86::Ymm zeroPTRegAvx2_; + x86::Ymm tmpReg1Avx2_; + x86::Ymm stPermRegAvx2_; + x86::Ymm actRegAvx2_; + x86::Ymm resultRegAvx2_; + x86::Ymm oneReg8BitAvx2_; + x86::Ymm oneReg16BitAvx2_; // arguments to the function created - asmjit::X86Gp in_acts_R_; - asmjit::X86Gp wghts_R_; - asmjit::X86Gp out_acts_R_; - asmjit::X86Gp a_zero_pt_R_; - asmjit::X86Gp H_R_; - asmjit::X86Gp W_R_; - asmjit::X86Gp row_offset_R_; + x86::Gp in_acts_R_; + x86::Gp wghts_R_; + x86::Gp out_acts_R_; + x86::Gp a_zero_pt_R_; + x86::Gp H_R_; + x86::Gp W_R_; + x86::Gp row_offset_R_; // Used registers - asmjit::X86Gp loopR1_; - asmjit::X86Gp loopR2_; - asmjit::X86Gp scratchReg1_; - asmjit::X86Gp scratchReg2_; + x86::Gp loopR1_; + x86::Gp loopR2_; + x86::Gp scratchReg1_; + x86::Gp scratchReg2_; // Other parameters bool isAZeroPointZero_; diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc index e789695..b140c83 100644 --- a/src/GroupwiseConvAcc32Avx2.cc +++ b/src/GroupwiseConvAcc32Avx2.cc @@ -104,7 +104,7 @@ jit_conv_kernel_fp getOrCreateConvKernel( template <> template <> void GenConvKernel<2, int32_t>::createVector8BitOne<inst_set_t::avx2>( - asmjit::X86Emitter* a) { + x86::Emitter* a) { // create 8-bit 1s // i.e., oneReg16BitAvx2_[0:7] contains 0x01, oneReg8BitAvx2_[8:15] contains // 0x01 and so on @@ -115,7 +115,7 @@ void GenConvKernel<2, int32_t>::createVector8BitOne<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::createVector16BitOne<inst_set_t::avx2>( - asmjit::X86Emitter* a) { + x86::Emitter* a) { // create 16-bit 1s // i.e., oneReg16BitAvx2_[0:15] contains 0x0001, oneReg16BitAvx2_[16:31] // contains 0x0001 and so on @@ -125,11 +125,11 @@ void GenConvKernel<2, int32_t>::createVector16BitOne<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::setToZeroPt<inst_set_t::avx2>( - asmjit::X86Emitter* a, - asmjit::X86Ymm destReg) { + x86::Emitter* a, + x86::Ymm destReg) { // make destReg all zeros a->vxorps(destReg, destReg, destReg); - asmjit::X86Xmm const_reg_xmm = x86::xmm10; + x86::Xmm const_reg_xmm = x86::xmm10; // move zero point to xmm10 a->movq(const_reg_xmm, a_zero_pt_R_); // make copies of zero point @@ -143,9 +143,9 @@ void GenConvKernel<2, int32_t>::setToZeroPt<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genConstForPermutations<inst_set_t::avx2>( - asmjit::X86Emitter* a) { - asmjit::X86Gp permute_const_reg = a->gpzRef(12); - asmjit::X86Xmm const_reg_xmm = x86::xmm10; + x86::Emitter* a) { + x86::Gp permute_const_reg = a->gpz(12); + x86::Xmm const_reg_xmm = x86::xmm10; // We have 1st group in even lanes and 2nd group in odd lanes. // Permute to put 1st group to lower 128-bit and 2nd group in upper // 128-bit. @@ -159,8 +159,7 @@ void GenConvKernel<2, int32_t>::genConstForPermutations<inst_set_t::avx2>( template <> template <> -void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>( - asmjit::X86Emitter* a) { +void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>(x86::Emitter* a) { if (C_per_G_ == 4) { // store with permutation a->vpermd(resultRegAvx2_, stPermRegAvx2_, resultRegAvx2_); @@ -171,7 +170,7 @@ void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::storeResultRowoffset<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int offset) { // store if (C_per_G_ == 4) { @@ -198,7 +197,7 @@ void GenConvKernel<2, int32_t>::storeResultRowoffset<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForLoadingWeights<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int c_offset) { // load weights for (int r = 0; r < R_; ++r) { @@ -225,9 +224,9 @@ void GenConvKernel<2, int32_t>::genForLoadingWeights<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::gen8bitFMA<inst_set_t::avx2>( - asmjit::X86Emitter* a, - asmjit::X86Ymm aReg, - asmjit::X86Ymm wReg) { + x86::Emitter* a, + x86::Ymm aReg, + x86::Ymm wReg) { a->vpmaddubsw(tmpReg1Avx2_, aReg, wReg); a->vpmaddwd(tmpReg1Avx2_, oneReg16BitAvx2_, tmpReg1Avx2_); a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_); @@ -236,8 +235,8 @@ void GenConvKernel<2, int32_t>::gen8bitFMA<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::gen8BitSumX4<inst_set_t::avx2>( - asmjit::X86Emitter* a, - asmjit::X86Ymm aReg) { + x86::Emitter* a, + x86::Ymm aReg) { a->vpmaddubsw(tmpReg1Avx2_, aReg, oneReg8BitAvx2_); a->vpmaddwd(tmpReg1Avx2_, tmpReg1Avx2_, oneReg16BitAvx2_); a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_); @@ -246,9 +245,9 @@ void GenConvKernel<2, int32_t>::gen8BitSumX4<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::gen8BitSumX8<inst_set_t::avx2>( - asmjit::X86Emitter* a, - asmjit::X86Ymm aReg, - asmjit::X86Ymm bReg) { + x86::Emitter* a, + x86::Ymm aReg, + x86::Ymm bReg) { a->vxorps(tmpReg1Avx2_, tmpReg1Avx2_, tmpReg1Avx2_); // Let a[0] denote 0th (LSB) 8-bit of aReg // After vpsadbw, a[0:2] = a[0] + ... + a[7] @@ -267,11 +266,11 @@ void GenConvKernel<2, int32_t>::gen8BitSumX8<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::gen8BitSumX16<inst_set_t::avx2>( - asmjit::X86Emitter* a, - asmjit::X86Ymm aReg, - asmjit::X86Ymm bReg, - asmjit::X86Ymm cReg, - asmjit::X86Ymm dReg) { + x86::Emitter* a, + x86::Ymm aReg, + x86::Ymm bReg, + x86::Ymm cReg, + x86::Ymm dReg) { a->vxorps(tmpReg1Avx2_, tmpReg1Avx2_, tmpReg1Avx2_); // After vpsadbw, a[0:2] = a[0] + ... + a[7] // a[8:10] = a[8] + ... + a[15] @@ -319,7 +318,7 @@ void GenConvKernel<2, int32_t>::gen8BitSumX16<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::gen8BitSum<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int act_offset, bool use_scratch_reg1 /*=true*/) { if (use_scratch_reg1) { @@ -385,11 +384,11 @@ void GenConvKernel<2, int32_t>::gen8BitSum<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genZeroPtSum<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int multiplier) { a->mov(scratchReg1_, static_cast<asmjit::Imm>(multiplier)); // tmpReg1Avx2_ also uses xmm11 - asmjit::X86Xmm const_reg_xmm = x86::xmm11; + x86::Xmm const_reg_xmm = x86::xmm11; a->movq(const_reg_xmm, scratchReg1_); a->vpbroadcastd(tmpReg1Avx2_, const_reg_xmm); a->vpmulld(tmpReg1Avx2_, zeroPTRegAvx2_, tmpReg1Avx2_); @@ -399,7 +398,7 @@ void GenConvKernel<2, int32_t>::genZeroPtSum<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForTopEdge<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int c_offset) { // top-left corner code if (c_offset == 0) { @@ -559,7 +558,7 @@ void GenConvKernel<2, int32_t>::genForTopEdge<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForLeftEdge<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int c_offset) { // left edge excluding corners asmjit::Label LoopLeftEdge = a->newLabel(); @@ -626,7 +625,7 @@ void GenConvKernel<2, int32_t>::genForLeftEdge<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForRightEdge<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int c_offset) { // right edge excluding corners asmjit::Label LoopRightEdge = a->newLabel(); @@ -714,7 +713,7 @@ void GenConvKernel<2, int32_t>::genForRightEdge<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForBottomEdge<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int c_offset) { // bottom-left corner // we updating the last row @@ -906,7 +905,7 @@ void GenConvKernel<2, int32_t>::genForBottomEdge<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genCoreInsts<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int c_offset) { // main compute asmjit::Label LoopH = a->newLabel(); @@ -1011,9 +1010,9 @@ template <> jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( const conv_param_t<2>& conv_param) { code_.reset(false); - code_.init(rt_.getCodeInfo()); - asmjit::X86Assembler assembler(&code_); - asmjit::X86Emitter* a = assembler.asEmitter(); + code_.init(rt_.codeInfo()); + x86::Assembler assembler(&code_); + x86::Emitter* a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) // log code to a file @@ -1030,16 +1029,16 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( wghts_R_ = a->zsi(); out_acts_R_ = a->zdx(); a_zero_pt_R_ = a->zcx(); - H_R_ = a->gpzRef(8); - W_R_ = a->gpzRef(9); - row_offset_R_ = a->gpzRef(10); + H_R_ = a->gpz(8); + W_R_ = a->gpz(9); + row_offset_R_ = a->gpz(10); // register for temporary use - scratchReg1_ = a->gpzRef(12); - scratchReg2_ = a->gpzRef(13); + scratchReg1_ = a->gpz(12); + scratchReg2_ = a->gpz(13); asmjit::FuncDetail func; - func.init(asmjit::FuncSignature6< + func.init(asmjit::FuncSignatureT< void, uint8_t*, int8_t*, @@ -1048,29 +1047,29 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( int32_t, int32_t>(asmjit::CallConv::kIdHost)); - asmjit::FuncFrameInfo ffi; - ffi.setDirtyRegs( - asmjit::X86Reg::kKindVec, - asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); - ffi.setDirtyRegs( - asmjit::X86Reg::kKindGp, asmjit::Utils::mask(10, 11, 12, 13, 14, 15)); + asmjit::FuncFrame frame; + frame.init(func); - asmjit::FuncArgsMapper args(&func); - args.assignAll(in_acts_R_, wghts_R_, out_acts_R_, a_zero_pt_R_, H_R_, W_R_); + frame.setDirtyRegs( + x86::Reg::kGroupVec, + asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.setDirtyRegs( + x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - args.updateFrameInfo(ffi); + asmjit::FuncArgsAssignment args(&func); + args.assignAll(in_acts_R_, wghts_R_, out_acts_R_, a_zero_pt_R_, H_R_, W_R_); - asmjit::FuncFrameLayout layout; - layout.init(func, ffi); + args.updateFuncFrame(frame); + frame.finalize(); - asmjit::FuncUtils::emitProlog(a, layout); - asmjit::FuncUtils::allocArgs(a, layout, args); + a->emitProlog(frame); + a->emitArgsAssignment(frame, args); createVector16BitOne<inst_set_t::avx2>(a); - loopR1_ = a->gpzRef(14); - loopR2_ = a->gpzRef(15); + loopR1_ = a->gpz(14); + loopR2_ = a->gpz(15); if (!isAZeroPointZero_) { setToZeroPt<inst_set_t::avx2>(a, zeroPTRegAvx2_); @@ -1095,7 +1094,7 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( genCoreInsts<inst_set_t::avx2>(a, c); } - asmjit::FuncUtils::emitEpilog(a, layout); + a->emitEpilog(frame); jit_conv_kernel_fp fn; asmjit::Error err = rt_.add(&fn, &code_); @@ -1117,7 +1116,7 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>( - asmjit::X86Emitter* a) { + x86::Emitter* a) { // top-left corner code // zero out the results register a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); @@ -1213,7 +1212,7 @@ void GenConvKernel<2, int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>( - asmjit::X86Emitter* a) { + x86::Emitter* a) { // left edge excluding corners asmjit::Label LoopLeftEdge = a->newLabel(); a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_)); @@ -1256,7 +1255,7 @@ void GenConvKernel<2, int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>( - asmjit::X86Emitter* a) { + x86::Emitter* a) { // right edge excluding corners asmjit::Label LoopRightEdge = a->newLabel(); @@ -1326,7 +1325,7 @@ void GenConvKernel<2, int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>( - asmjit::X86Emitter* a) { + x86::Emitter* a) { // bottom-left corner // zero out a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); @@ -1429,7 +1428,7 @@ void GenConvKernel<2, int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genRowoffsetCore<inst_set_t::avx2>( - asmjit::X86Emitter* a) { + x86::Emitter* a) { // number of uint8 elements in input channels should be a multiple of 32 assert(C_ % 32 == 0); @@ -1491,9 +1490,9 @@ jit_rowoffset_kernel_fp GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>( const conv_param_t<2>& conv_param) { code_.reset(false); - code_.init(rt_.getCodeInfo()); - asmjit::X86Assembler assembler(&code_); - asmjit::X86Emitter* a = assembler.asEmitter(); + code_.init(rt_.codeInfo()); + x86::Assembler assembler(&code_); + x86::Emitter* a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) // log code to a file @@ -1510,45 +1509,45 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>( a_zero_pt_R_ = a->zsi(); H_R_ = a->zdx(); W_R_ = a->zcx(); - row_offset_R_ = a->gpzRef(8); + row_offset_R_ = a->gpz(8); // register for temporary use - scratchReg1_ = a->gpzRef(12); - scratchReg2_ = a->gpzRef(13); + scratchReg1_ = a->gpz(12); + scratchReg2_ = a->gpz(13); - loopR1_ = a->gpzRef(14); - loopR2_ = a->gpzRef(15); + loopR1_ = a->gpz(14); + loopR2_ = a->gpz(15); asmjit::FuncDetail func; func.init( asmjit:: - FuncSignature5<void, uint8_t*, int32_t, int32_t, int32_t, int32_t*>( + FuncSignatureT<void, uint8_t*, int32_t, int32_t, int32_t, int32_t*>( asmjit::CallConv::kIdHost)); - asmjit::FuncFrameInfo ffi; - ffi.setDirtyRegs( - asmjit::X86Reg::kKindVec, - asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); - ffi.setDirtyRegs( - asmjit::X86Reg::kKindGp, asmjit::Utils::mask(10, 11, 12, 13, 14, 15)); + asmjit::FuncFrame frame; + frame.init(func); - asmjit::FuncArgsMapper args(&func); - args.assignAll(in_acts_R_, a_zero_pt_R_, H_R_, W_R_, row_offset_R_); + frame.setDirtyRegs( + x86::Reg::kGroupVec, + asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.setDirtyRegs( + x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - args.updateFrameInfo(ffi); + asmjit::FuncArgsAssignment args(&func); + args.assignAll(in_acts_R_, a_zero_pt_R_, H_R_, W_R_, row_offset_R_); - asmjit::FuncFrameLayout layout; - layout.init(func, ffi); + args.updateFuncFrame(frame); + frame.finalize(); - asmjit::FuncUtils::emitProlog(a, layout); - asmjit::FuncUtils::allocArgs(a, layout, args); + a->emitProlog(frame); + a->emitArgsAssignment(frame, args); // This uses xmm10 register temporarily. Should come before // createVector8BitOne if (!isAZeroPointZero_) { // we can use xmm11 because ymm11 is used by tmpReg1Avx2_ - asmjit::X86Xmm const_reg_xmm = x86::xmm11; + x86::Xmm const_reg_xmm = x86::xmm11; a->movq(const_reg_xmm, a_zero_pt_R_); a->vpbroadcastd(zeroPTRegAvx2_, const_reg_xmm); @@ -1569,7 +1568,7 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>( genRowoffsetCore<inst_set_t::avx2>(a); - asmjit::FuncUtils::emitEpilog(a, layout); + a->emitEpilog(frame); jit_rowoffset_kernel_fp fn; asmjit::Error err = rt_.add(&fn, &code_); diff --git a/src/PackAMatrix.cc b/src/PackAMatrix.cc index 143e11d..5fabf97 100644 --- a/src/PackAMatrix.cc +++ b/src/PackAMatrix.cc @@ -34,7 +34,8 @@ PackAMatrix<T, accT>::PackAMatrix( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) { + if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() && + !fbgemmHasAvx2Support())) { assert(0 && "unknown architecure"); } @@ -43,7 +44,12 @@ PackAMatrix<T, accT>::PackAMatrix( BaseType::bcol_ = params->KCB; row_interleave_B_ = params->ROW_INTERLEAVE; } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE; + } else if (fbgemmHasAvx512Support()) { BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB; BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; row_interleave_B_ = diff --git a/src/PackAWithIm2Col.cc b/src/PackAWithIm2Col.cc index d731654..2aca27d 100644 --- a/src/PackAWithIm2Col.cc +++ b/src/PackAWithIm2Col.cc @@ -49,7 +49,8 @@ PackAWithIm2Col<T, accT, SPATIAL_DIM>::PackAWithIm2Col( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) { + if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() && + !fbgemmHasAvx2Support())) { assert(0 && "unknown architecure"); } @@ -58,7 +59,12 @@ PackAWithIm2Col<T, accT, SPATIAL_DIM>::PackAWithIm2Col( BaseType::bcol_ = params->KCB; row_interleave_B_ = params->ROW_INTERLEAVE; } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE; + } else if (fbgemmHasAvx512Support()) { BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB; BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; row_interleave_B_ = @@ -478,7 +484,9 @@ int PackAWithIm2Col<T, accT, SPATIAL_DIM>::rowOffsetBufferSize( if (params) { return params->MCB; } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + return PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB; + } else if (fbgemmHasAvx512Support()) { return PackingTraits<T, accT, inst_set_t::avx512>::MCB; } else if (fbgemmHasAvx2Support()) { return PackingTraits<T, accT, inst_set_t::avx2>::MCB; diff --git a/src/PackAWithQuantRowOffset.cc b/src/PackAWithQuantRowOffset.cc index 0e5c598..0af05e8 100644 --- a/src/PackAWithQuantRowOffset.cc +++ b/src/PackAWithQuantRowOffset.cc @@ -45,7 +45,8 @@ PackAWithQuantRowOffset<T, accT>::PackAWithQuantRowOffset( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) { + if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() && + !fbgemmHasAvx2Support())) { assert(0 && "unknown architecure"); } @@ -54,7 +55,12 @@ PackAWithQuantRowOffset<T, accT>::PackAWithQuantRowOffset( BaseType::bcol_ = params->KCB; row_interleave_B_ = params->ROW_INTERLEAVE; } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE; + } else if (fbgemmHasAvx512Support()) { BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB; BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; row_interleave_B_ = @@ -199,7 +205,9 @@ int PackAWithQuantRowOffset<T, accT>::rowOffsetBufferSize( if (params) { return params->MCB; } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + return PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB; + } else if (fbgemmHasAvx512Support()) { return PackingTraits<T, accT, inst_set_t::avx512>::MCB; } else if (fbgemmHasAvx2Support()) { return PackingTraits<T, accT, inst_set_t::avx2>::MCB; diff --git a/src/PackAWithRowOffset.cc b/src/PackAWithRowOffset.cc index 733bf5c..e84c67b 100644 --- a/src/PackAWithRowOffset.cc +++ b/src/PackAWithRowOffset.cc @@ -39,7 +39,8 @@ PackAWithRowOffset<T, accT>::PackAWithRowOffset( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) { + if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() && + !fbgemmHasAvx2Support())) { assert(0 && "unknown architecure"); } @@ -48,7 +49,12 @@ PackAWithRowOffset<T, accT>::PackAWithRowOffset( BaseType::bcol_ = params->KCB; row_interleave_B_ = params->ROW_INTERLEAVE; } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE; + } else if (fbgemmHasAvx512Support()) { BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB; BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; row_interleave_B_ = @@ -189,7 +195,9 @@ int PackAWithRowOffset<T, accT>::rowOffsetBufferSize( if (params) { return params->MCB; } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + return PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB; + } else if (fbgemmHasAvx512Support()) { return PackingTraits<T, accT, inst_set_t::avx512>::MCB; } else if (fbgemmHasAvx2Support()) { return PackingTraits<T, accT, inst_set_t::avx2>::MCB; diff --git a/src/PackBMatrix.cc b/src/PackBMatrix.cc index 0990edb..bf43fab 100644 --- a/src/PackBMatrix.cc +++ b/src/PackBMatrix.cc @@ -188,7 +188,8 @@ PackBMatrix<T, accT>::PackBMatrix( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) { + if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() && + !fbgemmHasAvx2Support())) { assert(0 && "unknown architecure"); } @@ -197,7 +198,12 @@ PackBMatrix<T, accT>::PackBMatrix( BaseType::bcol_ = params->NCB; row_interleave_ = params->ROW_INTERLEAVE; } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::NCB; + row_interleave_ = + PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE; + } else if (fbgemmHasAvx512Support()) { BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::NCB; row_interleave_ = @@ -317,14 +323,16 @@ void PackBMatrix<T, accT>::pack_unpack_( } template <typename T, typename accT> -void PackBMatrix<T, accT>::pack(const block_type_t& block, - const BlockingFactors* params) { +void PackBMatrix<T, accT>::pack( + const block_type_t& block, + const BlockingFactors* params) { pack_unpack_(block, const_cast<T*>(smat_), BaseType::getBuf(), true, params); } template <typename T, typename accT> -void PackBMatrix<T, accT>::unpack(T* origin_buf, - const BlockingFactors* params) { +void PackBMatrix<T, accT>::unpack( + T* origin_buf, + const BlockingFactors* params) { block_type_t blockB{BaseType::packedRowStart(), BaseType::numPackedRows(), BaseType::packedColStart(), @@ -352,8 +360,9 @@ int32_t PackBMatrix<T, accT>::addr(int32_t r, int32_t c) const { } template <typename T, typename accT> -void PackBMatrix<T, accT>::printPackedMatrix(std::string name, - const BlockingFactors* params) { +void PackBMatrix<T, accT>::printPackedMatrix( + std::string name, + const BlockingFactors* params) { std::cout << name << ":" << "[" << BaseType::numPackedRows() << ", " << BaseType::numPackedCols() << "]" << std::endl; diff --git a/src/PackMatrix.cc b/src/PackMatrix.cc index c7503dd..ff7b842 100644 --- a/src/PackMatrix.cc +++ b/src/PackMatrix.cc @@ -36,7 +36,8 @@ int PackMatrix<PT, inpType, accType>::packedBufferSize( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) { + if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() && + !fbgemmHasAvx2Support())) { assert(0 && "unknown architecure"); } @@ -46,7 +47,11 @@ int PackMatrix<PT, inpType, accType>::packedBufferSize( NCB = params->NCB; KCB = params->KCB; } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + MCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni>::MCB; + NCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni>::NCB; + KCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni>::KCB; + } else if (fbgemmHasAvx512Support()) { MCB = PackingTraits<inpType, accType, inst_set_t::avx512>::MCB; NCB = PackingTraits<inpType, accType, inst_set_t::avx512>::NCB; KCB = PackingTraits<inpType, accType, inst_set_t::avx512>::KCB; diff --git a/src/PackWeightMatrixForGConv.cc b/src/PackWeightMatrixForGConv.cc index ba6adf3..f6ad59e 100644 --- a/src/PackWeightMatrixForGConv.cc +++ b/src/PackWeightMatrixForGConv.cc @@ -106,7 +106,7 @@ inline int PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::packed_index_( * on 2 groups at a time and full SIMD width can be efficiently utilized even * while working on 1 group at a time. * In this case, the layout is G (C/4) R S K 4 -*/ + */ template <typename T, typename accT, int SPATIAL_DIM> void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack_unpack_( @@ -148,9 +148,9 @@ void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack_unpack_( if (ispack) { transposeConvWeights(conv_param_, src, dst); } else { - // TODO: Wrap this as a inverseTransposeConvWeights()? - // For unpack & transposed, call transposeConvWeights() - // G (R S C/G) K/G => G K/G (R S C/G) + // TODO: Wrap this as a inverseTransposeConvWeights()? + // For unpack & transposed, call transposeConvWeights() + // G (R S C/G) K/G => G K/G (R S C/G) for (int r = 0; r < R; ++r) { for (int s = 0; s < S; ++s) { for (int k = 0; k < OC_per_G; ++k) { diff --git a/src/Utils.cc b/src/Utils.cc index 0fa620d..5214e41 100644 --- a/src/Utils.cc +++ b/src/Utils.cc @@ -202,4 +202,7 @@ bool fbgemmHasAvx2Support() { return (cpuinfo_has_x86_avx2()); } +bool fbgemmHasAvx512VnniSupport() { + return (cpuinfo_has_x86_avx512vnni()); +} } // namespace fbgemm diff --git a/test/GConvTest.cc b/test/GConvTest.cc index 0074535..8c1fb82 100644 --- a/test/GConvTest.cc +++ b/test/GConvTest.cc @@ -465,8 +465,8 @@ TEST_P(fbgemmGConvPackTest, PackUnpackTest) { for (int i = 0; i < weight_len; ++i) { EXPECT_EQ(Bint8.data()[i], unpack_buf.data()[i]) << "Pack/Unpack results differ at index " << i - << ", Reference: " << static_cast<int> (Bint8.data()[i]) - << ", Pack-Unpacked: " << static_cast<int> (unpack_buf.data()[i]); + << ", Reference: " << static_cast<int>(Bint8.data()[i]) + << ", Pack-Unpacked: " << static_cast<int>(unpack_buf.data()[i]); } } // for each shape } diff --git a/third_party/asmjit b/third_party/asmjit -Subproject 673dcefaa048c5f5a2bf8b85daf8f7b9978d018 +Subproject 5d40561d14f93dc45613bfa03155d1dfb4f5825 |