diff options
author | Jianyu Huang <jianyuhuang@fb.com> | 2019-08-06 21:55:17 +0300 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-08-06 21:59:00 +0300 |
commit | cf34b9a26b609109b18d6498f0608faddb7a911b (patch) | |
tree | 1ceaddaf942edb9debcafad7491b750fc3a5f066 /src/GroupwiseConv.h | |
parent | d8b3323668fdd15dc70e9cb43ab16e96f4846eeb (diff) |
Back out "[fbgemm] Integrate VNNI into FBGEMM master branch"
Summary:
Original commit changeset: fcaa13cc3159
ASMJIT requires the CMake version to be 3.8
However, FBGEMM and PyTorch only need the CMake version to be 3.5+.
This caused the build failure in FBGEMM:
https://circleci.com/gh/pytorch/FBGEMM/122#build-timing/containers/0
Reviewed By: dskhudia
Differential Revision: D16670547
fbshipit-source-id: 506714c3db1cb82cf98895f58f82f235128f5285
Diffstat (limited to 'src/GroupwiseConv.h')
-rw-r--r-- | src/GroupwiseConv.h | 100 |
1 files changed, 52 insertions, 48 deletions
diff --git a/src/GroupwiseConv.h b/src/GroupwiseConv.h index 4c5eea5..1e6324e 100644 --- a/src/GroupwiseConv.h +++ b/src/GroupwiseConv.h @@ -128,58 +128,60 @@ class GenConvKernel { const conv_param_t<SPATIAL_DIM>& conv_param); template <inst_set_t instSet> - void createVector16BitOne(x86::Emitter* a); + void createVector16BitOne(asmjit::X86Emitter* a); template <inst_set_t instSet> - void createVector8BitOne(x86::Emitter* a); + void createVector8BitOne(asmjit::X86Emitter* a); template <inst_set_t instSet> - void setToZeroPt(x86::Emitter* a, x86::Ymm destReg); + void setToZeroPt(asmjit::X86Emitter* a, asmjit::X86Ymm destReg); template <inst_set_t instSet> - void gen8bitFMA(x86::Emitter* a, x86::Ymm aReg, x86::Ymm wReg); + void + gen8bitFMA(asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm wReg); template <inst_set_t instSet> - void genForLoadingWeights(x86::Emitter* a, int c_offset); + void genForLoadingWeights(asmjit::X86Emitter* a, int c_offset); template <inst_set_t instSet> - void genConstForPermutations(x86::Emitter* a); + void genConstForPermutations(asmjit::X86Emitter* a); template <inst_set_t instSet> - void genForTopEdge(x86::Emitter* a, int c_offset); + void genForTopEdge(asmjit::X86Emitter* a, int c_offset); template <inst_set_t instSet> - void genForLeftEdge(x86::Emitter* a, int c_offset); + void genForLeftEdge(asmjit::X86Emitter* a, int c_offset); template <inst_set_t instSet> - void genForRightEdge(x86::Emitter* a, int c_offset); + void genForRightEdge(asmjit::X86Emitter* a, int c_offset); template <inst_set_t instSet> - void genForBottomEdge(x86::Emitter* a, int c_offset); + void genForBottomEdge(asmjit::X86Emitter* a, int c_offset); template <inst_set_t instSet> - void genCoreInsts(x86::Emitter* a, int c_offset); + void genCoreInsts(asmjit::X86Emitter* a, int c_offset); template <inst_set_t instSet> - void storeResult(x86::Emitter* a); + void storeResult(asmjit::X86Emitter* a); // for Rowoffset kernel // Add 4 consecutive numbers of 32 uint8 and emit 8 32-bit template <inst_set_t instSet> - void gen8BitSumX4(x86::Emitter* a, x86::Ymm aReg); + void gen8BitSumX4(asmjit::X86Emitter* a, asmjit::X86Ymm aReg); // Add 8 consecutive numbers of 64 uint8 and emit 8 32-bit template <inst_set_t instSet> - void gen8BitSumX8(x86::Emitter* a, x86::Ymm aReg, x86::Ymm bReg); + void + gen8BitSumX8(asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm bReg); // Add 16 consecutive numbers of 128 uint8 and emit 8 32-bit template <inst_set_t instSet> void gen8BitSumX16( - x86::Emitter* a, - x86::Ymm aReg, - x86::Ymm bReg, - x86::Ymm cReg, - x86::Ymm dReg); + asmjit::X86Emitter* a, + asmjit::X86Ymm aReg, + asmjit::X86Ymm bReg, + asmjit::X86Ymm cReg, + asmjit::X86Ymm dReg); // Generate instruction sequence that loads 8-bit values and sum them up. // Depending on C_per_G_, this function dispatches to gen8BitSumX4/8/16 @@ -189,33 +191,35 @@ class GenConvKernel { // Internally, actRegAvx2_, stPermRegAvx2_, WRegs_avx2_[0, 1], tmpReg1Avx2_, // and resultRegAvx2_ are used. template <inst_set_t instSet> - void - gen8BitSum(x86::Emitter* a, int act_offset, bool use_scratch_reg1 = true); + void gen8BitSum( + asmjit::X86Emitter* a, + int act_offset, + bool use_scratch_reg1 = true); // Use scratchReg1_ and tmpReg1Avx2_ internally template <inst_set_t instSet> - void genZeroPtSum(x86::Emitter* a, int multiplier); + void genZeroPtSum(asmjit::X86Emitter* a, int multiplier); template <inst_set_t instSet> - void genForTopEdgeRowoffset(x86::Emitter* a); + void genForTopEdgeRowoffset(asmjit::X86Emitter* a); template <inst_set_t instSet> - void genForLeftEdgeRowoffset(x86::Emitter* a); + void genForLeftEdgeRowoffset(asmjit::X86Emitter* a); template <inst_set_t instSet> - void genForRightEdgeRowoffset(x86::Emitter* a); + void genForRightEdgeRowoffset(asmjit::X86Emitter* a); template <inst_set_t instSet> - void genForBottomEdgeRowoffset(x86::Emitter* a); + void genForBottomEdgeRowoffset(asmjit::X86Emitter* a); template <inst_set_t instSet> - void genRowoffsetCorners(x86::Emitter* a); + void genRowoffsetCorners(asmjit::X86Emitter* a); template <inst_set_t instSet> - void genRowoffsetCore(x86::Emitter* a); + void genRowoffsetCore(asmjit::X86Emitter* a); template <inst_set_t instSet> - void storeResultRowoffset(x86::Emitter* a, int offset = 0); + void storeResultRowoffset(asmjit::X86Emitter* a, int offset = 0); static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit. static thread_local asmjit::CodeHolder code_; ///< JIT Code Holder for asmjit. @@ -230,30 +234,30 @@ class GenConvKernel { int vectorWidth_; ///< Vector width in bits. int VLEN_; ///< Vector width in elements. // avx2 specific - x86::Ymm + asmjit::X86Ymm WRegs_avx2_[9]; ///< AVX2 ymm registers for weights in the micro-kernel. - x86::Ymm zeroPTRegAvx2_; - x86::Ymm tmpReg1Avx2_; - x86::Ymm stPermRegAvx2_; - x86::Ymm actRegAvx2_; - x86::Ymm resultRegAvx2_; - x86::Ymm oneReg8BitAvx2_; - x86::Ymm oneReg16BitAvx2_; + asmjit::X86Ymm zeroPTRegAvx2_; + asmjit::X86Ymm tmpReg1Avx2_; + asmjit::X86Ymm stPermRegAvx2_; + asmjit::X86Ymm actRegAvx2_; + asmjit::X86Ymm resultRegAvx2_; + asmjit::X86Ymm oneReg8BitAvx2_; + asmjit::X86Ymm oneReg16BitAvx2_; // arguments to the function created - x86::Gp in_acts_R_; - x86::Gp wghts_R_; - x86::Gp out_acts_R_; - x86::Gp a_zero_pt_R_; - x86::Gp H_R_; - x86::Gp W_R_; - x86::Gp row_offset_R_; + asmjit::X86Gp in_acts_R_; + asmjit::X86Gp wghts_R_; + asmjit::X86Gp out_acts_R_; + asmjit::X86Gp a_zero_pt_R_; + asmjit::X86Gp H_R_; + asmjit::X86Gp W_R_; + asmjit::X86Gp row_offset_R_; // Used registers - x86::Gp loopR1_; - x86::Gp loopR2_; - x86::Gp scratchReg1_; - x86::Gp scratchReg2_; + asmjit::X86Gp loopR1_; + asmjit::X86Gp loopR2_; + asmjit::X86Gp scratchReg1_; + asmjit::X86Gp scratchReg2_; // Other parameters bool isAZeroPointZero_; |