Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/GenerateKernelU8S8S32ACC32Avx512.cc')
-rw-r--r--src/GenerateKernelU8S8S32ACC32Avx512.cc95
1 files changed, 48 insertions, 47 deletions
diff --git a/src/GenerateKernelU8S8S32ACC32Avx512.cc b/src/GenerateKernelU8S8S32ACC32Avx512.cc
index fe35627..12243ee 100644
--- a/src/GenerateKernelU8S8S32ACC32Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC32Avx512.cc
@@ -19,7 +19,7 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs<
inst_set_t::avx512>(
- x86::Emitter* a,
+ asmjit::X86Emitter* a,
int rowRegs,
int colRegs,
int leadingDimCReg) {
@@ -41,25 +41,25 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
inst_set_t::avx512>(
- x86::Emitter* a,
- x86::Gp buffer_A,
- x86::Gp buffer_B,
- x86::Gp B_pf,
+ asmjit::X86Emitter* a,
+ asmjit::X86Gp buffer_A,
+ asmjit::X86Gp buffer_B,
+ asmjit::X86Gp B_pf,
int rowRegs,
int colRegs,
int lda,
int leadingDimCReg) {
// used for matrix A
- x86::Zmm AReg = x86::zmm31;
+ asmjit::X86Zmm AReg = x86::zmm31;
// used for matrix B
- x86::Zmm BReg = x86::zmm30;
+ asmjit::X86Zmm BReg = x86::zmm30;
// Contains 16-bit 1s
- x86::Zmm oneReg = x86::zmm29;
+ asmjit::X86Zmm oneReg = x86::zmm29;
// temporary register
- x86::Zmm res1 = x86::zmm28;
+ asmjit::X86Zmm res1 = x86::zmm28;
for (int j = 0; j < colRegs; ++j) {
// load B
@@ -87,17 +87,18 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<
inst_set_t::avx512>(
- x86::Emitter* a,
+ asmjit::X86Emitter* a,
int rowRegs,
int colRegs,
- x86::Gp C_Offset,
- x86::Gp ldcReg,
+ asmjit::X86Gp C_Offset,
+ asmjit::X86Gp ldcReg,
bool accum,
int leadingDimCReg) {
for (int i = 0; i < rowRegs; ++i) {
if (i != 0) {
a->add(C_Offset, ldcReg);
- } else {
+ }
+ else {
a->mov(C_Offset, static_cast<asmjit::Imm>(0));
}
for (int j = 0; j < colRegs; ++j) {
@@ -167,9 +168,9 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
return codeCache_[kernelSig];
}
code_.reset(false);
- code_.init(rt_.codeInfo());
- x86::Assembler assembler(&code_);
- x86::Emitter* a = assembler.as<x86::Emitter>();
+ code_.init(rt_.getCodeInfo());
+ asmjit::X86Assembler assembler(&code_);
+ asmjit::X86Emitter* a = assembler.asEmitter();
#if defined(FBGEMM_LOG_CODE)
// generated code logging
@@ -204,52 +205,52 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
int mRegBlocksRem = mc % mRegBlockSize;
// arguments to the function created
- x86::Gp buffer_A = a->zdi();
- x86::Gp buffer_B = a->zsi();
- x86::Gp B_pf = a->zdx();
- x86::Gp CBase = a->zcx();
- x86::Gp kSize = a->gpz(8);
- x86::Gp ldcReg = a->gpz(9);
+ asmjit::X86Gp buffer_A = a->zdi();
+ asmjit::X86Gp buffer_B = a->zsi();
+ asmjit::X86Gp B_pf = a->zdx();
+ asmjit::X86Gp CBase = a->zcx();
+ asmjit::X86Gp kSize = a->gpzRef(8);
+ asmjit::X86Gp ldcReg = a->gpzRef(9);
asmjit::FuncDetail func;
func.init(
asmjit::
- FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
+ FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
asmjit::CallConv::kIdHost));
- asmjit::FuncFrame frame;
- frame.init(func);
-
- frame.setDirtyRegs(
- x86::Reg::kGroupVec,
- asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- frame.setDirtyRegs(
- x86::Reg::kGroupGp,
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ asmjit::FuncFrameInfo ffi;
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindVec,
+ asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindGp,
+ asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
- asmjit::FuncArgsAssignment args(&func);
+ asmjit::FuncArgsMapper args(&func);
args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
- args.updateFuncFrame(frame);
- frame.finalize();
+ args.updateFrameInfo(ffi);
+
+ asmjit::FuncFrameLayout layout;
+ layout.init(func, ffi);
- a->emitProlog(frame);
- a->emitArgsAssignment(frame, args);
+ asmjit::FuncUtils::emitProlog(a, layout);
+ asmjit::FuncUtils::allocArgs(a, layout, args);
asmjit::Label LoopMBlocks = a->newLabel();
asmjit::Label LoopNBlocks = a->newLabel();
asmjit::Label Loopk = a->newLabel();
- x86::Gp buffer_B_saved = a->gpz(10);
- x86::Gp C_Offset = a->gpz(11);
- x86::Gp B_pf_saved = a->gpz(12);
- x86::Gp iIdx = a->gpz(13);
- x86::Gp jIdx = a->gpz(14);
- x86::Gp kIdx = a->gpz(15);
- // x86::Gp B_pf = a->gpz(8);
+ asmjit::X86Gp buffer_B_saved = a->gpzRef(10);
+ asmjit::X86Gp C_Offset = a->gpzRef(11);
+ asmjit::X86Gp B_pf_saved = a->gpzRef(12);
+ asmjit::X86Gp iIdx = a->gpzRef(13);
+ asmjit::X86Gp jIdx = a->gpzRef(14);
+ asmjit::X86Gp kIdx = a->gpzRef(15);
+ // asmjit::X86Gp B_pf = a->gpzRef(8);
- x86::Zmm oneReg = x86::zmm29;
+ asmjit::X86Zmm oneReg = x86::zmm29;
// create 16-bit 1s
// i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
// and so on
@@ -419,7 +420,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
a->jl(LoopNRem);
}
- a->emitEpilog(frame);
+ asmjit::FuncUtils::emitEpilog(a, layout);
jit_micro_kernel_fp fn;
asmjit::Error err = rt_.add(&fn, &code_);