diff options
Diffstat (limited to 'src/GenerateKernelU8S8S32ACC32Avx512.cc')
-rw-r--r-- | src/GenerateKernelU8S8S32ACC32Avx512.cc | 16 |
1 files changed, 12 insertions, 4 deletions
diff --git a/src/GenerateKernelU8S8S32ACC32Avx512.cc b/src/GenerateKernelU8S8S32ACC32Avx512.cc index 0621bb0..e292fa8 100644 --- a/src/GenerateKernelU8S8S32ACC32Avx512.cc +++ b/src/GenerateKernelU8S8S32ACC32Avx512.cc @@ -144,6 +144,8 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>( constexpr int kBlock = PackingTraits<int8_t, int32_t, inst_set_t::avx512>::KCB; + constexpr int nBlock = + PackingTraits<int8_t, int32_t, inst_set_t::avx512>::NCB; constexpr int mRegBlockSize = PackingTraits<int8_t, int32_t, inst_set_t::avx512>::MR; constexpr int row_interleave = @@ -239,8 +241,11 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>( // update buffer_B address for next k iteration a->add( - buffer_B, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t))); - a->add(B_pf, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t))); + buffer_B, + static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); + a->add( + B_pf, + static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float))); @@ -291,8 +296,11 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>( // update buffer_B address for next k iteration a->add( - buffer_B, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t))); - a->add(B_pf, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t))); + buffer_B, + static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); + a->add( + B_pf, + static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); a->cmp(kIdx, kSize); a->jl(LoopkRem); |