Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/GroupwiseConv.h')
-rw-r--r--src/GroupwiseConv.h100
1 files changed, 52 insertions, 48 deletions
diff --git a/src/GroupwiseConv.h b/src/GroupwiseConv.h
index 4c5eea5..1e6324e 100644
--- a/src/GroupwiseConv.h
+++ b/src/GroupwiseConv.h
@@ -128,58 +128,60 @@ class GenConvKernel {
const conv_param_t<SPATIAL_DIM>& conv_param);
template <inst_set_t instSet>
- void createVector16BitOne(x86::Emitter* a);
+ void createVector16BitOne(asmjit::X86Emitter* a);
template <inst_set_t instSet>
- void createVector8BitOne(x86::Emitter* a);
+ void createVector8BitOne(asmjit::X86Emitter* a);
template <inst_set_t instSet>
- void setToZeroPt(x86::Emitter* a, x86::Ymm destReg);
+ void setToZeroPt(asmjit::X86Emitter* a, asmjit::X86Ymm destReg);
template <inst_set_t instSet>
- void gen8bitFMA(x86::Emitter* a, x86::Ymm aReg, x86::Ymm wReg);
+ void
+ gen8bitFMA(asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm wReg);
template <inst_set_t instSet>
- void genForLoadingWeights(x86::Emitter* a, int c_offset);
+ void genForLoadingWeights(asmjit::X86Emitter* a, int c_offset);
template <inst_set_t instSet>
- void genConstForPermutations(x86::Emitter* a);
+ void genConstForPermutations(asmjit::X86Emitter* a);
template <inst_set_t instSet>
- void genForTopEdge(x86::Emitter* a, int c_offset);
+ void genForTopEdge(asmjit::X86Emitter* a, int c_offset);
template <inst_set_t instSet>
- void genForLeftEdge(x86::Emitter* a, int c_offset);
+ void genForLeftEdge(asmjit::X86Emitter* a, int c_offset);
template <inst_set_t instSet>
- void genForRightEdge(x86::Emitter* a, int c_offset);
+ void genForRightEdge(asmjit::X86Emitter* a, int c_offset);
template <inst_set_t instSet>
- void genForBottomEdge(x86::Emitter* a, int c_offset);
+ void genForBottomEdge(asmjit::X86Emitter* a, int c_offset);
template <inst_set_t instSet>
- void genCoreInsts(x86::Emitter* a, int c_offset);
+ void genCoreInsts(asmjit::X86Emitter* a, int c_offset);
template <inst_set_t instSet>
- void storeResult(x86::Emitter* a);
+ void storeResult(asmjit::X86Emitter* a);
// for Rowoffset kernel
// Add 4 consecutive numbers of 32 uint8 and emit 8 32-bit
template <inst_set_t instSet>
- void gen8BitSumX4(x86::Emitter* a, x86::Ymm aReg);
+ void gen8BitSumX4(asmjit::X86Emitter* a, asmjit::X86Ymm aReg);
// Add 8 consecutive numbers of 64 uint8 and emit 8 32-bit
template <inst_set_t instSet>
- void gen8BitSumX8(x86::Emitter* a, x86::Ymm aReg, x86::Ymm bReg);
+ void
+ gen8BitSumX8(asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm bReg);
// Add 16 consecutive numbers of 128 uint8 and emit 8 32-bit
template <inst_set_t instSet>
void gen8BitSumX16(
- x86::Emitter* a,
- x86::Ymm aReg,
- x86::Ymm bReg,
- x86::Ymm cReg,
- x86::Ymm dReg);
+ asmjit::X86Emitter* a,
+ asmjit::X86Ymm aReg,
+ asmjit::X86Ymm bReg,
+ asmjit::X86Ymm cReg,
+ asmjit::X86Ymm dReg);
// Generate instruction sequence that loads 8-bit values and sum them up.
// Depending on C_per_G_, this function dispatches to gen8BitSumX4/8/16
@@ -189,33 +191,35 @@ class GenConvKernel {
// Internally, actRegAvx2_, stPermRegAvx2_, WRegs_avx2_[0, 1], tmpReg1Avx2_,
// and resultRegAvx2_ are used.
template <inst_set_t instSet>
- void
- gen8BitSum(x86::Emitter* a, int act_offset, bool use_scratch_reg1 = true);
+ void gen8BitSum(
+ asmjit::X86Emitter* a,
+ int act_offset,
+ bool use_scratch_reg1 = true);
// Use scratchReg1_ and tmpReg1Avx2_ internally
template <inst_set_t instSet>
- void genZeroPtSum(x86::Emitter* a, int multiplier);
+ void genZeroPtSum(asmjit::X86Emitter* a, int multiplier);
template <inst_set_t instSet>
- void genForTopEdgeRowoffset(x86::Emitter* a);
+ void genForTopEdgeRowoffset(asmjit::X86Emitter* a);
template <inst_set_t instSet>
- void genForLeftEdgeRowoffset(x86::Emitter* a);
+ void genForLeftEdgeRowoffset(asmjit::X86Emitter* a);
template <inst_set_t instSet>
- void genForRightEdgeRowoffset(x86::Emitter* a);
+ void genForRightEdgeRowoffset(asmjit::X86Emitter* a);
template <inst_set_t instSet>
- void genForBottomEdgeRowoffset(x86::Emitter* a);
+ void genForBottomEdgeRowoffset(asmjit::X86Emitter* a);
template <inst_set_t instSet>
- void genRowoffsetCorners(x86::Emitter* a);
+ void genRowoffsetCorners(asmjit::X86Emitter* a);
template <inst_set_t instSet>
- void genRowoffsetCore(x86::Emitter* a);
+ void genRowoffsetCore(asmjit::X86Emitter* a);
template <inst_set_t instSet>
- void storeResultRowoffset(x86::Emitter* a, int offset = 0);
+ void storeResultRowoffset(asmjit::X86Emitter* a, int offset = 0);
static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit.
static thread_local asmjit::CodeHolder code_; ///< JIT Code Holder for asmjit.
@@ -230,30 +234,30 @@ class GenConvKernel {
int vectorWidth_; ///< Vector width in bits.
int VLEN_; ///< Vector width in elements.
// avx2 specific
- x86::Ymm
+ asmjit::X86Ymm
WRegs_avx2_[9]; ///< AVX2 ymm registers for weights in the micro-kernel.
- x86::Ymm zeroPTRegAvx2_;
- x86::Ymm tmpReg1Avx2_;
- x86::Ymm stPermRegAvx2_;
- x86::Ymm actRegAvx2_;
- x86::Ymm resultRegAvx2_;
- x86::Ymm oneReg8BitAvx2_;
- x86::Ymm oneReg16BitAvx2_;
+ asmjit::X86Ymm zeroPTRegAvx2_;
+ asmjit::X86Ymm tmpReg1Avx2_;
+ asmjit::X86Ymm stPermRegAvx2_;
+ asmjit::X86Ymm actRegAvx2_;
+ asmjit::X86Ymm resultRegAvx2_;
+ asmjit::X86Ymm oneReg8BitAvx2_;
+ asmjit::X86Ymm oneReg16BitAvx2_;
// arguments to the function created
- x86::Gp in_acts_R_;
- x86::Gp wghts_R_;
- x86::Gp out_acts_R_;
- x86::Gp a_zero_pt_R_;
- x86::Gp H_R_;
- x86::Gp W_R_;
- x86::Gp row_offset_R_;
+ asmjit::X86Gp in_acts_R_;
+ asmjit::X86Gp wghts_R_;
+ asmjit::X86Gp out_acts_R_;
+ asmjit::X86Gp a_zero_pt_R_;
+ asmjit::X86Gp H_R_;
+ asmjit::X86Gp W_R_;
+ asmjit::X86Gp row_offset_R_;
// Used registers
- x86::Gp loopR1_;
- x86::Gp loopR2_;
- x86::Gp scratchReg1_;
- x86::Gp scratchReg2_;
+ asmjit::X86Gp loopR1_;
+ asmjit::X86Gp loopR2_;
+ asmjit::X86Gp scratchReg1_;
+ asmjit::X86Gp scratchReg2_;
// Other parameters
bool isAZeroPointZero_;