diff options
author | Jongsoo Park <jongsoo@fb.com> | 2019-02-13 02:29:20 +0300 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-02-13 02:33:33 +0300 |
commit | 66df1a0ccd762e525e319cb579810deade551152 (patch) | |
tree | e91c9bf581a5d121a11fa5ae6c3b7e883dc4defc | |
parent | 86eeae2f917b92126af5cb1d37336f7d503292ee (diff) |
isZeroPointZero_ -> isAZeroPointZero_ (#71)
Summary:
Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/71
To distinguish it from B zero point
Reviewed By: dskhudia
Differential Revision: D14021554
fbshipit-source-id: 555fad8342eaaf97a19f22ac0dfbb79df7293ce7
-rw-r--r-- | src/GroupwiseConv.h | 12 | ||||
-rw-r--r-- | src/GroupwiseConvAcc32Avx2.cc | 50 |
2 files changed, 29 insertions, 33 deletions
diff --git a/src/GroupwiseConv.h b/src/GroupwiseConv.h index 3605fcd..5681359 100644 --- a/src/GroupwiseConv.h +++ b/src/GroupwiseConv.h @@ -39,7 +39,7 @@ using jit_rowoffset_kernel_fp = void (*)( template <typename accT = int32_t> class GenConvKernel { public: - GenConvKernel(const conv_param_t<>& conv_param, std::int32_t zero_point) + GenConvKernel(const conv_param_t<>& conv_param, std::int32_t a_zero_point) : WRegs_avx2_{x86::ymm0, x86::ymm1, x86::ymm2, @@ -74,11 +74,7 @@ class GenConvKernel { // vector width in elements; Each element is int8 or uint8 VLEN_ = vectorWidth_ / 8; - if (zero_point == 0) { - isZeroPointZero_ = true; - } else { - isZeroPointZero_ = false; - } + isAZeroPointZero_ = a_zero_point == 0; G_ = conv_param.G; K_per_G_ = conv_param.OC / conv_param.G; @@ -105,7 +101,7 @@ class GenConvKernel { fileName += "_S-" + std::to_string(S_); fileName += "_PADH-" + std::to_string(H_PAD_); fileName += "_PADW-" + std::to_string(W_PAD_); - fileName += "_isZeroPointZero-" + std::to_string(isZeroPointZero_); + fileName += "_isZeroPointZero-" + std::to_string(isAZeroPointZero_); if (rowOffsetKernel) { fileName += "_rowOffset"; } @@ -261,7 +257,7 @@ class GenConvKernel { asmjit::X86Gp scratchReg2_; // Other parameters - bool isZeroPointZero_; + bool isAZeroPointZero_; // current conv parameters int G_; ///< Number of groups diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc index 94ca87c..f2ce07c 100644 --- a/src/GroupwiseConvAcc32Avx2.cc +++ b/src/GroupwiseConvAcc32Avx2.cc @@ -74,11 +74,11 @@ void calculateRowOffsets( tuple<bool, int, int, int> getKernelSig( const conv_param_t<>& conv_param, - bool isZeroPointZero) { + bool isAZeroPointZero) { int C_per_G = conv_param.IC / conv_param.G; int K_per_G = conv_param.OC / conv_param.G; auto kernelSig = - std::make_tuple(isZeroPointZero, conv_param.G, C_per_G, K_per_G); + std::make_tuple(isAZeroPointZero, conv_param.G, C_per_G, K_per_G); return kernelSig; } @@ -431,7 +431,7 @@ void GenConvKernel<int32_t>::genForTopEdge<inst_set_t::avx2>( } gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]); } else { - if (!isZeroPointZero_) { + if (!isAZeroPointZero_) { gen8bitFMA<inst_set_t::avx2>( a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]); } @@ -452,7 +452,7 @@ void GenConvKernel<int32_t>::genForTopEdge<inst_set_t::avx2>( } else { a->vmovups(resultRegAvx2_, x86::dword_ptr(out_acts_R_)); } - if (!isZeroPointZero_) { + if (!isAZeroPointZero_) { for (int r = 0; r < H_PAD_; ++r) { for (int s = 0; s < S_; ++s) { gen8bitFMA<inst_set_t::avx2>(a, zeroPTRegAvx2_, WRegs_avx2_[s]); @@ -508,7 +508,7 @@ void GenConvKernel<int32_t>::genForTopEdge<inst_set_t::avx2>( } else { a->vmovups(resultRegAvx2_, x86::dword_ptr(out_acts_R_)); } - if (!isZeroPointZero_) { + if (!isAZeroPointZero_) { for (int r = 0; r < H_PAD_; ++r) { for (int s = 0; s < S_; ++s) { gen8bitFMA<inst_set_t::avx2>( @@ -537,7 +537,7 @@ void GenConvKernel<int32_t>::genForTopEdge<inst_set_t::avx2>( } gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]); } - if (!isZeroPointZero_) { + if (!isAZeroPointZero_) { for (int s = S_ - W_PAD_; s < S_; ++s) { gen8bitFMA<inst_set_t::avx2>( a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]); @@ -573,7 +573,7 @@ void GenConvKernel<int32_t>::genForLeftEdge<inst_set_t::avx2>( a->imul(scratchReg1_, W_R_); a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); for (int r = 0; r < R_; ++r) { - if (!isZeroPointZero_) { + if (!isAZeroPointZero_) { for (int s = 0; s < W_PAD_; ++s) { gen8bitFMA<inst_set_t::avx2>( a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]); @@ -663,7 +663,7 @@ void GenConvKernel<int32_t>::genForRightEdge<inst_set_t::avx2>( gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]); a->add(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); } - if (!isZeroPointZero_) { + if (!isAZeroPointZero_) { for (int s = S_ - W_PAD_; s < S_; ++s) { gen8bitFMA<inst_set_t::avx2>( a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]); @@ -727,7 +727,7 @@ void GenConvKernel<int32_t>::genForBottomEdge<inst_set_t::avx2>( a->imul(scratchReg1_, W_R_); a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); for (int r = 0; r < R_ - H_PAD_; ++r) { - if (!isZeroPointZero_) { + if (!isAZeroPointZero_) { for (int s = 0; s < W_PAD_; ++s) { gen8bitFMA<inst_set_t::avx2>( a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]); @@ -756,7 +756,7 @@ void GenConvKernel<int32_t>::genForBottomEdge<inst_set_t::avx2>( a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); a->add(scratchReg1_, scratchReg2_); } - if (!isZeroPointZero_) { + if (!isAZeroPointZero_) { for (int r = R_ - H_PAD_; r < R_; ++r) { for (int s = 0; s < S_; ++s) { gen8bitFMA<inst_set_t::avx2>( @@ -806,7 +806,7 @@ void GenConvKernel<int32_t>::genForBottomEdge<inst_set_t::avx2>( a->add(scratchReg1_, scratchReg2_); } - if (!isZeroPointZero_) { + if (!isAZeroPointZero_) { for (int r = R_ - W_PAD_; r < R_; ++r) { for (int s = 0; s < S_; ++s) { gen8bitFMA<inst_set_t::avx2>( @@ -867,7 +867,7 @@ void GenConvKernel<int32_t>::genForBottomEdge<inst_set_t::avx2>( } a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); a->add(scratchReg1_, scratchReg2_); - if (!isZeroPointZero_) { + if (!isAZeroPointZero_) { for (int s = S_ - W_PAD_; s < S_; ++s) { gen8bitFMA<inst_set_t::avx2>( a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]); @@ -875,7 +875,7 @@ void GenConvKernel<int32_t>::genForBottomEdge<inst_set_t::avx2>( } } - if (!isZeroPointZero_) { + if (!isAZeroPointZero_) { for (int r = R_ - H_PAD_; r < R_; ++r) { for (int s = 0; s < S_; ++s) { gen8bitFMA<inst_set_t::avx2>( @@ -1064,7 +1064,7 @@ jit_conv_kernel_fp GenConvKernel<int32_t>::getOrCreate<inst_set_t::avx2>( loopR1_ = a->gpzRef(14); loopR2_ = a->gpzRef(15); - if (!isZeroPointZero_) { + if (!isAZeroPointZero_) { setToZeroPt<inst_set_t::avx2>(a, zeroPTRegAvx2_); } @@ -1095,7 +1095,7 @@ jit_conv_kernel_fp GenConvKernel<int32_t>::getOrCreate<inst_set_t::avx2>( std::cout << "Error: in fn add" << std::endl; return nullptr; } - auto kernelSig = getKernelSig(conv_param, isZeroPointZero_); + auto kernelSig = getKernelSig(conv_param, isAZeroPointZero_); codeCache_[kernelSig] = fn; #if defined(FBGEMM_LOG_CODE) @@ -1113,7 +1113,7 @@ void GenConvKernel<int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>( // top-left corner code // zero out the results register a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); - if (!isZeroPointZero_) { + if (!isAZeroPointZero_) { genZeroPtSum<inst_set_t::avx2>(a, R_ * S_ - (R_ - H_PAD_) * (S_ - W_PAD_)); } for (int r = H_PAD_; r < R_; ++r) { @@ -1140,7 +1140,7 @@ void GenConvKernel<int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>( a->bind(LoopTopEdge); // zero out a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); - if (!isZeroPointZero_) { + if (!isAZeroPointZero_) { genZeroPtSum<inst_set_t::avx2>(a, H_PAD_ * S_); } for (int r = H_PAD_; r < R_; ++r) { @@ -1174,7 +1174,7 @@ void GenConvKernel<int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>( // top-right corner code // zero out a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); - if (!isZeroPointZero_) { + if (!isAZeroPointZero_) { genZeroPtSum<inst_set_t::avx2>(a, R_ * S_ - (R_ - H_PAD_) * (S_ - W_PAD_)); } for (int r = H_PAD_; r < R_; ++r) { @@ -1212,7 +1212,7 @@ void GenConvKernel<int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>( a->bind(LoopLeftEdge); // zero out a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); - if (!isZeroPointZero_) { + if (!isAZeroPointZero_) { genZeroPtSum<inst_set_t::avx2>(a, R_ * W_PAD_); } a->mov(scratchReg1_, loopR1_); @@ -1264,7 +1264,7 @@ void GenConvKernel<int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>( a->bind(LoopRightEdge); // zero out a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); - if (!isZeroPointZero_) { + if (!isAZeroPointZero_) { genZeroPtSum<inst_set_t::avx2>(a, R_ * W_PAD_); } a->mov(scratchReg1_, loopR1_); @@ -1322,7 +1322,7 @@ void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>( // bottom-left corner // zero out a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); - if (!isZeroPointZero_) { + if (!isAZeroPointZero_) { genZeroPtSum<inst_set_t::avx2>(a, R_ * S_ - (R_ - H_PAD_) * (S_ - W_PAD_)); } a->mov(scratchReg1_, H_R_); @@ -1352,7 +1352,7 @@ void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>( a->bind(LoopBottomEdge); // zero out a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); - if (!isZeroPointZero_) { + if (!isAZeroPointZero_) { genZeroPtSum<inst_set_t::avx2>(a, H_PAD_ * S_); } a->mov(scratchReg1_, H_R_); @@ -1388,7 +1388,7 @@ void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>( // bottom-right corner // zero out a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); - if (!isZeroPointZero_) { + if (!isAZeroPointZero_) { genZeroPtSum<inst_set_t::avx2>(a, R_ * S_ - (R_ - H_PAD_) * (S_ - W_PAD_)); } // input start point @@ -1538,7 +1538,7 @@ GenConvKernel<int32_t>::getOrCreateRowOffset<inst_set_t::avx2>( // This uses xmm10 register temporarily. Should come before // createVector8BitOne - if (!isZeroPointZero_) { + if (!isAZeroPointZero_) { // we can use xmm11 because ymm11 is used by tmpReg1Avx2_ asmjit::X86Xmm const_reg_xmm = x86::xmm11; a->movq(const_reg_xmm, a_zero_pt_R_); @@ -1569,7 +1569,7 @@ GenConvKernel<int32_t>::getOrCreateRowOffset<inst_set_t::avx2>( std::cout << "Error: in fn add" << std::endl; return nullptr; } - auto kernelSig = getKernelSig(conv_param, isZeroPointZero_); + auto kernelSig = getKernelSig(conv_param, isAZeroPointZero_); codeCacheRowOffset_[kernelSig] = fn; #if defined(FBGEMM_LOG_CODE) |