From a7d921d447564cdc67a07ad79dfc9e0d3dcd7418 Mon Sep 17 00:00:00 2001 From: Jongsoo Park Date: Sat, 2 Feb 2019 10:32:16 -0800 Subject: minor optimization in handling zero points for row offset (#63) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/63 During row_offset compute, pre-multiply zero point by C_per_G . Instead of generating multiple vpaddd instructions for adding zero points, we just multiply by a constant. Reviewed By: dskhudia Differential Revision: D13833686 fbshipit-source-id: a42a447955380f6cfde3fbae20ec16d47423bdd6 --- src/GroupwiseConv.h | 4 ++ src/GroupwiseConvAcc32Avx2.cc | 133 +++++++++++++++++------------------------- 2 files changed, 58 insertions(+), 79 deletions(-) (limited to 'src') diff --git a/src/GroupwiseConv.h b/src/GroupwiseConv.h index a46a895..0d2db3a 100644 --- a/src/GroupwiseConv.h +++ b/src/GroupwiseConv.h @@ -169,6 +169,10 @@ class GenConvKernel { template void gen8BitSum(asmjit::X86Emitter* a, asmjit::X86Ymm aReg); + // Use scratchReg1_ and tmpReg1Avx2_ internally + template + void genZeroPtSum(asmjit::X86Emitter* a, int multiplier); + template void genForTopEdgeRowoffset(asmjit::X86Emitter* a); diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc index 7906c04..c85e339 100644 --- a/src/GroupwiseConvAcc32Avx2.cc +++ b/src/GroupwiseConvAcc32Avx2.cc @@ -211,6 +211,20 @@ void GenConvKernel::gen8BitSum( a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_); } +template <> +template <> +void GenConvKernel::genZeroPtSum( + asmjit::X86Emitter* a, + int multiplier) { + a->mov(scratchReg1_, static_cast(multiplier)); + // tmpReg1Avx2_ also uses xmm11 + asmjit::X86Xmm const_reg_xmm = x86::xmm11; + a->movq(const_reg_xmm, scratchReg1_); + a->vpbroadcastd(tmpReg1Avx2_, const_reg_xmm); + a->vpmulld(tmpReg1Avx2_, zeroPTRegAvx2_, tmpReg1Avx2_); + a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_); +} + template <> template <> void GenConvKernel::genForTopEdge( @@ -782,27 +796,22 @@ void GenConvKernel::genForTopEdgeRowoffset( // top-left corner code // zero out the results register a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); - for (int r = 0; r < R_; ++r) { + if (!isZeroPointZero_) { + genZeroPtSum(a, R_ * S_ - (R_ - H_PAD_) * (S_ - W_PAD_)); + } + for (int r = H_PAD_; r < R_; ++r) { int h_in = -H_PAD_ + r; - if (h_in >= 0) { - a->imul( - scratchReg1_, - W_R_, - static_cast(h_in * C_ * sizeof(uint8_t))); - } - for (int s = 0; s < S_; ++s) { + a->imul( + scratchReg1_, + W_R_, + static_cast(h_in * C_ * sizeof(uint8_t))); + for (int s = W_PAD_; s < S_; ++s) { int w_in = -W_PAD_ + s; - if (h_in >= 0 && w_in >= 0) { - a->vmovups( - actRegAvx2_, - x86::dword_ptr( - in_acts_R_, scratchReg1_, 0, w_in * C_ * sizeof(uint8_t))); - gen8BitSum(a, actRegAvx2_); - } else { - if (!isZeroPointZero_) { - gen8BitSum(a, zeroPTRegAvx2_); - } - } + a->vmovups( + actRegAvx2_, + x86::dword_ptr( + in_acts_R_, scratchReg1_, 0, w_in * C_ * sizeof(uint8_t))); + gen8BitSum(a, actRegAvx2_); } } // store results @@ -818,11 +827,7 @@ void GenConvKernel::genForTopEdgeRowoffset( // zero out a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); if (!isZeroPointZero_) { - for (int r = 0; r < H_PAD_; ++r) { - for (int s = 0; s < S_; ++s) { - gen8BitSum(a, zeroPTRegAvx2_); - } - } + genZeroPtSum(a, H_PAD_ * S_); } for (int r = H_PAD_; r < R_; ++r) { int h_in = -H_PAD_ + r; @@ -860,11 +865,7 @@ void GenConvKernel::genForTopEdgeRowoffset( // zero out a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); if (!isZeroPointZero_) { - for (int r = 0; r < H_PAD_; ++r) { - for (int s = 0; s < S_; ++s) { - gen8BitSum(a, zeroPTRegAvx2_); - } - } + genZeroPtSum(a, R_ * S_ - (R_ - H_PAD_) * (S_ - W_PAD_)); } for (int r = H_PAD_; r < R_; ++r) { int h_in = -H_PAD_ + r; @@ -880,11 +881,6 @@ void GenConvKernel::genForTopEdgeRowoffset( a->vmovups(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_)); gen8BitSum(a, actRegAvx2_); } - if (!isZeroPointZero_) { - for (int s = S_ - W_PAD_; s < S_; ++s) { - gen8BitSum(a, zeroPTRegAvx2_); - } - } } // store results @@ -907,16 +903,14 @@ void GenConvKernel::genForLeftEdgeRowoffset( a->bind(LoopLeftEdge); // zero out a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + if (!isZeroPointZero_) { + genZeroPtSum(a, R_ * W_PAD_); + } a->mov(scratchReg1_, loopR1_); a->sub(scratchReg1_, static_cast(H_PAD_)); a->imul(scratchReg1_, W_R_); a->imul(scratchReg1_, static_cast(C_ * sizeof(uint8_t))); for (int r = 0; r < R_; ++r) { - if (!isZeroPointZero_) { - for (int s = 0; s < W_PAD_; ++s) { - gen8BitSum(a, zeroPTRegAvx2_); - } - } for (int s = W_PAD_; s < S_; ++s) { a->vmovups( actRegAvx2_, @@ -968,6 +962,9 @@ void GenConvKernel::genForRightEdgeRowoffset( a->bind(LoopRightEdge); // zero out a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + if (!isZeroPointZero_) { + genZeroPtSum(a, R_ * W_PAD_); + } a->mov(scratchReg1_, loopR1_); a->sub(scratchReg1_, static_cast(H_PAD_)); a->imul(scratchReg1_, W_R_); @@ -979,16 +976,10 @@ void GenConvKernel::genForRightEdgeRowoffset( a->add(scratchReg1_, scratchReg2_); for (int r = 0; r < R_; ++r) { for (int s = 0; s < S_ - W_PAD_; ++s) { - a->vbroadcastsd(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_)); a->vmovups(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_)); gen8BitSum(a, actRegAvx2_); a->add(scratchReg1_, static_cast(C_ * sizeof(uint8_t))); } - if (!isZeroPointZero_) { - for (int s = S_ - W_PAD_; s < S_; ++s) { - gen8BitSum(a, zeroPTRegAvx2_); - } - } a->sub( scratchReg1_, @@ -1030,16 +1021,14 @@ void GenConvKernel::genForBottomEdgeRowoffset( // bottom-left corner // zero out a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + if (!isZeroPointZero_) { + genZeroPtSum(a, R_ * S_ - (R_ - H_PAD_) * (S_ - W_PAD_)); + } a->mov(scratchReg1_, H_R_); a->sub(scratchReg1_, static_cast(2 * H_PAD_)); a->imul(scratchReg1_, W_R_); a->imul(scratchReg1_, static_cast(C_ * sizeof(uint8_t))); for (int r = 0; r < R_ - H_PAD_; ++r) { - if (!isZeroPointZero_) { - for (int s = 0; s < W_PAD_; ++s) { - gen8BitSum(a, zeroPTRegAvx2_); - } - } for (int s = W_PAD_; s < S_; ++s) { a->vmovups( actRegAvx2_, @@ -1053,13 +1042,6 @@ void GenConvKernel::genForBottomEdgeRowoffset( a->imul(scratchReg2_, W_R_, static_cast(C_ * sizeof(uint8_t))); a->add(scratchReg1_, scratchReg2_); } - if (!isZeroPointZero_) { - for (int r = R_ - H_PAD_; r < R_; ++r) { - for (int s = 0; s < S_; ++s) { - gen8BitSum(a, zeroPTRegAvx2_); - } - } - } // we updating the last row a->mov(scratchReg1_, H_R_); @@ -1076,6 +1058,9 @@ void GenConvKernel::genForBottomEdgeRowoffset( a->bind(LoopBottomEdge); // zero out a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + if (!isZeroPointZero_) { + genZeroPtSum(a, H_PAD_ * S_); + } a->mov(scratchReg1_, H_R_); a->sub(scratchReg1_, static_cast(2 * H_PAD_)); a->imul(scratchReg1_, W_R_); @@ -1093,14 +1078,6 @@ void GenConvKernel::genForBottomEdgeRowoffset( a->add(scratchReg1_, scratchReg2_); } - if (!isZeroPointZero_) { - for (int r = R_ - W_PAD_; r < R_; ++r) { - for (int s = 0; s < S_; ++s) { - gen8BitSum(a, zeroPTRegAvx2_); - } - } - } - a->add(in_acts_R_, static_cast(C_ * sizeof(uint8_t))); // storeResult(a, ((H_-1)*W_+1)*8*sizeof(int32_t)); storeResultRowoffset(a); @@ -1121,6 +1098,9 @@ void GenConvKernel::genForBottomEdgeRowoffset( // bottom-right corner // zero out a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + if (!isZeroPointZero_) { + genZeroPtSum(a, R_ * S_ - (R_ - H_PAD_) * (S_ - W_PAD_)); + } // input start point // ((H_-(R_-H_PAD_))*W_+(W_-(S_-W_PAD_)))*C_*sizeof(uint8_t) a->mov(scratchReg1_, H_R_); @@ -1139,19 +1119,6 @@ void GenConvKernel::genForBottomEdgeRowoffset( } a->imul(scratchReg2_, W_R_, static_cast(C_ * sizeof(uint8_t))); a->add(scratchReg1_, scratchReg2_); - if (!isZeroPointZero_) { - for (int s = S_ - W_PAD_; s < S_; ++s) { - gen8BitSum(a, zeroPTRegAvx2_); - } - } - } - - if (!isZeroPointZero_) { - for (int r = R_ - H_PAD_; r < R_; ++r) { - for (int s = 0; s < S_; ++s) { - gen8BitSum(a, zeroPTRegAvx2_); - } - } } storeResultRowoffset(a); @@ -1288,7 +1255,15 @@ GenConvKernel::getOrCreateRowOffset( // This uses xmm10 register temporarily. Should come before // createVector8BitOne if (!isZeroPointZero_) { - setToZeroPt(a, zeroPTRegAvx2_); + // we can use xmm11 because ymm11 is used by tmpReg1Avx2_ + asmjit::X86Xmm const_reg_xmm = x86::xmm11; + a->movq(const_reg_xmm, a_zero_pt_R_); + a->vpbroadcastd(zeroPTRegAvx2_, const_reg_xmm); + + a->mov(scratchReg1_, static_cast(C_per_G_)); + a->movq(const_reg_xmm, scratchReg1_); + a->vpbroadcastd(tmpReg1Avx2_, const_reg_xmm); + a->vpmulld(zeroPTRegAvx2_, zeroPTRegAvx2_, tmpReg1Avx2_); } createVector16BitOne(a); -- cgit v1.2.3