diff options
author | Daya S Khudia <dskhudia@fb.com> | 2019-03-21 20:03:36 +0300 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-03-21 20:07:54 +0300 |
commit | d53c0220cf1749802736bba192c5e37f430df7a0 (patch) | |
tree | 49698cc737645bfe604a899d596c5bfe9325580e /src | |
parent | fe1c3d91772703a9c4f00fa04a9acaeeeffcf83c (diff) |
Further optimize acc16 kernel and cache blocking dimension for B matrix is now free to be autotuned (#88)
Summary:
Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/88
acc16 version
We have one more loop (over NR tiles in NCB block) in the generated assembly kernel. This change also frees NCB as an independent dimension that can be auto-tuned.
Reviewed By: jianyuh
Differential Revision: D14516232
fbshipit-source-id: f9bac9e7cdd3c89135d35a61c59a275c9a76562b
Diffstat (limited to 'src')
-rw-r--r-- | src/GenerateKernelU8S8S32ACC16Avx512.cc | 106 |
1 files changed, 85 insertions, 21 deletions
diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc index cd230c5..6f3f276 100644 --- a/src/GenerateKernelU8S8S32ACC16Avx512.cc +++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc @@ -150,15 +150,22 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( PackingTraits<int8_t, int16_t, inst_set_t::avx512>::NCB; constexpr int mRegBlockSize = PackingTraits<int8_t, int16_t, inst_set_t::avx512>::MR; - // constexpr int nRegBlockSize = - // PackingTraits<int8_t, int16_t, inst_set_t::avx512>::NR; + constexpr int nRegBlockSize = + PackingTraits<int8_t, int16_t, inst_set_t::avx512>::NR; constexpr int row_interleave = PackingTraits<int8_t, int16_t, inst_set_t::avx512>::ROW_INTERLEAVE; + + assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); + assert(nc % nRegBlockSize == 0 && "nc must be a multiple of NR"); + int maxMRegs = mRegBlockSize; + int maxNRegs = nRegBlockSize * row_interleave / VLEN_; + assert( + maxMRegs * maxNRegs <= 28 && + "MR*(NR*ROW_INTERLEAVE*8/512) \ + must be <= 28(available registers constraint)"); + int mRegBlocks = mc / mRegBlockSize; int mRegBlocksRem = mc % mRegBlockSize; - assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); - // assert((nc == nRegBlockSize) && - //"nc must be equal to the number of register blocks"); // arguments to the function created asmjit::X86Gp buffer_A = a->zdi(); @@ -180,7 +187,8 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); ffi.setDirtyRegs( - asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14)); + asmjit::X86Reg::kKindGp, + asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); asmjit::FuncArgsMapper args(&func); args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); @@ -193,31 +201,38 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( asmjit::FuncUtils::emitProlog(a, layout); asmjit::FuncUtils::allocArgs(a, layout, args); - asmjit::Label Loopk = a->newLabel(); asmjit::Label LoopMBlocks = a->newLabel(); + asmjit::Label LoopNBlocks = a->newLabel(); + asmjit::Label Loopk = a->newLabel(); asmjit::X86Gp buffer_B_saved = a->gpzRef(10); asmjit::X86Gp C_Offset = a->gpzRef(11); // asmjit::X86Gp B_pf_saved = a->gpzRef(12); asmjit::X86Gp iIdx = a->gpzRef(13); - asmjit::X86Gp kIdx = a->gpzRef(14); + asmjit::X86Gp jIdx = a->gpzRef(14); + asmjit::X86Gp kIdx = a->gpzRef(15); - int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; + // save B_buffer address + a->mov(buffer_B_saved, buffer_B); + // a->mov(B_pf_saved, B_pf); + + int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; + int colRegs = std::min(currColRegs, maxNRegs); if (mRegBlocks > 0) { // move 0 to iteration variables a->mov(iIdx, 0); - // save B_buffer address - a->mov(buffer_B_saved, buffer_B); - // a->mov(B_pf_saved, B_pf); - a->bind(LoopMBlocks); a->inc(iIdx); + a->mov(jIdx, 0); + + a->bind(LoopNBlocks); + a->inc(jIdx); int rowRegs = mRegBlockSize; // init C registers - initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs); + initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs); // init k loop index a->mov(kIdx, 0); @@ -226,7 +241,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); genComputeBlock<inst_set_t::avx512>( - a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); + a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); // update buffer_A address for next k iteration a->add( @@ -244,16 +259,40 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( // store C matrix storeCRegs<inst_set_t::avx512>( - a, rowRegs, colRegs, C_Offset, ldcReg, accum); + a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); - // increment A for next block + // reset A a->sub(buffer_A, kSize); + + // B for next block + a->mov(buffer_B, buffer_B_saved); + // using C_Offset as temp reg + a->imul( + C_Offset, + jIdx, + static_cast<asmjit::Imm>( + nRegBlockSize * row_interleave * sizeof(int8_t))); + a->add(buffer_B, C_Offset); + + // increment C for next block + a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); + + int jLoopTrips = currColRegs / maxNRegs; + a->cmp(jIdx, jLoopTrips); + a->jl(LoopNBlocks); + + // increment A for next block a->add( buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t))); - // increment C for next block + + // increment C for next A block + a->sub( + CBase, + static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * sizeof(int32_t))); a->imul( C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t))); a->add(CBase, C_Offset); + // reset B a->mov(buffer_B, buffer_B_saved); // a->mov(B_pf, B_pf_saved); @@ -263,11 +302,16 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( } // generate code for remainder if (mRegBlocksRem > 0) { + asmjit::Label LoopNRem = a->newLabel(); asmjit::Label LoopkRem = a->newLabel(); int rowRegs = mRegBlocksRem; + a->mov(jIdx, 0); + a->bind(LoopNRem); + a->inc(jIdx); + // init C registers - initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs); + initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs); // init k loop index a->mov(kIdx, 0); @@ -277,7 +321,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); genComputeBlock<inst_set_t::avx512>( - a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); + a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); // update buffer_A address for next k iteration a->add( @@ -293,9 +337,29 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( a->cmp(kIdx, kSize); a->jl(LoopkRem); + // reset A + a->sub(buffer_A, kSize); + + // B for next block + a->mov(buffer_B, buffer_B_saved); + // using C_Offset as temp reg + a->imul( + C_Offset, + jIdx, + static_cast<asmjit::Imm>( + nRegBlockSize * row_interleave * sizeof(int8_t))); + a->add(buffer_B, C_Offset); + // store C matrix storeCRegs<inst_set_t::avx512>( - a, rowRegs, colRegs, C_Offset, ldcReg, accum); + a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); + + // increment C for next block + a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); + + int jLoopTrips = currColRegs / maxNRegs; + a->cmp(jIdx, jLoopTrips); + a->jl(LoopNRem); } asmjit::FuncUtils::emitEpilog(a, layout); |