diff options
Diffstat (limited to 'src/GenerateKernelU8S8S32ACC32.cc')
-rw-r--r-- | src/GenerateKernelU8S8S32ACC32.cc | 60 |
1 files changed, 49 insertions, 11 deletions
diff --git a/src/GenerateKernelU8S8S32ACC32.cc b/src/GenerateKernelU8S8S32ACC32.cc index 203dd9a..ca750d9 100644 --- a/src/GenerateKernelU8S8S32ACC32.cc +++ b/src/GenerateKernelU8S8S32ACC32.cc @@ -17,7 +17,7 @@ thread_local asmjit::CodeHolder CodeGenBase<TA, TB, TC, accT>::code_; template <typename TA, typename TB, typename TC, typename accT> thread_local std::map< - std::tuple<bool, int, int>, + std::tuple<bool, int, int, int, int, int, int, int>, typename CodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp> CodeGenBase<TA, TB, TC, accT>::codeCache_; @@ -140,11 +140,45 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>( int32_t nc, int32_t kc, int32_t /* unused */) { - auto kernelSig = std::make_tuple(accum, mc, nc); + std::tuple<bool, int, int, int, int, int, int, int> kernelSig; + int kBlock; + int nBlock; + int mRegBlockSize; + int nRegBlockSize; + int nRegBlockSizeMin; + int row_interleave; + + if (blocking_params) { + kBlock = blocking_params->KCB; + nBlock = blocking_params->NCB; + mRegBlockSize = blocking_params->MR; + nRegBlockSize = blocking_params->NR; + nRegBlockSizeMin = blocking_params->NR_MIN; + row_interleave = blocking_params->ROW_INTERLEAVE; + } else { + kBlock = PackingTraits<uint8_t, int32_t, inst_set_t::avx2>::KCB; + nBlock = PackingTraits<uint8_t, int32_t, inst_set_t::avx2>::NCB; + mRegBlockSize = PackingTraits<uint8_t, int32_t, inst_set_t::avx2>::MR; + nRegBlockSize = PackingTraits<uint8_t, int32_t, inst_set_t::avx2>::NR; + nRegBlockSizeMin = + PackingTraits<uint8_t, int32_t, inst_set_t::avx2>::NR_MIN; + row_interleave = + PackingTraits<uint8_t, int32_t, inst_set_t::avx2>::ROW_INTERLEAVE; + } + + kernelSig = std::make_tuple( + accum, + mc, + nc, + nBlock, + kBlock, + mRegBlockSize, + nRegBlockSize, + nRegBlockSizeMin); + if (codeCache_.find(kernelSig) != codeCache_.end()) { return codeCache_[kernelSig]; } - code_.reset(false); code_.init(rt_.getCodeInfo()); asmjit::X86Assembler assembler(&code_); @@ -152,20 +186,24 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>( #if defined(FBGEMM_LOG_CODE) // generated code logging FILE* codeLogfile = - fopen(getCodeLoggingFile<inst_set_t::avx2>(accum, mc, nc).c_str(), "w"); + FILE* codeLogfile = fopen( + getCodeLoggingFile<inst_set_t::avx2>( + accum, + mc, + nc, + nBlock, + kBlock, + mRegBlockSize, + nRegBlockSize, + nRegBlockSizeMin) + .c_str(), + "w"); asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); if (codeLogger) { code_.setLogger(codeLogger); } #endif - constexpr int kBlock = PackingTraits<int8_t, int32_t, inst_set_t::avx2>::KCB; - constexpr int nBlock = PackingTraits<int8_t, int32_t, inst_set_t::avx2>::NCB; - constexpr int mRegBlockSize = - PackingTraits<int8_t, int32_t, inst_set_t::avx2>::MR; - constexpr int row_interleave = - PackingTraits<int8_t, int32_t, inst_set_t::avx2>::ROW_INTERLEAVE; - assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); // assert(mc <= 12 && "mc must be <= 12 (available registers constraint)"); int mRegBlocks = mc / mRegBlockSize; int mRegBlocksRem = mc % mRegBlockSize; |