diff options
author | Protonu Basu <protonu@fb.com> | 2019-04-02 15:22:44 +0300 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-04-02 15:28:21 +0300 |
commit | f12ec122be12b0647ada3ff2c374cca57aa4ae95 (patch) | |
tree | 43584749ec09d493ea3a3ec04e407c2b88e8c76c /src | |
parent | d8e0d440ef80362a786f4ebb68cf1b393c33b52d (diff) |
Exposing tuning parameters in FBGEMM (MCB, NCB, KCB, MR, NR, Row Interleave) (#90)
Summary:
Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/90
Exposing tuning parameters in FBGEMM (MCB, NCB, KCB, MR, NR, Row Interleave)
Reviewed By: dskhudia
Differential Revision: D14358148
fbshipit-source-id: 783fb4653fd696dbbd4075ad56cb8682db3011a5
Diffstat (limited to 'src')
-rw-r--r-- | src/ExecuteKernelGeneric.h | 3 | ||||
-rw-r--r-- | src/ExecuteKernelU8S8.cc | 68 | ||||
-rw-r--r-- | src/ExecuteKernelU8S8.h | 4 | ||||
-rw-r--r-- | src/Fbgemm.cc | 110 | ||||
-rw-r--r-- | src/GenerateKernel.h | 26 | ||||
-rw-r--r-- | src/GenerateKernelU8S8S32ACC16.cc | 62 | ||||
-rw-r--r-- | src/GenerateKernelU8S8S32ACC16Avx512.cc | 65 | ||||
-rw-r--r-- | src/GenerateKernelU8S8S32ACC32.cc | 60 | ||||
-rw-r--r-- | src/GenerateKernelU8S8S32ACC32Avx512.cc | 65 | ||||
-rw-r--r-- | src/PackAMatrix.cc | 43 | ||||
-rw-r--r-- | src/PackAWithIm2Col.cc | 53 | ||||
-rw-r--r-- | src/PackAWithQuantRowOffset.cc | 61 | ||||
-rw-r--r-- | src/PackAWithRowOffset.cc | 50 | ||||
-rw-r--r-- | src/PackBMatrix.cc | 39 | ||||
-rw-r--r-- | src/PackMatrix.cc | 44 |
15 files changed, 516 insertions, 237 deletions
diff --git a/src/ExecuteKernelGeneric.h b/src/ExecuteKernelGeneric.h index 667b0ef..ce9a7bb 100644 --- a/src/ExecuteKernelGeneric.h +++ b/src/ExecuteKernelGeneric.h @@ -40,7 +40,8 @@ class ExecuteKernel : public CodeGenBase< int32_t ldc, const processOutputType& outputProcess, int thread_id, - int num_threads); + int num_threads, + const BlockingFactors* params = nullptr); void execute(int kBlock); private: diff --git a/src/ExecuteKernelU8S8.cc b/src/ExecuteKernelU8S8.cc index 9b0ea41..4175d65 100644 --- a/src/ExecuteKernelU8S8.cc +++ b/src/ExecuteKernelU8S8.cc @@ -33,8 +33,11 @@ ExecuteKernel< int32_t ldc, const processOutputType& outputProcess, int thread_id, - int num_threads) - : packedA_(packA), + int num_threads, + const BlockingFactors* params) + : CodeGenBase<uint8_t, int8_t, int32_t, typename packingAMatrix::accType>( + params), + packedA_(packA), packedB_(packB), matC_(matC), C_buffer_(C_buffer), @@ -42,34 +45,41 @@ ExecuteKernel< outputProcess_(outputProcess), thread_id_(thread_id), num_threads_(num_threads) { - if (fbgemmHasAvx512Support()) { - mbSize_ = PackingTraits< - int8_t, - typename packingAMatrix::accType, - inst_set_t::avx512>::MCB; - nbSize_ = PackingTraits< - int8_t, - typename packingAMatrix::accType, - inst_set_t::avx512>::NCB; - nrMinSize_ = PackingTraits< - int8_t, - typename packingAMatrix::accType, - inst_set_t::avx512>::NR_MIN; - } else if (fbgemmHasAvx2Support()) { - mbSize_ = PackingTraits< - int8_t, - typename packingAMatrix::accType, - inst_set_t::avx2>::MCB; - nbSize_ = PackingTraits< - int8_t, - typename packingAMatrix::accType, - inst_set_t::avx2>::NCB; - nrMinSize_ = PackingTraits< - int8_t, - typename packingAMatrix::accType, - inst_set_t::avx2>::NR; + if (params) { + mbSize_ = params->MCB; + nbSize_ = params->NCB; + nrMinSize_ = params->NR_MIN; + nrSize_ = params->NR; } else { - assert(0 && "unsupported architecure"); + if (fbgemmHasAvx512Support()) { + mbSize_ = PackingTraits< + int8_t, + typename packingAMatrix::accType, + inst_set_t::avx512>::MCB; + nbSize_ = PackingTraits< + int8_t, + typename packingAMatrix::accType, + inst_set_t::avx512>::NCB; + nrMinSize_ = PackingTraits< + int8_t, + typename packingAMatrix::accType, + inst_set_t::avx512>::NR_MIN; + } else if (fbgemmHasAvx2Support()) { + mbSize_ = PackingTraits< + int8_t, + typename packingAMatrix::accType, + inst_set_t::avx2>::MCB; + nbSize_ = PackingTraits< + int8_t, + typename packingAMatrix::accType, + inst_set_t::avx2>::NCB; + nrMinSize_ = PackingTraits< + int8_t, + typename packingAMatrix::accType, + inst_set_t::avx2>::NR; + } else { + assert(0 && "unsupported architecure"); + } } C_tile_ = new int32_t[mbSize_ * nbSize_]; } diff --git a/src/ExecuteKernelU8S8.h b/src/ExecuteKernelU8S8.h index b56f54c..bb20134 100644 --- a/src/ExecuteKernelU8S8.h +++ b/src/ExecuteKernelU8S8.h @@ -44,7 +44,8 @@ class ExecuteKernel< int32_t ldc, const processOutputType& outputProcess, int thread_id, - int num_threads); + int num_threads, + const BlockingFactors* params = nullptr); void execute(int kBlock); ~ExecuteKernel() { @@ -70,6 +71,7 @@ class ExecuteKernel< int mbSize_; ///< block size in the m dimension. int nbSize_; ///< block size in the n dimension. int nrMinSize_; ///< minimum register size in the n dimension. + int nrSize_; ///< register size in the n dimension. }; } // namespace fbgemm diff --git a/src/Fbgemm.cc b/src/Fbgemm.cc index a90dd2d..a40f38a 100644 --- a/src/Fbgemm.cc +++ b/src/Fbgemm.cc @@ -36,7 +36,8 @@ void fbgemmPacked( uint32_t ldc, const processOutputType& outProcess, int thread_id, - int num_threads) { + int num_threads, + const BlockingFactors* blocking_params) { static_assert( std::is_same< typename packingAMatrix::accType, @@ -48,36 +49,43 @@ void fbgemmPacked( // Run time CPU detection if (cpuinfo_initialize()) { - if (fbgemmHasAvx512Support()) { - MCB = PackingTraits< - typename packingAMatrix::inpType, - typename packingAMatrix::accType, - inst_set_t::avx512>::MCB; - KCB = PackingTraits< - typename packingAMatrix::inpType, - typename packingAMatrix::accType, - inst_set_t::avx512>::KCB; - MR = PackingTraits< - typename packingAMatrix::inpType, - typename packingAMatrix::accType, - inst_set_t::avx512>::MR; - } else if (fbgemmHasAvx2Support()) { - MCB = PackingTraits< - typename packingAMatrix::inpType, - typename packingAMatrix::accType, - inst_set_t::avx2>::MCB; - KCB = PackingTraits< - typename packingAMatrix::inpType, - typename packingAMatrix::accType, - inst_set_t::avx2>::KCB; - MR = PackingTraits< - typename packingAMatrix::inpType, - typename packingAMatrix::accType, - inst_set_t::avx2>::MR; + if (blocking_params) { + MCB = blocking_params->MCB; + KCB = blocking_params->KCB; + MR = blocking_params->MR; } else { - // TODO: Have default slower path - assert(0 && "unsupported architecture"); - return; + if (fbgemmHasAvx512Support()) { + MCB = PackingTraits< + typename packingAMatrix::inpType, + typename packingAMatrix::accType, + inst_set_t::avx512>::MCB; + KCB = PackingTraits< + typename packingAMatrix::inpType, + typename packingAMatrix::accType, + inst_set_t::avx512>::KCB; + MR = PackingTraits< + typename packingAMatrix::inpType, + typename packingAMatrix::accType, + inst_set_t::avx512>::MR; + } else if (fbgemmHasAvx2Support()) { + MCB = PackingTraits< + typename packingAMatrix::inpType, + typename packingAMatrix::accType, + inst_set_t::avx2>::MCB; + KCB = PackingTraits< + typename packingAMatrix::inpType, + typename packingAMatrix::accType, + inst_set_t::avx2>::KCB; + MR = PackingTraits< + typename packingAMatrix::inpType, + typename packingAMatrix::accType, + inst_set_t::avx2>::MR; + + } else { + // TODO: Have default slower path + assert(0 && "unsupported architecture"); + return; + } } } else { throw std::runtime_error("Failed to initialize cpuinfo!"); @@ -149,7 +157,8 @@ void fbgemmPacked( ldc, outProcess, thread_id, - num_threads); + num_threads, + blocking_params); for (int i = i_begin; i < i_end; i += MCB) { // i is the element index mc = std::min(i_end - i, MCB); for (int kb = 0; kb < kBlocks; ++kb) { // kb is the block index @@ -209,7 +218,7 @@ template bool fbgemmOptimizedGConv(const conv_param_t<2>& conv_p); template bool fbgemmOptimizedGConv(const conv_param_t<3>& conv_p); bool fbgemmSupportedCPU() { - return (cpuinfo_initialize() && cpuinfo_has_x86_avx2()); + return (cpuinfo_initialize() && fbgemmHasAvx2Support()); } //////////////////////////////////////////////////////////////////////////////// @@ -223,7 +232,8 @@ bool fbgemmSupportedCPU() { uint32_t ldc, \ const ReQuantizeOutput<RELU, Q_GRAN>& outProcess, \ int thread_id, \ - int num_threads); + int num_threads, \ + const BlockingFactors* blocking_params); #define INSTANTIATE_Q_GRANS(PACK_A, ACC_T, RELU) \ INSTANTIATE_BASE(PACK_A, ACC_T, RELU, QuantizationGranularity::TENSOR); \ @@ -258,7 +268,8 @@ INSTANTIATE_ACC_T(PackAWithRowOffset); uint32_t ldc, \ const ReQuantizeOutput<RELU, Q_GRAN>& outProcess, \ int thread_id, \ - int num_threads); + int num_threads, \ + const BlockingFactors* blocking_params); #define INSTANTIATE_Q_GRANS(ACC_T, RELU, SPATIAL_DIM) \ INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::TENSOR); \ @@ -293,7 +304,8 @@ INSTANTIATE_RELU(int16_t); uint32_t ldc, \ const ReQuantizeForFloat<RELU, Q_GRAN>& outProcess, \ int thread_id, \ - int num_threads); + int num_threads, \ + const BlockingFactors* blocking_params); #define INSTANTIATE_Q_GRANS(PACK_A, RELU) \ INSTANTIATE_BASE(PACK_A, RELU, QuantizationGranularity::TENSOR); \ @@ -323,7 +335,8 @@ INSTANTIATE_RELU(PackAWithQuantRowOffset); uint32_t ldc, \ const ReQuantizeForFloat<RELU, Q_GRAN>& outProcess, \ int thread_id, \ - int num_threads); + int num_threads, \ + const BlockingFactors* blocking_params); #define INSTANTIATE_Q_GRANS(ACC_T, RELU, SPATIAL_DIM) \ INSTANTIATE_BASE(ACC_T, RELU, SPATIAL_DIM, QuantizationGranularity::TENSOR); \ @@ -355,7 +368,8 @@ template void fbgemmPacked( uint32_t ldc, const ReQuantizeForFloat<false>& outProcess, int thread_id, - int num_threads); + int num_threads, + const BlockingFactors* blocking_params); //////////////////////////////////////////////////////////////////////////////// // DoSpmdmOnInpBuffer @@ -371,7 +385,8 @@ template void fbgemmPacked( int32_t, \ ReQuantizeOutput<RELU, Q_GRAN>>& outProcess, \ int thread_id, \ - int num_threads); + int num_threads, \ + const BlockingFactors* blocking_params); #define INSTANTIATE_Q_GRANS(PACK_A, RELU) \ INSTANTIATE_BASE(PACK_A, RELU, QuantizationGranularity::TENSOR); \ @@ -401,7 +416,8 @@ INSTANTIATE_RELU(PackAWithRowOffset); int32_t, \ ReQuantizeOutput<RELU, Q_GRAN>>& outProcess, \ int thread_id, \ - int num_threads); + int num_threads, \ + const BlockingFactors* blocking_params); #define INSTANTIATE_Q_GRANS(RELU) \ INSTANTIATE_BASE(RELU, QuantizationGranularity::TENSOR); \ @@ -423,7 +439,8 @@ template void fbgemmPacked( const DoSpmdmOnInpBuffer<float, int32_t, ReQuantizeForFloat<false>>& outProcess, int thread_id, - int num_threads); + int num_threads, + const BlockingFactors* blocking_params); //////////////////////////////////////////////////////////////////////////////// // memCopy @@ -436,7 +453,8 @@ template void fbgemmPacked( uint32_t ldc, \ const memCopy<>& outProcess, \ int thread_id, \ - int num_threads); + int num_threads, \ + const BlockingFactors* blocking_params); #define INSTANTIATE_ACC_T(PACK_A) \ INSTANTIATE_BASE(PACK_A, int32_t) \ @@ -460,7 +478,8 @@ INSTANTIATE_ACC_T(PackAWithRowOffset); uint32_t ldc, \ const memCopy<>& outProcess, \ int thread_id, \ - int num_threads); + int num_threads, \ + const BlockingFactors* blocking_params); #define INSTANTIATE_SPATIAL_DIM(ACC_T) \ INSTANTIATE_BASE(ACC_T, 2); \ @@ -481,7 +500,8 @@ template void fbgemmPacked( uint32_t ldc, const memCopy<>& outProcess, int thread_id, - int num_threads); + int num_threads, + const BlockingFactors* blocking_params); template void fbgemmPacked( PackMatrix<PackAMatrix<uint8_t, int16_t>, uint8_t, int16_t>& packA, @@ -491,6 +511,8 @@ template void fbgemmPacked( uint32_t ldc, const DoNothing<int32_t, int32_t>& outProcess, int thread_id, - int num_threads); + int num_threads, + const BlockingFactors* blocking_params); + } // namespace fbgemm diff --git a/src/GenerateKernel.h b/src/GenerateKernel.h index 7d8ac05..dccdfc5 100644 --- a/src/GenerateKernel.h +++ b/src/GenerateKernel.h @@ -39,8 +39,9 @@ class CodeGenBase { /** * @brief Constructor for initializing AVX2/AVX512 registers. */ - CodeGenBase() - : CRegs_avx2_{x86::ymm0, + CodeGenBase(const BlockingFactors* params = nullptr) + : blocking_params(params), + CRegs_avx2_{x86::ymm0, x86::ymm1, x86::ymm2, x86::ymm3, @@ -136,12 +137,21 @@ class CodeGenBase { bool accum, int leadingDimCRegAssign = 4); + const BlockingFactors* blocking_params; /** * @brief Generate filename to dump generated code * (debug-only) */ template <inst_set_t instSet> - std::string getCodeLoggingFile(bool accum, int mc, int nc) { + std::string getCodeLoggingFile( + bool accum, + int mc, + int nc, + int NCB, + int KCB, + int MR, + int NR, + int NR_MIN) { std::string fileName = "gemm_"; if (std::is_same<accT, std::int16_t>::value) { fileName += "acc16_"; @@ -153,6 +163,11 @@ class CodeGenBase { fileName += "accum-" + std::to_string(accum); fileName += "_MC-" + std::to_string(mc); fileName += "_NC-" + std::to_string(nc); + fileName += "_NCB-" + std::to_string(NCB); + fileName += "_NCB-" + std::to_string(KCB); + fileName += "_MR-" + std::to_string(MR); + fileName += "_NR-" + std::to_string(NR); + fileName += "_NR_MIN-" + std::to_string(NR_MIN); if (instSet == inst_set_t::avx512) { fileName += "_avx512"; } else if (instSet == inst_set_t::avx2) { @@ -174,7 +189,10 @@ class CodeGenBase { int VLEN_; ///< Vector width in elements. static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit. static thread_local asmjit::CodeHolder code_; ///< JIT Code Holder for asmjit. - static thread_local std::map<std::tuple<bool, int, int>, jit_micro_kernel_fp> + // The hash depends on accumulate, mc, nc, ncb, kcb, nr, mr, nr_min + static thread_local std::map< + std::tuple<bool, int, int, int, int, int, int, int>, + jit_micro_kernel_fp> codeCache_; ///< JIT Code Cache for reuse. }; diff --git a/src/GenerateKernelU8S8S32ACC16.cc b/src/GenerateKernelU8S8S32ACC16.cc index e5980b9..082518c 100644 --- a/src/GenerateKernelU8S8S32ACC16.cc +++ b/src/GenerateKernelU8S8S32ACC16.cc @@ -17,7 +17,7 @@ thread_local asmjit::CodeHolder CodeGenBase<TA, TB, TC, accT>::code_; template <typename TA, typename TB, typename TC, typename accT> thread_local std::map< - std::tuple<bool, int, int>, + std::tuple<bool, int, int, int, int, int, int, int>, typename CodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp> CodeGenBase<TA, TB, TC, accT>::codeCache_; @@ -136,11 +136,45 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>( int32_t nc, int32_t kc, int32_t /* unused */) { - auto kernelSig = std::make_tuple(accum, mc, nc); + std::tuple<bool, int, int, int, int, int, int, int> kernelSig; + int kBlock; + int nBlock; + int mRegBlockSize; + int nRegBlockSize; + int nRegBlockSizeMin; + int row_interleave; + + if (blocking_params) { + kBlock = blocking_params->KCB; + nBlock = blocking_params->NCB; + mRegBlockSize = blocking_params->MR; + nRegBlockSize = blocking_params->NR; + nRegBlockSizeMin = blocking_params->NR_MIN; + row_interleave = blocking_params->ROW_INTERLEAVE; + } else { + kBlock = PackingTraits<uint8_t, int16_t, inst_set_t::avx2>::KCB; + nBlock = PackingTraits<uint8_t, int16_t, inst_set_t::avx2>::NCB; + mRegBlockSize = PackingTraits<uint8_t, int16_t, inst_set_t::avx2>::MR; + nRegBlockSize = PackingTraits<uint8_t, int16_t, inst_set_t::avx2>::NR; + nRegBlockSizeMin = + PackingTraits<uint8_t, int16_t, inst_set_t::avx2>::NR_MIN; + row_interleave = + PackingTraits<uint8_t, int16_t, inst_set_t::avx2>::ROW_INTERLEAVE; + } + + kernelSig = std::make_tuple( + accum, + mc, + nc, + nBlock, + kBlock, + mRegBlockSize, + nRegBlockSize, + nRegBlockSizeMin); + if (codeCache_.find(kernelSig) != codeCache_.end()) { return codeCache_[kernelSig]; } - code_.reset(false); code_.init(rt_.getCodeInfo()); asmjit::X86Assembler assembler(&code_); @@ -148,22 +182,24 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>( #if defined(FBGEMM_LOG_CODE) // generated code logging - FILE* codeLogfile = - fopen(getCodeLoggingFile<inst_set_t::avx2>(accum, mc, nc).c_str(), "w"); + FILE* codeLogfile = fopen( + getCodeLoggingFile<inst_set_t::avx2>( + accum, + mc, + nc, + nBlock, + kBlock, + mRegBlockSize, + nRegBlockSize, + nRegBlockSizeMin) + .c_str(), + "w"); asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); if (codeLogger) { code_.setLogger(codeLogger); } #endif - constexpr int kBlock = PackingTraits<int8_t, int16_t, inst_set_t::avx2>::KCB; - constexpr int nBlock = PackingTraits<int8_t, int16_t, inst_set_t::avx2>::NCB; - constexpr int mRegBlockSize = - PackingTraits<int8_t, int16_t, inst_set_t::avx2>::MR; - // constexpr int nRegBlockSize = - // PackingTraits<int8_t, int16_t, inst_set_t::avx2>::NR; - constexpr int row_interleave = - PackingTraits<int8_t, int16_t, inst_set_t::avx2>::ROW_INTERLEAVE; int mRegBlocks = mc / mRegBlockSize; int mRegBlocksRem = mc % mRegBlockSize; assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc index 2ded242..505fec1 100644 --- a/src/GenerateKernelU8S8S32ACC16Avx512.cc +++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc @@ -131,7 +131,42 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( int32_t nc, int32_t kc, int32_t /* unused */) { - auto kernelSig = std::make_tuple(accum, mc, nc); + std::tuple<bool, int, int, int, int, int, int, int> kernelSig; + int kBlock; + int nBlock; + int mRegBlockSize; + int nRegBlockSize; + int nRegBlockSizeMin; + int row_interleave; + + if (blocking_params) { + kBlock = blocking_params->KCB; + nBlock = blocking_params->NCB; + mRegBlockSize = blocking_params->MR; + nRegBlockSize = blocking_params->NR; + nRegBlockSizeMin = blocking_params->NR_MIN; + row_interleave = blocking_params->ROW_INTERLEAVE; + } else { + kBlock = PackingTraits<uint8_t, int16_t, inst_set_t::avx512>::KCB; + nBlock = PackingTraits<uint8_t, int16_t, inst_set_t::avx512>::NCB; + mRegBlockSize = PackingTraits<uint8_t, int16_t, inst_set_t::avx512>::MR; + nRegBlockSize = PackingTraits<uint8_t, int16_t, inst_set_t::avx512>::NR; + nRegBlockSizeMin = + PackingTraits<uint8_t, int16_t, inst_set_t::avx512>::NR_MIN; + row_interleave = + PackingTraits<uint8_t, int16_t, inst_set_t::avx512>::ROW_INTERLEAVE; + } + + kernelSig = std::make_tuple( + accum, + mc, + nc, + nBlock, + kBlock, + mRegBlockSize, + nRegBlockSize, + nRegBlockSizeMin); + if (codeCache_.find(kernelSig) != codeCache_.end()) { return codeCache_[kernelSig]; } @@ -143,27 +178,24 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( #if defined(FBGEMM_LOG_CODE) // generated code logging - FILE* codeLogfile = - fopen(getCodeLoggingFile<inst_set_t::avx512>(accum, mc, nc).c_str(), "w"); + FILE* codeLogfile = fopen( + getCodeLoggingFile<inst_set_t::avx512>( + accum, + mc, + nc, + nBlock, + kBlock, + mRegBlockSize, + nRegBlockSize, + nRegBlockSizeMin) + .c_str(), + "w"); asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); if (codeLogger) { code_.setLogger(codeLogger); } #endif - constexpr int kBlock = - PackingTraits<int8_t, int16_t, inst_set_t::avx512>::KCB; - constexpr int nBlock = - PackingTraits<int8_t, int16_t, inst_set_t::avx512>::NCB; - constexpr int mRegBlockSize = - PackingTraits<int8_t, int16_t, inst_set_t::avx512>::MR; - constexpr int nRegBlockSize = - PackingTraits<int8_t, int16_t, inst_set_t::avx512>::NR; - constexpr int nRegBlockSizeMin = - PackingTraits<int8_t, int16_t, inst_set_t::avx512>::NR_MIN; - constexpr int row_interleave = - PackingTraits<int8_t, int16_t, inst_set_t::avx512>::ROW_INTERLEAVE; - assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN"); int maxMRegs = mRegBlockSize; @@ -172,7 +204,6 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( maxMRegs * maxNRegs <= 24 && "MR*(NR*ROW_INTERLEAVE*8/512) \ must be <= 24(available registers constraint)"); - int mRegBlocks = mc / mRegBlockSize; int mRegBlocksRem = mc % mRegBlockSize; diff --git a/src/GenerateKernelU8S8S32ACC32.cc b/src/GenerateKernelU8S8S32ACC32.cc index 203dd9a..ca750d9 100644 --- a/src/GenerateKernelU8S8S32ACC32.cc +++ b/src/GenerateKernelU8S8S32ACC32.cc @@ -17,7 +17,7 @@ thread_local asmjit::CodeHolder CodeGenBase<TA, TB, TC, accT>::code_; template <typename TA, typename TB, typename TC, typename accT> thread_local std::map< - std::tuple<bool, int, int>, + std::tuple<bool, int, int, int, int, int, int, int>, typename CodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp> CodeGenBase<TA, TB, TC, accT>::codeCache_; @@ -140,11 +140,45 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>( int32_t nc, int32_t kc, int32_t /* unused */) { - auto kernelSig = std::make_tuple(accum, mc, nc); + std::tuple<bool, int, int, int, int, int, int, int> kernelSig; + int kBlock; + int nBlock; + int mRegBlockSize; + int nRegBlockSize; + int nRegBlockSizeMin; + int row_interleave; + + if (blocking_params) { + kBlock = blocking_params->KCB; + nBlock = blocking_params->NCB; + mRegBlockSize = blocking_params->MR; + nRegBlockSize = blocking_params->NR; + nRegBlockSizeMin = blocking_params->NR_MIN; + row_interleave = blocking_params->ROW_INTERLEAVE; + } else { + kBlock = PackingTraits<uint8_t, int32_t, inst_set_t::avx2>::KCB; + nBlock = PackingTraits<uint8_t, int32_t, inst_set_t::avx2>::NCB; + mRegBlockSize = PackingTraits<uint8_t, int32_t, inst_set_t::avx2>::MR; + nRegBlockSize = PackingTraits<uint8_t, int32_t, inst_set_t::avx2>::NR; + nRegBlockSizeMin = + PackingTraits<uint8_t, int32_t, inst_set_t::avx2>::NR_MIN; + row_interleave = + PackingTraits<uint8_t, int32_t, inst_set_t::avx2>::ROW_INTERLEAVE; + } + + kernelSig = std::make_tuple( + accum, + mc, + nc, + nBlock, + kBlock, + mRegBlockSize, + nRegBlockSize, + nRegBlockSizeMin); + if (codeCache_.find(kernelSig) != codeCache_.end()) { return codeCache_[kernelSig]; } - code_.reset(false); code_.init(rt_.getCodeInfo()); asmjit::X86Assembler assembler(&code_); @@ -152,20 +186,24 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>( #if defined(FBGEMM_LOG_CODE) // generated code logging FILE* codeLogfile = - fopen(getCodeLoggingFile<inst_set_t::avx2>(accum, mc, nc).c_str(), "w"); + FILE* codeLogfile = fopen( + getCodeLoggingFile<inst_set_t::avx2>( + accum, + mc, + nc, + nBlock, + kBlock, + mRegBlockSize, + nRegBlockSize, + nRegBlockSizeMin) + .c_str(), + "w"); asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); if (codeLogger) { code_.setLogger(codeLogger); } #endif - constexpr int kBlock = PackingTraits<int8_t, int32_t, inst_set_t::avx2>::KCB; - constexpr int nBlock = PackingTraits<int8_t, int32_t, inst_set_t::avx2>::NCB; - constexpr int mRegBlockSize = - PackingTraits<int8_t, int32_t, inst_set_t::avx2>::MR; - constexpr int row_interleave = - PackingTraits<int8_t, int32_t, inst_set_t::avx2>::ROW_INTERLEAVE; - assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); // assert(mc <= 12 && "mc must be <= 12 (available registers constraint)"); int mRegBlocks = mc / mRegBlockSize; int mRegBlocksRem = mc % mRegBlockSize; diff --git a/src/GenerateKernelU8S8S32ACC32Avx512.cc b/src/GenerateKernelU8S8S32ACC32Avx512.cc index 333aa9d..d1729e4 100644 --- a/src/GenerateKernelU8S8S32ACC32Avx512.cc +++ b/src/GenerateKernelU8S8S32ACC32Avx512.cc @@ -131,11 +131,45 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>( int32_t nc, int32_t kc, int32_t /* unused */) { - auto kernelSig = std::make_tuple(accum, mc, nc); + std::tuple<bool, int, int, int, int, int, int, int> kernelSig; + int kBlock; + int nBlock; + int mRegBlockSize; + int nRegBlockSize; + int nRegBlockSizeMin; + int row_interleave; + + if (blocking_params) { + kBlock = blocking_params->KCB; + nBlock = blocking_params->NCB; + mRegBlockSize = blocking_params->MR; + nRegBlockSize = blocking_params->NR; + nRegBlockSizeMin = blocking_params->NR_MIN; + row_interleave = blocking_params->ROW_INTERLEAVE; + } else { + kBlock = PackingTraits<uint8_t, int32_t, inst_set_t::avx512>::KCB; + nBlock = PackingTraits<uint8_t, int32_t, inst_set_t::avx512>::NCB; + mRegBlockSize = PackingTraits<uint8_t, int32_t, inst_set_t::avx512>::MR; + nRegBlockSize = PackingTraits<uint8_t, int32_t, inst_set_t::avx512>::NR; + nRegBlockSizeMin = + PackingTraits<uint8_t, int32_t, inst_set_t::avx512>::NR_MIN; + row_interleave = + PackingTraits<uint8_t, int32_t, inst_set_t::avx512>::ROW_INTERLEAVE; + } + + kernelSig = std::make_tuple( + accum, + mc, + nc, + nBlock, + kBlock, + mRegBlockSize, + nRegBlockSize, + nRegBlockSizeMin); + if (codeCache_.find(kernelSig) != codeCache_.end()) { return codeCache_[kernelSig]; } - code_.reset(false); code_.init(rt_.getCodeInfo()); asmjit::X86Assembler assembler(&code_); @@ -143,27 +177,24 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>( #if defined(FBGEMM_LOG_CODE) // generated code logging - FILE* codeLogfile = - fopen(getCodeLoggingFile<inst_set_t::avx512>(accum, mc, nc).c_str(), "w"); + FILE* codeLogfile = fopen( + getCodeLoggingFile<inst_set_t::avx512>( + accum, + mc, + nc, + nBlock, + kBlock, + mRegBlockSize, + nRegBlockSize, + nRegBlockSizeMin) + .c_str(), + "w"); asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); if (codeLogger) { code_.setLogger(codeLogger); } #endif - constexpr int kBlock = - PackingTraits<int8_t, int32_t, inst_set_t::avx512>::KCB; - constexpr int nBlock = - PackingTraits<int8_t, int32_t, inst_set_t::avx512>::NCB; - constexpr int mRegBlockSize = - PackingTraits<int8_t, int32_t, inst_set_t::avx512>::MR; - constexpr int nRegBlockSize = - PackingTraits<int8_t, int32_t, inst_set_t::avx512>::NR; - constexpr int nRegBlockSizeMin = - PackingTraits<int8_t, int32_t, inst_set_t::avx512>::NR_MIN; - constexpr int row_interleave = - PackingTraits<int8_t, int32_t, inst_set_t::avx512>::ROW_INTERLEAVE; - assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN"); int maxMRegs = mRegBlockSize; diff --git a/src/PackAMatrix.cc b/src/PackAMatrix.cc index db019db..89ec13e 100644 --- a/src/PackAMatrix.cc +++ b/src/PackAMatrix.cc @@ -20,24 +20,36 @@ PackAMatrix<T, accT>::PackAMatrix( const T* smat, int32_t ld, inpType* pmat, - int groups) - : PackMatrix<PackAMatrix<T, accT>, T, accT>(nRow, nCol, pmat, groups), + int groups, + const BlockingFactors* params) + : PackMatrix<PackAMatrix<T, accT>, T, accT>( + nRow, + nCol, + pmat, + groups, + params), trans_(trans), smat_(smat), ld_(ld) { - if (fbgemmHasAvx512Support()) { - BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB; - BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; - row_interleave_B_ = - PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE; - } else if (fbgemmHasAvx2Support()) { - BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB; - BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB; - row_interleave_B_ = - PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE; + if (params) { + BaseType::brow_ = params->MCB; + BaseType::bcol_ = params->KCB; + row_interleave_B_ = params->ROW_INTERLEAVE; } else { - // TODO: Have default slower path - assert(0 && "unsupported architecure"); + if (fbgemmHasAvx512Support()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE; + } else if (fbgemmHasAvx2Support()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE; + } else { + // TODO: Have default slower path + assert(0 && "unsupported architecure"); + } } if (BaseType::numCols() % groups != 0) { throw std::runtime_error( @@ -46,8 +58,7 @@ PackAMatrix<T, accT>::PackAMatrix( } if (pmat) { BaseType::buf_ = pmat; - } - else { + } else { BaseType::bufAllocatedHere_ = true; BaseType::buf_ = (T*)fbgemmAlignedAlloc( 64, BaseType::brow_ * BaseType::bcol_ * sizeof(T)); diff --git a/src/PackAWithIm2Col.cc b/src/PackAWithIm2Col.cc index 93408da..fb4556c 100644 --- a/src/PackAWithIm2Col.cc +++ b/src/PackAWithIm2Col.cc @@ -23,7 +23,8 @@ PackAWithIm2Col<T, accT, SPATIAL_DIM>::PackAWithIm2Col( inpType* pmat, int32_t a_zero_pt, int32_t* row_offset, - bool b_symmetric) + bool b_symmetric, + const BlockingFactors* params) : PackMatrix<PackAWithIm2Col<T, accT, SPATIAL_DIM>, T, accT>( conv_p.MB * std::accumulate( @@ -38,25 +39,33 @@ PackAWithIm2Col<T, accT, SPATIAL_DIM>::PackAWithIm2Col( std::multiplies<int>()) * conv_p.IC, pmat, - conv_p.G), + conv_p.G, + params), conv_p_(conv_p), sdata_(sdata), a_zero_pt_(a_zero_pt) { static_assert( SPATIAL_DIM == 2 || SPATIAL_DIM == 3, "unsupported conv dimension "); - if (fbgemmHasAvx512Support()) { - BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB; - BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; - row_interleave_B_ = - PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE; - } else if (fbgemmHasAvx2Support()) { - BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB; - BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB; - row_interleave_B_ = - PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE; + + if (params) { + BaseType::brow_ = params->MCB; + BaseType::bcol_ = params->KCB; + row_interleave_B_ = params->ROW_INTERLEAVE; } else { - // TODO: Have default slower path - assert(0 && "unsupported architecure"); + if (fbgemmHasAvx512Support()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE; + } else if (fbgemmHasAvx2Support()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE; + } else { + // TODO: Have default slower path + assert(0 && "unsupported architecure"); + } } if (BaseType::numCols() % conv_p.G != 0) { throw std::runtime_error( @@ -145,8 +154,7 @@ void pack_a_with_im2col_opt( std::memcpy( out + (i - block.row_start) * BCOL + j + s * IC, sdata + - ((n * IN_DIM_H + h_in) * IN_DIM_W + -PAD_W + w * STRIDE_W + - s) * + ((n * IN_DIM_H + h_in) * IN_DIM_W + -PAD_W + w * STRIDE_W + s) * IC, sizeof(uint8_t) * mid_len * IC); s += mid_len; @@ -459,17 +467,22 @@ void PackAWithIm2Col<T, accT, SPATIAL_DIM>::printPackedMatrix( } template <typename T, typename accT, int SPATIAL_DIM> -int PackAWithIm2Col<T, accT, SPATIAL_DIM>::rowOffsetBufferSize() { +int PackAWithIm2Col<T, accT, SPATIAL_DIM>::rowOffsetBufferSize( + const BlockingFactors* params) { if (cpuinfo_initialize()) { + if (params){ + return params->MCB; + } else { if (fbgemmHasAvx512Support()) { - return PackingTraits<T, accT, inst_set_t::avx512>::MCB; + return PackingTraits<T, accT, inst_set_t::avx512>::MCB; } else if (fbgemmHasAvx2Support()) { - return PackingTraits<T, accT, inst_set_t::avx2>::MCB; - } else { + return PackingTraits<T, accT, inst_set_t::avx2>::MCB; + } else { // TODO: Have default slower path assert(0 && "unsupported architecture"); return -1; } + } } else { throw std::runtime_error("Failed to initialize cpuinfo!"); } diff --git a/src/PackAWithQuantRowOffset.cc b/src/PackAWithQuantRowOffset.cc index 2929ebb..175425f 100644 --- a/src/PackAWithQuantRowOffset.cc +++ b/src/PackAWithQuantRowOffset.cc @@ -28,12 +28,14 @@ PackAWithQuantRowOffset<T, accT>::PackAWithQuantRowOffset( float scale, int32_t zero_pt, int groups, - int32_t* row_offset) + int32_t* row_offset, + const BlockingFactors* params) : PackMatrix<PackAWithQuantRowOffset<T, accT>, T, accT>( nRow, nCol, pmat, - groups), + groups, + params), trans_(trans), smat_(smat), ld_(ld), @@ -41,20 +43,30 @@ PackAWithQuantRowOffset<T, accT>::PackAWithQuantRowOffset( zero_pt_(zero_pt), row_offset_(row_offset) { rowOffsetAllocatedHere = false; - - if (fbgemmHasAvx512Support()) { - BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB; - BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; - row_interleave_B_ = - PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE; - } else if (fbgemmHasAvx2Support()) { - BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB; - BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB; - row_interleave_B_ = - PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE; + if (params) { + if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) { + BaseType::brow_ = params->MCB; + BaseType::bcol_ = params->KCB; + row_interleave_B_ = params->ROW_INTERLEAVE; + } else { + // TODO: Have default slower path + assert(0 && "unsupported architecure"); + } } else { - // TODO: Have default slower path - assert(0 && "unknown architecure"); + if (fbgemmHasAvx512Support()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE; + } else if (fbgemmHasAvx2Support()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE; + } else { + // TODO: Have default slower path + assert(0 && "unknown architecure"); + } } if (BaseType::numCols() % groups != 0) { throw std::runtime_error( @@ -179,15 +191,20 @@ void PackAWithQuantRowOffset<T, accT>::printPackedMatrix(std::string name) { } template <typename T, typename accT> -int PackAWithQuantRowOffset<T, accT>::rowOffsetBufferSize() { +int PackAWithQuantRowOffset<T, accT>::rowOffsetBufferSize( + const BlockingFactors* params) { if (cpuinfo_initialize()) { - if (fbgemmHasAvx512Support()) { - return PackingTraits<T, accT, inst_set_t::avx512>::MCB; - } else if (fbgemmHasAvx2Support()) { - return PackingTraits<T, accT, inst_set_t::avx2>::MCB; + if (params) { + return params->MCB; } else { - assert(0 && "unsupported architecture"); - return -1; + if (fbgemmHasAvx512Support()) { + return PackingTraits<T, accT, inst_set_t::avx512>::MCB; + } else if (fbgemmHasAvx2Support()) { + return PackingTraits<T, accT, inst_set_t::avx2>::MCB; + } else { + assert(0 && "unsupported architecture"); + return -1; + } } } else { throw std::runtime_error("Failed to initialize cpuinfo!"); diff --git a/src/PackAWithRowOffset.cc b/src/PackAWithRowOffset.cc index 7777f1a..139a6d3 100644 --- a/src/PackAWithRowOffset.cc +++ b/src/PackAWithRowOffset.cc @@ -24,31 +24,38 @@ PackAWithRowOffset<T, accT>::PackAWithRowOffset( uint32_t ld, inpType* pmat, int groups, - int32_t* row_offset) + int32_t* row_offset, + const BlockingFactors* params) : PackMatrix<PackAWithRowOffset<T, accT>, T, accT>( nRow, nCol, pmat, - groups), + groups, + params), trans_(trans), smat_(smat), ld_(ld), row_offset_(row_offset) { rowOffsetAllocatedHere = false; - - if (fbgemmHasAvx512Support()) { - BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB; - BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; - row_interleave_B_ = - PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE; - } else if (fbgemmHasAvx2Support()) { - BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB; - BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB; - row_interleave_B_ = - PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE; + if (params) { + BaseType::brow_ = params->MCB; + BaseType::bcol_ = params->KCB; + row_interleave_B_ = params->ROW_INTERLEAVE; } else { - // TODO: Have default slower path - assert(0 && "unknown architecure"); + if (fbgemmHasAvx512Support()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE; + } else if (fbgemmHasAvx2Support()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE; + } else { + // TODO: Have default slower path + assert(0 && "unknown architecure"); + } } if (BaseType::numCols() % groups != 0) { throw std::runtime_error( @@ -169,17 +176,22 @@ void PackAWithRowOffset<T, accT>::printPackedMatrix(std::string name) { } template <typename T, typename accT> -int PackAWithRowOffset<T, accT>::rowOffsetBufferSize() { +int PackAWithRowOffset<T, accT>::rowOffsetBufferSize( + const BlockingFactors* params) { if (cpuinfo_initialize()) { + if (params){ + return params->MCB; + } else { if (fbgemmHasAvx512Support()) { - return PackingTraits<T, accT, inst_set_t::avx512>::MCB; - } else if (fbgemmHasAvx2Support()) { - return PackingTraits<T, accT, inst_set_t::avx2>::MCB; + return PackingTraits<T, accT, inst_set_t::avx512>::MCB; + } else if (fbgemmHasAvx2Support()) { + return PackingTraits<T, accT, inst_set_t::avx2>::MCB; } else { // TODO: Have default slower path assert(0 && "unsupported architecture"); return -1; } + } } else { throw std::runtime_error("Failed to initialize cpuinfo!"); } diff --git a/src/PackBMatrix.cc b/src/PackBMatrix.cc index 48641ff..472c802 100644 --- a/src/PackBMatrix.cc +++ b/src/PackBMatrix.cc @@ -174,23 +174,36 @@ PackBMatrix<T, accT>::PackBMatrix( const T* smat, int32_t ld, inpType* pmat, - int groups) - : PackMatrix<PackBMatrix<T, accT>, T, accT>(nRow, nCol, pmat, groups), + int groups, + const BlockingFactors* params) + : PackMatrix<PackBMatrix<T, accT>, T, accT>( + nRow, + nCol, + pmat, + groups, + params), trans_(trans), smat_(smat), ld_(ld) { - if (fbgemmHasAvx512Support()) { - BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; - BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::NCB; - row_interleave_ = - PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE; - } else if (fbgemmHasAvx2Support()) { - BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB; - BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::NCB; - row_interleave_ = PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE; + if (params) { + BaseType::brow_ = params->KCB; + BaseType::bcol_ = params->NCB; + row_interleave_ = params->ROW_INTERLEAVE; } else { - // Error - assert(0 && "unknown architecure"); + if (fbgemmHasAvx512Support()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::NCB; + row_interleave_ = + PackingTraits<T, accT, inst_set_t::avx512>::ROW_INTERLEAVE; + } else if (fbgemmHasAvx2Support()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx2>::KCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx2>::NCB; + row_interleave_ = + PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE; + } else { + // Error + assert(0 && "unknown architecure"); + } } if (BaseType::numRows() % groups != 0) { throw std::runtime_error( diff --git a/src/PackMatrix.cc b/src/PackMatrix.cc index 316fc06..e93b97c 100644 --- a/src/PackMatrix.cc +++ b/src/PackMatrix.cc @@ -18,33 +18,57 @@ PackMatrix<PT, inpType, accType>::PackMatrix( int32_t rows, int32_t cols, inpType* buf, - int groups) + int groups, + const BlockingFactors* params) : buf_(buf), nrows_(rows), ncols_(cols), G_(groups) { bufAllocatedHere_ = false; + blocking_params = params; if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } } template <typename PT, typename inpType, typename accType> -int PackMatrix<PT, inpType, accType>::packedBufferSize(int rows, int cols) { +int PackMatrix<PT, inpType, accType>::packedBufferSize( + int rows, + int cols, + const BlockingFactors* params) { + int MCB, KCB, NCB; + if (params) { + MCB = params->MCB; + NCB = params->NCB; + KCB = params->KCB; + } else { + if (fbgemmHasAvx512Support()) { + MCB = PackingTraits<inpType, accType, inst_set_t::avx512>::MCB; + NCB = PackingTraits<inpType, accType, inst_set_t::avx512>::NCB; + KCB = PackingTraits<inpType, accType, inst_set_t::avx512>::KCB; + } else if (fbgemmHasAvx2Support()) { + MCB = PackingTraits<inpType, accType, inst_set_t::avx2>::MCB; + NCB = PackingTraits<inpType, accType, inst_set_t::avx2>::NCB; + KCB = PackingTraits<inpType, accType, inst_set_t::avx2>::KCB; + } else { + // TODO: Have default slower path + assert(0 && "unsupported architecure"); + return -1; + } + } + if (fbgemmHasAvx512Support()) { if (isA()) { - return PackingTraits<inpType, accType, inst_set_t::avx512>::MCB * - PackingTraits<inpType, accType, inst_set_t::avx512>::KCB; + return MCB * KCB; } else { - int rowBlock = PackingTraits<inpType, accType, inst_set_t::avx512>::KCB; - int colBlock = PackingTraits<inpType, accType, inst_set_t::avx512>::NCB; + int rowBlock = KCB; + int colBlock = NCB; return (((rows + rowBlock - 1) / rowBlock) * rowBlock) * (((cols + colBlock - 1) / colBlock) * colBlock); } } else if (fbgemmHasAvx2Support()) { if (isA()) { - return PackingTraits<inpType, accType, inst_set_t::avx2>::MCB * - PackingTraits<inpType, accType, inst_set_t::avx2>::KCB; + return MCB * KCB; } else { - int rowBlock = PackingTraits<inpType, accType, inst_set_t::avx2>::KCB; - int colBlock = PackingTraits<inpType, accType, inst_set_t::avx2>::NCB; + int rowBlock = KCB; + int colBlock = NCB; return (((rows + rowBlock - 1) / rowBlock) * rowBlock) * (((cols + colBlock - 1) / colBlock) * colBlock); } |