Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaya S Khudia <dskhudia@fb.com>2019-03-21 20:03:36 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-03-21 20:07:54 +0300
commit452627c5f29412528c26b57880f27914b1068d6e (patch)
tree3d7f6af4a78a15e72212e4a25916d238f0c9699a
parentd53c0220cf1749802736bba192c5e37f430df7a0 (diff)
Allocate some registers for B matrix loading and reuse loaded results
Summary: Instead of loading B matrix values with every vpmaddubsw instruction, load once and reuse. The downside is we need to use some register for holding these B matrix values which could have been otherwise used for C accumulations. Reviewed By: jianyuh Differential Revision: D14529495 fbshipit-source-id: 54bd4bcdcf14ac2f25a433ac60bfc08b7359453f
-rw-r--r--include/fbgemm/PackingTraits-inl.h4
-rw-r--r--src/GenerateKernel.h13
-rw-r--r--src/GenerateKernelU8S8S32ACC16Avx512.cc13
3 files changed, 24 insertions, 6 deletions
diff --git a/include/fbgemm/PackingTraits-inl.h b/include/fbgemm/PackingTraits-inl.h
index 465c498..6bf34d5 100644
--- a/include/fbgemm/PackingTraits-inl.h
+++ b/include/fbgemm/PackingTraits-inl.h
@@ -186,7 +186,7 @@ struct PackingTraits<
std::int16_t,
inst_set_t::avx512,
typename std::enable_if<is_8bit<T>::value>::type> {
- static constexpr int MR{7}; ///< Register block for M dimension
+ static constexpr int MR{6}; ///< Register block for M dimension
static constexpr int NR{
128}; ///< Register block for N dimension;
///< Must be a multiple of 32 because 32*ROW_INTERLEAVE int8
@@ -200,7 +200,7 @@ struct PackingTraits<
///< B matrix.
static constexpr int MCB{
- 56}; ///< Cache block for M dimension (multiple of MR).
+ 60}; ///< Cache block for M dimension (multiple of MR).
static constexpr int NCB{
128}; ///< Cache block for N dimension (multiple of NR).
static constexpr int KCB{256}; ///< Cache block for K dimension.
diff --git a/src/GenerateKernel.h b/src/GenerateKernel.h
index 7c0368b..7d8ac05 100644
--- a/src/GenerateKernel.h
+++ b/src/GenerateKernel.h
@@ -59,7 +59,15 @@ class CodeGenBase {
x86::zmm15, x86::zmm16, x86::zmm17, x86::zmm18, x86::zmm19,
x86::zmm20, x86::zmm21, x86::zmm22, x86::zmm23, x86::zmm24,
x86::zmm25, x86::zmm26, x86::zmm27,
- } {
+ },
+ AllRegs_avx512_{x86::zmm0, x86::zmm1, x86::zmm2, x86::zmm3,
+ x86::zmm4, x86::zmm5, x86::zmm6, x86::zmm7,
+ x86::zmm8, x86::zmm9, x86::zmm10, x86::zmm11,
+ x86::zmm12, x86::zmm13, x86::zmm14, x86::zmm15,
+ x86::zmm16, x86::zmm17, x86::zmm18, x86::zmm19,
+ x86::zmm20, x86::zmm21, x86::zmm22, x86::zmm23,
+ x86::zmm24, x86::zmm25, x86::zmm26, x86::zmm27,
+ x86::zmm28, x86::zmm29, x86::zmm30, x86::zmm31} {
// vector width in bits
if (cpuinfo_initialize()) {
if (fbgemmHasAvx512Support()) {
@@ -159,6 +167,9 @@ class CodeGenBase {
CRegs_avx2_[12]; ///< AVX2 ymm registers for C in the micro-kernel.
asmjit::X86Zmm
CRegs_avx512_[28]; ///< AVX512 zmm registers for C in the micro-kernel.
+ asmjit::X86Zmm
+ AllRegs_avx512_[32]; ///< all AVX512 zmm registers.
+
int vectorWidth_; ///< Vector width in bits.
int VLEN_; ///< Vector width in elements.
static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit.
diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc
index 6f3f276..9bf2eea 100644
--- a/src/GenerateKernelU8S8S32ACC16Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc
@@ -54,13 +54,20 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
asmjit::X86Zmm tmpReg = x86::zmm30;
+ // We start allocating BRegs from zmm27 and then allocate zmm26 and so on.
+ for (int j = 0; j < colRegs; ++j) {
+ a->vmovups(
+ AllRegs_avx512_[27 - j],
+ x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
+ }
+
for (int i = 0; i < rowRegs; ++i) {
// broadcast A
a->vpbroadcastw(
AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t)));
for (int j = 0; j < colRegs; ++j) {
a->vpmaddubsw(
- tmpReg, AReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
+ tmpReg, AReg, AllRegs_avx512_[27-j]);
a->vpaddsw(
CRegs_avx512_[i * leadingDimCRegAssign + j],
tmpReg,
@@ -160,9 +167,9 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
int maxMRegs = mRegBlockSize;
int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
assert(
- maxMRegs * maxNRegs <= 28 &&
+ maxMRegs * maxNRegs <= 24 &&
"MR*(NR*ROW_INTERLEAVE*8/512) \
- must be <= 28(available registers constraint)");
+ must be <= 24(available registers constraint)");
int mRegBlocks = mc / mRegBlockSize;
int mRegBlocksRem = mc % mRegBlockSize;