diff options
Diffstat (limited to 'src/GenerateKernelU8S8S32ACC16Avx512.cc')
-rw-r--r-- | src/GenerateKernelU8S8S32ACC16Avx512.cc | 65 |
1 files changed, 48 insertions, 17 deletions
diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc index 2ded242..505fec1 100644 --- a/src/GenerateKernelU8S8S32ACC16Avx512.cc +++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc @@ -131,7 +131,42 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( int32_t nc, int32_t kc, int32_t /* unused */) { - auto kernelSig = std::make_tuple(accum, mc, nc); + std::tuple<bool, int, int, int, int, int, int, int> kernelSig; + int kBlock; + int nBlock; + int mRegBlockSize; + int nRegBlockSize; + int nRegBlockSizeMin; + int row_interleave; + + if (blocking_params) { + kBlock = blocking_params->KCB; + nBlock = blocking_params->NCB; + mRegBlockSize = blocking_params->MR; + nRegBlockSize = blocking_params->NR; + nRegBlockSizeMin = blocking_params->NR_MIN; + row_interleave = blocking_params->ROW_INTERLEAVE; + } else { + kBlock = PackingTraits<uint8_t, int16_t, inst_set_t::avx512>::KCB; + nBlock = PackingTraits<uint8_t, int16_t, inst_set_t::avx512>::NCB; + mRegBlockSize = PackingTraits<uint8_t, int16_t, inst_set_t::avx512>::MR; + nRegBlockSize = PackingTraits<uint8_t, int16_t, inst_set_t::avx512>::NR; + nRegBlockSizeMin = + PackingTraits<uint8_t, int16_t, inst_set_t::avx512>::NR_MIN; + row_interleave = + PackingTraits<uint8_t, int16_t, inst_set_t::avx512>::ROW_INTERLEAVE; + } + + kernelSig = std::make_tuple( + accum, + mc, + nc, + nBlock, + kBlock, + mRegBlockSize, + nRegBlockSize, + nRegBlockSizeMin); + if (codeCache_.find(kernelSig) != codeCache_.end()) { return codeCache_[kernelSig]; } @@ -143,27 +178,24 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( #if defined(FBGEMM_LOG_CODE) // generated code logging - FILE* codeLogfile = - fopen(getCodeLoggingFile<inst_set_t::avx512>(accum, mc, nc).c_str(), "w"); + FILE* codeLogfile = fopen( + getCodeLoggingFile<inst_set_t::avx512>( + accum, + mc, + nc, + nBlock, + kBlock, + mRegBlockSize, + nRegBlockSize, + nRegBlockSizeMin) + .c_str(), + "w"); asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); if (codeLogger) { code_.setLogger(codeLogger); } #endif - constexpr int kBlock = - PackingTraits<int8_t, int16_t, inst_set_t::avx512>::KCB; - constexpr int nBlock = - PackingTraits<int8_t, int16_t, inst_set_t::avx512>::NCB; - constexpr int mRegBlockSize = - PackingTraits<int8_t, int16_t, inst_set_t::avx512>::MR; - constexpr int nRegBlockSize = - PackingTraits<int8_t, int16_t, inst_set_t::avx512>::NR; - constexpr int nRegBlockSizeMin = - PackingTraits<int8_t, int16_t, inst_set_t::avx512>::NR_MIN; - constexpr int row_interleave = - PackingTraits<int8_t, int16_t, inst_set_t::avx512>::ROW_INTERLEAVE; - assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN"); int maxMRegs = mRegBlockSize; @@ -172,7 +204,6 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( maxMRegs * maxNRegs <= 24 && "MR*(NR*ROW_INTERLEAVE*8/512) \ must be <= 24(available registers constraint)"); - int mRegBlocks = mc / mRegBlockSize; int mRegBlocksRem = mc % mRegBlockSize; |