diff options
author | Jianyu Huang <jianyuhuang@fb.com> | 2019-02-14 02:46:39 +0300 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-02-14 03:14:19 +0300 |
commit | 5ec2641b8a4bb74827f9c46357c69c4af1f6cfe0 (patch) | |
tree | 99582190a2a21c1c00bd07dab0f4fefc9a4c9a5d /src | |
parent | aa88eafebf8df24dfa8c60768b1103f6981f7322 (diff) |
JIT kernel should only handle a small portion of NCB for the last block: multiple of NR
Summary:
Before this Diff:
we pass into the JIT kernel with nc = NCB ( packedB_.blockColSize() ) instead of nc = leftover size (packedB_.lastBcol() ) for the last block of B (diffusion/FBS/browse/master/fbcode/deeplearning/fbgemm/src/ExecuteKernelU8S8.cc;1adfe7977ef7ea2a1aee0ed785bd3fed5b7c4a20$102), which cause the additional computation when n is small.
After this Diff:
we pass into the JIT kernel with a small portion of NCB (still multiple of NR) for the last block of B.
The main performance gain is for Acc16, because NCB = 4 * NR for Acc16 and NCB = NR for Acc32 in our current settings (AVX2 and AVX512).
Reviewed By: jspark1105
Differential Revision: D14063628
fbshipit-source-id: 5829d06553daf617e2fefa7d26cb2d761af402c1
Diffstat (limited to 'src')
-rw-r--r-- | src/ExecuteKernelU8S8.cc | 29 | ||||
-rw-r--r-- | src/ExecuteKernelU8S8.h | 1 | ||||
-rw-r--r-- | src/GenerateKernelU8S8S32ACC16.cc | 13 | ||||
-rw-r--r-- | src/GenerateKernelU8S8S32ACC16Avx512.cc | 14 | ||||
-rw-r--r-- | src/GenerateKernelU8S8S32ACC32.cc | 15 | ||||
-rw-r--r-- | src/GenerateKernelU8S8S32ACC32Avx512.cc | 16 |
6 files changed, 72 insertions, 16 deletions
diff --git a/src/ExecuteKernelU8S8.cc b/src/ExecuteKernelU8S8.cc index cdceb63..64bac90 100644 --- a/src/ExecuteKernelU8S8.cc +++ b/src/ExecuteKernelU8S8.cc @@ -53,6 +53,10 @@ ExecuteKernel< int8_t, typename packingAMatrix::accType, inst_set_t::avx512>::NCB; + nrSize_ = PackingTraits< + int8_t, + typename packingAMatrix::accType, + inst_set_t::avx512>::NR; } else if (cpuinfo_has_x86_avx2()) { mbSize_ = PackingTraits< int8_t, @@ -62,6 +66,10 @@ ExecuteKernel< int8_t, typename packingAMatrix::accType, inst_set_t::avx2>::NCB; + nrSize_ = PackingTraits< + int8_t, + typename packingAMatrix::accType, + inst_set_t::avx2>::NR; } else { assert(0 && "unsupported architecure"); } @@ -125,6 +133,27 @@ void ExecuteKernel< #endif for (int jb = 0; jb < bColBlocks; ++jb) { + if (jb == bColBlocks - 1) { + int nc = ((packedB_.lastBcol() - 1) / nrSize_ + 1) * nrSize_; + if (nc != nbSize_) { + if (cpuinfo_initialize()) { + if (cpuinfo_has_x86_avx512f()) { + fn = BaseType::template getOrCreate<inst_set_t::avx512>( + accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_); + } else if (cpuinfo_has_x86_avx2()) { + fn = BaseType::template getOrCreate<inst_set_t::avx2>( + accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_); + } else { + // TODO: Have default slower path + assert(0 && "unsupported architecture"); + return; + } + } else { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } + } + } + bBuf = packedB_.getBuf(jb, kBlock); // prefetch addr of the next packed block of B matrix bBuf_pf = packedB_.getBuf(jb == bColBlocks - 1 ? jb : jb + 1, kBlock); diff --git a/src/ExecuteKernelU8S8.h b/src/ExecuteKernelU8S8.h index 253c4e6..c3f5486 100644 --- a/src/ExecuteKernelU8S8.h +++ b/src/ExecuteKernelU8S8.h @@ -71,6 +71,7 @@ class ExecuteKernel< ///< multiple of N. int mbSize_; ///< block size in the m dimension. int nbSize_; ///< block size in the n dimension. + int nrSize_; ///< register size in the n dimension. }; } // namespace fbgemm diff --git a/src/GenerateKernelU8S8S32ACC16.cc b/src/GenerateKernelU8S8S32ACC16.cc index b7bc676..3419281 100644 --- a/src/GenerateKernelU8S8S32ACC16.cc +++ b/src/GenerateKernelU8S8S32ACC16.cc @@ -151,6 +151,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>( // code_.setLogger(&logger); constexpr int kBlock = PackingTraits<int8_t, int16_t, inst_set_t::avx2>::KCB; + constexpr int nBlock = PackingTraits<int8_t, int16_t, inst_set_t::avx2>::NCB; constexpr int mRegBlockSize = PackingTraits<int8_t, int16_t, inst_set_t::avx2>::MR; // constexpr int nRegBlockSize = @@ -237,8 +238,10 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>( // update buffer_B address for next k iteration a->add( - buffer_B, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t))); - // a->add(B_pf, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t))); + buffer_B, + static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); + // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * + // sizeof(int8_t))); a->cmp(kIdx, kSize); a->jl(Loopk); @@ -285,8 +288,10 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>( // update buffer_B address for next k iteration a->add( - buffer_B, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t))); - // a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t)); + buffer_B, + static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); + // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * + // sizeof(int8_t))); a->cmp(kIdx, kSize); a->jl(LoopkRem); diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc index eeeaea0..ef625e2 100644 --- a/src/GenerateKernelU8S8S32ACC16Avx512.cc +++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc @@ -140,6 +140,8 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( constexpr int kBlock = PackingTraits<int8_t, int16_t, inst_set_t::avx512>::KCB; + constexpr int nBlock = + PackingTraits<int8_t, int16_t, inst_set_t::avx512>::NCB; constexpr int mRegBlockSize = PackingTraits<int8_t, int16_t, inst_set_t::avx512>::MR; // constexpr int nRegBlockSize = @@ -226,8 +228,10 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( // update buffer_B address for next k iteration a->add( - buffer_B, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t))); - // a->add(B_pf, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t))); + buffer_B, + static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); + // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * + // sizeof(int8_t))); a->cmp(kIdx, kSize); a->jl(Loopk); @@ -275,8 +279,10 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( // update buffer_B address for next k iteration a->add( - buffer_B, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t))); - // a->add(B_pf, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t))); + buffer_B, + static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); + // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave * + // sizeof(int8_t))); a->cmp(kIdx, kSize); a->jl(LoopkRem); diff --git a/src/GenerateKernelU8S8S32ACC32.cc b/src/GenerateKernelU8S8S32ACC32.cc index f788872..ab0625c 100644 --- a/src/GenerateKernelU8S8S32ACC32.cc +++ b/src/GenerateKernelU8S8S32ACC32.cc @@ -155,6 +155,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>( // code_.setLogger(&logger); constexpr int kBlock = PackingTraits<int8_t, int32_t, inst_set_t::avx2>::KCB; + constexpr int nBlock = PackingTraits<int8_t, int32_t, inst_set_t::avx2>::NCB; constexpr int mRegBlockSize = PackingTraits<int8_t, int32_t, inst_set_t::avx2>::MR; constexpr int row_interleave = @@ -249,8 +250,11 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>( // update buffer_B address for next k iteration a->add( - buffer_B, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t))); - a->add(B_pf, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t))); + buffer_B, + static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); + a->add( + B_pf, + static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); // a->add(B_pf, 32*sizeof(float)); @@ -301,8 +305,11 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>( // update buffer_B address for next k iteration a->add( - buffer_B, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t))); - a->add(B_pf, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t))); + buffer_B, + static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); + a->add( + B_pf, + static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); a->cmp(kIdx, kSize); a->jl(LoopkRem); diff --git a/src/GenerateKernelU8S8S32ACC32Avx512.cc b/src/GenerateKernelU8S8S32ACC32Avx512.cc index 0621bb0..e292fa8 100644 --- a/src/GenerateKernelU8S8S32ACC32Avx512.cc +++ b/src/GenerateKernelU8S8S32ACC32Avx512.cc @@ -144,6 +144,8 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>( constexpr int kBlock = PackingTraits<int8_t, int32_t, inst_set_t::avx512>::KCB; + constexpr int nBlock = + PackingTraits<int8_t, int32_t, inst_set_t::avx512>::NCB; constexpr int mRegBlockSize = PackingTraits<int8_t, int32_t, inst_set_t::avx512>::MR; constexpr int row_interleave = @@ -239,8 +241,11 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>( // update buffer_B address for next k iteration a->add( - buffer_B, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t))); - a->add(B_pf, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t))); + buffer_B, + static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); + a->add( + B_pf, + static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float))); @@ -291,8 +296,11 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>( // update buffer_B address for next k iteration a->add( - buffer_B, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t))); - a->add(B_pf, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t))); + buffer_B, + static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); + a->add( + B_pf, + static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); a->cmp(kIdx, kSize); a->jl(LoopkRem); |