Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJongsoo Park <jongsoo@fb.com>2019-02-02 21:32:16 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-02 21:35:25 +0300
commitdf7b1c1237c2f4274294ad9136861f30a7234c14 (patch)
treef8cfb98be892ca5e1d83beee060010868d123b2e
parenta7d921d447564cdc67a07ad79dfc9e0d3dcd7418 (diff)
gconv optimized for 8 channels per group (#65)
Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/65 As title says Reviewed By: jianyuh Differential Revision: D13834287 fbshipit-source-id: ff174fdfcc27bcc227e435ff27e5c2a7024bf736
-rw-r--r--bench/GroupwiseConvRequantizeBenchmark.cc5
-rw-r--r--include/fbgemm/QuantUtilsAvx2.h5
-rw-r--r--src/Fbgemm.cc7
-rw-r--r--src/GroupwiseConv.h22
-rw-r--r--src/GroupwiseConvAcc32Avx2.cc661
-rw-r--r--src/PackWeightMatrixForGConv.cc21
-rw-r--r--src/QuantUtilsAvx2.cc281
-rw-r--r--test/GConvTest.cc10
8 files changed, 751 insertions, 261 deletions
diff --git a/bench/GroupwiseConvRequantizeBenchmark.cc b/bench/GroupwiseConvRequantizeBenchmark.cc
index 991013f..158ca4f 100644
--- a/bench/GroupwiseConvRequantizeBenchmark.cc
+++ b/bench/GroupwiseConvRequantizeBenchmark.cc
@@ -58,6 +58,11 @@ void performance_test() {
// BatchSize > 1
// conv_param_t<>(2, 128, 128, {56, 48}, 32, {3, 3}, {1, 1}, {1, 1, 1,
// 1}),
+
+ conv_param_t<>(1, 256, 256, {56, 48}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 256, 256, {48, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 256, 256, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(2, 256, 256, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
};
bool flush = true;
diff --git a/include/fbgemm/QuantUtilsAvx2.h b/include/fbgemm/QuantUtilsAvx2.h
index 40b830c..04aeba1 100644
--- a/include/fbgemm/QuantUtilsAvx2.h
+++ b/include/fbgemm/QuantUtilsAvx2.h
@@ -85,8 +85,9 @@ template <
bool B_SYMMETRIC,
QuantizationGranularity Q_GRAN,
bool HAS_BIAS,
- bool FUSE_RELU>
-FBGEMM_API void requantizeOutputProcessingGConv4Avx2(
+ bool FUSE_RELU,
+ int C_PER_G>
+FBGEMM_API void requantizeOutputProcessingGConvAvx2(
std::uint8_t* out,
const std::int32_t* inp,
const block_type_t& block,
diff --git a/src/Fbgemm.cc b/src/Fbgemm.cc
index 9384af6..cb22999 100644
--- a/src/Fbgemm.cc
+++ b/src/Fbgemm.cc
@@ -197,9 +197,10 @@ FBGEMM_API bool fbgemmOptimizedGConv(const conv_param_t<SPATIAL_DIM>& conv_p) {
int C_per_G = conv_p.IC / conv_p.G;
int K_per_G = conv_p.OC / conv_p.G;
- return (SPATIAL_DIM == 2) && (C_per_G == K_per_G) && (C_per_G == 4) &&
- (conv_p.G % 8 == 0) && (conv_p.K[0] == conv_p.K[1]) &&
- (conv_p.K[0] == 3) && (conv_p.pad[0] == 1) && (conv_p.pad[1] == 1) &&
+ return (SPATIAL_DIM == 2) && (C_per_G == K_per_G) &&
+ (C_per_G == 4 || C_per_G == 8) && (conv_p.G % 8 == 0) &&
+ (conv_p.K[0] == conv_p.K[1]) && (conv_p.K[0] == 3) &&
+ (conv_p.pad[0] == 1) && (conv_p.pad[1] == 1) &&
(conv_p.pad[0] == conv_p.pad[2]) && (conv_p.pad[1] == conv_p.pad[3]) &&
(conv_p.dilation[0] == 1) && (conv_p.dilation[0] == conv_p.dilation[1]) &&
(conv_p.stride[0] == 1) && (conv_p.stride[0] == conv_p.stride[1]);
diff --git a/src/GroupwiseConv.h b/src/GroupwiseConv.h
index 0d2db3a..b65082f 100644
--- a/src/GroupwiseConv.h
+++ b/src/GroupwiseConv.h
@@ -142,32 +142,38 @@ class GenConvKernel {
gen8bitFMA(asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm wReg);
template <inst_set_t instSet>
- void genForLoadingWeights(asmjit::X86Emitter* a);
+ void genForLoadingWeights(asmjit::X86Emitter* a, int c_offset);
template <inst_set_t instSet>
void genConstForPermutations(asmjit::X86Emitter* a);
template <inst_set_t instSet>
- void genForTopEdge(asmjit::X86Emitter* a);
+ void genForTopEdge(asmjit::X86Emitter* a, int c_offset);
template <inst_set_t instSet>
- void genForLeftEdge(asmjit::X86Emitter* a);
+ void genForLeftEdge(asmjit::X86Emitter* a, int c_offset);
template <inst_set_t instSet>
- void genForRightEdge(asmjit::X86Emitter* a);
+ void genForRightEdge(asmjit::X86Emitter* a, int c_offset);
template <inst_set_t instSet>
- void genForBottomEdge(asmjit::X86Emitter* a);
+ void genForBottomEdge(asmjit::X86Emitter* a, int c_offset);
template <inst_set_t instSet>
- void genCoreInsts(asmjit::X86Emitter* a);
+ void genCoreInsts(asmjit::X86Emitter* a, int c_offset);
template <inst_set_t instSet>
- void storeResult(asmjit::X86Emitter* a, int offset = 0);
+ void storeResult(asmjit::X86Emitter* a);
// for Rowoffset kernel
+ // Add 4 consecutive numbers of 32 uint8 and emit 8 32-bit
template <inst_set_t instSet>
- void gen8BitSum(asmjit::X86Emitter* a, asmjit::X86Ymm aReg);
+ void gen8BitSumX4(asmjit::X86Emitter* a, asmjit::X86Ymm aReg);
+
+ // Add 8 consecutive numbers of 64 uint8 and emit 8 32-bit
+ template <inst_set_t instSet>
+ void
+ gen8BitSumX8(asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm bReg);
// Use scratchReg1_ and tmpReg1Avx2_ internally
template <inst_set_t instSet>
diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc
index c85e339..be5145b 100644
--- a/src/GroupwiseConvAcc32Avx2.cc
+++ b/src/GroupwiseConvAcc32Avx2.cc
@@ -158,11 +158,12 @@ void GenConvKernel<int32_t>::genConstForPermutations<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<int32_t>::storeResult<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
- int offset) {
- // store with permutation
- a->vpermd(resultRegAvx2_, stPermRegAvx2_, resultRegAvx2_);
- a->vmovups(x86::dword_ptr(out_acts_R_, offset), resultRegAvx2_);
+ asmjit::X86Emitter* a) {
+ if (C_per_G_ == 4) {
+ // store with permutation
+ a->vpermd(resultRegAvx2_, stPermRegAvx2_, resultRegAvx2_);
+ }
+ a->vmovups(x86::dword_ptr(out_acts_R_), resultRegAvx2_);
}
template <>
@@ -171,21 +172,38 @@ void GenConvKernel<int32_t>::storeResultRowoffset<inst_set_t::avx2>(
asmjit::X86Emitter* a,
int offset) {
// store
- a->vmovups(x86::dword_ptr(row_offset_R_, offset), resultRegAvx2_);
+ if (C_per_G_ == 4) {
+ a->vmovups(x86::dword_ptr(row_offset_R_, offset), resultRegAvx2_);
+ } else {
+ // need to permute due to vphaddd
+ // 11 01 10 00
+ a->vpermq(resultRegAvx2_, resultRegAvx2_, static_cast<asmjit::Imm>(0xd8));
+ a->vmovups(x86::dword_ptr(row_offset_R_, offset), resultRegAvx2_);
+ }
}
template <>
template <>
void GenConvKernel<int32_t>::genForLoadingWeights<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ asmjit::X86Emitter* a, int c_offset) {
// load weights
for (int r = 0; r < R_; ++r) {
for (int s = 0; s < S_; ++s) {
- a->vmovaps(
- WRegs_avx2_[r * S_ + s],
- x86::dword_ptr(
- wghts_R_,
- (r * S_ + s) * 2 * K_per_G_ * C_per_G_ * sizeof(int8_t)));
+ if (C_per_G_ == 4) {
+ a->vmovaps(
+ WRegs_avx2_[r * S_ + s],
+ x86::dword_ptr(
+ wghts_R_,
+ (r * S_ + s) * 2 * K_per_G_ * C_per_G_ * sizeof(int8_t)));
+ } else {
+ // C_per_G == 8
+ a->vmovaps(
+ WRegs_avx2_[r * S_ + s],
+ x86::dword_ptr(
+ wghts_R_,
+ (((c_offset / 4) * R_ + r) * S_ + s) * K_per_G_ * 4 *
+ sizeof(int8_t)));
+ }
}
}
}
@@ -203,7 +221,7 @@ void GenConvKernel<int32_t>::gen8bitFMA<inst_set_t::avx2>(
template <>
template <>
-void GenConvKernel<int32_t>::gen8BitSum<inst_set_t::avx2>(
+void GenConvKernel<int32_t>::gen8BitSumX4<inst_set_t::avx2>(
asmjit::X86Emitter* a,
asmjit::X86Ymm aReg) {
a->vpmaddubsw(tmpReg1Avx2_, aReg, oneReg8BitAvx2_);
@@ -213,6 +231,19 @@ void GenConvKernel<int32_t>::gen8BitSum<inst_set_t::avx2>(
template <>
template <>
+void GenConvKernel<int32_t>::gen8BitSumX8<inst_set_t::avx2>(
+ asmjit::X86Emitter* a,
+ asmjit::X86Ymm aReg,
+ asmjit::X86Ymm bReg) {
+ a->vxorps(tmpReg1Avx2_, tmpReg1Avx2_, tmpReg1Avx2_);
+ a->vpsadbw(aReg, aReg, tmpReg1Avx2_);
+ a->vpsadbw(bReg, bReg, tmpReg1Avx2_);
+ a->vphaddd(aReg, aReg, bReg);
+ a->vpaddd(resultRegAvx2_, aReg, resultRegAvx2_);
+}
+
+template <>
+template <>
void GenConvKernel<int32_t>::genZeroPtSum<inst_set_t::avx2>(
asmjit::X86Emitter* a,
int multiplier) {
@@ -228,10 +259,14 @@ void GenConvKernel<int32_t>::genZeroPtSum<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<int32_t>::genForTopEdge<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ asmjit::X86Emitter* a, int c_offset) {
// top-left corner code
- // zero out the results register
- a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ if (c_offset == 0) {
+ // zero out the results register
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ } else {
+ a->vmovups(resultRegAvx2_, x86::dword_ptr(out_acts_R_));
+ }
for (int r = 0; r < R_; ++r) {
int h_in = -H_PAD_ + r;
if (h_in >= 0) {
@@ -243,10 +278,20 @@ void GenConvKernel<int32_t>::genForTopEdge<inst_set_t::avx2>(
for (int s = 0; s < S_; ++s) {
int w_in = -W_PAD_ + s;
if (h_in >= 0 && w_in >= 0) {
- a->vbroadcastsd(
- actRegAvx2_,
- x86::dword_ptr(
- in_acts_R_, scratchReg1_, 0, w_in * C_ * sizeof(uint8_t)));
+ if (C_per_G_ == 4) {
+ a->vbroadcastsd(
+ actRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_, scratchReg1_, 0, w_in * C_ * sizeof(uint8_t)));
+ } else {
+ a->vbroadcastss(
+ actRegAvx2_,
+ x86::word_ptr(
+ in_acts_R_,
+ scratchReg1_,
+ 0,
+ (w_in * C_ + c_offset) * sizeof(uint8_t)));
+ }
gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]);
} else {
if (!isZeroPointZero_) {
@@ -265,7 +310,11 @@ void GenConvKernel<int32_t>::genForTopEdge<inst_set_t::avx2>(
a->mov(loopR2_, static_cast<asmjit::Imm>(W_PAD_));
a->bind(LoopTopEdge);
// zero out
- a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ if (c_offset == 0) {
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ } else {
+ a->vmovups(resultRegAvx2_, x86::dword_ptr(out_acts_R_));
+ }
if (!isZeroPointZero_) {
for (int r = 0; r < H_PAD_; ++r) {
for (int s = 0; s < S_; ++s) {
@@ -280,10 +329,20 @@ void GenConvKernel<int32_t>::genForTopEdge<inst_set_t::avx2>(
W_R_,
static_cast<asmjit::Imm>(h_in * C_ * sizeof(uint8_t)));
for (int s = 0; s < S_; ++s) {
- a->vbroadcastsd(
- actRegAvx2_,
- x86::dword_ptr(
- in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t)));
+ if (C_per_G_ == 4) {
+ a->vbroadcastsd(
+ actRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t)));
+ } else {
+ a->vbroadcastss(
+ actRegAvx2_,
+ x86::word_ptr(
+ in_acts_R_,
+ scratchReg1_,
+ 0,
+ (s * C_ + c_offset) * sizeof(uint8_t)));
+ }
gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]);
}
}
@@ -307,7 +366,11 @@ void GenConvKernel<int32_t>::genForTopEdge<inst_set_t::avx2>(
// top-right corner code
// zero out
- a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ if (c_offset == 0) {
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ } else {
+ a->vmovups(resultRegAvx2_, x86::dword_ptr(out_acts_R_));
+ }
if (!isZeroPointZero_) {
for (int r = 0; r < H_PAD_; ++r) {
for (int s = 0; s < S_; ++s) {
@@ -327,7 +390,14 @@ void GenConvKernel<int32_t>::genForTopEdge<inst_set_t::avx2>(
a->sub(scratchReg2_, static_cast<asmjit::Imm>(R_ - W_PAD_ - s));
a->imul(scratchReg2_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
a->add(scratchReg1_, scratchReg2_);
- a->vbroadcastsd(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_));
+ if (C_per_G_ == 4) {
+ a->vbroadcastsd(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_));
+ } else {
+ a->vbroadcastss(
+ actRegAvx2_,
+ x86::word_ptr(
+ in_acts_R_, scratchReg1_, 0, c_offset * sizeof(uint8_t)));
+ }
gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]);
}
if (!isZeroPointZero_) {
@@ -348,13 +418,19 @@ void GenConvKernel<int32_t>::genForTopEdge<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<int32_t>::genForLeftEdge<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ asmjit::X86Emitter* a, int c_offset) {
// left edge excluding corners
asmjit::Label LoopLeftEdge = a->newLabel();
a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_));
a->bind(LoopLeftEdge);
- // zero out
- a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+ a->add(out_acts_R_, scratchReg2_);
+ if (c_offset == 0) {
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ } else {
+ a->vmovups(resultRegAvx2_, x86::dword_ptr(out_acts_R_));
+ }
a->mov(scratchReg1_, loopR1_);
a->sub(scratchReg1_, static_cast<asmjit::Imm>(H_PAD_));
a->imul(scratchReg1_, W_R_);
@@ -367,21 +443,28 @@ void GenConvKernel<int32_t>::genForLeftEdge<inst_set_t::avx2>(
}
}
for (int s = W_PAD_; s < S_; ++s) {
- a->vbroadcastsd(
- actRegAvx2_,
- x86::dword_ptr(
- in_acts_R_,
- scratchReg1_,
- 0,
- (s - W_PAD_) * C_ * sizeof(uint8_t)));
+ if (C_per_G_ == 4) {
+ a->vbroadcastsd(
+ actRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_,
+ scratchReg1_,
+ 0,
+ (s - W_PAD_) * C_ * sizeof(uint8_t)));
+ } else {
+ a->vbroadcastss(
+ actRegAvx2_,
+ x86::word_ptr(
+ in_acts_R_,
+ scratchReg1_,
+ 0,
+ ((s - W_PAD_) * C_ + c_offset) * sizeof(uint8_t)));
+ }
gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]);
}
a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
a->add(scratchReg1_, scratchReg2_);
}
-
- a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
- a->add(out_acts_R_, scratchReg2_);
storeResult<inst_set_t::avx2>(a);
a->inc(loopR1_);
@@ -401,7 +484,7 @@ void GenConvKernel<int32_t>::genForLeftEdge<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<int32_t>::genForRightEdge<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ asmjit::X86Emitter* a, int c_offset) {
// right edge excluding corners
asmjit::Label LoopRightEdge = a->newLabel();
@@ -415,8 +498,12 @@ void GenConvKernel<int32_t>::genForRightEdge<inst_set_t::avx2>(
a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_));
a->bind(LoopRightEdge);
- // zero out
- a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ if (c_offset == 0) {
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ } else {
+ a->vmovups(resultRegAvx2_, x86::dword_ptr(out_acts_R_));
+ }
a->mov(scratchReg1_, loopR1_);
a->sub(scratchReg1_, static_cast<asmjit::Imm>(H_PAD_));
a->imul(scratchReg1_, W_R_);
@@ -428,7 +515,14 @@ void GenConvKernel<int32_t>::genForRightEdge<inst_set_t::avx2>(
a->add(scratchReg1_, scratchReg2_);
for (int r = 0; r < R_; ++r) {
for (int s = 0; s < S_ - W_PAD_; ++s) {
- a->vbroadcastsd(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_));
+ if (C_per_G_ == 4) {
+ a->vbroadcastsd(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_));
+ } else {
+ a->vbroadcastss(
+ actRegAvx2_,
+ x86::word_ptr(
+ in_acts_R_, scratchReg1_, 0, c_offset * sizeof(uint8_t)));
+ }
gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]);
a->add(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
}
@@ -477,10 +571,20 @@ void GenConvKernel<int32_t>::genForRightEdge<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<int32_t>::genForBottomEdge<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ asmjit::X86Emitter* a, int c_offset) {
// bottom-left corner
- // zero out
- a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ // we updating the last row
+ a->mov(scratchReg1_, H_R_);
+ a->sub(scratchReg1_, 1);
+ a->imul(scratchReg1_, W_R_);
+ a->imul(scratchReg1_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+ a->add(out_acts_R_, scratchReg1_);
+ if (c_offset == 0) {
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ } else {
+ a->vmovups(resultRegAvx2_, x86::dword_ptr(out_acts_R_));
+ }
a->mov(scratchReg1_, H_R_);
a->sub(scratchReg1_, static_cast<asmjit::Imm>(2 * H_PAD_));
a->imul(scratchReg1_, W_R_);
@@ -493,13 +597,23 @@ void GenConvKernel<int32_t>::genForBottomEdge<inst_set_t::avx2>(
}
}
for (int s = W_PAD_; s < S_; ++s) {
- a->vbroadcastsd(
- actRegAvx2_,
- x86::dword_ptr(
- in_acts_R_,
- scratchReg1_,
- 0,
- (s - W_PAD_) * C_ * sizeof(uint8_t)));
+ if (C_per_G_ == 4) {
+ a->vbroadcastsd(
+ actRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_,
+ scratchReg1_,
+ 0,
+ (s - W_PAD_) * C_ * sizeof(uint8_t)));
+ } else {
+ a->vbroadcastss(
+ actRegAvx2_,
+ x86::word_ptr(
+ in_acts_R_,
+ scratchReg1_,
+ 0,
+ ((s - W_PAD_) * C_ + c_offset) * sizeof(uint8_t)));
+ }
gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]);
}
a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
@@ -514,12 +628,6 @@ void GenConvKernel<int32_t>::genForBottomEdge<inst_set_t::avx2>(
}
}
- // we updating the last row
- a->mov(scratchReg1_, H_R_);
- a->sub(scratchReg1_, 1);
- a->imul(scratchReg1_, W_R_);
- a->imul(scratchReg1_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
- a->add(out_acts_R_, scratchReg1_);
// storeResult<inst_set_t::avx2>(a, (H_-1)*W_*K_*sizeof(int32_t));
storeResult<inst_set_t::avx2>(a);
a->add(out_acts_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
@@ -528,8 +636,12 @@ void GenConvKernel<int32_t>::genForBottomEdge<inst_set_t::avx2>(
asmjit::Label LoopBottomEdge = a->newLabel();
a->mov(loopR2_, static_cast<asmjit::Imm>(W_PAD_));
a->bind(LoopBottomEdge);
- // zero out
- a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ if (c_offset == 0) {
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ } else {
+ a->vmovups(resultRegAvx2_, x86::dword_ptr(out_acts_R_));
+ }
a->mov(scratchReg1_, H_R_);
a->sub(scratchReg1_, static_cast<asmjit::Imm>(2 * H_PAD_));
a->imul(scratchReg1_, W_R_);
@@ -537,10 +649,20 @@ void GenConvKernel<int32_t>::genForBottomEdge<inst_set_t::avx2>(
for (int r = 0; r < R_ - W_PAD_; ++r) {
// int h_in = H_-2*H_PAD_ + r;
for (int s = 0; s < S_; ++s) {
- a->vbroadcastsd(
- actRegAvx2_,
- x86::dword_ptr(
- in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t)));
+ if (C_per_G_ == 4) {
+ a->vbroadcastsd(
+ actRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t)));
+ } else {
+ a->vbroadcastss(
+ actRegAvx2_,
+ x86::word_ptr(
+ in_acts_R_,
+ scratchReg1_,
+ 0,
+ (s * C_ + c_offset) * sizeof(uint8_t)));
+ }
gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]);
}
a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
@@ -574,8 +696,12 @@ void GenConvKernel<int32_t>::genForBottomEdge<inst_set_t::avx2>(
// a->sub(out_acts_R_, (W_ - 2*W_PAD_)*K_*sizeof(int32_t));
// bottom-right corner
- // zero out
- a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ if (c_offset == 0) {
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ } else {
+ a->vmovups(resultRegAvx2_, x86::dword_ptr(out_acts_R_));
+ }
// input start point
// ((H_-(R_-H_PAD_))*W_+(W_-(S_-W_PAD_)))*C_*sizeof(uint8_t)
a->mov(scratchReg1_, H_R_);
@@ -586,10 +712,20 @@ void GenConvKernel<int32_t>::genForBottomEdge<inst_set_t::avx2>(
a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
for (int r = 0; r < R_ - H_PAD_; ++r) {
for (int s = 0; s < S_ - W_PAD_; ++s) {
- a->vbroadcastsd(
- actRegAvx2_,
- x86::dword_ptr(
- in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t)));
+ if (C_per_G_ == 4) {
+ a->vbroadcastsd(
+ actRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t)));
+ } else {
+ a->vbroadcastss(
+ actRegAvx2_,
+ x86::word_ptr(
+ in_acts_R_,
+ scratchReg1_,
+ 0,
+ (s * C_ + c_offset) * sizeof(uint8_t)));
+ }
gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]);
}
a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
@@ -626,7 +762,7 @@ void GenConvKernel<int32_t>::genForBottomEdge<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<int32_t>::genCoreInsts<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ asmjit::X86Emitter* a, int c_offset) {
// main compute
asmjit::Label LoopH = a->newLabel();
asmjit::Label LoopW = a->newLabel();
@@ -646,27 +782,41 @@ void GenConvKernel<int32_t>::genCoreInsts<inst_set_t::avx2>(
// W loop
a->mov(loopR2_, static_cast<asmjit::Imm>(W_PAD_));
a->bind(LoopW);
- // zero out
- a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ if (c_offset == 0) {
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ } else {
+ a->vmovups(resultRegAvx2_, x86::dword_ptr(out_acts_R_));
+ }
// compute on all filters
for (int r = 0; r < R_; ++r) {
for (int s = 0; s < S_; ++s) {
- a->vbroadcastsd(
- actRegAvx2_, x86::dword_ptr(in_acts_R_, s * C_ * sizeof(uint8_t)));
+ if (C_per_G_ == 4) {
+ a->vbroadcastsd(
+ actRegAvx2_, x86::dword_ptr(in_acts_R_, s * C_ * sizeof(uint8_t)));
+ } else {
+ a->vbroadcastss(
+ actRegAvx2_,
+ x86::word_ptr(in_acts_R_, (s * C_ + c_offset) * sizeof(uint8_t)));
+ }
gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]);
}
+ // advance input pointer by one row
a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
a->add(in_acts_R_, scratchReg2_);
}
+ // rewind input pointer
a->imul(
scratchReg2_, W_R_, static_cast<asmjit::Imm>(R_ * C_ * sizeof(uint8_t)));
a->sub(in_acts_R_, scratchReg2_);
// a->add(scratchReg1_, C_*sizeof(uint8_t));
+ // advance input pointer by one pixel
a->add(in_acts_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
// storeResult<inst_set_t::avx2>(a, (W_+1)*K_*sizeof(int32_t));
storeResult<inst_set_t::avx2>(a);
+ // advance output pointer by one pixel
a->add(out_acts_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
a->inc(loopR2_);
@@ -677,6 +827,7 @@ void GenConvKernel<int32_t>::genCoreInsts<inst_set_t::avx2>(
in_acts_R_, static_cast<asmjit::Imm>(2 * W_PAD_ * C_ * sizeof(uint8_t)));
// a->sub(in_acts_R_, (W_ - 2*W_PAD_)*C_*sizeof(uint8_t));
// a->add(in_acts_R_, W_*C_*sizeof(uint8_t));
+ // advance output pointer by padding size
a->add(
out_acts_R_, static_cast<asmjit::Imm>(2 * W_PAD_ * K_ * sizeof(int32_t)));
// a->sub(out_acts_R_, (W_ - 2*W_PAD_)*K_*sizeof(int32_t));
@@ -687,6 +838,28 @@ void GenConvKernel<int32_t>::genCoreInsts<inst_set_t::avx2>(
a->sub(scratchReg2_, static_cast<asmjit::Imm>(H_PAD_));
a->cmp(loopR1_, scratchReg2_);
a->jl(LoopH);
+
+ if (c_offset + 4 < C_per_G_) {
+ // FIXME : simplify
+ // reset input pointer
+ // scratchReg2_ = W_R_ * C_ * (H_R_ - 2 * H_PAD_)
+ a->mov(scratchReg2_, H_R_);
+ a->sub(scratchReg2_, static_cast<asmjit::Imm>(2 * H_PAD_));
+ a->imul(scratchReg2_, W_R_);
+ a->imul(scratchReg2_, C_);
+ a->sub(in_acts_R_, scratchReg2_);
+
+ // reset output pointer
+ assert(K_ == C_);
+ a->imul(scratchReg2_, static_cast<asmjit::Imm>(sizeof(int32_t)));
+ a->sub(out_acts_R_, scratchReg2_);
+
+ a->mov(scratchReg2_, static_cast<asmjit::Imm>(H_PAD_));
+ a->imul(scratchReg2_, W_R_);
+ a->add(scratchReg2_, static_cast<asmjit::Imm>(W_PAD_));
+ a->imul(scratchReg2_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+ a->sub(out_acts_R_, scratchReg2_);
+ }
}
template <>
@@ -759,16 +932,24 @@ jit_conv_kernel_fp GenConvKernel<int32_t>::getOrCreate<inst_set_t::avx2>(
setToZeroPt<inst_set_t::avx2>(a, zeroPTRegAvx2_);
}
- genForLoadingWeights<inst_set_t::avx2>(a);
-
genConstForPermutations<inst_set_t::avx2>(a);
- genForTopEdge<inst_set_t::avx2>(a);
- genForLeftEdge<inst_set_t::avx2>(a);
- genForRightEdge<inst_set_t::avx2>(a);
- genForBottomEdge<inst_set_t::avx2>(a);
-
- genCoreInsts<inst_set_t::avx2>(a);
+ // Work on 4 input channels at a time.
+ // The minimum unit should be 4 because instruction sequence in gen8bitFMA
+ // reduces 4 inputs.
+ // We can't work on more than 4 input channels because of we can't put too
+ // many weights in register (we need R S K 4 / 32 registers to store weights
+ // for 4 input channels).
+ for (int c = 0; c < C_per_G_; c += 4) {
+ genForLoadingWeights<inst_set_t::avx2>(a, c);
+
+ genForTopEdge<inst_set_t::avx2>(a, c);
+ genForLeftEdge<inst_set_t::avx2>(a, c);
+ genForRightEdge<inst_set_t::avx2>(a, c);
+ genForBottomEdge<inst_set_t::avx2>(a, c);
+
+ genCoreInsts<inst_set_t::avx2>(a, c);
+ }
asmjit::FuncUtils::emitEpilog(a, layout);
@@ -811,9 +992,21 @@ void GenConvKernel<int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>(
actRegAvx2_,
x86::dword_ptr(
in_acts_R_, scratchReg1_, 0, w_in * C_ * sizeof(uint8_t)));
- gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
+ if (C_per_G_ == 4) {
+ gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_);
+ } else {
+ a->vmovups(
+ stPermRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_,
+ scratchReg1_,
+ 0,
+ (w_in * C_ + 32) * sizeof(uint8_t)));
+ gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_);
+ }
}
}
+
// store results
storeResultRowoffset<inst_set_t::avx2>(a);
@@ -840,7 +1033,18 @@ void GenConvKernel<int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>(
actRegAvx2_,
x86::dword_ptr(
in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t)));
- gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
+ if (C_per_G_ == 4) {
+ gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_);
+ } else {
+ a->vmovups(
+ stPermRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_,
+ scratchReg1_,
+ 0,
+ (s * C_ + 32) * sizeof(uint8_t)));
+ gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_);
+ }
}
}
a->add(in_acts_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
@@ -879,7 +1083,13 @@ void GenConvKernel<int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>(
a->imul(scratchReg2_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
a->add(scratchReg1_, scratchReg2_);
a->vmovups(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_));
- gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
+ if (C_per_G_ == 4) {
+ gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_);
+ } else {
+ a->vmovups(
+ stPermRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_, 0, 32));
+ gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_);
+ }
}
}
@@ -919,7 +1129,18 @@ void GenConvKernel<int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>(
scratchReg1_,
0,
(s - W_PAD_) * C_ * sizeof(uint8_t)));
- gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
+ if (C_per_G_ == 4) {
+ gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_);
+ } else {
+ a->vmovups(
+ stPermRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_,
+ scratchReg1_,
+ 0,
+ ((s - W_PAD_) * C_ + 32) * sizeof(uint8_t)));
+ gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_);
+ }
}
a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
a->add(scratchReg1_, scratchReg2_);
@@ -977,7 +1198,14 @@ void GenConvKernel<int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>(
for (int r = 0; r < R_; ++r) {
for (int s = 0; s < S_ - W_PAD_; ++s) {
a->vmovups(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_));
- gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
+ if (C_per_G_ == 4) {
+ gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_);
+ } else {
+ a->vmovups(
+ stPermRegAvx2_,
+ x86::dword_ptr(in_acts_R_, scratchReg1_, 0, 32 * sizeof(uint8_t)));
+ gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_);
+ }
a->add(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
}
@@ -1037,7 +1265,18 @@ void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
scratchReg1_,
0,
(s - W_PAD_) * C_ * sizeof(uint8_t)));
- gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
+ if (C_per_G_ == 4) {
+ gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_);
+ } else {
+ a->vmovups(
+ stPermRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_,
+ scratchReg1_,
+ 0,
+ ((s - W_PAD_) * C_ + 32) * sizeof(uint8_t)));
+ gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_);
+ }
}
a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
a->add(scratchReg1_, scratchReg2_);
@@ -1072,7 +1311,15 @@ void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
actRegAvx2_,
x86::dword_ptr(
in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t)));
- gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
+ if (C_per_G_ == 4) {
+ gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_);
+ } else {
+ a->vmovups(
+ stPermRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_, scratchReg1_, 0, (s * C_ + 32) * sizeof(uint8_t)));
+ gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_);
+ }
}
a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
a->add(scratchReg1_, scratchReg2_);
@@ -1115,7 +1362,15 @@ void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
actRegAvx2_,
x86::dword_ptr(
in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t)));
- gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
+ if (C_per_G_ == 4) {
+ gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_);
+ } else {
+ a->vmovups(
+ stPermRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_, scratchReg1_, 0, (s * C_ + 32) * sizeof(uint8_t)));
+ gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_);
+ }
}
a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
a->add(scratchReg1_, scratchReg2_);
@@ -1164,7 +1419,14 @@ void GenConvKernel<int32_t>::genRowoffsetCore<inst_set_t::avx2>(
for (int s = 0; s < S_; ++s) {
a->vmovups(
actRegAvx2_, x86::dword_ptr(in_acts_R_, s * C_ * sizeof(uint8_t)));
- gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
+ if (C_per_G_ == 4) {
+ gen8BitSumX4<inst_set_t::avx2>(a, actRegAvx2_);
+ } else {
+ a->vmovups(
+ stPermRegAvx2_,
+ x86::dword_ptr(in_acts_R_, (s * C_ + 32) * sizeof(uint8_t)));
+ gen8BitSumX8<inst_set_t::avx2>(a, actRegAvx2_, stPermRegAvx2_);
+ }
}
a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
a->add(in_acts_R_, scratchReg2_);
@@ -1355,7 +1617,8 @@ void fbgemmGroupwiseConvBase_(
reinterpret_cast<float*>(rowOffsetTrDest),
ih_iw);
int gLimit = gOuter + 8;
- for (int g = gOuter; g < gLimit; g += 2) {
+ int gDelta = C_per_G == 4 ? 2 : 1;
+ for (int g = gOuter; g < gLimit; g += gDelta) {
int32_t* currOutBuf =
outBuffer + i * oh_ow * conv_param.OC + g * K_per_G;
const uint8_t* actStartGroup = actStartBatch + g * C_per_G;
@@ -1370,7 +1633,7 @@ void fbgemmGroupwiseConvBase_(
W);
// Output processing should be called for each group
- for (int j = 0; j < 2; ++j) {
+ for (int j = 0; j < gDelta; ++j) {
// calculateRowOffsets(
// conv_param, actStartGroup, rowOffsetBuf, a_zero_point, j);
int32_t* rowOffsetForCurG =
@@ -1486,6 +1749,7 @@ void fbgemmGroupwiseConv(
int thread_id,
int num_threads) {
typedef ReQuantizeOutput<FUSE_RELU, Q_GRAN> processOutputType;
+
if (!fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param) ||
(!cpuinfo_has_x86_avx512f() && !cpuinfo_has_x86_avx2())) {
return fbgemmGroupwiseConvBase_<
@@ -1544,7 +1808,8 @@ void fbgemmGroupwiseConv(
reinterpret_cast<float*>(rowOffsetTrDest),
ih_iw);
int gLimit = gOuter + 8;
- for (int g = gOuter; g < gLimit; g += 2) {
+ int gDelta = C_per_G == 4 ? 2 : 1;
+ for (int g = gOuter; g < gLimit; g += gDelta) {
int32_t* currOutBuf = outBuffer + (g - gOuter) * K_per_G;
const uint8_t* actStartGroup = actStartBatch + g * C_per_G;
@@ -1576,72 +1841,160 @@ void fbgemmGroupwiseConv(
int ld_out = K_per_G * G;
int ld_in = K_per_G * G;
- if (a_zero_point == 0) {
- if (b_symmetric) {
- if (outProcess.getBias() == nullptr) {
- requantizeOutputProcessingGConv4Avx2<
- true,
- true,
- Q_GRAN,
- false,
- FUSE_RELU>(out, inp, block, ld_out, ld_in, r);
+ if (C_per_G == 4) {
+ if (a_zero_point == 0) {
+ if (b_symmetric) {
+ if (outProcess.getBias() == nullptr) {
+ requantizeOutputProcessingGConvAvx2<
+ true,
+ true,
+ Q_GRAN,
+ false,
+ FUSE_RELU,
+ 4>(out, inp, block, ld_out, ld_in, r);
+ } else {
+ requantizeOutputProcessingGConvAvx2<
+ true,
+ true,
+ Q_GRAN,
+ true,
+ FUSE_RELU,
+ 4>(out, inp, block, ld_out, ld_in, r);
+ }
} else {
- requantizeOutputProcessingGConv4Avx2<
- true,
- true,
- Q_GRAN,
- true,
- FUSE_RELU>(out, inp, block, ld_out, ld_in, r);
+ if (outProcess.getBias() == nullptr) {
+ requantizeOutputProcessingGConvAvx2<
+ true,
+ false,
+ Q_GRAN,
+ false,
+ FUSE_RELU,
+ 4>(out, inp, block, ld_out, ld_in, r);
+ } else {
+ requantizeOutputProcessingGConvAvx2<
+ true,
+ false,
+ Q_GRAN,
+ true,
+ FUSE_RELU,
+ 4>(out, inp, block, ld_out, ld_in, r);
+ }
}
} else {
- if (outProcess.getBias() == nullptr) {
- requantizeOutputProcessingGConv4Avx2<
- true,
- false,
- Q_GRAN,
- false,
- FUSE_RELU>(out, inp, block, ld_out, ld_in, r);
+ if (b_symmetric) {
+ if (outProcess.getBias() == nullptr) {
+ requantizeOutputProcessingGConvAvx2<
+ false,
+ true,
+ Q_GRAN,
+ false,
+ FUSE_RELU,
+ 4>(out, inp, block, ld_out, ld_in, r);
+ } else {
+ requantizeOutputProcessingGConvAvx2<
+ false,
+ true,
+ Q_GRAN,
+ true,
+ FUSE_RELU,
+ 4>(out, inp, block, ld_out, ld_in, r);
+ }
} else {
- requantizeOutputProcessingGConv4Avx2<
- true,
- false,
- Q_GRAN,
- true,
- FUSE_RELU>(out, inp, block, ld_out, ld_in, r);
+ if (outProcess.getBias() == nullptr) {
+ requantizeOutputProcessingGConvAvx2<
+ false,
+ false,
+ Q_GRAN,
+ false,
+ FUSE_RELU,
+ 4>(out, inp, block, ld_out, ld_in, r);
+ } else {
+ requantizeOutputProcessingGConvAvx2<
+ false,
+ false,
+ Q_GRAN,
+ true,
+ FUSE_RELU,
+ 4>(out, inp, block, ld_out, ld_in, r);
+ }
}
}
} else {
- if (b_symmetric) {
- if (outProcess.getBias() == nullptr) {
- requantizeOutputProcessingGConv4Avx2<
- false,
- true,
- Q_GRAN,
- false,
- FUSE_RELU>(out, inp, block, ld_out, ld_in, r);
+ if (a_zero_point == 0) {
+ if (b_symmetric) {
+ if (outProcess.getBias() == nullptr) {
+ requantizeOutputProcessingGConvAvx2<
+ true,
+ true,
+ Q_GRAN,
+ false,
+ FUSE_RELU,
+ 8>(out, inp, block, ld_out, ld_in, r);
+ } else {
+ requantizeOutputProcessingGConvAvx2<
+ true,
+ true,
+ Q_GRAN,
+ true,
+ FUSE_RELU,
+ 8>(out, inp, block, ld_out, ld_in, r);
+ }
} else {
- requantizeOutputProcessingGConv4Avx2<
- false,
- true,
- Q_GRAN,
- true,
- FUSE_RELU>(out, inp, block, ld_out, ld_in, r);
+ if (outProcess.getBias() == nullptr) {
+ requantizeOutputProcessingGConvAvx2<
+ true,
+ false,
+ Q_GRAN,
+ false,
+ FUSE_RELU,
+ 8>(out, inp, block, ld_out, ld_in, r);
+ } else {
+ requantizeOutputProcessingGConvAvx2<
+ true,
+ false,
+ Q_GRAN,
+ true,
+ FUSE_RELU,
+ 8>(out, inp, block, ld_out, ld_in, r);
+ }
}
} else {
- if (outProcess.getBias() == nullptr) {
- requantizeOutputProcessingGConv4Avx2<
- false,
- false,
- Q_GRAN,
- false,
- FUSE_RELU>(out, inp, block, ld_out, ld_in, r);
+ if (b_symmetric) {
+ if (outProcess.getBias() == nullptr) {
+ requantizeOutputProcessingGConvAvx2<
+ false,
+ true,
+ Q_GRAN,
+ false,
+ FUSE_RELU,
+ 8>(out, inp, block, ld_out, ld_in, r);
+ } else {
+ requantizeOutputProcessingGConvAvx2<
+ false,
+ true,
+ Q_GRAN,
+ true,
+ FUSE_RELU,
+ 8>(out, inp, block, ld_out, ld_in, r);
+ }
} else {
- requantizeOutputProcessingGConv4Avx2<
- false,
- false,
- Q_GRAN,
- true,
- FUSE_RELU>(out, inp, block, ld_out, ld_in, r);
+ if (outProcess.getBias() == nullptr) {
+ requantizeOutputProcessingGConvAvx2<
+ false,
+ false,
+ Q_GRAN,
+ false,
+ FUSE_RELU,
+ 8>(out, inp, block, ld_out, ld_in, r);
+ } else {
+ requantizeOutputProcessingGConvAvx2<
+ false,
+ false,
+ Q_GRAN,
+ true,
+ FUSE_RELU,
+ 8>(out, inp, block, ld_out, ld_in, r);
+ }
}
}
}
@@ -1674,7 +2027,7 @@ int rowOffsetBufferSizeGConv(const conv_param_t<SPATIAL_DIM>& conv_param) {
int bufferSize = conv_param.OUT_DIM[0] * conv_param.OUT_DIM[1];
int C_per_G = conv_param.IC / conv_param.G;
int K_per_G = conv_param.OC / conv_param.G;
- if (C_per_G == 4 && K_per_G == 4) {
+ if ((C_per_G == 4 && K_per_G == 4) || (C_per_G == 8 && K_per_G == 8)) {
return 2 * 8 * bufferSize;
} else {
return conv_param.G * bufferSize;
@@ -1683,7 +2036,7 @@ int rowOffsetBufferSizeGConv(const conv_param_t<SPATIAL_DIM>& conv_param) {
int bufferSize = conv_param.OUT_DIM[0] * conv_param.OUT_DIM[1];
int C_per_G = conv_param.IC / conv_param.G;
int K_per_G = conv_param.OC / conv_param.G;
- if (C_per_G == 4 && K_per_G == 4) {
+ if ((C_per_G == 4 && K_per_G == 4) || (C_per_G == 8 && K_per_G == 8)) {
// row offset is calculated for 8 groups at a time
// 2x is needed for transposing
return 2 * 8 * bufferSize;
diff --git a/src/PackWeightMatrixForGConv.cc b/src/PackWeightMatrixForGConv.cc
index e38fba9..5870fa5 100644
--- a/src/PackWeightMatrixForGConv.cc
+++ b/src/PackWeightMatrixForGConv.cc
@@ -52,6 +52,7 @@ PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::PackWeightMatrixForGConv(
* For IC_per_G == 8, 16, 32 && OC_per_G == 8, 16, 32 there is no need to work
* on 2 groups at a time and full SIMD width can be efficiently utilized even
* while working on 1 group at a time.
+ * In this case, the layout is G (C/4) R S K 4
*/
template <typename T, typename accT, int SPATIAL_DIM>
void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack() {
@@ -77,11 +78,21 @@ void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack() {
[(((g * OC_per_G + k) * R + r) * S + s) * IC_per_G + c]
: sdata_
[(((g * R + r) * S + s) * IC_per_G + c) * OC_per_G + k];
- pdata_
- [(((((g / 2) * R + r) * S + s) * OC_per_G + k) * 2 +
- (g % 2)) *
- IC_per_G +
- c] = b;
+ if (IC_per_G == 4) {
+ // For IC_per_G == 4, we need to work on 2 groups at a time
+ pdata_
+ [(((((g / 2) * R + r) * S + s) * OC_per_G + k) * 2 +
+ (g % 2)) *
+ IC_per_G +
+ c] = b;
+ } else {
+ pdata_
+ [((((g * (IC_per_G / 4) + (c / 4)) * R + r) * S + s) *
+ OC_per_G +
+ k) *
+ 4 +
+ (c % 4)] = b;
+ }
}
}
}
diff --git a/src/QuantUtilsAvx2.cc b/src/QuantUtilsAvx2.cc
index 7c36f6d..bc41310 100644
--- a/src/QuantUtilsAvx2.cc
+++ b/src/QuantUtilsAvx2.cc
@@ -662,8 +662,9 @@ template <
bool B_SYMMETRIC,
QuantizationGranularity Q_GRAN,
bool HAS_BIAS,
- bool FUSE_RELU>
-void requantizeOutputProcessingGConv4Avx2(
+ bool FUSE_RELU,
+ int C_PER_G>
+void requantizeOutputProcessingGConvAvx2(
uint8_t* out,
const int32_t* inp,
const block_type_t& block,
@@ -744,82 +745,151 @@ void requantizeOutputProcessingGConv4Avx2(
}
if (!B_SYMMETRIC) {
- // Load row_offsets for 2 groups and broadcast by 4 times each because
- // we have 4 channels per group.
-
- // groups 0 and 1
- __m256i row_offset_v = _mm256_insertf128_si256(
- _mm256_castsi128_si256(
- _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 0])),
- _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 1]),
- 1);
+ __m256i row_offset_v;
+
+ // When C_PER_G == 4, we need to handle 2 groups at a time to fully
+ // utilize 32B AVX2 vector register (C_PER_G * 2 * sizeof(int32_t) ==
+ // 32B)
+ // When C_PER_G == 8, we just need 1 group at a time on the other hand.
+
+ // Groups 0 and 1 when C_PER_G == 4
+ // Group 0 when C_PER_G == 8
+ if (C_PER_G == 4) {
+ // Load row_offsets for 2 groups and broadcast by 4 times each because
+ // we have 4 channels per group.
+ // groups 0 and 1
+ row_offset_v = _mm256_insertf128_si256(
+ _mm256_castsi128_si256(
+ _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 0])),
+ _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 1]),
+ 1);
+ } else {
+ assert(C_PER_G == 8);
+ row_offset_v =
+ _mm256_set1_epi32(r.row_offsets
+ [(i - block.row_start) * 8 +
+ (j - block.col_start) / (VLEN * 4) * 4]);
+ }
__m256i B_zero_point_v = _mm256_set1_epi32(r.B_zero_point[0]);
if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
B_zero_point_v = _mm256_loadu_si256(
reinterpret_cast<const __m256i*>(r.B_zero_point + j));
} else if (Q_GRAN == QuantizationGranularity::GROUP) {
- B_zero_point_v = _mm256_insertf128_si256(
- _mm256_castsi128_si256(
- _mm_set1_epi32(r.B_zero_point[quant_param_idx])),
- _mm_set1_epi32(r.B_zero_point[quant_param_idx + 1]),
- 1);
+ if (C_PER_G == 4) {
+ B_zero_point_v = _mm256_insertf128_si256(
+ _mm256_castsi128_si256(
+ _mm_set1_epi32(r.B_zero_point[quant_param_idx])),
+ _mm_set1_epi32(r.B_zero_point[quant_param_idx + 1]),
+ 1);
+ } else {
+ B_zero_point_v = _mm256_set1_epi32(
+ r.B_zero_point
+ [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4]);
+ }
}
row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v);
x_v = _mm256_sub_epi32(x_v, row_offset_v);
- // groups 2 and 3
- row_offset_v = _mm256_insertf128_si256(
- _mm256_castsi128_si256(
- _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 2])),
- _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 3]),
- 1);
+ // Groups 2 and 3 when C_PER_G == 4
+ // Group 1 when C_PER_G == 8
+ if (C_PER_G == 4) {
+ // + 2 and + 3 here are for groups 2 and 3
+ row_offset_v = _mm256_insertf128_si256(
+ _mm256_castsi128_si256(
+ _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 2])),
+ _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 3]),
+ 1);
+ } else {
+ // + 1 here is for group 1
+ row_offset_v = _mm256_set1_epi32(
+ r.row_offsets
+ [(i - block.row_start) * 8 +
+ (j - block.col_start) / (VLEN * 4) * 4 + 1]);
+ }
if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
B_zero_point_v = _mm256_loadu_si256(
reinterpret_cast<const __m256i*>(r.B_zero_point + j + VLEN));
} else if (Q_GRAN == QuantizationGranularity::GROUP) {
- B_zero_point_v = _mm256_insertf128_si256(
- _mm256_castsi128_si256(
- _mm_set1_epi32(r.B_zero_point[quant_param_idx + 2])),
- _mm_set1_epi32(r.B_zero_point[quant_param_idx + 3]),
- 1);
+ if (C_PER_G == 4) {
+ B_zero_point_v = _mm256_insertf128_si256(
+ _mm256_castsi128_si256(
+ _mm_set1_epi32(r.B_zero_point[quant_param_idx + 2])),
+ _mm_set1_epi32(r.B_zero_point[quant_param_idx + 3]),
+ 1);
+ } else {
+ B_zero_point_v = _mm256_set1_epi32(
+ r.B_zero_point
+ [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 +
+ 1]);
+ }
}
row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v);
y_v = _mm256_sub_epi32(y_v, row_offset_v);
- // groups 4 and 5
- row_offset_v = _mm256_insertf128_si256(
- _mm256_castsi128_si256(
- _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 4])),
- _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 5]),
- 1);
+ // Groups 4 and 5 when C_PER_G == 4
+ // Group 2 when C_PER_G == 8
+ if (C_PER_G == 4) {
+ row_offset_v = _mm256_insertf128_si256(
+ _mm256_castsi128_si256(
+ _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 4])),
+ _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 5]),
+ 1);
+ } else {
+ row_offset_v = _mm256_set1_epi32(
+ r.row_offsets
+ [(i - block.row_start) * 8 +
+ (j - block.col_start) / (VLEN * 4) * 4 + 2]);
+ }
if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
B_zero_point_v = _mm256_loadu_si256(
reinterpret_cast<const __m256i*>(r.B_zero_point + j + 2 * VLEN));
} else if (Q_GRAN == QuantizationGranularity::GROUP) {
- B_zero_point_v = _mm256_insertf128_si256(
- _mm256_castsi128_si256(
- _mm_set1_epi32(r.B_zero_point[quant_param_idx + 4])),
- _mm_set1_epi32(r.B_zero_point[quant_param_idx + 5]),
- 1);
+ if (C_PER_G == 4) {
+ B_zero_point_v = _mm256_insertf128_si256(
+ _mm256_castsi128_si256(
+ _mm_set1_epi32(r.B_zero_point[quant_param_idx + 4])),
+ _mm_set1_epi32(r.B_zero_point[quant_param_idx + 5]),
+ 1);
+ } else {
+ B_zero_point_v = _mm256_set1_epi32(
+ r.B_zero_point
+ [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 +
+ 2]);
+ }
}
row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v);
z_v = _mm256_sub_epi32(z_v, row_offset_v);
- // groups 6 and 7
- row_offset_v = _mm256_insertf128_si256(
- _mm256_castsi128_si256(
- _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 6])),
- _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 7]),
- 1);
+ // Groups 6 and 7 when C_PER_G == 4
+ // Group 3 when C_PER_G == 8
+ if (C_PER_G == 4) {
+ row_offset_v = _mm256_insertf128_si256(
+ _mm256_castsi128_si256(
+ _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 6])),
+ _mm_set1_epi32(r.row_offsets[(i - block.row_start) * 8 + 7]),
+ 1);
+ } else {
+ row_offset_v = _mm256_set1_epi32(
+ r.row_offsets
+ [(i - block.row_start) * 8 +
+ (j - block.col_start) / (VLEN * 4) * 4 + 3]);
+ }
if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) {
B_zero_point_v = _mm256_loadu_si256(
reinterpret_cast<const __m256i*>(r.B_zero_point + j + 3 * VLEN));
} else if (Q_GRAN == QuantizationGranularity::GROUP) {
- B_zero_point_v = _mm256_insertf128_si256(
- _mm256_castsi128_si256(
- _mm_set1_epi32(r.B_zero_point[quant_param_idx + 6])),
- _mm_set1_epi32(r.B_zero_point[quant_param_idx + 7]),
- 1);
+ if (C_PER_G == 4) {
+ B_zero_point_v = _mm256_insertf128_si256(
+ _mm256_castsi128_si256(
+ _mm_set1_epi32(r.B_zero_point[quant_param_idx + 6])),
+ _mm_set1_epi32(r.B_zero_point[quant_param_idx + 7]),
+ 1);
+ } else {
+ B_zero_point_v = _mm256_set1_epi32(
+ r.B_zero_point
+ [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 +
+ 3]);
+ }
}
row_offset_v = _mm256_mullo_epi32(row_offset_v, B_zero_point_v);
w_v = _mm256_sub_epi32(w_v, row_offset_v);
@@ -868,33 +938,58 @@ void requantizeOutputProcessingGConv4Avx2(
_mm256_cvtepi32_ps(w_v),
_mm256_loadu_ps(r.C_multiplier + j + 3 * VLEN));
} else if (Q_GRAN == QuantizationGranularity::GROUP) {
- multiplier_v = _mm256_insertf128_ps(
- _mm256_castps128_ps256(
- _mm_set1_ps(r.C_multiplier[quant_param_idx])),
- _mm_set1_ps(r.C_multiplier[quant_param_idx + 1]),
- 1);
- x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
+ if (C_PER_G == 4) {
+ multiplier_v = _mm256_insertf128_ps(
+ _mm256_castps128_ps256(
+ _mm_set1_ps(r.C_multiplier[quant_param_idx])),
+ _mm_set1_ps(r.C_multiplier[quant_param_idx + 1]),
+ 1);
+ x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
- multiplier_v = _mm256_insertf128_ps(
- _mm256_castps128_ps256(
- _mm_set1_ps(r.C_multiplier[quant_param_idx + 2])),
- _mm_set1_ps(r.C_multiplier[quant_param_idx + 3]),
- 1);
- y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v);
+ multiplier_v = _mm256_insertf128_ps(
+ _mm256_castps128_ps256(
+ _mm_set1_ps(r.C_multiplier[quant_param_idx + 2])),
+ _mm_set1_ps(r.C_multiplier[quant_param_idx + 3]),
+ 1);
+ y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v);
- multiplier_v = _mm256_insertf128_ps(
- _mm256_castps128_ps256(
- _mm_set1_ps(r.C_multiplier[quant_param_idx + 4])),
- _mm_set1_ps(r.C_multiplier[quant_param_idx + 5]),
- 1);
- z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v);
+ multiplier_v = _mm256_insertf128_ps(
+ _mm256_castps128_ps256(
+ _mm_set1_ps(r.C_multiplier[quant_param_idx + 4])),
+ _mm_set1_ps(r.C_multiplier[quant_param_idx + 5]),
+ 1);
+ z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v);
- multiplier_v = _mm256_insertf128_ps(
- _mm256_castps128_ps256(
- _mm_set1_ps(r.C_multiplier[quant_param_idx + 6])),
- _mm_set1_ps(r.C_multiplier[quant_param_idx + 7]),
- 1);
- w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v);
+ multiplier_v = _mm256_insertf128_ps(
+ _mm256_castps128_ps256(
+ _mm_set1_ps(r.C_multiplier[quant_param_idx + 6])),
+ _mm_set1_ps(r.C_multiplier[quant_param_idx + 7]),
+ 1);
+ w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v);
+ } else {
+ multiplier_v = _mm256_set1_ps(
+ r.C_multiplier
+ [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4]);
+ x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
+
+ multiplier_v = _mm256_set1_ps(
+ r.C_multiplier
+ [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 +
+ 1]);
+ y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v);
+
+ multiplier_v = _mm256_set1_ps(
+ r.C_multiplier
+ [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 +
+ 2]);
+ z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v);
+
+ multiplier_v = _mm256_set1_ps(
+ r.C_multiplier
+ [quant_param_idx + (j - block.col_start) / (VLEN * 4) * 4 +
+ 3]);
+ w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v);
+ }
} else {
x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v);
@@ -970,22 +1065,30 @@ void requantizeOutputProcessingGConv4Avx2(
} // i loop
}
-#define INSTANTIATE_REQUANTIZE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU) \
- template void \
- requantizeOutputProcessingAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \
- uint8_t * out, \
- const int32_t* inp, \
- const block_type_t& block, \
- int ld_out, \
- int ld_in, \
- const requantizationParams_t& r); \
- template void \
- requantizeOutputProcessingGConv4Avx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \
- uint8_t * out, \
- const int32_t* inp, \
- const block_type_t& block, \
- int ld_out, \
- int ld_in, \
+#define INSTANTIATE_REQUANTIZE(A_SYM, B_SYM, Q_GRAN, BIAS, RELU) \
+ template void \
+ requantizeOutputProcessingAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU>( \
+ uint8_t * out, \
+ const int32_t* inp, \
+ const block_type_t& block, \
+ int ld_out, \
+ int ld_in, \
+ const requantizationParams_t& r); \
+ template void \
+ requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 4>( \
+ uint8_t * out, \
+ const int32_t* inp, \
+ const block_type_t& block, \
+ int ld_out, \
+ int ld_in, \
+ const requantizationParams_t& r); \
+ template void \
+ requantizeOutputProcessingGConvAvx2<A_SYM, B_SYM, Q_GRAN, BIAS, RELU, 8>( \
+ uint8_t * out, \
+ const int32_t* inp, \
+ const block_type_t& block, \
+ int ld_out, \
+ int ld_in, \
const requantizationParams_t& r);
#define INSTANTIATE_A_SYM(B_SYM, Q_GRAN, BIAS, RELU) \
diff --git a/test/GConvTest.cc b/test/GConvTest.cc
index 34042c6..e1434d2 100644
--- a/test/GConvTest.cc
+++ b/test/GConvTest.cc
@@ -71,6 +71,16 @@ static vector<conv_param_t<>> GetShapes_() {
conv_param_t<>(1, 128, 128, {48, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
conv_param_t<>(1, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
conv_param_t<>(2, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+
+ conv_param_t<>(1, 64, 64, {3, 3}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 64, 64, {4, 4}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 64, 64, {3, 5}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 64, 64, {5, 3}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 16, 16, {5, 5}, 2, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 256, 256, {56, 48}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 256, 256, {48, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 256, 256, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(2, 256, 256, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
};
return shapes;
}