diff options
Diffstat (limited to 'src/GenerateKernelU8S8S32ACC16.cc')
-rw-r--r-- | src/GenerateKernelU8S8S32ACC16.cc | 91 |
1 files changed, 44 insertions, 47 deletions
diff --git a/src/GenerateKernelU8S8S32ACC16.cc b/src/GenerateKernelU8S8S32ACC16.cc index 718b883..1e7e081 100644 --- a/src/GenerateKernelU8S8S32ACC16.cc +++ b/src/GenerateKernelU8S8S32ACC16.cc @@ -31,7 +31,7 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs< inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int rowRegs, int colRegs, int leadingDimCReg) { @@ -53,18 +53,18 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock< inst_set_t::avx2>( - asmjit::X86Emitter* a, - asmjit::X86Gp buffer_A, - asmjit::X86Gp buffer_B, - asmjit::X86Gp /* unused (reserved for prefetching)*/, + x86::Emitter* a, + x86::Gp buffer_A, + x86::Gp buffer_B, + x86::Gp /* unused (reserved for prefetching)*/, int rowRegs, int colRegs, int lda, int leadingDimCReg) { // used for matrix A - asmjit::X86Ymm AReg = x86::ymm12; + x86::Ymm AReg = x86::ymm12; - asmjit::X86Ymm tmpReg = x86::ymm14; + x86::Ymm tmpReg = x86::ymm14; for (int i = 0; i < rowRegs; ++i) { // broadcast A @@ -95,15 +95,15 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs< inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int rowRegs, int colRegs, - asmjit::X86Gp C_Offset, - asmjit::X86Gp ldcReg, + x86::Gp C_Offset, + x86::Gp ldcReg, bool accum, int leadingDimCReg) { - asmjit::X86Xmm extractDest128 = x86::xmm15; - asmjit::X86Ymm extractDest256 = x86::ymm15; + x86::Xmm extractDest128 = x86::xmm15; + x86::Ymm extractDest256 = x86::ymm15; for (int i = 0; i < rowRegs; ++i) { a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(i * sizeof(int32_t))); @@ -112,7 +112,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs< a->vextracti128( extractDest128, CRegs_avx2_[i * leadingDimCReg + j], idx); a->vpmovsxwd(extractDest256, extractDest128); - asmjit::X86Mem destAddr = x86::dword_ptr( + x86::Mem destAddr = x86::dword_ptr( a->zcx(), C_Offset, 0, (j * 2 + idx) * 8 * sizeof(int32_t)); if (accum) { a->vpaddd(extractDest256, extractDest256, destAddr); @@ -176,9 +176,9 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>( return codeCache_[kernelSig]; } code_.reset(false); - code_.init(rt_.getCodeInfo()); - asmjit::X86Assembler assembler(&code_); - asmjit::X86Emitter* a = assembler.asEmitter(); + code_.init(rt_.codeInfo()); + x86::Assembler assembler(&code_); + x86::Emitter* a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) // generated code logging @@ -207,46 +207,45 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>( //"nc must be equal to the number of register blocks"); // arguments to the function created - asmjit::X86Gp buffer_A = a->zdi(); - asmjit::X86Gp buffer_B = a->zsi(); - asmjit::X86Gp B_pf = a->zdx(); - asmjit::X86Gp CBase = a->zcx(); - asmjit::X86Gp kSize = a->gpzRef(8); - asmjit::X86Gp ldcReg = a->gpzRef(9); + x86::Gp buffer_A = a->zdi(); + x86::Gp buffer_B = a->zsi(); + x86::Gp B_pf = a->zdx(); + x86::Gp CBase = a->zcx(); + x86::Gp kSize = a->gpz(8); + x86::Gp ldcReg = a->gpz(9); asmjit::FuncDetail func; func.init( asmjit:: - FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( + FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( asmjit::CallConv::kIdHost)); - asmjit::FuncFrameInfo ffi; - ffi.setDirtyRegs( - asmjit::X86Reg::kKindVec, - asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); - ffi.setDirtyRegs( - asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14)); + asmjit::FuncFrame frame; + frame.init(func); + frame.setDirtyRegs( + x86::Reg::kGroupVec, + asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.setDirtyRegs( + x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14)); - asmjit::FuncArgsMapper args(&func); + asmjit::FuncArgsAssignment args(&func); args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); - args.updateFrameInfo(ffi); + args.updateFuncFrame(frame); + frame.finalize(); - asmjit::FuncFrameLayout layout; - layout.init(func, ffi); - - asmjit::FuncUtils::emitProlog(a, layout); - asmjit::FuncUtils::allocArgs(a, layout, args); + a->emitProlog(frame); + a->emitArgsAssignment(frame, args); asmjit::Label Loopk = a->newLabel(); asmjit::Label LoopMBlocks = a->newLabel(); - asmjit::X86Gp buffer_B_saved = a->gpzRef(10); - asmjit::X86Gp C_Offset = a->gpzRef(11); - // asmjit::X86Gp B_pf_saved = a->gpzRef(12); - asmjit::X86Gp iIdx = a->gpzRef(13); - asmjit::X86Gp kIdx = a->gpzRef(14); + x86::Gp buffer_B_saved = a->gpz(10); + x86::Gp C_Offset = a->gpz(11); + // x86::Gp B_pf_saved = a->gpz(12); + x86::Gp iIdx = a->gpz(13); + x86::Gp kIdx = a->gpz(14); int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; if (mRegBlocks > 0) { @@ -289,8 +288,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>( a->jl(Loopk); // store C matrix - storeCRegs<inst_set_t::avx2>( - a, rowRegs, colRegs, C_Offset, ldcReg, accum); + storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum); // increment A for next block a->sub(buffer_A, kSize); @@ -340,11 +338,10 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>( a->jl(LoopkRem); // store C matrix - storeCRegs<inst_set_t::avx2>( - a, rowRegs, colRegs, C_Offset, ldcReg, accum); + storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum); } - asmjit::FuncUtils::emitEpilog(a, layout); + a->emitEpilog(frame); jit_micro_kernel_fp fn; asmjit::Error err = rt_.add(&fn, &code_); |