Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAleks Zi <zlateski@fb.com>2019-09-16 21:03:32 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-09-16 21:07:52 +0300
commit53f0c0d175ae4283609a5b251052f9c6598b8aee (patch)
treea9b3b48a081e8a296c8e4aa046fbde08153e6b5a
parent96f2b9db2ea2972b6b8c04ed165a1854220a5e0b (diff)
A bit more refactoring
Summary: Small refactor of the avx2 acc32 generator Reviewed By: dskhudia Differential Revision: D17138005 fbshipit-source-id: 06ded92c5bebb35070a45578feb96e418f8d8489
-rw-r--r--src/GenerateKernelU8S8S32ACC32.cc73
1 files changed, 23 insertions, 50 deletions
diff --git a/src/GenerateKernelU8S8S32ACC32.cc b/src/GenerateKernelU8S8S32ACC32.cc
index a0fe26c..226e974 100644
--- a/src/GenerateKernelU8S8S32ACC32.cc
+++ b/src/GenerateKernelU8S8S32ACC32.cc
@@ -218,7 +218,6 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>(
a->emitProlog(frame);
a->emitArgsAssignment(frame, args);
- asmjit::Label Loopk = a->newLabel();
asmjit::Label LoopMBlocks = a->newLabel();
x86::Gp buffer_B_saved = a->gpz(10);
@@ -238,25 +237,16 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>(
a->mov(C_Offset, 0);
int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
- if (mRegBlocks > 0) {
- // move 0 to iteration variables
- a->mov(iIdx, 0);
- // save B_buffer address
- a->mov(buffer_B_saved, buffer_B);
- a->mov(B_pf_saved, B_pf);
+ auto issueLoopOverK = [&](int rowRegs) {
+ asmjit::Label LoopKLabel = a->newLabel();
- a->bind(LoopMBlocks);
- a->inc(iIdx);
-
- int rowRegs = mRegBlockSize;
-
- // init C registers
+ // Init C (result) vector registers
initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, colRegs);
- // init k loop index
+ // Loops over K
a->mov(kIdx, 0);
- a->bind(Loopk);
+ a->bind(LoopKLabel);
// k is incremented by row_interleave
a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
@@ -274,15 +264,28 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>(
a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
sizeof(int8_t)));
- // a->add(B_pf, 32*sizeof(float));
-
a->cmp(kIdx, kSize);
- a->jl(Loopk);
+ a->jl(LoopKLabel);
// store C matrix
storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum,
colRegs);
+ };
+
+ if (mRegBlocks > 0) {
+ // move 0 to iteration variables
+ a->mov(iIdx, 0);
+
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ a->mov(B_pf_saved, B_pf);
+ a->bind(LoopMBlocks);
+ a->inc(iIdx);
+
+ issueLoopOverK(mRegBlockSize);
+
+ int rowRegs = mRegBlockSize;
// increment A for next block
a->sub(buffer_A, kSize);
a->add(buffer_A,
@@ -296,43 +299,13 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>(
// reset B
a->mov(buffer_B, buffer_B_saved);
a->mov(B_pf, B_pf_saved);
+
a->cmp(iIdx, mRegBlocks);
a->jl(LoopMBlocks);
}
// generate code for remainder
if (mRegBlocksRem > 0) {
- asmjit::Label LoopkRem = a->newLabel();
- int rowRegs = mRegBlocksRem;
-
- // init C registers
- initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(LoopkRem);
-
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx2>(a, buffer_A, buffer_B, B_pf, rowRegs,
- colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(buffer_A,
- static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
- sizeof(int8_t)));
- a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
- sizeof(int8_t)));
-
- a->cmp(kIdx, kSize);
- a->jl(LoopkRem);
-
- // store C matrix
- storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum,
- colRegs);
+ issueLoopOverK(mRegBlocksRem);
}
a->emitEpilog(frame);