Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/GenerateKernelU8S8S32ACC32Avx512.cc')
-rw-r--r--src/GenerateKernelU8S8S32ACC32Avx512.cc8
1 files changed, 7 insertions, 1 deletions
diff --git a/src/GenerateKernelU8S8S32ACC32Avx512.cc b/src/GenerateKernelU8S8S32ACC32Avx512.cc
index 0dcc321..333aa9d 100644
--- a/src/GenerateKernelU8S8S32ACC32Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC32Avx512.cc
@@ -159,11 +159,13 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
PackingTraits<int8_t, int32_t, inst_set_t::avx512>::MR;
constexpr int nRegBlockSize =
PackingTraits<int8_t, int32_t, inst_set_t::avx512>::NR;
+ constexpr int nRegBlockSizeMin =
+ PackingTraits<int8_t, int32_t, inst_set_t::avx512>::NR_MIN;
constexpr int row_interleave =
PackingTraits<int8_t, int32_t, inst_set_t::avx512>::ROW_INTERLEAVE;
assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
- assert(nc % nRegBlockSize == 0 && "nc must be a multiple of NR");
+ assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
int maxMRegs = mRegBlockSize;
int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
assert(
@@ -301,6 +303,8 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
a->cmp(jIdx, jLoopTrips);
a->jl(LoopNBlocks);
@@ -382,6 +386,8 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
a->cmp(jIdx, jLoopTrips);
a->jl(LoopNRem);
}