Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/GenerateKernelU8S8S32ACC16.cc')
-rw-r--r--src/GenerateKernelU8S8S32ACC16.cc357
1 files changed, 172 insertions, 185 deletions
diff --git a/src/GenerateKernelU8S8S32ACC16.cc b/src/GenerateKernelU8S8S32ACC16.cc
index 1e7e081..cbd5877 100644
--- a/src/GenerateKernelU8S8S32ACC16.cc
+++ b/src/GenerateKernelU8S8S32ACC16.cc
@@ -9,18 +9,6 @@
namespace fbgemm {
-template <typename TA, typename TB, typename TC, typename accT>
-thread_local asmjit::JitRuntime CodeGenBase<TA, TB, TC, accT>::rt_;
-
-template <typename TA, typename TB, typename TC, typename accT>
-thread_local asmjit::CodeHolder CodeGenBase<TA, TB, TC, accT>::code_;
-
-template <typename TA, typename TB, typename TC, typename accT>
-thread_local std::map<
- std::tuple<bool, int, int, int, int, int, int, int>,
- typename CodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp>
- CodeGenBase<TA, TB, TC, accT>::codeCache_;
-
namespace x86 = asmjit::x86;
/**
@@ -35,12 +23,13 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs<
int rowRegs,
int colRegs,
int leadingDimCReg) {
+ using CRegs = x86::Ymm;
for (int i = 0; i < rowRegs; ++i) {
for (int j = 0; j < colRegs; ++j) {
a->vxorps(
- CRegs_avx2_[i * leadingDimCReg + j],
- CRegs_avx2_[i * leadingDimCReg + j],
- CRegs_avx2_[i * leadingDimCReg + j]);
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j),
+ CRegs(i * leadingDimCReg + j));
}
}
}
@@ -66,6 +55,8 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
x86::Ymm tmpReg = x86::ymm14;
+ using CRegs = x86::Ymm;
+
for (int i = 0; i < rowRegs; ++i) {
// broadcast A
a->vpbroadcastw(
@@ -74,9 +65,9 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
a->vpmaddubsw(
tmpReg, AReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
a->vpaddsw(
- CRegs_avx2_[i * leadingDimCReg + j],
+ CRegs(i * leadingDimCReg + j),
tmpReg,
- CRegs_avx2_[i * leadingDimCReg + j]);
+ CRegs(i * leadingDimCReg + j));
// Prefetching is hurting performance in some cases
// because prefetch instructions itself consumes a slot
// in pipeline issue thus slowing down the kernel.
@@ -105,12 +96,13 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
x86::Xmm extractDest128 = x86::xmm15;
x86::Ymm extractDest256 = x86::ymm15;
+ using CRegs = x86::Ymm;
for (int i = 0; i < rowRegs; ++i) {
a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(i * sizeof(int32_t)));
for (int j = 0; j < colRegs; ++j) {
for (int idx = 0; idx < 2; ++idx) {
a->vextracti128(
- extractDest128, CRegs_avx2_[i * leadingDimCReg + j], idx);
+ extractDest128, CRegs(i * leadingDimCReg + j), idx);
a->vpmovsxwd(extractDest256, extractDest128);
x86::Mem destAddr = x86::dword_ptr(
a->zcx(), C_Offset, 0, (j * 2 + idx) * 8 * sizeof(int32_t));
@@ -172,191 +164,186 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
nRegBlockSize,
nRegBlockSizeMin);
- if (codeCache_.find(kernelSig) != codeCache_.end()) {
- return codeCache_[kernelSig];
- }
- code_.reset(false);
- code_.init(rt_.codeInfo());
- x86::Assembler assembler(&code_);
- x86::Emitter* a = assembler.as<x86::Emitter>();
+ return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp {
+ asmjit::CodeHolder code;
+ code.init(runtime().codeInfo());
+ x86::Assembler assembler(&code);
+ x86::Emitter *a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
- // generated code logging
- FILE* codeLogfile = fopen(
- getCodeLoggingFile<inst_set_t::avx2>(
- accum,
- mc,
- nc,
- nBlock,
- kBlock,
- mRegBlockSize,
- nRegBlockSize,
- nRegBlockSizeMin)
- .c_str(),
- "w");
- asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
- if (codeLogger) {
- code_.setLogger(codeLogger);
- }
+ // generated code logging
+ FILE *codeLogfile = fopen(getCodeLoggingFile<inst_set_t::avx2>(
+ accum, mc, nc, nBlock, kBlock, mRegBlockSize,
+ nRegBlockSize, nRegBlockSizeMin)
+ .c_str(),
+ "w");
+ asmjit::FileLogger *codeLogger = new asmjit::FileLogger(codeLogfile);
+ if (codeLogger) {
+ code.setLogger(codeLogger);
+ }
#endif
- int mRegBlocks = mc / mRegBlockSize;
- int mRegBlocksRem = mc % mRegBlockSize;
- assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
- // assert((nc == nRegBlockSize) &&
- //"nc must be equal to the number of register blocks");
-
- // arguments to the function created
- x86::Gp buffer_A = a->zdi();
- x86::Gp buffer_B = a->zsi();
- x86::Gp B_pf = a->zdx();
- x86::Gp CBase = a->zcx();
- x86::Gp kSize = a->gpz(8);
- x86::Gp ldcReg = a->gpz(9);
-
- asmjit::FuncDetail func;
- func.init(
- asmjit::
- FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
- asmjit::CallConv::kIdHost));
-
- asmjit::FuncFrame frame;
- frame.init(func);
- frame.setDirtyRegs(
- x86::Reg::kGroupVec,
- asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- frame.setDirtyRegs(
- x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));
-
- asmjit::FuncArgsAssignment args(&func);
- args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
-
- args.updateFuncFrame(frame);
- frame.finalize();
-
- a->emitProlog(frame);
- a->emitArgsAssignment(frame, args);
-
- asmjit::Label Loopk = a->newLabel();
- asmjit::Label LoopMBlocks = a->newLabel();
-
- x86::Gp buffer_B_saved = a->gpz(10);
- x86::Gp C_Offset = a->gpz(11);
- // x86::Gp B_pf_saved = a->gpz(12);
- x86::Gp iIdx = a->gpz(13);
- x86::Gp kIdx = a->gpz(14);
-
- int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
- if (mRegBlocks > 0) {
- // move 0 to iteration variables
- a->mov(iIdx, 0);
-
- // save B_buffer address
- a->mov(buffer_B_saved, buffer_B);
- // a->mov(B_pf_saved, B_pf);
-
- a->bind(LoopMBlocks);
- a->inc(iIdx);
-
- int rowRegs = mRegBlockSize;
-
- // init C registers
- initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(Loopk);
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx2>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
-
- // update buffer_A address for next k iteration
- a->add(
- buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(
- buffer_B,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
- // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
- // sizeof(int8_t)));
-
- a->cmp(kIdx, kSize);
- a->jl(Loopk);
-
- // store C matrix
- storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum);
-
- // increment A for next block
- a->sub(buffer_A, kSize);
- a->add(
- buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
- // increment C for next block
- a->imul(
- C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t)));
- a->add(CBase, C_Offset);
- // reset B
- a->mov(buffer_B, buffer_B_saved);
- // a->mov(B_pf, B_pf_saved);
-
- a->cmp(iIdx, mRegBlocks);
- a->jl(LoopMBlocks);
- }
- // generate code for remainder
- if (mRegBlocksRem > 0) {
- asmjit::Label LoopkRem = a->newLabel();
- int rowRegs = mRegBlocksRem;
+ int mRegBlocks = mc / mRegBlockSize;
+ int mRegBlocksRem = mc % mRegBlockSize;
+ assert(kc % row_interleave == 0 &&
+ "kc must be a multiple of row_interleave");
+ // assert((nc == nRegBlockSize) &&
+ //"nc must be equal to the number of register blocks");
+
+ // arguments to the function created
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
+
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::FuncSignatureT<void, uint8_t *, int8_t *, int8_t *, int32_t *,
+ int, int>(asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrame frame;
+ frame.init(func);
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(x86::Reg::kGroupGp,
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));
+
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+
+ args.updateFuncFrame(frame);
+ frame.finalize();
+
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
+
+ asmjit::Label Loopk = a->newLabel();
+ asmjit::Label LoopMBlocks = a->newLabel();
+
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ // x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp kIdx = a->gpz(14);
+
+ int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
+ if (mRegBlocks > 0) {
+ // move 0 to iteration variables
+ a->mov(iIdx, 0);
+
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ // a->mov(B_pf_saved, B_pf);
+
+ a->bind(LoopMBlocks);
+ a->inc(iIdx);
+
+ int rowRegs = mRegBlockSize;
+
+ // init C registers
+ initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(Loopk);
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx2>(a, buffer_A, buffer_B, B_pf, rowRegs,
+ colRegs, kBlock);
+
+ // update buffer_A address for next k iteration
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+ // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ // sizeof(int8_t)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(Loopk);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg,
+ accum);
+
+ // increment A for next block
+ a->sub(buffer_A, kSize);
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
+ // increment C for next block
+ a->imul(C_Offset, ldcReg,
+ static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t)));
+ a->add(CBase, C_Offset);
+ // reset B
+ a->mov(buffer_B, buffer_B_saved);
+ // a->mov(B_pf, B_pf_saved);
+
+ a->cmp(iIdx, mRegBlocks);
+ a->jl(LoopMBlocks);
+ }
+ // generate code for remainder
+ if (mRegBlocksRem > 0) {
+ asmjit::Label LoopkRem = a->newLabel();
+ int rowRegs = mRegBlocksRem;
- // init C registers
- initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs);
+ // init C registers
+ initCRegs<inst_set_t::avx2>(a, rowRegs, colRegs);
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(LoopkRem);
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(LoopkRem);
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
- genComputeBlock<inst_set_t::avx2>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
+ genComputeBlock<inst_set_t::avx2>(a, buffer_A, buffer_B, B_pf, rowRegs,
+ colRegs, kBlock);
- // update buffer_A address for next k iteration
- a->add(
- buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+ // update buffer_A address for next k iteration
+ a->add(buffer_A,
+ static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
- // update buffer_B address for next k iteration
- a->add(
- buffer_B,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
- // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
- // sizeof(int8_t)));
+ // update buffer_B address for next k iteration
+ a->add(buffer_B, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ sizeof(int8_t)));
+ // a->add(B_pf, static_cast<asmjit::Imm>(nBlock * row_interleave *
+ // sizeof(int8_t)));
- a->cmp(kIdx, kSize);
- a->jl(LoopkRem);
+ a->cmp(kIdx, kSize);
+ a->jl(LoopkRem);
- // store C matrix
- storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum);
- }
+ // store C matrix
+ storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg,
+ accum);
+ }
- a->emitEpilog(frame);
+ a->emitEpilog(frame);
- jit_micro_kernel_fp fn;
- asmjit::Error err = rt_.add(&fn, &code_);
- if (err) {
- std::cout << "Error: in fn add" << std::endl;
- return nullptr;
- }
- codeCache_[kernelSig] = fn;
+ jit_micro_kernel_fp fn;
+ asmjit::Error err;
+ {
+ std::unique_lock<std::mutex> lock(rtMutex_);
+ err = runtime().add(&fn, &code);
+ }
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
#if defined(FBGEMM_LOG_CODE)
- fclose(codeLogfile);
- delete codeLogger;
+ fclose(codeLogfile);
+ delete codeLogger;
#endif
- return fn;
+ return fn;
+ });
}
} // namespace fbgemm