diff options
author | Jongsoo Park <jongsoo@fb.com> | 2019-02-13 01:35:32 +0300 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-02-13 01:48:03 +0300 |
commit | 86eeae2f917b92126af5cb1d37336f7d503292ee (patch) | |
tree | cb7af7e0fc924fb27d673aa296c6dd0af3d60e1e | |
parent | df7b1c1237c2f4274294ad9136861f30a7234c14 (diff) |
group conv optimized for 16 channels per group (#68)
Summary:
Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/68
Continuing optimizations for group convolution. Even though op-level speedup for 16 channels per group is lower compared to 4 or 8-channel cases, we have a nice overall speedup in resnext101-32x4d because it has many Conv operators with 16 channels per group.
Reviewed By: protonu
Differential Revision: D13949873
fbshipit-source-id: 1dff4b1acfdabe23616e7df365daf2b7f6e8aea9
-rw-r--r-- | bench/GroupwiseConvRequantizeBenchmark.cc | 13 | ||||
-rw-r--r-- | src/Fbgemm.cc | 2 | ||||
-rw-r--r-- | src/GroupwiseConv.h | 22 | ||||
-rw-r--r-- | src/GroupwiseConvAcc32Avx2.cc | 445 | ||||
-rw-r--r-- | src/QuantUtilsAvx2.cc | 126 | ||||
-rw-r--r-- | test/GConvTest.cc | 11 |
6 files changed, 413 insertions, 206 deletions
diff --git a/bench/GroupwiseConvRequantizeBenchmark.cc b/bench/GroupwiseConvRequantizeBenchmark.cc index 158ca4f..4c93f23 100644 --- a/bench/GroupwiseConvRequantizeBenchmark.cc +++ b/bench/GroupwiseConvRequantizeBenchmark.cc @@ -59,10 +59,15 @@ void performance_test() { // conv_param_t<>(2, 128, 128, {56, 48}, 32, {3, 3}, {1, 1}, {1, 1, 1, // 1}), - conv_param_t<>(1, 256, 256, {56, 48}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), - conv_param_t<>(1, 256, 256, {48, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), - conv_param_t<>(1, 256, 256, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), - conv_param_t<>(2, 256, 256, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 256, 256, {28, 24}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 256, 256, {24, 28}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 256, 256, {28, 28}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(2, 256, 256, {28, 28}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), + + conv_param_t<>(1, 512, 512, {14, 12}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 512, 512, {12, 14}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 512, 512, {14, 14}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(2, 512, 512, {14, 14}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), }; bool flush = true; diff --git a/src/Fbgemm.cc b/src/Fbgemm.cc index cb22999..ab0693a 100644 --- a/src/Fbgemm.cc +++ b/src/Fbgemm.cc @@ -198,7 +198,7 @@ FBGEMM_API bool fbgemmOptimizedGConv(const conv_param_t<SPATIAL_DIM>& conv_p) { int K_per_G = conv_p.OC / conv_p.G; return (SPATIAL_DIM == 2) && (C_per_G == K_per_G) && - (C_per_G == 4 || C_per_G == 8) && (conv_p.G % 8 == 0) && + (C_per_G == 4 || C_per_G == 8 || C_per_G == 16) && (conv_p.G % 8 == 0) && (conv_p.K[0] == conv_p.K[1]) && (conv_p.K[0] == 3) && (conv_p.pad[0] == 1) && (conv_p.pad[1] == 1) && (conv_p.pad[0] == conv_p.pad[2]) && (conv_p.pad[1] == conv_p.pad[3]) && diff --git a/src/GroupwiseConv.h b/src/GroupwiseConv.h index b65082f..3605fcd 100644 --- a/src/GroupwiseConv.h +++ b/src/GroupwiseConv.h @@ -175,6 +175,28 @@ class GenConvKernel { void gen8BitSumX8(asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm bReg); + // Add 16 consecutive numbers of 128 uint8 and emit 8 32-bit + template <inst_set_t instSet> + void gen8BitSumX16( + asmjit::X86Emitter* a, + asmjit::X86Ymm aReg, + asmjit::X86Ymm bReg, + asmjit::X86Ymm cReg, + asmjit::X86Ymm dReg); + + // Generate instruction sequence that loads 8-bit values and sum them up. + // Depending on C_per_G_, this function dispatches to gen8BitSumX4/8/16 + // This function assumes in_acts_R_ has the base pointer to activation, + // scratchReg1_ has a variable offset, and act_offset has the final immediate + // offset. + // Internally, actRegAvx2_, stPermRegAvx2_, WRegs_avx2_[0, 1], tmpReg1Avx2_, + // and resultRegAvx2_ are used. + template <inst_set_t instSet> + void gen8BitSum( + asmjit::X86Emitter* a, + int act_offset, + bool use_scratch_reg1 = true); + // Use scratchReg1_ and tmpReg1Avx2_ internally template <inst_set_t instSet> void genZeroPtSum(asmjit::X86Emitter* a, int multiplier); diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc index be5145b..94ca87c 100644 --- a/src/GroupwiseConvAcc32Avx2.cc +++ b/src/GroupwiseConvAcc32Avx2.cc @@ -174,10 +174,21 @@ void GenConvKernel<int32_t>::storeResultRowoffset<inst_set_t::avx2>( // store if (C_per_G_ == 4) { a->vmovups(x86::dword_ptr(row_offset_R_, offset), resultRegAvx2_); + } else if (C_per_G_ == 8) { + // need to permute because vphaddd is used in gen8BitSumX8 + // 11 01 10 00 = 0xd8 + a->vpermq(resultRegAvx2_, resultRegAvx2_, static_cast<asmjit::Imm>(0xd8)); + a->vmovups(x86::dword_ptr(row_offset_R_, offset), resultRegAvx2_); } else { - // need to permute due to vphaddd - // 11 01 10 00 + assert(C_per_G_ == 16); + // need to permute because vphaddd is used in gen8BitSumX16 + // a[0:4] = a[0] + ... + a[15], a[4:8] = b[0] + ... + b[15] + // a[8:12] = a[16] + ... + a[31], a[12:16] = b[16] + ... + b[31] a->vpermq(resultRegAvx2_, resultRegAvx2_, static_cast<asmjit::Imm>(0xd8)); + // 11 01 10 00 = 0xd8 + // a[0:4] = a[0] + ... + a[15], a[4:8] = a[16] + ... + a[31] + // a[8:12] = b[0] + ... + b[16], a[12:16] = b[16] + ... + b[31] + a->vpshufd(resultRegAvx2_, resultRegAvx2_, static_cast<asmjit::Imm>(0xd8)); a->vmovups(x86::dword_ptr(row_offset_R_, offset), resultRegAvx2_); } } @@ -196,7 +207,7 @@ void GenConvKernel<int32_t>::genForLoadingWeights<inst_set_t::avx2>( wghts_R_, (r * S_ + s) * 2 * K_per_G_ * C_per_G_ * sizeof(int8_t))); } else { - // C_per_G == 8 + // C_per_G == 8 or 16 a->vmovaps( WRegs_avx2_[r * S_ + s], x86::dword_ptr( @@ -236,14 +247,140 @@ void GenConvKernel<int32_t>::gen8BitSumX8<inst_set_t::avx2>( asmjit::X86Ymm aReg, asmjit::X86Ymm bReg) { a->vxorps(tmpReg1Avx2_, tmpReg1Avx2_, tmpReg1Avx2_); + // Let a[0] denote 0th (LSB) 8-bit of aReg + // After vpsadbw, a[0:2] = a[0] + ... + a[7] + // a[8:10] = a[8] + ... + a[16] + // a[16:18] = a[16] + ... + a[24] + // a[24:26] = a[24] + ... + a[32] a->vpsadbw(aReg, aReg, tmpReg1Avx2_); a->vpsadbw(bReg, bReg, tmpReg1Avx2_); + // After vphadd, a[0:4] = a[0] + ... + a[7], a[4:8] = a[8] + ... + b[15] + // a[8:12] = b[0] + ... + b[7], a[12:16] = b[8] + ... + b[15] + // ... a->vphaddd(aReg, aReg, bReg); a->vpaddd(resultRegAvx2_, aReg, resultRegAvx2_); } template <> template <> +void GenConvKernel<int32_t>::gen8BitSumX16<inst_set_t::avx2>( + asmjit::X86Emitter* a, + asmjit::X86Ymm aReg, + asmjit::X86Ymm bReg, + asmjit::X86Ymm cReg, + asmjit::X86Ymm dReg) { + a->vxorps(tmpReg1Avx2_, tmpReg1Avx2_, tmpReg1Avx2_); + // After vpsadbw, a[0:2] = a[0] + ... + a[7] + // a[8:10] = a[8] + ... + a[15] + // a[16:18] = a[16] + ... + a[23] + // a[24:26] = a[24] + ... + a[31] + a->vpsadbw(aReg, aReg, tmpReg1Avx2_); + // 11 01 10 00 = 0xd8 + // a[0:4] = a[0] + ... + a[7], a[4:8] = a[8] + ... + a[15] + // a[8:16] = zeros + a->vpshufd(aReg, aReg, static_cast<asmjit::Imm>(0xd8)); + a->vpsadbw(bReg, bReg, tmpReg1Avx2_); + // 10 00 11 01 = 0x8d + // b[0:8] = zeros + // b[8:12] = b[0] + ... + b[7], b[12:16] = b[8] + ... + b[15] + a->vpshufd(bReg, bReg, static_cast<asmjit::Imm>(0x8d)); + // a[0:4] = a[0] + ... + a[7], a[4:8] = a[8] + ... + a[15] + // a[8:12] = b[0] + ... + b[7], a[12:16] + b[8] + ... + b[15] + a->vpaddd(aReg, aReg, bReg); + + // After vpsadbw, c[0:4] = c[0] + ... + c[7] + // c[8:12] = c[8] + ... + c[15] + // c[16:20] = c[16] + ... + c[23] + // c[24:28] = c[24] + ... + c[31] + a->vpsadbw(cReg, cReg, tmpReg1Avx2_); + // 11 01 10 00 = 0xd8 + // c[0:4] = c[0] + ... + c[7], c[4:8] = c[8] + ... + c[15] + // c[8:16] = zeros + a->vpshufd(cReg, cReg, static_cast<asmjit::Imm>(0xd8)); + a->vpsadbw(dReg, dReg, tmpReg1Avx2_); + // 10 00 11 01 = 0x8d + // d[0:8] = zeros + // d[8:12] = d[0] + ... + d[7], d[12:16] = d[8] + ... + d[15] + a->vpshufd(dReg, dReg, static_cast<asmjit::Imm>(0x8d)); + // c[0:4] = c[0] + ... + c[7], c[4:8] = c[8] + ... + c[15] + // c[8:12] = d[0] + ... + d[7], c[12:16] + d[8] + ... + d[15] + a->vpaddd(cReg, cReg, dReg); + + // a[0:4] = a[0] + ... + a[15], a[4:8] = b[0] + ... + b[15] + // a[8:12] = c[0] + ... + c[15], a[12:16] = d[0] + ... + d[15] + a->vphaddd(aReg, aReg, cReg); + + a->vpaddd(resultRegAvx2_, aReg, resultRegAvx2_); +} + +template <> +template <> +void GenConvKernel<int32_t>::gen8BitSum<inst_set_t::avx2>( + asmjit::X86Emitter* a, + int act_offset, + bool use_scratch_reg1 /*=true*/) { + if (use_scratch_reg1) { + a->vmovups( + actRegAvx2_, + x86::dword_ptr( + in_acts_R_, scratchReg1_, 0, act_offset * sizeof(uint8_t))); + } else { + a->vmovups( + actRegAvx2_, x86::dword_ptr(in_acts_R_, act_offset * sizeof(uint8_t))); + } + if (C_per_G_ == 4) { + gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_); + } else { + if (use_scratch_reg1) { + a->vmovups( + stPermRegAvx2_, + x86::dword_ptr( + in_acts_R_, + scratchReg1_, + 0, + (act_offset + VLEN_) * sizeof(uint8_t))); + } else { + a->vmovups( + stPermRegAvx2_, + x86::dword_ptr(in_acts_R_, (act_offset + VLEN_) * sizeof(uint8_t))); + } + if (C_per_G_ == 8) { + gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_); + } else { + assert(C_per_G_ == 16); + if (use_scratch_reg1) { + a->vmovups( + WRegs_avx2_[0], + x86::dword_ptr( + in_acts_R_, + scratchReg1_, + 0, + (act_offset + 2 * VLEN_) * sizeof(uint8_t))); + a->vmovups( + WRegs_avx2_[1], + x86::dword_ptr( + in_acts_R_, + scratchReg1_, + 0, + (act_offset + 3 * VLEN_) * sizeof(uint8_t))); + } else { + a->vmovups( + WRegs_avx2_[0], + x86::dword_ptr( + in_acts_R_, (act_offset + 2 * VLEN_) * sizeof(uint8_t))); + a->vmovups( + WRegs_avx2_[1], + x86::dword_ptr( + in_acts_R_, (act_offset + 3 * VLEN_) * sizeof(uint8_t))); + } + gen8BitSumX16<inst_set_t::avx2>( + a, actRegAvx2_, stPermRegAvx2_, WRegs_avx2_[0], WRegs_avx2_[1]); + } // C_per_G_ != 8 + } // C_per_G_ != 4 +} + +template <> +template <> void GenConvKernel<int32_t>::genZeroPtSum<inst_set_t::avx2>( asmjit::X86Emitter* a, int multiplier) { @@ -840,7 +977,6 @@ void GenConvKernel<int32_t>::genCoreInsts<inst_set_t::avx2>( a->jl(LoopH); if (c_offset + 4 < C_per_G_) { - // FIXME : simplify // reset input pointer // scratchReg2_ = W_R_ * C_ * (H_R_ - 2 * H_PAD_) a->mov(scratchReg2_, H_R_); @@ -988,22 +1124,7 @@ void GenConvKernel<int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>( static_cast<asmjit::Imm>(h_in * C_ * sizeof(uint8_t))); for (int s = W_PAD_; s < S_; ++s) { int w_in = -W_PAD_ + s; - a->vmovups( - actRegAvx2_, - x86::dword_ptr( - in_acts_R_, scratchReg1_, 0, w_in * C_ * sizeof(uint8_t))); - if (C_per_G_ == 4) { - gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_); - } else { - a->vmovups( - stPermRegAvx2_, - x86::dword_ptr( - in_acts_R_, - scratchReg1_, - 0, - (w_in * C_ + 32) * sizeof(uint8_t))); - gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_); - } + gen8BitSum<inst_set_t::avx2>(a, w_in * C_); } } @@ -1029,22 +1150,7 @@ void GenConvKernel<int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>( W_R_, static_cast<asmjit::Imm>(h_in * C_ * sizeof(uint8_t))); for (int s = 0; s < S_; ++s) { - a->vmovups( - actRegAvx2_, - x86::dword_ptr( - in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t))); - if (C_per_G_ == 4) { - gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_); - } else { - a->vmovups( - stPermRegAvx2_, - x86::dword_ptr( - in_acts_R_, - scratchReg1_, - 0, - (s * C_ + 32) * sizeof(uint8_t))); - gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_); - } + gen8BitSum<inst_set_t::avx2>(a, s * C_); } } a->add(in_acts_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); @@ -1082,14 +1188,7 @@ void GenConvKernel<int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>( a->sub(scratchReg2_, static_cast<asmjit::Imm>(R_ - W_PAD_ - s)); a->imul(scratchReg2_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); a->add(scratchReg1_, scratchReg2_); - a->vmovups(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_)); - if (C_per_G_ == 4) { - gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_); - } else { - a->vmovups( - stPermRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_, 0, 32)); - gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_); - } + gen8BitSum<inst_set_t::avx2>(a, 0); } } @@ -1122,25 +1221,7 @@ void GenConvKernel<int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>( a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); for (int r = 0; r < R_; ++r) { for (int s = W_PAD_; s < S_; ++s) { - a->vmovups( - actRegAvx2_, - x86::dword_ptr( - in_acts_R_, - scratchReg1_, - 0, - (s - W_PAD_) * C_ * sizeof(uint8_t))); - if (C_per_G_ == 4) { - gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_); - } else { - a->vmovups( - stPermRegAvx2_, - x86::dword_ptr( - in_acts_R_, - scratchReg1_, - 0, - ((s - W_PAD_) * C_ + 32) * sizeof(uint8_t))); - gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_); - } + gen8BitSum<inst_set_t::avx2>(a, (s - W_PAD_) * C_); } a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); a->add(scratchReg1_, scratchReg2_); @@ -1197,15 +1278,7 @@ void GenConvKernel<int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>( a->add(scratchReg1_, scratchReg2_); for (int r = 0; r < R_; ++r) { for (int s = 0; s < S_ - W_PAD_; ++s) { - a->vmovups(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_)); - if (C_per_G_ == 4) { - gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_); - } else { - a->vmovups( - stPermRegAvx2_, - x86::dword_ptr(in_acts_R_, scratchReg1_, 0, 32 * sizeof(uint8_t))); - gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_); - } + gen8BitSum<inst_set_t::avx2>(a, 0); a->add(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); } @@ -1258,25 +1331,7 @@ void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>( a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); for (int r = 0; r < R_ - H_PAD_; ++r) { for (int s = W_PAD_; s < S_; ++s) { - a->vmovups( - actRegAvx2_, - x86::dword_ptr( - in_acts_R_, - scratchReg1_, - 0, - (s - W_PAD_) * C_ * sizeof(uint8_t))); - if (C_per_G_ == 4) { - gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_); - } else { - a->vmovups( - stPermRegAvx2_, - x86::dword_ptr( - in_acts_R_, - scratchReg1_, - 0, - ((s - W_PAD_) * C_ + 32) * sizeof(uint8_t))); - gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_); - } + gen8BitSum<inst_set_t::avx2>(a, (s - W_PAD_) * C_); } a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); a->add(scratchReg1_, scratchReg2_); @@ -1307,19 +1362,7 @@ void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>( for (int r = 0; r < R_ - W_PAD_; ++r) { // int h_in = H_-2*H_PAD_ + r; for (int s = 0; s < S_; ++s) { - a->vmovups( - actRegAvx2_, - x86::dword_ptr( - in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t))); - if (C_per_G_ == 4) { - gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_); - } else { - a->vmovups( - stPermRegAvx2_, - x86::dword_ptr( - in_acts_R_, scratchReg1_, 0, (s * C_ + 32) * sizeof(uint8_t))); - gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_); - } + gen8BitSum<inst_set_t::avx2>(a, s * C_); } a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); a->add(scratchReg1_, scratchReg2_); @@ -1358,19 +1401,7 @@ void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>( a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); for (int r = 0; r < R_ - H_PAD_; ++r) { for (int s = 0; s < S_ - W_PAD_; ++s) { - a->vmovups( - actRegAvx2_, - x86::dword_ptr( - in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t))); - if (C_per_G_ == 4) { - gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_); - } else { - a->vmovups( - stPermRegAvx2_, - x86::dword_ptr( - in_acts_R_, scratchReg1_, 0, (s * C_ + 32) * sizeof(uint8_t))); - gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_); - } + gen8BitSum<inst_set_t::avx2>(a, s * C_); } a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); a->add(scratchReg1_, scratchReg2_); @@ -1417,16 +1448,7 @@ void GenConvKernel<int32_t>::genRowoffsetCore<inst_set_t::avx2>( a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); for (int r = 0; r < R_; ++r) { for (int s = 0; s < S_; ++s) { - a->vmovups( - actRegAvx2_, x86::dword_ptr(in_acts_R_, s * C_ * sizeof(uint8_t))); - if (C_per_G_ == 4) { - gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_); - } else { - a->vmovups( - stPermRegAvx2_, - x86::dword_ptr(in_acts_R_, (s * C_ + 32) * sizeof(uint8_t))); - gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_); - } + gen8BitSum<inst_set_t::avx2>(a, s * C_, false /*use_scratch_reg1*/); } a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); a->add(in_acts_R_, scratchReg2_); @@ -1600,14 +1622,10 @@ void fbgemmGroupwiseConvBase_( for (int i = 0; i < MB; ++i) { const uint8_t* actStartBatch = activations + i * ih_iw * conv_param.IC; for (int gOuter = 0; gOuter < G; gOuter += 8) { - // for C_per_G == 4 and K_per_G == 4, row offset is calcualted for 8 - // groups at a time The result is row offsets in the format IH*IW x G + // row offset is calcualted for 8 groups at a time. + // The result is row offsets in the format IH*IW x G fpRowoffset( - actStartBatch + gOuter * C_per_G, - a_zero_point, - H, - W, - rowOffsetBuf); + actStartBatch + gOuter * C_per_G, a_zero_point, H, W, rowOffsetBuf); // Transpose to get row offsets in the format G x IH*IW internal::transpose_8x8( ih_iw, @@ -1617,20 +1635,30 @@ void fbgemmGroupwiseConvBase_( reinterpret_cast<float*>(rowOffsetTrDest), ih_iw); int gLimit = gOuter + 8; - int gDelta = C_per_G == 4 ? 2 : 1; + // Work on 8 output channels at a time (8 * sizeof(int32_t) == 32B VLEN + // of AVX2), and we need multiple groups if a group has not enough + // number of channels. + int gDelta = std::max(8 / C_per_G, 1); for (int g = gOuter; g < gLimit; g += gDelta) { int32_t* currOutBuf = outBuffer + i * oh_ow * conv_param.OC + g * K_per_G; const uint8_t* actStartGroup = actStartBatch + g * C_per_G; - - fpConv( - actStartGroup, - packed_weights.getBuf() + - g * conv_param.K[0] * conv_param.K[1] * K_per_G * C_per_G, - currOutBuf, - a_zero_point, - H, - W); + for (int k = 0; k < K_per_G; k += 8) { + // Don't be confused with k above which refers to output channels. + // k0 and k1 are filter dimensions (commonly 3 and 3) + int k0 = conv_param.K[0]; + int k1 = conv_param.K[1]; + fpConv( + actStartGroup, + // packed weight is in G (C/4) R S K 4 layout for IC_per_G >= 8 + // in (G/2) R S K (2C) for IC_per_G == 4 + packed_weights.getBuf() + + (g * (C_per_G / 4) * k0 * k1 * K_per_G + k) * 4, + currOutBuf + k, + a_zero_point, + H, + W); + } // k loop // Output processing should be called for each group for (int j = 0; j < gDelta; ++j) { @@ -1791,14 +1819,10 @@ void fbgemmGroupwiseConv( for (int i = 0; i < MB; ++i) { const uint8_t* actStartBatch = activations + i * ih_iw * conv_param.IC; for (int gOuter = 0; gOuter < G; gOuter += 8) { - // for C_per_G == 4 and K_per_G == 4, row offset is calcualted for 8 - // groups at a time The result is row offsets in the format IH*IW x G + // row offset is calcualted for 8 groups at a time. + // The result is row offsets in the format IH*IW x G fpRowoffset( - actStartBatch + gOuter * C_per_G, - a_zero_point, - H, - W, - rowOffsetBuf); + actStartBatch + gOuter * C_per_G, a_zero_point, H, W, rowOffsetBuf); // Transpose to get row offsets in the format G x IH*IW internal::transpose_8x8( ih_iw, @@ -1808,20 +1832,31 @@ void fbgemmGroupwiseConv( reinterpret_cast<float*>(rowOffsetTrDest), ih_iw); int gLimit = gOuter + 8; - int gDelta = C_per_G == 4 ? 2 : 1; + // Work on 8 output channels at a time (8 * sizeof(int32_t) == 32B VLEN + // of AVX2), and we need multiple groups if a group has not enough + // number of channels. + int gDelta = std::max(8 / C_per_G, 1); for (int g = gOuter; g < gLimit; g += gDelta) { + // Reusing the same region of outBuffer multiple times for locality int32_t* currOutBuf = outBuffer + (g - gOuter) * K_per_G; const uint8_t* actStartGroup = actStartBatch + g * C_per_G; - - fpConv( - actStartGroup, - packed_weights.getBuf() + - g * conv_param.K[0] * conv_param.K[1] * K_per_G * C_per_G, - currOutBuf, - a_zero_point, - H, - W); - } + for (int k = 0; k < K_per_G; k += 8) { + // Don't be confused with k above which refers to output channels. + // k0 and k1 are filter dimensions (commonly 3 and 3) + int k0 = conv_param.K[0]; + int k1 = conv_param.K[1]; + fpConv( + actStartGroup, + // packed weight is in G (C/4) R S K 4 layout for IC_per_G >= 8 + // in (G/2) R S K (2C) for IC_per_G == 4 + packed_weights.getBuf() + + (g * (C_per_G / 4) * k0 * k1 * K_per_G + k) * 4, + currOutBuf + k, + a_zero_point, + H, + W); + } // k loop + } // g loop bool b_symmetric = outProcess.getBZeroPoint()[0] == 0 || rowOffsetBuf == nullptr; @@ -1919,7 +1954,7 @@ void fbgemmGroupwiseConv( } } } - } else { + } else if (C_per_G == 8) { if (a_zero_point == 0) { if (b_symmetric) { if (outProcess.getBias() == nullptr) { @@ -1997,6 +2032,84 @@ void fbgemmGroupwiseConv( } } } + } else { + if (a_zero_point == 0) { + if (b_symmetric) { + if (outProcess.getBias() == nullptr) { + requantizeOutputProcessingGConvAvx2< + true, + true, + Q_GRAN, + false, + FUSE_RELU, + 16>(out, inp, block, ld_out, ld_in, r); + } else { + requantizeOutputProcessingGConvAvx2< + true, + true, + Q_GRAN, + true, + FUSE_RELU, + 16>(out, inp, block, ld_out, ld_in, r); + } + } else { + if (outProcess.getBias() == nullptr) { + requantizeOutputProcessingGConvAvx2< + true, + false, + Q_GRAN, + false, + FUSE_RELU, + 16>(out, inp, block, ld_out, ld_in, r); + } else { + requantizeOutputProcessingGConvAvx2< + true, + false, + Q_GRAN, + true, + FUSE_RELU, + 16>(out, inp, block, ld_out, ld_in, r); + } + } + } else { + if (b_symmetric) { + if (outProcess.getBias() == nullptr) { + requantizeOutputProcessingGConvAvx2< + false, + true, + Q_GRAN, + false, + FUSE_RELU, + 16>(out, inp, block, ld_out, ld_in, r); + } else { + requantizeOutputProcessingGConvAvx2< + false, + true, + Q_GRAN, + true, + FUSE_RELU, + 16>(out, inp, block, ld_out, ld_in, r); + } + } else { + if (outProcess.getBias() == nullptr) { + requantizeOutputProcessingGConvAvx2< + false, + false, + Q_GRAN, + false, + FUSE_RELU, + 16>(out, inp, block, ld_out, ld_in, r); + } else { + requantizeOutputProcessingGConvAvx2< + false, + false, + Q_GRAN, + true, + FUSE_RELU, + 16>(out, inp, block, ld_out, ld_in, r); + } + } + } } } // gOuter loop } // i loop @@ -2027,7 +2140,8 @@ int rowOffsetBufferSizeGConv(const conv_param_t<SPATIAL_DIM>& conv_param) { int bufferSize = conv_param.OUT_DIM[0] * conv_param.OUT_DIM[1]; int C_per_G = conv_param.IC / conv_param.G; int K_per_G = conv_param.OC / conv_param.G; - if ((C_per_G == 4 && K_per_G == 4) || (C_per_G == 8 && K_per_G == 8)) { + if (C_per_G == K_per_G && + (C_per_G == 4 || C_per_G == 8 || C_per_G == 16)) { return 2 * 8 * bufferSize; } else { return conv_param.G * bufferSize; @@ -2036,7 +2150,8 @@ int rowOffsetBufferSizeGConv(const conv_param_t<SPATIAL_DIM>& conv_param) { int bufferSize = conv_param.OUT_DIM[0] * conv_param.OUT_DIM[1]; int C_per_G = conv_param.IC / conv_param.G; int K_per_G = conv_param.OC / conv_param.G; - if ((C_per_G == 4 && K_per_G == 4) || (C_per_G == 8 && K_per_G == 8)) { + if (C_per_G == K_per_G && + (C_per_G == 4 || C_per_G == 8 || C_per_G == 16)) { // row offset is calculated for 8 groups at a time // 2x is needed for transposing return 2 * 8 * bufferSize; diff --git a/src/QuantUtilsAvx2.cc b/src/QuantUtilsAvx2.cc index bc41310..be12142 100644 --- a/src/QuantUtilsAvx2.cc +++ b/src/QuantUtilsAvx2.cc @@ -763,12 +763,17 @@ void requantizeOutputProcessingGConvAvx2( _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 0])), _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 1]), 1); - } else { - assert(C_PER_G == 8); + } else if (C_PER_G == 8) { row_offset_v = _mm256_set1_epi32(r.row_offsets [(i - block.row_start) * 8 + (j - block.col_start) / (VLEN * 4) * 4]); + } else { + assert(C_PER_G == 16); + row_offset_v = + _mm256_set1_epi32(r.row_offsets + [(i - block.row_start) * 8 + + (j - block.col_start) / (VLEN * 4) * 2]); } __m256i B_zero_point_v = _mm256_set1_epi32(r.B_zero_point[0]); if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { @@ -781,10 +786,14 @@ void requantizeOutputProcessingGConvAvx2( _mm_set1_epi32(r.B_zero_point[quant_param_idx])), _mm_set1_epi32(r.B_zero_point[quant_param_idx + 1]), 1); - } else { + } else if (C_PER_G == 8) { B_zero_point_v = _mm256_set1_epi32( r.B_zero_point [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4]); + } else { + B_zero_point_v = _mm256_set1_epi32( + r.B_zero_point + [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 2]); } } row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v); @@ -799,12 +808,17 @@ void requantizeOutputProcessingGConvAvx2( _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 2])), _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 3]), 1); - } else { + } else if (C_PER_G == 8) { // + 1 here is for group 1 row_offset_v = _mm256_set1_epi32( r.row_offsets [(i - block.row_start) * 8 + (j - block.col_start) / (VLEN * 4) * 4 + 1]); + } else if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { + row_offset_v = + _mm256_set1_epi32(r.row_offsets + [(i - block.row_start) * 8 + + (j - block.col_start) / (VLEN * 4) * 2]); } if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { B_zero_point_v = _mm256_loadu_si256( @@ -816,14 +830,16 @@ void requantizeOutputProcessingGConvAvx2( _mm_set1_epi32(r.B_zero_point[quant_param_idx + 2])), _mm_set1_epi32(r.B_zero_point[quant_param_idx + 3]), 1); - } else { + } else if (C_PER_G == 8) { B_zero_point_v = _mm256_set1_epi32( r.B_zero_point [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 + 1]); } } - row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v); + if (C_PER_G != 16 || Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { + row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v); + } y_v = _mm256_sub_epi32(y_v, row_offset_v); // Groups 4 and 5 when C_PER_G == 4 @@ -834,11 +850,16 @@ void requantizeOutputProcessingGConvAvx2( _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 4])), _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 5]), 1); - } else { + } else if (C_PER_G == 8) { row_offset_v = _mm256_set1_epi32( r.row_offsets [(i - block.row_start) * 8 + (j - block.col_start) / (VLEN * 4) * 4 + 2]); + } else { + row_offset_v = _mm256_set1_epi32( + r.row_offsets + [(i - block.row_start) * 8 + + (j - block.col_start) / (VLEN * 4) * 2 + 1]); } if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { B_zero_point_v = _mm256_loadu_si256( @@ -850,11 +871,16 @@ void requantizeOutputProcessingGConvAvx2( _mm_set1_epi32(r.B_zero_point[quant_param_idx + 4])), _mm_set1_epi32(r.B_zero_point[quant_param_idx + 5]), 1); - } else { + } else if (C_PER_G == 8) { B_zero_point_v = _mm256_set1_epi32( r.B_zero_point [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 + 2]); + } else { + B_zero_point_v = _mm256_set1_epi32( + r.B_zero_point + [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 2 + + 1]); } } row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v); @@ -868,11 +894,16 @@ void requantizeOutputProcessingGConvAvx2( _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 6])), _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 7]), 1); - } else { + } else if (C_PER_G == 8) { row_offset_v = _mm256_set1_epi32( r.row_offsets [(i - block.row_start) * 8 + (j - block.col_start) / (VLEN * 4) * 4 + 3]); + } else if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { + row_offset_v = _mm256_set1_epi32( + r.row_offsets + [(i - block.row_start) * 8 + + (j - block.col_start) / (VLEN * 4) * 2 + 1]); } if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { B_zero_point_v = _mm256_loadu_si256( @@ -884,14 +915,16 @@ void requantizeOutputProcessingGConvAvx2( _mm_set1_epi32(r.B_zero_point[quant_param_idx + 6])), _mm_set1_epi32(r.B_zero_point[quant_param_idx + 7]), 1); - } else { + } else if (C_PER_G == 8) { B_zero_point_v = _mm256_set1_epi32( r.B_zero_point [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 + 3]); } } - row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v); + if (C_PER_G != 16 || Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { + row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v); + } w_v = _mm256_sub_epi32(w_v, row_offset_v); } if (HAS_BIAS) { @@ -966,7 +999,7 @@ void requantizeOutputProcessingGConvAvx2( _mm_set1_ps(r.C_multiplier[quant_param_idx + 7]), 1); w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); - } else { + } else if (C_PER_G == 8) { multiplier_v = _mm256_set1_ps( r.C_multiplier [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4]); @@ -989,6 +1022,19 @@ void requantizeOutputProcessingGConvAvx2( [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 + 3]); w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); + } else { + multiplier_v = _mm256_set1_ps( + r.C_multiplier + [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 2]); + x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); + y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v); + + multiplier_v = _mm256_set1_ps( + r.C_multiplier + [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 2 + + 1]); + z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v); + w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v); } } else { x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v); @@ -1065,30 +1111,38 @@ void requantizeOutputProcessingGConvAvx2( } // i loop } -#define INSTANTIATE_REQUANTIZE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU) \ - template void \ - requantizeOutputProcessingAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \ - uint8_t * out, \ - const int32_t* inp, \ - const block_type_t& block, \ - int ld_out, \ - int ld_in, \ - const requantizationParams_t& r); \ - template void \ - requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 4>( \ - uint8_t * out, \ - const int32_t* inp, \ - const block_type_t& block, \ - int ld_out, \ - int ld_in, \ - const requantizationParams_t& r); \ - template void \ - requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 8>( \ - uint8_t * out, \ - const int32_t* inp, \ - const block_type_t& block, \ - int ld_out, \ - int ld_in, \ +#define INSTANTIATE_REQUANTIZE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU) \ + template void \ + requantizeOutputProcessingAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \ + uint8_t * out, \ + const int32_t* inp, \ + const block_type_t& block, \ + int ld_out, \ + int ld_in, \ + const requantizationParams_t& r); \ + template void \ + requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 4>( \ + uint8_t * out, \ + const int32_t* inp, \ + const block_type_t& block, \ + int ld_out, \ + int ld_in, \ + const requantizationParams_t& r); \ + template void \ + requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 8>( \ + uint8_t * out, \ + const int32_t* inp, \ + const block_type_t& block, \ + int ld_out, \ + int ld_in, \ + const requantizationParams_t& r); \ + template void \ + requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 16>( \ + uint8_t * out, \ + const int32_t* inp, \ + const block_type_t& block, \ + int ld_out, \ + int ld_in, \ const requantizationParams_t& r); #define INSTANTIATE_A_SYM(B_SYM, Q_GRAN, BIAS, RELU) \ diff --git a/test/GConvTest.cc b/test/GConvTest.cc index e1434d2..66d61b6 100644 --- a/test/GConvTest.cc +++ b/test/GConvTest.cc @@ -69,6 +69,7 @@ static vector<conv_param_t<>> GetShapes_() { conv_param_t<>(1, 8, 8, {5, 5}, 2, {3, 3}, {1, 1}, {1, 1, 1, 1}), conv_param_t<>(1, 128, 128, {56, 48}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), conv_param_t<>(1, 128, 128, {48, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), + // the line below is from resnext101-32x4d conv_param_t<>(1, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), conv_param_t<>(2, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), @@ -81,6 +82,16 @@ static vector<conv_param_t<>> GetShapes_() { conv_param_t<>(1, 256, 256, {48, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), conv_param_t<>(1, 256, 256, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), conv_param_t<>(2, 256, 256, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), + + conv_param_t<>(1, 128, 128, {3, 3}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 128, 128, {4, 4}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 128, 128, {3, 5}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 128, 128, {5, 3}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 32, 32, {5, 5}, 2, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 512, 512, {56, 48}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 512, 512, {48, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(1, 512, 512, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), + conv_param_t<>(2, 512, 512, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), }; return shapes; } |