From 452627c5f29412528c26b57880f27914b1068d6e Mon Sep 17 00:00:00 2001 From: Daya S Khudia Date: Thu, 21 Mar 2019 10:03:36 -0700 Subject: Allocate some registers for B matrix loading and reuse loaded results Summary: Instead of loading B matrix values with every vpmaddubsw instruction, load once and reuse. The downside is we need to use some register for holding these B matrix values which could have been otherwise used for C accumulations. Reviewed By: jianyuh Differential Revision: D14529495 fbshipit-source-id: 54bd4bcdcf14ac2f25a433ac60bfc08b7359453f --- include/fbgemm/PackingTraits-inl.h | 4 ++-- src/GenerateKernel.h | 13 ++++++++++++- src/GenerateKernelU8S8S32ACC16Avx512.cc | 13 ++++++++++--- 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/include/fbgemm/PackingTraits-inl.h b/include/fbgemm/PackingTraits-inl.h index 465c498..6bf34d5 100644 --- a/include/fbgemm/PackingTraits-inl.h +++ b/include/fbgemm/PackingTraits-inl.h @@ -186,7 +186,7 @@ struct PackingTraits< std::int16_t, inst_set_t::avx512, typename std::enable_if::value>::type> { - static constexpr int MR{7}; ///< Register block for M dimension + static constexpr int MR{6}; ///< Register block for M dimension static constexpr int NR{ 128}; ///< Register block for N dimension; ///< Must be a multiple of 32 because 32*ROW_INTERLEAVE int8 @@ -200,7 +200,7 @@ struct PackingTraits< ///< B matrix. static constexpr int MCB{ - 56}; ///< Cache block for M dimension (multiple of MR). + 60}; ///< Cache block for M dimension (multiple of MR). static constexpr int NCB{ 128}; ///< Cache block for N dimension (multiple of NR). static constexpr int KCB{256}; ///< Cache block for K dimension. diff --git a/src/GenerateKernel.h b/src/GenerateKernel.h index 7c0368b..7d8ac05 100644 --- a/src/GenerateKernel.h +++ b/src/GenerateKernel.h @@ -59,7 +59,15 @@ class CodeGenBase { x86::zmm15, x86::zmm16, x86::zmm17, x86::zmm18, x86::zmm19, x86::zmm20, x86::zmm21, x86::zmm22, x86::zmm23, x86::zmm24, x86::zmm25, x86::zmm26, x86::zmm27, - } { + }, + AllRegs_avx512_{x86::zmm0, x86::zmm1, x86::zmm2, x86::zmm3, + x86::zmm4, x86::zmm5, x86::zmm6, x86::zmm7, + x86::zmm8, x86::zmm9, x86::zmm10, x86::zmm11, + x86::zmm12, x86::zmm13, x86::zmm14, x86::zmm15, + x86::zmm16, x86::zmm17, x86::zmm18, x86::zmm19, + x86::zmm20, x86::zmm21, x86::zmm22, x86::zmm23, + x86::zmm24, x86::zmm25, x86::zmm26, x86::zmm27, + x86::zmm28, x86::zmm29, x86::zmm30, x86::zmm31} { // vector width in bits if (cpuinfo_initialize()) { if (fbgemmHasAvx512Support()) { @@ -159,6 +167,9 @@ class CodeGenBase { CRegs_avx2_[12]; ///< AVX2 ymm registers for C in the micro-kernel. asmjit::X86Zmm CRegs_avx512_[28]; ///< AVX512 zmm registers for C in the micro-kernel. + asmjit::X86Zmm + AllRegs_avx512_[32]; ///< all AVX512 zmm registers. + int vectorWidth_; ///< Vector width in bits. int VLEN_; ///< Vector width in elements. static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit. diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc index 6f3f276..9bf2eea 100644 --- a/src/GenerateKernelU8S8S32ACC16Avx512.cc +++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc @@ -54,13 +54,20 @@ void CodeGenBase::genComputeBlock< asmjit::X86Zmm tmpReg = x86::zmm30; + // We start allocating BRegs from zmm27 and then allocate zmm26 and so on. + for (int j = 0; j < colRegs; ++j) { + a->vmovups( + AllRegs_avx512_[27 - j], + x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t))); + } + for (int i = 0; i < rowRegs; ++i) { // broadcast A a->vpbroadcastw( AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t))); for (int j = 0; j < colRegs; ++j) { a->vpmaddubsw( - tmpReg, AReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t))); + tmpReg, AReg, AllRegs_avx512_[27-j]); a->vpaddsw( CRegs_avx512_[i * leadingDimCRegAssign + j], tmpReg, @@ -160,9 +167,9 @@ CodeGenBase::getOrCreate( int maxMRegs = mRegBlockSize; int maxNRegs = nRegBlockSize * row_interleave / VLEN_; assert( - maxMRegs * maxNRegs <= 28 && + maxMRegs * maxNRegs <= 24 && "MR*(NR*ROW_INTERLEAVE*8/512) \ - must be <= 28(available registers constraint)"); + must be <= 24(available registers constraint)"); int mRegBlocks = mc / mRegBlockSize; int mRegBlocksRem = mc % mRegBlockSize; -- cgit v1.2.3