Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJongsoo Park <jongsoo@fb.com>2019-02-02 21:32:16 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-02 21:35:25 +0300
commita7d921d447564cdc67a07ad79dfc9e0d3dcd7418 (patch)
treee14bd4cc18a5a3e290918c77d3d3ce23d3423331 /src
parentef0ad4c0bb21248f276d8a8d380dfc9e37b9b141 (diff)
minor optimization in handling zero points for row offset (#63)
Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/63 During row_offset compute, pre-multiply zero point by C_per_G . Instead of generating multiple vpaddd instructions for adding zero points, we just multiply by a constant. Reviewed By: dskhudia Differential Revision: D13833686 fbshipit-source-id: a42a447955380f6cfde3fbae20ec16d47423bdd6
Diffstat (limited to 'src')
-rw-r--r--src/GroupwiseConv.h4
-rw-r--r--src/GroupwiseConvAcc32Avx2.cc133
2 files changed, 58 insertions, 79 deletions
diff --git a/src/GroupwiseConv.h b/src/GroupwiseConv.h
index a46a895..0d2db3a 100644
--- a/src/GroupwiseConv.h
+++ b/src/GroupwiseConv.h
@@ -169,6 +169,10 @@ class GenConvKernel {
template <inst_set_t instSet>
void gen8BitSum(asmjit::X86Emitter* a, asmjit::X86Ymm aReg);
+ // Use scratchReg1_ and tmpReg1Avx2_ internally
+ template <inst_set_t instSet>
+ void genZeroPtSum(asmjit::X86Emitter* a, int multiplier);
+
template <inst_set_t instSet>
void genForTopEdgeRowoffset(asmjit::X86Emitter* a);
diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc
index 7906c04..c85e339 100644
--- a/src/GroupwiseConvAcc32Avx2.cc
+++ b/src/GroupwiseConvAcc32Avx2.cc
@@ -213,6 +213,20 @@ void GenConvKernel<int32_t>::gen8BitSum<inst_set_t::avx2>(
template <>
template <>
+void GenConvKernel<int32_t>::genZeroPtSum<inst_set_t::avx2>(
+ asmjit::X86Emitter* a,
+ int multiplier) {
+ a->mov(scratchReg1_, static_cast<asmjit::Imm>(multiplier));
+ // tmpReg1Avx2_ also uses xmm11
+ asmjit::X86Xmm const_reg_xmm = x86::xmm11;
+ a->movq(const_reg_xmm, scratchReg1_);
+ a->vpbroadcastd(tmpReg1Avx2_, const_reg_xmm);
+ a->vpmulld(tmpReg1Avx2_, zeroPTRegAvx2_, tmpReg1Avx2_);
+ a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_);
+}
+
+template <>
+template <>
void GenConvKernel<int32_t>::genForTopEdge<inst_set_t::avx2>(
asmjit::X86Emitter* a) {
// top-left corner code
@@ -782,27 +796,22 @@ void GenConvKernel<int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>(
// top-left corner code
// zero out the results register
a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
- for (int r = 0; r < R_; ++r) {
+ if (!isZeroPointZero_) {
+ genZeroPtSum<inst_set_t::avx2>(a, R_ * S_ - (R_ - H_PAD_) * (S_ - W_PAD_));
+ }
+ for (int r = H_PAD_; r < R_; ++r) {
int h_in = -H_PAD_ + r;
- if (h_in >= 0) {
- a->imul(
- scratchReg1_,
- W_R_,
- static_cast<asmjit::Imm>(h_in * C_ * sizeof(uint8_t)));
- }
- for (int s = 0; s < S_; ++s) {
+ a->imul(
+ scratchReg1_,
+ W_R_,
+ static_cast<asmjit::Imm>(h_in * C_ * sizeof(uint8_t)));
+ for (int s = W_PAD_; s < S_; ++s) {
int w_in = -W_PAD_ + s;
- if (h_in >= 0 && w_in >= 0) {
- a->vmovups(
- actRegAvx2_,
- x86::dword_ptr(
- in_acts_R_, scratchReg1_, 0, w_in * C_ * sizeof(uint8_t)));
- gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
- } else {
- if (!isZeroPointZero_) {
- gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
- }
- }
+ a->vmovups(
+ actRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_, scratchReg1_, 0, w_in * C_ * sizeof(uint8_t)));
+ gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
}
}
// store results
@@ -818,11 +827,7 @@ void GenConvKernel<int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>(
// zero out
a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
if (!isZeroPointZero_) {
- for (int r = 0; r < H_PAD_; ++r) {
- for (int s = 0; s < S_; ++s) {
- gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
- }
- }
+ genZeroPtSum<inst_set_t::avx2>(a, H_PAD_ * S_);
}
for (int r = H_PAD_; r < R_; ++r) {
int h_in = -H_PAD_ + r;
@@ -860,11 +865,7 @@ void GenConvKernel<int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>(
// zero out
a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
if (!isZeroPointZero_) {
- for (int r = 0; r < H_PAD_; ++r) {
- for (int s = 0; s < S_; ++s) {
- gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
- }
- }
+ genZeroPtSum<inst_set_t::avx2>(a, R_ * S_ - (R_ - H_PAD_) * (S_ - W_PAD_));
}
for (int r = H_PAD_; r < R_; ++r) {
int h_in = -H_PAD_ + r;
@@ -880,11 +881,6 @@ void GenConvKernel<int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>(
a->vmovups(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_));
gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
}
- if (!isZeroPointZero_) {
- for (int s = S_ - W_PAD_; s < S_; ++s) {
- gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
- }
- }
}
// store results
@@ -907,16 +903,14 @@ void GenConvKernel<int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>(
a->bind(LoopLeftEdge);
// zero out
a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ if (!isZeroPointZero_) {
+ genZeroPtSum<inst_set_t::avx2>(a, R_ * W_PAD_);
+ }
a->mov(scratchReg1_, loopR1_);
a->sub(scratchReg1_, static_cast<asmjit::Imm>(H_PAD_));
a->imul(scratchReg1_, W_R_);
a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
for (int r = 0; r < R_; ++r) {
- if (!isZeroPointZero_) {
- for (int s = 0; s < W_PAD_; ++s) {
- gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
- }
- }
for (int s = W_PAD_; s < S_; ++s) {
a->vmovups(
actRegAvx2_,
@@ -968,6 +962,9 @@ void GenConvKernel<int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>(
a->bind(LoopRightEdge);
// zero out
a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ if (!isZeroPointZero_) {
+ genZeroPtSum<inst_set_t::avx2>(a, R_ * W_PAD_);
+ }
a->mov(scratchReg1_, loopR1_);
a->sub(scratchReg1_, static_cast<asmjit::Imm>(H_PAD_));
a->imul(scratchReg1_, W_R_);
@@ -979,16 +976,10 @@ void GenConvKernel<int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>(
a->add(scratchReg1_, scratchReg2_);
for (int r = 0; r < R_; ++r) {
for (int s = 0; s < S_ - W_PAD_; ++s) {
- a->vbroadcastsd(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_));
a->vmovups(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_));
gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
a->add(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
}
- if (!isZeroPointZero_) {
- for (int s = S_ - W_PAD_; s < S_; ++s) {
- gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
- }
- }
a->sub(
scratchReg1_,
@@ -1030,16 +1021,14 @@ void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
// bottom-left corner
// zero out
a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ if (!isZeroPointZero_) {
+ genZeroPtSum<inst_set_t::avx2>(a, R_ * S_ - (R_ - H_PAD_) * (S_ - W_PAD_));
+ }
a->mov(scratchReg1_, H_R_);
a->sub(scratchReg1_, static_cast<asmjit::Imm>(2 * H_PAD_));
a->imul(scratchReg1_, W_R_);
a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
for (int r = 0; r < R_ - H_PAD_; ++r) {
- if (!isZeroPointZero_) {
- for (int s = 0; s < W_PAD_; ++s) {
- gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
- }
- }
for (int s = W_PAD_; s < S_; ++s) {
a->vmovups(
actRegAvx2_,
@@ -1053,13 +1042,6 @@ void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
a->add(scratchReg1_, scratchReg2_);
}
- if (!isZeroPointZero_) {
- for (int r = R_ - H_PAD_; r < R_; ++r) {
- for (int s = 0; s < S_; ++s) {
- gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
- }
- }
- }
// we updating the last row
a->mov(scratchReg1_, H_R_);
@@ -1076,6 +1058,9 @@ void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
a->bind(LoopBottomEdge);
// zero out
a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ if (!isZeroPointZero_) {
+ genZeroPtSum<inst_set_t::avx2>(a, H_PAD_ * S_);
+ }
a->mov(scratchReg1_, H_R_);
a->sub(scratchReg1_, static_cast<asmjit::Imm>(2 * H_PAD_));
a->imul(scratchReg1_, W_R_);
@@ -1093,14 +1078,6 @@ void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
a->add(scratchReg1_, scratchReg2_);
}
- if (!isZeroPointZero_) {
- for (int r = R_ - W_PAD_; r < R_; ++r) {
- for (int s = 0; s < S_; ++s) {
- gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
- }
- }
- }
-
a->add(in_acts_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
// storeResult<inst_set_t::avx2>(a, ((H_-1)*W_+1)*8*sizeof(int32_t));
storeResultRowoffset<inst_set_t::avx2>(a);
@@ -1121,6 +1098,9 @@ void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
// bottom-right corner
// zero out
a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ if (!isZeroPointZero_) {
+ genZeroPtSum<inst_set_t::avx2>(a, R_ * S_ - (R_ - H_PAD_) * (S_ - W_PAD_));
+ }
// input start point
// ((H_-(R_-H_PAD_))*W_+(W_-(S_-W_PAD_)))*C_*sizeof(uint8_t)
a->mov(scratchReg1_, H_R_);
@@ -1139,19 +1119,6 @@ void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
}
a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
a->add(scratchReg1_, scratchReg2_);
- if (!isZeroPointZero_) {
- for (int s = S_ - W_PAD_; s < S_; ++s) {
- gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
- }
- }
- }
-
- if (!isZeroPointZero_) {
- for (int r = R_ - H_PAD_; r < R_; ++r) {
- for (int s = 0; s < S_; ++s) {
- gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
- }
- }
}
storeResultRowoffset<inst_set_t::avx2>(a);
@@ -1288,7 +1255,15 @@ GenConvKernel<int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
// This uses xmm10 register temporarily. Should come before
// createVector8BitOne
if (!isZeroPointZero_) {
- setToZeroPt<inst_set_t::avx2>(a, zeroPTRegAvx2_);
+ // we can use xmm11 because ymm11 is used by tmpReg1Avx2_
+ asmjit::X86Xmm const_reg_xmm = x86::xmm11;
+ a->movq(const_reg_xmm, a_zero_pt_R_);
+ a->vpbroadcastd(zeroPTRegAvx2_, const_reg_xmm);
+
+ a->mov(scratchReg1_, static_cast<asmjit::Imm>(C_per_G_));
+ a->movq(const_reg_xmm, scratchReg1_);
+ a->vpbroadcastd(tmpReg1Avx2_, const_reg_xmm);
+ a->vpmulld(zeroPTRegAvx2_, zeroPTRegAvx2_, tmpReg1Avx2_);
}
createVector16BitOne<inst_set_t::avx2>(a);