diff options
Diffstat (limited to 'src/GroupwiseConvAcc32Avx2.cc')
-rw-r--r-- | src/GroupwiseConvAcc32Avx2.cc | 177 |
1 files changed, 88 insertions, 89 deletions
diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc index e789695..b140c83 100644 --- a/src/GroupwiseConvAcc32Avx2.cc +++ b/src/GroupwiseConvAcc32Avx2.cc @@ -104,7 +104,7 @@ jit_conv_kernel_fp getOrCreateConvKernel( template <> template <> void GenConvKernel<2, int32_t>::createVector8BitOne<inst_set_t::avx2>( - asmjit::X86Emitter* a) { + x86::Emitter* a) { // create 8-bit 1s // i.e., oneReg16BitAvx2_[0:7] contains 0x01, oneReg8BitAvx2_[8:15] contains // 0x01 and so on @@ -115,7 +115,7 @@ void GenConvKernel<2, int32_t>::createVector8BitOne<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::createVector16BitOne<inst_set_t::avx2>( - asmjit::X86Emitter* a) { + x86::Emitter* a) { // create 16-bit 1s // i.e., oneReg16BitAvx2_[0:15] contains 0x0001, oneReg16BitAvx2_[16:31] // contains 0x0001 and so on @@ -125,11 +125,11 @@ void GenConvKernel<2, int32_t>::createVector16BitOne<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::setToZeroPt<inst_set_t::avx2>( - asmjit::X86Emitter* a, - asmjit::X86Ymm destReg) { + x86::Emitter* a, + x86::Ymm destReg) { // make destReg all zeros a->vxorps(destReg, destReg, destReg); - asmjit::X86Xmm const_reg_xmm = x86::xmm10; + x86::Xmm const_reg_xmm = x86::xmm10; // move zero point to xmm10 a->movq(const_reg_xmm, a_zero_pt_R_); // make copies of zero point @@ -143,9 +143,9 @@ void GenConvKernel<2, int32_t>::setToZeroPt<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genConstForPermutations<inst_set_t::avx2>( - asmjit::X86Emitter* a) { - asmjit::X86Gp permute_const_reg = a->gpzRef(12); - asmjit::X86Xmm const_reg_xmm = x86::xmm10; + x86::Emitter* a) { + x86::Gp permute_const_reg = a->gpz(12); + x86::Xmm const_reg_xmm = x86::xmm10; // We have 1st group in even lanes and 2nd group in odd lanes. // Permute to put 1st group to lower 128-bit and 2nd group in upper // 128-bit. @@ -159,8 +159,7 @@ void GenConvKernel<2, int32_t>::genConstForPermutations<inst_set_t::avx2>( template <> template <> -void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>( - asmjit::X86Emitter* a) { +void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>(x86::Emitter* a) { if (C_per_G_ == 4) { // store with permutation a->vpermd(resultRegAvx2_, stPermRegAvx2_, resultRegAvx2_); @@ -171,7 +170,7 @@ void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::storeResultRowoffset<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int offset) { // store if (C_per_G_ == 4) { @@ -198,7 +197,7 @@ void GenConvKernel<2, int32_t>::storeResultRowoffset<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForLoadingWeights<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int c_offset) { // load weights for (int r = 0; r < R_; ++r) { @@ -225,9 +224,9 @@ void GenConvKernel<2, int32_t>::genForLoadingWeights<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::gen8bitFMA<inst_set_t::avx2>( - asmjit::X86Emitter* a, - asmjit::X86Ymm aReg, - asmjit::X86Ymm wReg) { + x86::Emitter* a, + x86::Ymm aReg, + x86::Ymm wReg) { a->vpmaddubsw(tmpReg1Avx2_, aReg, wReg); a->vpmaddwd(tmpReg1Avx2_, oneReg16BitAvx2_, tmpReg1Avx2_); a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_); @@ -236,8 +235,8 @@ void GenConvKernel<2, int32_t>::gen8bitFMA<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::gen8BitSumX4<inst_set_t::avx2>( - asmjit::X86Emitter* a, - asmjit::X86Ymm aReg) { + x86::Emitter* a, + x86::Ymm aReg) { a->vpmaddubsw(tmpReg1Avx2_, aReg, oneReg8BitAvx2_); a->vpmaddwd(tmpReg1Avx2_, tmpReg1Avx2_, oneReg16BitAvx2_); a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_); @@ -246,9 +245,9 @@ void GenConvKernel<2, int32_t>::gen8BitSumX4<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::gen8BitSumX8<inst_set_t::avx2>( - asmjit::X86Emitter* a, - asmjit::X86Ymm aReg, - asmjit::X86Ymm bReg) { + x86::Emitter* a, + x86::Ymm aReg, + x86::Ymm bReg) { a->vxorps(tmpReg1Avx2_, tmpReg1Avx2_, tmpReg1Avx2_); // Let a[0] denote 0th (LSB) 8-bit of aReg // After vpsadbw, a[0:2] = a[0] + ... + a[7] @@ -267,11 +266,11 @@ void GenConvKernel<2, int32_t>::gen8BitSumX8<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::gen8BitSumX16<inst_set_t::avx2>( - asmjit::X86Emitter* a, - asmjit::X86Ymm aReg, - asmjit::X86Ymm bReg, - asmjit::X86Ymm cReg, - asmjit::X86Ymm dReg) { + x86::Emitter* a, + x86::Ymm aReg, + x86::Ymm bReg, + x86::Ymm cReg, + x86::Ymm dReg) { a->vxorps(tmpReg1Avx2_, tmpReg1Avx2_, tmpReg1Avx2_); // After vpsadbw, a[0:2] = a[0] + ... + a[7] // a[8:10] = a[8] + ... + a[15] @@ -319,7 +318,7 @@ void GenConvKernel<2, int32_t>::gen8BitSumX16<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::gen8BitSum<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int act_offset, bool use_scratch_reg1 /*=true*/) { if (use_scratch_reg1) { @@ -385,11 +384,11 @@ void GenConvKernel<2, int32_t>::gen8BitSum<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genZeroPtSum<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int multiplier) { a->mov(scratchReg1_, static_cast<asmjit::Imm>(multiplier)); // tmpReg1Avx2_ also uses xmm11 - asmjit::X86Xmm const_reg_xmm = x86::xmm11; + x86::Xmm const_reg_xmm = x86::xmm11; a->movq(const_reg_xmm, scratchReg1_); a->vpbroadcastd(tmpReg1Avx2_, const_reg_xmm); a->vpmulld(tmpReg1Avx2_, zeroPTRegAvx2_, tmpReg1Avx2_); @@ -399,7 +398,7 @@ void GenConvKernel<2, int32_t>::genZeroPtSum<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForTopEdge<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int c_offset) { // top-left corner code if (c_offset == 0) { @@ -559,7 +558,7 @@ void GenConvKernel<2, int32_t>::genForTopEdge<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForLeftEdge<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int c_offset) { // left edge excluding corners asmjit::Label LoopLeftEdge = a->newLabel(); @@ -626,7 +625,7 @@ void GenConvKernel<2, int32_t>::genForLeftEdge<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForRightEdge<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int c_offset) { // right edge excluding corners asmjit::Label LoopRightEdge = a->newLabel(); @@ -714,7 +713,7 @@ void GenConvKernel<2, int32_t>::genForRightEdge<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForBottomEdge<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int c_offset) { // bottom-left corner // we updating the last row @@ -906,7 +905,7 @@ void GenConvKernel<2, int32_t>::genForBottomEdge<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genCoreInsts<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int c_offset) { // main compute asmjit::Label LoopH = a->newLabel(); @@ -1011,9 +1010,9 @@ template <> jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( const conv_param_t<2>& conv_param) { code_.reset(false); - code_.init(rt_.getCodeInfo()); - asmjit::X86Assembler assembler(&code_); - asmjit::X86Emitter* a = assembler.asEmitter(); + code_.init(rt_.codeInfo()); + x86::Assembler assembler(&code_); + x86::Emitter* a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) // log code to a file @@ -1030,16 +1029,16 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( wghts_R_ = a->zsi(); out_acts_R_ = a->zdx(); a_zero_pt_R_ = a->zcx(); - H_R_ = a->gpzRef(8); - W_R_ = a->gpzRef(9); - row_offset_R_ = a->gpzRef(10); + H_R_ = a->gpz(8); + W_R_ = a->gpz(9); + row_offset_R_ = a->gpz(10); // register for temporary use - scratchReg1_ = a->gpzRef(12); - scratchReg2_ = a->gpzRef(13); + scratchReg1_ = a->gpz(12); + scratchReg2_ = a->gpz(13); asmjit::FuncDetail func; - func.init(asmjit::FuncSignature6< + func.init(asmjit::FuncSignatureT< void, uint8_t*, int8_t*, @@ -1048,29 +1047,29 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( int32_t, int32_t>(asmjit::CallConv::kIdHost)); - asmjit::FuncFrameInfo ffi; - ffi.setDirtyRegs( - asmjit::X86Reg::kKindVec, - asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); - ffi.setDirtyRegs( - asmjit::X86Reg::kKindGp, asmjit::Utils::mask(10, 11, 12, 13, 14, 15)); + asmjit::FuncFrame frame; + frame.init(func); - asmjit::FuncArgsMapper args(&func); - args.assignAll(in_acts_R_, wghts_R_, out_acts_R_, a_zero_pt_R_, H_R_, W_R_); + frame.setDirtyRegs( + x86::Reg::kGroupVec, + asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.setDirtyRegs( + x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - args.updateFrameInfo(ffi); + asmjit::FuncArgsAssignment args(&func); + args.assignAll(in_acts_R_, wghts_R_, out_acts_R_, a_zero_pt_R_, H_R_, W_R_); - asmjit::FuncFrameLayout layout; - layout.init(func, ffi); + args.updateFuncFrame(frame); + frame.finalize(); - asmjit::FuncUtils::emitProlog(a, layout); - asmjit::FuncUtils::allocArgs(a, layout, args); + a->emitProlog(frame); + a->emitArgsAssignment(frame, args); createVector16BitOne<inst_set_t::avx2>(a); - loopR1_ = a->gpzRef(14); - loopR2_ = a->gpzRef(15); + loopR1_ = a->gpz(14); + loopR2_ = a->gpz(15); if (!isAZeroPointZero_) { setToZeroPt<inst_set_t::avx2>(a, zeroPTRegAvx2_); @@ -1095,7 +1094,7 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( genCoreInsts<inst_set_t::avx2>(a, c); } - asmjit::FuncUtils::emitEpilog(a, layout); + a->emitEpilog(frame); jit_conv_kernel_fp fn; asmjit::Error err = rt_.add(&fn, &code_); @@ -1117,7 +1116,7 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>( - asmjit::X86Emitter* a) { + x86::Emitter* a) { // top-left corner code // zero out the results register a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); @@ -1213,7 +1212,7 @@ void GenConvKernel<2, int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>( - asmjit::X86Emitter* a) { + x86::Emitter* a) { // left edge excluding corners asmjit::Label LoopLeftEdge = a->newLabel(); a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_)); @@ -1256,7 +1255,7 @@ void GenConvKernel<2, int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>( - asmjit::X86Emitter* a) { + x86::Emitter* a) { // right edge excluding corners asmjit::Label LoopRightEdge = a->newLabel(); @@ -1326,7 +1325,7 @@ void GenConvKernel<2, int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>( - asmjit::X86Emitter* a) { + x86::Emitter* a) { // bottom-left corner // zero out a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); @@ -1429,7 +1428,7 @@ void GenConvKernel<2, int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genRowoffsetCore<inst_set_t::avx2>( - asmjit::X86Emitter* a) { + x86::Emitter* a) { // number of uint8 elements in input channels should be a multiple of 32 assert(C_ % 32 == 0); @@ -1491,9 +1490,9 @@ jit_rowoffset_kernel_fp GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>( const conv_param_t<2>& conv_param) { code_.reset(false); - code_.init(rt_.getCodeInfo()); - asmjit::X86Assembler assembler(&code_); - asmjit::X86Emitter* a = assembler.asEmitter(); + code_.init(rt_.codeInfo()); + x86::Assembler assembler(&code_); + x86::Emitter* a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) // log code to a file @@ -1510,45 +1509,45 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>( a_zero_pt_R_ = a->zsi(); H_R_ = a->zdx(); W_R_ = a->zcx(); - row_offset_R_ = a->gpzRef(8); + row_offset_R_ = a->gpz(8); // register for temporary use - scratchReg1_ = a->gpzRef(12); - scratchReg2_ = a->gpzRef(13); + scratchReg1_ = a->gpz(12); + scratchReg2_ = a->gpz(13); - loopR1_ = a->gpzRef(14); - loopR2_ = a->gpzRef(15); + loopR1_ = a->gpz(14); + loopR2_ = a->gpz(15); asmjit::FuncDetail func; func.init( asmjit:: - FuncSignature5<void, uint8_t*, int32_t, int32_t, int32_t, int32_t*>( + FuncSignatureT<void, uint8_t*, int32_t, int32_t, int32_t, int32_t*>( asmjit::CallConv::kIdHost)); - asmjit::FuncFrameInfo ffi; - ffi.setDirtyRegs( - asmjit::X86Reg::kKindVec, - asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); - ffi.setDirtyRegs( - asmjit::X86Reg::kKindGp, asmjit::Utils::mask(10, 11, 12, 13, 14, 15)); + asmjit::FuncFrame frame; + frame.init(func); - asmjit::FuncArgsMapper args(&func); - args.assignAll(in_acts_R_, a_zero_pt_R_, H_R_, W_R_, row_offset_R_); + frame.setDirtyRegs( + x86::Reg::kGroupVec, + asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.setDirtyRegs( + x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - args.updateFrameInfo(ffi); + asmjit::FuncArgsAssignment args(&func); + args.assignAll(in_acts_R_, a_zero_pt_R_, H_R_, W_R_, row_offset_R_); - asmjit::FuncFrameLayout layout; - layout.init(func, ffi); + args.updateFuncFrame(frame); + frame.finalize(); - asmjit::FuncUtils::emitProlog(a, layout); - asmjit::FuncUtils::allocArgs(a, layout, args); + a->emitProlog(frame); + a->emitArgsAssignment(frame, args); // This uses xmm10 register temporarily. Should come before // createVector8BitOne if (!isAZeroPointZero_) { // we can use xmm11 because ymm11 is used by tmpReg1Avx2_ - asmjit::X86Xmm const_reg_xmm = x86::xmm11; + x86::Xmm const_reg_xmm = x86::xmm11; a->movq(const_reg_xmm, a_zero_pt_R_); a->vpbroadcastd(zeroPTRegAvx2_, const_reg_xmm); @@ -1569,7 +1568,7 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>( genRowoffsetCore<inst_set_t::avx2>(a); - asmjit::FuncUtils::emitEpilog(a, layout); + a->emitEpilog(frame); jit_rowoffset_kernel_fp fn; asmjit::Error err = rt_.add(&fn, &code_); |