Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJianyu Huang <jianyuhuang@fb.com>2019-08-06 19:35:42 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-08-06 19:50:51 +0300
commitd8b3323668fdd15dc70e9cb43ab16e96f4846eeb (patch)
treed48a6818c14575d92e68bf1ffb621d646a6c893e /src/GenerateKernelU8S8S32ACC16Avx512.cc
parent0d5d057ca941ebb511bdc6178fc26c23e6c4a953 (diff)
Integrate VNNI into FBGEMM master branch (#113)
Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/113 Adding the VNNI support in FBGEMM. Reviewed By: dskhudia Differential Revision: D16276574 fbshipit-source-id: 832ccdb27339489ebc138f3b2678e53d107c1b79
Diffstat (limited to 'src/GenerateKernelU8S8S32ACC16Avx512.cc')
-rw-r--r--src/GenerateKernelU8S8S32ACC16Avx512.cc94
1 files changed, 47 insertions, 47 deletions
diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc
index c95757b..a49e440 100644
--- a/src/GenerateKernelU8S8S32ACC16Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc
@@ -19,7 +19,7 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs<
inst_set_t::avx512>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
int leadingDimCReg) {
@@ -41,18 +41,18 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
inst_set_t::avx512>(
- asmjit::X86Emitter* a,
- asmjit::X86Gp buffer_A,
- asmjit::X86Gp buffer_B,
- asmjit::X86Gp /* unused (reserved for prefetching)*/,
+ x86::Emitter* a,
+ x86::Gp buffer_A,
+ x86::Gp buffer_B,
+ x86::Gp /* unused (reserved for prefetching)*/,
int rowRegs,
int colRegs,
int lda,
int leadingDimCReg) {
// used for matrix A
- asmjit::X86Zmm AReg = x86::zmm29;
+ x86::Zmm AReg = x86::zmm29;
- asmjit::X86Zmm tmpReg = x86::zmm30;
+ x86::Zmm tmpReg = x86::zmm30;
// We start allocating BRegs from zmm27 and then allocate zmm26 and so on.
for (int j = 0; j < colRegs; ++j) {
@@ -66,8 +66,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
a->vpbroadcastw(
AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t)));
for (int j = 0; j < colRegs; ++j) {
- a->vpmaddubsw(
- tmpReg, AReg, AllRegs_avx512_[27-j]);
+ a->vpmaddubsw(tmpReg, AReg, AllRegs_avx512_[27 - j]);
a->vpaddsw(
CRegs_avx512_[i * leadingDimCReg + j],
tmpReg,
@@ -90,15 +89,16 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
inst_set_t::avx512>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
- asmjit::X86Gp C_Offset,
- asmjit::X86Gp ldcReg,
+ x86::Gp C_Offset,
+ x86::Gp ldcReg,
+
bool accum,
int leadingDimCReg) {
- asmjit::X86Ymm extractDest256 = x86::ymm31;
- asmjit::X86Zmm extractDest512 = x86::zmm31;
+ x86::Ymm extractDest256 = x86::ymm31;
+ x86::Zmm extractDest512 = x86::zmm31;
for (int i = 0; i < rowRegs; ++i) {
a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(i * sizeof(int32_t)));
@@ -107,7 +107,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
a->vextracti32x8(
extractDest256, CRegs_avx512_[i * leadingDimCReg + j], idx);
a->vpmovsxwd(extractDest512, extractDest256);
- asmjit::X86Mem destAddr = x86::dword_ptr(
+ x86::Mem destAddr = x86::dword_ptr(
a->zcx(), C_Offset, 0, (j * 2 + idx) * 16 * sizeof(int32_t));
if (accum) {
a->vpaddd(extractDest512, extractDest512, destAddr);
@@ -172,9 +172,9 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
}
code_.reset(false);
- code_.init(rt_.getCodeInfo());
- asmjit::X86Assembler assembler(&code_);
- asmjit::X86Emitter* a = assembler.asEmitter();
+ code_.init(rt_.codeInfo());
+ x86::Assembler assembler(&code_);
+ x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
// generated code logging
@@ -209,49 +209,49 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
int mRegBlocksRem = mc % mRegBlockSize;
// arguments to the function created
- asmjit::X86Gp buffer_A = a->zdi();
- asmjit::X86Gp buffer_B = a->zsi();
- asmjit::X86Gp B_pf = a->zdx();
- asmjit::X86Gp CBase = a->zcx();
- asmjit::X86Gp kSize = a->gpzRef(8);
- asmjit::X86Gp ldcReg = a->gpzRef(9);
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
asmjit::FuncDetail func;
func.init(
asmjit::
- FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
+ FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
asmjit::CallConv::kIdHost));
- asmjit::FuncFrameInfo ffi;
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindVec,
- asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindGp,
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
+ asmjit::FuncFrame frame;
+ frame.init(func);
- asmjit::FuncArgsMapper args(&func);
- args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(
+ x86::Reg::kGroupGp,
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- args.updateFrameInfo(ffi);
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
- asmjit::FuncFrameLayout layout;
- layout.init(func, ffi);
+ args.updateFuncFrame(frame);
+ frame.finalize();
- asmjit::FuncUtils::emitProlog(a, layout);
- asmjit::FuncUtils::allocArgs(a, layout, args);
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
asmjit::Label LoopMBlocks = a->newLabel();
asmjit::Label LoopNBlocks = a->newLabel();
asmjit::Label Loopk = a->newLabel();
- asmjit::X86Gp buffer_B_saved = a->gpzRef(10);
- asmjit::X86Gp C_Offset = a->gpzRef(11);
- // asmjit::X86Gp B_pf_saved = a->gpzRef(12);
- asmjit::X86Gp iIdx = a->gpzRef(13);
- asmjit::X86Gp jIdx = a->gpzRef(14);
- asmjit::X86Gp kIdx = a->gpzRef(15);
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ // x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp jIdx = a->gpz(14);
+ x86::Gp kIdx = a->gpz(15);
// save B_buffer address
a->mov(buffer_B_saved, buffer_B);
@@ -407,7 +407,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
a->jl(LoopNRem);
}
- asmjit::FuncUtils::emitEpilog(a, layout);
+ a->emitEpilog(frame);
jit_micro_kernel_fp fn;
asmjit::Error err = rt_.add(&fn, &code_);