diff options
author | Daya S Khudia <dskhudia@fb.com> | 2019-03-21 20:03:36 +0300 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-03-21 20:07:54 +0300 |
commit | f65f0ebe54f0512d8f42ee10025b596e3f42e0b8 (patch) | |
tree | 8a80b9de7c8d5ae034d707b27ac7c84cecd83d0d | |
parent | 452627c5f29412528c26b57880f27914b1068d6e (diff) |
Improves small N cases back to what they were
Summary: In D14507536 and D14516232 small N cases suffered if we increased the NR. This fixes those cases.
Reviewed By: jianyuh
Differential Revision: D14529494
fbshipit-source-id: 6f53797948de760d6ed24b767cbbe8d27768660f
-rw-r--r-- | include/fbgemm/PackingTraits-inl.h | 8 | ||||
-rw-r--r-- | src/ExecuteKernelU8S8.cc | 8 | ||||
-rw-r--r-- | src/ExecuteKernelU8S8.h | 2 | ||||
-rw-r--r-- | src/GenerateKernelU8S8S32ACC16Avx512.cc | 8 | ||||
-rw-r--r-- | src/GenerateKernelU8S8S32ACC32Avx512.cc | 8 |
5 files changed, 27 insertions, 7 deletions
diff --git a/include/fbgemm/PackingTraits-inl.h b/include/fbgemm/PackingTraits-inl.h index 6bf34d5..5b50bc9 100644 --- a/include/fbgemm/PackingTraits-inl.h +++ b/include/fbgemm/PackingTraits-inl.h @@ -154,6 +154,10 @@ struct PackingTraits< inst_set_t::avx512, typename std::enable_if<is_8bit<T>::value>::type> { static constexpr int MR{14}; ///< Register block for M dimension. + static constexpr int NR_MIN{ + 16}; ///< Minimum register block for N dimension. + ///< 16 because 16*ROW_INTERLEAVE int8 elements + ///< completely fill a 512-bit wide vector. static constexpr int NR{ 32}; ///< Register block for N dimension. ///< Must be a multiple of 16 because 16*ROW_INTERLEAVE int8 elements @@ -187,6 +191,10 @@ struct PackingTraits< inst_set_t::avx512, typename std::enable_if<is_8bit<T>::value>::type> { static constexpr int MR{6}; ///< Register block for M dimension + static constexpr int NR_MIN{ + 32}; ///< Minimum register block for N dimension; + ///< 32 because 32*ROW_INTERLEAVE int8 elements + ///< completely fill a 512-bit wide vector. static constexpr int NR{ 128}; ///< Register block for N dimension; ///< Must be a multiple of 32 because 32*ROW_INTERLEAVE int8 diff --git a/src/ExecuteKernelU8S8.cc b/src/ExecuteKernelU8S8.cc index f2b028d..9b0ea41 100644 --- a/src/ExecuteKernelU8S8.cc +++ b/src/ExecuteKernelU8S8.cc @@ -51,10 +51,10 @@ ExecuteKernel< int8_t, typename packingAMatrix::accType, inst_set_t::avx512>::NCB; - nrSize_ = PackingTraits< + nrMinSize_ = PackingTraits< int8_t, typename packingAMatrix::accType, - inst_set_t::avx512>::NR; + inst_set_t::avx512>::NR_MIN; } else if (fbgemmHasAvx2Support()) { mbSize_ = PackingTraits< int8_t, @@ -64,7 +64,7 @@ ExecuteKernel< int8_t, typename packingAMatrix::accType, inst_set_t::avx2>::NCB; - nrSize_ = PackingTraits< + nrMinSize_ = PackingTraits< int8_t, typename packingAMatrix::accType, inst_set_t::avx2>::NR; @@ -132,7 +132,7 @@ void ExecuteKernel< for (int jb = 0; jb < bColBlocks; ++jb) { if (jb == bColBlocks - 1) { - int nc = ((packedB_.lastBcol() - 1) / nrSize_ + 1) * nrSize_; + int nc = ((packedB_.lastBcol() - 1) / nrMinSize_ + 1) * nrMinSize_; if (nc != nbSize_) { if (cpuinfo_initialize()) { if (fbgemmHasAvx512Support()) { diff --git a/src/ExecuteKernelU8S8.h b/src/ExecuteKernelU8S8.h index 8d999ff..b56f54c 100644 --- a/src/ExecuteKernelU8S8.h +++ b/src/ExecuteKernelU8S8.h @@ -69,7 +69,7 @@ class ExecuteKernel< ///< multiple of N. int mbSize_; ///< block size in the m dimension. int nbSize_; ///< block size in the n dimension. - int nrSize_; ///< register size in the n dimension. + int nrMinSize_; ///< minimum register size in the n dimension. }; } // namespace fbgemm diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc index 9bf2eea..2ded242 100644 --- a/src/GenerateKernelU8S8S32ACC16Avx512.cc +++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc @@ -159,11 +159,13 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( PackingTraits<int8_t, int16_t, inst_set_t::avx512>::MR; constexpr int nRegBlockSize = PackingTraits<int8_t, int16_t, inst_set_t::avx512>::NR; + constexpr int nRegBlockSizeMin = + PackingTraits<int8_t, int16_t, inst_set_t::avx512>::NR_MIN; constexpr int row_interleave = PackingTraits<int8_t, int16_t, inst_set_t::avx512>::ROW_INTERLEAVE; assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); - assert(nc % nRegBlockSize == 0 && "nc must be a multiple of NR"); + assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN"); int maxMRegs = mRegBlockSize; int maxNRegs = nRegBlockSize * row_interleave / VLEN_; assert( @@ -285,6 +287,8 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); int jLoopTrips = currColRegs / maxNRegs; + // jLoopTrips should be at least 1 + jLoopTrips = jLoopTrips ? jLoopTrips : 1; a->cmp(jIdx, jLoopTrips); a->jl(LoopNBlocks); @@ -365,6 +369,8 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); int jLoopTrips = currColRegs / maxNRegs; + // jLoopTrips should be at least 1 + jLoopTrips = jLoopTrips ? jLoopTrips : 1; a->cmp(jIdx, jLoopTrips); a->jl(LoopNRem); } diff --git a/src/GenerateKernelU8S8S32ACC32Avx512.cc b/src/GenerateKernelU8S8S32ACC32Avx512.cc index 0dcc321..333aa9d 100644 --- a/src/GenerateKernelU8S8S32ACC32Avx512.cc +++ b/src/GenerateKernelU8S8S32ACC32Avx512.cc @@ -159,11 +159,13 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>( PackingTraits<int8_t, int32_t, inst_set_t::avx512>::MR; constexpr int nRegBlockSize = PackingTraits<int8_t, int32_t, inst_set_t::avx512>::NR; + constexpr int nRegBlockSizeMin = + PackingTraits<int8_t, int32_t, inst_set_t::avx512>::NR_MIN; constexpr int row_interleave = PackingTraits<int8_t, int32_t, inst_set_t::avx512>::ROW_INTERLEAVE; assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); - assert(nc % nRegBlockSize == 0 && "nc must be a multiple of NR"); + assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN"); int maxMRegs = mRegBlockSize; int maxNRegs = nRegBlockSize * row_interleave / VLEN_; assert( @@ -301,6 +303,8 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>( a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); int jLoopTrips = currColRegs / maxNRegs; + // jLoopTrips should be at least 1 + jLoopTrips = jLoopTrips ? jLoopTrips : 1; a->cmp(jIdx, jLoopTrips); a->jl(LoopNBlocks); @@ -382,6 +386,8 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>( a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); int jLoopTrips = currColRegs / maxNRegs; + // jLoopTrips should be at least 1 + jLoopTrips = jLoopTrips ? jLoopTrips : 1; a->cmp(jIdx, jLoopTrips); a->jl(LoopNRem); } |