Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/GroupwiseConvAcc32Avx2.cc')
-rw-r--r--src/GroupwiseConvAcc32Avx2.cc177
1 files changed, 88 insertions, 89 deletions
diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc
index e789695..b140c83 100644
--- a/src/GroupwiseConvAcc32Avx2.cc
+++ b/src/GroupwiseConvAcc32Avx2.cc
@@ -104,7 +104,7 @@ jit_conv_kernel_fp getOrCreateConvKernel(
template <>
template <>
void GenConvKernel<2, int32_t>::createVector8BitOne<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// create 8-bit 1s
// i.e., oneReg16BitAvx2_[0:7] contains 0x01, oneReg8BitAvx2_[8:15] contains
// 0x01 and so on
@@ -115,7 +115,7 @@ void GenConvKernel<2, int32_t>::createVector8BitOne<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::createVector16BitOne<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// create 16-bit 1s
// i.e., oneReg16BitAvx2_[0:15] contains 0x0001, oneReg16BitAvx2_[16:31]
// contains 0x0001 and so on
@@ -125,11 +125,11 @@ void GenConvKernel<2, int32_t>::createVector16BitOne<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::setToZeroPt<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
- asmjit::X86Ymm destReg) {
+ x86::Emitter* a,
+ x86::Ymm destReg) {
// make destReg all zeros
a->vxorps(destReg, destReg, destReg);
- asmjit::X86Xmm const_reg_xmm = x86::xmm10;
+ x86::Xmm const_reg_xmm = x86::xmm10;
// move zero point to xmm10
a->movq(const_reg_xmm, a_zero_pt_R_);
// make copies of zero point
@@ -143,9 +143,9 @@ void GenConvKernel<2, int32_t>::setToZeroPt<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genConstForPermutations<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
- asmjit::X86Gp permute_const_reg = a->gpzRef(12);
- asmjit::X86Xmm const_reg_xmm = x86::xmm10;
+ x86::Emitter* a) {
+ x86::Gp permute_const_reg = a->gpz(12);
+ x86::Xmm const_reg_xmm = x86::xmm10;
// We have 1st group in even lanes and 2nd group in odd lanes.
// Permute to put 1st group to lower 128-bit and 2nd group in upper
// 128-bit.
@@ -159,8 +159,7 @@ void GenConvKernel<2, int32_t>::genConstForPermutations<inst_set_t::avx2>(
template <>
template <>
-void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>(x86::Emitter* a) {
if (C_per_G_ == 4) {
// store with permutation
a->vpermd(resultRegAvx2_, stPermRegAvx2_, resultRegAvx2_);
@@ -171,7 +170,7 @@ void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::storeResultRowoffset<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int offset) {
// store
if (C_per_G_ == 4) {
@@ -198,7 +197,7 @@ void GenConvKernel<2, int32_t>::storeResultRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForLoadingWeights<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int c_offset) {
// load weights
for (int r = 0; r < R_; ++r) {
@@ -225,9 +224,9 @@ void GenConvKernel<2, int32_t>::genForLoadingWeights<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8bitFMA<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
- asmjit::X86Ymm aReg,
- asmjit::X86Ymm wReg) {
+ x86::Emitter* a,
+ x86::Ymm aReg,
+ x86::Ymm wReg) {
a->vpmaddubsw(tmpReg1Avx2_, aReg, wReg);
a->vpmaddwd(tmpReg1Avx2_, oneReg16BitAvx2_, tmpReg1Avx2_);
a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_);
@@ -236,8 +235,8 @@ void GenConvKernel<2, int32_t>::gen8bitFMA<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8BitSumX4<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
- asmjit::X86Ymm aReg) {
+ x86::Emitter* a,
+ x86::Ymm aReg) {
a->vpmaddubsw(tmpReg1Avx2_, aReg, oneReg8BitAvx2_);
a->vpmaddwd(tmpReg1Avx2_, tmpReg1Avx2_, oneReg16BitAvx2_);
a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_);
@@ -246,9 +245,9 @@ void GenConvKernel<2, int32_t>::gen8BitSumX4<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8BitSumX8<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
- asmjit::X86Ymm aReg,
- asmjit::X86Ymm bReg) {
+ x86::Emitter* a,
+ x86::Ymm aReg,
+ x86::Ymm bReg) {
a->vxorps(tmpReg1Avx2_, tmpReg1Avx2_, tmpReg1Avx2_);
// Let a[0] denote 0th (LSB) 8-bit of aReg
// After vpsadbw, a[0:2] = a[0] + ... + a[7]
@@ -267,11 +266,11 @@ void GenConvKernel<2, int32_t>::gen8BitSumX8<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8BitSumX16<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
- asmjit::X86Ymm aReg,
- asmjit::X86Ymm bReg,
- asmjit::X86Ymm cReg,
- asmjit::X86Ymm dReg) {
+ x86::Emitter* a,
+ x86::Ymm aReg,
+ x86::Ymm bReg,
+ x86::Ymm cReg,
+ x86::Ymm dReg) {
a->vxorps(tmpReg1Avx2_, tmpReg1Avx2_, tmpReg1Avx2_);
// After vpsadbw, a[0:2] = a[0] + ... + a[7]
// a[8:10] = a[8] + ... + a[15]
@@ -319,7 +318,7 @@ void GenConvKernel<2, int32_t>::gen8BitSumX16<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8BitSum<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int act_offset,
bool use_scratch_reg1 /*=true*/) {
if (use_scratch_reg1) {
@@ -385,11 +384,11 @@ void GenConvKernel<2, int32_t>::gen8BitSum<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genZeroPtSum<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int multiplier) {
a->mov(scratchReg1_, static_cast<asmjit::Imm>(multiplier));
// tmpReg1Avx2_ also uses xmm11
- asmjit::X86Xmm const_reg_xmm = x86::xmm11;
+ x86::Xmm const_reg_xmm = x86::xmm11;
a->movq(const_reg_xmm, scratchReg1_);
a->vpbroadcastd(tmpReg1Avx2_, const_reg_xmm);
a->vpmulld(tmpReg1Avx2_, zeroPTRegAvx2_, tmpReg1Avx2_);
@@ -399,7 +398,7 @@ void GenConvKernel<2, int32_t>::genZeroPtSum<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForTopEdge<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int c_offset) {
// top-left corner code
if (c_offset == 0) {
@@ -559,7 +558,7 @@ void GenConvKernel<2, int32_t>::genForTopEdge<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForLeftEdge<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int c_offset) {
// left edge excluding corners
asmjit::Label LoopLeftEdge = a->newLabel();
@@ -626,7 +625,7 @@ void GenConvKernel<2, int32_t>::genForLeftEdge<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForRightEdge<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int c_offset) {
// right edge excluding corners
asmjit::Label LoopRightEdge = a->newLabel();
@@ -714,7 +713,7 @@ void GenConvKernel<2, int32_t>::genForRightEdge<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForBottomEdge<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int c_offset) {
// bottom-left corner
// we updating the last row
@@ -906,7 +905,7 @@ void GenConvKernel<2, int32_t>::genForBottomEdge<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genCoreInsts<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int c_offset) {
// main compute
asmjit::Label LoopH = a->newLabel();
@@ -1011,9 +1010,9 @@ template <>
jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
const conv_param_t<2>& conv_param) {
code_.reset(false);
- code_.init(rt_.getCodeInfo());
- asmjit::X86Assembler assembler(&code_);
- asmjit::X86Emitter* a = assembler.asEmitter();
+ code_.init(rt_.codeInfo());
+ x86::Assembler assembler(&code_);
+ x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
// log code to a file
@@ -1030,16 +1029,16 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
wghts_R_ = a->zsi();
out_acts_R_ = a->zdx();
a_zero_pt_R_ = a->zcx();
- H_R_ = a->gpzRef(8);
- W_R_ = a->gpzRef(9);
- row_offset_R_ = a->gpzRef(10);
+ H_R_ = a->gpz(8);
+ W_R_ = a->gpz(9);
+ row_offset_R_ = a->gpz(10);
// register for temporary use
- scratchReg1_ = a->gpzRef(12);
- scratchReg2_ = a->gpzRef(13);
+ scratchReg1_ = a->gpz(12);
+ scratchReg2_ = a->gpz(13);
asmjit::FuncDetail func;
- func.init(asmjit::FuncSignature6<
+ func.init(asmjit::FuncSignatureT<
void,
uint8_t*,
int8_t*,
@@ -1048,29 +1047,29 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
int32_t,
int32_t>(asmjit::CallConv::kIdHost));
- asmjit::FuncFrameInfo ffi;
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindVec,
- asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindGp, asmjit::Utils::mask(10, 11, 12, 13, 14, 15));
+ asmjit::FuncFrame frame;
+ frame.init(func);
- asmjit::FuncArgsMapper args(&func);
- args.assignAll(in_acts_R_, wghts_R_, out_acts_R_, a_zero_pt_R_, H_R_, W_R_);
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(
+ x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- args.updateFrameInfo(ffi);
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(in_acts_R_, wghts_R_, out_acts_R_, a_zero_pt_R_, H_R_, W_R_);
- asmjit::FuncFrameLayout layout;
- layout.init(func, ffi);
+ args.updateFuncFrame(frame);
+ frame.finalize();
- asmjit::FuncUtils::emitProlog(a, layout);
- asmjit::FuncUtils::allocArgs(a, layout, args);
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
createVector16BitOne<inst_set_t::avx2>(a);
- loopR1_ = a->gpzRef(14);
- loopR2_ = a->gpzRef(15);
+ loopR1_ = a->gpz(14);
+ loopR2_ = a->gpz(15);
if (!isAZeroPointZero_) {
setToZeroPt<inst_set_t::avx2>(a, zeroPTRegAvx2_);
@@ -1095,7 +1094,7 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
genCoreInsts<inst_set_t::avx2>(a, c);
}
- asmjit::FuncUtils::emitEpilog(a, layout);
+ a->emitEpilog(frame);
jit_conv_kernel_fp fn;
asmjit::Error err = rt_.add(&fn, &code_);
@@ -1117,7 +1116,7 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// top-left corner code
// zero out the results register
a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
@@ -1213,7 +1212,7 @@ void GenConvKernel<2, int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// left edge excluding corners
asmjit::Label LoopLeftEdge = a->newLabel();
a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_));
@@ -1256,7 +1255,7 @@ void GenConvKernel<2, int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// right edge excluding corners
asmjit::Label LoopRightEdge = a->newLabel();
@@ -1326,7 +1325,7 @@ void GenConvKernel<2, int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// bottom-left corner
// zero out
a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
@@ -1429,7 +1428,7 @@ void GenConvKernel<2, int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genRowoffsetCore<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// number of uint8 elements in input channels should be a multiple of 32
assert(C_ % 32 == 0);
@@ -1491,9 +1490,9 @@ jit_rowoffset_kernel_fp
GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
const conv_param_t<2>& conv_param) {
code_.reset(false);
- code_.init(rt_.getCodeInfo());
- asmjit::X86Assembler assembler(&code_);
- asmjit::X86Emitter* a = assembler.asEmitter();
+ code_.init(rt_.codeInfo());
+ x86::Assembler assembler(&code_);
+ x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
// log code to a file
@@ -1510,45 +1509,45 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
a_zero_pt_R_ = a->zsi();
H_R_ = a->zdx();
W_R_ = a->zcx();
- row_offset_R_ = a->gpzRef(8);
+ row_offset_R_ = a->gpz(8);
// register for temporary use
- scratchReg1_ = a->gpzRef(12);
- scratchReg2_ = a->gpzRef(13);
+ scratchReg1_ = a->gpz(12);
+ scratchReg2_ = a->gpz(13);
- loopR1_ = a->gpzRef(14);
- loopR2_ = a->gpzRef(15);
+ loopR1_ = a->gpz(14);
+ loopR2_ = a->gpz(15);
asmjit::FuncDetail func;
func.init(
asmjit::
- FuncSignature5<void, uint8_t*, int32_t, int32_t, int32_t, int32_t*>(
+ FuncSignatureT<void, uint8_t*, int32_t, int32_t, int32_t, int32_t*>(
asmjit::CallConv::kIdHost));
- asmjit::FuncFrameInfo ffi;
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindVec,
- asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindGp, asmjit::Utils::mask(10, 11, 12, 13, 14, 15));
+ asmjit::FuncFrame frame;
+ frame.init(func);
- asmjit::FuncArgsMapper args(&func);
- args.assignAll(in_acts_R_, a_zero_pt_R_, H_R_, W_R_, row_offset_R_);
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(
+ x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- args.updateFrameInfo(ffi);
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(in_acts_R_, a_zero_pt_R_, H_R_, W_R_, row_offset_R_);
- asmjit::FuncFrameLayout layout;
- layout.init(func, ffi);
+ args.updateFuncFrame(frame);
+ frame.finalize();
- asmjit::FuncUtils::emitProlog(a, layout);
- asmjit::FuncUtils::allocArgs(a, layout, args);
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
// This uses xmm10 register temporarily. Should come before
// createVector8BitOne
if (!isAZeroPointZero_) {
// we can use xmm11 because ymm11 is used by tmpReg1Avx2_
- asmjit::X86Xmm const_reg_xmm = x86::xmm11;
+ x86::Xmm const_reg_xmm = x86::xmm11;
a->movq(const_reg_xmm, a_zero_pt_R_);
a->vpbroadcastd(zeroPTRegAvx2_, const_reg_xmm);
@@ -1569,7 +1568,7 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
genRowoffsetCore<inst_set_t::avx2>(a);
- asmjit::FuncUtils::emitEpilog(a, layout);
+ a->emitEpilog(frame);
jit_rowoffset_kernel_fp fn;
asmjit::Error err = rt_.add(&fn, &code_);