diff options
author | Jongsoo Park <jongsoo@fb.com> | 2019-02-02 21:32:16 +0300 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-02-02 21:35:25 +0300 |
commit | df7b1c1237c2f4274294ad9136861f30a7234c14 (patch) | |
tree | f8cfb98be892ca5e1d83beee060010868d123b2e | |
parent | a7d921d447564cdc67a07ad79dfc9e0d3dcd7418 (diff) |
gconv optimized for 8 channels per group (#65)
Summary:
Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/65
As title says
Reviewed By: jianyuh
Differential Revision: D13834287
fbshipit-source-id: ff174fdfcc27bcc227e435ff27e5c2a7024bf736
-rw-r--r-- | bench/GroupwiseConvRequantizeBenchmark.cc | 5 | ||||
-rw-r--r-- | include/fbgemm/QuantUtilsAvx2.h | 5 | ||||
-rw-r--r-- | src/Fbgemm.cc | 7 | ||||
-rw-r--r-- | src/GroupwiseConv.h | 22 | ||||
-rw-r--r-- | src/GroupwiseConvAcc32Avx2.cc | 661 | ||||
-rw-r--r-- | src/PackWeightMatrixForGConv.cc | 21 | ||||
-rw-r--r-- | src/QuantUtilsAvx2.cc | 281 | ||||
-rw-r--r-- | test/GConvTest.cc | 10 |
8 files changed, 751 insertions, 261 deletions
diff --git a/bench/GroupwiseConvRequantizeBenchmark.cc b/bench/GroupwiseConvRequantizeBenchmark.cc index 991013f..158ca4f 100644 --- a/bench/GroupwiseConvRequantizeBenchmark.cc +++ b/bench/GroupwiseConvRequantizeBenchmark.cc @@ -58,6 +58,11 @@ void performance_test() { // BatchSize > 1 // conv_param_t<>(2, 128, 128, {56, 48}, 32, {3, 3}, {1, 1}, {1, 1, 1, // 1}), + + conv_param_t<>(1, 256, 256, {56, 48}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 256, 256, {48, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 256, 256, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(2, 256, 256, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), }; bool flush = true; diff --git a/include/fbgemm/QuantUtilsAvx2.h b/include/fbgemm/QuantUtilsAvx2.h index 40b830c..04aeba1 100644 --- a/include/fbgemm/QuantUtilsAvx2.h +++ b/include/fbgemm/QuantUtilsAvx2.h @@ -85,8 +85,9 @@ template < bool B_SYMMETRIC, QuantizationGranularity Q_GRAN, bool HAS_BIAS, - bool FUSE_RELU> -FBGEMM_API void requantizeOutputProcessingGConv4Avx2( + bool FUSE_RELU, + int C_PER_G> +FBGEMM_API void requantizeOutputProcessingGConvAvx2( std::uint8_t* out, const std::int32_t* inp, const block_type_t& block, diff --git a/src/Fbgemm.cc b/src/Fbgemm.cc index 9384af6..cb22999 100644 --- a/src/Fbgemm.cc +++ b/src/Fbgemm.cc @@ -197,9 +197,10 @@ FBGEMM_API bool fbgemmOptimizedGConv(const conv_param_t<SPATIAL_DIM>& conv_p) { int C_per_G = conv_p.IC / conv_p.G; int K_per_G = conv_p.OC / conv_p.G; - return (SPATIAL_DIM == 2) && (C_per_G == K_per_G) && (C_per_G == 4) && - (conv_p.G % 8 == 0) && (conv_p.K[0] == conv_p.K[1]) && - (conv_p.K[0] == 3) && (conv_p.pad[0] == 1) && (conv_p.pad[1] == 1) && + return (SPATIAL_DIM == 2) && (C_per_G == K_per_G) && + (C_per_G == 4 || C_per_G == 8) && (conv_p.G % 8 == 0) && + (conv_p.K[0] == conv_p.K[1]) && (conv_p.K[0] == 3) && + (conv_p.pad[0] == 1) && (conv_p.pad[1] == 1) && (conv_p.pad[0] == conv_p.pad[2]) && (conv_p.pad[1] == conv_p.pad[3]) && (conv_p.dilation[0] == 1) && (conv_p.dilation[0] == conv_p.dilation[1]) && (conv_p.stride[0] == 1) && (conv_p.stride[0] == conv_p.stride[1]); diff --git a/src/GroupwiseConv.h b/src/GroupwiseConv.h index 0d2db3a..b65082f 100644 --- a/src/GroupwiseConv.h +++ b/src/GroupwiseConv.h @@ -142,32 +142,38 @@ class GenConvKernel { gen8bitFMA(asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm wReg); template <inst_set_t instSet> - void genForLoadingWeights(asmjit::X86Emitter* a); + void genForLoadingWeights(asmjit::X86Emitter* a, int c_offset); template <inst_set_t instSet> void genConstForPermutations(asmjit::X86Emitter* a); template <inst_set_t instSet> - void genForTopEdge(asmjit::X86Emitter* a); + void genForTopEdge(asmjit::X86Emitter* a, int c_offset); template <inst_set_t instSet> - void genForLeftEdge(asmjit::X86Emitter* a); + void genForLeftEdge(asmjit::X86Emitter* a, int c_offset); template <inst_set_t instSet> - void genForRightEdge(asmjit::X86Emitter* a); + void genForRightEdge(asmjit::X86Emitter* a, int c_offset); template <inst_set_t instSet> - void genForBottomEdge(asmjit::X86Emitter* a); + void genForBottomEdge(asmjit::X86Emitter* a, int c_offset); template <inst_set_t instSet> - void genCoreInsts(asmjit::X86Emitter* a); + void genCoreInsts(asmjit::X86Emitter* a, int c_offset); template <inst_set_t instSet> - void storeResult(asmjit::X86Emitter* a, int offset = 0); + void storeResult(asmjit::X86Emitter* a); // for Rowoffset kernel + // Add 4 consecutive numbers of 32 uint8 and emit 8 32-bit template <inst_set_t instSet> - void gen8BitSum(asmjit::X86Emitter* a, asmjit::X86Ymm aReg); + void gen8BitSumX4(asmjit::X86Emitter* a, asmjit::X86Ymm aReg); + + // Add 8 consecutive numbers of 64 uint8 and emit 8 32-bit + template <inst_set_t instSet> + void + gen8BitSumX8(asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm bReg); // Use scratchReg1_ and tmpReg1Avx2_ internally template <inst_set_t instSet> diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc index c85e339..be5145b 100644 --- a/src/GroupwiseConvAcc32Avx2.cc +++ b/src/GroupwiseConvAcc32Avx2.cc @@ -158,11 +158,12 @@ void GenConvKernel<int32_t>::genConstForPermutations<inst_set_t::avx2>( template <> template <> void GenConvKernel<int32_t>::storeResult<inst_set_t::avx2>( - asmjit::X86Emitter* a, - int offset) { - // store with permutation - a->vpermd(resultRegAvx2_, stPermRegAvx2_, resultRegAvx2_); - a->vmovups(x86::dword_ptr(out_acts_R_, offset), resultRegAvx2_); + asmjit::X86Emitter* a) { + if (C_per_G_ == 4) { + // store with permutation + a->vpermd(resultRegAvx2_, stPermRegAvx2_, resultRegAvx2_); + } + a->vmovups(x86::dword_ptr(out_acts_R_), resultRegAvx2_); } template <> @@ -171,21 +172,38 @@ void GenConvKernel<int32_t>::storeResultRowoffset<inst_set_t::avx2>( asmjit::X86Emitter* a, int offset) { // store - a->vmovups(x86::dword_ptr(row_offset_R_, offset), resultRegAvx2_); + if (C_per_G_ == 4) { + a->vmovups(x86::dword_ptr(row_offset_R_, offset), resultRegAvx2_); + } else { + // need to permute due to vphaddd + // 11 01 10 00 + a->vpermq(resultRegAvx2_, resultRegAvx2_, static_cast<asmjit::Imm>(0xd8)); + a->vmovups(x86::dword_ptr(row_offset_R_, offset), resultRegAvx2_); + } } template <> template <> void GenConvKernel<int32_t>::genForLoadingWeights<inst_set_t::avx2>( - asmjit::X86Emitter* a) { + asmjit::X86Emitter* a, int c_offset) { // load weights for (int r = 0; r < R_; ++r) { for (int s = 0; s < S_; ++s) { - a->vmovaps( - WRegs_avx2_[r * S_ + s], - x86::dword_ptr( - wghts_R_, - (r * S_ + s) * 2 * K_per_G_ * C_per_G_ * sizeof(int8_t))); + if (C_per_G_ == 4) { + a->vmovaps( + WRegs_avx2_[r * S_ + s], + x86::dword_ptr( + wghts_R_, + (r * S_ + s) * 2 * K_per_G_ * C_per_G_ * sizeof(int8_t))); + } else { + // C_per_G == 8 + a->vmovaps( + WRegs_avx2_[r * S_ + s], + x86::dword_ptr( + wghts_R_, + (((c_offset / 4) * R_ + r) * S_ + s) * K_per_G_ * 4 * + sizeof(int8_t))); + } } } } @@ -203,7 +221,7 @@ void GenConvKernel<int32_t>::gen8bitFMA<inst_set_t::avx2>( template <> template <> -void GenConvKernel<int32_t>::gen8BitSum<inst_set_t::avx2>( +void GenConvKernel<int32_t>::gen8BitSumX4<inst_set_t::avx2>( asmjit::X86Emitter* a, asmjit::X86Ymm aReg) { a->vpmaddubsw(tmpReg1Avx2_, aReg, oneReg8BitAvx2_); @@ -213,6 +231,19 @@ void GenConvKernel<int32_t>::gen8BitSum<inst_set_t::avx2>( template <> template <> +void GenConvKernel<int32_t>::gen8BitSumX8<inst_set_t::avx2>( + asmjit::X86Emitter* a, + asmjit::X86Ymm aReg, + asmjit::X86Ymm bReg) { + a->vxorps(tmpReg1Avx2_, tmpReg1Avx2_, tmpReg1Avx2_); + a->vpsadbw(aReg, aReg, tmpReg1Avx2_); + a->vpsadbw(bReg, bReg, tmpReg1Avx2_); + a->vphaddd(aReg, aReg, bReg); + a->vpaddd(resultRegAvx2_, aReg, resultRegAvx2_); +} + +template <> +template <> void GenConvKernel<int32_t>::genZeroPtSum<inst_set_t::avx2>( asmjit::X86Emitter* a, int multiplier) { @@ -228,10 +259,14 @@ void GenConvKernel<int32_t>::genZeroPtSum<inst_set_t::avx2>( template <> template <> void GenConvKernel<int32_t>::genForTopEdge<inst_set_t::avx2>( - asmjit::X86Emitter* a) { + asmjit::X86Emitter* a, int c_offset) { // top-left corner code - // zero out the results register - a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + if (c_offset == 0) { + // zero out the results register + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + } else { + a->vmovups(resultRegAvx2_, x86::dword_ptr(out_acts_R_)); + } for (int r = 0; r < R_; ++r) { int h_in = -H_PAD_ + r; if (h_in >= 0) { @@ -243,10 +278,20 @@ void GenConvKernel<int32_t>::genForTopEdge<inst_set_t::avx2>( for (int s = 0; s < S_; ++s) { int w_in = -W_PAD_ + s; if (h_in >= 0 && w_in >= 0) { - a->vbroadcastsd( - actRegAvx2_, - x86::dword_ptr( - in_acts_R_, scratchReg1_, 0, w_in * C_ * sizeof(uint8_t))); + if (C_per_G_ == 4) { + a->vbroadcastsd( + actRegAvx2_, + x86::dword_ptr( + in_acts_R_, scratchReg1_, 0, w_in * C_ * sizeof(uint8_t))); + } else { + a->vbroadcastss( + actRegAvx2_, + x86::word_ptr( + in_acts_R_, + scratchReg1_, + 0, + (w_in * C_ + c_offset) * sizeof(uint8_t))); + } gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]); } else { if (!isZeroPointZero_) { @@ -265,7 +310,11 @@ void GenConvKernel<int32_t>::genForTopEdge<inst_set_t::avx2>( a->mov(loopR2_, static_cast<asmjit::Imm>(W_PAD_)); a->bind(LoopTopEdge); // zero out - a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + if (c_offset == 0) { + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + } else { + a->vmovups(resultRegAvx2_, x86::dword_ptr(out_acts_R_)); + } if (!isZeroPointZero_) { for (int r = 0; r < H_PAD_; ++r) { for (int s = 0; s < S_; ++s) { @@ -280,10 +329,20 @@ void GenConvKernel<int32_t>::genForTopEdge<inst_set_t::avx2>( W_R_, static_cast<asmjit::Imm>(h_in * C_ * sizeof(uint8_t))); for (int s = 0; s < S_; ++s) { - a->vbroadcastsd( - actRegAvx2_, - x86::dword_ptr( - in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t))); + if (C_per_G_ == 4) { + a->vbroadcastsd( + actRegAvx2_, + x86::dword_ptr( + in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t))); + } else { + a->vbroadcastss( + actRegAvx2_, + x86::word_ptr( + in_acts_R_, + scratchReg1_, + 0, + (s * C_ + c_offset) * sizeof(uint8_t))); + } gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]); } } @@ -307,7 +366,11 @@ void GenConvKernel<int32_t>::genForTopEdge<inst_set_t::avx2>( // top-right corner code // zero out - a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + if (c_offset == 0) { + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + } else { + a->vmovups(resultRegAvx2_, x86::dword_ptr(out_acts_R_)); + } if (!isZeroPointZero_) { for (int r = 0; r < H_PAD_; ++r) { for (int s = 0; s < S_; ++s) { @@ -327,7 +390,14 @@ void GenConvKernel<int32_t>::genForTopEdge<inst_set_t::avx2>( a->sub(scratchReg2_, static_cast<asmjit::Imm>(R_ - W_PAD_ - s)); a->imul(scratchReg2_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); a->add(scratchReg1_, scratchReg2_); - a->vbroadcastsd(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_)); + if (C_per_G_ == 4) { + a->vbroadcastsd(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_)); + } else { + a->vbroadcastss( + actRegAvx2_, + x86::word_ptr( + in_acts_R_, scratchReg1_, 0, c_offset * sizeof(uint8_t))); + } gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]); } if (!isZeroPointZero_) { @@ -348,13 +418,19 @@ void GenConvKernel<int32_t>::genForTopEdge<inst_set_t::avx2>( template <> template <> void GenConvKernel<int32_t>::genForLeftEdge<inst_set_t::avx2>( - asmjit::X86Emitter* a) { + asmjit::X86Emitter* a, int c_offset) { // left edge excluding corners asmjit::Label LoopLeftEdge = a->newLabel(); a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_)); a->bind(LoopLeftEdge); - // zero out - a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + a->add(out_acts_R_, scratchReg2_); + if (c_offset == 0) { + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + } else { + a->vmovups(resultRegAvx2_, x86::dword_ptr(out_acts_R_)); + } a->mov(scratchReg1_, loopR1_); a->sub(scratchReg1_, static_cast<asmjit::Imm>(H_PAD_)); a->imul(scratchReg1_, W_R_); @@ -367,21 +443,28 @@ void GenConvKernel<int32_t>::genForLeftEdge<inst_set_t::avx2>( } } for (int s = W_PAD_; s < S_; ++s) { - a->vbroadcastsd( - actRegAvx2_, - x86::dword_ptr( - in_acts_R_, - scratchReg1_, - 0, - (s - W_PAD_) * C_ * sizeof(uint8_t))); + if (C_per_G_ == 4) { + a->vbroadcastsd( + actRegAvx2_, + x86::dword_ptr( + in_acts_R_, + scratchReg1_, + 0, + (s - W_PAD_) * C_ * sizeof(uint8_t))); + } else { + a->vbroadcastss( + actRegAvx2_, + x86::word_ptr( + in_acts_R_, + scratchReg1_, + 0, + ((s - W_PAD_) * C_ + c_offset) * sizeof(uint8_t))); + } gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]); } a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); a->add(scratchReg1_, scratchReg2_); } - - a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); - a->add(out_acts_R_, scratchReg2_); storeResult<inst_set_t::avx2>(a); a->inc(loopR1_); @@ -401,7 +484,7 @@ void GenConvKernel<int32_t>::genForLeftEdge<inst_set_t::avx2>( template <> template <> void GenConvKernel<int32_t>::genForRightEdge<inst_set_t::avx2>( - asmjit::X86Emitter* a) { + asmjit::X86Emitter* a, int c_offset) { // right edge excluding corners asmjit::Label LoopRightEdge = a->newLabel(); @@ -415,8 +498,12 @@ void GenConvKernel<int32_t>::genForRightEdge<inst_set_t::avx2>( a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_)); a->bind(LoopRightEdge); - // zero out - a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + if (c_offset == 0) { + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + } else { + a->vmovups(resultRegAvx2_, x86::dword_ptr(out_acts_R_)); + } a->mov(scratchReg1_, loopR1_); a->sub(scratchReg1_, static_cast<asmjit::Imm>(H_PAD_)); a->imul(scratchReg1_, W_R_); @@ -428,7 +515,14 @@ void GenConvKernel<int32_t>::genForRightEdge<inst_set_t::avx2>( a->add(scratchReg1_, scratchReg2_); for (int r = 0; r < R_; ++r) { for (int s = 0; s < S_ - W_PAD_; ++s) { - a->vbroadcastsd(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_)); + if (C_per_G_ == 4) { + a->vbroadcastsd(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_)); + } else { + a->vbroadcastss( + actRegAvx2_, + x86::word_ptr( + in_acts_R_, scratchReg1_, 0, c_offset * sizeof(uint8_t))); + } gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]); a->add(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); } @@ -477,10 +571,20 @@ void GenConvKernel<int32_t>::genForRightEdge<inst_set_t::avx2>( template <> template <> void GenConvKernel<int32_t>::genForBottomEdge<inst_set_t::avx2>( - asmjit::X86Emitter* a) { + asmjit::X86Emitter* a, int c_offset) { // bottom-left corner - // zero out - a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + // we updating the last row + a->mov(scratchReg1_, H_R_); + a->sub(scratchReg1_, 1); + a->imul(scratchReg1_, W_R_); + a->imul(scratchReg1_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + a->add(out_acts_R_, scratchReg1_); + if (c_offset == 0) { + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + } else { + a->vmovups(resultRegAvx2_, x86::dword_ptr(out_acts_R_)); + } a->mov(scratchReg1_, H_R_); a->sub(scratchReg1_, static_cast<asmjit::Imm>(2 * H_PAD_)); a->imul(scratchReg1_, W_R_); @@ -493,13 +597,23 @@ void GenConvKernel<int32_t>::genForBottomEdge<inst_set_t::avx2>( } } for (int s = W_PAD_; s < S_; ++s) { - a->vbroadcastsd( - actRegAvx2_, - x86::dword_ptr( - in_acts_R_, - scratchReg1_, - 0, - (s - W_PAD_) * C_ * sizeof(uint8_t))); + if (C_per_G_ == 4) { + a->vbroadcastsd( + actRegAvx2_, + x86::dword_ptr( + in_acts_R_, + scratchReg1_, + 0, + (s - W_PAD_) * C_ * sizeof(uint8_t))); + } else { + a->vbroadcastss( + actRegAvx2_, + x86::word_ptr( + in_acts_R_, + scratchReg1_, + 0, + ((s - W_PAD_) * C_ + c_offset) * sizeof(uint8_t))); + } gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]); } a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); @@ -514,12 +628,6 @@ void GenConvKernel<int32_t>::genForBottomEdge<inst_set_t::avx2>( } } - // we updating the last row - a->mov(scratchReg1_, H_R_); - a->sub(scratchReg1_, 1); - a->imul(scratchReg1_, W_R_); - a->imul(scratchReg1_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); - a->add(out_acts_R_, scratchReg1_); // storeResult<inst_set_t::avx2>(a, (H_-1)*W_*K_*sizeof(int32_t)); storeResult<inst_set_t::avx2>(a); a->add(out_acts_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); @@ -528,8 +636,12 @@ void GenConvKernel<int32_t>::genForBottomEdge<inst_set_t::avx2>( asmjit::Label LoopBottomEdge = a->newLabel(); a->mov(loopR2_, static_cast<asmjit::Imm>(W_PAD_)); a->bind(LoopBottomEdge); - // zero out - a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + if (c_offset == 0) { + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + } else { + a->vmovups(resultRegAvx2_, x86::dword_ptr(out_acts_R_)); + } a->mov(scratchReg1_, H_R_); a->sub(scratchReg1_, static_cast<asmjit::Imm>(2 * H_PAD_)); a->imul(scratchReg1_, W_R_); @@ -537,10 +649,20 @@ void GenConvKernel<int32_t>::genForBottomEdge<inst_set_t::avx2>( for (int r = 0; r < R_ - W_PAD_; ++r) { // int h_in = H_-2*H_PAD_ + r; for (int s = 0; s < S_; ++s) { - a->vbroadcastsd( - actRegAvx2_, - x86::dword_ptr( - in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t))); + if (C_per_G_ == 4) { + a->vbroadcastsd( + actRegAvx2_, + x86::dword_ptr( + in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t))); + } else { + a->vbroadcastss( + actRegAvx2_, + x86::word_ptr( + in_acts_R_, + scratchReg1_, + 0, + (s * C_ + c_offset) * sizeof(uint8_t))); + } gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]); } a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); @@ -574,8 +696,12 @@ void GenConvKernel<int32_t>::genForBottomEdge<inst_set_t::avx2>( // a->sub(out_acts_R_, (W_ - 2*W_PAD_)*K_*sizeof(int32_t)); // bottom-right corner - // zero out - a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + if (c_offset == 0) { + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + } else { + a->vmovups(resultRegAvx2_, x86::dword_ptr(out_acts_R_)); + } // input start point // ((H_-(R_-H_PAD_))*W_+(W_-(S_-W_PAD_)))*C_*sizeof(uint8_t) a->mov(scratchReg1_, H_R_); @@ -586,10 +712,20 @@ void GenConvKernel<int32_t>::genForBottomEdge<inst_set_t::avx2>( a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); for (int r = 0; r < R_ - H_PAD_; ++r) { for (int s = 0; s < S_ - W_PAD_; ++s) { - a->vbroadcastsd( - actRegAvx2_, - x86::dword_ptr( - in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t))); + if (C_per_G_ == 4) { + a->vbroadcastsd( + actRegAvx2_, + x86::dword_ptr( + in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t))); + } else { + a->vbroadcastss( + actRegAvx2_, + x86::word_ptr( + in_acts_R_, + scratchReg1_, + 0, + (s * C_ + c_offset) * sizeof(uint8_t))); + } gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]); } a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); @@ -626,7 +762,7 @@ void GenConvKernel<int32_t>::genForBottomEdge<inst_set_t::avx2>( template <> template <> void GenConvKernel<int32_t>::genCoreInsts<inst_set_t::avx2>( - asmjit::X86Emitter* a) { + asmjit::X86Emitter* a, int c_offset) { // main compute asmjit::Label LoopH = a->newLabel(); asmjit::Label LoopW = a->newLabel(); @@ -646,27 +782,41 @@ void GenConvKernel<int32_t>::genCoreInsts<inst_set_t::avx2>( // W loop a->mov(loopR2_, static_cast<asmjit::Imm>(W_PAD_)); a->bind(LoopW); - // zero out - a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + if (c_offset == 0) { + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + } else { + a->vmovups(resultRegAvx2_, x86::dword_ptr(out_acts_R_)); + } // compute on all filters for (int r = 0; r < R_; ++r) { for (int s = 0; s < S_; ++s) { - a->vbroadcastsd( - actRegAvx2_, x86::dword_ptr(in_acts_R_, s * C_ * sizeof(uint8_t))); + if (C_per_G_ == 4) { + a->vbroadcastsd( + actRegAvx2_, x86::dword_ptr(in_acts_R_, s * C_ * sizeof(uint8_t))); + } else { + a->vbroadcastss( + actRegAvx2_, + x86::word_ptr(in_acts_R_, (s * C_ + c_offset) * sizeof(uint8_t))); + } gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]); } + // advance input pointer by one row a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); a->add(in_acts_R_, scratchReg2_); } + // rewind input pointer a->imul( scratchReg2_, W_R_, static_cast<asmjit::Imm>(R_ * C_ * sizeof(uint8_t))); a->sub(in_acts_R_, scratchReg2_); // a->add(scratchReg1_, C_*sizeof(uint8_t)); + // advance input pointer by one pixel a->add(in_acts_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); // storeResult<inst_set_t::avx2>(a, (W_+1)*K_*sizeof(int32_t)); storeResult<inst_set_t::avx2>(a); + // advance output pointer by one pixel a->add(out_acts_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); a->inc(loopR2_); @@ -677,6 +827,7 @@ void GenConvKernel<int32_t>::genCoreInsts<inst_set_t::avx2>( in_acts_R_, static_cast<asmjit::Imm>(2 * W_PAD_ * C_ * sizeof(uint8_t))); // a->sub(in_acts_R_, (W_ - 2*W_PAD_)*C_*sizeof(uint8_t)); // a->add(in_acts_R_, W_*C_*sizeof(uint8_t)); + // advance output pointer by padding size a->add( out_acts_R_, static_cast<asmjit::Imm>(2 * W_PAD_ * K_ * sizeof(int32_t))); // a->sub(out_acts_R_, (W_ - 2*W_PAD_)*K_*sizeof(int32_t)); @@ -687,6 +838,28 @@ void GenConvKernel<int32_t>::genCoreInsts<inst_set_t::avx2>( a->sub(scratchReg2_, static_cast<asmjit::Imm>(H_PAD_)); a->cmp(loopR1_, scratchReg2_); a->jl(LoopH); + + if (c_offset + 4 < C_per_G_) { + // FIXME : simplify + // reset input pointer + // scratchReg2_ = W_R_ * C_ * (H_R_ - 2 * H_PAD_) + a->mov(scratchReg2_, H_R_); + a->sub(scratchReg2_, static_cast<asmjit::Imm>(2 * H_PAD_)); + a->imul(scratchReg2_, W_R_); + a->imul(scratchReg2_, C_); + a->sub(in_acts_R_, scratchReg2_); + + // reset output pointer + assert(K_ == C_); + a->imul(scratchReg2_, static_cast<asmjit::Imm>(sizeof(int32_t))); + a->sub(out_acts_R_, scratchReg2_); + + a->mov(scratchReg2_, static_cast<asmjit::Imm>(H_PAD_)); + a->imul(scratchReg2_, W_R_); + a->add(scratchReg2_, static_cast<asmjit::Imm>(W_PAD_)); + a->imul(scratchReg2_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + a->sub(out_acts_R_, scratchReg2_); + } } template <> @@ -759,16 +932,24 @@ jit_conv_kernel_fp GenConvKernel<int32_t>::getOrCreate<inst_set_t::avx2>( setToZeroPt<inst_set_t::avx2>(a, zeroPTRegAvx2_); } - genForLoadingWeights<inst_set_t::avx2>(a); - genConstForPermutations<inst_set_t::avx2>(a); - genForTopEdge<inst_set_t::avx2>(a); - genForLeftEdge<inst_set_t::avx2>(a); - genForRightEdge<inst_set_t::avx2>(a); - genForBottomEdge<inst_set_t::avx2>(a); - - genCoreInsts<inst_set_t::avx2>(a); + // Work on 4 input channels at a time. + // The minimum unit should be 4 because instruction sequence in gen8bitFMA + // reduces 4 inputs. + // We can't work on more than 4 input channels because of we can't put too + // many weights in register (we need R S K 4 / 32 registers to store weights + // for 4 input channels). + for (int c = 0; c < C_per_G_; c += 4) { + genForLoadingWeights<inst_set_t::avx2>(a, c); + + genForTopEdge<inst_set_t::avx2>(a, c); + genForLeftEdge<inst_set_t::avx2>(a, c); + genForRightEdge<inst_set_t::avx2>(a, c); + genForBottomEdge<inst_set_t::avx2>(a, c); + + genCoreInsts<inst_set_t::avx2>(a, c); + } asmjit::FuncUtils::emitEpilog(a, layout); @@ -811,9 +992,21 @@ void GenConvKernel<int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>( actRegAvx2_, x86::dword_ptr( in_acts_R_, scratchReg1_, 0, w_in * C_ * sizeof(uint8_t))); - gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_); + if (C_per_G_ == 4) { + gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_); + } else { + a->vmovups( + stPermRegAvx2_, + x86::dword_ptr( + in_acts_R_, + scratchReg1_, + 0, + (w_in * C_ + 32) * sizeof(uint8_t))); + gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_); + } } } + // store results storeResultRowoffset<inst_set_t::avx2>(a); @@ -840,7 +1033,18 @@ void GenConvKernel<int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>( actRegAvx2_, x86::dword_ptr( in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t))); - gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_); + if (C_per_G_ == 4) { + gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_); + } else { + a->vmovups( + stPermRegAvx2_, + x86::dword_ptr( + in_acts_R_, + scratchReg1_, + 0, + (s * C_ + 32) * sizeof(uint8_t))); + gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_); + } } } a->add(in_acts_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); @@ -879,7 +1083,13 @@ void GenConvKernel<int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>( a->imul(scratchReg2_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); a->add(scratchReg1_, scratchReg2_); a->vmovups(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_)); - gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_); + if (C_per_G_ == 4) { + gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_); + } else { + a->vmovups( + stPermRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_, 0, 32)); + gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_); + } } } @@ -919,7 +1129,18 @@ void GenConvKernel<int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>( scratchReg1_, 0, (s - W_PAD_) * C_ * sizeof(uint8_t))); - gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_); + if (C_per_G_ == 4) { + gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_); + } else { + a->vmovups( + stPermRegAvx2_, + x86::dword_ptr( + in_acts_R_, + scratchReg1_, + 0, + ((s - W_PAD_) * C_ + 32) * sizeof(uint8_t))); + gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_); + } } a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); a->add(scratchReg1_, scratchReg2_); @@ -977,7 +1198,14 @@ void GenConvKernel<int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>( for (int r = 0; r < R_; ++r) { for (int s = 0; s < S_ - W_PAD_; ++s) { a->vmovups(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_)); - gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_); + if (C_per_G_ == 4) { + gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_); + } else { + a->vmovups( + stPermRegAvx2_, + x86::dword_ptr(in_acts_R_, scratchReg1_, 0, 32 * sizeof(uint8_t))); + gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_); + } a->add(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); } @@ -1037,7 +1265,18 @@ void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>( scratchReg1_, 0, (s - W_PAD_) * C_ * sizeof(uint8_t))); - gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_); + if (C_per_G_ == 4) { + gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_); + } else { + a->vmovups( + stPermRegAvx2_, + x86::dword_ptr( + in_acts_R_, + scratchReg1_, + 0, + ((s - W_PAD_) * C_ + 32) * sizeof(uint8_t))); + gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_); + } } a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); a->add(scratchReg1_, scratchReg2_); @@ -1072,7 +1311,15 @@ void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>( actRegAvx2_, x86::dword_ptr( in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t))); - gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_); + if (C_per_G_ == 4) { + gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_); + } else { + a->vmovups( + stPermRegAvx2_, + x86::dword_ptr( + in_acts_R_, scratchReg1_, 0, (s * C_ + 32) * sizeof(uint8_t))); + gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_); + } } a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); a->add(scratchReg1_, scratchReg2_); @@ -1115,7 +1362,15 @@ void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>( actRegAvx2_, x86::dword_ptr( in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t))); - gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_); + if (C_per_G_ == 4) { + gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_); + } else { + a->vmovups( + stPermRegAvx2_, + x86::dword_ptr( + in_acts_R_, scratchReg1_, 0, (s * C_ + 32) * sizeof(uint8_t))); + gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_); + } } a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); a->add(scratchReg1_, scratchReg2_); @@ -1164,7 +1419,14 @@ void GenConvKernel<int32_t>::genRowoffsetCore<inst_set_t::avx2>( for (int s = 0; s < S_; ++s) { a->vmovups( actRegAvx2_, x86::dword_ptr(in_acts_R_, s * C_ * sizeof(uint8_t))); - gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_); + if (C_per_G_ == 4) { + gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_); + } else { + a->vmovups( + stPermRegAvx2_, + x86::dword_ptr(in_acts_R_, (s * C_ + 32) * sizeof(uint8_t))); + gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_); + } } a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); a->add(in_acts_R_, scratchReg2_); @@ -1355,7 +1617,8 @@ void fbgemmGroupwiseConvBase_( reinterpret_cast<float*>(rowOffsetTrDest), ih_iw); int gLimit = gOuter + 8; - for (int g = gOuter; g < gLimit; g += 2) { + int gDelta = C_per_G == 4 ? 2 : 1; + for (int g = gOuter; g < gLimit; g += gDelta) { int32_t* currOutBuf = outBuffer + i * oh_ow * conv_param.OC + g * K_per_G; const uint8_t* actStartGroup = actStartBatch + g * C_per_G; @@ -1370,7 +1633,7 @@ void fbgemmGroupwiseConvBase_( W); // Output processing should be called for each group - for (int j = 0; j < 2; ++j) { + for (int j = 0; j < gDelta; ++j) { // calculateRowOffsets( // conv_param, actStartGroup, rowOffsetBuf, a_zero_point, j); int32_t* rowOffsetForCurG = @@ -1486,6 +1749,7 @@ void fbgemmGroupwiseConv( int thread_id, int num_threads) { typedef ReQuantizeOutput<FUSE_RELU, Q_GRAN> processOutputType; + if (!fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param) || (!cpuinfo_has_x86_avx512f() && !cpuinfo_has_x86_avx2())) { return fbgemmGroupwiseConvBase_< @@ -1544,7 +1808,8 @@ void fbgemmGroupwiseConv( reinterpret_cast<float*>(rowOffsetTrDest), ih_iw); int gLimit = gOuter + 8; - for (int g = gOuter; g < gLimit; g += 2) { + int gDelta = C_per_G == 4 ? 2 : 1; + for (int g = gOuter; g < gLimit; g += gDelta) { int32_t* currOutBuf = outBuffer + (g - gOuter) * K_per_G; const uint8_t* actStartGroup = actStartBatch + g * C_per_G; @@ -1576,72 +1841,160 @@ void fbgemmGroupwiseConv( int ld_out = K_per_G * G; int ld_in = K_per_G * G; - if (a_zero_point == 0) { - if (b_symmetric) { - if (outProcess.getBias() == nullptr) { - requantizeOutputProcessingGConv4Avx2< - true, - true, - Q_GRAN, - false, - FUSE_RELU>(out, inp, block, ld_out, ld_in, r); + if (C_per_G == 4) { + if (a_zero_point == 0) { + if (b_symmetric) { + if (outProcess.getBias() == nullptr) { + requantizeOutputProcessingGConvAvx2< + true, + true, + Q_GRAN, + false, + FUSE_RELU, + 4>(out, inp, block, ld_out, ld_in, r); + } else { + requantizeOutputProcessingGConvAvx2< + true, + true, + Q_GRAN, + true, + FUSE_RELU, + 4>(out, inp, block, ld_out, ld_in, r); + } } else { - requantizeOutputProcessingGConv4Avx2< - true, - true, - Q_GRAN, - true, - FUSE_RELU>(out, inp, block, ld_out, ld_in, r); + if (outProcess.getBias() == nullptr) { + requantizeOutputProcessingGConvAvx2< + true, + false, + Q_GRAN, + false, + FUSE_RELU, + 4>(out, inp, block, ld_out, ld_in, r); + } else { + requantizeOutputProcessingGConvAvx2< + true, + false, + Q_GRAN, + true, + FUSE_RELU, + 4>(out, inp, block, ld_out, ld_in, r); + } } } else { - if (outProcess.getBias() == nullptr) { - requantizeOutputProcessingGConv4Avx2< - true, - false, - Q_GRAN, - false, - FUSE_RELU>(out, inp, block, ld_out, ld_in, r); + if (b_symmetric) { + if (outProcess.getBias() == nullptr) { + requantizeOutputProcessingGConvAvx2< + false, + true, + Q_GRAN, + false, + FUSE_RELU, + 4>(out, inp, block, ld_out, ld_in, r); + } else { + requantizeOutputProcessingGConvAvx2< + false, + true, + Q_GRAN, + true, + FUSE_RELU, + 4>(out, inp, block, ld_out, ld_in, r); + } } else { - requantizeOutputProcessingGConv4Avx2< - true, - false, - Q_GRAN, - true, - FUSE_RELU>(out, inp, block, ld_out, ld_in, r); + if (outProcess.getBias() == nullptr) { + requantizeOutputProcessingGConvAvx2< + false, + false, + Q_GRAN, + false, + FUSE_RELU, + 4>(out, inp, block, ld_out, ld_in, r); + } else { + requantizeOutputProcessingGConvAvx2< + false, + false, + Q_GRAN, + true, + FUSE_RELU, + 4>(out, inp, block, ld_out, ld_in, r); + } } } } else { - if (b_symmetric) { - if (outProcess.getBias() == nullptr) { - requantizeOutputProcessingGConv4Avx2< - false, - true, - Q_GRAN, - false, - FUSE_RELU>(out, inp, block, ld_out, ld_in, r); + if (a_zero_point == 0) { + if (b_symmetric) { + if (outProcess.getBias() == nullptr) { + requantizeOutputProcessingGConvAvx2< + true, + true, + Q_GRAN, + false, + FUSE_RELU, + 8>(out, inp, block, ld_out, ld_in, r); + } else { + requantizeOutputProcessingGConvAvx2< + true, + true, + Q_GRAN, + true, + FUSE_RELU, + 8>(out, inp, block, ld_out, ld_in, r); + } } else { - requantizeOutputProcessingGConv4Avx2< - false, - true, - Q_GRAN, - true, - FUSE_RELU>(out, inp, block, ld_out, ld_in, r); + if (outProcess.getBias() == nullptr) { + requantizeOutputProcessingGConvAvx2< + true, + false, + Q_GRAN, + false, + FUSE_RELU, + 8>(out, inp, block, ld_out, ld_in, r); + } else { + requantizeOutputProcessingGConvAvx2< + true, + false, + Q_GRAN, + true, + FUSE_RELU, + 8>(out, inp, block, ld_out, ld_in, r); + } } } else { - if (outProcess.getBias() == nullptr) { - requantizeOutputProcessingGConv4Avx2< - false, - false, - Q_GRAN, - false, - FUSE_RELU>(out, inp, block, ld_out, ld_in, r); + if (b_symmetric) { + if (outProcess.getBias() == nullptr) { + requantizeOutputProcessingGConvAvx2< + false, + true, + Q_GRAN, + false, + FUSE_RELU, + 8>(out, inp, block, ld_out, ld_in, r); + } else { + requantizeOutputProcessingGConvAvx2< + false, + true, + Q_GRAN, + true, + FUSE_RELU, + 8>(out, inp, block, ld_out, ld_in, r); + } } else { - requantizeOutputProcessingGConv4Avx2< - false, - false, - Q_GRAN, - true, - FUSE_RELU>(out, inp, block, ld_out, ld_in, r); + if (outProcess.getBias() == nullptr) { + requantizeOutputProcessingGConvAvx2< + false, + false, + Q_GRAN, + false, + FUSE_RELU, + 8>(out, inp, block, ld_out, ld_in, r); + } else { + requantizeOutputProcessingGConvAvx2< + false, + false, + Q_GRAN, + true, + FUSE_RELU, + 8>(out, inp, block, ld_out, ld_in, r); + } } } } @@ -1674,7 +2027,7 @@ int rowOffsetBufferSizeGConv(const conv_param_t<SPATIAL_DIM>& conv_param) { int bufferSize = conv_param.OUT_DIM[0] * conv_param.OUT_DIM[1]; int C_per_G = conv_param.IC / conv_param.G; int K_per_G = conv_param.OC / conv_param.G; - if (C_per_G == 4 && K_per_G == 4) { + if ((C_per_G == 4 && K_per_G == 4) || (C_per_G == 8 && K_per_G == 8)) { return 2 * 8 * bufferSize; } else { return conv_param.G * bufferSize; @@ -1683,7 +2036,7 @@ int rowOffsetBufferSizeGConv(const conv_param_t<SPATIAL_DIM>& conv_param) { int bufferSize = conv_param.OUT_DIM[0] * conv_param.OUT_DIM[1]; int C_per_G = conv_param.IC / conv_param.G; int K_per_G = conv_param.OC / conv_param.G; - if (C_per_G == 4 && K_per_G == 4) { + if ((C_per_G == 4 && K_per_G == 4) || (C_per_G == 8 && K_per_G == 8)) { // row offset is calculated for 8 groups at a time // 2x is needed for transposing return 2 * 8 * bufferSize; diff --git a/src/PackWeightMatrixForGConv.cc b/src/PackWeightMatrixForGConv.cc index e38fba9..5870fa5 100644 --- a/src/PackWeightMatrixForGConv.cc +++ b/src/PackWeightMatrixForGConv.cc @@ -52,6 +52,7 @@ PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::PackWeightMatrixForGConv( * For IC_per_G == 8, 16, 32 && OC_per_G == 8, 16, 32 there is no need to work * on 2 groups at a time and full SIMD width can be efficiently utilized even * while working on 1 group at a time. + * In this case, the layout is G (C/4) R S K 4 */ template <typename T, typename accT, int SPATIAL_DIM> void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack() { @@ -77,11 +78,21 @@ void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack() { [(((g * OC_per_G + k) * R + r) * S + s) * IC_per_G + c] : sdata_ [(((g * R + r) * S + s) * IC_per_G + c) * OC_per_G + k]; - pdata_ - [(((((g / 2) * R + r) * S + s) * OC_per_G + k) * 2 + - (g % 2)) * - IC_per_G + - c] = b; + if (IC_per_G == 4) { + // For IC_per_G == 4, we need to work on 2 groups at a time + pdata_ + [(((((g / 2) * R + r) * S + s) * OC_per_G + k) * 2 + + (g % 2)) * + IC_per_G + + c] = b; + } else { + pdata_ + [((((g * (IC_per_G / 4) + (c / 4)) * R + r) * S + s) * + OC_per_G + + k) * + 4 + + (c % 4)] = b; + } } } } diff --git a/src/QuantUtilsAvx2.cc b/src/QuantUtilsAvx2.cc index 7c36f6d..bc41310 100644 --- a/src/QuantUtilsAvx2.cc +++ b/src/QuantUtilsAvx2.cc @@ -662,8 +662,9 @@ template < bool B_SYMMETRIC, QuantizationGranularity Q_GRAN, bool HAS_BIAS, - bool FUSE_RELU> -void requantizeOutputProcessingGConv4Avx2( + bool FUSE_RELU, + int C_PER_G> +void requantizeOutputProcessingGConvAvx2( uint8_t* out, const int32_t* inp, const block_type_t& block, @@ -744,82 +745,151 @@ void requantizeOutputProcessingGConv4Avx2( } if (!B_SYMMETRIC) { - // Load row_offsets for 2 groups and broadcast by 4 times each because - // we have 4 channels per group. - - // groups 0 and 1 - __m256i row_offset_v = _mm256_insertf128_si256( - _mm256_castsi128_si256( - _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 0])), - _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 1]), - 1); + __m256i row_offset_v; + + // When C_PER_G == 4, we need to handle 2 groups at a time to fully + // utilize 32B AVX2 vector register (C_PER_G * 2 * sizeof(int32_t) == + // 32B) + // When C_PER_G == 8, we just need 1 group at a time on the other hand. + + // Groups 0 and 1 when C_PER_G == 4 + // Group 0 when C_PER_G == 8 + if (C_PER_G == 4) { + // Load row_offsets for 2 groups and broadcast by 4 times each because + // we have 4 channels per group. + // groups 0 and 1 + row_offset_v = _mm256_insertf128_si256( + _mm256_castsi128_si256( + _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 0])), + _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 1]), + 1); + } else { + assert(C_PER_G == 8); + row_offset_v = + _mm256_set1_epi32(r.row_offsets + [(i - block.row_start) * 8 + + (j - block.col_start) / (VLEN * 4) * 4]); + } __m256i B_zero_point_v = _mm256_set1_epi32(r.B_zero_point[0]); if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { B_zero_point_v = _mm256_loadu_si256( reinterpret_cast<const __m256i*>(r.B_zero_point + j)); } else if (Q_GRAN == QuantizationGranularity::GROUP) { - B_zero_point_v = _mm256_insertf128_si256( - _mm256_castsi128_si256( - _mm_set1_epi32(r.B_zero_point[quant_param_idx])), - _mm_set1_epi32(r.B_zero_point[quant_param_idx + 1]), - 1); + if (C_PER_G == 4) { + B_zero_point_v = _mm256_insertf128_si256( + _mm256_castsi128_si256( + _mm_set1_epi32(r.B_zero_point[quant_param_idx])), + _mm_set1_epi32(r.B_zero_point[quant_param_idx + 1]), + 1); + } else { + B_zero_point_v = _mm256_set1_epi32( + r.B_zero_point + [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4]); + } } row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v); x_v = _mm256_sub_epi32(x_v, row_offset_v); - // groups 2 and 3 - row_offset_v = _mm256_insertf128_si256( - _mm256_castsi128_si256( - _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 2])), - _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 3]), - 1); + // Groups 2 and 3 when C_PER_G == 4 + // Group 1 when C_PER_G == 8 + if (C_PER_G == 4) { + // + 2 and + 3 here are for groups 2 and 3 + row_offset_v = _mm256_insertf128_si256( + _mm256_castsi128_si256( + _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 2])), + _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 3]), + 1); + } else { + // + 1 here is for group 1 + row_offset_v = _mm256_set1_epi32( + r.row_offsets + [(i - block.row_start) * 8 + + (j - block.col_start) / (VLEN * 4) * 4 + 1]); + } if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { B_zero_point_v = _mm256_loadu_si256( reinterpret_cast<const __m256i*>(r.B_zero_point + j + VLEN)); } else if (Q_GRAN == QuantizationGranularity::GROUP) { - B_zero_point_v = _mm256_insertf128_si256( - _mm256_castsi128_si256( - _mm_set1_epi32(r.B_zero_point[quant_param_idx + 2])), - _mm_set1_epi32(r.B_zero_point[quant_param_idx + 3]), - 1); + if (C_PER_G == 4) { + B_zero_point_v = _mm256_insertf128_si256( + _mm256_castsi128_si256( + _mm_set1_epi32(r.B_zero_point[quant_param_idx + 2])), + _mm_set1_epi32(r.B_zero_point[quant_param_idx + 3]), + 1); + } else { + B_zero_point_v = _mm256_set1_epi32( + r.B_zero_point + [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 + + 1]); + } } row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v); y_v = _mm256_sub_epi32(y_v, row_offset_v); - // groups 4 and 5 - row_offset_v = _mm256_insertf128_si256( - _mm256_castsi128_si256( - _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 4])), - _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 5]), - 1); + // Groups 4 and 5 when C_PER_G == 4 + // Group 2 when C_PER_G == 8 + if (C_PER_G == 4) { + row_offset_v = _mm256_insertf128_si256( + _mm256_castsi128_si256( + _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 4])), + _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 5]), + 1); + } else { + row_offset_v = _mm256_set1_epi32( + r.row_offsets + [(i - block.row_start) * 8 + + (j - block.col_start) / (VLEN * 4) * 4 + 2]); + } if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { B_zero_point_v = _mm256_loadu_si256( reinterpret_cast<const __m256i*>(r.B_zero_point + j + 2 * VLEN)); } else if (Q_GRAN == QuantizationGranularity::GROUP) { - B_zero_point_v = _mm256_insertf128_si256( - _mm256_castsi128_si256( - _mm_set1_epi32(r.B_zero_point[quant_param_idx + 4])), - _mm_set1_epi32(r.B_zero_point[quant_param_idx + 5]), - 1); + if (C_PER_G == 4) { + B_zero_point_v = _mm256_insertf128_si256( + _mm256_castsi128_si256( + _mm_set1_epi32(r.B_zero_point[quant_param_idx + 4])), + _mm_set1_epi32(r.B_zero_point[quant_param_idx + 5]), + 1); + } else { + B_zero_point_v = _mm256_set1_epi32( + r.B_zero_point + [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 + + 2]); + } } row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v); z_v = _mm256_sub_epi32(z_v, row_offset_v); - // groups 6 and 7 - row_offset_v = _mm256_insertf128_si256( - _mm256_castsi128_si256( - _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 6])), - _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 7]), - 1); + // Groups 6 and 7 when C_PER_G == 4 + // Group 3 when C_PER_G == 8 + if (C_PER_G == 4) { + row_offset_v = _mm256_insertf128_si256( + _mm256_castsi128_si256( + _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 6])), + _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 7]), + 1); + } else { + row_offset_v = _mm256_set1_epi32( + r.row_offsets + [(i - block.row_start) * 8 + + (j - block.col_start) / (VLEN * 4) * 4 + 3]); + } if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { B_zero_point_v = _mm256_loadu_si256( reinterpret_cast<const __m256i*>(r.B_zero_point + j + 3 * VLEN)); } else if (Q_GRAN == QuantizationGranularity::GROUP) { - B_zero_point_v = _mm256_insertf128_si256( - _mm256_castsi128_si256( - _mm_set1_epi32(r.B_zero_point[quant_param_idx + 6])), - _mm_set1_epi32(r.B_zero_point[quant_param_idx + 7]), - 1); + if (C_PER_G == 4) { + B_zero_point_v = _mm256_insertf128_si256( + _mm256_castsi128_si256( + _mm_set1_epi32(r.B_zero_point[quant_param_idx + 6])), + _mm_set1_epi32(r.B_zero_point[quant_param_idx + 7]), + 1); + } else { + B_zero_point_v = _mm256_set1_epi32( + r.B_zero_point + [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 + + 3]); + } } row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v); w_v = _mm256_sub_epi32(w_v, row_offset_v); @@ -868,33 +938,58 @@ void requantizeOutputProcessingGConv4Avx2( _mm256_cvtepi32_ps(w_v), _mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN)); } else if (Q_GRAN == QuantizationGranularity::GROUP) { - multiplier_v = _mm256_insertf128_ps( - _mm256_castps128_ps256( - _mm_set1_ps(r.C_multiplier[quant_param_idx])), - _mm_set1_ps(r.C_multiplier[quant_param_idx + 1]), - 1); - x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); + if (C_PER_G == 4) { + multiplier_v = _mm256_insertf128_ps( + _mm256_castps128_ps256( + _mm_set1_ps(r.C_multiplier[quant_param_idx])), + _mm_set1_ps(r.C_multiplier[quant_param_idx + 1]), + 1); + x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); - multiplier_v = _mm256_insertf128_ps( - _mm256_castps128_ps256( - _mm_set1_ps(r.C_multiplier[quant_param_idx + 2])), - _mm_set1_ps(r.C_multiplier[quant_param_idx + 3]), - 1); - y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); + multiplier_v = _mm256_insertf128_ps( + _mm256_castps128_ps256( + _mm_set1_ps(r.C_multiplier[quant_param_idx + 2])), + _mm_set1_ps(r.C_multiplier[quant_param_idx + 3]), + 1); + y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); - multiplier_v = _mm256_insertf128_ps( - _mm256_castps128_ps256( - _mm_set1_ps(r.C_multiplier[quant_param_idx + 4])), - _mm_set1_ps(r.C_multiplier[quant_param_idx + 5]), - 1); - z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); + multiplier_v = _mm256_insertf128_ps( + _mm256_castps128_ps256( + _mm_set1_ps(r.C_multiplier[quant_param_idx + 4])), + _mm_set1_ps(r.C_multiplier[quant_param_idx + 5]), + 1); + z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); - multiplier_v = _mm256_insertf128_ps( - _mm256_castps128_ps256( - _mm_set1_ps(r.C_multiplier[quant_param_idx + 6])), - _mm_set1_ps(r.C_multiplier[quant_param_idx + 7]), - 1); - w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); + multiplier_v = _mm256_insertf128_ps( + _mm256_castps128_ps256( + _mm_set1_ps(r.C_multiplier[quant_param_idx + 6])), + _mm_set1_ps(r.C_multiplier[quant_param_idx + 7]), + 1); + w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); + } else { + multiplier_v = _mm256_set1_ps( + r.C_multiplier + [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4]); + x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); + + multiplier_v = _mm256_set1_ps( + r.C_multiplier + [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 + + 1]); + y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); + + multiplier_v = _mm256_set1_ps( + r.C_multiplier + [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 + + 2]); + z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); + + multiplier_v = _mm256_set1_ps( + r.C_multiplier + [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 + + 3]); + w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); + } } else { x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); @@ -970,22 +1065,30 @@ void requantizeOutputProcessingGConv4Avx2( } // i loop } -#define INSTANTIATE_REQUANTIZE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU) \ - template void \ - requantizeOutputProcessingAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \ - uint8_t * out, \ - const int32_t* inp, \ - const block_type_t& block, \ - int ld_out, \ - int ld_in, \ - const requantizationParams_t& r); \ - template void \ - requantizeOutputProcessingGConv4Avx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \ - uint8_t * out, \ - const int32_t* inp, \ - const block_type_t& block, \ - int ld_out, \ - int ld_in, \ +#define INSTANTIATE_REQUANTIZE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU) \ + template void \ + requantizeOutputProcessingAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \ + uint8_t * out, \ + const int32_t* inp, \ + const block_type_t& block, \ + int ld_out, \ + int ld_in, \ + const requantizationParams_t& r); \ + template void \ + requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 4>( \ + uint8_t * out, \ + const int32_t* inp, \ + const block_type_t& block, \ + int ld_out, \ + int ld_in, \ + const requantizationParams_t& r); \ + template void \ + requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 8>( \ + uint8_t * out, \ + const int32_t* inp, \ + const block_type_t& block, \ + int ld_out, \ + int ld_in, \ const requantizationParams_t& r); #define INSTANTIATE_A_SYM(B_SYM, Q_GRAN, BIAS, RELU) \ diff --git a/test/GConvTest.cc b/test/GConvTest.cc index 34042c6..e1434d2 100644 --- a/test/GConvTest.cc +++ b/test/GConvTest.cc @@ -71,6 +71,16 @@ static vector<conv_param_t<>> GetShapes_() { conv_param_t<>(1, 128, 128, {48, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), conv_param_t<>(1, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), conv_param_t<>(2, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), + + conv_param_t<>(1, 64, 64, {3, 3}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 64, 64, {4, 4}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 64, 64, {3, 5}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 64, 64, {5, 3}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 16, 16, {5, 5}, 2, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 256, 256, {56, 48}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 256, 256, {48, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 256, 256, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(2, 256, 256, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), }; return shapes; } |