Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/GroupwiseConvAcc32Avx2.cc')
-rw-r--r--src/GroupwiseConvAcc32Avx2.cc177
1 files changed, 89 insertions, 88 deletions
diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc
index b140c83..e789695 100644
--- a/src/GroupwiseConvAcc32Avx2.cc
+++ b/src/GroupwiseConvAcc32Avx2.cc
@@ -104,7 +104,7 @@ jit_conv_kernel_fp getOrCreateConvKernel(
template <>
template <>
void GenConvKernel<2, int32_t>::createVector8BitOne<inst_set_t::avx2>(
- x86::Emitter* a) {
+ asmjit::X86Emitter* a) {
// create 8-bit 1s
// i.e., oneReg16BitAvx2_[0:7] contains 0x01, oneReg8BitAvx2_[8:15] contains
// 0x01 and so on
@@ -115,7 +115,7 @@ void GenConvKernel<2, int32_t>::createVector8BitOne<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::createVector16BitOne<inst_set_t::avx2>(
- x86::Emitter* a) {
+ asmjit::X86Emitter* a) {
// create 16-bit 1s
// i.e., oneReg16BitAvx2_[0:15] contains 0x0001, oneReg16BitAvx2_[16:31]
// contains 0x0001 and so on
@@ -125,11 +125,11 @@ void GenConvKernel<2, int32_t>::createVector16BitOne<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::setToZeroPt<inst_set_t::avx2>(
- x86::Emitter* a,
- x86::Ymm destReg) {
+ asmjit::X86Emitter* a,
+ asmjit::X86Ymm destReg) {
// make destReg all zeros
a->vxorps(destReg, destReg, destReg);
- x86::Xmm const_reg_xmm = x86::xmm10;
+ asmjit::X86Xmm const_reg_xmm = x86::xmm10;
// move zero point to xmm10
a->movq(const_reg_xmm, a_zero_pt_R_);
// make copies of zero point
@@ -143,9 +143,9 @@ void GenConvKernel<2, int32_t>::setToZeroPt<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genConstForPermutations<inst_set_t::avx2>(
- x86::Emitter* a) {
- x86::Gp permute_const_reg = a->gpz(12);
- x86::Xmm const_reg_xmm = x86::xmm10;
+ asmjit::X86Emitter* a) {
+ asmjit::X86Gp permute_const_reg = a->gpzRef(12);
+ asmjit::X86Xmm const_reg_xmm = x86::xmm10;
// We have 1st group in even lanes and 2nd group in odd lanes.
// Permute to put 1st group to lower 128-bit and 2nd group in upper
// 128-bit.
@@ -159,7 +159,8 @@ void GenConvKernel<2, int32_t>::genConstForPermutations<inst_set_t::avx2>(
template <>
template <>
-void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>(x86::Emitter* a) {
+void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>(
+ asmjit::X86Emitter* a) {
if (C_per_G_ == 4) {
// store with permutation
a->vpermd(resultRegAvx2_, stPermRegAvx2_, resultRegAvx2_);
@@ -170,7 +171,7 @@ void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>(x86::Emitter* a) {
template <>
template <>
void GenConvKernel<2, int32_t>::storeResultRowoffset<inst_set_t::avx2>(
- x86::Emitter* a,
+ asmjit::X86Emitter* a,
int offset) {
// store
if (C_per_G_ == 4) {
@@ -197,7 +198,7 @@ void GenConvKernel<2, int32_t>::storeResultRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForLoadingWeights<inst_set_t::avx2>(
- x86::Emitter* a,
+ asmjit::X86Emitter* a,
int c_offset) {
// load weights
for (int r = 0; r < R_; ++r) {
@@ -224,9 +225,9 @@ void GenConvKernel<2, int32_t>::genForLoadingWeights<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8bitFMA<inst_set_t::avx2>(
- x86::Emitter* a,
- x86::Ymm aReg,
- x86::Ymm wReg) {
+ asmjit::X86Emitter* a,
+ asmjit::X86Ymm aReg,
+ asmjit::X86Ymm wReg) {
a->vpmaddubsw(tmpReg1Avx2_, aReg, wReg);
a->vpmaddwd(tmpReg1Avx2_, oneReg16BitAvx2_, tmpReg1Avx2_);
a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_);
@@ -235,8 +236,8 @@ void GenConvKernel<2, int32_t>::gen8bitFMA<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8BitSumX4<inst_set_t::avx2>(
- x86::Emitter* a,
- x86::Ymm aReg) {
+ asmjit::X86Emitter* a,
+ asmjit::X86Ymm aReg) {
a->vpmaddubsw(tmpReg1Avx2_, aReg, oneReg8BitAvx2_);
a->vpmaddwd(tmpReg1Avx2_, tmpReg1Avx2_, oneReg16BitAvx2_);
a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_);
@@ -245,9 +246,9 @@ void GenConvKernel<2, int32_t>::gen8BitSumX4<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8BitSumX8<inst_set_t::avx2>(
- x86::Emitter* a,
- x86::Ymm aReg,
- x86::Ymm bReg) {
+ asmjit::X86Emitter* a,
+ asmjit::X86Ymm aReg,
+ asmjit::X86Ymm bReg) {
a->vxorps(tmpReg1Avx2_, tmpReg1Avx2_, tmpReg1Avx2_);
// Let a[0] denote 0th (LSB) 8-bit of aReg
// After vpsadbw, a[0:2] = a[0] + ... + a[7]
@@ -266,11 +267,11 @@ void GenConvKernel<2, int32_t>::gen8BitSumX8<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8BitSumX16<inst_set_t::avx2>(
- x86::Emitter* a,
- x86::Ymm aReg,
- x86::Ymm bReg,
- x86::Ymm cReg,
- x86::Ymm dReg) {
+ asmjit::X86Emitter* a,
+ asmjit::X86Ymm aReg,
+ asmjit::X86Ymm bReg,
+ asmjit::X86Ymm cReg,
+ asmjit::X86Ymm dReg) {
a->vxorps(tmpReg1Avx2_, tmpReg1Avx2_, tmpReg1Avx2_);
// After vpsadbw, a[0:2] = a[0] + ... + a[7]
// a[8:10] = a[8] + ... + a[15]
@@ -318,7 +319,7 @@ void GenConvKernel<2, int32_t>::gen8BitSumX16<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8BitSum<inst_set_t::avx2>(
- x86::Emitter* a,
+ asmjit::X86Emitter* a,
int act_offset,
bool use_scratch_reg1 /*=true*/) {
if (use_scratch_reg1) {
@@ -384,11 +385,11 @@ void GenConvKernel<2, int32_t>::gen8BitSum<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genZeroPtSum<inst_set_t::avx2>(
- x86::Emitter* a,
+ asmjit::X86Emitter* a,
int multiplier) {
a->mov(scratchReg1_, static_cast<asmjit::Imm>(multiplier));
// tmpReg1Avx2_ also uses xmm11
- x86::Xmm const_reg_xmm = x86::xmm11;
+ asmjit::X86Xmm const_reg_xmm = x86::xmm11;
a->movq(const_reg_xmm, scratchReg1_);
a->vpbroadcastd(tmpReg1Avx2_, const_reg_xmm);
a->vpmulld(tmpReg1Avx2_, zeroPTRegAvx2_, tmpReg1Avx2_);
@@ -398,7 +399,7 @@ void GenConvKernel<2, int32_t>::genZeroPtSum<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForTopEdge<inst_set_t::avx2>(
- x86::Emitter* a,
+ asmjit::X86Emitter* a,
int c_offset) {
// top-left corner code
if (c_offset == 0) {
@@ -558,7 +559,7 @@ void GenConvKernel<2, int32_t>::genForTopEdge<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForLeftEdge<inst_set_t::avx2>(
- x86::Emitter* a,
+ asmjit::X86Emitter* a,
int c_offset) {
// left edge excluding corners
asmjit::Label LoopLeftEdge = a->newLabel();
@@ -625,7 +626,7 @@ void GenConvKernel<2, int32_t>::genForLeftEdge<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForRightEdge<inst_set_t::avx2>(
- x86::Emitter* a,
+ asmjit::X86Emitter* a,
int c_offset) {
// right edge excluding corners
asmjit::Label LoopRightEdge = a->newLabel();
@@ -713,7 +714,7 @@ void GenConvKernel<2, int32_t>::genForRightEdge<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForBottomEdge<inst_set_t::avx2>(
- x86::Emitter* a,
+ asmjit::X86Emitter* a,
int c_offset) {
// bottom-left corner
// we updating the last row
@@ -905,7 +906,7 @@ void GenConvKernel<2, int32_t>::genForBottomEdge<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genCoreInsts<inst_set_t::avx2>(
- x86::Emitter* a,
+ asmjit::X86Emitter* a,
int c_offset) {
// main compute
asmjit::Label LoopH = a->newLabel();
@@ -1010,9 +1011,9 @@ template <>
jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
const conv_param_t<2>& conv_param) {
code_.reset(false);
- code_.init(rt_.codeInfo());
- x86::Assembler assembler(&code_);
- x86::Emitter* a = assembler.as<x86::Emitter>();
+ code_.init(rt_.getCodeInfo());
+ asmjit::X86Assembler assembler(&code_);
+ asmjit::X86Emitter* a = assembler.asEmitter();
#if defined(FBGEMM_LOG_CODE)
// log code to a file
@@ -1029,16 +1030,16 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
wghts_R_ = a->zsi();
out_acts_R_ = a->zdx();
a_zero_pt_R_ = a->zcx();
- H_R_ = a->gpz(8);
- W_R_ = a->gpz(9);
- row_offset_R_ = a->gpz(10);
+ H_R_ = a->gpzRef(8);
+ W_R_ = a->gpzRef(9);
+ row_offset_R_ = a->gpzRef(10);
// register for temporary use
- scratchReg1_ = a->gpz(12);
- scratchReg2_ = a->gpz(13);
+ scratchReg1_ = a->gpzRef(12);
+ scratchReg2_ = a->gpzRef(13);
asmjit::FuncDetail func;
- func.init(asmjit::FuncSignatureT<
+ func.init(asmjit::FuncSignature6<
void,
uint8_t*,
int8_t*,
@@ -1047,29 +1048,29 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
int32_t,
int32_t>(asmjit::CallConv::kIdHost));
- asmjit::FuncFrame frame;
- frame.init(func);
+ asmjit::FuncFrameInfo ffi;
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindVec,
+ asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindGp, asmjit::Utils::mask(10, 11, 12, 13, 14, 15));
- frame.setDirtyRegs(
- x86::Reg::kGroupVec,
- asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- frame.setDirtyRegs(
- x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
-
- asmjit::FuncArgsAssignment args(&func);
+ asmjit::FuncArgsMapper args(&func);
args.assignAll(in_acts_R_, wghts_R_, out_acts_R_, a_zero_pt_R_, H_R_, W_R_);
- args.updateFuncFrame(frame);
- frame.finalize();
+ args.updateFrameInfo(ffi);
+
+ asmjit::FuncFrameLayout layout;
+ layout.init(func, ffi);
- a->emitProlog(frame);
- a->emitArgsAssignment(frame, args);
+ asmjit::FuncUtils::emitProlog(a, layout);
+ asmjit::FuncUtils::allocArgs(a, layout, args);
createVector16BitOne<inst_set_t::avx2>(a);
- loopR1_ = a->gpz(14);
- loopR2_ = a->gpz(15);
+ loopR1_ = a->gpzRef(14);
+ loopR2_ = a->gpzRef(15);
if (!isAZeroPointZero_) {
setToZeroPt<inst_set_t::avx2>(a, zeroPTRegAvx2_);
@@ -1094,7 +1095,7 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
genCoreInsts<inst_set_t::avx2>(a, c);
}
- a->emitEpilog(frame);
+ asmjit::FuncUtils::emitEpilog(a, layout);
jit_conv_kernel_fp fn;
asmjit::Error err = rt_.add(&fn, &code_);
@@ -1116,7 +1117,7 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>(
- x86::Emitter* a) {
+ asmjit::X86Emitter* a) {
// top-left corner code
// zero out the results register
a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
@@ -1212,7 +1213,7 @@ void GenConvKernel<2, int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>(
- x86::Emitter* a) {
+ asmjit::X86Emitter* a) {
// left edge excluding corners
asmjit::Label LoopLeftEdge = a->newLabel();
a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_));
@@ -1255,7 +1256,7 @@ void GenConvKernel<2, int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>(
- x86::Emitter* a) {
+ asmjit::X86Emitter* a) {
// right edge excluding corners
asmjit::Label LoopRightEdge = a->newLabel();
@@ -1325,7 +1326,7 @@ void GenConvKernel<2, int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
- x86::Emitter* a) {
+ asmjit::X86Emitter* a) {
// bottom-left corner
// zero out
a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
@@ -1428,7 +1429,7 @@ void GenConvKernel<2, int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genRowoffsetCore<inst_set_t::avx2>(
- x86::Emitter* a) {
+ asmjit::X86Emitter* a) {
// number of uint8 elements in input channels should be a multiple of 32
assert(C_ % 32 == 0);
@@ -1490,9 +1491,9 @@ jit_rowoffset_kernel_fp
GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
const conv_param_t<2>& conv_param) {
code_.reset(false);
- code_.init(rt_.codeInfo());
- x86::Assembler assembler(&code_);
- x86::Emitter* a = assembler.as<x86::Emitter>();
+ code_.init(rt_.getCodeInfo());
+ asmjit::X86Assembler assembler(&code_);
+ asmjit::X86Emitter* a = assembler.asEmitter();
#if defined(FBGEMM_LOG_CODE)
// log code to a file
@@ -1509,45 +1510,45 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
a_zero_pt_R_ = a->zsi();
H_R_ = a->zdx();
W_R_ = a->zcx();
- row_offset_R_ = a->gpz(8);
+ row_offset_R_ = a->gpzRef(8);
// register for temporary use
- scratchReg1_ = a->gpz(12);
- scratchReg2_ = a->gpz(13);
+ scratchReg1_ = a->gpzRef(12);
+ scratchReg2_ = a->gpzRef(13);
- loopR1_ = a->gpz(14);
- loopR2_ = a->gpz(15);
+ loopR1_ = a->gpzRef(14);
+ loopR2_ = a->gpzRef(15);
asmjit::FuncDetail func;
func.init(
asmjit::
- FuncSignatureT<void, uint8_t*, int32_t, int32_t, int32_t, int32_t*>(
+ FuncSignature5<void, uint8_t*, int32_t, int32_t, int32_t, int32_t*>(
asmjit::CallConv::kIdHost));
- asmjit::FuncFrame frame;
- frame.init(func);
+ asmjit::FuncFrameInfo ffi;
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindVec,
+ asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindGp, asmjit::Utils::mask(10, 11, 12, 13, 14, 15));
- frame.setDirtyRegs(
- x86::Reg::kGroupVec,
- asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- frame.setDirtyRegs(
- x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
-
- asmjit::FuncArgsAssignment args(&func);
+ asmjit::FuncArgsMapper args(&func);
args.assignAll(in_acts_R_, a_zero_pt_R_, H_R_, W_R_, row_offset_R_);
- args.updateFuncFrame(frame);
- frame.finalize();
+ args.updateFrameInfo(ffi);
+
+ asmjit::FuncFrameLayout layout;
+ layout.init(func, ffi);
- a->emitProlog(frame);
- a->emitArgsAssignment(frame, args);
+ asmjit::FuncUtils::emitProlog(a, layout);
+ asmjit::FuncUtils::allocArgs(a, layout, args);
// This uses xmm10 register temporarily. Should come before
// createVector8BitOne
if (!isAZeroPointZero_) {
// we can use xmm11 because ymm11 is used by tmpReg1Avx2_
- x86::Xmm const_reg_xmm = x86::xmm11;
+ asmjit::X86Xmm const_reg_xmm = x86::xmm11;
a->movq(const_reg_xmm, a_zero_pt_R_);
a->vpbroadcastd(zeroPTRegAvx2_, const_reg_xmm);
@@ -1568,7 +1569,7 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
genRowoffsetCore<inst_set_t::avx2>(a);
- a->emitEpilog(frame);
+ asmjit::FuncUtils::emitEpilog(a, layout);
jit_rowoffset_kernel_fp fn;
asmjit::Error err = rt_.add(&fn, &code_);