Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJianyu Huang <jianyuhuang@fb.com>2019-02-14 02:46:39 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-14 03:14:19 +0300
commit5ec2641b8a4bb74827f9c46357c69c4af1f6cfe0 (patch)
tree99582190a2a21c1c00bd07dab0f4fefc9a4c9a5d /src
parentaa88eafebf8df24dfa8c60768b1103f6981f7322 (diff)
JIT kernel should only handle a small portion of NCB for the last block: multiple of NR
Summary: Before this Diff: we pass into the JIT kernel with nc = NCB ( packedB_.blockColSize() ) instead of nc = leftover size (packedB_.lastBcol() ) for the last block of B (diffusion/FBS/browse/master/fbcode/deeplearning/fbgemm/src/ExecuteKernelU8S8.cc;1adfe7977ef7ea2a1aee0ed785bd3fed5b7c4a20$102), which cause the additional computation when n is small. After this Diff: we pass into the JIT kernel with a small portion of NCB (still multiple of NR) for the last block of B. The main performance gain is for Acc16, because NCB = 4 * NR for Acc16 and NCB = NR for Acc32 in our current settings (AVX2 and AVX512). Reviewed By: jspark1105 Differential Revision: D14063628 fbshipit-source-id: 5829d06553daf617e2fefa7d26cb2d761af402c1
Diffstat (limited to 'src')
-rw-r--r--src/ExecuteKernelU8S8.cc29
-rw-r--r--src/ExecuteKernelU8S8.h1
-rw-r--r--src/GenerateKernelU8S8S32ACC16.cc13
-rw-r--r--src/GenerateKernelU8S8S32ACC16Avx512.cc14
-rw-r--r--src/GenerateKernelU8S8S32ACC32.cc15
-rw-r--r--src/GenerateKernelU8S8S32ACC32Avx512.cc16
6 files changed, 72 insertions, 16 deletions
diff --git a/src/ExecuteKernelU8S8.cc b/src/ExecuteKernelU8S8.cc
index cdceb63..64bac90 100644
--- a/src/ExecuteKernelU8S8.cc
+++ b/src/ExecuteKernelU8S8.cc
@@ -53,6 +53,10 @@ ExecuteKernel<
int8_t,
typename packingAMatrix::accType,
inst_set_t::avx512>::NCB;
+ nrSize_ = PackingTraits<
+ int8_t,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512>::NR;
} else if (cpuinfo_has_x86_avx2()) {
mbSize_ = PackingTraits<
int8_t,
@@ -62,6 +66,10 @@ ExecuteKernel<
int8_t,
typename packingAMatrix::accType,
inst_set_t::avx2>::NCB;
+ nrSize_ = PackingTraits<
+ int8_t,
+ typename packingAMatrix::accType,
+ inst_set_t::avx2>::NR;
} else {
assert(0 && "unsupported architecure");
}
@@ -125,6 +133,27 @@ void ExecuteKernel<
#endif
for (int jb = 0; jb < bColBlocks; ++jb) {
+ if (jb == bColBlocks - 1) {
+ int nc = ((packedB_.lastBcol() - 1) / nrSize_ + 1) * nrSize_;
+ if (nc != nbSize_) {
+ if (cpuinfo_initialize()) {
+ if (cpuinfo_has_x86_avx512f()) {
+ fn = BaseType::template getOrCreate<inst_set_t::avx512>(
+ accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_);
+ } else if (cpuinfo_has_x86_avx2()) {
+ fn = BaseType::template getOrCreate<inst_set_t::avx2>(
+ accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_);
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unsupported architecture");
+ return;
+ }
+ } else {
+ throw std::runtime_error("Failed to initialize cpuinfo!");
+ }
+ }
+ }
+
bBuf = packedB_.getBuf(jb, kBlock);
// prefetch addr of the next packed block of B matrix
bBuf_pf = packedB_.getBuf(jb == bColBlocks - 1 ? jb : jb + 1, kBlock);
diff --git a/src/ExecuteKernelU8S8.h b/src/ExecuteKernelU8S8.h
index 253c4e6..c3f5486 100644
--- a/src/ExecuteKernelU8S8.h
+++ b/src/ExecuteKernelU8S8.h
@@ -71,6 +71,7 @@ class ExecuteKernel<
///< multiple of N.
int mbSize_; ///< block size in the m dimension.
int nbSize_; ///< block size in the n dimension.
+ int nrSize_; ///< register size in the n dimension.
};
} // namespace fbgemm
diff --git a/src/GenerateKernelU8S8S32ACC16.cc b/src/GenerateKernelU8S8S32ACC16.cc
index b7bc676..3419281 100644
--- a/src/GenerateKernelU8S8S32ACC16.cc
+++ b/src/GenerateKernelU8S8S32ACC16.cc
@@ -151,6 +151,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
// code_.setLogger(&logger);
constexpr int kBlock = PackingTraits<int8_t, int16_t, inst_set_t::avx2>::KCB;
+ constexpr int nBlock = PackingTraits<int8_t, int16_t, inst_set_t::avx2>::NCB;
constexpr int mRegBlockSize =
PackingTraits<int8_t, int16_t, inst_set_t::avx2>::MR;
// constexpr int nRegBlockSize =
@@ -237,8 +238,10 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
// update buffer_B address for next k iteration
a->add(
- buffer_B, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t)));
- // a->add(B_pf, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t)));
+ buffer_B,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+ // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ // sizeof(int8_t)));
a->cmp(kIdx, kSize);
a->jl(Loopk);
@@ -285,8 +288,10 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
// update buffer_B address for next k iteration
a->add(
- buffer_B, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t)));
- // a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t));
+ buffer_B,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+ // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ // sizeof(int8_t)));
a->cmp(kIdx, kSize);
a->jl(LoopkRem);
diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc
index eeeaea0..ef625e2 100644
--- a/src/GenerateKernelU8S8S32ACC16Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc
@@ -140,6 +140,8 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
constexpr int kBlock =
PackingTraits<int8_t, int16_t, inst_set_t::avx512>::KCB;
+ constexpr int nBlock =
+ PackingTraits<int8_t, int16_t, inst_set_t::avx512>::NCB;
constexpr int mRegBlockSize =
PackingTraits<int8_t, int16_t, inst_set_t::avx512>::MR;
// constexpr int nRegBlockSize =
@@ -226,8 +228,10 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
// update buffer_B address for next k iteration
a->add(
- buffer_B, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t)));
- // a->add(B_pf, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t)));
+ buffer_B,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+ // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ // sizeof(int8_t)));
a->cmp(kIdx, kSize);
a->jl(Loopk);
@@ -275,8 +279,10 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
// update buffer_B address for next k iteration
a->add(
- buffer_B, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t)));
- // a->add(B_pf, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t)));
+ buffer_B,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+ // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ // sizeof(int8_t)));
a->cmp(kIdx, kSize);
a->jl(LoopkRem);
diff --git a/src/GenerateKernelU8S8S32ACC32.cc b/src/GenerateKernelU8S8S32ACC32.cc
index f788872..ab0625c 100644
--- a/src/GenerateKernelU8S8S32ACC32.cc
+++ b/src/GenerateKernelU8S8S32ACC32.cc
@@ -155,6 +155,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>(
// code_.setLogger(&logger);
constexpr int kBlock = PackingTraits<int8_t, int32_t, inst_set_t::avx2>::KCB;
+ constexpr int nBlock = PackingTraits<int8_t, int32_t, inst_set_t::avx2>::NCB;
constexpr int mRegBlockSize =
PackingTraits<int8_t, int32_t, inst_set_t::avx2>::MR;
constexpr int row_interleave =
@@ -249,8 +250,11 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>(
// update buffer_B address for next k iteration
a->add(
- buffer_B, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t)));
- a->add(B_pf, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t)));
+ buffer_B,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+ a->add(
+ B_pf,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
// a->add(B_pf, 32*sizeof(float));
@@ -301,8 +305,11 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>(
// update buffer_B address for next k iteration
a->add(
- buffer_B, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t)));
- a->add(B_pf, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t)));
+ buffer_B,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+ a->add(
+ B_pf,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
a->cmp(kIdx, kSize);
a->jl(LoopkRem);
diff --git a/src/GenerateKernelU8S8S32ACC32Avx512.cc b/src/GenerateKernelU8S8S32ACC32Avx512.cc
index 0621bb0..e292fa8 100644
--- a/src/GenerateKernelU8S8S32ACC32Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC32Avx512.cc
@@ -144,6 +144,8 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
constexpr int kBlock =
PackingTraits<int8_t, int32_t, inst_set_t::avx512>::KCB;
+ constexpr int nBlock =
+ PackingTraits<int8_t, int32_t, inst_set_t::avx512>::NCB;
constexpr int mRegBlockSize =
PackingTraits<int8_t, int32_t, inst_set_t::avx512>::MR;
constexpr int row_interleave =
@@ -239,8 +241,11 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
// update buffer_B address for next k iteration
a->add(
- buffer_B, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t)));
- a->add(B_pf, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t)));
+ buffer_B,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+ a->add(
+ B_pf,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
// a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float)));
@@ -291,8 +296,11 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
// update buffer_B address for next k iteration
a->add(
- buffer_B, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t)));
- a->add(B_pf, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t)));
+ buffer_B,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+ a->add(
+ B_pf,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
a->cmp(kIdx, kSize);
a->jl(LoopkRem);