Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJianyu Huang <jianyuhuang@fb.com>2019-08-06 19:35:42 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-08-06 19:50:51 +0300
commitd8b3323668fdd15dc70e9cb43ab16e96f4846eeb (patch)
treed48a6818c14575d92e68bf1ffb621d646a6c893e
parent0d5d057ca941ebb511bdc6178fc26c23e6c4a953 (diff)
Integrate VNNI into FBGEMM master branch (#113)
Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/113 Adding the VNNI support in FBGEMM. Reviewed By: dskhudia Differential Revision: D16276574 fbshipit-source-id: 832ccdb27339489ebc138f3b2678e53d107c1b79
-rw-r--r--CMakeLists.txt2
-rw-r--r--include/fbgemm/PackingTraits-inl.h50
-rw-r--r--include/fbgemm/Utils.h7
-rw-r--r--src/ExecuteKernelU8S8.cc47
-rw-r--r--src/Fbgemm.cc18
-rw-r--r--src/GenerateKernel.h30
-rw-r--r--src/GenerateKernelU8S8S32ACC16.cc91
-rw-r--r--src/GenerateKernelU8S8S32ACC16Avx512.cc94
-rw-r--r--src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc102
-rw-r--r--src/GenerateKernelU8S8S32ACC32.cc87
-rw-r--r--src/GenerateKernelU8S8S32ACC32Avx512.cc95
-rw-r--r--src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc431
-rw-r--r--src/GroupwiseConv.h100
-rw-r--r--src/GroupwiseConvAcc32Avx2.cc177
-rw-r--r--src/PackAMatrix.cc10
-rw-r--r--src/PackAWithIm2Col.cc14
-rw-r--r--src/PackAWithQuantRowOffset.cc14
-rw-r--r--src/PackAWithRowOffset.cc14
-rw-r--r--src/PackBMatrix.cc25
-rw-r--r--src/PackMatrix.cc9
-rw-r--r--src/PackWeightMatrixForGConv.cc8
-rw-r--r--src/Utils.cc3
-rw-r--r--test/GConvTest.cc4
m---------third_party/asmjit0
24 files changed, 1054 insertions, 378 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index b575e17..817f699 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -33,8 +33,10 @@ set(FBGEMM_GENERIC_SRCS src/ExecuteKernel.cc
src/FbgemmI8Spmdm.cc
src/GenerateKernelU8S8S32ACC16.cc
src/GenerateKernelU8S8S32ACC16Avx512.cc
+ src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc
src/GenerateKernelU8S8S32ACC32.cc
src/GenerateKernelU8S8S32ACC32Avx512.cc
+ src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc
src/GroupwiseConvAcc32Avx2.cc
src/PackAMatrix.cc
src/PackAWithIm2Col.cc
diff --git a/include/fbgemm/PackingTraits-inl.h b/include/fbgemm/PackingTraits-inl.h
index 76eb425..baccfad 100644
--- a/include/fbgemm/PackingTraits-inl.h
+++ b/include/fbgemm/PackingTraits-inl.h
@@ -222,3 +222,53 @@ struct PackingTraits<
128}; ///< Cache block for N dimension (multiple of NR).
static constexpr int KCB{256}; ///< Cache block for K dimension.
};
+
+/**
+ * @brief Helper struct to type specialize for int16_t and int32_t together.
+ */
+template <typename T>
+struct is_16or32bit {
+ static constexpr bool value =
+ std::is_same<T, int16_t>::value || std::is_same<T, int32_t>::value;
+};
+
+/**
+ * @brief Packing parameter specialization for accumulation into 32-bit/16-bit
+ * integers.
+ *
+ * Since there is no int16_t accumulation for AVX512 VNNI, we redirect int16_t
+ * to int32_t accumulation and use the same blocking parameters as int32_t.
+ *
+ * This is picked when T is of int8 type (signed or unsigned) and instruction
+ * set is avx512_vnni.
+ */
+template <typename T, typename accT>
+struct PackingTraits<
+ T,
+ accT,
+ inst_set_t::avx512_vnni,
+ typename std::enable_if<
+ is_8bit<T>::value && is_16or32bit<accT>::value>::type> {
+ static constexpr int MR{8}; ///< Register block for M dimension.
+ static constexpr int NR_MIN{
+ 16}; ///< Minimum register block for N dimension.
+ ///< 16 because 16*ROW_INTERLEAVE int8 elements
+ ///< completely fill a 512-bit wide vector.
+ static constexpr int NR{
+ 32}; ///< Register block for N dimension.
+ ///< Must be a multiple of 16 because 16*ROW_INTERLEAVE int8 elements
+ ///< completely fill a 512-bit wide vector. Total registers used for
+ ///< N dimension: NR*ROW_INTERLEAVE*8/VLEN. We use MR x
+ ///< NR*ROW_INTERLEAVE*8/VLEN zmm registers
+ ///< for C accumulations.
+
+ static constexpr int ROW_INTERLEAVE{
+ 4}; ///< 4 rows are interleaved to use vpmaddubsw instruction for packing
+ ///< B matrix.
+
+ static constexpr int MCB{
+ 128}; ///< Cache block for M dimension (multiple of MR).
+ static constexpr int NCB{
+ 32}; ///< Cache block for N dimension (multiple of NR).
+ static constexpr int KCB{256}; ///< Cache block for K dimension.
+};
diff --git a/include/fbgemm/Utils.h b/include/fbgemm/Utils.h
index 107cf07..3f8522b 100644
--- a/include/fbgemm/Utils.h
+++ b/include/fbgemm/Utils.h
@@ -29,7 +29,7 @@ enum class matrix_op_t { NoTranspose, Transpose };
/**
* @brief Typed enum for supported instruction sets.
*/
-enum class inst_set_t { anyarch, avx2, avx512 };
+enum class inst_set_t { anyarch, avx2, avx512, avx512_vnni };
/**
* @brief Typed enum for optimized paths for convolutions
@@ -100,6 +100,11 @@ FBGEMM_API bool fbgemmHasAvx512Support();
FBGEMM_API bool fbgemmHasAvx2Support();
/**
+ * @brief Are we running on a AVX512_VNNI supported cpu?
+ */
+FBGEMM_API bool fbgemmHasAvx512VnniSupport();
+
+/**
* @brief Helper struct to enable autotuning of FBGEMM packing and kernels.
*
* This structure is optional. If not used, the default values for these
diff --git a/src/ExecuteKernelU8S8.cc b/src/ExecuteKernelU8S8.cc
index f7292fd..0a4ff55 100644
--- a/src/ExecuteKernelU8S8.cc
+++ b/src/ExecuteKernelU8S8.cc
@@ -49,7 +49,8 @@ ExecuteKernel<
throw std::runtime_error("Failed to initialize cpuinfo!");
}
if (params) {
- if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) {
+ if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support() ||
+ fbgemmHasAvx2Support()) {
mbSize_ = params->MCB;
nbSize_ = params->NCB;
nrMinSize_ = params->NR_MIN;
@@ -59,7 +60,20 @@ ExecuteKernel<
assert(0 && "unsupported architecure");
}
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ mbSize_ = PackingTraits<
+ int8_t,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512_vnni>::MCB;
+ nbSize_ = PackingTraits<
+ int8_t,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512_vnni>::NCB;
+ nrMinSize_ = PackingTraits<
+ int8_t,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512_vnni>::NR_MIN;
+ } else if (fbgemmHasAvx512Support()) {
mbSize_ = PackingTraits<
int8_t,
typename packingAMatrix::accType,
@@ -118,7 +132,25 @@ void ExecuteKernel<
typename BaseType::jit_micro_kernel_fp fn;
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ if (std::is_same<typename packingAMatrix::accType, std::int16_t>::value) {
+ // For AVX512VNNI, we redirect int16_t to int32_t accumulation.
+ CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
+ fn = codeObj.getOrCreate<inst_set_t::avx512_vnni>(
+ accum,
+ packed_rows_A,
+ packedB_.blockColSize(),
+ packedA_.numPackedCols(),
+ nbSize_);
+ } else {
+ fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni>(
+ accum,
+ packed_rows_A,
+ packedB_.blockColSize(),
+ packedA_.numPackedCols(),
+ nbSize_);
+ }
+ } else if (fbgemmHasAvx512Support()) {
fn = BaseType::template getOrCreate<inst_set_t::avx512>(
accum,
packed_rows_A,
@@ -148,7 +180,10 @@ void ExecuteKernel<
if (jb == bColBlocks - 1) {
int nc = ((packedB_.lastBcol() - 1) / nrMinSize_ + 1) * nrMinSize_;
if (nc != nbSize_) {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni>(
+ accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_);
+ } else if (fbgemmHasAvx512Support()) {
fn = BaseType::template getOrCreate<inst_set_t::avx512>(
accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_);
} else if (fbgemmHasAvx2Support()) {
@@ -213,7 +248,7 @@ void ExecuteKernel<
int32_t nSize =
C_buffer_start == C_tile_ ? jb * nbSize_ : packedB_.numCols();
if (nSize) {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support()) {
// TODO: avx512 path
// Currently use avx2 code
outputProcess_.template f<inst_set_t::avx2>(
@@ -238,7 +273,7 @@ void ExecuteKernel<
if (C_buffer_start == C_tile_) {
// When C_tile_ scratchpad was used to avoid accessing memory past
// C_buffer_ .
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support()) {
// TODO: avx512 path
// Currently use avx2 code
outputProcess_.template f<inst_set_t::avx2>(
diff --git a/src/Fbgemm.cc b/src/Fbgemm.cc
index 0f2f6fb..4f7026f 100644
--- a/src/Fbgemm.cc
+++ b/src/Fbgemm.cc
@@ -48,7 +48,8 @@ void fbgemmPacked(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
- if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) {
+ if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
+ !fbgemmHasAvx2Support())) {
assert(0 && "unknown architecure");
}
@@ -62,7 +63,20 @@ void fbgemmPacked(
MR = blocking_params->MR;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ MCB = PackingTraits<
+ typename packingAMatrix::inpType,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512_vnni>::MCB;
+ KCB = PackingTraits<
+ typename packingAMatrix::inpType,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512_vnni>::KCB;
+ MR = PackingTraits<
+ typename packingAMatrix::inpType,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512_vnni>::MR;
+ } else if (fbgemmHasAvx512Support()) {
MCB = PackingTraits<
typename packingAMatrix::inpType,
typename packingAMatrix::accType,
diff --git a/src/GenerateKernel.h b/src/GenerateKernel.h
index dccdfc5..e52097e 100644
--- a/src/GenerateKernel.h
+++ b/src/GenerateKernel.h
@@ -18,7 +18,7 @@ namespace fbgemm {
namespace x86 = asmjit::x86;
/**
- * @brief AVX2/AVX512 JIT assembly code generator.
+ * @brief AVX2/AVX512/AVX512VNNI JIT assembly code generator.
* @tparam TA Type of matrix A.
* @tparam TB Type of matrix B.
* @tparam TC Type of matrix C.
@@ -104,7 +104,7 @@ class CodeGenBase {
*/
template <inst_set_t instSet>
void initCRegs(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
int leadingDimCRegAssign = 4);
@@ -114,10 +114,10 @@ class CodeGenBase {
*/
template <inst_set_t instSet>
void genComputeBlock(
- asmjit::X86Emitter* a,
- asmjit::X86Gp buffer_A,
- asmjit::X86Gp buffer_B,
- asmjit::X86Gp B_pf,
+ x86::Emitter* a,
+ x86::Gp buffer_A,
+ x86::Gp buffer_B,
+ x86::Gp B_pf,
int rowRegs,
int colRegs,
int lda,
@@ -129,11 +129,11 @@ class CodeGenBase {
*/
template <inst_set_t instSet>
void storeCRegs(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
- asmjit::X86Gp C_Offset,
- asmjit::X86Gp ldcReg,
+ x86::Gp C_Offset,
+ x86::Gp ldcReg,
bool accum,
int leadingDimCRegAssign = 4);
@@ -168,7 +168,9 @@ class CodeGenBase {
fileName += "_MR-" + std::to_string(MR);
fileName += "_NR-" + std::to_string(NR);
fileName += "_NR_MIN-" + std::to_string(NR_MIN);
- if (instSet == inst_set_t::avx512) {
+ if (instSet == inst_set_t::avx512_vnni) {
+ fileName += "_avx512vnni";
+ } else if (instSet == inst_set_t::avx512) {
fileName += "_avx512";
} else if (instSet == inst_set_t::avx2) {
fileName += "_avx2";
@@ -178,12 +180,10 @@ class CodeGenBase {
}
private:
- asmjit::X86Ymm
- CRegs_avx2_[12]; ///< AVX2 ymm registers for C in the micro-kernel.
- asmjit::X86Zmm
+ x86::Ymm CRegs_avx2_[12]; ///< AVX2 ymm registers for C in the micro-kernel.
+ x86::Zmm
CRegs_avx512_[28]; ///< AVX512 zmm registers for C in the micro-kernel.
- asmjit::X86Zmm
- AllRegs_avx512_[32]; ///< all AVX512 zmm registers.
+ x86::Zmm AllRegs_avx512_[32]; ///< all AVX512 zmm registers.
int vectorWidth_; ///< Vector width in bits.
int VLEN_; ///< Vector width in elements.
diff --git a/src/GenerateKernelU8S8S32ACC16.cc b/src/GenerateKernelU8S8S32ACC16.cc
index 718b883..1e7e081 100644
--- a/src/GenerateKernelU8S8S32ACC16.cc
+++ b/src/GenerateKernelU8S8S32ACC16.cc
@@ -31,7 +31,7 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs<
inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
int leadingDimCReg) {
@@ -53,18 +53,18 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
inst_set_t::avx2>(
- asmjit::X86Emitter* a,
- asmjit::X86Gp buffer_A,
- asmjit::X86Gp buffer_B,
- asmjit::X86Gp /* unused (reserved for prefetching)*/,
+ x86::Emitter* a,
+ x86::Gp buffer_A,
+ x86::Gp buffer_B,
+ x86::Gp /* unused (reserved for prefetching)*/,
int rowRegs,
int colRegs,
int lda,
int leadingDimCReg) {
// used for matrix A
- asmjit::X86Ymm AReg = x86::ymm12;
+ x86::Ymm AReg = x86::ymm12;
- asmjit::X86Ymm tmpReg = x86::ymm14;
+ x86::Ymm tmpReg = x86::ymm14;
for (int i = 0; i < rowRegs; ++i) {
// broadcast A
@@ -95,15 +95,15 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
- asmjit::X86Gp C_Offset,
- asmjit::X86Gp ldcReg,
+ x86::Gp C_Offset,
+ x86::Gp ldcReg,
bool accum,
int leadingDimCReg) {
- asmjit::X86Xmm extractDest128 = x86::xmm15;
- asmjit::X86Ymm extractDest256 = x86::ymm15;
+ x86::Xmm extractDest128 = x86::xmm15;
+ x86::Ymm extractDest256 = x86::ymm15;
for (int i = 0; i < rowRegs; ++i) {
a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(i * sizeof(int32_t)));
@@ -112,7 +112,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
a->vextracti128(
extractDest128, CRegs_avx2_[i * leadingDimCReg + j], idx);
a->vpmovsxwd(extractDest256, extractDest128);
- asmjit::X86Mem destAddr = x86::dword_ptr(
+ x86::Mem destAddr = x86::dword_ptr(
a->zcx(), C_Offset, 0, (j * 2 + idx) * 8 * sizeof(int32_t));
if (accum) {
a->vpaddd(extractDest256, extractDest256, destAddr);
@@ -176,9 +176,9 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
return codeCache_[kernelSig];
}
code_.reset(false);
- code_.init(rt_.getCodeInfo());
- asmjit::X86Assembler assembler(&code_);
- asmjit::X86Emitter* a = assembler.asEmitter();
+ code_.init(rt_.codeInfo());
+ x86::Assembler assembler(&code_);
+ x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
// generated code logging
@@ -207,46 +207,45 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
//"nc must be equal to the number of register blocks");
// arguments to the function created
- asmjit::X86Gp buffer_A = a->zdi();
- asmjit::X86Gp buffer_B = a->zsi();
- asmjit::X86Gp B_pf = a->zdx();
- asmjit::X86Gp CBase = a->zcx();
- asmjit::X86Gp kSize = a->gpzRef(8);
- asmjit::X86Gp ldcReg = a->gpzRef(9);
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
asmjit::FuncDetail func;
func.init(
asmjit::
- FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
+ FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
asmjit::CallConv::kIdHost));
- asmjit::FuncFrameInfo ffi;
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindVec,
- asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14));
+ asmjit::FuncFrame frame;
+ frame.init(func);
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(
+ x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));
- asmjit::FuncArgsMapper args(&func);
+ asmjit::FuncArgsAssignment args(&func);
args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
- args.updateFrameInfo(ffi);
+ args.updateFuncFrame(frame);
+ frame.finalize();
- asmjit::FuncFrameLayout layout;
- layout.init(func, ffi);
-
- asmjit::FuncUtils::emitProlog(a, layout);
- asmjit::FuncUtils::allocArgs(a, layout, args);
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
asmjit::Label Loopk = a->newLabel();
asmjit::Label LoopMBlocks = a->newLabel();
- asmjit::X86Gp buffer_B_saved = a->gpzRef(10);
- asmjit::X86Gp C_Offset = a->gpzRef(11);
- // asmjit::X86Gp B_pf_saved = a->gpzRef(12);
- asmjit::X86Gp iIdx = a->gpzRef(13);
- asmjit::X86Gp kIdx = a->gpzRef(14);
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ // x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp kIdx = a->gpz(14);
int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
if (mRegBlocks > 0) {
@@ -289,8 +288,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
a->jl(Loopk);
// store C matrix
- storeCRegs<inst_set_t::avx2>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum);
+ storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum);
// increment A for next block
a->sub(buffer_A, kSize);
@@ -340,11 +338,10 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
a->jl(LoopkRem);
// store C matrix
- storeCRegs<inst_set_t::avx2>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum);
+ storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum);
}
- asmjit::FuncUtils::emitEpilog(a, layout);
+ a->emitEpilog(frame);
jit_micro_kernel_fp fn;
asmjit::Error err = rt_.add(&fn, &code_);
diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc
index c95757b..a49e440 100644
--- a/src/GenerateKernelU8S8S32ACC16Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc
@@ -19,7 +19,7 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs<
inst_set_t::avx512>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
int leadingDimCReg) {
@@ -41,18 +41,18 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
inst_set_t::avx512>(
- asmjit::X86Emitter* a,
- asmjit::X86Gp buffer_A,
- asmjit::X86Gp buffer_B,
- asmjit::X86Gp /* unused (reserved for prefetching)*/,
+ x86::Emitter* a,
+ x86::Gp buffer_A,
+ x86::Gp buffer_B,
+ x86::Gp /* unused (reserved for prefetching)*/,
int rowRegs,
int colRegs,
int lda,
int leadingDimCReg) {
// used for matrix A
- asmjit::X86Zmm AReg = x86::zmm29;
+ x86::Zmm AReg = x86::zmm29;
- asmjit::X86Zmm tmpReg = x86::zmm30;
+ x86::Zmm tmpReg = x86::zmm30;
// We start allocating BRegs from zmm27 and then allocate zmm26 and so on.
for (int j = 0; j < colRegs; ++j) {
@@ -66,8 +66,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
a->vpbroadcastw(
AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t)));
for (int j = 0; j < colRegs; ++j) {
- a->vpmaddubsw(
- tmpReg, AReg, AllRegs_avx512_[27-j]);
+ a->vpmaddubsw(tmpReg, AReg, AllRegs_avx512_[27 - j]);
a->vpaddsw(
CRegs_avx512_[i * leadingDimCReg + j],
tmpReg,
@@ -90,15 +89,16 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
inst_set_t::avx512>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
- asmjit::X86Gp C_Offset,
- asmjit::X86Gp ldcReg,
+ x86::Gp C_Offset,
+ x86::Gp ldcReg,
+
bool accum,
int leadingDimCReg) {
- asmjit::X86Ymm extractDest256 = x86::ymm31;
- asmjit::X86Zmm extractDest512 = x86::zmm31;
+ x86::Ymm extractDest256 = x86::ymm31;
+ x86::Zmm extractDest512 = x86::zmm31;
for (int i = 0; i < rowRegs; ++i) {
a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(i * sizeof(int32_t)));
@@ -107,7 +107,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
a->vextracti32x8(
extractDest256, CRegs_avx512_[i * leadingDimCReg + j], idx);
a->vpmovsxwd(extractDest512, extractDest256);
- asmjit::X86Mem destAddr = x86::dword_ptr(
+ x86::Mem destAddr = x86::dword_ptr(
a->zcx(), C_Offset, 0, (j * 2 + idx) * 16 * sizeof(int32_t));
if (accum) {
a->vpaddd(extractDest512, extractDest512, destAddr);
@@ -172,9 +172,9 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
}
code_.reset(false);
- code_.init(rt_.getCodeInfo());
- asmjit::X86Assembler assembler(&code_);
- asmjit::X86Emitter* a = assembler.asEmitter();
+ code_.init(rt_.codeInfo());
+ x86::Assembler assembler(&code_);
+ x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
// generated code logging
@@ -209,49 +209,49 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
int mRegBlocksRem = mc % mRegBlockSize;
// arguments to the function created
- asmjit::X86Gp buffer_A = a->zdi();
- asmjit::X86Gp buffer_B = a->zsi();
- asmjit::X86Gp B_pf = a->zdx();
- asmjit::X86Gp CBase = a->zcx();
- asmjit::X86Gp kSize = a->gpzRef(8);
- asmjit::X86Gp ldcReg = a->gpzRef(9);
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
asmjit::FuncDetail func;
func.init(
asmjit::
- FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
+ FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
asmjit::CallConv::kIdHost));
- asmjit::FuncFrameInfo ffi;
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindVec,
- asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindGp,
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
+ asmjit::FuncFrame frame;
+ frame.init(func);
- asmjit::FuncArgsMapper args(&func);
- args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(
+ x86::Reg::kGroupGp,
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- args.updateFrameInfo(ffi);
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
- asmjit::FuncFrameLayout layout;
- layout.init(func, ffi);
+ args.updateFuncFrame(frame);
+ frame.finalize();
- asmjit::FuncUtils::emitProlog(a, layout);
- asmjit::FuncUtils::allocArgs(a, layout, args);
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
asmjit::Label LoopMBlocks = a->newLabel();
asmjit::Label LoopNBlocks = a->newLabel();
asmjit::Label Loopk = a->newLabel();
- asmjit::X86Gp buffer_B_saved = a->gpzRef(10);
- asmjit::X86Gp C_Offset = a->gpzRef(11);
- // asmjit::X86Gp B_pf_saved = a->gpzRef(12);
- asmjit::X86Gp iIdx = a->gpzRef(13);
- asmjit::X86Gp jIdx = a->gpzRef(14);
- asmjit::X86Gp kIdx = a->gpzRef(15);
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ // x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp jIdx = a->gpz(14);
+ x86::Gp kIdx = a->gpz(15);
// save B_buffer address
a->mov(buffer_B_saved, buffer_B);
@@ -407,7 +407,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
a->jl(LoopNRem);
}
- asmjit::FuncUtils::emitEpilog(a, layout);
+ a->emitEpilog(frame);
jit_micro_kernel_fp fn;
asmjit::Error err = rt_.add(&fn, &code_);
diff --git a/src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc b/src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc
new file mode 100644
index 0000000..f559aba
--- /dev/null
+++ b/src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc
@@ -0,0 +1,102 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <iostream>
+#include "GenerateKernel.h"
+
+namespace fbgemm {
+
+namespace x86 = asmjit::x86;
+
+/**
+ * Generate AVX512 instructions for initializing the C registers to 0 in 16-bit
+ * Accumulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs<
+ inst_set_t::avx512_vnni>(
+ x86::Emitter* a,
+ int rowRegs,
+ int colRegs,
+ int leadingDimCReg) {
+ assert(0 && "Accumulation to int16_t is not available for VNNI!");
+
+ // For AVX512VNNI, redirect to int32_t accumulation.
+ CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
+ codeObj.initCRegs<inst_set_t::avx512_vnni>(
+ a, rowRegs, colRegs, leadingDimCReg);
+}
+
+/**
+ * Generate AVX512 instructions for computing block in the rank-k update of
+ * 16-bit Accmulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
+ inst_set_t::avx512_vnni>(
+ x86::Emitter* a,
+ x86::Gp buffer_A,
+ x86::Gp buffer_B,
+ x86::Gp /* unused (reserved for prefetching)*/,
+ int rowRegs,
+ int colRegs,
+ int lda,
+ int leadingDimCReg) {
+ assert(0 && "Accumulation to int16_t is not available for VNNI!");
+
+ // For AVX512VNNI, redirect to int32_t accumulation.
+ CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
+ codeObj.genComputeBlock<inst_set_t::avx512_vnni>(
+ a, buffer_A, buffer_B, buffer_B, rowRegs, colRegs, lda, leadingDimCReg);
+}
+
+/**
+ * Generate AVX512 instructions for storing the C registers back to the memory
+ * in 16-bit Accumulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
+ inst_set_t::avx512_vnni>(
+ x86::Emitter* a,
+ int rowRegs,
+ int colRegs,
+ x86::Gp C_Offset,
+ x86::Gp ldcReg,
+ bool accum,
+ int leadingDimCReg) {
+ assert(0 && "Accumulation to int16_t is not available for VNNI!");
+
+ // For AVX512VNNI, redirect to int32_t accumulation.
+ CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
+ codeObj.storeCRegs<inst_set_t::avx512_vnni>(
+ a, rowRegs, colRegs, C_Offset, ldcReg, accum, leadingDimCReg);
+}
+
+/**
+ * Get or Create the AVX512 instructions for 16-bit Accumulation macro-kernel.
+ *
+ */
+template <>
+template <>
+CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::jit_micro_kernel_fp
+CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<
+ inst_set_t::avx512_vnni>(
+ bool accum,
+ int32_t mc,
+ int32_t nc,
+ int32_t kc,
+ int32_t /* unused */) {
+ assert(0 && "Accumulation to int16_t is not available for VNNI!");
+
+ // For AVX512VNNI, redirect to int32_t accumulation.
+ CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
+ return codeObj.getOrCreate<inst_set_t::avx512_vnni>(accum, mc, nc, kc, kc);
+}
+
+} // namespace fbgemm
diff --git a/src/GenerateKernelU8S8S32ACC32.cc b/src/GenerateKernelU8S8S32ACC32.cc
index 58643ad..6b54743 100644
--- a/src/GenerateKernelU8S8S32ACC32.cc
+++ b/src/GenerateKernelU8S8S32ACC32.cc
@@ -31,7 +31,7 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs<
inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
int leadingDimCReg) {
@@ -53,25 +53,25 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
inst_set_t::avx2>(
- asmjit::X86Emitter* a,
- asmjit::X86Gp buffer_A,
- asmjit::X86Gp buffer_B,
- asmjit::X86Gp B_pf,
+ x86::Emitter* a,
+ x86::Gp buffer_A,
+ x86::Gp buffer_B,
+ x86::Gp B_pf,
int rowRegs,
int colRegs,
int lda,
int leadingDimCReg) {
// used for matrix A
- asmjit::X86Ymm AReg = x86::ymm12;
+ x86::Ymm AReg = x86::ymm12;
// used for matrix B
- asmjit::X86Ymm BReg = x86::ymm13;
+ x86::Ymm BReg = x86::ymm13;
// Contains 16-bit 1s
- asmjit::X86Ymm oneReg = x86::ymm15;
+ x86::Ymm oneReg = x86::ymm15;
// temporary register
- asmjit::X86Ymm res1 = x86::ymm14;
+ x86::Ymm res1 = x86::ymm14;
for (int j = 0; j < colRegs; ++j) {
// load B
@@ -99,11 +99,11 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<
inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
- asmjit::X86Gp C_Offset,
- asmjit::X86Gp ldcReg,
+ x86::Gp C_Offset,
+ x86::Gp ldcReg,
bool accum,
int leadingDimCReg) {
for (int i = 0; i < rowRegs; ++i) {
@@ -177,9 +177,9 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>(
return codeCache_[kernelSig];
}
code_.reset(false);
- code_.init(rt_.getCodeInfo());
- asmjit::X86Assembler assembler(&code_);
- asmjit::X86Emitter* a = assembler.asEmitter();
+ code_.init(rt_.codeInfo());
+ x86::Assembler assembler(&code_);
+ x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
// generated code logging
FILE* codeLogfile = fopen(
@@ -205,49 +205,48 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>(
int mRegBlocksRem = mc % mRegBlockSize;
// arguments to the function created
- asmjit::X86Gp buffer_A = a->zdi();
- asmjit::X86Gp buffer_B = a->zsi();
- asmjit::X86Gp B_pf = a->zdx();
- asmjit::X86Gp CBase = a->zcx();
- asmjit::X86Gp kSize = a->gpzRef(8);
- asmjit::X86Gp ldcReg = a->gpzRef(9);
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
asmjit::FuncDetail func;
func.init(
asmjit::
- FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
+ FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
asmjit::CallConv::kIdHost));
- asmjit::FuncFrameInfo ffi;
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindVec,
- asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14));
+ asmjit::FuncFrame frame;
+ frame.init(func);
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(
+ x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));
- asmjit::FuncArgsMapper args(&func);
+ asmjit::FuncArgsAssignment args(&func);
args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
- args.updateFrameInfo(ffi);
+ args.updateFuncFrame(frame);
+ frame.finalize();
- asmjit::FuncFrameLayout layout;
- layout.init(func, ffi);
-
- asmjit::FuncUtils::emitProlog(a, layout);
- asmjit::FuncUtils::allocArgs(a, layout, args);
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
asmjit::Label Loopk = a->newLabel();
asmjit::Label LoopMBlocks = a->newLabel();
- asmjit::X86Gp buffer_B_saved = a->gpzRef(10);
- asmjit::X86Gp C_Offset = a->gpzRef(11);
- asmjit::X86Gp B_pf_saved = a->gpzRef(12);
- asmjit::X86Gp iIdx = a->gpzRef(13);
- asmjit::X86Gp kIdx = a->gpzRef(14);
- // asmjit::X86Gp B_pf = a->gpzRef(8);
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp kIdx = a->gpz(14);
+ // x86::Gp B_pf = a->gpz(8);
- asmjit::X86Ymm oneReg = x86::ymm15;
+ x86::Ymm oneReg = x86::ymm15;
// create 16-bit 1s
// i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
// and so on
@@ -358,7 +357,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>(
a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
}
- asmjit::FuncUtils::emitEpilog(a, layout);
+ a->emitEpilog(frame);
jit_micro_kernel_fp fn;
asmjit::Error err = rt_.add(&fn, &code_);
diff --git a/src/GenerateKernelU8S8S32ACC32Avx512.cc b/src/GenerateKernelU8S8S32ACC32Avx512.cc
index 12243ee..fe35627 100644
--- a/src/GenerateKernelU8S8S32ACC32Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC32Avx512.cc
@@ -19,7 +19,7 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs<
inst_set_t::avx512>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
int leadingDimCReg) {
@@ -41,25 +41,25 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
inst_set_t::avx512>(
- asmjit::X86Emitter* a,
- asmjit::X86Gp buffer_A,
- asmjit::X86Gp buffer_B,
- asmjit::X86Gp B_pf,
+ x86::Emitter* a,
+ x86::Gp buffer_A,
+ x86::Gp buffer_B,
+ x86::Gp B_pf,
int rowRegs,
int colRegs,
int lda,
int leadingDimCReg) {
// used for matrix A
- asmjit::X86Zmm AReg = x86::zmm31;
+ x86::Zmm AReg = x86::zmm31;
// used for matrix B
- asmjit::X86Zmm BReg = x86::zmm30;
+ x86::Zmm BReg = x86::zmm30;
// Contains 16-bit 1s
- asmjit::X86Zmm oneReg = x86::zmm29;
+ x86::Zmm oneReg = x86::zmm29;
// temporary register
- asmjit::X86Zmm res1 = x86::zmm28;
+ x86::Zmm res1 = x86::zmm28;
for (int j = 0; j < colRegs; ++j) {
// load B
@@ -87,18 +87,17 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<
inst_set_t::avx512>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
- asmjit::X86Gp C_Offset,
- asmjit::X86Gp ldcReg,
+ x86::Gp C_Offset,
+ x86::Gp ldcReg,
bool accum,
int leadingDimCReg) {
for (int i = 0; i < rowRegs; ++i) {
if (i != 0) {
a->add(C_Offset, ldcReg);
- }
- else {
+ } else {
a->mov(C_Offset, static_cast<asmjit::Imm>(0));
}
for (int j = 0; j < colRegs; ++j) {
@@ -168,9 +167,9 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
return codeCache_[kernelSig];
}
code_.reset(false);
- code_.init(rt_.getCodeInfo());
- asmjit::X86Assembler assembler(&code_);
- asmjit::X86Emitter* a = assembler.asEmitter();
+ code_.init(rt_.codeInfo());
+ x86::Assembler assembler(&code_);
+ x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
// generated code logging
@@ -205,52 +204,52 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
int mRegBlocksRem = mc % mRegBlockSize;
// arguments to the function created
- asmjit::X86Gp buffer_A = a->zdi();
- asmjit::X86Gp buffer_B = a->zsi();
- asmjit::X86Gp B_pf = a->zdx();
- asmjit::X86Gp CBase = a->zcx();
- asmjit::X86Gp kSize = a->gpzRef(8);
- asmjit::X86Gp ldcReg = a->gpzRef(9);
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
asmjit::FuncDetail func;
func.init(
asmjit::
- FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
+ FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
asmjit::CallConv::kIdHost));
- asmjit::FuncFrameInfo ffi;
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindVec,
- asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindGp,
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
+ asmjit::FuncFrame frame;
+ frame.init(func);
- asmjit::FuncArgsMapper args(&func);
- args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(
+ x86::Reg::kGroupGp,
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- args.updateFrameInfo(ffi);
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
- asmjit::FuncFrameLayout layout;
- layout.init(func, ffi);
+ args.updateFuncFrame(frame);
+ frame.finalize();
- asmjit::FuncUtils::emitProlog(a, layout);
- asmjit::FuncUtils::allocArgs(a, layout, args);
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
asmjit::Label LoopMBlocks = a->newLabel();
asmjit::Label LoopNBlocks = a->newLabel();
asmjit::Label Loopk = a->newLabel();
- asmjit::X86Gp buffer_B_saved = a->gpzRef(10);
- asmjit::X86Gp C_Offset = a->gpzRef(11);
- asmjit::X86Gp B_pf_saved = a->gpzRef(12);
- asmjit::X86Gp iIdx = a->gpzRef(13);
- asmjit::X86Gp jIdx = a->gpzRef(14);
- asmjit::X86Gp kIdx = a->gpzRef(15);
- // asmjit::X86Gp B_pf = a->gpzRef(8);
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp jIdx = a->gpz(14);
+ x86::Gp kIdx = a->gpz(15);
+ // x86::Gp B_pf = a->gpz(8);
- asmjit::X86Zmm oneReg = x86::zmm29;
+ x86::Zmm oneReg = x86::zmm29;
// create 16-bit 1s
// i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
// and so on
@@ -420,7 +419,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
a->jl(LoopNRem);
}
- asmjit::FuncUtils::emitEpilog(a, layout);
+ a->emitEpilog(frame);
jit_micro_kernel_fp fn;
asmjit::Error err = rt_.add(&fn, &code_);
diff --git a/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc b/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc
new file mode 100644
index 0000000..8ae0745
--- /dev/null
+++ b/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc
@@ -0,0 +1,431 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <iostream>
+#include "GenerateKernel.h"
+
+namespace fbgemm {
+
+namespace x86 = asmjit::x86;
+
+/**
+ * Generate AVX512 instructions for initializing the C registers to 0 in 32-bit
+ * Accumulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs<
+ inst_set_t::avx512_vnni>(
+ x86::Emitter* a,
+ int rowRegs,
+ int colRegs,
+ int leadingDimCReg) {
+ for (int i = 0; i < rowRegs; ++i) {
+ for (int j = 0; j < colRegs; ++j) {
+ a->vxorps(
+ CRegs_avx512_[i * leadingDimCReg + j],
+ CRegs_avx512_[i * leadingDimCReg + j],
+ CRegs_avx512_[i * leadingDimCReg + j]);
+ }
+ }
+}
+
+/**
+ * Generate AVX512 instructions for computing block in the rank-k update of
+ * 32-bit Accmulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
+ inst_set_t::avx512_vnni>(
+ x86::Emitter* a,
+ x86::Gp buffer_A,
+ x86::Gp buffer_B,
+ x86::Gp B_pf,
+ int rowRegs,
+ int colRegs,
+ int lda,
+ int leadingDimCReg) {
+ // used for matrix A
+ x86::Zmm AReg = x86::zmm31;
+
+ // used for matrix B
+ x86::Zmm BReg = x86::zmm30;
+
+ for (int j = 0; j < colRegs; ++j) {
+ // load B
+ a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
+ // load A, broadcast and fmas
+ for (int i = 0; i < rowRegs; ++i) {
+ a->vpbroadcastd(
+ AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t)));
+ a->vpdpbusd(CRegs_avx512_[i * leadingDimCReg + j], AReg, BReg);
+ }
+ a->prefetcht0(x86::dword_ptr(B_pf, j * VLEN_ * sizeof(int8_t)));
+ }
+}
+
+/**
+ * Generate AVX512 instructions for storing the C registers back to the memory
+ * in 32-bit Accumulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<
+ inst_set_t::avx512_vnni>(
+ x86::Emitter* a,
+ int rowRegs,
+ int colRegs,
+ x86::Gp C_Offset,
+ x86::Gp ldcReg,
+ bool accum,
+ int leadingDimCReg) {
+ for (int i = 0; i < rowRegs; ++i) {
+ if (i != 0) {
+ a->add(C_Offset, ldcReg);
+ } else {
+ a->mov(C_Offset, static_cast<asmjit::Imm>(0));
+ }
+ for (int j = 0; j < colRegs; ++j) {
+ if (accum) {
+ a->vpaddd(
+ CRegs_avx512_[i * leadingDimCReg + j],
+ CRegs_avx512_[i * leadingDimCReg + j],
+ x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)));
+ }
+ a->vmovups(
+ x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)),
+ CRegs_avx512_[i * leadingDimCReg + j]);
+ }
+ }
+}
+
+/**
+ * Get or Create the AVX512 instructions for 32-bit Accumulation macro-kernel.
+ *
+ */
+template <>
+template <>
+CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::jit_micro_kernel_fp
+CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<
+ inst_set_t::avx512_vnni>(
+ bool accum,
+ int32_t mc,
+ int32_t nc,
+ int32_t kc,
+ int32_t /* unused */) {
+ std::tuple<bool, int, int, int, int, int, int, int> kernelSig;
+ int kBlock;
+ int nBlock;
+ int mRegBlockSize;
+ int nRegBlockSize;
+ int nRegBlockSizeMin;
+ int row_interleave;
+
+ if (blocking_params) {
+ kBlock = blocking_params->KCB;
+ nBlock = blocking_params->NCB;
+ mRegBlockSize = blocking_params->MR;
+ nRegBlockSize = blocking_params->NR;
+ nRegBlockSizeMin = blocking_params->NR_MIN;
+ row_interleave = blocking_params->ROW_INTERLEAVE;
+ } else {
+ kBlock = PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::KCB;
+ nBlock = PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::NCB;
+ mRegBlockSize =
+ PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::MR;
+ nRegBlockSize =
+ PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::NR;
+ nRegBlockSizeMin =
+ PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::NR_MIN;
+ row_interleave = PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::
+ ROW_INTERLEAVE;
+ }
+
+ kernelSig = std::make_tuple(
+ accum,
+ mc,
+ nc,
+ nBlock,
+ kBlock,
+ mRegBlockSize,
+ nRegBlockSize,
+ nRegBlockSizeMin);
+
+ if (codeCache_.find(kernelSig) != codeCache_.end()) {
+ return codeCache_[kernelSig];
+ }
+ code_.reset(false);
+ code_.init(rt_.codeInfo());
+ x86::Assembler assembler(&code_);
+ x86::Emitter* a = assembler.as<x86::Emitter>();
+
+#if defined(FBGEMM_LOG_CODE)
+ // generated code logging
+ FILE* codeLogfile = fopen(
+ getCodeLoggingFile<inst_set_t::avx512_vnni>(
+ accum,
+ mc,
+ nc,
+ nBlock,
+ kBlock,
+ mRegBlockSize,
+ nRegBlockSize,
+ nRegBlockSizeMin)
+ .c_str(),
+ "w");
+ asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
+ if (codeLogger) {
+ code_.setLogger(codeLogger);
+ }
+#endif
+
+ assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
+ assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
+ int maxMRegs = mRegBlockSize;
+ int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
+ assert(
+ maxMRegs * maxNRegs <= 28 &&
+ "MR*(NR*ROW_INTERLEAVE*8/512) \
+ must be <= 28(available registers constraint)");
+
+ int mRegBlocks = mc / mRegBlockSize;
+ int mRegBlocksRem = mc % mRegBlockSize;
+
+ // arguments to the function created
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
+
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::
+ FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
+ asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrame frame;
+ frame.init(func);
+
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(
+ x86::Reg::kGroupGp,
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+
+ args.updateFuncFrame(frame);
+ frame.finalize();
+
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
+
+ asmjit::Label LoopMBlocks = a->newLabel();
+ asmjit::Label LoopNBlocks = a->newLabel();
+ asmjit::Label Loopk = a->newLabel();
+
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp jIdx = a->gpz(14);
+ x86::Gp kIdx = a->gpz(15);
+ // x86::Gp B_pf = a->gpz(8);
+
+ x86::Zmm oneReg = x86::zmm29;
+ // create 16-bit 1s
+ // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
+ // and so on
+ // a->vpcmpeqw(oneReg, oneReg, oneReg);
+ a->vpternlogd(oneReg, oneReg, oneReg, 0xff);
+ a->vpsrlw(oneReg, oneReg, 15);
+ a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t)));
+
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ a->mov(B_pf_saved, B_pf);
+
+ int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
+ int colRegs = std::min(currColRegs, maxNRegs);
+ if (mRegBlocks > 0) {
+ // move 0 to iteration variables
+ a->mov(iIdx, 0);
+
+ a->bind(LoopMBlocks);
+ a->inc(iIdx);
+ a->mov(jIdx, 0);
+
+ a->bind(LoopNBlocks);
+ a->inc(jIdx);
+
+ int rowRegs = mRegBlockSize;
+
+ // init C registers
+ initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(Loopk);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx512_vnni>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(
+ buffer_B,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+ a->add(
+ B_pf,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+
+ // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(Loopk);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512_vnni>(
+ a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
+
+ // reset A
+ a->sub(buffer_A, kSize);
+
+ // B for next block
+ a->mov(buffer_B, buffer_B_saved);
+ // using C_Offset as temp reg
+ a->imul(
+ C_Offset,
+ jIdx,
+ static_cast<asmjit::Imm>(
+ nRegBlockSize * row_interleave * sizeof(int8_t)));
+ a->add(buffer_B, C_Offset);
+ a->mov(B_pf, B_pf_saved);
+ a->add(B_pf, C_Offset);
+
+ // increment C for next B block
+ a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
+
+ int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
+ a->cmp(jIdx, jLoopTrips);
+ a->jl(LoopNBlocks);
+
+ // increment A for next block
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
+
+ // increment C for next A block
+ a->sub(
+ CBase,
+ static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * sizeof(int32_t)));
+ a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs));
+ a->add(CBase, C_Offset);
+
+ // reset B
+ a->mov(buffer_B, buffer_B_saved);
+ a->mov(B_pf, B_pf_saved);
+ a->cmp(iIdx, mRegBlocks);
+ a->jl(LoopMBlocks);
+ }
+ // generate code for remainder
+ if (mRegBlocksRem > 0) {
+ asmjit::Label LoopNRem = a->newLabel();
+ asmjit::Label LoopkRem = a->newLabel();
+ int rowRegs = mRegBlocksRem;
+
+ a->mov(jIdx, 0);
+ a->bind(LoopNRem);
+ a->inc(jIdx);
+
+ // init C registers
+ initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(LoopkRem);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx512_vnni>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(
+ buffer_B,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+ a->add(
+ B_pf,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(LoopkRem);
+
+ // reset A
+ a->sub(buffer_A, kSize);
+ // B for next block
+ // using C_Offset as temp reg
+ a->imul(
+ C_Offset,
+ jIdx,
+ static_cast<asmjit::Imm>(
+ nRegBlockSize * row_interleave * sizeof(int8_t)));
+ a->mov(buffer_B, buffer_B_saved);
+ a->add(buffer_B, C_Offset);
+ a->mov(B_pf, B_pf_saved);
+ a->add(B_pf, C_Offset);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512_vnni>(
+ a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
+
+ // increment C for next B block
+ a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
+
+ int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
+ a->cmp(jIdx, jLoopTrips);
+ a->jl(LoopNRem);
+ }
+
+ a->emitEpilog(frame);
+
+ jit_micro_kernel_fp fn;
+ asmjit::Error err = rt_.add(&fn, &code_);
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
+ codeCache_[kernelSig] = fn;
+
+#if defined(FBGEMM_LOG_CODE)
+ fclose(codeLogfile);
+ delete codeLogger;
+#endif
+
+ return fn;
+}
+
+} // namespace fbgemm
diff --git a/src/GroupwiseConv.h b/src/GroupwiseConv.h
index 1e6324e..4c5eea5 100644
--- a/src/GroupwiseConv.h
+++ b/src/GroupwiseConv.h
@@ -128,60 +128,58 @@ class GenConvKernel {
const conv_param_t<SPATIAL_DIM>& conv_param);
template <inst_set_t instSet>
- void createVector16BitOne(asmjit::X86Emitter* a);
+ void createVector16BitOne(x86::Emitter* a);
template <inst_set_t instSet>
- void createVector8BitOne(asmjit::X86Emitter* a);
+ void createVector8BitOne(x86::Emitter* a);
template <inst_set_t instSet>
- void setToZeroPt(asmjit::X86Emitter* a, asmjit::X86Ymm destReg);
+ void setToZeroPt(x86::Emitter* a, x86::Ymm destReg);
template <inst_set_t instSet>
- void
- gen8bitFMA(asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm wReg);
+ void gen8bitFMA(x86::Emitter* a, x86::Ymm aReg, x86::Ymm wReg);
template <inst_set_t instSet>
- void genForLoadingWeights(asmjit::X86Emitter* a, int c_offset);
+ void genForLoadingWeights(x86::Emitter* a, int c_offset);
template <inst_set_t instSet>
- void genConstForPermutations(asmjit::X86Emitter* a);
+ void genConstForPermutations(x86::Emitter* a);
template <inst_set_t instSet>
- void genForTopEdge(asmjit::X86Emitter* a, int c_offset);
+ void genForTopEdge(x86::Emitter* a, int c_offset);
template <inst_set_t instSet>
- void genForLeftEdge(asmjit::X86Emitter* a, int c_offset);
+ void genForLeftEdge(x86::Emitter* a, int c_offset);
template <inst_set_t instSet>
- void genForRightEdge(asmjit::X86Emitter* a, int c_offset);
+ void genForRightEdge(x86::Emitter* a, int c_offset);
template <inst_set_t instSet>
- void genForBottomEdge(asmjit::X86Emitter* a, int c_offset);
+ void genForBottomEdge(x86::Emitter* a, int c_offset);
template <inst_set_t instSet>
- void genCoreInsts(asmjit::X86Emitter* a, int c_offset);
+ void genCoreInsts(x86::Emitter* a, int c_offset);
template <inst_set_t instSet>
- void storeResult(asmjit::X86Emitter* a);
+ void storeResult(x86::Emitter* a);
// for Rowoffset kernel
// Add 4 consecutive numbers of 32 uint8 and emit 8 32-bit
template <inst_set_t instSet>
- void gen8BitSumX4(asmjit::X86Emitter* a, asmjit::X86Ymm aReg);
+ void gen8BitSumX4(x86::Emitter* a, x86::Ymm aReg);
// Add 8 consecutive numbers of 64 uint8 and emit 8 32-bit
template <inst_set_t instSet>
- void
- gen8BitSumX8(asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm bReg);
+ void gen8BitSumX8(x86::Emitter* a, x86::Ymm aReg, x86::Ymm bReg);
// Add 16 consecutive numbers of 128 uint8 and emit 8 32-bit
template <inst_set_t instSet>
void gen8BitSumX16(
- asmjit::X86Emitter* a,
- asmjit::X86Ymm aReg,
- asmjit::X86Ymm bReg,
- asmjit::X86Ymm cReg,
- asmjit::X86Ymm dReg);
+ x86::Emitter* a,
+ x86::Ymm aReg,
+ x86::Ymm bReg,
+ x86::Ymm cReg,
+ x86::Ymm dReg);
// Generate instruction sequence that loads 8-bit values and sum them up.
// Depending on C_per_G_, this function dispatches to gen8BitSumX4/8/16
@@ -191,35 +189,33 @@ class GenConvKernel {
// Internally, actRegAvx2_, stPermRegAvx2_, WRegs_avx2_[0, 1], tmpReg1Avx2_,
// and resultRegAvx2_ are used.
template <inst_set_t instSet>
- void gen8BitSum(
- asmjit::X86Emitter* a,
- int act_offset,
- bool use_scratch_reg1 = true);
+ void
+ gen8BitSum(x86::Emitter* a, int act_offset, bool use_scratch_reg1 = true);
// Use scratchReg1_ and tmpReg1Avx2_ internally
template <inst_set_t instSet>
- void genZeroPtSum(asmjit::X86Emitter* a, int multiplier);
+ void genZeroPtSum(x86::Emitter* a, int multiplier);
template <inst_set_t instSet>
- void genForTopEdgeRowoffset(asmjit::X86Emitter* a);
+ void genForTopEdgeRowoffset(x86::Emitter* a);
template <inst_set_t instSet>
- void genForLeftEdgeRowoffset(asmjit::X86Emitter* a);
+ void genForLeftEdgeRowoffset(x86::Emitter* a);
template <inst_set_t instSet>
- void genForRightEdgeRowoffset(asmjit::X86Emitter* a);
+ void genForRightEdgeRowoffset(x86::Emitter* a);
template <inst_set_t instSet>
- void genForBottomEdgeRowoffset(asmjit::X86Emitter* a);
+ void genForBottomEdgeRowoffset(x86::Emitter* a);
template <inst_set_t instSet>
- void genRowoffsetCorners(asmjit::X86Emitter* a);
+ void genRowoffsetCorners(x86::Emitter* a);
template <inst_set_t instSet>
- void genRowoffsetCore(asmjit::X86Emitter* a);
+ void genRowoffsetCore(x86::Emitter* a);
template <inst_set_t instSet>
- void storeResultRowoffset(asmjit::X86Emitter* a, int offset = 0);
+ void storeResultRowoffset(x86::Emitter* a, int offset = 0);
static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit.
static thread_local asmjit::CodeHolder code_; ///< JIT Code Holder for asmjit.
@@ -234,30 +230,30 @@ class GenConvKernel {
int vectorWidth_; ///< Vector width in bits.
int VLEN_; ///< Vector width in elements.
// avx2 specific
- asmjit::X86Ymm
+ x86::Ymm
WRegs_avx2_[9]; ///< AVX2 ymm registers for weights in the micro-kernel.
- asmjit::X86Ymm zeroPTRegAvx2_;
- asmjit::X86Ymm tmpReg1Avx2_;
- asmjit::X86Ymm stPermRegAvx2_;
- asmjit::X86Ymm actRegAvx2_;
- asmjit::X86Ymm resultRegAvx2_;
- asmjit::X86Ymm oneReg8BitAvx2_;
- asmjit::X86Ymm oneReg16BitAvx2_;
+ x86::Ymm zeroPTRegAvx2_;
+ x86::Ymm tmpReg1Avx2_;
+ x86::Ymm stPermRegAvx2_;
+ x86::Ymm actRegAvx2_;
+ x86::Ymm resultRegAvx2_;
+ x86::Ymm oneReg8BitAvx2_;
+ x86::Ymm oneReg16BitAvx2_;
// arguments to the function created
- asmjit::X86Gp in_acts_R_;
- asmjit::X86Gp wghts_R_;
- asmjit::X86Gp out_acts_R_;
- asmjit::X86Gp a_zero_pt_R_;
- asmjit::X86Gp H_R_;
- asmjit::X86Gp W_R_;
- asmjit::X86Gp row_offset_R_;
+ x86::Gp in_acts_R_;
+ x86::Gp wghts_R_;
+ x86::Gp out_acts_R_;
+ x86::Gp a_zero_pt_R_;
+ x86::Gp H_R_;
+ x86::Gp W_R_;
+ x86::Gp row_offset_R_;
// Used registers
- asmjit::X86Gp loopR1_;
- asmjit::X86Gp loopR2_;
- asmjit::X86Gp scratchReg1_;
- asmjit::X86Gp scratchReg2_;
+ x86::Gp loopR1_;
+ x86::Gp loopR2_;
+ x86::Gp scratchReg1_;
+ x86::Gp scratchReg2_;
// Other parameters
bool isAZeroPointZero_;
diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc
index e789695..b140c83 100644
--- a/src/GroupwiseConvAcc32Avx2.cc
+++ b/src/GroupwiseConvAcc32Avx2.cc
@@ -104,7 +104,7 @@ jit_conv_kernel_fp getOrCreateConvKernel(
template <>
template <>
void GenConvKernel<2, int32_t>::createVector8BitOne<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// create 8-bit 1s
// i.e., oneReg16BitAvx2_[0:7] contains 0x01, oneReg8BitAvx2_[8:15] contains
// 0x01 and so on
@@ -115,7 +115,7 @@ void GenConvKernel<2, int32_t>::createVector8BitOne<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::createVector16BitOne<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// create 16-bit 1s
// i.e., oneReg16BitAvx2_[0:15] contains 0x0001, oneReg16BitAvx2_[16:31]
// contains 0x0001 and so on
@@ -125,11 +125,11 @@ void GenConvKernel<2, int32_t>::createVector16BitOne<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::setToZeroPt<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
- asmjit::X86Ymm destReg) {
+ x86::Emitter* a,
+ x86::Ymm destReg) {
// make destReg all zeros
a->vxorps(destReg, destReg, destReg);
- asmjit::X86Xmm const_reg_xmm = x86::xmm10;
+ x86::Xmm const_reg_xmm = x86::xmm10;
// move zero point to xmm10
a->movq(const_reg_xmm, a_zero_pt_R_);
// make copies of zero point
@@ -143,9 +143,9 @@ void GenConvKernel<2, int32_t>::setToZeroPt<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genConstForPermutations<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
- asmjit::X86Gp permute_const_reg = a->gpzRef(12);
- asmjit::X86Xmm const_reg_xmm = x86::xmm10;
+ x86::Emitter* a) {
+ x86::Gp permute_const_reg = a->gpz(12);
+ x86::Xmm const_reg_xmm = x86::xmm10;
// We have 1st group in even lanes and 2nd group in odd lanes.
// Permute to put 1st group to lower 128-bit and 2nd group in upper
// 128-bit.
@@ -159,8 +159,7 @@ void GenConvKernel<2, int32_t>::genConstForPermutations<inst_set_t::avx2>(
template <>
template <>
-void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>(x86::Emitter* a) {
if (C_per_G_ == 4) {
// store with permutation
a->vpermd(resultRegAvx2_, stPermRegAvx2_, resultRegAvx2_);
@@ -171,7 +170,7 @@ void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::storeResultRowoffset<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int offset) {
// store
if (C_per_G_ == 4) {
@@ -198,7 +197,7 @@ void GenConvKernel<2, int32_t>::storeResultRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForLoadingWeights<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int c_offset) {
// load weights
for (int r = 0; r < R_; ++r) {
@@ -225,9 +224,9 @@ void GenConvKernel<2, int32_t>::genForLoadingWeights<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8bitFMA<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
- asmjit::X86Ymm aReg,
- asmjit::X86Ymm wReg) {
+ x86::Emitter* a,
+ x86::Ymm aReg,
+ x86::Ymm wReg) {
a->vpmaddubsw(tmpReg1Avx2_, aReg, wReg);
a->vpmaddwd(tmpReg1Avx2_, oneReg16BitAvx2_, tmpReg1Avx2_);
a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_);
@@ -236,8 +235,8 @@ void GenConvKernel<2, int32_t>::gen8bitFMA<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8BitSumX4<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
- asmjit::X86Ymm aReg) {
+ x86::Emitter* a,
+ x86::Ymm aReg) {
a->vpmaddubsw(tmpReg1Avx2_, aReg, oneReg8BitAvx2_);
a->vpmaddwd(tmpReg1Avx2_, tmpReg1Avx2_, oneReg16BitAvx2_);
a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_);
@@ -246,9 +245,9 @@ void GenConvKernel<2, int32_t>::gen8BitSumX4<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8BitSumX8<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
- asmjit::X86Ymm aReg,
- asmjit::X86Ymm bReg) {
+ x86::Emitter* a,
+ x86::Ymm aReg,
+ x86::Ymm bReg) {
a->vxorps(tmpReg1Avx2_, tmpReg1Avx2_, tmpReg1Avx2_);
// Let a[0] denote 0th (LSB) 8-bit of aReg
// After vpsadbw, a[0:2] = a[0] + ... + a[7]
@@ -267,11 +266,11 @@ void GenConvKernel<2, int32_t>::gen8BitSumX8<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8BitSumX16<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
- asmjit::X86Ymm aReg,
- asmjit::X86Ymm bReg,
- asmjit::X86Ymm cReg,
- asmjit::X86Ymm dReg) {
+ x86::Emitter* a,
+ x86::Ymm aReg,
+ x86::Ymm bReg,
+ x86::Ymm cReg,
+ x86::Ymm dReg) {
a->vxorps(tmpReg1Avx2_, tmpReg1Avx2_, tmpReg1Avx2_);
// After vpsadbw, a[0:2] = a[0] + ... + a[7]
// a[8:10] = a[8] + ... + a[15]
@@ -319,7 +318,7 @@ void GenConvKernel<2, int32_t>::gen8BitSumX16<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8BitSum<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int act_offset,
bool use_scratch_reg1 /*=true*/) {
if (use_scratch_reg1) {
@@ -385,11 +384,11 @@ void GenConvKernel<2, int32_t>::gen8BitSum<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genZeroPtSum<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int multiplier) {
a->mov(scratchReg1_, static_cast<asmjit::Imm>(multiplier));
// tmpReg1Avx2_ also uses xmm11
- asmjit::X86Xmm const_reg_xmm = x86::xmm11;
+ x86::Xmm const_reg_xmm = x86::xmm11;
a->movq(const_reg_xmm, scratchReg1_);
a->vpbroadcastd(tmpReg1Avx2_, const_reg_xmm);
a->vpmulld(tmpReg1Avx2_, zeroPTRegAvx2_, tmpReg1Avx2_);
@@ -399,7 +398,7 @@ void GenConvKernel<2, int32_t>::genZeroPtSum<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForTopEdge<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int c_offset) {
// top-left corner code
if (c_offset == 0) {
@@ -559,7 +558,7 @@ void GenConvKernel<2, int32_t>::genForTopEdge<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForLeftEdge<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int c_offset) {
// left edge excluding corners
asmjit::Label LoopLeftEdge = a->newLabel();
@@ -626,7 +625,7 @@ void GenConvKernel<2, int32_t>::genForLeftEdge<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForRightEdge<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int c_offset) {
// right edge excluding corners
asmjit::Label LoopRightEdge = a->newLabel();
@@ -714,7 +713,7 @@ void GenConvKernel<2, int32_t>::genForRightEdge<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForBottomEdge<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int c_offset) {
// bottom-left corner
// we updating the last row
@@ -906,7 +905,7 @@ void GenConvKernel<2, int32_t>::genForBottomEdge<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genCoreInsts<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int c_offset) {
// main compute
asmjit::Label LoopH = a->newLabel();
@@ -1011,9 +1010,9 @@ template <>
jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
const conv_param_t<2>& conv_param) {
code_.reset(false);
- code_.init(rt_.getCodeInfo());
- asmjit::X86Assembler assembler(&code_);
- asmjit::X86Emitter* a = assembler.asEmitter();
+ code_.init(rt_.codeInfo());
+ x86::Assembler assembler(&code_);
+ x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
// log code to a file
@@ -1030,16 +1029,16 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
wghts_R_ = a->zsi();
out_acts_R_ = a->zdx();
a_zero_pt_R_ = a->zcx();
- H_R_ = a->gpzRef(8);
- W_R_ = a->gpzRef(9);
- row_offset_R_ = a->gpzRef(10);
+ H_R_ = a->gpz(8);
+ W_R_ = a->gpz(9);
+ row_offset_R_ = a->gpz(10);
// register for temporary use
- scratchReg1_ = a->gpzRef(12);
- scratchReg2_ = a->gpzRef(13);
+ scratchReg1_ = a->gpz(12);
+ scratchReg2_ = a->gpz(13);
asmjit::FuncDetail func;
- func.init(asmjit::FuncSignature6<
+ func.init(asmjit::FuncSignatureT<
void,
uint8_t*,
int8_t*,
@@ -1048,29 +1047,29 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
int32_t,
int32_t>(asmjit::CallConv::kIdHost));
- asmjit::FuncFrameInfo ffi;
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindVec,
- asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindGp, asmjit::Utils::mask(10, 11, 12, 13, 14, 15));
+ asmjit::FuncFrame frame;
+ frame.init(func);
- asmjit::FuncArgsMapper args(&func);
- args.assignAll(in_acts_R_, wghts_R_, out_acts_R_, a_zero_pt_R_, H_R_, W_R_);
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(
+ x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- args.updateFrameInfo(ffi);
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(in_acts_R_, wghts_R_, out_acts_R_, a_zero_pt_R_, H_R_, W_R_);
- asmjit::FuncFrameLayout layout;
- layout.init(func, ffi);
+ args.updateFuncFrame(frame);
+ frame.finalize();
- asmjit::FuncUtils::emitProlog(a, layout);
- asmjit::FuncUtils::allocArgs(a, layout, args);
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
createVector16BitOne<inst_set_t::avx2>(a);
- loopR1_ = a->gpzRef(14);
- loopR2_ = a->gpzRef(15);
+ loopR1_ = a->gpz(14);
+ loopR2_ = a->gpz(15);
if (!isAZeroPointZero_) {
setToZeroPt<inst_set_t::avx2>(a, zeroPTRegAvx2_);
@@ -1095,7 +1094,7 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
genCoreInsts<inst_set_t::avx2>(a, c);
}
- asmjit::FuncUtils::emitEpilog(a, layout);
+ a->emitEpilog(frame);
jit_conv_kernel_fp fn;
asmjit::Error err = rt_.add(&fn, &code_);
@@ -1117,7 +1116,7 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// top-left corner code
// zero out the results register
a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
@@ -1213,7 +1212,7 @@ void GenConvKernel<2, int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// left edge excluding corners
asmjit::Label LoopLeftEdge = a->newLabel();
a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_));
@@ -1256,7 +1255,7 @@ void GenConvKernel<2, int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// right edge excluding corners
asmjit::Label LoopRightEdge = a->newLabel();
@@ -1326,7 +1325,7 @@ void GenConvKernel<2, int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// bottom-left corner
// zero out
a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
@@ -1429,7 +1428,7 @@ void GenConvKernel<2, int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genRowoffsetCore<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// number of uint8 elements in input channels should be a multiple of 32
assert(C_ % 32 == 0);
@@ -1491,9 +1490,9 @@ jit_rowoffset_kernel_fp
GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
const conv_param_t<2>& conv_param) {
code_.reset(false);
- code_.init(rt_.getCodeInfo());
- asmjit::X86Assembler assembler(&code_);
- asmjit::X86Emitter* a = assembler.asEmitter();
+ code_.init(rt_.codeInfo());
+ x86::Assembler assembler(&code_);
+ x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
// log code to a file
@@ -1510,45 +1509,45 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
a_zero_pt_R_ = a->zsi();
H_R_ = a->zdx();
W_R_ = a->zcx();
- row_offset_R_ = a->gpzRef(8);
+ row_offset_R_ = a->gpz(8);
// register for temporary use
- scratchReg1_ = a->gpzRef(12);
- scratchReg2_ = a->gpzRef(13);
+ scratchReg1_ = a->gpz(12);
+ scratchReg2_ = a->gpz(13);
- loopR1_ = a->gpzRef(14);
- loopR2_ = a->gpzRef(15);
+ loopR1_ = a->gpz(14);
+ loopR2_ = a->gpz(15);
asmjit::FuncDetail func;
func.init(
asmjit::
- FuncSignature5<void, uint8_t*, int32_t, int32_t, int32_t, int32_t*>(
+ FuncSignatureT<void, uint8_t*, int32_t, int32_t, int32_t, int32_t*>(
asmjit::CallConv::kIdHost));
- asmjit::FuncFrameInfo ffi;
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindVec,
- asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindGp, asmjit::Utils::mask(10, 11, 12, 13, 14, 15));
+ asmjit::FuncFrame frame;
+ frame.init(func);
- asmjit::FuncArgsMapper args(&func);
- args.assignAll(in_acts_R_, a_zero_pt_R_, H_R_, W_R_, row_offset_R_);
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(
+ x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- args.updateFrameInfo(ffi);
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(in_acts_R_, a_zero_pt_R_, H_R_, W_R_, row_offset_R_);
- asmjit::FuncFrameLayout layout;
- layout.init(func, ffi);
+ args.updateFuncFrame(frame);
+ frame.finalize();
- asmjit::FuncUtils::emitProlog(a, layout);
- asmjit::FuncUtils::allocArgs(a, layout, args);
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
// This uses xmm10 register temporarily. Should come before
// createVector8BitOne
if (!isAZeroPointZero_) {
// we can use xmm11 because ymm11 is used by tmpReg1Avx2_
- asmjit::X86Xmm const_reg_xmm = x86::xmm11;
+ x86::Xmm const_reg_xmm = x86::xmm11;
a->movq(const_reg_xmm, a_zero_pt_R_);
a->vpbroadcastd(zeroPTRegAvx2_, const_reg_xmm);
@@ -1569,7 +1568,7 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
genRowoffsetCore<inst_set_t::avx2>(a);
- asmjit::FuncUtils::emitEpilog(a, layout);
+ a->emitEpilog(frame);
jit_rowoffset_kernel_fp fn;
asmjit::Error err = rt_.add(&fn, &code_);
diff --git a/src/PackAMatrix.cc b/src/PackAMatrix.cc
index 143e11d..5fabf97 100644
--- a/src/PackAMatrix.cc
+++ b/src/PackAMatrix.cc
@@ -34,7 +34,8 @@ PackAMatrix<T, accT>::PackAMatrix(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
- if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) {
+ if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
+ !fbgemmHasAvx2Support())) {
assert(0 && "unknown architecure");
}
@@ -43,7 +44,12 @@ PackAMatrix<T, accT>::PackAMatrix(
BaseType::bcol_ = params->KCB;
row_interleave_B_ = params->ROW_INTERLEAVE;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE;
+ } else if (fbgemmHasAvx512Support()) {
BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
row_interleave_B_ =
diff --git a/src/PackAWithIm2Col.cc b/src/PackAWithIm2Col.cc
index d731654..2aca27d 100644
--- a/src/PackAWithIm2Col.cc
+++ b/src/PackAWithIm2Col.cc
@@ -49,7 +49,8 @@ PackAWithIm2Col<T, accT, SPATIAL_DIM>::PackAWithIm2Col(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
- if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) {
+ if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
+ !fbgemmHasAvx2Support())) {
assert(0 && "unknown architecure");
}
@@ -58,7 +59,12 @@ PackAWithIm2Col<T, accT, SPATIAL_DIM>::PackAWithIm2Col(
BaseType::bcol_ = params->KCB;
row_interleave_B_ = params->ROW_INTERLEAVE;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE;
+ } else if (fbgemmHasAvx512Support()) {
BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
row_interleave_B_ =
@@ -478,7 +484,9 @@ int PackAWithIm2Col<T, accT, SPATIAL_DIM>::rowOffsetBufferSize(
if (params) {
return params->MCB;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ return PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
+ } else if (fbgemmHasAvx512Support()) {
return PackingTraits<T, accT, inst_set_t::avx512>::MCB;
} else if (fbgemmHasAvx2Support()) {
return PackingTraits<T, accT, inst_set_t::avx2>::MCB;
diff --git a/src/PackAWithQuantRowOffset.cc b/src/PackAWithQuantRowOffset.cc
index 0e5c598..0af05e8 100644
--- a/src/PackAWithQuantRowOffset.cc
+++ b/src/PackAWithQuantRowOffset.cc
@@ -45,7 +45,8 @@ PackAWithQuantRowOffset<T, accT>::PackAWithQuantRowOffset(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
- if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) {
+ if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
+ !fbgemmHasAvx2Support())) {
assert(0 && "unknown architecure");
}
@@ -54,7 +55,12 @@ PackAWithQuantRowOffset<T, accT>::PackAWithQuantRowOffset(
BaseType::bcol_ = params->KCB;
row_interleave_B_ = params->ROW_INTERLEAVE;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE;
+ } else if (fbgemmHasAvx512Support()) {
BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
row_interleave_B_ =
@@ -199,7 +205,9 @@ int PackAWithQuantRowOffset<T, accT>::rowOffsetBufferSize(
if (params) {
return params->MCB;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ return PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
+ } else if (fbgemmHasAvx512Support()) {
return PackingTraits<T, accT, inst_set_t::avx512>::MCB;
} else if (fbgemmHasAvx2Support()) {
return PackingTraits<T, accT, inst_set_t::avx2>::MCB;
diff --git a/src/PackAWithRowOffset.cc b/src/PackAWithRowOffset.cc
index 733bf5c..e84c67b 100644
--- a/src/PackAWithRowOffset.cc
+++ b/src/PackAWithRowOffset.cc
@@ -39,7 +39,8 @@ PackAWithRowOffset<T, accT>::PackAWithRowOffset(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
- if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) {
+ if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
+ !fbgemmHasAvx2Support())) {
assert(0 && "unknown architecure");
}
@@ -48,7 +49,12 @@ PackAWithRowOffset<T, accT>::PackAWithRowOffset(
BaseType::bcol_ = params->KCB;
row_interleave_B_ = params->ROW_INTERLEAVE;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE;
+ } else if (fbgemmHasAvx512Support()) {
BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
row_interleave_B_ =
@@ -189,7 +195,9 @@ int PackAWithRowOffset<T, accT>::rowOffsetBufferSize(
if (params) {
return params->MCB;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ return PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
+ } else if (fbgemmHasAvx512Support()) {
return PackingTraits<T, accT, inst_set_t::avx512>::MCB;
} else if (fbgemmHasAvx2Support()) {
return PackingTraits<T, accT, inst_set_t::avx2>::MCB;
diff --git a/src/PackBMatrix.cc b/src/PackBMatrix.cc
index 0990edb..bf43fab 100644
--- a/src/PackBMatrix.cc
+++ b/src/PackBMatrix.cc
@@ -188,7 +188,8 @@ PackBMatrix<T, accT>::PackBMatrix(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
- if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) {
+ if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
+ !fbgemmHasAvx2Support())) {
assert(0 && "unknown architecure");
}
@@ -197,7 +198,12 @@ PackBMatrix<T, accT>::PackBMatrix(
BaseType::bcol_ = params->NCB;
row_interleave_ = params->ROW_INTERLEAVE;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::NCB;
+ row_interleave_ =
+ PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE;
+ } else if (fbgemmHasAvx512Support()) {
BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::NCB;
row_interleave_ =
@@ -317,14 +323,16 @@ void PackBMatrix<T, accT>::pack_unpack_(
}
template <typename T, typename accT>
-void PackBMatrix<T, accT>::pack(const block_type_t& block,
- const BlockingFactors* params) {
+void PackBMatrix<T, accT>::pack(
+ const block_type_t& block,
+ const BlockingFactors* params) {
pack_unpack_(block, const_cast<T*>(smat_), BaseType::getBuf(), true, params);
}
template <typename T, typename accT>
-void PackBMatrix<T, accT>::unpack(T* origin_buf,
- const BlockingFactors* params) {
+void PackBMatrix<T, accT>::unpack(
+ T* origin_buf,
+ const BlockingFactors* params) {
block_type_t blockB{BaseType::packedRowStart(),
BaseType::numPackedRows(),
BaseType::packedColStart(),
@@ -352,8 +360,9 @@ int32_t PackBMatrix<T, accT>::addr(int32_t r, int32_t c) const {
}
template <typename T, typename accT>
-void PackBMatrix<T, accT>::printPackedMatrix(std::string name,
- const BlockingFactors* params) {
+void PackBMatrix<T, accT>::printPackedMatrix(
+ std::string name,
+ const BlockingFactors* params) {
std::cout << name << ":"
<< "[" << BaseType::numPackedRows() << ", "
<< BaseType::numPackedCols() << "]" << std::endl;
diff --git a/src/PackMatrix.cc b/src/PackMatrix.cc
index c7503dd..ff7b842 100644
--- a/src/PackMatrix.cc
+++ b/src/PackMatrix.cc
@@ -36,7 +36,8 @@ int PackMatrix<PT, inpType, accType>::packedBufferSize(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
- if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) {
+ if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
+ !fbgemmHasAvx2Support())) {
assert(0 && "unknown architecure");
}
@@ -46,7 +47,11 @@ int PackMatrix<PT, inpType, accType>::packedBufferSize(
NCB = params->NCB;
KCB = params->KCB;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ MCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni>::MCB;
+ NCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni>::NCB;
+ KCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni>::KCB;
+ } else if (fbgemmHasAvx512Support()) {
MCB = PackingTraits<inpType, accType, inst_set_t::avx512>::MCB;
NCB = PackingTraits<inpType, accType, inst_set_t::avx512>::NCB;
KCB = PackingTraits<inpType, accType, inst_set_t::avx512>::KCB;
diff --git a/src/PackWeightMatrixForGConv.cc b/src/PackWeightMatrixForGConv.cc
index ba6adf3..f6ad59e 100644
--- a/src/PackWeightMatrixForGConv.cc
+++ b/src/PackWeightMatrixForGConv.cc
@@ -106,7 +106,7 @@ inline int PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::packed_index_(
* on 2 groups at a time and full SIMD width can be efficiently utilized even
* while working on 1 group at a time.
* In this case, the layout is G (C/4) R S K 4
-*/
+ */
template <typename T, typename accT, int SPATIAL_DIM>
void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack_unpack_(
@@ -148,9 +148,9 @@ void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack_unpack_(
if (ispack) {
transposeConvWeights(conv_param_, src, dst);
} else {
- // TODO: Wrap this as a inverseTransposeConvWeights()?
- // For unpack & transposed, call transposeConvWeights()
- // G (R S C/G) K/G => G K/G (R S C/G)
+ // TODO: Wrap this as a inverseTransposeConvWeights()?
+ // For unpack & transposed, call transposeConvWeights()
+ // G (R S C/G) K/G => G K/G (R S C/G)
for (int r = 0; r < R; ++r) {
for (int s = 0; s < S; ++s) {
for (int k = 0; k < OC_per_G; ++k) {
diff --git a/src/Utils.cc b/src/Utils.cc
index 0fa620d..5214e41 100644
--- a/src/Utils.cc
+++ b/src/Utils.cc
@@ -202,4 +202,7 @@ bool fbgemmHasAvx2Support() {
return (cpuinfo_has_x86_avx2());
}
+bool fbgemmHasAvx512VnniSupport() {
+ return (cpuinfo_has_x86_avx512vnni());
+}
} // namespace fbgemm
diff --git a/test/GConvTest.cc b/test/GConvTest.cc
index 0074535..8c1fb82 100644
--- a/test/GConvTest.cc
+++ b/test/GConvTest.cc
@@ -465,8 +465,8 @@ TEST_P(fbgemmGConvPackTest, PackUnpackTest) {
for (int i = 0; i < weight_len; ++i) {
EXPECT_EQ(Bint8.data()[i], unpack_buf.data()[i])
<< "Pack/Unpack results differ at index " << i
- << ", Reference: " << static_cast<int> (Bint8.data()[i])
- << ", Pack-Unpacked: " << static_cast<int> (unpack_buf.data()[i]);
+ << ", Reference: " << static_cast<int>(Bint8.data()[i])
+ << ", Pack-Unpacked: " << static_cast<int>(unpack_buf.data()[i]);
}
} // for each shape
}
diff --git a/third_party/asmjit b/third_party/asmjit
-Subproject 673dcefaa048c5f5a2bf8b85daf8f7b9978d018
+Subproject 5d40561d14f93dc45613bfa03155d1dfb4f5825