Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/GenerateKernelU8S8S32ACC16Avx512.cc')
-rw-r--r--src/GenerateKernelU8S8S32ACC16Avx512.cc473
1 files changed, 229 insertions, 244 deletions
diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc
index a49e440..b67d8e8 100644
--- a/src/GenerateKernelU8S8S32ACC16Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc
@@ -167,262 +167,247 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
nRegBlockSize,
nRegBlockSizeMin);
- if (codeCache_.find(kernelSig) != codeCache_.end()) {
- return codeCache_[kernelSig];
- }
-
- code_.reset(false);
- code_.init(rt_.codeInfo());
- x86::Assembler assembler(&code_);
- x86::Emitter* a = assembler.as<x86::Emitter>();
+ return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp {
+ asmjit::CodeHolder code;
+ code.init(runtime().codeInfo());
+ x86::Assembler assembler(&code);
+ x86::Emitter *a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
- // generated code logging
- FILE* codeLogfile = fopen(
- getCodeLoggingFile<inst_set_t::avx512>(
- accum,
- mc,
- nc,
- nBlock,
- kBlock,
- mRegBlockSize,
- nRegBlockSize,
- nRegBlockSizeMin)
- .c_str(),
- "w");
- asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
- if (codeLogger) {
- code_.setLogger(codeLogger);
- }
+ // generated code logging
+ FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx512>(
+ accum, mc, nc, nBlock, kBlock, mRegBlockSize,
+ nRegBlockSize, nRegBlockSizeMin)
+ .c_str(),
+ "w");
+ asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile);
+ if (codeLogger) {
+ code.setLogger(codeLogger);
+ }
#endif
- assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
- assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
- int maxMRegs = mRegBlockSize;
- int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
- assert(
- (maxMRegs+1) * maxNRegs <= 28 &&
- "number of zmm registers for C + one row for loading B: \
+ assert(kc % row_interleave == 0 &&
+ "kc must be a multiple of row_interleave");
+ assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
+ int maxMRegs = mRegBlockSize;
+ int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
+ assert((maxMRegs + 1) * maxNRegs <= 28 &&
+ "number of zmm registers for C + one row for loading B: \
MR*(NR*ROW_INTERLEAVE*8/512) + (NR*ROW_INTERLEAVE*8/512) \
must be <= 28(available registers constraint)");
- int mRegBlocks = mc / mRegBlockSize;
- int mRegBlocksRem = mc % mRegBlockSize;
-
- // arguments to the function created
- x86::Gp buffer_A = a->zdi();
- x86::Gp buffer_B = a->zsi();
- x86::Gp B_pf = a->zdx();
- x86::Gp CBase = a->zcx();
- x86::Gp kSize = a->gpz(8);
- x86::Gp ldcReg = a->gpz(9);
-
- asmjit::FuncDetail func;
- func.init(
- asmjit::
- FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
- asmjit::CallConv::kIdHost));
-
- asmjit::FuncFrame frame;
- frame.init(func);
-
- frame.setDirtyRegs(
- x86::Reg::kGroupVec,
- asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- frame.setDirtyRegs(
- x86::Reg::kGroupGp,
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
-
- asmjit::FuncArgsAssignment args(&func);
- args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
-
- args.updateFuncFrame(frame);
- frame.finalize();
-
- a->emitProlog(frame);
- a->emitArgsAssignment(frame, args);
-
- asmjit::Label LoopMBlocks = a->newLabel();
- asmjit::Label LoopNBlocks = a->newLabel();
- asmjit::Label Loopk = a->newLabel();
-
- x86::Gp buffer_B_saved = a->gpz(10);
- x86::Gp C_Offset = a->gpz(11);
- // x86::Gp B_pf_saved = a->gpz(12);
- x86::Gp iIdx = a->gpz(13);
- x86::Gp jIdx = a->gpz(14);
- x86::Gp kIdx = a->gpz(15);
-
- // save B_buffer address
- a->mov(buffer_B_saved, buffer_B);
- // a->mov(B_pf_saved, B_pf);
-
- int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
- int colRegs = std::min(currColRegs, maxNRegs);
- if (mRegBlocks > 0) {
- // move 0 to iteration variables
- a->mov(iIdx, 0);
-
- a->bind(LoopMBlocks);
- a->inc(iIdx);
- a->mov(jIdx, 0);
-
- a->bind(LoopNBlocks);
- a->inc(jIdx);
-
- int rowRegs = mRegBlockSize;
-
- // init C registers
- initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(Loopk);
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx512>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(
- buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(
- buffer_B,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
- // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
- // sizeof(int8_t)));
-
- a->cmp(kIdx, kSize);
- a->jl(Loopk);
-
- // store C matrix
- storeCRegs<inst_set_t::avx512>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
-
- // reset A
- a->sub(buffer_A, kSize);
-
- // B for next block
- a->mov(buffer_B, buffer_B_saved);
- // using C_Offset as temp reg
- a->imul(
- C_Offset,
- jIdx,
- static_cast<asmjit::Imm>(
- nRegBlockSize * row_interleave * sizeof(int8_t)));
- a->add(buffer_B, C_Offset);
-
- // increment C for next block
- a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
-
- int jLoopTrips = currColRegs / maxNRegs;
- // jLoopTrips should be at least 1
- jLoopTrips = jLoopTrips ? jLoopTrips : 1;
- a->cmp(jIdx, jLoopTrips);
- a->jl(LoopNBlocks);
-
- // increment A for next block
- a->add(
- buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
-
- // increment C for next A block
- a->sub(
- CBase,
- static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * sizeof(int32_t)));
- a->imul(
- C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t)));
- a->add(CBase, C_Offset);
-
- // reset B
- a->mov(buffer_B, buffer_B_saved);
- // a->mov(B_pf, B_pf_saved);
-
- a->cmp(iIdx, mRegBlocks);
- a->jl(LoopMBlocks);
- }
- // generate code for remainder
- if (mRegBlocksRem > 0) {
- asmjit::Label LoopNRem = a->newLabel();
- asmjit::Label LoopkRem = a->newLabel();
- int rowRegs = mRegBlocksRem;
-
- a->mov(jIdx, 0);
- a->bind(LoopNRem);
- a->inc(jIdx);
-
- // init C registers
- initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(LoopkRem);
-
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx512>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(
- buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(
- buffer_B,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
- // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
- // sizeof(int8_t)));
-
- a->cmp(kIdx, kSize);
- a->jl(LoopkRem);
-
- // reset A
- a->sub(buffer_A, kSize);
-
- // B for next block
- a->mov(buffer_B, buffer_B_saved);
- // using C_Offset as temp reg
- a->imul(
- C_Offset,
- jIdx,
- static_cast<asmjit::Imm>(
- nRegBlockSize * row_interleave * sizeof(int8_t)));
- a->add(buffer_B, C_Offset);
-
- // store C matrix
- storeCRegs<inst_set_t::avx512>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
-
- // increment C for next block
- a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
-
- int jLoopTrips = currColRegs / maxNRegs;
- // jLoopTrips should be at least 1
- jLoopTrips = jLoopTrips ? jLoopTrips : 1;
- a->cmp(jIdx, jLoopTrips);
- a->jl(LoopNRem);
- }
+ int mRegBlocks = mc / mRegBlockSize;
+ int mRegBlocksRem = mc % mRegBlockSize;
+
+ // arguments to the function created
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
+
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *,
+ int, int>(asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrame frame;
+ frame.init(func);
+
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(x86::Reg::kGroupGp,
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+
+ args.updateFuncFrame(frame);
+ frame.finalize();
+
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
+
+ asmjit::Label LoopMBlocks = a->newLabel();
+ asmjit::Label LoopNBlocks = a->newLabel();
+ asmjit::Label Loopk = a->newLabel();
+
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ // x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp jIdx = a->gpz(14);
+ x86::Gp kIdx = a->gpz(15);
+
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ // a->mov(B_pf_saved, B_pf);
+
+ int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
+ int colRegs = std::min(currColRegs, maxNRegs);
+ if (mRegBlocks > 0) {
+ // move 0 to iteration variables
+ a->mov(iIdx, 0);
+
+ a->bind(LoopMBlocks);
+ a->inc(iIdx);
+ a->mov(jIdx, 0);
+
+ a->bind(LoopNBlocks);
+ a->inc(jIdx);
+
+ int rowRegs = mRegBlockSize;
+
+ // init C registers
+ initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(Loopk);
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx512>(a, buffer_A, buffer_B, B_pf, rowRegs,
+ colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+ // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ // sizeof(int8_t)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(Loopk);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, C_Offset, ldcReg,
+ accum, colRegs);
+
+ // reset A
+ a->sub(buffer_A, kSize);
+
+ // B for next block
+ a->mov(buffer_B, buffer_B_saved);
+ // using C_Offset as temp reg
+ a->imul(C_Offset, jIdx,
+ static_cast<asmjit::Imm>(nRegBlockSize * row_interleave *
+ sizeof(int8_t)));
+ a->add(buffer_B, C_Offset);
+
+ // increment C for next block
+ a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
+
+ int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
+ a->cmp(jIdx, jLoopTrips);
+ a->jl(LoopNBlocks);
+
+ // increment A for next block
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
+
+ // increment C for next A block
+ a->sub(CBase, static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize *
+ sizeof(int32_t)));
+ a->imul(C_Offset, ldcReg,
+ static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t)));
+ a->add(CBase, C_Offset);
+
+ // reset B
+ a->mov(buffer_B, buffer_B_saved);
+ // a->mov(B_pf, B_pf_saved);
+
+ a->cmp(iIdx, mRegBlocks);
+ a->jl(LoopMBlocks);
+ }
+ // generate code for remainder
+ if (mRegBlocksRem > 0) {
+ asmjit::Label LoopNRem = a->newLabel();
+ asmjit::Label LoopkRem = a->newLabel();
+ int rowRegs = mRegBlocksRem;
+
+ a->mov(jIdx, 0);
+ a->bind(LoopNRem);
+ a->inc(jIdx);
+
+ // init C registers
+ initCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(LoopkRem);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx512>(a, buffer_A, buffer_B, B_pf, rowRegs,
+ colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+ // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ // sizeof(int8_t)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(LoopkRem);
+
+ // reset A
+ a->sub(buffer_A, kSize);
+
+ // B for next block
+ a->mov(buffer_B, buffer_B_saved);
+ // using C_Offset as temp reg
+ a->imul(C_Offset, jIdx,
+ static_cast<asmjit::Imm>(nRegBlockSize * row_interleave *
+ sizeof(int8_t)));
+ a->add(buffer_B, C_Offset);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512>(a, rowRegs, colRegs, C_Offset, ldcReg,
+ accum, colRegs);
+
+ // increment C for next block
+ a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
+
+ int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
+ a->cmp(jIdx, jLoopTrips);
+ a->jl(LoopNRem);
+ }
- a->emitEpilog(frame);
+ a->emitEpilog(frame);
- jit_micro_kernel_fp fn;
- asmjit::Error err = rt_.add(&fn, &code_);
- if (err) {
- std::cout << "Error: in fn add" << std::endl;
- return nullptr;
- }
- codeCache_[kernelSig] = fn;
+ jit_micro_kernel_fp fn;
+ asmjit::Error err;
+ {
+ std::unique_lock<std::mutex> lock(rtMutex_);
+ err = runtime().add(&fn, &code);
+ }
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
#if defined(FBGEMM_LOG_CODE)
- fclose(codeLogfile);
- delete codeLogger;
+ fclose(codeLogfile);
+ delete codeLogger;
#endif
- return fn;
+ return fn;
+ });
}
} // namespace fbgemm