Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJianyu Huang <jianyuhuang@fb.com>2019-08-09 21:23:22 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-08-09 21:33:13 +0300
commit7b156071d8912dcf6711c88578c30f0f0d05d3a6 (patch)
treeb95540b1acbe2e17982f8a1c48fbe5c75a016d12
parent122135c29b68de5176bd56de6ced936cdc63cb36 (diff)
Integrate VNNI into FBGEMM master branch (#114)
Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/114 Adding the VNNI support in FBGEMM. Previously, we have the issue on CMake version. Currently PyTorch and FBGEMM OSS test has the CMake 3.5 test, while ASMJIT requires CMake to be 3.8+. This caused the build failure for some platforms. Now the CMake version issue is resolved by a PR to ASMJIT to downgrade the CMake requirement: https://github.com/asmjit/asmjit/pull/252. Reviewed By: dskhudia Differential Revision: D16720839 fbshipit-source-id: e5e5f2d26f924df8d9fb955f4a3758561fa73288
-rw-r--r--CMakeLists.txt2
-rw-r--r--include/fbgemm/PackingTraits-inl.h50
-rw-r--r--include/fbgemm/Utils.h7
-rw-r--r--src/ExecuteKernelU8S8.cc47
-rw-r--r--src/Fbgemm.cc18
-rw-r--r--src/GenerateKernel.h30
-rw-r--r--src/GenerateKernelU8S8S32ACC16.cc91
-rw-r--r--src/GenerateKernelU8S8S32ACC16Avx512.cc94
-rw-r--r--src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc102
-rw-r--r--src/GenerateKernelU8S8S32ACC32.cc87
-rw-r--r--src/GenerateKernelU8S8S32ACC32Avx512.cc95
-rw-r--r--src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc431
-rw-r--r--src/GroupwiseConv.h100
-rw-r--r--src/GroupwiseConvAcc32Avx2.cc177
-rw-r--r--src/PackAMatrix.cc10
-rw-r--r--src/PackAWithIm2Col.cc14
-rw-r--r--src/PackAWithQuantRowOffset.cc14
-rw-r--r--src/PackAWithRowOffset.cc14
-rw-r--r--src/PackBMatrix.cc25
-rw-r--r--src/PackMatrix.cc9
-rw-r--r--src/PackWeightMatrixForGConv.cc8
-rw-r--r--src/Utils.cc3
-rw-r--r--test/GConvTest.cc4
m---------third_party/asmjit0
24 files changed, 1054 insertions, 378 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index b575e17..817f699 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -33,8 +33,10 @@ set(FBGEMM_GENERIC_SRCS src/ExecuteKernel.cc
src/FbgemmI8Spmdm.cc
src/GenerateKernelU8S8S32ACC16.cc
src/GenerateKernelU8S8S32ACC16Avx512.cc
+ src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc
src/GenerateKernelU8S8S32ACC32.cc
src/GenerateKernelU8S8S32ACC32Avx512.cc
+ src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc
src/GroupwiseConvAcc32Avx2.cc
src/PackAMatrix.cc
src/PackAWithIm2Col.cc
diff --git a/include/fbgemm/PackingTraits-inl.h b/include/fbgemm/PackingTraits-inl.h
index 76eb425..baccfad 100644
--- a/include/fbgemm/PackingTraits-inl.h
+++ b/include/fbgemm/PackingTraits-inl.h
@@ -222,3 +222,53 @@ struct PackingTraits<
128}; ///< Cache block for N dimension (multiple of NR).
static constexpr int KCB{256}; ///< Cache block for K dimension.
};
+
+/**
+ * @brief Helper struct to type specialize for int16_t and int32_t together.
+ */
+template <typename T>
+struct is_16or32bit {
+ static constexpr bool value =
+ std::is_same<T, int16_t>::value || std::is_same<T, int32_t>::value;
+};
+
+/**
+ * @brief Packing parameter specialization for accumulation into 32-bit/16-bit
+ * integers.
+ *
+ * Since there is no int16_t accumulation for AVX512 VNNI, we redirect int16_t
+ * to int32_t accumulation and use the same blocking parameters as int32_t.
+ *
+ * This is picked when T is of int8 type (signed or unsigned) and instruction
+ * set is avx512_vnni.
+ */
+template <typename T, typename accT>
+struct PackingTraits<
+ T,
+ accT,
+ inst_set_t::avx512_vnni,
+ typename std::enable_if<
+ is_8bit<T>::value && is_16or32bit<accT>::value>::type> {
+ static constexpr int MR{8}; ///< Register block for M dimension.
+ static constexpr int NR_MIN{
+ 16}; ///< Minimum register block for N dimension.
+ ///< 16 because 16*ROW_INTERLEAVE int8 elements
+ ///< completely fill a 512-bit wide vector.
+ static constexpr int NR{
+ 32}; ///< Register block for N dimension.
+ ///< Must be a multiple of 16 because 16*ROW_INTERLEAVE int8 elements
+ ///< completely fill a 512-bit wide vector. Total registers used for
+ ///< N dimension: NR*ROW_INTERLEAVE*8/VLEN. We use MR x
+ ///< NR*ROW_INTERLEAVE*8/VLEN zmm registers
+ ///< for C accumulations.
+
+ static constexpr int ROW_INTERLEAVE{
+ 4}; ///< 4 rows are interleaved to use vpmaddubsw instruction for packing
+ ///< B matrix.
+
+ static constexpr int MCB{
+ 128}; ///< Cache block for M dimension (multiple of MR).
+ static constexpr int NCB{
+ 32}; ///< Cache block for N dimension (multiple of NR).
+ static constexpr int KCB{256}; ///< Cache block for K dimension.
+};
diff --git a/include/fbgemm/Utils.h b/include/fbgemm/Utils.h
index 107cf07..3f8522b 100644
--- a/include/fbgemm/Utils.h
+++ b/include/fbgemm/Utils.h
@@ -29,7 +29,7 @@ enum class matrix_op_t { NoTranspose, Transpose };
/**
* @brief Typed enum for supported instruction sets.
*/
-enum class inst_set_t { anyarch, avx2, avx512 };
+enum class inst_set_t { anyarch, avx2, avx512, avx512_vnni };
/**
* @brief Typed enum for optimized paths for convolutions
@@ -100,6 +100,11 @@ FBGEMM_API bool fbgemmHasAvx512Support();
FBGEMM_API bool fbgemmHasAvx2Support();
/**
+ * @brief Are we running on a AVX512_VNNI supported cpu?
+ */
+FBGEMM_API bool fbgemmHasAvx512VnniSupport();
+
+/**
* @brief Helper struct to enable autotuning of FBGEMM packing and kernels.
*
* This structure is optional. If not used, the default values for these
diff --git a/src/ExecuteKernelU8S8.cc b/src/ExecuteKernelU8S8.cc
index f7292fd..0a4ff55 100644
--- a/src/ExecuteKernelU8S8.cc
+++ b/src/ExecuteKernelU8S8.cc
@@ -49,7 +49,8 @@ ExecuteKernel<
throw std::runtime_error("Failed to initialize cpuinfo!");
}
if (params) {
- if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) {
+ if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support() ||
+ fbgemmHasAvx2Support()) {
mbSize_ = params->MCB;
nbSize_ = params->NCB;
nrMinSize_ = params->NR_MIN;
@@ -59,7 +60,20 @@ ExecuteKernel<
assert(0 && "unsupported architecure");
}
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ mbSize_ = PackingTraits<
+ int8_t,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512_vnni>::MCB;
+ nbSize_ = PackingTraits<
+ int8_t,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512_vnni>::NCB;
+ nrMinSize_ = PackingTraits<
+ int8_t,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512_vnni>::NR_MIN;
+ } else if (fbgemmHasAvx512Support()) {
mbSize_ = PackingTraits<
int8_t,
typename packingAMatrix::accType,
@@ -118,7 +132,25 @@ void ExecuteKernel<
typename BaseType::jit_micro_kernel_fp fn;
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ if (std::is_same<typename packingAMatrix::accType, std::int16_t>::value) {
+ // For AVX512VNNI, we redirect int16_t to int32_t accumulation.
+ CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
+ fn = codeObj.getOrCreate<inst_set_t::avx512_vnni>(
+ accum,
+ packed_rows_A,
+ packedB_.blockColSize(),
+ packedA_.numPackedCols(),
+ nbSize_);
+ } else {
+ fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni>(
+ accum,
+ packed_rows_A,
+ packedB_.blockColSize(),
+ packedA_.numPackedCols(),
+ nbSize_);
+ }
+ } else if (fbgemmHasAvx512Support()) {
fn = BaseType::template getOrCreate<inst_set_t::avx512>(
accum,
packed_rows_A,
@@ -148,7 +180,10 @@ void ExecuteKernel<
if (jb == bColBlocks - 1) {
int nc = ((packedB_.lastBcol() - 1) / nrMinSize_ + 1) * nrMinSize_;
if (nc != nbSize_) {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni>(
+ accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_);
+ } else if (fbgemmHasAvx512Support()) {
fn = BaseType::template getOrCreate<inst_set_t::avx512>(
accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_);
} else if (fbgemmHasAvx2Support()) {
@@ -213,7 +248,7 @@ void ExecuteKernel<
int32_t nSize =
C_buffer_start == C_tile_ ? jb * nbSize_ : packedB_.numCols();
if (nSize) {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support()) {
// TODO: avx512 path
// Currently use avx2 code
outputProcess_.template f<inst_set_t::avx2>(
@@ -238,7 +273,7 @@ void ExecuteKernel<
if (C_buffer_start == C_tile_) {
// When C_tile_ scratchpad was used to avoid accessing memory past
// C_buffer_ .
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support()) {
// TODO: avx512 path
// Currently use avx2 code
outputProcess_.template f<inst_set_t::avx2>(
diff --git a/src/Fbgemm.cc b/src/Fbgemm.cc
index 0f2f6fb..4f7026f 100644
--- a/src/Fbgemm.cc
+++ b/src/Fbgemm.cc
@@ -48,7 +48,8 @@ void fbgemmPacked(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
- if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) {
+ if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
+ !fbgemmHasAvx2Support())) {
assert(0 && "unknown architecure");
}
@@ -62,7 +63,20 @@ void fbgemmPacked(
MR = blocking_params->MR;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ MCB = PackingTraits<
+ typename packingAMatrix::inpType,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512_vnni>::MCB;
+ KCB = PackingTraits<
+ typename packingAMatrix::inpType,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512_vnni>::KCB;
+ MR = PackingTraits<
+ typename packingAMatrix::inpType,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512_vnni>::MR;
+ } else if (fbgemmHasAvx512Support()) {
MCB = PackingTraits<
typename packingAMatrix::inpType,
typename packingAMatrix::accType,
diff --git a/src/GenerateKernel.h b/src/GenerateKernel.h
index dccdfc5..e52097e 100644
--- a/src/GenerateKernel.h
+++ b/src/GenerateKernel.h
@@ -18,7 +18,7 @@ namespace fbgemm {
namespace x86 = asmjit::x86;
/**
- * @brief AVX2/AVX512 JIT assembly code generator.
+ * @brief AVX2/AVX512/AVX512VNNI JIT assembly code generator.
* @tparam TA Type of matrix A.
* @tparam TB Type of matrix B.
* @tparam TC Type of matrix C.
@@ -104,7 +104,7 @@ class CodeGenBase {
*/
template <inst_set_t instSet>
void initCRegs(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
int leadingDimCRegAssign = 4);
@@ -114,10 +114,10 @@ class CodeGenBase {
*/
template <inst_set_t instSet>
void genComputeBlock(
- asmjit::X86Emitter* a,
- asmjit::X86Gp buffer_A,
- asmjit::X86Gp buffer_B,
- asmjit::X86Gp B_pf,
+ x86::Emitter* a,
+ x86::Gp buffer_A,
+ x86::Gp buffer_B,
+ x86::Gp B_pf,
int rowRegs,
int colRegs,
int lda,
@@ -129,11 +129,11 @@ class CodeGenBase {
*/
template <inst_set_t instSet>
void storeCRegs(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
- asmjit::X86Gp C_Offset,
- asmjit::X86Gp ldcReg,
+ x86::Gp C_Offset,
+ x86::Gp ldcReg,
bool accum,
int leadingDimCRegAssign = 4);
@@ -168,7 +168,9 @@ class CodeGenBase {
fileName += "_MR-" + std::to_string(MR);
fileName += "_NR-" + std::to_string(NR);
fileName += "_NR_MIN-" + std::to_string(NR_MIN);
- if (instSet == inst_set_t::avx512) {
+ if (instSet == inst_set_t::avx512_vnni) {
+ fileName += "_avx512vnni";
+ } else if (instSet == inst_set_t::avx512) {
fileName += "_avx512";
} else if (instSet == inst_set_t::avx2) {
fileName += "_avx2";
@@ -178,12 +180,10 @@ class CodeGenBase {
}
private:
- asmjit::X86Ymm
- CRegs_avx2_[12]; ///< AVX2 ymm registers for C in the micro-kernel.
- asmjit::X86Zmm
+ x86::Ymm CRegs_avx2_[12]; ///< AVX2 ymm registers for C in the micro-kernel.
+ x86::Zmm
CRegs_avx512_[28]; ///< AVX512 zmm registers for C in the micro-kernel.
- asmjit::X86Zmm
- AllRegs_avx512_[32]; ///< all AVX512 zmm registers.
+ x86::Zmm AllRegs_avx512_[32]; ///< all AVX512 zmm registers.
int vectorWidth_; ///< Vector width in bits.
int VLEN_; ///< Vector width in elements.
diff --git a/src/GenerateKernelU8S8S32ACC16.cc b/src/GenerateKernelU8S8S32ACC16.cc
index 718b883..1e7e081 100644
--- a/src/GenerateKernelU8S8S32ACC16.cc
+++ b/src/GenerateKernelU8S8S32ACC16.cc
@@ -31,7 +31,7 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs<
inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
int leadingDimCReg) {
@@ -53,18 +53,18 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
inst_set_t::avx2>(
- asmjit::X86Emitter* a,
- asmjit::X86Gp buffer_A,
- asmjit::X86Gp buffer_B,
- asmjit::X86Gp /* unused (reserved for prefetching)*/,
+ x86::Emitter* a,
+ x86::Gp buffer_A,
+ x86::Gp buffer_B,
+ x86::Gp /* unused (reserved for prefetching)*/,
int rowRegs,
int colRegs,
int lda,
int leadingDimCReg) {
// used for matrix A
- asmjit::X86Ymm AReg = x86::ymm12;
+ x86::Ymm AReg = x86::ymm12;
- asmjit::X86Ymm tmpReg = x86::ymm14;
+ x86::Ymm tmpReg = x86::ymm14;
for (int i = 0; i < rowRegs; ++i) {
// broadcast A
@@ -95,15 +95,15 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
- asmjit::X86Gp C_Offset,
- asmjit::X86Gp ldcReg,
+ x86::Gp C_Offset,
+ x86::Gp ldcReg,
bool accum,
int leadingDimCReg) {
- asmjit::X86Xmm extractDest128 = x86::xmm15;
- asmjit::X86Ymm extractDest256 = x86::ymm15;
+ x86::Xmm extractDest128 = x86::xmm15;
+ x86::Ymm extractDest256 = x86::ymm15;
for (int i = 0; i < rowRegs; ++i) {
a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(i * sizeof(int32_t)));
@@ -112,7 +112,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
a->vextracti128(
extractDest128, CRegs_avx2_[i * leadingDimCReg + j], idx);
a->vpmovsxwd(extractDest256, extractDest128);
- asmjit::X86Mem destAddr = x86::dword_ptr(
+ x86::Mem destAddr = x86::dword_ptr(
a->zcx(), C_Offset, 0, (j * 2 + idx) * 8 * sizeof(int32_t));
if (accum) {
a->vpaddd(extractDest256, extractDest256, destAddr);
@@ -176,9 +176,9 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
return codeCache_[kernelSig];
}
code_.reset(false);
- code_.init(rt_.getCodeInfo());
- asmjit::X86Assembler assembler(&code_);
- asmjit::X86Emitter* a = assembler.asEmitter();
+ code_.init(rt_.codeInfo());
+ x86::Assembler assembler(&code_);
+ x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
// generated code logging
@@ -207,46 +207,45 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
//"nc must be equal to the number of register blocks");
// arguments to the function created
- asmjit::X86Gp buffer_A = a->zdi();
- asmjit::X86Gp buffer_B = a->zsi();
- asmjit::X86Gp B_pf = a->zdx();
- asmjit::X86Gp CBase = a->zcx();
- asmjit::X86Gp kSize = a->gpzRef(8);
- asmjit::X86Gp ldcReg = a->gpzRef(9);
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
asmjit::FuncDetail func;
func.init(
asmjit::
- FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
+ FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
asmjit::CallConv::kIdHost));
- asmjit::FuncFrameInfo ffi;
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindVec,
- asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14));
+ asmjit::FuncFrame frame;
+ frame.init(func);
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(
+ x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));
- asmjit::FuncArgsMapper args(&func);
+ asmjit::FuncArgsAssignment args(&func);
args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
- args.updateFrameInfo(ffi);
+ args.updateFuncFrame(frame);
+ frame.finalize();
- asmjit::FuncFrameLayout layout;
- layout.init(func, ffi);
-
- asmjit::FuncUtils::emitProlog(a, layout);
- asmjit::FuncUtils::allocArgs(a, layout, args);
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
asmjit::Label Loopk = a->newLabel();
asmjit::Label LoopMBlocks = a->newLabel();
- asmjit::X86Gp buffer_B_saved = a->gpzRef(10);
- asmjit::X86Gp C_Offset = a->gpzRef(11);
- // asmjit::X86Gp B_pf_saved = a->gpzRef(12);
- asmjit::X86Gp iIdx = a->gpzRef(13);
- asmjit::X86Gp kIdx = a->gpzRef(14);
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ // x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp kIdx = a->gpz(14);
int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
if (mRegBlocks > 0) {
@@ -289,8 +288,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
a->jl(Loopk);
// store C matrix
- storeCRegs<inst_set_t::avx2>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum);
+ storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum);
// increment A for next block
a->sub(buffer_A, kSize);
@@ -340,11 +338,10 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
a->jl(LoopkRem);
// store C matrix
- storeCRegs<inst_set_t::avx2>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum);
+ storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum);
}
- asmjit::FuncUtils::emitEpilog(a, layout);
+ a->emitEpilog(frame);
jit_micro_kernel_fp fn;
asmjit::Error err = rt_.add(&fn, &code_);
diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc
index c95757b..a49e440 100644
--- a/src/GenerateKernelU8S8S32ACC16Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc
@@ -19,7 +19,7 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs<
inst_set_t::avx512>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
int leadingDimCReg) {
@@ -41,18 +41,18 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
inst_set_t::avx512>(
- asmjit::X86Emitter* a,
- asmjit::X86Gp buffer_A,
- asmjit::X86Gp buffer_B,
- asmjit::X86Gp /* unused (reserved for prefetching)*/,
+ x86::Emitter* a,
+ x86::Gp buffer_A,
+ x86::Gp buffer_B,
+ x86::Gp /* unused (reserved for prefetching)*/,
int rowRegs,
int colRegs,
int lda,
int leadingDimCReg) {
// used for matrix A
- asmjit::X86Zmm AReg = x86::zmm29;
+ x86::Zmm AReg = x86::zmm29;
- asmjit::X86Zmm tmpReg = x86::zmm30;
+ x86::Zmm tmpReg = x86::zmm30;
// We start allocating BRegs from zmm27 and then allocate zmm26 and so on.
for (int j = 0; j < colRegs; ++j) {
@@ -66,8 +66,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
a->vpbroadcastw(
AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t)));
for (int j = 0; j < colRegs; ++j) {
- a->vpmaddubsw(
- tmpReg, AReg, AllRegs_avx512_[27-j]);
+ a->vpmaddubsw(tmpReg, AReg, AllRegs_avx512_[27 - j]);
a->vpaddsw(
CRegs_avx512_[i * leadingDimCReg + j],
tmpReg,
@@ -90,15 +89,16 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
inst_set_t::avx512>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
- asmjit::X86Gp C_Offset,
- asmjit::X86Gp ldcReg,
+ x86::Gp C_Offset,
+ x86::Gp ldcReg,
+
bool accum,
int leadingDimCReg) {
- asmjit::X86Ymm extractDest256 = x86::ymm31;
- asmjit::X86Zmm extractDest512 = x86::zmm31;
+ x86::Ymm extractDest256 = x86::ymm31;
+ x86::Zmm extractDest512 = x86::zmm31;
for (int i = 0; i < rowRegs; ++i) {
a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(i * sizeof(int32_t)));
@@ -107,7 +107,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
a->vextracti32x8(
extractDest256, CRegs_avx512_[i * leadingDimCReg + j], idx);
a->vpmovsxwd(extractDest512, extractDest256);
- asmjit::X86Mem destAddr = x86::dword_ptr(
+ x86::Mem destAddr = x86::dword_ptr(
a->zcx(), C_Offset, 0, (j * 2 + idx) * 16 * sizeof(int32_t));
if (accum) {
a->vpaddd(extractDest512, extractDest512, destAddr);
@@ -172,9 +172,9 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
}
code_.reset(false);
- code_.init(rt_.getCodeInfo());
- asmjit::X86Assembler assembler(&code_);
- asmjit::X86Emitter* a = assembler.asEmitter();
+ code_.init(rt_.codeInfo());
+ x86::Assembler assembler(&code_);
+ x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
// generated code logging
@@ -209,49 +209,49 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
int mRegBlocksRem = mc % mRegBlockSize;
// arguments to the function created
- asmjit::X86Gp buffer_A = a->zdi();
- asmjit::X86Gp buffer_B = a->zsi();
- asmjit::X86Gp B_pf = a->zdx();
- asmjit::X86Gp CBase = a->zcx();
- asmjit::X86Gp kSize = a->gpzRef(8);
- asmjit::X86Gp ldcReg = a->gpzRef(9);
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
asmjit::FuncDetail func;
func.init(
asmjit::
- FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
+ FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
asmjit::CallConv::kIdHost));
- asmjit::FuncFrameInfo ffi;
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindVec,
- asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindGp,
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
+ asmjit::FuncFrame frame;
+ frame.init(func);
- asmjit::FuncArgsMapper args(&func);
- args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(
+ x86::Reg::kGroupGp,
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- args.updateFrameInfo(ffi);
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
- asmjit::FuncFrameLayout layout;
- layout.init(func, ffi);
+ args.updateFuncFrame(frame);
+ frame.finalize();
- asmjit::FuncUtils::emitProlog(a, layout);
- asmjit::FuncUtils::allocArgs(a, layout, args);
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
asmjit::Label LoopMBlocks = a->newLabel();
asmjit::Label LoopNBlocks = a->newLabel();
asmjit::Label Loopk = a->newLabel();
- asmjit::X86Gp buffer_B_saved = a->gpzRef(10);
- asmjit::X86Gp C_Offset = a->gpzRef(11);
- // asmjit::X86Gp B_pf_saved = a->gpzRef(12);
- asmjit::X86Gp iIdx = a->gpzRef(13);
- asmjit::X86Gp jIdx = a->gpzRef(14);
- asmjit::X86Gp kIdx = a->gpzRef(15);
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ // x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp jIdx = a->gpz(14);
+ x86::Gp kIdx = a->gpz(15);
// save B_buffer address
a->mov(buffer_B_saved, buffer_B);
@@ -407,7 +407,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
a->jl(LoopNRem);
}
- asmjit::FuncUtils::emitEpilog(a, layout);
+ a->emitEpilog(frame);
jit_micro_kernel_fp fn;
asmjit::Error err = rt_.add(&fn, &code_);
diff --git a/src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc b/src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc
new file mode 100644
index 0000000..f559aba
--- /dev/null
+++ b/src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc
@@ -0,0 +1,102 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <iostream>
+#include "GenerateKernel.h"
+
+namespace fbgemm {
+
+namespace x86 = asmjit::x86;
+
+/**
+ * Generate AVX512 instructions for initializing the C registers to 0 in 16-bit
+ * Accumulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs<
+ inst_set_t::avx512_vnni>(
+ x86::Emitter* a,
+ int rowRegs,
+ int colRegs,
+ int leadingDimCReg) {
+ assert(0 && "Accumulation to int16_t is not available for VNNI!");
+
+ // For AVX512VNNI, redirect to int32_t accumulation.
+ CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
+ codeObj.initCRegs<inst_set_t::avx512_vnni>(
+ a, rowRegs, colRegs, leadingDimCReg);
+}
+
+/**
+ * Generate AVX512 instructions for computing block in the rank-k update of
+ * 16-bit Accmulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
+ inst_set_t::avx512_vnni>(
+ x86::Emitter* a,
+ x86::Gp buffer_A,
+ x86::Gp buffer_B,
+ x86::Gp /* unused (reserved for prefetching)*/,
+ int rowRegs,
+ int colRegs,
+ int lda,
+ int leadingDimCReg) {
+ assert(0 && "Accumulation to int16_t is not available for VNNI!");
+
+ // For AVX512VNNI, redirect to int32_t accumulation.
+ CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
+ codeObj.genComputeBlock<inst_set_t::avx512_vnni>(
+ a, buffer_A, buffer_B, buffer_B, rowRegs, colRegs, lda, leadingDimCReg);
+}
+
+/**
+ * Generate AVX512 instructions for storing the C registers back to the memory
+ * in 16-bit Accumulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
+ inst_set_t::avx512_vnni>(
+ x86::Emitter* a,
+ int rowRegs,
+ int colRegs,
+ x86::Gp C_Offset,
+ x86::Gp ldcReg,
+ bool accum,
+ int leadingDimCReg) {
+ assert(0 && "Accumulation to int16_t is not available for VNNI!");
+
+ // For AVX512VNNI, redirect to int32_t accumulation.
+ CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
+ codeObj.storeCRegs<inst_set_t::avx512_vnni>(
+ a, rowRegs, colRegs, C_Offset, ldcReg, accum, leadingDimCReg);
+}
+
+/**
+ * Get or Create the AVX512 instructions for 16-bit Accumulation macro-kernel.
+ *
+ */
+template <>
+template <>
+CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::jit_micro_kernel_fp
+CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<
+ inst_set_t::avx512_vnni>(
+ bool accum,
+ int32_t mc,
+ int32_t nc,
+ int32_t kc,
+ int32_t /* unused */) {
+ assert(0 && "Accumulation to int16_t is not available for VNNI!");
+
+ // For AVX512VNNI, redirect to int32_t accumulation.
+ CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
+ return codeObj.getOrCreate<inst_set_t::avx512_vnni>(accum, mc, nc, kc, kc);
+}
+
+} // namespace fbgemm
diff --git a/src/GenerateKernelU8S8S32ACC32.cc b/src/GenerateKernelU8S8S32ACC32.cc
index 58643ad..6b54743 100644
--- a/src/GenerateKernelU8S8S32ACC32.cc
+++ b/src/GenerateKernelU8S8S32ACC32.cc
@@ -31,7 +31,7 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs<
inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
int leadingDimCReg) {
@@ -53,25 +53,25 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
inst_set_t::avx2>(
- asmjit::X86Emitter* a,
- asmjit::X86Gp buffer_A,
- asmjit::X86Gp buffer_B,
- asmjit::X86Gp B_pf,
+ x86::Emitter* a,
+ x86::Gp buffer_A,
+ x86::Gp buffer_B,
+ x86::Gp B_pf,
int rowRegs,
int colRegs,
int lda,
int leadingDimCReg) {
// used for matrix A
- asmjit::X86Ymm AReg = x86::ymm12;
+ x86::Ymm AReg = x86::ymm12;
// used for matrix B
- asmjit::X86Ymm BReg = x86::ymm13;
+ x86::Ymm BReg = x86::ymm13;
// Contains 16-bit 1s
- asmjit::X86Ymm oneReg = x86::ymm15;
+ x86::Ymm oneReg = x86::ymm15;
// temporary register
- asmjit::X86Ymm res1 = x86::ymm14;
+ x86::Ymm res1 = x86::ymm14;
for (int j = 0; j < colRegs; ++j) {
// load B
@@ -99,11 +99,11 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<
inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
- asmjit::X86Gp C_Offset,
- asmjit::X86Gp ldcReg,
+ x86::Gp C_Offset,
+ x86::Gp ldcReg,
bool accum,
int leadingDimCReg) {
for (int i = 0; i < rowRegs; ++i) {
@@ -177,9 +177,9 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>(
return codeCache_[kernelSig];
}
code_.reset(false);
- code_.init(rt_.getCodeInfo());
- asmjit::X86Assembler assembler(&code_);
- asmjit::X86Emitter* a = assembler.asEmitter();
+ code_.init(rt_.codeInfo());
+ x86::Assembler assembler(&code_);
+ x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
// generated code logging
FILE* codeLogfile = fopen(
@@ -205,49 +205,48 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>(
int mRegBlocksRem = mc % mRegBlockSize;
// arguments to the function created
- asmjit::X86Gp buffer_A = a->zdi();
- asmjit::X86Gp buffer_B = a->zsi();
- asmjit::X86Gp B_pf = a->zdx();
- asmjit::X86Gp CBase = a->zcx();
- asmjit::X86Gp kSize = a->gpzRef(8);
- asmjit::X86Gp ldcReg = a->gpzRef(9);
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
asmjit::FuncDetail func;
func.init(
asmjit::
- FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
+ FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
asmjit::CallConv::kIdHost));
- asmjit::FuncFrameInfo ffi;
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindVec,
- asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14));
+ asmjit::FuncFrame frame;
+ frame.init(func);
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(
+ x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));
- asmjit::FuncArgsMapper args(&func);
+ asmjit::FuncArgsAssignment args(&func);
args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
- args.updateFrameInfo(ffi);
+ args.updateFuncFrame(frame);
+ frame.finalize();
- asmjit::FuncFrameLayout layout;
- layout.init(func, ffi);
-
- asmjit::FuncUtils::emitProlog(a, layout);
- asmjit::FuncUtils::allocArgs(a, layout, args);
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
asmjit::Label Loopk = a->newLabel();
asmjit::Label LoopMBlocks = a->newLabel();
- asmjit::X86Gp buffer_B_saved = a->gpzRef(10);
- asmjit::X86Gp C_Offset = a->gpzRef(11);
- asmjit::X86Gp B_pf_saved = a->gpzRef(12);
- asmjit::X86Gp iIdx = a->gpzRef(13);
- asmjit::X86Gp kIdx = a->gpzRef(14);
- // asmjit::X86Gp B_pf = a->gpzRef(8);
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp kIdx = a->gpz(14);
+ // x86::Gp B_pf = a->gpz(8);
- asmjit::X86Ymm oneReg = x86::ymm15;
+ x86::Ymm oneReg = x86::ymm15;
// create 16-bit 1s
// i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
// and so on
@@ -358,7 +357,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>(
a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
}
- asmjit::FuncUtils::emitEpilog(a, layout);
+ a->emitEpilog(frame);
jit_micro_kernel_fp fn;
asmjit::Error err = rt_.add(&fn, &code_);
diff --git a/src/GenerateKernelU8S8S32ACC32Avx512.cc b/src/GenerateKernelU8S8S32ACC32Avx512.cc
index 12243ee..fe35627 100644
--- a/src/GenerateKernelU8S8S32ACC32Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC32Avx512.cc
@@ -19,7 +19,7 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs<
inst_set_t::avx512>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
int leadingDimCReg) {
@@ -41,25 +41,25 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
inst_set_t::avx512>(
- asmjit::X86Emitter* a,
- asmjit::X86Gp buffer_A,
- asmjit::X86Gp buffer_B,
- asmjit::X86Gp B_pf,
+ x86::Emitter* a,
+ x86::Gp buffer_A,
+ x86::Gp buffer_B,
+ x86::Gp B_pf,
int rowRegs,
int colRegs,
int lda,
int leadingDimCReg) {
// used for matrix A
- asmjit::X86Zmm AReg = x86::zmm31;
+ x86::Zmm AReg = x86::zmm31;
// used for matrix B
- asmjit::X86Zmm BReg = x86::zmm30;
+ x86::Zmm BReg = x86::zmm30;
// Contains 16-bit 1s
- asmjit::X86Zmm oneReg = x86::zmm29;
+ x86::Zmm oneReg = x86::zmm29;
// temporary register
- asmjit::X86Zmm res1 = x86::zmm28;
+ x86::Zmm res1 = x86::zmm28;
for (int j = 0; j < colRegs; ++j) {
// load B
@@ -87,18 +87,17 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<
inst_set_t::avx512>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
- asmjit::X86Gp C_Offset,
- asmjit::X86Gp ldcReg,
+ x86::Gp C_Offset,
+ x86::Gp ldcReg,
bool accum,
int leadingDimCReg) {
for (int i = 0; i < rowRegs; ++i) {
if (i != 0) {
a->add(C_Offset, ldcReg);
- }
- else {
+ } else {
a->mov(C_Offset, static_cast<asmjit::Imm>(0));
}
for (int j = 0; j < colRegs; ++j) {
@@ -168,9 +167,9 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
return codeCache_[kernelSig];
}
code_.reset(false);
- code_.init(rt_.getCodeInfo());
- asmjit::X86Assembler assembler(&code_);
- asmjit::X86Emitter* a = assembler.asEmitter();
+ code_.init(rt_.codeInfo());
+ x86::Assembler assembler(&code_);
+ x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
// generated code logging
@@ -205,52 +204,52 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
int mRegBlocksRem = mc % mRegBlockSize;
// arguments to the function created
- asmjit::X86Gp buffer_A = a->zdi();
- asmjit::X86Gp buffer_B = a->zsi();
- asmjit::X86Gp B_pf = a->zdx();
- asmjit::X86Gp CBase = a->zcx();
- asmjit::X86Gp kSize = a->gpzRef(8);
- asmjit::X86Gp ldcReg = a->gpzRef(9);
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
asmjit::FuncDetail func;
func.init(
asmjit::
- FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
+ FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
asmjit::CallConv::kIdHost));
- asmjit::FuncFrameInfo ffi;
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindVec,
- asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindGp,
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
+ asmjit::FuncFrame frame;
+ frame.init(func);
- asmjit::FuncArgsMapper args(&func);
- args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(
+ x86::Reg::kGroupGp,
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- args.updateFrameInfo(ffi);
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
- asmjit::FuncFrameLayout layout;
- layout.init(func, ffi);
+ args.updateFuncFrame(frame);
+ frame.finalize();
- asmjit::FuncUtils::emitProlog(a, layout);
- asmjit::FuncUtils::allocArgs(a, layout, args);
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
asmjit::Label LoopMBlocks = a->newLabel();
asmjit::Label LoopNBlocks = a->newLabel();
asmjit::Label Loopk = a->newLabel();
- asmjit::X86Gp buffer_B_saved = a->gpzRef(10);
- asmjit::X86Gp C_Offset = a->gpzRef(11);
- asmjit::X86Gp B_pf_saved = a->gpzRef(12);
- asmjit::X86Gp iIdx = a->gpzRef(13);
- asmjit::X86Gp jIdx = a->gpzRef(14);
- asmjit::X86Gp kIdx = a->gpzRef(15);
- // asmjit::X86Gp B_pf = a->gpzRef(8);
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp jIdx = a->gpz(14);
+ x86::Gp kIdx = a->gpz(15);
+ // x86::Gp B_pf = a->gpz(8);
- asmjit::X86Zmm oneReg = x86::zmm29;
+ x86::Zmm oneReg = x86::zmm29;
// create 16-bit 1s
// i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
// and so on
@@ -420,7 +419,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
a->jl(LoopNRem);
}
- asmjit::FuncUtils::emitEpilog(a, layout);
+ a->emitEpilog(frame);
jit_micro_kernel_fp fn;
asmjit::Error err = rt_.add(&fn, &code_);
diff --git a/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc b/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc
new file mode 100644
index 0000000..8ae0745
--- /dev/null
+++ b/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc
@@ -0,0 +1,431 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <iostream>
+#include "GenerateKernel.h"
+
+namespace fbgemm {
+
+namespace x86 = asmjit::x86;
+
+/**
+ * Generate AVX512 instructions for initializing the C registers to 0 in 32-bit
+ * Accumulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs<
+ inst_set_t::avx512_vnni>(
+ x86::Emitter* a,
+ int rowRegs,
+ int colRegs,
+ int leadingDimCReg) {
+ for (int i = 0; i < rowRegs; ++i) {
+ for (int j = 0; j < colRegs; ++j) {
+ a->vxorps(
+ CRegs_avx512_[i * leadingDimCReg + j],
+ CRegs_avx512_[i * leadingDimCReg + j],
+ CRegs_avx512_[i * leadingDimCReg + j]);
+ }
+ }
+}
+
+/**
+ * Generate AVX512 instructions for computing block in the rank-k update of
+ * 32-bit Accmulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
+ inst_set_t::avx512_vnni>(
+ x86::Emitter* a,
+ x86::Gp buffer_A,
+ x86::Gp buffer_B,
+ x86::Gp B_pf,
+ int rowRegs,
+ int colRegs,
+ int lda,
+ int leadingDimCReg) {
+ // used for matrix A
+ x86::Zmm AReg = x86::zmm31;
+
+ // used for matrix B
+ x86::Zmm BReg = x86::zmm30;
+
+ for (int j = 0; j < colRegs; ++j) {
+ // load B
+ a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
+ // load A, broadcast and fmas
+ for (int i = 0; i < rowRegs; ++i) {
+ a->vpbroadcastd(
+ AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t)));
+ a->vpdpbusd(CRegs_avx512_[i * leadingDimCReg + j], AReg, BReg);
+ }
+ a->prefetcht0(x86::dword_ptr(B_pf, j * VLEN_ * sizeof(int8_t)));
+ }
+}
+
+/**
+ * Generate AVX512 instructions for storing the C registers back to the memory
+ * in 32-bit Accumulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<
+ inst_set_t::avx512_vnni>(
+ x86::Emitter* a,
+ int rowRegs,
+ int colRegs,
+ x86::Gp C_Offset,
+ x86::Gp ldcReg,
+ bool accum,
+ int leadingDimCReg) {
+ for (int i = 0; i < rowRegs; ++i) {
+ if (i != 0) {
+ a->add(C_Offset, ldcReg);
+ } else {
+ a->mov(C_Offset, static_cast<asmjit::Imm>(0));
+ }
+ for (int j = 0; j < colRegs; ++j) {
+ if (accum) {
+ a->vpaddd(
+ CRegs_avx512_[i * leadingDimCReg + j],
+ CRegs_avx512_[i * leadingDimCReg + j],
+ x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)));
+ }
+ a->vmovups(
+ x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)),
+ CRegs_avx512_[i * leadingDimCReg + j]);
+ }
+ }
+}
+
+/**
+ * Get or Create the AVX512 instructions for 32-bit Accumulation macro-kernel.
+ *
+ */
+template <>
+template <>
+CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::jit_micro_kernel_fp
+CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<
+ inst_set_t::avx512_vnni>(
+ bool accum,
+ int32_t mc,
+ int32_t nc,
+ int32_t kc,
+ int32_t /* unused */) {
+ std::tuple<bool, int, int, int, int, int, int, int> kernelSig;
+ int kBlock;
+ int nBlock;
+ int mRegBlockSize;
+ int nRegBlockSize;
+ int nRegBlockSizeMin;
+ int row_interleave;
+
+ if (blocking_params) {
+ kBlock = blocking_params->KCB;
+ nBlock = blocking_params->NCB;
+ mRegBlockSize = blocking_params->MR;
+ nRegBlockSize = blocking_params->NR;
+ nRegBlockSizeMin = blocking_params->NR_MIN;
+ row_interleave = blocking_params->ROW_INTERLEAVE;
+ } else {
+ kBlock = PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::KCB;
+ nBlock = PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::NCB;
+ mRegBlockSize =
+ PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::MR;
+ nRegBlockSize =
+ PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::NR;
+ nRegBlockSizeMin =
+ PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::NR_MIN;
+ row_interleave = PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::
+ ROW_INTERLEAVE;
+ }
+
+ kernelSig = std::make_tuple(
+ accum,
+ mc,
+ nc,
+ nBlock,
+ kBlock,
+ mRegBlockSize,
+ nRegBlockSize,
+ nRegBlockSizeMin);
+
+ if (codeCache_.find(kernelSig) != codeCache_.end()) {
+ return codeCache_[kernelSig];
+ }
+ code_.reset(false);
+ code_.init(rt_.codeInfo());
+ x86::Assembler assembler(&code_);
+ x86::Emitter* a = assembler.as<x86::Emitter>();
+
+#if defined(FBGEMM_LOG_CODE)
+ // generated code logging
+ FILE* codeLogfile = fopen(
+ getCodeLoggingFile<inst_set_t::avx512_vnni>(
+ accum,
+ mc,
+ nc,
+ nBlock,
+ kBlock,
+ mRegBlockSize,
+ nRegBlockSize,
+ nRegBlockSizeMin)
+ .c_str(),
+ "w");
+ asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
+ if (codeLogger) {
+ code_.setLogger(codeLogger);
+ }
+#endif
+
+ assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
+ assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
+ int maxMRegs = mRegBlockSize;
+ int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
+ assert(
+ maxMRegs * maxNRegs <= 28 &&
+ "MR*(NR*ROW_INTERLEAVE*8/512) \
+ must be <= 28(available registers constraint)");
+
+ int mRegBlocks = mc / mRegBlockSize;
+ int mRegBlocksRem = mc % mRegBlockSize;
+
+ // arguments to the function created
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
+
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::
+ FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
+ asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrame frame;
+ frame.init(func);
+
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(
+ x86::Reg::kGroupGp,
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+
+ args.updateFuncFrame(frame);
+ frame.finalize();
+
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
+
+ asmjit::Label LoopMBlocks = a->newLabel();
+ asmjit::Label LoopNBlocks = a->newLabel();
+ asmjit::Label Loopk = a->newLabel();
+
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp jIdx = a->gpz(14);
+ x86::Gp kIdx = a->gpz(15);
+ // x86::Gp B_pf = a->gpz(8);
+
+ x86::Zmm oneReg = x86::zmm29;
+ // create 16-bit 1s
+ // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
+ // and so on
+ // a->vpcmpeqw(oneReg, oneReg, oneReg);
+ a->vpternlogd(oneReg, oneReg, oneReg, 0xff);
+ a->vpsrlw(oneReg, oneReg, 15);
+ a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t)));
+
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ a->mov(B_pf_saved, B_pf);
+
+ int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
+ int colRegs = std::min(currColRegs, maxNRegs);
+ if (mRegBlocks > 0) {
+ // move 0 to iteration variables
+ a->mov(iIdx, 0);
+
+ a->bind(LoopMBlocks);
+ a->inc(iIdx);
+ a->mov(jIdx, 0);
+
+ a->bind(LoopNBlocks);
+ a->inc(jIdx);
+
+ int rowRegs = mRegBlockSize;
+
+ // init C registers
+ initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(Loopk);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx512_vnni>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(
+ buffer_B,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+ a->add(
+ B_pf,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+
+ // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(Loopk);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512_vnni>(
+ a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
+
+ // reset A
+ a->sub(buffer_A, kSize);
+
+ // B for next block
+ a->mov(buffer_B, buffer_B_saved);
+ // using C_Offset as temp reg
+ a->imul(
+ C_Offset,
+ jIdx,
+ static_cast<asmjit::Imm>(
+ nRegBlockSize * row_interleave * sizeof(int8_t)));
+ a->add(buffer_B, C_Offset);
+ a->mov(B_pf, B_pf_saved);
+ a->add(B_pf, C_Offset);
+
+ // increment C for next B block
+ a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
+
+ int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
+ a->cmp(jIdx, jLoopTrips);
+ a->jl(LoopNBlocks);
+
+ // increment A for next block
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
+
+ // increment C for next A block
+ a->sub(
+ CBase,
+ static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * sizeof(int32_t)));
+ a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs));
+ a->add(CBase, C_Offset);
+
+ // reset B
+ a->mov(buffer_B, buffer_B_saved);
+ a->mov(B_pf, B_pf_saved);
+ a->cmp(iIdx, mRegBlocks);
+ a->jl(LoopMBlocks);
+ }
+ // generate code for remainder
+ if (mRegBlocksRem > 0) {
+ asmjit::Label LoopNRem = a->newLabel();
+ asmjit::Label LoopkRem = a->newLabel();
+ int rowRegs = mRegBlocksRem;
+
+ a->mov(jIdx, 0);
+ a->bind(LoopNRem);
+ a->inc(jIdx);
+
+ // init C registers
+ initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(LoopkRem);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx512_vnni>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(
+ buffer_B,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+ a->add(
+ B_pf,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(LoopkRem);
+
+ // reset A
+ a->sub(buffer_A, kSize);
+ // B for next block
+ // using C_Offset as temp reg
+ a->imul(
+ C_Offset,
+ jIdx,
+ static_cast<asmjit::Imm>(
+ nRegBlockSize * row_interleave * sizeof(int8_t)));
+ a->mov(buffer_B, buffer_B_saved);
+ a->add(buffer_B, C_Offset);
+ a->mov(B_pf, B_pf_saved);
+ a->add(B_pf, C_Offset);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512_vnni>(
+ a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
+
+ // increment C for next B block
+ a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
+
+ int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
+ a->cmp(jIdx, jLoopTrips);
+ a->jl(LoopNRem);
+ }
+
+ a->emitEpilog(frame);
+
+ jit_micro_kernel_fp fn;
+ asmjit::Error err = rt_.add(&fn, &code_);
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
+ codeCache_[kernelSig] = fn;
+
+#if defined(FBGEMM_LOG_CODE)
+ fclose(codeLogfile);
+ delete codeLogger;
+#endif
+
+ return fn;
+}
+
+} // namespace fbgemm
diff --git a/src/GroupwiseConv.h b/src/GroupwiseConv.h
index 1e6324e..4c5eea5 100644
--- a/src/GroupwiseConv.h
+++ b/src/GroupwiseConv.h
@@ -128,60 +128,58 @@ class GenConvKernel {
const conv_param_t<SPATIAL_DIM>& conv_param);
template <inst_set_t instSet>
- void createVector16BitOne(asmjit::X86Emitter* a);
+ void createVector16BitOne(x86::Emitter* a);
template <inst_set_t instSet>
- void createVector8BitOne(asmjit::X86Emitter* a);
+ void createVector8BitOne(x86::Emitter* a);
template <inst_set_t instSet>
- void setToZeroPt(asmjit::X86Emitter* a, asmjit::X86Ymm destReg);
+ void setToZeroPt(x86::Emitter* a, x86::Ymm destReg);
template <inst_set_t instSet>
- void
- gen8bitFMA(asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm wReg);
+ void gen8bitFMA(x86::Emitter* a, x86::Ymm aReg, x86::Ymm wReg);
template <inst_set_t instSet>
- void genForLoadingWeights(asmjit::X86Emitter* a, int c_offset);
+ void genForLoadingWeights(x86::Emitter* a, int c_offset);
template <inst_set_t instSet>
- void genConstForPermutations(asmjit::X86Emitter* a);
+ void genConstForPermutations(x86::Emitter* a);
template <inst_set_t instSet>
- void genForTopEdge(asmjit::X86Emitter* a, int c_offset);
+ void genForTopEdge(x86::Emitter* a, int c_offset);
template <inst_set_t instSet>
- void genForLeftEdge(asmjit::X86Emitter* a, int c_offset);
+ void genForLeftEdge(x86::Emitter* a, int c_offset);
template <inst_set_t instSet>
- void genForRightEdge(asmjit::X86Emitter* a, int c_offset);
+ void genForRightEdge(x86::Emitter* a, int c_offset);
template <inst_set_t instSet>
- void genForBottomEdge(asmjit::X86Emitter* a, int c_offset);
+ void genForBottomEdge(x86::Emitter* a, int c_offset);
template <inst_set_t instSet>
- void genCoreInsts(asmjit::X86Emitter* a, int c_offset);
+ void genCoreInsts(x86::Emitter* a, int c_offset);
template <inst_set_t instSet>
- void storeResult(asmjit::X86Emitter* a);
+ void storeResult(x86::Emitter* a);
// for Rowoffset kernel
// Add 4 consecutive numbers of 32 uint8 and emit 8 32-bit
template <inst_set_t instSet>
- void gen8BitSumX4(asmjit::X86Emitter* a, asmjit::X86Ymm aReg);
+ void gen8BitSumX4(x86::Emitter* a, x86::Ymm aReg);
// Add 8 consecutive numbers of 64 uint8 and emit 8 32-bit
template <inst_set_t instSet>
- void
- gen8BitSumX8(asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm bReg);
+ void gen8BitSumX8(x86::Emitter* a, x86::Ymm aReg, x86::Ymm bReg);
// Add 16 consecutive numbers of 128 uint8 and emit 8 32-bit
template <inst_set_t instSet>
void gen8BitSumX16(
- asmjit::X86Emitter* a,
- asmjit::X86Ymm aReg,
- asmjit::X86Ymm bReg,
- asmjit::X86Ymm cReg,
- asmjit::X86Ymm dReg);
+ x86::Emitter* a,
+ x86::Ymm aReg,
+ x86::Ymm bReg,
+ x86::Ymm cReg,
+ x86::Ymm dReg);
// Generate instruction sequence that loads 8-bit values and sum them up.
// Depending on C_per_G_, this function dispatches to gen8BitSumX4/8/16
@@ -191,35 +189,33 @@ class GenConvKernel {
// Internally, actRegAvx2_, stPermRegAvx2_, WRegs_avx2_[0, 1], tmpReg1Avx2_,
// and resultRegAvx2_ are used.
template <inst_set_t instSet>
- void gen8BitSum(
- asmjit::X86Emitter* a,
- int act_offset,
- bool use_scratch_reg1 = true);
+ void
+ gen8BitSum(x86::Emitter* a, int act_offset, bool use_scratch_reg1 = true);
// Use scratchReg1_ and tmpReg1Avx2_ internally
template <inst_set_t instSet>
- void genZeroPtSum(asmjit::X86Emitter* a, int multiplier);
+ void genZeroPtSum(x86::Emitter* a, int multiplier);
template <inst_set_t instSet>
- void genForTopEdgeRowoffset(asmjit::X86Emitter* a);
+ void genForTopEdgeRowoffset(x86::Emitter* a);
template <inst_set_t instSet>
- void genForLeftEdgeRowoffset(asmjit::X86Emitter* a);
+ void genForLeftEdgeRowoffset(x86::Emitter* a);
template <inst_set_t instSet>
- void genForRightEdgeRowoffset(asmjit::X86Emitter* a);
+ void genForRightEdgeRowoffset(x86::Emitter* a);
template <inst_set_t instSet>
- void genForBottomEdgeRowoffset(asmjit::X86Emitter* a);
+ void genForBottomEdgeRowoffset(x86::Emitter* a);
template <inst_set_t instSet>
- void genRowoffsetCorners(asmjit::X86Emitter* a);
+ void genRowoffsetCorners(x86::Emitter* a);
template <inst_set_t instSet>
- void genRowoffsetCore(asmjit::X86Emitter* a);
+ void genRowoffsetCore(x86::Emitter* a);
template <inst_set_t instSet>
- void storeResultRowoffset(asmjit::X86Emitter* a, int offset = 0);
+ void storeResultRowoffset(x86::Emitter* a, int offset = 0);
static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit.
static thread_local asmjit::CodeHolder code_; ///< JIT Code Holder for asmjit.
@@ -234,30 +230,30 @@ class GenConvKernel {
int vectorWidth_; ///< Vector width in bits.
int VLEN_; ///< Vector width in elements.
// avx2 specific
- asmjit::X86Ymm
+ x86::Ymm
WRegs_avx2_[9]; ///< AVX2 ymm registers for weights in the micro-kernel.
- asmjit::X86Ymm zeroPTRegAvx2_;
- asmjit::X86Ymm tmpReg1Avx2_;
- asmjit::X86Ymm stPermRegAvx2_;
- asmjit::X86Ymm actRegAvx2_;
- asmjit::X86Ymm resultRegAvx2_;
- asmjit::X86Ymm oneReg8BitAvx2_;
- asmjit::X86Ymm oneReg16BitAvx2_;
+ x86::Ymm zeroPTRegAvx2_;
+ x86::Ymm tmpReg1Avx2_;
+ x86::Ymm stPermRegAvx2_;
+ x86::Ymm actRegAvx2_;
+ x86::Ymm resultRegAvx2_;
+ x86::Ymm oneReg8BitAvx2_;
+ x86::Ymm oneReg16BitAvx2_;
// arguments to the function created
- asmjit::X86Gp in_acts_R_;
- asmjit::X86Gp wghts_R_;
- asmjit::X86Gp out_acts_R_;
- asmjit::X86Gp a_zero_pt_R_;
- asmjit::X86Gp H_R_;
- asmjit::X86Gp W_R_;
- asmjit::X86Gp row_offset_R_;
+ x86::Gp in_acts_R_;
+ x86::Gp wghts_R_;
+ x86::Gp out_acts_R_;
+ x86::Gp a_zero_pt_R_;
+ x86::Gp H_R_;
+ x86::Gp W_R_;
+ x86::Gp row_offset_R_;
// Used registers
- asmjit::X86Gp loopR1_;
- asmjit::X86Gp loopR2_;
- asmjit::X86Gp scratchReg1_;
- asmjit::X86Gp scratchReg2_;
+ x86::Gp loopR1_;
+ x86::Gp loopR2_;
+ x86::Gp scratchReg1_;
+ x86::Gp scratchReg2_;
// Other parameters
bool isAZeroPointZero_;
diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc
index e789695..b140c83 100644
--- a/src/GroupwiseConvAcc32Avx2.cc
+++ b/src/GroupwiseConvAcc32Avx2.cc
@@ -104,7 +104,7 @@ jit_conv_kernel_fp getOrCreateConvKernel(
template <>
template <>
void GenConvKernel<2, int32_t>::createVector8BitOne<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// create 8-bit 1s
// i.e., oneReg16BitAvx2_[0:7] contains 0x01, oneReg8BitAvx2_[8:15] contains
// 0x01 and so on
@@ -115,7 +115,7 @@ void GenConvKernel<2, int32_t>::createVector8BitOne<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::createVector16BitOne<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// create 16-bit 1s
// i.e., oneReg16BitAvx2_[0:15] contains 0x0001, oneReg16BitAvx2_[16:31]
// contains 0x0001 and so on
@@ -125,11 +125,11 @@ void GenConvKernel<2, int32_t>::createVector16BitOne<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::setToZeroPt<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
- asmjit::X86Ymm destReg) {
+ x86::Emitter* a,
+ x86::Ymm destReg) {
// make destReg all zeros
a->vxorps(destReg, destReg, destReg);
- asmjit::X86Xmm const_reg_xmm = x86::xmm10;
+ x86::Xmm const_reg_xmm = x86::xmm10;
// move zero point to xmm10
a->movq(const_reg_xmm, a_zero_pt_R_);
// make copies of zero point
@@ -143,9 +143,9 @@ void GenConvKernel<2, int32_t>::setToZeroPt<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genConstForPermutations<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
- asmjit::X86Gp permute_const_reg = a->gpzRef(12);
- asmjit::X86Xmm const_reg_xmm = x86::xmm10;
+ x86::Emitter* a) {
+ x86::Gp permute_const_reg = a->gpz(12);
+ x86::Xmm const_reg_xmm = x86::xmm10;
// We have 1st group in even lanes and 2nd group in odd lanes.
// Permute to put 1st group to lower 128-bit and 2nd group in upper
// 128-bit.
@@ -159,8 +159,7 @@ void GenConvKernel<2, int32_t>::genConstForPermutations<inst_set_t::avx2>(
template <>
template <>
-void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>(x86::Emitter* a) {
if (C_per_G_ == 4) {
// store with permutation
a->vpermd(resultRegAvx2_, stPermRegAvx2_, resultRegAvx2_);
@@ -171,7 +170,7 @@ void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::storeResultRowoffset<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int offset) {
// store
if (C_per_G_ == 4) {
@@ -198,7 +197,7 @@ void GenConvKernel<2, int32_t>::storeResultRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForLoadingWeights<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int c_offset) {
// load weights
for (int r = 0; r < R_; ++r) {
@@ -225,9 +224,9 @@ void GenConvKernel<2, int32_t>::genForLoadingWeights<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8bitFMA<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
- asmjit::X86Ymm aReg,
- asmjit::X86Ymm wReg) {
+ x86::Emitter* a,
+ x86::Ymm aReg,
+ x86::Ymm wReg) {
a->vpmaddubsw(tmpReg1Avx2_, aReg, wReg);
a->vpmaddwd(tmpReg1Avx2_, oneReg16BitAvx2_, tmpReg1Avx2_);
a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_);
@@ -236,8 +235,8 @@ void GenConvKernel<2, int32_t>::gen8bitFMA<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8BitSumX4<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
- asmjit::X86Ymm aReg) {
+ x86::Emitter* a,
+ x86::Ymm aReg) {
a->vpmaddubsw(tmpReg1Avx2_, aReg, oneReg8BitAvx2_);
a->vpmaddwd(tmpReg1Avx2_, tmpReg1Avx2_, oneReg16BitAvx2_);
a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_);
@@ -246,9 +245,9 @@ void GenConvKernel<2, int32_t>::gen8BitSumX4<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8BitSumX8<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
- asmjit::X86Ymm aReg,
- asmjit::X86Ymm bReg) {
+ x86::Emitter* a,
+ x86::Ymm aReg,
+ x86::Ymm bReg) {
a->vxorps(tmpReg1Avx2_, tmpReg1Avx2_, tmpReg1Avx2_);
// Let a[0] denote 0th (LSB) 8-bit of aReg
// After vpsadbw, a[0:2] = a[0] + ... + a[7]
@@ -267,11 +266,11 @@ void GenConvKernel<2, int32_t>::gen8BitSumX8<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8BitSumX16<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
- asmjit::X86Ymm aReg,
- asmjit::X86Ymm bReg,
- asmjit::X86Ymm cReg,
- asmjit::X86Ymm dReg) {
+ x86::Emitter* a,
+ x86::Ymm aReg,
+ x86::Ymm bReg,
+ x86::Ymm cReg,
+ x86::Ymm dReg) {
a->vxorps(tmpReg1Avx2_, tmpReg1Avx2_, tmpReg1Avx2_);
// After vpsadbw, a[0:2] = a[0] + ... + a[7]
// a[8:10] = a[8] + ... + a[15]
@@ -319,7 +318,7 @@ void GenConvKernel<2, int32_t>::gen8BitSumX16<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8BitSum<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int act_offset,
bool use_scratch_reg1 /*=true*/) {
if (use_scratch_reg1) {
@@ -385,11 +384,11 @@ void GenConvKernel<2, int32_t>::gen8BitSum<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genZeroPtSum<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int multiplier) {
a->mov(scratchReg1_, static_cast<asmjit::Imm>(multiplier));
// tmpReg1Avx2_ also uses xmm11
- asmjit::X86Xmm const_reg_xmm = x86::xmm11;
+ x86::Xmm const_reg_xmm = x86::xmm11;
a->movq(const_reg_xmm, scratchReg1_);
a->vpbroadcastd(tmpReg1Avx2_, const_reg_xmm);
a->vpmulld(tmpReg1Avx2_, zeroPTRegAvx2_, tmpReg1Avx2_);
@@ -399,7 +398,7 @@ void GenConvKernel<2, int32_t>::genZeroPtSum<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForTopEdge<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int c_offset) {
// top-left corner code
if (c_offset == 0) {
@@ -559,7 +558,7 @@ void GenConvKernel<2, int32_t>::genForTopEdge<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForLeftEdge<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int c_offset) {
// left edge excluding corners
asmjit::Label LoopLeftEdge = a->newLabel();
@@ -626,7 +625,7 @@ void GenConvKernel<2, int32_t>::genForLeftEdge<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForRightEdge<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int c_offset) {
// right edge excluding corners
asmjit::Label LoopRightEdge = a->newLabel();
@@ -714,7 +713,7 @@ void GenConvKernel<2, int32_t>::genForRightEdge<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForBottomEdge<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int c_offset) {
// bottom-left corner
// we updating the last row
@@ -906,7 +905,7 @@ void GenConvKernel<2, int32_t>::genForBottomEdge<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genCoreInsts<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int c_offset) {
// main compute
asmjit::Label LoopH = a->newLabel();
@@ -1011,9 +1010,9 @@ template <>
jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
const conv_param_t<2>& conv_param) {
code_.reset(false);
- code_.init(rt_.getCodeInfo());
- asmjit::X86Assembler assembler(&code_);
- asmjit::X86Emitter* a = assembler.asEmitter();
+ code_.init(rt_.codeInfo());
+ x86::Assembler assembler(&code_);
+ x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
// log code to a file
@@ -1030,16 +1029,16 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
wghts_R_ = a->zsi();
out_acts_R_ = a->zdx();
a_zero_pt_R_ = a->zcx();
- H_R_ = a->gpzRef(8);
- W_R_ = a->gpzRef(9);
- row_offset_R_ = a->gpzRef(10);
+ H_R_ = a->gpz(8);
+ W_R_ = a->gpz(9);
+ row_offset_R_ = a->gpz(10);
// register for temporary use
- scratchReg1_ = a->gpzRef(12);
- scratchReg2_ = a->gpzRef(13);
+ scratchReg1_ = a->gpz(12);
+ scratchReg2_ = a->gpz(13);
asmjit::FuncDetail func;
- func.init(asmjit::FuncSignature6<
+ func.init(asmjit::FuncSignatureT<
void,
uint8_t*,
int8_t*,
@@ -1048,29 +1047,29 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
int32_t,
int32_t>(asmjit::CallConv::kIdHost));
- asmjit::FuncFrameInfo ffi;
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindVec,
- asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindGp, asmjit::Utils::mask(10, 11, 12, 13, 14, 15));
+ asmjit::FuncFrame frame;
+ frame.init(func);
- asmjit::FuncArgsMapper args(&func);
- args.assignAll(in_acts_R_, wghts_R_, out_acts_R_, a_zero_pt_R_, H_R_, W_R_);
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(
+ x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- args.updateFrameInfo(ffi);
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(in_acts_R_, wghts_R_, out_acts_R_, a_zero_pt_R_, H_R_, W_R_);
- asmjit::FuncFrameLayout layout;
- layout.init(func, ffi);
+ args.updateFuncFrame(frame);
+ frame.finalize();
- asmjit::FuncUtils::emitProlog(a, layout);
- asmjit::FuncUtils::allocArgs(a, layout, args);
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
createVector16BitOne<inst_set_t::avx2>(a);
- loopR1_ = a->gpzRef(14);
- loopR2_ = a->gpzRef(15);
+ loopR1_ = a->gpz(14);
+ loopR2_ = a->gpz(15);
if (!isAZeroPointZero_) {
setToZeroPt<inst_set_t::avx2>(a, zeroPTRegAvx2_);
@@ -1095,7 +1094,7 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
genCoreInsts<inst_set_t::avx2>(a, c);
}
- asmjit::FuncUtils::emitEpilog(a, layout);
+ a->emitEpilog(frame);
jit_conv_kernel_fp fn;
asmjit::Error err = rt_.add(&fn, &code_);
@@ -1117,7 +1116,7 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// top-left corner code
// zero out the results register
a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
@@ -1213,7 +1212,7 @@ void GenConvKernel<2, int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// left edge excluding corners
asmjit::Label LoopLeftEdge = a->newLabel();
a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_));
@@ -1256,7 +1255,7 @@ void GenConvKernel<2, int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// right edge excluding corners
asmjit::Label LoopRightEdge = a->newLabel();
@@ -1326,7 +1325,7 @@ void GenConvKernel<2, int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// bottom-left corner
// zero out
a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
@@ -1429,7 +1428,7 @@ void GenConvKernel<2, int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genRowoffsetCore<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// number of uint8 elements in input channels should be a multiple of 32
assert(C_ % 32 == 0);
@@ -1491,9 +1490,9 @@ jit_rowoffset_kernel_fp
GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
const conv_param_t<2>& conv_param) {
code_.reset(false);
- code_.init(rt_.getCodeInfo());
- asmjit::X86Assembler assembler(&code_);
- asmjit::X86Emitter* a = assembler.asEmitter();
+ code_.init(rt_.codeInfo());
+ x86::Assembler assembler(&code_);
+ x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
// log code to a file
@@ -1510,45 +1509,45 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
a_zero_pt_R_ = a->zsi();
H_R_ = a->zdx();
W_R_ = a->zcx();
- row_offset_R_ = a->gpzRef(8);
+ row_offset_R_ = a->gpz(8);
// register for temporary use
- scratchReg1_ = a->gpzRef(12);
- scratchReg2_ = a->gpzRef(13);
+ scratchReg1_ = a->gpz(12);
+ scratchReg2_ = a->gpz(13);
- loopR1_ = a->gpzRef(14);
- loopR2_ = a->gpzRef(15);
+ loopR1_ = a->gpz(14);
+ loopR2_ = a->gpz(15);
asmjit::FuncDetail func;
func.init(
asmjit::
- FuncSignature5<void, uint8_t*, int32_t, int32_t, int32_t, int32_t*>(
+ FuncSignatureT<void, uint8_t*, int32_t, int32_t, int32_t, int32_t*>(
asmjit::CallConv::kIdHost));
- asmjit::FuncFrameInfo ffi;
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindVec,
- asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindGp, asmjit::Utils::mask(10, 11, 12, 13, 14, 15));
+ asmjit::FuncFrame frame;
+ frame.init(func);
- asmjit::FuncArgsMapper args(&func);
- args.assignAll(in_acts_R_, a_zero_pt_R_, H_R_, W_R_, row_offset_R_);
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(
+ x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- args.updateFrameInfo(ffi);
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(in_acts_R_, a_zero_pt_R_, H_R_, W_R_, row_offset_R_);
- asmjit::FuncFrameLayout layout;
- layout.init(func, ffi);
+ args.updateFuncFrame(frame);
+ frame.finalize();
- asmjit::FuncUtils::emitProlog(a, layout);
- asmjit::FuncUtils::allocArgs(a, layout, args);
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
// This uses xmm10 register temporarily. Should come before
// createVector8BitOne
if (!isAZeroPointZero_) {
// we can use xmm11 because ymm11 is used by tmpReg1Avx2_
- asmjit::X86Xmm const_reg_xmm = x86::xmm11;
+ x86::Xmm const_reg_xmm = x86::xmm11;
a->movq(const_reg_xmm, a_zero_pt_R_);
a->vpbroadcastd(zeroPTRegAvx2_, const_reg_xmm);
@@ -1569,7 +1568,7 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
genRowoffsetCore<inst_set_t::avx2>(a);
- asmjit::FuncUtils::emitEpilog(a, layout);
+ a->emitEpilog(frame);
jit_rowoffset_kernel_fp fn;
asmjit::Error err = rt_.add(&fn, &code_);
diff --git a/src/PackAMatrix.cc b/src/PackAMatrix.cc
index 143e11d..5fabf97 100644
--- a/src/PackAMatrix.cc
+++ b/src/PackAMatrix.cc
@@ -34,7 +34,8 @@ PackAMatrix<T, accT>::PackAMatrix(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
- if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) {
+ if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
+ !fbgemmHasAvx2Support())) {
assert(0 && "unknown architecure");
}
@@ -43,7 +44,12 @@ PackAMatrix<T, accT>::PackAMatrix(
BaseType::bcol_ = params->KCB;
row_interleave_B_ = params->ROW_INTERLEAVE;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE;
+ } else if (fbgemmHasAvx512Support()) {
BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
row_interleave_B_ =
diff --git a/src/PackAWithIm2Col.cc b/src/PackAWithIm2Col.cc
index d731654..2aca27d 100644
--- a/src/PackAWithIm2Col.cc
+++ b/src/PackAWithIm2Col.cc
@@ -49,7 +49,8 @@ PackAWithIm2Col<T, accT, SPATIAL_DIM>::PackAWithIm2Col(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
- if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) {
+ if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
+ !fbgemmHasAvx2Support())) {
assert(0 && "unknown architecure");
}
@@ -58,7 +59,12 @@ PackAWithIm2Col<T, accT, SPATIAL_DIM>::PackAWithIm2Col(
BaseType::bcol_ = params->KCB;
row_interleave_B_ = params->ROW_INTERLEAVE;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE;
+ } else if (fbgemmHasAvx512Support()) {
BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
row_interleave_B_ =
@@ -478,7 +484,9 @@ int PackAWithIm2Col<T, accT, SPATIAL_DIM>::rowOffsetBufferSize(
if (params) {
return params->MCB;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ return PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
+ } else if (fbgemmHasAvx512Support()) {
return PackingTraits<T, accT, inst_set_t::avx512>::MCB;
} else if (fbgemmHasAvx2Support()) {
return PackingTraits<T, accT, inst_set_t::avx2>::MCB;
diff --git a/src/PackAWithQuantRowOffset.cc b/src/PackAWithQuantRowOffset.cc
index 0e5c598..0af05e8 100644
--- a/src/PackAWithQuantRowOffset.cc
+++ b/src/PackAWithQuantRowOffset.cc
@@ -45,7 +45,8 @@ PackAWithQuantRowOffset<T, accT>::PackAWithQuantRowOffset(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
- if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) {
+ if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
+ !fbgemmHasAvx2Support())) {
assert(0 && "unknown architecure");
}
@@ -54,7 +55,12 @@ PackAWithQuantRowOffset<T, accT>::PackAWithQuantRowOffset(
BaseType::bcol_ = params->KCB;
row_interleave_B_ = params->ROW_INTERLEAVE;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE;
+ } else if (fbgemmHasAvx512Support()) {
BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
row_interleave_B_ =
@@ -199,7 +205,9 @@ int PackAWithQuantRowOffset<T, accT>::rowOffsetBufferSize(
if (params) {
return params->MCB;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ return PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
+ } else if (fbgemmHasAvx512Support()) {
return PackingTraits<T, accT, inst_set_t::avx512>::MCB;
} else if (fbgemmHasAvx2Support()) {
return PackingTraits<T, accT, inst_set_t::avx2>::MCB;
diff --git a/src/PackAWithRowOffset.cc b/src/PackAWithRowOffset.cc
index 733bf5c..e84c67b 100644
--- a/src/PackAWithRowOffset.cc
+++ b/src/PackAWithRowOffset.cc
@@ -39,7 +39,8 @@ PackAWithRowOffset<T, accT>::PackAWithRowOffset(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
- if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) {
+ if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
+ !fbgemmHasAvx2Support())) {
assert(0 && "unknown architecure");
}
@@ -48,7 +49,12 @@ PackAWithRowOffset<T, accT>::PackAWithRowOffset(
BaseType::bcol_ = params->KCB;
row_interleave_B_ = params->ROW_INTERLEAVE;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE;
+ } else if (fbgemmHasAvx512Support()) {
BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
row_interleave_B_ =
@@ -189,7 +195,9 @@ int PackAWithRowOffset<T, accT>::rowOffsetBufferSize(
if (params) {
return params->MCB;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ return PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
+ } else if (fbgemmHasAvx512Support()) {
return PackingTraits<T, accT, inst_set_t::avx512>::MCB;
} else if (fbgemmHasAvx2Support()) {
return PackingTraits<T, accT, inst_set_t::avx2>::MCB;
diff --git a/src/PackBMatrix.cc b/src/PackBMatrix.cc
index 0990edb..bf43fab 100644
--- a/src/PackBMatrix.cc
+++ b/src/PackBMatrix.cc
@@ -188,7 +188,8 @@ PackBMatrix<T, accT>::PackBMatrix(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
- if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) {
+ if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
+ !fbgemmHasAvx2Support())) {
assert(0 && "unknown architecure");
}
@@ -197,7 +198,12 @@ PackBMatrix<T, accT>::PackBMatrix(
BaseType::bcol_ = params->NCB;
row_interleave_ = params->ROW_INTERLEAVE;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::NCB;
+ row_interleave_ =
+ PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE;
+ } else if (fbgemmHasAvx512Support()) {
BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::NCB;
row_interleave_ =
@@ -317,14 +323,16 @@ void PackBMatrix<T, accT>::pack_unpack_(
}
template <typename T, typename accT>
-void PackBMatrix<T, accT>::pack(const block_type_t& block,
- const BlockingFactors* params) {
+void PackBMatrix<T, accT>::pack(
+ const block_type_t& block,
+ const BlockingFactors* params) {
pack_unpack_(block, const_cast<T*>(smat_), BaseType::getBuf(), true, params);
}
template <typename T, typename accT>
-void PackBMatrix<T, accT>::unpack(T* origin_buf,
- const BlockingFactors* params) {
+void PackBMatrix<T, accT>::unpack(
+ T* origin_buf,
+ const BlockingFactors* params) {
block_type_t blockB{BaseType::packedRowStart(),
BaseType::numPackedRows(),
BaseType::packedColStart(),
@@ -352,8 +360,9 @@ int32_t PackBMatrix<T, accT>::addr(int32_t r, int32_t c) const {
}
template <typename T, typename accT>
-void PackBMatrix<T, accT>::printPackedMatrix(std::string name,
- const BlockingFactors* params) {
+void PackBMatrix<T, accT>::printPackedMatrix(
+ std::string name,
+ const BlockingFactors* params) {
std::cout << name << ":"
<< "[" << BaseType::numPackedRows() << ", "
<< BaseType::numPackedCols() << "]" << std::endl;
diff --git a/src/PackMatrix.cc b/src/PackMatrix.cc
index c7503dd..ff7b842 100644
--- a/src/PackMatrix.cc
+++ b/src/PackMatrix.cc
@@ -36,7 +36,8 @@ int PackMatrix<PT, inpType, accType>::packedBufferSize(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
- if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) {
+ if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
+ !fbgemmHasAvx2Support())) {
assert(0 && "unknown architecure");
}
@@ -46,7 +47,11 @@ int PackMatrix<PT, inpType, accType>::packedBufferSize(
NCB = params->NCB;
KCB = params->KCB;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ MCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni>::MCB;
+ NCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni>::NCB;
+ KCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni>::KCB;
+ } else if (fbgemmHasAvx512Support()) {
MCB = PackingTraits<inpType, accType, inst_set_t::avx512>::MCB;
NCB = PackingTraits<inpType, accType, inst_set_t::avx512>::NCB;
KCB = PackingTraits<inpType, accType, inst_set_t::avx512>::KCB;
diff --git a/src/PackWeightMatrixForGConv.cc b/src/PackWeightMatrixForGConv.cc
index ba6adf3..f6ad59e 100644
--- a/src/PackWeightMatrixForGConv.cc
+++ b/src/PackWeightMatrixForGConv.cc
@@ -106,7 +106,7 @@ inline int PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::packed_index_(
* on 2 groups at a time and full SIMD width can be efficiently utilized even
* while working on 1 group at a time.
* In this case, the layout is G (C/4) R S K 4
-*/
+ */
template <typename T, typename accT, int SPATIAL_DIM>
void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack_unpack_(
@@ -148,9 +148,9 @@ void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack_unpack_(
if (ispack) {
transposeConvWeights(conv_param_, src, dst);
} else {
- // TODO: Wrap this as a inverseTransposeConvWeights()?
- // For unpack & transposed, call transposeConvWeights()
- // G (R S C/G) K/G => G K/G (R S C/G)
+ // TODO: Wrap this as a inverseTransposeConvWeights()?
+ // For unpack & transposed, call transposeConvWeights()
+ // G (R S C/G) K/G => G K/G (R S C/G)
for (int r = 0; r < R; ++r) {
for (int s = 0; s < S; ++s) {
for (int k = 0; k < OC_per_G; ++k) {
diff --git a/src/Utils.cc b/src/Utils.cc
index 0fa620d..5214e41 100644
--- a/src/Utils.cc
+++ b/src/Utils.cc
@@ -202,4 +202,7 @@ bool fbgemmHasAvx2Support() {
return (cpuinfo_has_x86_avx2());
}
+bool fbgemmHasAvx512VnniSupport() {
+ return (cpuinfo_has_x86_avx512vnni());
+}
} // namespace fbgemm
diff --git a/test/GConvTest.cc b/test/GConvTest.cc
index 0074535..8c1fb82 100644
--- a/test/GConvTest.cc
+++ b/test/GConvTest.cc
@@ -465,8 +465,8 @@ TEST_P(fbgemmGConvPackTest, PackUnpackTest) {
for (int i = 0; i < weight_len; ++i) {
EXPECT_EQ(Bint8.data()[i], unpack_buf.data()[i])
<< "Pack/Unpack results differ at index " << i
- << ", Reference: " << static_cast<int> (Bint8.data()[i])
- << ", Pack-Unpacked: " << static_cast<int> (unpack_buf.data()[i]);
+ << ", Reference: " << static_cast<int>(Bint8.data()[i])
+ << ", Pack-Unpacked: " << static_cast<int>(unpack_buf.data()[i]);
}
} // for each shape
}
diff --git a/third_party/asmjit b/third_party/asmjit
-Subproject 673dcefaa048c5f5a2bf8b85daf8f7b9978d018
+Subproject 3d510b3540776d4961f5eac83af3643d31cde18