Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJongsoo Park <jongsoo@fb.com>2019-02-13 01:35:32 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-13 01:48:03 +0300
commit86eeae2f917b92126af5cb1d37336f7d503292ee (patch)
treecb7af7e0fc924fb27d673aa296c6dd0af3d60e1e
parentdf7b1c1237c2f4274294ad9136861f30a7234c14 (diff)
group conv optimized for 16 channels per group (#68)
Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/68 Continuing optimizations for group convolution. Even though op-level speedup for 16 channels per group is lower compared to 4 or 8-channel cases, we have a nice overall speedup in resnext101-32x4d because it has many Conv operators with 16 channels per group. Reviewed By: protonu Differential Revision: D13949873 fbshipit-source-id: 1dff4b1acfdabe23616e7df365daf2b7f6e8aea9
-rw-r--r--bench/GroupwiseConvRequantizeBenchmark.cc13
-rw-r--r--src/Fbgemm.cc2
-rw-r--r--src/GroupwiseConv.h22
-rw-r--r--src/GroupwiseConvAcc32Avx2.cc445
-rw-r--r--src/QuantUtilsAvx2.cc126
-rw-r--r--test/GConvTest.cc11
6 files changed, 413 insertions, 206 deletions
diff --git a/bench/GroupwiseConvRequantizeBenchmark.cc b/bench/GroupwiseConvRequantizeBenchmark.cc
index 158ca4f..4c93f23 100644
--- a/bench/GroupwiseConvRequantizeBenchmark.cc
+++ b/bench/GroupwiseConvRequantizeBenchmark.cc
@@ -59,10 +59,15 @@ void performance_test() {
// conv_param_t<>(2, 128, 128, {56, 48}, 32, {3, 3}, {1, 1}, {1, 1, 1,
// 1}),
- conv_param_t<>(1, 256, 256, {56, 48}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
- conv_param_t<>(1, 256, 256, {48, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
- conv_param_t<>(1, 256, 256, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
- conv_param_t<>(2, 256, 256, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 256, 256, {28, 24}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 256, 256, {24, 28}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 256, 256, {28, 28}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(2, 256, 256, {28, 28}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+
+ conv_param_t<>(1, 512, 512, {14, 12}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 512, 512, {12, 14}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 512, 512, {14, 14}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(2, 512, 512, {14, 14}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
};
bool flush = true;
diff --git a/src/Fbgemm.cc b/src/Fbgemm.cc
index cb22999..ab0693a 100644
--- a/src/Fbgemm.cc
+++ b/src/Fbgemm.cc
@@ -198,7 +198,7 @@ FBGEMM_API bool fbgemmOptimizedGConv(const conv_param_t<SPATIAL_DIM>& conv_p) {
int K_per_G = conv_p.OC / conv_p.G;
return (SPATIAL_DIM == 2) && (C_per_G == K_per_G) &&
- (C_per_G == 4 || C_per_G == 8) && (conv_p.G % 8 == 0) &&
+ (C_per_G == 4 || C_per_G == 8 || C_per_G == 16) && (conv_p.G % 8 == 0) &&
(conv_p.K[0] == conv_p.K[1]) && (conv_p.K[0] == 3) &&
(conv_p.pad[0] == 1) && (conv_p.pad[1] == 1) &&
(conv_p.pad[0] == conv_p.pad[2]) && (conv_p.pad[1] == conv_p.pad[3]) &&
diff --git a/src/GroupwiseConv.h b/src/GroupwiseConv.h
index b65082f..3605fcd 100644
--- a/src/GroupwiseConv.h
+++ b/src/GroupwiseConv.h
@@ -175,6 +175,28 @@ class GenConvKernel {
void
gen8BitSumX8(asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm bReg);
+ // Add 16 consecutive numbers of 128 uint8 and emit 8 32-bit
+ template <inst_set_t instSet>
+ void gen8BitSumX16(
+ asmjit::X86Emitter* a,
+ asmjit::X86Ymm aReg,
+ asmjit::X86Ymm bReg,
+ asmjit::X86Ymm cReg,
+ asmjit::X86Ymm dReg);
+
+ // Generate instruction sequence that loads 8-bit values and sum them up.
+ // Depending on C_per_G_, this function dispatches to gen8BitSumX4/8/16
+ // This function assumes in_acts_R_ has the base pointer to activation,
+ // scratchReg1_ has a variable offset, and act_offset has the final immediate
+ // offset.
+ // Internally, actRegAvx2_, stPermRegAvx2_, WRegs_avx2_[0, 1], tmpReg1Avx2_,
+ // and resultRegAvx2_ are used.
+ template <inst_set_t instSet>
+ void gen8BitSum(
+ asmjit::X86Emitter* a,
+ int act_offset,
+ bool use_scratch_reg1 = true);
+
// Use scratchReg1_ and tmpReg1Avx2_ internally
template <inst_set_t instSet>
void genZeroPtSum(asmjit::X86Emitter* a, int multiplier);
diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc
index be5145b..94ca87c 100644
--- a/src/GroupwiseConvAcc32Avx2.cc
+++ b/src/GroupwiseConvAcc32Avx2.cc
@@ -174,10 +174,21 @@ void GenConvKernel<int32_t>::storeResultRowoffset<inst_set_t::avx2>(
// store
if (C_per_G_ == 4) {
a->vmovups(x86::dword_ptr(row_offset_R_, offset), resultRegAvx2_);
+ } else if (C_per_G_ == 8) {
+ // need to permute because vphaddd is used in gen8BitSumX8
+ // 11 01 10 00 = 0xd8
+ a->vpermq(resultRegAvx2_, resultRegAvx2_, static_cast<asmjit::Imm>(0xd8));
+ a->vmovups(x86::dword_ptr(row_offset_R_, offset), resultRegAvx2_);
} else {
- // need to permute due to vphaddd
- // 11 01 10 00
+ assert(C_per_G_ == 16);
+ // need to permute because vphaddd is used in gen8BitSumX16
+ // a[0:4] = a[0] + ... + a[15], a[4:8] = b[0] + ... + b[15]
+ // a[8:12] = a[16] + ... + a[31], a[12:16] = b[16] + ... + b[31]
a->vpermq(resultRegAvx2_, resultRegAvx2_, static_cast<asmjit::Imm>(0xd8));
+ // 11 01 10 00 = 0xd8
+ // a[0:4] = a[0] + ... + a[15], a[4:8] = a[16] + ... + a[31]
+ // a[8:12] = b[0] + ... + b[16], a[12:16] = b[16] + ... + b[31]
+ a->vpshufd(resultRegAvx2_, resultRegAvx2_, static_cast<asmjit::Imm>(0xd8));
a->vmovups(x86::dword_ptr(row_offset_R_, offset), resultRegAvx2_);
}
}
@@ -196,7 +207,7 @@ void GenConvKernel<int32_t>::genForLoadingWeights<inst_set_t::avx2>(
wghts_R_,
(r * S_ + s) * 2 * K_per_G_ * C_per_G_ * sizeof(int8_t)));
} else {
- // C_per_G == 8
+ // C_per_G == 8 or 16
a->vmovaps(
WRegs_avx2_[r * S_ + s],
x86::dword_ptr(
@@ -236,14 +247,140 @@ void GenConvKernel<int32_t>::gen8BitSumX8<inst_set_t::avx2>(
asmjit::X86Ymm aReg,
asmjit::X86Ymm bReg) {
a->vxorps(tmpReg1Avx2_, tmpReg1Avx2_, tmpReg1Avx2_);
+ // Let a[0] denote 0th (LSB) 8-bit of aReg
+ // After vpsadbw, a[0:2] = a[0] + ... + a[7]
+ // a[8:10] = a[8] + ... + a[16]
+ // a[16:18] = a[16] + ... + a[24]
+ // a[24:26] = a[24] + ... + a[32]
a->vpsadbw(aReg, aReg, tmpReg1Avx2_);
a->vpsadbw(bReg, bReg, tmpReg1Avx2_);
+ // After vphadd, a[0:4] = a[0] + ... + a[7], a[4:8] = a[8] + ... + b[15]
+ // a[8:12] = b[0] + ... + b[7], a[12:16] = b[8] + ... + b[15]
+ // ...
a->vphaddd(aReg, aReg, bReg);
a->vpaddd(resultRegAvx2_, aReg, resultRegAvx2_);
}
template <>
template <>
+void GenConvKernel<int32_t>::gen8BitSumX16<inst_set_t::avx2>(
+ asmjit::X86Emitter* a,
+ asmjit::X86Ymm aReg,
+ asmjit::X86Ymm bReg,
+ asmjit::X86Ymm cReg,
+ asmjit::X86Ymm dReg) {
+ a->vxorps(tmpReg1Avx2_, tmpReg1Avx2_, tmpReg1Avx2_);
+ // After vpsadbw, a[0:2] = a[0] + ... + a[7]
+ // a[8:10] = a[8] + ... + a[15]
+ // a[16:18] = a[16] + ... + a[23]
+ // a[24:26] = a[24] + ... + a[31]
+ a->vpsadbw(aReg, aReg, tmpReg1Avx2_);
+ // 11 01 10 00 = 0xd8
+ // a[0:4] = a[0] + ... + a[7], a[4:8] = a[8] + ... + a[15]
+ // a[8:16] = zeros
+ a->vpshufd(aReg, aReg, static_cast<asmjit::Imm>(0xd8));
+ a->vpsadbw(bReg, bReg, tmpReg1Avx2_);
+ // 10 00 11 01 = 0x8d
+ // b[0:8] = zeros
+ // b[8:12] = b[0] + ... + b[7], b[12:16] = b[8] + ... + b[15]
+ a->vpshufd(bReg, bReg, static_cast<asmjit::Imm>(0x8d));
+ // a[0:4] = a[0] + ... + a[7], a[4:8] = a[8] + ... + a[15]
+ // a[8:12] = b[0] + ... + b[7], a[12:16] + b[8] + ... + b[15]
+ a->vpaddd(aReg, aReg, bReg);
+
+ // After vpsadbw, c[0:4] = c[0] + ... + c[7]
+ // c[8:12] = c[8] + ... + c[15]
+ // c[16:20] = c[16] + ... + c[23]
+ // c[24:28] = c[24] + ... + c[31]
+ a->vpsadbw(cReg, cReg, tmpReg1Avx2_);
+ // 11 01 10 00 = 0xd8
+ // c[0:4] = c[0] + ... + c[7], c[4:8] = c[8] + ... + c[15]
+ // c[8:16] = zeros
+ a->vpshufd(cReg, cReg, static_cast<asmjit::Imm>(0xd8));
+ a->vpsadbw(dReg, dReg, tmpReg1Avx2_);
+ // 10 00 11 01 = 0x8d
+ // d[0:8] = zeros
+ // d[8:12] = d[0] + ... + d[7], d[12:16] = d[8] + ... + d[15]
+ a->vpshufd(dReg, dReg, static_cast<asmjit::Imm>(0x8d));
+ // c[0:4] = c[0] + ... + c[7], c[4:8] = c[8] + ... + c[15]
+ // c[8:12] = d[0] + ... + d[7], c[12:16] + d[8] + ... + d[15]
+ a->vpaddd(cReg, cReg, dReg);
+
+ // a[0:4] = a[0] + ... + a[15], a[4:8] = b[0] + ... + b[15]
+ // a[8:12] = c[0] + ... + c[15], a[12:16] = d[0] + ... + d[15]
+ a->vphaddd(aReg, aReg, cReg);
+
+ a->vpaddd(resultRegAvx2_, aReg, resultRegAvx2_);
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::gen8BitSum<inst_set_t::avx2>(
+ asmjit::X86Emitter* a,
+ int act_offset,
+ bool use_scratch_reg1 /*=true*/) {
+ if (use_scratch_reg1) {
+ a->vmovups(
+ actRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_, scratchReg1_, 0, act_offset * sizeof(uint8_t)));
+ } else {
+ a->vmovups(
+ actRegAvx2_, x86::dword_ptr(in_acts_R_, act_offset * sizeof(uint8_t)));
+ }
+ if (C_per_G_ == 4) {
+ gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_);
+ } else {
+ if (use_scratch_reg1) {
+ a->vmovups(
+ stPermRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_,
+ scratchReg1_,
+ 0,
+ (act_offset + VLEN_) * sizeof(uint8_t)));
+ } else {
+ a->vmovups(
+ stPermRegAvx2_,
+ x86::dword_ptr(in_acts_R_, (act_offset + VLEN_) * sizeof(uint8_t)));
+ }
+ if (C_per_G_ == 8) {
+ gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_);
+ } else {
+ assert(C_per_G_ == 16);
+ if (use_scratch_reg1) {
+ a->vmovups(
+ WRegs_avx2_[0],
+ x86::dword_ptr(
+ in_acts_R_,
+ scratchReg1_,
+ 0,
+ (act_offset + 2 * VLEN_) * sizeof(uint8_t)));
+ a->vmovups(
+ WRegs_avx2_[1],
+ x86::dword_ptr(
+ in_acts_R_,
+ scratchReg1_,
+ 0,
+ (act_offset + 3 * VLEN_) * sizeof(uint8_t)));
+ } else {
+ a->vmovups(
+ WRegs_avx2_[0],
+ x86::dword_ptr(
+ in_acts_R_, (act_offset + 2 * VLEN_) * sizeof(uint8_t)));
+ a->vmovups(
+ WRegs_avx2_[1],
+ x86::dword_ptr(
+ in_acts_R_, (act_offset + 3 * VLEN_) * sizeof(uint8_t)));
+ }
+ gen8BitSumX16<inst_set_t::avx2>(
+ a, actRegAvx2_, stPermRegAvx2_, WRegs_avx2_[0], WRegs_avx2_[1]);
+ } // C_per_G_ != 8
+ } // C_per_G_ != 4
+}
+
+template <>
+template <>
void GenConvKernel<int32_t>::genZeroPtSum<inst_set_t::avx2>(
asmjit::X86Emitter* a,
int multiplier) {
@@ -840,7 +977,6 @@ void GenConvKernel<int32_t>::genCoreInsts<inst_set_t::avx2>(
a->jl(LoopH);
if (c_offset + 4 < C_per_G_) {
- // FIXME : simplify
// reset input pointer
// scratchReg2_ = W_R_ * C_ * (H_R_ - 2 * H_PAD_)
a->mov(scratchReg2_, H_R_);
@@ -988,22 +1124,7 @@ void GenConvKernel<int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>(
static_cast<asmjit::Imm>(h_in * C_ * sizeof(uint8_t)));
for (int s = W_PAD_; s < S_; ++s) {
int w_in = -W_PAD_ + s;
- a->vmovups(
- actRegAvx2_,
- x86::dword_ptr(
- in_acts_R_, scratchReg1_, 0, w_in * C_ * sizeof(uint8_t)));
- if (C_per_G_ == 4) {
- gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_);
- } else {
- a->vmovups(
- stPermRegAvx2_,
- x86::dword_ptr(
- in_acts_R_,
- scratchReg1_,
- 0,
- (w_in * C_ + 32) * sizeof(uint8_t)));
- gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_);
- }
+ gen8BitSum<inst_set_t::avx2>(a, w_in * C_);
}
}
@@ -1029,22 +1150,7 @@ void GenConvKernel<int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>(
W_R_,
static_cast<asmjit::Imm>(h_in * C_ * sizeof(uint8_t)));
for (int s = 0; s < S_; ++s) {
- a->vmovups(
- actRegAvx2_,
- x86::dword_ptr(
- in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t)));
- if (C_per_G_ == 4) {
- gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_);
- } else {
- a->vmovups(
- stPermRegAvx2_,
- x86::dword_ptr(
- in_acts_R_,
- scratchReg1_,
- 0,
- (s * C_ + 32) * sizeof(uint8_t)));
- gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_);
- }
+ gen8BitSum<inst_set_t::avx2>(a, s * C_);
}
}
a->add(in_acts_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
@@ -1082,14 +1188,7 @@ void GenConvKernel<int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>(
a->sub(scratchReg2_, static_cast<asmjit::Imm>(R_ - W_PAD_ - s));
a->imul(scratchReg2_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
a->add(scratchReg1_, scratchReg2_);
- a->vmovups(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_));
- if (C_per_G_ == 4) {
- gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_);
- } else {
- a->vmovups(
- stPermRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_, 0, 32));
- gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_);
- }
+ gen8BitSum<inst_set_t::avx2>(a, 0);
}
}
@@ -1122,25 +1221,7 @@ void GenConvKernel<int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>(
a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
for (int r = 0; r < R_; ++r) {
for (int s = W_PAD_; s < S_; ++s) {
- a->vmovups(
- actRegAvx2_,
- x86::dword_ptr(
- in_acts_R_,
- scratchReg1_,
- 0,
- (s - W_PAD_) * C_ * sizeof(uint8_t)));
- if (C_per_G_ == 4) {
- gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_);
- } else {
- a->vmovups(
- stPermRegAvx2_,
- x86::dword_ptr(
- in_acts_R_,
- scratchReg1_,
- 0,
- ((s - W_PAD_) * C_ + 32) * sizeof(uint8_t)));
- gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_);
- }
+ gen8BitSum<inst_set_t::avx2>(a, (s - W_PAD_) * C_);
}
a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
a->add(scratchReg1_, scratchReg2_);
@@ -1197,15 +1278,7 @@ void GenConvKernel<int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>(
a->add(scratchReg1_, scratchReg2_);
for (int r = 0; r < R_; ++r) {
for (int s = 0; s < S_ - W_PAD_; ++s) {
- a->vmovups(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_));
- if (C_per_G_ == 4) {
- gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_);
- } else {
- a->vmovups(
- stPermRegAvx2_,
- x86::dword_ptr(in_acts_R_, scratchReg1_, 0, 32 * sizeof(uint8_t)));
- gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_);
- }
+ gen8BitSum<inst_set_t::avx2>(a, 0);
a->add(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
}
@@ -1258,25 +1331,7 @@ void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
for (int r = 0; r < R_ - H_PAD_; ++r) {
for (int s = W_PAD_; s < S_; ++s) {
- a->vmovups(
- actRegAvx2_,
- x86::dword_ptr(
- in_acts_R_,
- scratchReg1_,
- 0,
- (s - W_PAD_) * C_ * sizeof(uint8_t)));
- if (C_per_G_ == 4) {
- gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_);
- } else {
- a->vmovups(
- stPermRegAvx2_,
- x86::dword_ptr(
- in_acts_R_,
- scratchReg1_,
- 0,
- ((s - W_PAD_) * C_ + 32) * sizeof(uint8_t)));
- gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_);
- }
+ gen8BitSum<inst_set_t::avx2>(a, (s - W_PAD_) * C_);
}
a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
a->add(scratchReg1_, scratchReg2_);
@@ -1307,19 +1362,7 @@ void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
for (int r = 0; r < R_ - W_PAD_; ++r) {
// int h_in = H_-2*H_PAD_ + r;
for (int s = 0; s < S_; ++s) {
- a->vmovups(
- actRegAvx2_,
- x86::dword_ptr(
- in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t)));
- if (C_per_G_ == 4) {
- gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_);
- } else {
- a->vmovups(
- stPermRegAvx2_,
- x86::dword_ptr(
- in_acts_R_, scratchReg1_, 0, (s * C_ + 32) * sizeof(uint8_t)));
- gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_);
- }
+ gen8BitSum<inst_set_t::avx2>(a, s * C_);
}
a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
a->add(scratchReg1_, scratchReg2_);
@@ -1358,19 +1401,7 @@ void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
for (int r = 0; r < R_ - H_PAD_; ++r) {
for (int s = 0; s < S_ - W_PAD_; ++s) {
- a->vmovups(
- actRegAvx2_,
- x86::dword_ptr(
- in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t)));
- if (C_per_G_ == 4) {
- gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_);
- } else {
- a->vmovups(
- stPermRegAvx2_,
- x86::dword_ptr(
- in_acts_R_, scratchReg1_, 0, (s * C_ + 32) * sizeof(uint8_t)));
- gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_);
- }
+ gen8BitSum<inst_set_t::avx2>(a, s * C_);
}
a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
a->add(scratchReg1_, scratchReg2_);
@@ -1417,16 +1448,7 @@ void GenConvKernel<int32_t>::genRowoffsetCore<inst_set_t::avx2>(
a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
for (int r = 0; r < R_; ++r) {
for (int s = 0; s < S_; ++s) {
- a->vmovups(
- actRegAvx2_, x86::dword_ptr(in_acts_R_, s * C_ * sizeof(uint8_t)));
- if (C_per_G_ == 4) {
- gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_);
- } else {
- a->vmovups(
- stPermRegAvx2_,
- x86::dword_ptr(in_acts_R_, (s * C_ + 32) * sizeof(uint8_t)));
- gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_);
- }
+ gen8BitSum<inst_set_t::avx2>(a, s * C_, false /*use_scratch_reg1*/);
}
a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
a->add(in_acts_R_, scratchReg2_);
@@ -1600,14 +1622,10 @@ void fbgemmGroupwiseConvBase_(
for (int i = 0; i < MB; ++i) {
const uint8_t* actStartBatch = activations + i * ih_iw * conv_param.IC;
for (int gOuter = 0; gOuter < G; gOuter += 8) {
- // for C_per_G == 4 and K_per_G == 4, row offset is calcualted for 8
- // groups at a time The result is row offsets in the format IH*IW x G
+ // row offset is calcualted for 8 groups at a time.
+ // The result is row offsets in the format IH*IW x G
fpRowoffset(
- actStartBatch + gOuter * C_per_G,
- a_zero_point,
- H,
- W,
- rowOffsetBuf);
+ actStartBatch + gOuter * C_per_G, a_zero_point, H, W, rowOffsetBuf);
// Transpose to get row offsets in the format G x IH*IW
internal::transpose_8x8(
ih_iw,
@@ -1617,20 +1635,30 @@ void fbgemmGroupwiseConvBase_(
reinterpret_cast<float*>(rowOffsetTrDest),
ih_iw);
int gLimit = gOuter + 8;
- int gDelta = C_per_G == 4 ? 2 : 1;
+ // Work on 8 output channels at a time (8 * sizeof(int32_t) == 32B VLEN
+ // of AVX2), and we need multiple groups if a group has not enough
+ // number of channels.
+ int gDelta = std::max(8 / C_per_G, 1);
for (int g = gOuter; g < gLimit; g += gDelta) {
int32_t* currOutBuf =
outBuffer + i * oh_ow * conv_param.OC + g * K_per_G;
const uint8_t* actStartGroup = actStartBatch + g * C_per_G;
-
- fpConv(
- actStartGroup,
- packed_weights.getBuf() +
- g * conv_param.K[0] * conv_param.K[1] * K_per_G * C_per_G,
- currOutBuf,
- a_zero_point,
- H,
- W);
+ for (int k = 0; k < K_per_G; k += 8) {
+ // Don't be confused with k above which refers to output channels.
+ // k0 and k1 are filter dimensions (commonly 3 and 3)
+ int k0 = conv_param.K[0];
+ int k1 = conv_param.K[1];
+ fpConv(
+ actStartGroup,
+ // packed weight is in G (C/4) R S K 4 layout for IC_per_G >= 8
+ // in (G/2) R S K (2C) for IC_per_G == 4
+ packed_weights.getBuf() +
+ (g * (C_per_G / 4) * k0 * k1 * K_per_G + k) * 4,
+ currOutBuf + k,
+ a_zero_point,
+ H,
+ W);
+ } // k loop
// Output processing should be called for each group
for (int j = 0; j < gDelta; ++j) {
@@ -1791,14 +1819,10 @@ void fbgemmGroupwiseConv(
for (int i = 0; i < MB; ++i) {
const uint8_t* actStartBatch = activations + i * ih_iw * conv_param.IC;
for (int gOuter = 0; gOuter < G; gOuter += 8) {
- // for C_per_G == 4 and K_per_G == 4, row offset is calcualted for 8
- // groups at a time The result is row offsets in the format IH*IW x G
+ // row offset is calcualted for 8 groups at a time.
+ // The result is row offsets in the format IH*IW x G
fpRowoffset(
- actStartBatch + gOuter * C_per_G,
- a_zero_point,
- H,
- W,
- rowOffsetBuf);
+ actStartBatch + gOuter * C_per_G, a_zero_point, H, W, rowOffsetBuf);
// Transpose to get row offsets in the format G x IH*IW
internal::transpose_8x8(
ih_iw,
@@ -1808,20 +1832,31 @@ void fbgemmGroupwiseConv(
reinterpret_cast<float*>(rowOffsetTrDest),
ih_iw);
int gLimit = gOuter + 8;
- int gDelta = C_per_G == 4 ? 2 : 1;
+ // Work on 8 output channels at a time (8 * sizeof(int32_t) == 32B VLEN
+ // of AVX2), and we need multiple groups if a group has not enough
+ // number of channels.
+ int gDelta = std::max(8 / C_per_G, 1);
for (int g = gOuter; g < gLimit; g += gDelta) {
+ // Reusing the same region of outBuffer multiple times for locality
int32_t* currOutBuf = outBuffer + (g - gOuter) * K_per_G;
const uint8_t* actStartGroup = actStartBatch + g * C_per_G;
-
- fpConv(
- actStartGroup,
- packed_weights.getBuf() +
- g * conv_param.K[0] * conv_param.K[1] * K_per_G * C_per_G,
- currOutBuf,
- a_zero_point,
- H,
- W);
- }
+ for (int k = 0; k < K_per_G; k += 8) {
+ // Don't be confused with k above which refers to output channels.
+ // k0 and k1 are filter dimensions (commonly 3 and 3)
+ int k0 = conv_param.K[0];
+ int k1 = conv_param.K[1];
+ fpConv(
+ actStartGroup,
+ // packed weight is in G (C/4) R S K 4 layout for IC_per_G >= 8
+ // in (G/2) R S K (2C) for IC_per_G == 4
+ packed_weights.getBuf() +
+ (g * (C_per_G / 4) * k0 * k1 * K_per_G + k) * 4,
+ currOutBuf + k,
+ a_zero_point,
+ H,
+ W);
+ } // k loop
+ } // g loop
bool b_symmetric =
outProcess.getBZeroPoint()[0] == 0 || rowOffsetBuf == nullptr;
@@ -1919,7 +1954,7 @@ void fbgemmGroupwiseConv(
}
}
}
- } else {
+ } else if (C_per_G == 8) {
if (a_zero_point == 0) {
if (b_symmetric) {
if (outProcess.getBias() == nullptr) {
@@ -1997,6 +2032,84 @@ void fbgemmGroupwiseConv(
}
}
}
+ } else {
+ if (a_zero_point == 0) {
+ if (b_symmetric) {
+ if (outProcess.getBias() == nullptr) {
+ requantizeOutputProcessingGConvAvx2<
+ true,
+ true,
+ Q_GRAN,
+ false,
+ FUSE_RELU,
+ 16>(out, inp, block, ld_out, ld_in, r);
+ } else {
+ requantizeOutputProcessingGConvAvx2<
+ true,
+ true,
+ Q_GRAN,
+ true,
+ FUSE_RELU,
+ 16>(out, inp, block, ld_out, ld_in, r);
+ }
+ } else {
+ if (outProcess.getBias() == nullptr) {
+ requantizeOutputProcessingGConvAvx2<
+ true,
+ false,
+ Q_GRAN,
+ false,
+ FUSE_RELU,
+ 16>(out, inp, block, ld_out, ld_in, r);
+ } else {
+ requantizeOutputProcessingGConvAvx2<
+ true,
+ false,
+ Q_GRAN,
+ true,
+ FUSE_RELU,
+ 16>(out, inp, block, ld_out, ld_in, r);
+ }
+ }
+ } else {
+ if (b_symmetric) {
+ if (outProcess.getBias() == nullptr) {
+ requantizeOutputProcessingGConvAvx2<
+ false,
+ true,
+ Q_GRAN,
+ false,
+ FUSE_RELU,
+ 16>(out, inp, block, ld_out, ld_in, r);
+ } else {
+ requantizeOutputProcessingGConvAvx2<
+ false,
+ true,
+ Q_GRAN,
+ true,
+ FUSE_RELU,
+ 16>(out, inp, block, ld_out, ld_in, r);
+ }
+ } else {
+ if (outProcess.getBias() == nullptr) {
+ requantizeOutputProcessingGConvAvx2<
+ false,
+ false,
+ Q_GRAN,
+ false,
+ FUSE_RELU,
+ 16>(out, inp, block, ld_out, ld_in, r);
+ } else {
+ requantizeOutputProcessingGConvAvx2<
+ false,
+ false,
+ Q_GRAN,
+ true,
+ FUSE_RELU,
+ 16>(out, inp, block, ld_out, ld_in, r);
+ }
+ }
+ }
}
} // gOuter loop
} // i loop
@@ -2027,7 +2140,8 @@ int rowOffsetBufferSizeGConv(const conv_param_t<SPATIAL_DIM>& conv_param) {
int bufferSize = conv_param.OUT_DIM[0] * conv_param.OUT_DIM[1];
int C_per_G = conv_param.IC / conv_param.G;
int K_per_G = conv_param.OC / conv_param.G;
- if ((C_per_G == 4 && K_per_G == 4) || (C_per_G == 8 && K_per_G == 8)) {
+ if (C_per_G == K_per_G &&
+ (C_per_G == 4 || C_per_G == 8 || C_per_G == 16)) {
return 2 * 8 * bufferSize;
} else {
return conv_param.G * bufferSize;
@@ -2036,7 +2150,8 @@ int rowOffsetBufferSizeGConv(const conv_param_t<SPATIAL_DIM>& conv_param) {
int bufferSize = conv_param.OUT_DIM[0] * conv_param.OUT_DIM[1];
int C_per_G = conv_param.IC / conv_param.G;
int K_per_G = conv_param.OC / conv_param.G;
- if ((C_per_G == 4 && K_per_G == 4) || (C_per_G == 8 && K_per_G == 8)) {
+ if (C_per_G == K_per_G &&
+ (C_per_G == 4 || C_per_G == 8 || C_per_G == 16)) {
// row offset is calculated for 8 groups at a time
// 2x is needed for transposing
return 2 * 8 * bufferSize;
diff --git a/src/QuantUtilsAvx2.cc b/src/QuantUtilsAvx2.cc
index bc41310..be12142 100644
--- a/src/QuantUtilsAvx2.cc
+++ b/src/QuantUtilsAvx2.cc
@@ -763,12 +763,17 @@ void requantizeOutputProcessingGConvAvx2(
_mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 0])),
_mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 1]),
1);
- } else {
- assert(C_PER_G == 8);
+ } else if (C_PER_G == 8) {
row_offset_v =
_mm256_set1_epi32(r.row_offsets
[(i - block.row_start) * 8 +
(j - block.col_start) / (VLEN * 4) * 4]);
+ } else {
+ assert(C_PER_G == 16);
+ row_offset_v =
+ _mm256_set1_epi32(r.row_offsets
+ [(i - block.row_start) * 8 +
+ (j - block.col_start) / (VLEN * 4) * 2]);
}
__m256i B_zero_point_v = _mm256_set1_epi32(r.B_zero_point[0]);
if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
@@ -781,10 +786,14 @@ void requantizeOutputProcessingGConvAvx2(
_mm_set1_epi32(r.B_zero_point[quant_param_idx])),
_mm_set1_epi32(r.B_zero_point[quant_param_idx + 1]),
1);
- } else {
+ } else if (C_PER_G == 8) {
B_zero_point_v = _mm256_set1_epi32(
r.B_zero_point
[quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4]);
+ } else {
+ B_zero_point_v = _mm256_set1_epi32(
+ r.B_zero_point
+ [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 2]);
}
}
row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v);
@@ -799,12 +808,17 @@ void requantizeOutputProcessingGConvAvx2(
_mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 2])),
_mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 3]),
1);
- } else {
+ } else if (C_PER_G == 8) {
// + 1 here is for group 1
row_offset_v = _mm256_set1_epi32(
r.row_offsets
[(i - block.row_start) * 8 +
(j - block.col_start) / (VLEN * 4) * 4 + 1]);
+ } else if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
+ row_offset_v =
+ _mm256_set1_epi32(r.row_offsets
+ [(i - block.row_start) * 8 +
+ (j - block.col_start) / (VLEN * 4) * 2]);
}
if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
B_zero_point_v = _mm256_loadu_si256(
@@ -816,14 +830,16 @@ void requantizeOutputProcessingGConvAvx2(
_mm_set1_epi32(r.B_zero_point[quant_param_idx + 2])),
_mm_set1_epi32(r.B_zero_point[quant_param_idx + 3]),
1);
- } else {
+ } else if (C_PER_G == 8) {
B_zero_point_v = _mm256_set1_epi32(
r.B_zero_point
[quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 +
1]);
}
}
- row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v);
+ if (C_PER_G != 16 || Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
+ row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v);
+ }
y_v = _mm256_sub_epi32(y_v, row_offset_v);
// Groups 4 and 5 when C_PER_G == 4
@@ -834,11 +850,16 @@ void requantizeOutputProcessingGConvAvx2(
_mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 4])),
_mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 5]),
1);
- } else {
+ } else if (C_PER_G == 8) {
row_offset_v = _mm256_set1_epi32(
r.row_offsets
[(i - block.row_start) * 8 +
(j - block.col_start) / (VLEN * 4) * 4 + 2]);
+ } else {
+ row_offset_v = _mm256_set1_epi32(
+ r.row_offsets
+ [(i - block.row_start) * 8 +
+ (j - block.col_start) / (VLEN * 4) * 2 + 1]);
}
if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
B_zero_point_v = _mm256_loadu_si256(
@@ -850,11 +871,16 @@ void requantizeOutputProcessingGConvAvx2(
_mm_set1_epi32(r.B_zero_point[quant_param_idx + 4])),
_mm_set1_epi32(r.B_zero_point[quant_param_idx + 5]),
1);
- } else {
+ } else if (C_PER_G == 8) {
B_zero_point_v = _mm256_set1_epi32(
r.B_zero_point
[quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 +
2]);
+ } else {
+ B_zero_point_v = _mm256_set1_epi32(
+ r.B_zero_point
+ [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 2 +
+ 1]);
}
}
row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v);
@@ -868,11 +894,16 @@ void requantizeOutputProcessingGConvAvx2(
_mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 6])),
_mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 7]),
1);
- } else {
+ } else if (C_PER_G == 8) {
row_offset_v = _mm256_set1_epi32(
r.row_offsets
[(i - block.row_start) * 8 +
(j - block.col_start) / (VLEN * 4) * 4 + 3]);
+ } else if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
+ row_offset_v = _mm256_set1_epi32(
+ r.row_offsets
+ [(i - block.row_start) * 8 +
+ (j - block.col_start) / (VLEN * 4) * 2 + 1]);
}
if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
B_zero_point_v = _mm256_loadu_si256(
@@ -884,14 +915,16 @@ void requantizeOutputProcessingGConvAvx2(
_mm_set1_epi32(r.B_zero_point[quant_param_idx + 6])),
_mm_set1_epi32(r.B_zero_point[quant_param_idx + 7]),
1);
- } else {
+ } else if (C_PER_G == 8) {
B_zero_point_v = _mm256_set1_epi32(
r.B_zero_point
[quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 +
3]);
}
}
- row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v);
+ if (C_PER_G != 16 || Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
+ row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v);
+ }
w_v = _mm256_sub_epi32(w_v, row_offset_v);
}
if (HAS_BIAS) {
@@ -966,7 +999,7 @@ void requantizeOutputProcessingGConvAvx2(
_mm_set1_ps(r.C_multiplier[quant_param_idx + 7]),
1);
w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v);
- } else {
+ } else if (C_PER_G == 8) {
multiplier_v = _mm256_set1_ps(
r.C_multiplier
[quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4]);
@@ -989,6 +1022,19 @@ void requantizeOutputProcessingGConvAvx2(
[quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 +
3]);
w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v);
+ } else {
+ multiplier_v = _mm256_set1_ps(
+ r.C_multiplier
+ [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 2]);
+ x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
+ y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v);
+
+ multiplier_v = _mm256_set1_ps(
+ r.C_multiplier
+ [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 2 +
+ 1]);
+ z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v);
+ w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v);
}
} else {
x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
@@ -1065,30 +1111,38 @@ void requantizeOutputProcessingGConvAvx2(
} // i loop
}
-#define INSTANTIATE_REQUANTIZE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU) \
- template void \
- requantizeOutputProcessingAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \
- uint8_t * out, \
- const int32_t* inp, \
- const block_type_t& block, \
- int ld_out, \
- int ld_in, \
- const requantizationParams_t& r); \
- template void \
- requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 4>( \
- uint8_t * out, \
- const int32_t* inp, \
- const block_type_t& block, \
- int ld_out, \
- int ld_in, \
- const requantizationParams_t& r); \
- template void \
- requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 8>( \
- uint8_t * out, \
- const int32_t* inp, \
- const block_type_t& block, \
- int ld_out, \
- int ld_in, \
+#define INSTANTIATE_REQUANTIZE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU) \
+ template void \
+ requantizeOutputProcessingAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \
+ uint8_t * out, \
+ const int32_t* inp, \
+ const block_type_t& block, \
+ int ld_out, \
+ int ld_in, \
+ const requantizationParams_t& r); \
+ template void \
+ requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 4>( \
+ uint8_t * out, \
+ const int32_t* inp, \
+ const block_type_t& block, \
+ int ld_out, \
+ int ld_in, \
+ const requantizationParams_t& r); \
+ template void \
+ requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 8>( \
+ uint8_t * out, \
+ const int32_t* inp, \
+ const block_type_t& block, \
+ int ld_out, \
+ int ld_in, \
+ const requantizationParams_t& r); \
+ template void \
+ requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 16>( \
+ uint8_t * out, \
+ const int32_t* inp, \
+ const block_type_t& block, \
+ int ld_out, \
+ int ld_in, \
const requantizationParams_t& r);
#define INSTANTIATE_A_SYM(B_SYM, Q_GRAN, BIAS, RELU) \
diff --git a/test/GConvTest.cc b/test/GConvTest.cc
index e1434d2..66d61b6 100644
--- a/test/GConvTest.cc
+++ b/test/GConvTest.cc
@@ -69,6 +69,7 @@ static vector<conv_param_t<>> GetShapes_() {
conv_param_t<>(1, 8, 8, {5, 5}, 2, {3, 3}, {1, 1}, {1, 1, 1, 1}),
conv_param_t<>(1, 128, 128, {56, 48}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
conv_param_t<>(1, 128, 128, {48, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ // the line below is from resnext101-32x4d
conv_param_t<>(1, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
conv_param_t<>(2, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
@@ -81,6 +82,16 @@ static vector<conv_param_t<>> GetShapes_() {
conv_param_t<>(1, 256, 256, {48, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
conv_param_t<>(1, 256, 256, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
conv_param_t<>(2, 256, 256, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+
+ conv_param_t<>(1, 128, 128, {3, 3}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 128, 128, {4, 4}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 128, 128, {3, 5}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 128, 128, {5, 3}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 32, 32, {5, 5}, 2, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 512, 512, {56, 48}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 512, 512, {48, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 512, 512, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(2, 512, 512, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
};
return shapes;
}