diff options
Diffstat (limited to 'src/GenerateKernelU8S8S32ACC32Avx512.cc')
-rw-r--r-- | src/GenerateKernelU8S8S32ACC32Avx512.cc | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/src/GenerateKernelU8S8S32ACC32Avx512.cc b/src/GenerateKernelU8S8S32ACC32Avx512.cc index 0dcc321..333aa9d 100644 --- a/src/GenerateKernelU8S8S32ACC32Avx512.cc +++ b/src/GenerateKernelU8S8S32ACC32Avx512.cc @@ -159,11 +159,13 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>( PackingTraits<int8_t, int32_t, inst_set_t::avx512>::MR; constexpr int nRegBlockSize = PackingTraits<int8_t, int32_t, inst_set_t::avx512>::NR; + constexpr int nRegBlockSizeMin = + PackingTraits<int8_t, int32_t, inst_set_t::avx512>::NR_MIN; constexpr int row_interleave = PackingTraits<int8_t, int32_t, inst_set_t::avx512>::ROW_INTERLEAVE; assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); - assert(nc % nRegBlockSize == 0 && "nc must be a multiple of NR"); + assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN"); int maxMRegs = mRegBlockSize; int maxNRegs = nRegBlockSize * row_interleave / VLEN_; assert( @@ -301,6 +303,8 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>( a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); int jLoopTrips = currColRegs / maxNRegs; + // jLoopTrips should be at least 1 + jLoopTrips = jLoopTrips ? jLoopTrips : 1; a->cmp(jIdx, jLoopTrips); a->jl(LoopNBlocks); @@ -382,6 +386,8 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>( a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); int jLoopTrips = currColRegs / maxNRegs; + // jLoopTrips should be at least 1 + jLoopTrips = jLoopTrips ? jLoopTrips : 1; a->cmp(jIdx, jLoopTrips); a->jl(LoopNRem); } |