Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/GenerateKernelU8S8S32ACC32.cc')
-rw-r--r--src/GenerateKernelU8S8S32ACC32.cc60
1 files changed, 49 insertions, 11 deletions
diff --git a/src/GenerateKernelU8S8S32ACC32.cc b/src/GenerateKernelU8S8S32ACC32.cc
index 203dd9a..ca750d9 100644
--- a/src/GenerateKernelU8S8S32ACC32.cc
+++ b/src/GenerateKernelU8S8S32ACC32.cc
@@ -17,7 +17,7 @@ thread_local asmjit::CodeHolder CodeGenBase<TA, TB, TC, accT>::code_;
template <typename TA, typename TB, typename TC, typename accT>
thread_local std::map<
- std::tuple<bool, int, int>,
+ std::tuple<bool, int, int, int, int, int, int, int>,
typename CodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp>
CodeGenBase<TA, TB, TC, accT>::codeCache_;
@@ -140,11 +140,45 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>(
int32_t nc,
int32_t kc,
int32_t /* unused */) {
- auto kernelSig = std::make_tuple(accum, mc, nc);
+ std::tuple<bool, int, int, int, int, int, int, int> kernelSig;
+ int kBlock;
+ int nBlock;
+ int mRegBlockSize;
+ int nRegBlockSize;
+ int nRegBlockSizeMin;
+ int row_interleave;
+
+ if (blocking_params) {
+ kBlock = blocking_params->KCB;
+ nBlock = blocking_params->NCB;
+ mRegBlockSize = blocking_params->MR;
+ nRegBlockSize = blocking_params->NR;
+ nRegBlockSizeMin = blocking_params->NR_MIN;
+ row_interleave = blocking_params->ROW_INTERLEAVE;
+ } else {
+ kBlock = PackingTraits<uint8_t, int32_t, inst_set_t::avx2>::KCB;
+ nBlock = PackingTraits<uint8_t, int32_t, inst_set_t::avx2>::NCB;
+ mRegBlockSize = PackingTraits<uint8_t, int32_t, inst_set_t::avx2>::MR;
+ nRegBlockSize = PackingTraits<uint8_t, int32_t, inst_set_t::avx2>::NR;
+ nRegBlockSizeMin =
+ PackingTraits<uint8_t, int32_t, inst_set_t::avx2>::NR_MIN;
+ row_interleave =
+ PackingTraits<uint8_t, int32_t, inst_set_t::avx2>::ROW_INTERLEAVE;
+ }
+
+ kernelSig = std::make_tuple(
+ accum,
+ mc,
+ nc,
+ nBlock,
+ kBlock,
+ mRegBlockSize,
+ nRegBlockSize,
+ nRegBlockSizeMin);
+
if (codeCache_.find(kernelSig) != codeCache_.end()) {
return codeCache_[kernelSig];
}
-
code_.reset(false);
code_.init(rt_.getCodeInfo());
asmjit::X86Assembler assembler(&code_);
@@ -152,20 +186,24 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>(
#if defined(FBGEMM_LOG_CODE)
// generated code logging
FILE* codeLogfile =
- fopen(getCodeLoggingFile<inst_set_t::avx2>(accum, mc, nc).c_str(), "w");
+ FILE* codeLogfile = fopen(
+ getCodeLoggingFile<inst_set_t::avx2>(
+ accum,
+ mc,
+ nc,
+ nBlock,
+ kBlock,
+ mRegBlockSize,
+ nRegBlockSize,
+ nRegBlockSizeMin)
+ .c_str(),
+ "w");
asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
if (codeLogger) {
code_.setLogger(codeLogger);
}
#endif
- constexpr int kBlock = PackingTraits<int8_t, int32_t, inst_set_t::avx2>::KCB;
- constexpr int nBlock = PackingTraits<int8_t, int32_t, inst_set_t::avx2>::NCB;
- constexpr int mRegBlockSize =
- PackingTraits<int8_t, int32_t, inst_set_t::avx2>::MR;
- constexpr int row_interleave =
- PackingTraits<int8_t, int32_t, inst_set_t::avx2>::ROW_INTERLEAVE;
- assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
// assert(mc <= 12 && "mc must be <= 12 (available registers constraint)");
int mRegBlocks = mc / mRegBlockSize;
int mRegBlocksRem = mc % mRegBlockSize;