Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/GroupwiseConvAcc32Avx2.cc')
-rw-r--r--src/GroupwiseConvAcc32Avx2.cc133
1 files changed, 54 insertions, 79 deletions
diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc
index 7906c04..c85e339 100644
--- a/src/GroupwiseConvAcc32Avx2.cc
+++ b/src/GroupwiseConvAcc32Avx2.cc
@@ -213,6 +213,20 @@ void GenConvKernel<int32_t>::gen8BitSum<inst_set_t::avx2>(
template <>
template <>
+void GenConvKernel<int32_t>::genZeroPtSum<inst_set_t::avx2>(
+ asmjit::X86Emitter* a,
+ int multiplier) {
+ a->mov(scratchReg1_, static_cast<asmjit::Imm>(multiplier));
+ // tmpReg1Avx2_ also uses xmm11
+ asmjit::X86Xmm const_reg_xmm = x86::xmm11;
+ a->movq(const_reg_xmm, scratchReg1_);
+ a->vpbroadcastd(tmpReg1Avx2_, const_reg_xmm);
+ a->vpmulld(tmpReg1Avx2_, zeroPTRegAvx2_, tmpReg1Avx2_);
+ a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_);
+}
+
+template <>
+template <>
void GenConvKernel<int32_t>::genForTopEdge<inst_set_t::avx2>(
asmjit::X86Emitter* a) {
// top-left corner code
@@ -782,27 +796,22 @@ void GenConvKernel<int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>(
// top-left corner code
// zero out the results register
a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
- for (int r = 0; r < R_; ++r) {
+ if (!isZeroPointZero_) {
+ genZeroPtSum<inst_set_t::avx2>(a, R_ * S_ - (R_ - H_PAD_) * (S_ - W_PAD_));
+ }
+ for (int r = H_PAD_; r < R_; ++r) {
int h_in = -H_PAD_ + r;
- if (h_in >= 0) {
- a->imul(
- scratchReg1_,
- W_R_,
- static_cast<asmjit::Imm>(h_in * C_ * sizeof(uint8_t)));
- }
- for (int s = 0; s < S_; ++s) {
+ a->imul(
+ scratchReg1_,
+ W_R_,
+ static_cast<asmjit::Imm>(h_in * C_ * sizeof(uint8_t)));
+ for (int s = W_PAD_; s < S_; ++s) {
int w_in = -W_PAD_ + s;
- if (h_in >= 0 && w_in >= 0) {
- a->vmovups(
- actRegAvx2_,
- x86::dword_ptr(
- in_acts_R_, scratchReg1_, 0, w_in * C_ * sizeof(uint8_t)));
- gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
- } else {
- if (!isZeroPointZero_) {
- gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
- }
- }
+ a->vmovups(
+ actRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_, scratchReg1_, 0, w_in * C_ * sizeof(uint8_t)));
+ gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
}
}
// store results
@@ -818,11 +827,7 @@ void GenConvKernel<int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>(
// zero out
a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
if (!isZeroPointZero_) {
- for (int r = 0; r < H_PAD_; ++r) {
- for (int s = 0; s < S_; ++s) {
- gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
- }
- }
+ genZeroPtSum<inst_set_t::avx2>(a, H_PAD_ * S_);
}
for (int r = H_PAD_; r < R_; ++r) {
int h_in = -H_PAD_ + r;
@@ -860,11 +865,7 @@ void GenConvKernel<int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>(
// zero out
a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
if (!isZeroPointZero_) {
- for (int r = 0; r < H_PAD_; ++r) {
- for (int s = 0; s < S_; ++s) {
- gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
- }
- }
+ genZeroPtSum<inst_set_t::avx2>(a, R_ * S_ - (R_ - H_PAD_) * (S_ - W_PAD_));
}
for (int r = H_PAD_; r < R_; ++r) {
int h_in = -H_PAD_ + r;
@@ -880,11 +881,6 @@ void GenConvKernel<int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>(
a->vmovups(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_));
gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
}
- if (!isZeroPointZero_) {
- for (int s = S_ - W_PAD_; s < S_; ++s) {
- gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
- }
- }
}
// store results
@@ -907,16 +903,14 @@ void GenConvKernel<int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>(
a->bind(LoopLeftEdge);
// zero out
a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ if (!isZeroPointZero_) {
+ genZeroPtSum<inst_set_t::avx2>(a, R_ * W_PAD_);
+ }
a->mov(scratchReg1_, loopR1_);
a->sub(scratchReg1_, static_cast<asmjit::Imm>(H_PAD_));
a->imul(scratchReg1_, W_R_);
a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
for (int r = 0; r < R_; ++r) {
- if (!isZeroPointZero_) {
- for (int s = 0; s < W_PAD_; ++s) {
- gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
- }
- }
for (int s = W_PAD_; s < S_; ++s) {
a->vmovups(
actRegAvx2_,
@@ -968,6 +962,9 @@ void GenConvKernel<int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>(
a->bind(LoopRightEdge);
// zero out
a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ if (!isZeroPointZero_) {
+ genZeroPtSum<inst_set_t::avx2>(a, R_ * W_PAD_);
+ }
a->mov(scratchReg1_, loopR1_);
a->sub(scratchReg1_, static_cast<asmjit::Imm>(H_PAD_));
a->imul(scratchReg1_, W_R_);
@@ -979,16 +976,10 @@ void GenConvKernel<int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>(
a->add(scratchReg1_, scratchReg2_);
for (int r = 0; r < R_; ++r) {
for (int s = 0; s < S_ - W_PAD_; ++s) {
- a->vbroadcastsd(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_));
a->vmovups(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_));
gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
a->add(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
}
- if (!isZeroPointZero_) {
- for (int s = S_ - W_PAD_; s < S_; ++s) {
- gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
- }
- }
a->sub(
scratchReg1_,
@@ -1030,16 +1021,14 @@ void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
// bottom-left corner
// zero out
a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ if (!isZeroPointZero_) {
+ genZeroPtSum<inst_set_t::avx2>(a, R_ * S_ - (R_ - H_PAD_) * (S_ - W_PAD_));
+ }
a->mov(scratchReg1_, H_R_);
a->sub(scratchReg1_, static_cast<asmjit::Imm>(2 * H_PAD_));
a->imul(scratchReg1_, W_R_);
a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
for (int r = 0; r < R_ - H_PAD_; ++r) {
- if (!isZeroPointZero_) {
- for (int s = 0; s < W_PAD_; ++s) {
- gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
- }
- }
for (int s = W_PAD_; s < S_; ++s) {
a->vmovups(
actRegAvx2_,
@@ -1053,13 +1042,6 @@ void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
a->add(scratchReg1_, scratchReg2_);
}
- if (!isZeroPointZero_) {
- for (int r = R_ - H_PAD_; r < R_; ++r) {
- for (int s = 0; s < S_; ++s) {
- gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
- }
- }
- }
// we updating the last row
a->mov(scratchReg1_, H_R_);
@@ -1076,6 +1058,9 @@ void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
a->bind(LoopBottomEdge);
// zero out
a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ if (!isZeroPointZero_) {
+ genZeroPtSum<inst_set_t::avx2>(a, H_PAD_ * S_);
+ }
a->mov(scratchReg1_, H_R_);
a->sub(scratchReg1_, static_cast<asmjit::Imm>(2 * H_PAD_));
a->imul(scratchReg1_, W_R_);
@@ -1093,14 +1078,6 @@ void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
a->add(scratchReg1_, scratchReg2_);
}
- if (!isZeroPointZero_) {
- for (int r = R_ - W_PAD_; r < R_; ++r) {
- for (int s = 0; s < S_; ++s) {
- gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
- }
- }
- }
-
a->add(in_acts_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
// storeResult<inst_set_t::avx2>(a, ((H_-1)*W_+1)*8*sizeof(int32_t));
storeResultRowoffset<inst_set_t::avx2>(a);
@@ -1121,6 +1098,9 @@ void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
// bottom-right corner
// zero out
a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ if (!isZeroPointZero_) {
+ genZeroPtSum<inst_set_t::avx2>(a, R_ * S_ - (R_ - H_PAD_) * (S_ - W_PAD_));
+ }
// input start point
// ((H_-(R_-H_PAD_))*W_+(W_-(S_-W_PAD_)))*C_*sizeof(uint8_t)
a->mov(scratchReg1_, H_R_);
@@ -1139,19 +1119,6 @@ void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
}
a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
a->add(scratchReg1_, scratchReg2_);
- if (!isZeroPointZero_) {
- for (int s = S_ - W_PAD_; s < S_; ++s) {
- gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
- }
- }
- }
-
- if (!isZeroPointZero_) {
- for (int r = R_ - H_PAD_; r < R_; ++r) {
- for (int s = 0; s < S_; ++s) {
- gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
- }
- }
}
storeResultRowoffset<inst_set_t::avx2>(a);
@@ -1288,7 +1255,15 @@ GenConvKernel<int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
// This uses xmm10 register temporarily. Should come before
// createVector8BitOne
if (!isZeroPointZero_) {
- setToZeroPt<inst_set_t::avx2>(a, zeroPTRegAvx2_);
+ // we can use xmm11 because ymm11 is used by tmpReg1Avx2_
+ asmjit::X86Xmm const_reg_xmm = x86::xmm11;
+ a->movq(const_reg_xmm, a_zero_pt_R_);
+ a->vpbroadcastd(zeroPTRegAvx2_, const_reg_xmm);
+
+ a->mov(scratchReg1_, static_cast<asmjit::Imm>(C_per_G_));
+ a->movq(const_reg_xmm, scratchReg1_);
+ a->vpbroadcastd(tmpReg1Avx2_, const_reg_xmm);
+ a->vpmulld(zeroPTRegAvx2_, zeroPTRegAvx2_, tmpReg1Avx2_);
}
createVector16BitOne<inst_set_t::avx2>(a);