Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaya S Khudia <dskhudia@fb.com>2019-03-21 20:03:36 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-03-21 20:07:54 +0300
commitf65f0ebe54f0512d8f42ee10025b596e3f42e0b8 (patch)
tree8a80b9de7c8d5ae034d707b27ac7c84cecd83d0d
parent452627c5f29412528c26b57880f27914b1068d6e (diff)
Improves small N cases back to what they were
Summary: In D14507536 and D14516232 small N cases suffered if we increased the NR. This fixes those cases. Reviewed By: jianyuh Differential Revision: D14529494 fbshipit-source-id: 6f53797948de760d6ed24b767cbbe8d27768660f
-rw-r--r--include/fbgemm/PackingTraits-inl.h8
-rw-r--r--src/ExecuteKernelU8S8.cc8
-rw-r--r--src/ExecuteKernelU8S8.h2
-rw-r--r--src/GenerateKernelU8S8S32ACC16Avx512.cc8
-rw-r--r--src/GenerateKernelU8S8S32ACC32Avx512.cc8
5 files changed, 27 insertions, 7 deletions
diff --git a/include/fbgemm/PackingTraits-inl.h b/include/fbgemm/PackingTraits-inl.h
index 6bf34d5..5b50bc9 100644
--- a/include/fbgemm/PackingTraits-inl.h
+++ b/include/fbgemm/PackingTraits-inl.h
@@ -154,6 +154,10 @@ struct PackingTraits<
inst_set_t::avx512,
typename std::enable_if<is_8bit<T>::value>::type> {
static constexpr int MR{14}; ///< Register block for M dimension.
+ static constexpr int NR_MIN{
+ 16}; ///< Minimum register block for N dimension.
+ ///< 16 because 16*ROW_INTERLEAVE int8 elements
+ ///< completely fill a 512-bit wide vector.
static constexpr int NR{
32}; ///< Register block for N dimension.
///< Must be a multiple of 16 because 16*ROW_INTERLEAVE int8 elements
@@ -187,6 +191,10 @@ struct PackingTraits<
inst_set_t::avx512,
typename std::enable_if<is_8bit<T>::value>::type> {
static constexpr int MR{6}; ///< Register block for M dimension
+ static constexpr int NR_MIN{
+ 32}; ///< Minimum register block for N dimension;
+ ///< 32 because 32*ROW_INTERLEAVE int8 elements
+ ///< completely fill a 512-bit wide vector.
static constexpr int NR{
128}; ///< Register block for N dimension;
///< Must be a multiple of 32 because 32*ROW_INTERLEAVE int8
diff --git a/src/ExecuteKernelU8S8.cc b/src/ExecuteKernelU8S8.cc
index f2b028d..9b0ea41 100644
--- a/src/ExecuteKernelU8S8.cc
+++ b/src/ExecuteKernelU8S8.cc
@@ -51,10 +51,10 @@ ExecuteKernel<
int8_t,
typename packingAMatrix::accType,
inst_set_t::avx512>::NCB;
- nrSize_ = PackingTraits<
+ nrMinSize_ = PackingTraits<
int8_t,
typename packingAMatrix::accType,
- inst_set_t::avx512>::NR;
+ inst_set_t::avx512>::NR_MIN;
} else if (fbgemmHasAvx2Support()) {
mbSize_ = PackingTraits<
int8_t,
@@ -64,7 +64,7 @@ ExecuteKernel<
int8_t,
typename packingAMatrix::accType,
inst_set_t::avx2>::NCB;
- nrSize_ = PackingTraits<
+ nrMinSize_ = PackingTraits<
int8_t,
typename packingAMatrix::accType,
inst_set_t::avx2>::NR;
@@ -132,7 +132,7 @@ void ExecuteKernel<
for (int jb = 0; jb < bColBlocks; ++jb) {
if (jb == bColBlocks - 1) {
- int nc = ((packedB_.lastBcol() - 1) / nrSize_ + 1) * nrSize_;
+ int nc = ((packedB_.lastBcol() - 1) / nrMinSize_ + 1) * nrMinSize_;
if (nc != nbSize_) {
if (cpuinfo_initialize()) {
if (fbgemmHasAvx512Support()) {
diff --git a/src/ExecuteKernelU8S8.h b/src/ExecuteKernelU8S8.h
index 8d999ff..b56f54c 100644
--- a/src/ExecuteKernelU8S8.h
+++ b/src/ExecuteKernelU8S8.h
@@ -69,7 +69,7 @@ class ExecuteKernel<
///< multiple of N.
int mbSize_; ///< block size in the m dimension.
int nbSize_; ///< block size in the n dimension.
- int nrSize_; ///< register size in the n dimension.
+ int nrMinSize_; ///< minimum register size in the n dimension.
};
} // namespace fbgemm
diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc
index 9bf2eea..2ded242 100644
--- a/src/GenerateKernelU8S8S32ACC16Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc
@@ -159,11 +159,13 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
PackingTraits<int8_t, int16_t, inst_set_t::avx512>::MR;
constexpr int nRegBlockSize =
PackingTraits<int8_t, int16_t, inst_set_t::avx512>::NR;
+ constexpr int nRegBlockSizeMin =
+ PackingTraits<int8_t, int16_t, inst_set_t::avx512>::NR_MIN;
constexpr int row_interleave =
PackingTraits<int8_t, int16_t, inst_set_t::avx512>::ROW_INTERLEAVE;
assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
- assert(nc % nRegBlockSize == 0 && "nc must be a multiple of NR");
+ assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
int maxMRegs = mRegBlockSize;
int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
assert(
@@ -285,6 +287,8 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
a->cmp(jIdx, jLoopTrips);
a->jl(LoopNBlocks);
@@ -365,6 +369,8 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
a->cmp(jIdx, jLoopTrips);
a->jl(LoopNRem);
}
diff --git a/src/GenerateKernelU8S8S32ACC32Avx512.cc b/src/GenerateKernelU8S8S32ACC32Avx512.cc
index 0dcc321..333aa9d 100644
--- a/src/GenerateKernelU8S8S32ACC32Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC32Avx512.cc
@@ -159,11 +159,13 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
PackingTraits<int8_t, int32_t, inst_set_t::avx512>::MR;
constexpr int nRegBlockSize =
PackingTraits<int8_t, int32_t, inst_set_t::avx512>::NR;
+ constexpr int nRegBlockSizeMin =
+ PackingTraits<int8_t, int32_t, inst_set_t::avx512>::NR_MIN;
constexpr int row_interleave =
PackingTraits<int8_t, int32_t, inst_set_t::avx512>::ROW_INTERLEAVE;
assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
- assert(nc % nRegBlockSize == 0 && "nc must be a multiple of NR");
+ assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
int maxMRegs = mRegBlockSize;
int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
assert(
@@ -301,6 +303,8 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
a->cmp(jIdx, jLoopTrips);
a->jl(LoopNBlocks);
@@ -382,6 +386,8 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
a->cmp(jIdx, jLoopTrips);
a->jl(LoopNRem);
}