Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/GenerateKernelU8S8S32ACC16Avx512.cc')
-rw-r--r--src/GenerateKernelU8S8S32ACC16Avx512.cc65
1 files changed, 48 insertions, 17 deletions
diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc
index 2ded242..505fec1 100644
--- a/src/GenerateKernelU8S8S32ACC16Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc
@@ -131,7 +131,42 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
int32_t nc,
int32_t kc,
int32_t /* unused */) {
- auto kernelSig = std::make_tuple(accum, mc, nc);
+ std::tuple<bool, int, int, int, int, int, int, int> kernelSig;
+ int kBlock;
+ int nBlock;
+ int mRegBlockSize;
+ int nRegBlockSize;
+ int nRegBlockSizeMin;
+ int row_interleave;
+
+ if (blocking_params) {
+ kBlock = blocking_params->KCB;
+ nBlock = blocking_params->NCB;
+ mRegBlockSize = blocking_params->MR;
+ nRegBlockSize = blocking_params->NR;
+ nRegBlockSizeMin = blocking_params->NR_MIN;
+ row_interleave = blocking_params->ROW_INTERLEAVE;
+ } else {
+ kBlock = PackingTraits<uint8_t, int16_t, inst_set_t::avx512>::KCB;
+ nBlock = PackingTraits<uint8_t, int16_t, inst_set_t::avx512>::NCB;
+ mRegBlockSize = PackingTraits<uint8_t, int16_t, inst_set_t::avx512>::MR;
+ nRegBlockSize = PackingTraits<uint8_t, int16_t, inst_set_t::avx512>::NR;
+ nRegBlockSizeMin =
+ PackingTraits<uint8_t, int16_t, inst_set_t::avx512>::NR_MIN;
+ row_interleave =
+ PackingTraits<uint8_t, int16_t, inst_set_t::avx512>::ROW_INTERLEAVE;
+ }
+
+ kernelSig = std::make_tuple(
+ accum,
+ mc,
+ nc,
+ nBlock,
+ kBlock,
+ mRegBlockSize,
+ nRegBlockSize,
+ nRegBlockSizeMin);
+
if (codeCache_.find(kernelSig) != codeCache_.end()) {
return codeCache_[kernelSig];
}
@@ -143,27 +178,24 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
#if defined(FBGEMM_LOG_CODE)
// generated code logging
- FILE* codeLogfile =
- fopen(getCodeLoggingFile<inst_set_t::avx512>(accum, mc, nc).c_str(), "w");
+ FILE* codeLogfile = fopen(
+ getCodeLoggingFile<inst_set_t::avx512>(
+ accum,
+ mc,
+ nc,
+ nBlock,
+ kBlock,
+ mRegBlockSize,
+ nRegBlockSize,
+ nRegBlockSizeMin)
+ .c_str(),
+ "w");
asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
if (codeLogger) {
code_.setLogger(codeLogger);
}
#endif
- constexpr int kBlock =
- PackingTraits<int8_t, int16_t, inst_set_t::avx512>::KCB;
- constexpr int nBlock =
- PackingTraits<int8_t, int16_t, inst_set_t::avx512>::NCB;
- constexpr int mRegBlockSize =
- PackingTraits<int8_t, int16_t, inst_set_t::avx512>::MR;
- constexpr int nRegBlockSize =
- PackingTraits<int8_t, int16_t, inst_set_t::avx512>::NR;
- constexpr int nRegBlockSizeMin =
- PackingTraits<int8_t, int16_t, inst_set_t::avx512>::NR_MIN;
- constexpr int row_interleave =
- PackingTraits<int8_t, int16_t, inst_set_t::avx512>::ROW_INTERLEAVE;
-
assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
int maxMRegs = mRegBlockSize;
@@ -172,7 +204,6 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
maxMRegs * maxNRegs <= 24 &&
"MR*(NR*ROW_INTERLEAVE*8/512) \
must be <= 24(available registers constraint)");
-
int mRegBlocks = mc / mRegBlockSize;
int mRegBlocksRem = mc % mRegBlockSize;