diff options
Diffstat (limited to 'src/GenerateKernelU8S8S32ACC16Avx512.cc')
-rw-r--r-- | src/GenerateKernelU8S8S32ACC16Avx512.cc | 94 |
1 files changed, 47 insertions, 47 deletions
diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc index a49e440..c95757b 100644 --- a/src/GenerateKernelU8S8S32ACC16Avx512.cc +++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc @@ -19,7 +19,7 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs< inst_set_t::avx512>( - x86::Emitter* a, + asmjit::X86Emitter* a, int rowRegs, int colRegs, int leadingDimCReg) { @@ -41,18 +41,18 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock< inst_set_t::avx512>( - x86::Emitter* a, - x86::Gp buffer_A, - x86::Gp buffer_B, - x86::Gp /* unused (reserved for prefetching)*/, + asmjit::X86Emitter* a, + asmjit::X86Gp buffer_A, + asmjit::X86Gp buffer_B, + asmjit::X86Gp /* unused (reserved for prefetching)*/, int rowRegs, int colRegs, int lda, int leadingDimCReg) { // used for matrix A - x86::Zmm AReg = x86::zmm29; + asmjit::X86Zmm AReg = x86::zmm29; - x86::Zmm tmpReg = x86::zmm30; + asmjit::X86Zmm tmpReg = x86::zmm30; // We start allocating BRegs from zmm27 and then allocate zmm26 and so on. for (int j = 0; j < colRegs; ++j) { @@ -66,7 +66,8 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock< a->vpbroadcastw( AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t))); for (int j = 0; j < colRegs; ++j) { - a->vpmaddubsw(tmpReg, AReg, AllRegs_avx512_[27 - j]); + a->vpmaddubsw( + tmpReg, AReg, AllRegs_avx512_[27-j]); a->vpaddsw( CRegs_avx512_[i * leadingDimCReg + j], tmpReg, @@ -89,16 +90,15 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs< inst_set_t::avx512>( - x86::Emitter* a, + asmjit::X86Emitter* a, int rowRegs, int colRegs, - x86::Gp C_Offset, - x86::Gp ldcReg, - + asmjit::X86Gp C_Offset, + asmjit::X86Gp ldcReg, bool accum, int leadingDimCReg) { - x86::Ymm extractDest256 = x86::ymm31; - x86::Zmm extractDest512 = x86::zmm31; + asmjit::X86Ymm extractDest256 = x86::ymm31; + asmjit::X86Zmm extractDest512 = x86::zmm31; for (int i = 0; i < rowRegs; ++i) { a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(i * sizeof(int32_t))); @@ -107,7 +107,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs< a->vextracti32x8( extractDest256, CRegs_avx512_[i * leadingDimCReg + j], idx); a->vpmovsxwd(extractDest512, extractDest256); - x86::Mem destAddr = x86::dword_ptr( + asmjit::X86Mem destAddr = x86::dword_ptr( a->zcx(), C_Offset, 0, (j * 2 + idx) * 16 * sizeof(int32_t)); if (accum) { a->vpaddd(extractDest512, extractDest512, destAddr); @@ -172,9 +172,9 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( } code_.reset(false); - code_.init(rt_.codeInfo()); - x86::Assembler assembler(&code_); - x86::Emitter* a = assembler.as<x86::Emitter>(); + code_.init(rt_.getCodeInfo()); + asmjit::X86Assembler assembler(&code_); + asmjit::X86Emitter* a = assembler.asEmitter(); #if defined(FBGEMM_LOG_CODE) // generated code logging @@ -209,49 +209,49 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( int mRegBlocksRem = mc % mRegBlockSize; // arguments to the function created - x86::Gp buffer_A = a->zdi(); - x86::Gp buffer_B = a->zsi(); - x86::Gp B_pf = a->zdx(); - x86::Gp CBase = a->zcx(); - x86::Gp kSize = a->gpz(8); - x86::Gp ldcReg = a->gpz(9); + asmjit::X86Gp buffer_A = a->zdi(); + asmjit::X86Gp buffer_B = a->zsi(); + asmjit::X86Gp B_pf = a->zdx(); + asmjit::X86Gp CBase = a->zcx(); + asmjit::X86Gp kSize = a->gpzRef(8); + asmjit::X86Gp ldcReg = a->gpzRef(9); asmjit::FuncDetail func; func.init( asmjit:: - FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( + FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( asmjit::CallConv::kIdHost)); - asmjit::FuncFrame frame; - frame.init(func); - - frame.setDirtyRegs( - x86::Reg::kGroupVec, - asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - frame.setDirtyRegs( - x86::Reg::kGroupGp, - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + asmjit::FuncFrameInfo ffi; + ffi.setDirtyRegs( + asmjit::X86Reg::kKindVec, + asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); + ffi.setDirtyRegs( + asmjit::X86Reg::kKindGp, + asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); - asmjit::FuncArgsAssignment args(&func); + asmjit::FuncArgsMapper args(&func); args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); - args.updateFuncFrame(frame); - frame.finalize(); + args.updateFrameInfo(ffi); + + asmjit::FuncFrameLayout layout; + layout.init(func, ffi); - a->emitProlog(frame); - a->emitArgsAssignment(frame, args); + asmjit::FuncUtils::emitProlog(a, layout); + asmjit::FuncUtils::allocArgs(a, layout, args); asmjit::Label LoopMBlocks = a->newLabel(); asmjit::Label LoopNBlocks = a->newLabel(); asmjit::Label Loopk = a->newLabel(); - x86::Gp buffer_B_saved = a->gpz(10); - x86::Gp C_Offset = a->gpz(11); - // x86::Gp B_pf_saved = a->gpz(12); - x86::Gp iIdx = a->gpz(13); - x86::Gp jIdx = a->gpz(14); - x86::Gp kIdx = a->gpz(15); + asmjit::X86Gp buffer_B_saved = a->gpzRef(10); + asmjit::X86Gp C_Offset = a->gpzRef(11); + // asmjit::X86Gp B_pf_saved = a->gpzRef(12); + asmjit::X86Gp iIdx = a->gpzRef(13); + asmjit::X86Gp jIdx = a->gpzRef(14); + asmjit::X86Gp kIdx = a->gpzRef(15); // save B_buffer address a->mov(buffer_B_saved, buffer_B); @@ -407,7 +407,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( a->jl(LoopNRem); } - a->emitEpilog(frame); + asmjit::FuncUtils::emitEpilog(a, layout); jit_micro_kernel_fp fn; asmjit::Error err = rt_.add(&fn, &code_); |