Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDaya S Khudia <dskhudia@fb.com>2019-03-21 20:03:36 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-03-21 20:07:54 +0300
commitd53c0220cf1749802736bba192c5e37f430df7a0 (patch)
tree49698cc737645bfe604a899d596c5bfe9325580e /src
parentfe1c3d91772703a9c4f00fa04a9acaeeeffcf83c (diff)
Further optimize acc16 kernel and cache blocking dimension for B matrix is now free to be autotuned (#88)
Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/88 acc16 version We have one more loop (over NR tiles in NCB block) in the generated assembly kernel. This change also frees NCB as an independent dimension that can be auto-tuned. Reviewed By: jianyuh Differential Revision: D14516232 fbshipit-source-id: f9bac9e7cdd3c89135d35a61c59a275c9a76562b
Diffstat (limited to 'src')
-rw-r--r--src/GenerateKernelU8S8S32ACC16Avx512.cc106
1 files changed, 85 insertions, 21 deletions
diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc
index cd230c5..6f3f276 100644
--- a/src/GenerateKernelU8S8S32ACC16Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc
@@ -150,15 +150,22 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
PackingTraits<int8_t, int16_t, inst_set_t::avx512>::NCB;
constexpr int mRegBlockSize =
PackingTraits<int8_t, int16_t, inst_set_t::avx512>::MR;
- // constexpr int nRegBlockSize =
- // PackingTraits<int8_t, int16_t, inst_set_t::avx512>::NR;
+ constexpr int nRegBlockSize =
+ PackingTraits<int8_t, int16_t, inst_set_t::avx512>::NR;
constexpr int row_interleave =
PackingTraits<int8_t, int16_t, inst_set_t::avx512>::ROW_INTERLEAVE;
+
+ assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
+ assert(nc % nRegBlockSize == 0 && "nc must be a multiple of NR");
+ int maxMRegs = mRegBlockSize;
+ int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
+ assert(
+ maxMRegs * maxNRegs <= 28 &&
+ "MR*(NR*ROW_INTERLEAVE*8/512) \
+ must be <= 28(available registers constraint)");
+
int mRegBlocks = mc / mRegBlockSize;
int mRegBlocksRem = mc % mRegBlockSize;
- assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
- // assert((nc == nRegBlockSize) &&
- //"nc must be equal to the number of register blocks");
// arguments to the function created
asmjit::X86Gp buffer_A = a->zdi();
@@ -180,7 +187,8 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
ffi.setDirtyRegs(
- asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14));
+ asmjit::X86Reg::kKindGp,
+ asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
asmjit::FuncArgsMapper args(&func);
args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
@@ -193,31 +201,38 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
asmjit::FuncUtils::emitProlog(a, layout);
asmjit::FuncUtils::allocArgs(a, layout, args);
- asmjit::Label Loopk = a->newLabel();
asmjit::Label LoopMBlocks = a->newLabel();
+ asmjit::Label LoopNBlocks = a->newLabel();
+ asmjit::Label Loopk = a->newLabel();
asmjit::X86Gp buffer_B_saved = a->gpzRef(10);
asmjit::X86Gp C_Offset = a->gpzRef(11);
// asmjit::X86Gp B_pf_saved = a->gpzRef(12);
asmjit::X86Gp iIdx = a->gpzRef(13);
- asmjit::X86Gp kIdx = a->gpzRef(14);
+ asmjit::X86Gp jIdx = a->gpzRef(14);
+ asmjit::X86Gp kIdx = a->gpzRef(15);
- int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ // a->mov(B_pf_saved, B_pf);
+
+ int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
+ int colRegs = std::min(currColRegs, maxNRegs);
if (mRegBlocks > 0) {
// move 0 to iteration variables
a->mov(iIdx, 0);
- // save B_buffer address
- a->mov(buffer_B_saved, buffer_B);
- // a->mov(B_pf_saved, B_pf);
-
a->bind(LoopMBlocks);
a->inc(iIdx);
+ a->mov(jIdx, 0);
+
+ a->bind(LoopNBlocks);
+ a->inc(jIdx);
int rowRegs = mRegBlockSize;
// init C registers
- initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs);
+ initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
// init k loop index
a->mov(kIdx, 0);
@@ -226,7 +241,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
genComputeBlock<inst_set_t::avx512>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
// update buffer_A address for next k iteration
a->add(
@@ -244,16 +259,40 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
// store C matrix
storeCRegs<inst_set_t::avx512>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum);
+ a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
- // increment A for next block
+ // reset A
a->sub(buffer_A, kSize);
+
+ // B for next block
+ a->mov(buffer_B, buffer_B_saved);
+ // using C_Offset as temp reg
+ a->imul(
+ C_Offset,
+ jIdx,
+ static_cast<asmjit::Imm>(
+ nRegBlockSize * row_interleave * sizeof(int8_t)));
+ a->add(buffer_B, C_Offset);
+
+ // increment C for next block
+ a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
+
+ int jLoopTrips = currColRegs / maxNRegs;
+ a->cmp(jIdx, jLoopTrips);
+ a->jl(LoopNBlocks);
+
+ // increment A for next block
a->add(
buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
- // increment C for next block
+
+ // increment C for next A block
+ a->sub(
+ CBase,
+ static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * sizeof(int32_t)));
a->imul(
C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t)));
a->add(CBase, C_Offset);
+
// reset B
a->mov(buffer_B, buffer_B_saved);
// a->mov(B_pf, B_pf_saved);
@@ -263,11 +302,16 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
}
// generate code for remainder
if (mRegBlocksRem > 0) {
+ asmjit::Label LoopNRem = a->newLabel();
asmjit::Label LoopkRem = a->newLabel();
int rowRegs = mRegBlocksRem;
+ a->mov(jIdx, 0);
+ a->bind(LoopNRem);
+ a->inc(jIdx);
+
// init C registers
- initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs);
+ initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
// init k loop index
a->mov(kIdx, 0);
@@ -277,7 +321,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
genComputeBlock<inst_set_t::avx512>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
// update buffer_A address for next k iteration
a->add(
@@ -293,9 +337,29 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
a->cmp(kIdx, kSize);
a->jl(LoopkRem);
+ // reset A
+ a->sub(buffer_A, kSize);
+
+ // B for next block
+ a->mov(buffer_B, buffer_B_saved);
+ // using C_Offset as temp reg
+ a->imul(
+ C_Offset,
+ jIdx,
+ static_cast<asmjit::Imm>(
+ nRegBlockSize * row_interleave * sizeof(int8_t)));
+ a->add(buffer_B, C_Offset);
+
// store C matrix
storeCRegs<inst_set_t::avx512>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum);
+ a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
+
+ // increment C for next block
+ a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
+
+ int jLoopTrips = currColRegs / maxNRegs;
+ a->cmp(jIdx, jLoopTrips);
+ a->jl(LoopNRem);
}
asmjit::FuncUtils::emitEpilog(a, layout);