From 53f0c0d175ae4283609a5b251052f9c6598b8aee Mon Sep 17 00:00:00 2001 From: Aleks Zi Date: Mon, 16 Sep 2019 11:03:32 -0700 Subject: A bit more refactoring Summary: Small refactor of the avx2 acc32 generator Reviewed By: dskhudia Differential Revision: D17138005 fbshipit-source-id: 06ded92c5bebb35070a45578feb96e418f8d8489 --- src/GenerateKernelU8S8S32ACC32.cc | 73 ++++++++++++--------------------------- 1 file changed, 23 insertions(+), 50 deletions(-) diff --git a/src/GenerateKernelU8S8S32ACC32.cc b/src/GenerateKernelU8S8S32ACC32.cc index a0fe26c..226e974 100644 --- a/src/GenerateKernelU8S8S32ACC32.cc +++ b/src/GenerateKernelU8S8S32ACC32.cc @@ -218,7 +218,6 @@ CodeGenBase::getOrCreate( a->emitProlog(frame); a->emitArgsAssignment(frame, args); - asmjit::Label Loopk = a->newLabel(); asmjit::Label LoopMBlocks = a->newLabel(); x86::Gp buffer_B_saved = a->gpz(10); @@ -238,25 +237,16 @@ CodeGenBase::getOrCreate( a->mov(C_Offset, 0); int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; - if (mRegBlocks > 0) { - // move 0 to iteration variables - a->mov(iIdx, 0); - // save B_buffer address - a->mov(buffer_B_saved, buffer_B); - a->mov(B_pf_saved, B_pf); + auto issueLoopOverK = [&](int rowRegs) { + asmjit::Label LoopKLabel = a->newLabel(); - a->bind(LoopMBlocks); - a->inc(iIdx); - - int rowRegs = mRegBlockSize; - - // init C registers + // Init C (result) vector registers initCRegs(a, rowRegs, colRegs, colRegs); - // init k loop index + // Loops over K a->mov(kIdx, 0); - a->bind(Loopk); + a->bind(LoopKLabel); // k is incremented by row_interleave a->add(kIdx, static_cast(row_interleave)); @@ -274,15 +264,28 @@ CodeGenBase::getOrCreate( a->add(B_pf, static_cast(nBlock * row_interleave * sizeof(int8_t))); - // a->add(B_pf, 32*sizeof(float)); - a->cmp(kIdx, kSize); - a->jl(Loopk); + a->jl(LoopKLabel); // store C matrix storeCRegs(a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); + }; + + if (mRegBlocks > 0) { + // move 0 to iteration variables + a->mov(iIdx, 0); + + // save B_buffer address + a->mov(buffer_B_saved, buffer_B); + a->mov(B_pf_saved, B_pf); + a->bind(LoopMBlocks); + a->inc(iIdx); + + issueLoopOverK(mRegBlockSize); + + int rowRegs = mRegBlockSize; // increment A for next block a->sub(buffer_A, kSize); a->add(buffer_A, @@ -296,43 +299,13 @@ CodeGenBase::getOrCreate( // reset B a->mov(buffer_B, buffer_B_saved); a->mov(B_pf, B_pf_saved); + a->cmp(iIdx, mRegBlocks); a->jl(LoopMBlocks); } // generate code for remainder if (mRegBlocksRem > 0) { - asmjit::Label LoopkRem = a->newLabel(); - int rowRegs = mRegBlocksRem; - - // init C registers - initCRegs(a, rowRegs, colRegs, colRegs); - - // init k loop index - a->mov(kIdx, 0); - a->bind(LoopkRem); - - // k is incremented by row_interleave - a->add(kIdx, static_cast(row_interleave)); - - genComputeBlock(a, buffer_A, buffer_B, B_pf, rowRegs, - colRegs, kBlock, colRegs); - - // update buffer_A address for next k iteration - a->add(buffer_A, - static_cast(row_interleave * sizeof(uint8_t))); - - // update buffer_B address for next k iteration - a->add(buffer_B, static_cast(nBlock * row_interleave * - sizeof(int8_t))); - a->add(B_pf, static_cast(nBlock * row_interleave * - sizeof(int8_t))); - - a->cmp(kIdx, kSize); - a->jl(LoopkRem); - - // store C matrix - storeCRegs(a, rowRegs, colRegs, C_Offset, ldcReg, accum, - colRegs); + issueLoopOverK(mRegBlocksRem); } a->emitEpilog(frame); -- cgit v1.2.3