diff options
Diffstat (limited to 'src/GroupwiseConvAcc32Avx2.cc')
-rw-r--r-- | src/GroupwiseConvAcc32Avx2.cc | 177 |
1 files changed, 89 insertions, 88 deletions
diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc index b140c83..e789695 100644 --- a/src/GroupwiseConvAcc32Avx2.cc +++ b/src/GroupwiseConvAcc32Avx2.cc @@ -104,7 +104,7 @@ jit_conv_kernel_fp getOrCreateConvKernel( template <> template <> void GenConvKernel<2, int32_t>::createVector8BitOne<inst_set_t::avx2>( - x86::Emitter* a) { + asmjit::X86Emitter* a) { // create 8-bit 1s // i.e., oneReg16BitAvx2_[0:7] contains 0x01, oneReg8BitAvx2_[8:15] contains // 0x01 and so on @@ -115,7 +115,7 @@ void GenConvKernel<2, int32_t>::createVector8BitOne<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::createVector16BitOne<inst_set_t::avx2>( - x86::Emitter* a) { + asmjit::X86Emitter* a) { // create 16-bit 1s // i.e., oneReg16BitAvx2_[0:15] contains 0x0001, oneReg16BitAvx2_[16:31] // contains 0x0001 and so on @@ -125,11 +125,11 @@ void GenConvKernel<2, int32_t>::createVector16BitOne<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::setToZeroPt<inst_set_t::avx2>( - x86::Emitter* a, - x86::Ymm destReg) { + asmjit::X86Emitter* a, + asmjit::X86Ymm destReg) { // make destReg all zeros a->vxorps(destReg, destReg, destReg); - x86::Xmm const_reg_xmm = x86::xmm10; + asmjit::X86Xmm const_reg_xmm = x86::xmm10; // move zero point to xmm10 a->movq(const_reg_xmm, a_zero_pt_R_); // make copies of zero point @@ -143,9 +143,9 @@ void GenConvKernel<2, int32_t>::setToZeroPt<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genConstForPermutations<inst_set_t::avx2>( - x86::Emitter* a) { - x86::Gp permute_const_reg = a->gpz(12); - x86::Xmm const_reg_xmm = x86::xmm10; + asmjit::X86Emitter* a) { + asmjit::X86Gp permute_const_reg = a->gpzRef(12); + asmjit::X86Xmm const_reg_xmm = x86::xmm10; // We have 1st group in even lanes and 2nd group in odd lanes. // Permute to put 1st group to lower 128-bit and 2nd group in upper // 128-bit. @@ -159,7 +159,8 @@ void GenConvKernel<2, int32_t>::genConstForPermutations<inst_set_t::avx2>( template <> template <> -void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>(x86::Emitter* a) { +void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>( + asmjit::X86Emitter* a) { if (C_per_G_ == 4) { // store with permutation a->vpermd(resultRegAvx2_, stPermRegAvx2_, resultRegAvx2_); @@ -170,7 +171,7 @@ void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>(x86::Emitter* a) { template <> template <> void GenConvKernel<2, int32_t>::storeResultRowoffset<inst_set_t::avx2>( - x86::Emitter* a, + asmjit::X86Emitter* a, int offset) { // store if (C_per_G_ == 4) { @@ -197,7 +198,7 @@ void GenConvKernel<2, int32_t>::storeResultRowoffset<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForLoadingWeights<inst_set_t::avx2>( - x86::Emitter* a, + asmjit::X86Emitter* a, int c_offset) { // load weights for (int r = 0; r < R_; ++r) { @@ -224,9 +225,9 @@ void GenConvKernel<2, int32_t>::genForLoadingWeights<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::gen8bitFMA<inst_set_t::avx2>( - x86::Emitter* a, - x86::Ymm aReg, - x86::Ymm wReg) { + asmjit::X86Emitter* a, + asmjit::X86Ymm aReg, + asmjit::X86Ymm wReg) { a->vpmaddubsw(tmpReg1Avx2_, aReg, wReg); a->vpmaddwd(tmpReg1Avx2_, oneReg16BitAvx2_, tmpReg1Avx2_); a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_); @@ -235,8 +236,8 @@ void GenConvKernel<2, int32_t>::gen8bitFMA<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::gen8BitSumX4<inst_set_t::avx2>( - x86::Emitter* a, - x86::Ymm aReg) { + asmjit::X86Emitter* a, + asmjit::X86Ymm aReg) { a->vpmaddubsw(tmpReg1Avx2_, aReg, oneReg8BitAvx2_); a->vpmaddwd(tmpReg1Avx2_, tmpReg1Avx2_, oneReg16BitAvx2_); a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_); @@ -245,9 +246,9 @@ void GenConvKernel<2, int32_t>::gen8BitSumX4<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::gen8BitSumX8<inst_set_t::avx2>( - x86::Emitter* a, - x86::Ymm aReg, - x86::Ymm bReg) { + asmjit::X86Emitter* a, + asmjit::X86Ymm aReg, + asmjit::X86Ymm bReg) { a->vxorps(tmpReg1Avx2_, tmpReg1Avx2_, tmpReg1Avx2_); // Let a[0] denote 0th (LSB) 8-bit of aReg // After vpsadbw, a[0:2] = a[0] + ... + a[7] @@ -266,11 +267,11 @@ void GenConvKernel<2, int32_t>::gen8BitSumX8<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::gen8BitSumX16<inst_set_t::avx2>( - x86::Emitter* a, - x86::Ymm aReg, - x86::Ymm bReg, - x86::Ymm cReg, - x86::Ymm dReg) { + asmjit::X86Emitter* a, + asmjit::X86Ymm aReg, + asmjit::X86Ymm bReg, + asmjit::X86Ymm cReg, + asmjit::X86Ymm dReg) { a->vxorps(tmpReg1Avx2_, tmpReg1Avx2_, tmpReg1Avx2_); // After vpsadbw, a[0:2] = a[0] + ... + a[7] // a[8:10] = a[8] + ... + a[15] @@ -318,7 +319,7 @@ void GenConvKernel<2, int32_t>::gen8BitSumX16<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::gen8BitSum<inst_set_t::avx2>( - x86::Emitter* a, + asmjit::X86Emitter* a, int act_offset, bool use_scratch_reg1 /*=true*/) { if (use_scratch_reg1) { @@ -384,11 +385,11 @@ void GenConvKernel<2, int32_t>::gen8BitSum<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genZeroPtSum<inst_set_t::avx2>( - x86::Emitter* a, + asmjit::X86Emitter* a, int multiplier) { a->mov(scratchReg1_, static_cast<asmjit::Imm>(multiplier)); // tmpReg1Avx2_ also uses xmm11 - x86::Xmm const_reg_xmm = x86::xmm11; + asmjit::X86Xmm const_reg_xmm = x86::xmm11; a->movq(const_reg_xmm, scratchReg1_); a->vpbroadcastd(tmpReg1Avx2_, const_reg_xmm); a->vpmulld(tmpReg1Avx2_, zeroPTRegAvx2_, tmpReg1Avx2_); @@ -398,7 +399,7 @@ void GenConvKernel<2, int32_t>::genZeroPtSum<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForTopEdge<inst_set_t::avx2>( - x86::Emitter* a, + asmjit::X86Emitter* a, int c_offset) { // top-left corner code if (c_offset == 0) { @@ -558,7 +559,7 @@ void GenConvKernel<2, int32_t>::genForTopEdge<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForLeftEdge<inst_set_t::avx2>( - x86::Emitter* a, + asmjit::X86Emitter* a, int c_offset) { // left edge excluding corners asmjit::Label LoopLeftEdge = a->newLabel(); @@ -625,7 +626,7 @@ void GenConvKernel<2, int32_t>::genForLeftEdge<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForRightEdge<inst_set_t::avx2>( - x86::Emitter* a, + asmjit::X86Emitter* a, int c_offset) { // right edge excluding corners asmjit::Label LoopRightEdge = a->newLabel(); @@ -713,7 +714,7 @@ void GenConvKernel<2, int32_t>::genForRightEdge<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForBottomEdge<inst_set_t::avx2>( - x86::Emitter* a, + asmjit::X86Emitter* a, int c_offset) { // bottom-left corner // we updating the last row @@ -905,7 +906,7 @@ void GenConvKernel<2, int32_t>::genForBottomEdge<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genCoreInsts<inst_set_t::avx2>( - x86::Emitter* a, + asmjit::X86Emitter* a, int c_offset) { // main compute asmjit::Label LoopH = a->newLabel(); @@ -1010,9 +1011,9 @@ template <> jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( const conv_param_t<2>& conv_param) { code_.reset(false); - code_.init(rt_.codeInfo()); - x86::Assembler assembler(&code_); - x86::Emitter* a = assembler.as<x86::Emitter>(); + code_.init(rt_.getCodeInfo()); + asmjit::X86Assembler assembler(&code_); + asmjit::X86Emitter* a = assembler.asEmitter(); #if defined(FBGEMM_LOG_CODE) // log code to a file @@ -1029,16 +1030,16 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( wghts_R_ = a->zsi(); out_acts_R_ = a->zdx(); a_zero_pt_R_ = a->zcx(); - H_R_ = a->gpz(8); - W_R_ = a->gpz(9); - row_offset_R_ = a->gpz(10); + H_R_ = a->gpzRef(8); + W_R_ = a->gpzRef(9); + row_offset_R_ = a->gpzRef(10); // register for temporary use - scratchReg1_ = a->gpz(12); - scratchReg2_ = a->gpz(13); + scratchReg1_ = a->gpzRef(12); + scratchReg2_ = a->gpzRef(13); asmjit::FuncDetail func; - func.init(asmjit::FuncSignatureT< + func.init(asmjit::FuncSignature6< void, uint8_t*, int8_t*, @@ -1047,29 +1048,29 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( int32_t, int32_t>(asmjit::CallConv::kIdHost)); - asmjit::FuncFrame frame; - frame.init(func); + asmjit::FuncFrameInfo ffi; + ffi.setDirtyRegs( + asmjit::X86Reg::kKindVec, + asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); + ffi.setDirtyRegs( + asmjit::X86Reg::kKindGp, asmjit::Utils::mask(10, 11, 12, 13, 14, 15)); - frame.setDirtyRegs( - x86::Reg::kGroupVec, - asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - frame.setDirtyRegs( - x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - - asmjit::FuncArgsAssignment args(&func); + asmjit::FuncArgsMapper args(&func); args.assignAll(in_acts_R_, wghts_R_, out_acts_R_, a_zero_pt_R_, H_R_, W_R_); - args.updateFuncFrame(frame); - frame.finalize(); + args.updateFrameInfo(ffi); + + asmjit::FuncFrameLayout layout; + layout.init(func, ffi); - a->emitProlog(frame); - a->emitArgsAssignment(frame, args); + asmjit::FuncUtils::emitProlog(a, layout); + asmjit::FuncUtils::allocArgs(a, layout, args); createVector16BitOne<inst_set_t::avx2>(a); - loopR1_ = a->gpz(14); - loopR2_ = a->gpz(15); + loopR1_ = a->gpzRef(14); + loopR2_ = a->gpzRef(15); if (!isAZeroPointZero_) { setToZeroPt<inst_set_t::avx2>(a, zeroPTRegAvx2_); @@ -1094,7 +1095,7 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( genCoreInsts<inst_set_t::avx2>(a, c); } - a->emitEpilog(frame); + asmjit::FuncUtils::emitEpilog(a, layout); jit_conv_kernel_fp fn; asmjit::Error err = rt_.add(&fn, &code_); @@ -1116,7 +1117,7 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>( - x86::Emitter* a) { + asmjit::X86Emitter* a) { // top-left corner code // zero out the results register a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); @@ -1212,7 +1213,7 @@ void GenConvKernel<2, int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>( - x86::Emitter* a) { + asmjit::X86Emitter* a) { // left edge excluding corners asmjit::Label LoopLeftEdge = a->newLabel(); a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_)); @@ -1255,7 +1256,7 @@ void GenConvKernel<2, int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>( - x86::Emitter* a) { + asmjit::X86Emitter* a) { // right edge excluding corners asmjit::Label LoopRightEdge = a->newLabel(); @@ -1325,7 +1326,7 @@ void GenConvKernel<2, int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>( - x86::Emitter* a) { + asmjit::X86Emitter* a) { // bottom-left corner // zero out a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); @@ -1428,7 +1429,7 @@ void GenConvKernel<2, int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genRowoffsetCore<inst_set_t::avx2>( - x86::Emitter* a) { + asmjit::X86Emitter* a) { // number of uint8 elements in input channels should be a multiple of 32 assert(C_ % 32 == 0); @@ -1490,9 +1491,9 @@ jit_rowoffset_kernel_fp GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>( const conv_param_t<2>& conv_param) { code_.reset(false); - code_.init(rt_.codeInfo()); - x86::Assembler assembler(&code_); - x86::Emitter* a = assembler.as<x86::Emitter>(); + code_.init(rt_.getCodeInfo()); + asmjit::X86Assembler assembler(&code_); + asmjit::X86Emitter* a = assembler.asEmitter(); #if defined(FBGEMM_LOG_CODE) // log code to a file @@ -1509,45 +1510,45 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>( a_zero_pt_R_ = a->zsi(); H_R_ = a->zdx(); W_R_ = a->zcx(); - row_offset_R_ = a->gpz(8); + row_offset_R_ = a->gpzRef(8); // register for temporary use - scratchReg1_ = a->gpz(12); - scratchReg2_ = a->gpz(13); + scratchReg1_ = a->gpzRef(12); + scratchReg2_ = a->gpzRef(13); - loopR1_ = a->gpz(14); - loopR2_ = a->gpz(15); + loopR1_ = a->gpzRef(14); + loopR2_ = a->gpzRef(15); asmjit::FuncDetail func; func.init( asmjit:: - FuncSignatureT<void, uint8_t*, int32_t, int32_t, int32_t, int32_t*>( + FuncSignature5<void, uint8_t*, int32_t, int32_t, int32_t, int32_t*>( asmjit::CallConv::kIdHost)); - asmjit::FuncFrame frame; - frame.init(func); + asmjit::FuncFrameInfo ffi; + ffi.setDirtyRegs( + asmjit::X86Reg::kKindVec, + asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); + ffi.setDirtyRegs( + asmjit::X86Reg::kKindGp, asmjit::Utils::mask(10, 11, 12, 13, 14, 15)); - frame.setDirtyRegs( - x86::Reg::kGroupVec, - asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - frame.setDirtyRegs( - x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - - asmjit::FuncArgsAssignment args(&func); + asmjit::FuncArgsMapper args(&func); args.assignAll(in_acts_R_, a_zero_pt_R_, H_R_, W_R_, row_offset_R_); - args.updateFuncFrame(frame); - frame.finalize(); + args.updateFrameInfo(ffi); + + asmjit::FuncFrameLayout layout; + layout.init(func, ffi); - a->emitProlog(frame); - a->emitArgsAssignment(frame, args); + asmjit::FuncUtils::emitProlog(a, layout); + asmjit::FuncUtils::allocArgs(a, layout, args); // This uses xmm10 register temporarily. Should come before // createVector8BitOne if (!isAZeroPointZero_) { // we can use xmm11 because ymm11 is used by tmpReg1Avx2_ - x86::Xmm const_reg_xmm = x86::xmm11; + asmjit::X86Xmm const_reg_xmm = x86::xmm11; a->movq(const_reg_xmm, a_zero_pt_R_); a->vpbroadcastd(zeroPTRegAvx2_, const_reg_xmm); @@ -1568,7 +1569,7 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>( genRowoffsetCore<inst_set_t::avx2>(a); - a->emitEpilog(frame); + asmjit::FuncUtils::emitEpilog(a, layout); jit_rowoffset_kernel_fp fn; asmjit::Error err = rt_.add(&fn, &code_); |