Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJianyu Huang <jianyuhuang@fb.com>2019-08-06 21:55:17 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-08-06 21:59:00 +0300
commitcf34b9a26b609109b18d6498f0608faddb7a911b (patch)
tree1ceaddaf942edb9debcafad7491b750fc3a5f066
parentd8b3323668fdd15dc70e9cb43ab16e96f4846eeb (diff)
Back out "[fbgemm] Integrate VNNI into FBGEMM master branch"
Summary: Original commit changeset: fcaa13cc3159 ASMJIT requires the CMake version to be 3.8 However, FBGEMM and PyTorch only need the CMake version to be 3.5+. This caused the build failure in FBGEMM: https://circleci.com/gh/pytorch/FBGEMM/122#build-timing/containers/0 Reviewed By: dskhudia Differential Revision: D16670547 fbshipit-source-id: 506714c3db1cb82cf98895f58f82f235128f5285
-rw-r--r--CMakeLists.txt2
-rw-r--r--include/fbgemm/PackingTraits-inl.h50
-rw-r--r--include/fbgemm/Utils.h7
-rw-r--r--src/ExecuteKernelU8S8.cc47
-rw-r--r--src/Fbgemm.cc18
-rw-r--r--src/GenerateKernel.h30
-rw-r--r--src/GenerateKernelU8S8S32ACC16.cc91
-rw-r--r--src/GenerateKernelU8S8S32ACC16Avx512.cc94
-rw-r--r--src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc102
-rw-r--r--src/GenerateKernelU8S8S32ACC32.cc87
-rw-r--r--src/GenerateKernelU8S8S32ACC32Avx512.cc95
-rw-r--r--src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc431
-rw-r--r--src/GroupwiseConv.h100
-rw-r--r--src/GroupwiseConvAcc32Avx2.cc177
-rw-r--r--src/PackAMatrix.cc10
-rw-r--r--src/PackAWithIm2Col.cc14
-rw-r--r--src/PackAWithQuantRowOffset.cc14
-rw-r--r--src/PackAWithRowOffset.cc14
-rw-r--r--src/PackBMatrix.cc25
-rw-r--r--src/PackMatrix.cc9
-rw-r--r--src/PackWeightMatrixForGConv.cc8
-rw-r--r--src/Utils.cc3
-rw-r--r--test/GConvTest.cc4
m---------third_party/asmjit0
24 files changed, 378 insertions, 1054 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 817f699..b575e17 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -33,10 +33,8 @@ set(FBGEMM_GENERIC_SRCS src/ExecuteKernel.cc
src/FbgemmI8Spmdm.cc
src/GenerateKernelU8S8S32ACC16.cc
src/GenerateKernelU8S8S32ACC16Avx512.cc
- src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc
src/GenerateKernelU8S8S32ACC32.cc
src/GenerateKernelU8S8S32ACC32Avx512.cc
- src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc
src/GroupwiseConvAcc32Avx2.cc
src/PackAMatrix.cc
src/PackAWithIm2Col.cc
diff --git a/include/fbgemm/PackingTraits-inl.h b/include/fbgemm/PackingTraits-inl.h
index baccfad..76eb425 100644
--- a/include/fbgemm/PackingTraits-inl.h
+++ b/include/fbgemm/PackingTraits-inl.h
@@ -222,53 +222,3 @@ struct PackingTraits<
128}; ///< Cache block for N dimension (multiple of NR).
static constexpr int KCB{256}; ///< Cache block for K dimension.
};
-
-/**
- * @brief Helper struct to type specialize for int16_t and int32_t together.
- */
-template <typename T>
-struct is_16or32bit {
- static constexpr bool value =
- std::is_same<T, int16_t>::value || std::is_same<T, int32_t>::value;
-};
-
-/**
- * @brief Packing parameter specialization for accumulation into 32-bit/16-bit
- * integers.
- *
- * Since there is no int16_t accumulation for AVX512 VNNI, we redirect int16_t
- * to int32_t accumulation and use the same blocking parameters as int32_t.
- *
- * This is picked when T is of int8 type (signed or unsigned) and instruction
- * set is avx512_vnni.
- */
-template <typename T, typename accT>
-struct PackingTraits<
- T,
- accT,
- inst_set_t::avx512_vnni,
- typename std::enable_if<
- is_8bit<T>::value && is_16or32bit<accT>::value>::type> {
- static constexpr int MR{8}; ///< Register block for M dimension.
- static constexpr int NR_MIN{
- 16}; ///< Minimum register block for N dimension.
- ///< 16 because 16*ROW_INTERLEAVE int8 elements
- ///< completely fill a 512-bit wide vector.
- static constexpr int NR{
- 32}; ///< Register block for N dimension.
- ///< Must be a multiple of 16 because 16*ROW_INTERLEAVE int8 elements
- ///< completely fill a 512-bit wide vector. Total registers used for
- ///< N dimension: NR*ROW_INTERLEAVE*8/VLEN. We use MR x
- ///< NR*ROW_INTERLEAVE*8/VLEN zmm registers
- ///< for C accumulations.
-
- static constexpr int ROW_INTERLEAVE{
- 4}; ///< 4 rows are interleaved to use vpmaddubsw instruction for packing
- ///< B matrix.
-
- static constexpr int MCB{
- 128}; ///< Cache block for M dimension (multiple of MR).
- static constexpr int NCB{
- 32}; ///< Cache block for N dimension (multiple of NR).
- static constexpr int KCB{256}; ///< Cache block for K dimension.
-};
diff --git a/include/fbgemm/Utils.h b/include/fbgemm/Utils.h
index 3f8522b..107cf07 100644
--- a/include/fbgemm/Utils.h
+++ b/include/fbgemm/Utils.h
@@ -29,7 +29,7 @@ enum class matrix_op_t { NoTranspose, Transpose };
/**
* @brief Typed enum for supported instruction sets.
*/
-enum class inst_set_t { anyarch, avx2, avx512, avx512_vnni };
+enum class inst_set_t { anyarch, avx2, avx512 };
/**
* @brief Typed enum for optimized paths for convolutions
@@ -100,11 +100,6 @@ FBGEMM_API bool fbgemmHasAvx512Support();
FBGEMM_API bool fbgemmHasAvx2Support();
/**
- * @brief Are we running on a AVX512_VNNI supported cpu?
- */
-FBGEMM_API bool fbgemmHasAvx512VnniSupport();
-
-/**
* @brief Helper struct to enable autotuning of FBGEMM packing and kernels.
*
* This structure is optional. If not used, the default values for these
diff --git a/src/ExecuteKernelU8S8.cc b/src/ExecuteKernelU8S8.cc
index 0a4ff55..f7292fd 100644
--- a/src/ExecuteKernelU8S8.cc
+++ b/src/ExecuteKernelU8S8.cc
@@ -49,8 +49,7 @@ ExecuteKernel<
throw std::runtime_error("Failed to initialize cpuinfo!");
}
if (params) {
- if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support() ||
- fbgemmHasAvx2Support()) {
+ if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) {
mbSize_ = params->MCB;
nbSize_ = params->NCB;
nrMinSize_ = params->NR_MIN;
@@ -60,20 +59,7 @@ ExecuteKernel<
assert(0 && "unsupported architecure");
}
} else {
- if (fbgemmHasAvx512VnniSupport()) {
- mbSize_ = PackingTraits<
- int8_t,
- typename packingAMatrix::accType,
- inst_set_t::avx512_vnni>::MCB;
- nbSize_ = PackingTraits<
- int8_t,
- typename packingAMatrix::accType,
- inst_set_t::avx512_vnni>::NCB;
- nrMinSize_ = PackingTraits<
- int8_t,
- typename packingAMatrix::accType,
- inst_set_t::avx512_vnni>::NR_MIN;
- } else if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512Support()) {
mbSize_ = PackingTraits<
int8_t,
typename packingAMatrix::accType,
@@ -132,25 +118,7 @@ void ExecuteKernel<
typename BaseType::jit_micro_kernel_fp fn;
- if (fbgemmHasAvx512VnniSupport()) {
- if (std::is_same<typename packingAMatrix::accType, std::int16_t>::value) {
- // For AVX512VNNI, we redirect int16_t to int32_t accumulation.
- CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
- fn = codeObj.getOrCreate<inst_set_t::avx512_vnni>(
- accum,
- packed_rows_A,
- packedB_.blockColSize(),
- packedA_.numPackedCols(),
- nbSize_);
- } else {
- fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni>(
- accum,
- packed_rows_A,
- packedB_.blockColSize(),
- packedA_.numPackedCols(),
- nbSize_);
- }
- } else if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512Support()) {
fn = BaseType::template getOrCreate<inst_set_t::avx512>(
accum,
packed_rows_A,
@@ -180,10 +148,7 @@ void ExecuteKernel<
if (jb == bColBlocks - 1) {
int nc = ((packedB_.lastBcol() - 1) / nrMinSize_ + 1) * nrMinSize_;
if (nc != nbSize_) {
- if (fbgemmHasAvx512VnniSupport()) {
- fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni>(
- accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_);
- } else if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512Support()) {
fn = BaseType::template getOrCreate<inst_set_t::avx512>(
accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_);
} else if (fbgemmHasAvx2Support()) {
@@ -248,7 +213,7 @@ void ExecuteKernel<
int32_t nSize =
C_buffer_start == C_tile_ ? jb * nbSize_ : packedB_.numCols();
if (nSize) {
- if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512Support()) {
// TODO: avx512 path
// Currently use avx2 code
outputProcess_.template f<inst_set_t::avx2>(
@@ -273,7 +238,7 @@ void ExecuteKernel<
if (C_buffer_start == C_tile_) {
// When C_tile_ scratchpad was used to avoid accessing memory past
// C_buffer_ .
- if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512Support()) {
// TODO: avx512 path
// Currently use avx2 code
outputProcess_.template f<inst_set_t::avx2>(
diff --git a/src/Fbgemm.cc b/src/Fbgemm.cc
index 4f7026f..0f2f6fb 100644
--- a/src/Fbgemm.cc
+++ b/src/Fbgemm.cc
@@ -48,8 +48,7 @@ void fbgemmPacked(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
- if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
- !fbgemmHasAvx2Support())) {
+ if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) {
assert(0 && "unknown architecure");
}
@@ -63,20 +62,7 @@ void fbgemmPacked(
MR = blocking_params->MR;
} else {
- if (fbgemmHasAvx512VnniSupport()) {
- MCB = PackingTraits<
- typename packingAMatrix::inpType,
- typename packingAMatrix::accType,
- inst_set_t::avx512_vnni>::MCB;
- KCB = PackingTraits<
- typename packingAMatrix::inpType,
- typename packingAMatrix::accType,
- inst_set_t::avx512_vnni>::KCB;
- MR = PackingTraits<
- typename packingAMatrix::inpType,
- typename packingAMatrix::accType,
- inst_set_t::avx512_vnni>::MR;
- } else if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512Support()) {
MCB = PackingTraits<
typename packingAMatrix::inpType,
typename packingAMatrix::accType,
diff --git a/src/GenerateKernel.h b/src/GenerateKernel.h
index e52097e..dccdfc5 100644
--- a/src/GenerateKernel.h
+++ b/src/GenerateKernel.h
@@ -18,7 +18,7 @@ namespace fbgemm {
namespace x86 = asmjit::x86;
/**
- * @brief AVX2/AVX512/AVX512VNNI JIT assembly code generator.
+ * @brief AVX2/AVX512 JIT assembly code generator.
* @tparam TA Type of matrix A.
* @tparam TB Type of matrix B.
* @tparam TC Type of matrix C.
@@ -104,7 +104,7 @@ class CodeGenBase {
*/
template <inst_set_t instSet>
void initCRegs(
- x86::Emitter* a,
+ asmjit::X86Emitter* a,
int rowRegs,
int colRegs,
int leadingDimCRegAssign = 4);
@@ -114,10 +114,10 @@ class CodeGenBase {
*/
template <inst_set_t instSet>
void genComputeBlock(
- x86::Emitter* a,
- x86::Gp buffer_A,
- x86::Gp buffer_B,
- x86::Gp B_pf,
+ asmjit::X86Emitter* a,
+ asmjit::X86Gp buffer_A,
+ asmjit::X86Gp buffer_B,
+ asmjit::X86Gp B_pf,
int rowRegs,
int colRegs,
int lda,
@@ -129,11 +129,11 @@ class CodeGenBase {
*/
template <inst_set_t instSet>
void storeCRegs(
- x86::Emitter* a,
+ asmjit::X86Emitter* a,
int rowRegs,
int colRegs,
- x86::Gp C_Offset,
- x86::Gp ldcReg,
+ asmjit::X86Gp C_Offset,
+ asmjit::X86Gp ldcReg,
bool accum,
int leadingDimCRegAssign = 4);
@@ -168,9 +168,7 @@ class CodeGenBase {
fileName += "_MR-" + std::to_string(MR);
fileName += "_NR-" + std::to_string(NR);
fileName += "_NR_MIN-" + std::to_string(NR_MIN);
- if (instSet == inst_set_t::avx512_vnni) {
- fileName += "_avx512vnni";
- } else if (instSet == inst_set_t::avx512) {
+ if (instSet == inst_set_t::avx512) {
fileName += "_avx512";
} else if (instSet == inst_set_t::avx2) {
fileName += "_avx2";
@@ -180,10 +178,12 @@ class CodeGenBase {
}
private:
- x86::Ymm CRegs_avx2_[12]; ///< AVX2 ymm registers for C in the micro-kernel.
- x86::Zmm
+ asmjit::X86Ymm
+ CRegs_avx2_[12]; ///< AVX2 ymm registers for C in the micro-kernel.
+ asmjit::X86Zmm
CRegs_avx512_[28]; ///< AVX512 zmm registers for C in the micro-kernel.
- x86::Zmm AllRegs_avx512_[32]; ///< all AVX512 zmm registers.
+ asmjit::X86Zmm
+ AllRegs_avx512_[32]; ///< all AVX512 zmm registers.
int vectorWidth_; ///< Vector width in bits.
int VLEN_; ///< Vector width in elements.
diff --git a/src/GenerateKernelU8S8S32ACC16.cc b/src/GenerateKernelU8S8S32ACC16.cc
index 1e7e081..718b883 100644
--- a/src/GenerateKernelU8S8S32ACC16.cc
+++ b/src/GenerateKernelU8S8S32ACC16.cc
@@ -31,7 +31,7 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs<
inst_set_t::avx2>(
- x86::Emitter* a,
+ asmjit::X86Emitter* a,
int rowRegs,
int colRegs,
int leadingDimCReg) {
@@ -53,18 +53,18 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
inst_set_t::avx2>(
- x86::Emitter* a,
- x86::Gp buffer_A,
- x86::Gp buffer_B,
- x86::Gp /* unused (reserved for prefetching)*/,
+ asmjit::X86Emitter* a,
+ asmjit::X86Gp buffer_A,
+ asmjit::X86Gp buffer_B,
+ asmjit::X86Gp /* unused (reserved for prefetching)*/,
int rowRegs,
int colRegs,
int lda,
int leadingDimCReg) {
// used for matrix A
- x86::Ymm AReg = x86::ymm12;
+ asmjit::X86Ymm AReg = x86::ymm12;
- x86::Ymm tmpReg = x86::ymm14;
+ asmjit::X86Ymm tmpReg = x86::ymm14;
for (int i = 0; i < rowRegs; ++i) {
// broadcast A
@@ -95,15 +95,15 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
inst_set_t::avx2>(
- x86::Emitter* a,
+ asmjit::X86Emitter* a,
int rowRegs,
int colRegs,
- x86::Gp C_Offset,
- x86::Gp ldcReg,
+ asmjit::X86Gp C_Offset,
+ asmjit::X86Gp ldcReg,
bool accum,
int leadingDimCReg) {
- x86::Xmm extractDest128 = x86::xmm15;
- x86::Ymm extractDest256 = x86::ymm15;
+ asmjit::X86Xmm extractDest128 = x86::xmm15;
+ asmjit::X86Ymm extractDest256 = x86::ymm15;
for (int i = 0; i < rowRegs; ++i) {
a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(i * sizeof(int32_t)));
@@ -112,7 +112,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
a->vextracti128(
extractDest128, CRegs_avx2_[i * leadingDimCReg + j], idx);
a->vpmovsxwd(extractDest256, extractDest128);
- x86::Mem destAddr = x86::dword_ptr(
+ asmjit::X86Mem destAddr = x86::dword_ptr(
a->zcx(), C_Offset, 0, (j * 2 + idx) * 8 * sizeof(int32_t));
if (accum) {
a->vpaddd(extractDest256, extractDest256, destAddr);
@@ -176,9 +176,9 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
return codeCache_[kernelSig];
}
code_.reset(false);
- code_.init(rt_.codeInfo());
- x86::Assembler assembler(&code_);
- x86::Emitter* a = assembler.as<x86::Emitter>();
+ code_.init(rt_.getCodeInfo());
+ asmjit::X86Assembler assembler(&code_);
+ asmjit::X86Emitter* a = assembler.asEmitter();
#if defined(FBGEMM_LOG_CODE)
// generated code logging
@@ -207,45 +207,46 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
//"nc must be equal to the number of register blocks");
// arguments to the function created
- x86::Gp buffer_A = a->zdi();
- x86::Gp buffer_B = a->zsi();
- x86::Gp B_pf = a->zdx();
- x86::Gp CBase = a->zcx();
- x86::Gp kSize = a->gpz(8);
- x86::Gp ldcReg = a->gpz(9);
+ asmjit::X86Gp buffer_A = a->zdi();
+ asmjit::X86Gp buffer_B = a->zsi();
+ asmjit::X86Gp B_pf = a->zdx();
+ asmjit::X86Gp CBase = a->zcx();
+ asmjit::X86Gp kSize = a->gpzRef(8);
+ asmjit::X86Gp ldcReg = a->gpzRef(9);
asmjit::FuncDetail func;
func.init(
asmjit::
- FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
+ FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
asmjit::CallConv::kIdHost));
- asmjit::FuncFrame frame;
- frame.init(func);
- frame.setDirtyRegs(
- x86::Reg::kGroupVec,
- asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- frame.setDirtyRegs(
- x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));
+ asmjit::FuncFrameInfo ffi;
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindVec,
+ asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14));
- asmjit::FuncArgsAssignment args(&func);
+ asmjit::FuncArgsMapper args(&func);
args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
- args.updateFuncFrame(frame);
- frame.finalize();
+ args.updateFrameInfo(ffi);
- a->emitProlog(frame);
- a->emitArgsAssignment(frame, args);
+ asmjit::FuncFrameLayout layout;
+ layout.init(func, ffi);
+
+ asmjit::FuncUtils::emitProlog(a, layout);
+ asmjit::FuncUtils::allocArgs(a, layout, args);
asmjit::Label Loopk = a->newLabel();
asmjit::Label LoopMBlocks = a->newLabel();
- x86::Gp buffer_B_saved = a->gpz(10);
- x86::Gp C_Offset = a->gpz(11);
- // x86::Gp B_pf_saved = a->gpz(12);
- x86::Gp iIdx = a->gpz(13);
- x86::Gp kIdx = a->gpz(14);
+ asmjit::X86Gp buffer_B_saved = a->gpzRef(10);
+ asmjit::X86Gp C_Offset = a->gpzRef(11);
+ // asmjit::X86Gp B_pf_saved = a->gpzRef(12);
+ asmjit::X86Gp iIdx = a->gpzRef(13);
+ asmjit::X86Gp kIdx = a->gpzRef(14);
int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
if (mRegBlocks > 0) {
@@ -288,7 +289,8 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
a->jl(Loopk);
// store C matrix
- storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum);
+ storeCRegs<inst_set_t::avx2>(
+ a, rowRegs, colRegs, C_Offset, ldcReg, accum);
// increment A for next block
a->sub(buffer_A, kSize);
@@ -338,10 +340,11 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
a->jl(LoopkRem);
// store C matrix
- storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum);
+ storeCRegs<inst_set_t::avx2>(
+ a, rowRegs, colRegs, C_Offset, ldcReg, accum);
}
- a->emitEpilog(frame);
+ asmjit::FuncUtils::emitEpilog(a, layout);
jit_micro_kernel_fp fn;
asmjit::Error err = rt_.add(&fn, &code_);
diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc
index a49e440..c95757b 100644
--- a/src/GenerateKernelU8S8S32ACC16Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc
@@ -19,7 +19,7 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs<
inst_set_t::avx512>(
- x86::Emitter* a,
+ asmjit::X86Emitter* a,
int rowRegs,
int colRegs,
int leadingDimCReg) {
@@ -41,18 +41,18 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
inst_set_t::avx512>(
- x86::Emitter* a,
- x86::Gp buffer_A,
- x86::Gp buffer_B,
- x86::Gp /* unused (reserved for prefetching)*/,
+ asmjit::X86Emitter* a,
+ asmjit::X86Gp buffer_A,
+ asmjit::X86Gp buffer_B,
+ asmjit::X86Gp /* unused (reserved for prefetching)*/,
int rowRegs,
int colRegs,
int lda,
int leadingDimCReg) {
// used for matrix A
- x86::Zmm AReg = x86::zmm29;
+ asmjit::X86Zmm AReg = x86::zmm29;
- x86::Zmm tmpReg = x86::zmm30;
+ asmjit::X86Zmm tmpReg = x86::zmm30;
// We start allocating BRegs from zmm27 and then allocate zmm26 and so on.
for (int j = 0; j < colRegs; ++j) {
@@ -66,7 +66,8 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
a->vpbroadcastw(
AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t)));
for (int j = 0; j < colRegs; ++j) {
- a->vpmaddubsw(tmpReg, AReg, AllRegs_avx512_[27 - j]);
+ a->vpmaddubsw(
+ tmpReg, AReg, AllRegs_avx512_[27-j]);
a->vpaddsw(
CRegs_avx512_[i * leadingDimCReg + j],
tmpReg,
@@ -89,16 +90,15 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
inst_set_t::avx512>(
- x86::Emitter* a,
+ asmjit::X86Emitter* a,
int rowRegs,
int colRegs,
- x86::Gp C_Offset,
- x86::Gp ldcReg,
-
+ asmjit::X86Gp C_Offset,
+ asmjit::X86Gp ldcReg,
bool accum,
int leadingDimCReg) {
- x86::Ymm extractDest256 = x86::ymm31;
- x86::Zmm extractDest512 = x86::zmm31;
+ asmjit::X86Ymm extractDest256 = x86::ymm31;
+ asmjit::X86Zmm extractDest512 = x86::zmm31;
for (int i = 0; i < rowRegs; ++i) {
a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(i * sizeof(int32_t)));
@@ -107,7 +107,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
a->vextracti32x8(
extractDest256, CRegs_avx512_[i * leadingDimCReg + j], idx);
a->vpmovsxwd(extractDest512, extractDest256);
- x86::Mem destAddr = x86::dword_ptr(
+ asmjit::X86Mem destAddr = x86::dword_ptr(
a->zcx(), C_Offset, 0, (j * 2 + idx) * 16 * sizeof(int32_t));
if (accum) {
a->vpaddd(extractDest512, extractDest512, destAddr);
@@ -172,9 +172,9 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
}
code_.reset(false);
- code_.init(rt_.codeInfo());
- x86::Assembler assembler(&code_);
- x86::Emitter* a = assembler.as<x86::Emitter>();
+ code_.init(rt_.getCodeInfo());
+ asmjit::X86Assembler assembler(&code_);
+ asmjit::X86Emitter* a = assembler.asEmitter();
#if defined(FBGEMM_LOG_CODE)
// generated code logging
@@ -209,49 +209,49 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
int mRegBlocksRem = mc % mRegBlockSize;
// arguments to the function created
- x86::Gp buffer_A = a->zdi();
- x86::Gp buffer_B = a->zsi();
- x86::Gp B_pf = a->zdx();
- x86::Gp CBase = a->zcx();
- x86::Gp kSize = a->gpz(8);
- x86::Gp ldcReg = a->gpz(9);
+ asmjit::X86Gp buffer_A = a->zdi();
+ asmjit::X86Gp buffer_B = a->zsi();
+ asmjit::X86Gp B_pf = a->zdx();
+ asmjit::X86Gp CBase = a->zcx();
+ asmjit::X86Gp kSize = a->gpzRef(8);
+ asmjit::X86Gp ldcReg = a->gpzRef(9);
asmjit::FuncDetail func;
func.init(
asmjit::
- FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
+ FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
asmjit::CallConv::kIdHost));
- asmjit::FuncFrame frame;
- frame.init(func);
-
- frame.setDirtyRegs(
- x86::Reg::kGroupVec,
- asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- frame.setDirtyRegs(
- x86::Reg::kGroupGp,
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ asmjit::FuncFrameInfo ffi;
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindVec,
+ asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindGp,
+ asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
- asmjit::FuncArgsAssignment args(&func);
+ asmjit::FuncArgsMapper args(&func);
args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
- args.updateFuncFrame(frame);
- frame.finalize();
+ args.updateFrameInfo(ffi);
+
+ asmjit::FuncFrameLayout layout;
+ layout.init(func, ffi);
- a->emitProlog(frame);
- a->emitArgsAssignment(frame, args);
+ asmjit::FuncUtils::emitProlog(a, layout);
+ asmjit::FuncUtils::allocArgs(a, layout, args);
asmjit::Label LoopMBlocks = a->newLabel();
asmjit::Label LoopNBlocks = a->newLabel();
asmjit::Label Loopk = a->newLabel();
- x86::Gp buffer_B_saved = a->gpz(10);
- x86::Gp C_Offset = a->gpz(11);
- // x86::Gp B_pf_saved = a->gpz(12);
- x86::Gp iIdx = a->gpz(13);
- x86::Gp jIdx = a->gpz(14);
- x86::Gp kIdx = a->gpz(15);
+ asmjit::X86Gp buffer_B_saved = a->gpzRef(10);
+ asmjit::X86Gp C_Offset = a->gpzRef(11);
+ // asmjit::X86Gp B_pf_saved = a->gpzRef(12);
+ asmjit::X86Gp iIdx = a->gpzRef(13);
+ asmjit::X86Gp jIdx = a->gpzRef(14);
+ asmjit::X86Gp kIdx = a->gpzRef(15);
// save B_buffer address
a->mov(buffer_B_saved, buffer_B);
@@ -407,7 +407,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
a->jl(LoopNRem);
}
- a->emitEpilog(frame);
+ asmjit::FuncUtils::emitEpilog(a, layout);
jit_micro_kernel_fp fn;
asmjit::Error err = rt_.add(&fn, &code_);
diff --git a/src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc b/src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc
deleted file mode 100644
index f559aba..0000000
--- a/src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc
+++ /dev/null
@@ -1,102 +0,0 @@
-/*
- * Copyright (c) Facebook, Inc. and its affiliates.
- * All rights reserved.
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree.
- */
-#include <iostream>
-#include "GenerateKernel.h"
-
-namespace fbgemm {
-
-namespace x86 = asmjit::x86;
-
-/**
- * Generate AVX512 instructions for initializing the C registers to 0 in 16-bit
- * Accumulation kernel.
- */
-template <>
-template <>
-void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs<
- inst_set_t::avx512_vnni>(
- x86::Emitter* a,
- int rowRegs,
- int colRegs,
- int leadingDimCReg) {
- assert(0 && "Accumulation to int16_t is not available for VNNI!");
-
- // For AVX512VNNI, redirect to int32_t accumulation.
- CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
- codeObj.initCRegs<inst_set_t::avx512_vnni>(
- a, rowRegs, colRegs, leadingDimCReg);
-}
-
-/**
- * Generate AVX512 instructions for computing block in the rank-k update of
- * 16-bit Accmulation kernel.
- */
-template <>
-template <>
-void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
- inst_set_t::avx512_vnni>(
- x86::Emitter* a,
- x86::Gp buffer_A,
- x86::Gp buffer_B,
- x86::Gp /* unused (reserved for prefetching)*/,
- int rowRegs,
- int colRegs,
- int lda,
- int leadingDimCReg) {
- assert(0 && "Accumulation to int16_t is not available for VNNI!");
-
- // For AVX512VNNI, redirect to int32_t accumulation.
- CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
- codeObj.genComputeBlock<inst_set_t::avx512_vnni>(
- a, buffer_A, buffer_B, buffer_B, rowRegs, colRegs, lda, leadingDimCReg);
-}
-
-/**
- * Generate AVX512 instructions for storing the C registers back to the memory
- * in 16-bit Accumulation kernel.
- */
-template <>
-template <>
-void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
- inst_set_t::avx512_vnni>(
- x86::Emitter* a,
- int rowRegs,
- int colRegs,
- x86::Gp C_Offset,
- x86::Gp ldcReg,
- bool accum,
- int leadingDimCReg) {
- assert(0 && "Accumulation to int16_t is not available for VNNI!");
-
- // For AVX512VNNI, redirect to int32_t accumulation.
- CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
- codeObj.storeCRegs<inst_set_t::avx512_vnni>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum, leadingDimCReg);
-}
-
-/**
- * Get or Create the AVX512 instructions for 16-bit Accumulation macro-kernel.
- *
- */
-template <>
-template <>
-CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::jit_micro_kernel_fp
-CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<
- inst_set_t::avx512_vnni>(
- bool accum,
- int32_t mc,
- int32_t nc,
- int32_t kc,
- int32_t /* unused */) {
- assert(0 && "Accumulation to int16_t is not available for VNNI!");
-
- // For AVX512VNNI, redirect to int32_t accumulation.
- CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
- return codeObj.getOrCreate<inst_set_t::avx512_vnni>(accum, mc, nc, kc, kc);
-}
-
-} // namespace fbgemm
diff --git a/src/GenerateKernelU8S8S32ACC32.cc b/src/GenerateKernelU8S8S32ACC32.cc
index 6b54743..58643ad 100644
--- a/src/GenerateKernelU8S8S32ACC32.cc
+++ b/src/GenerateKernelU8S8S32ACC32.cc
@@ -31,7 +31,7 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs<
inst_set_t::avx2>(
- x86::Emitter* a,
+ asmjit::X86Emitter* a,
int rowRegs,
int colRegs,
int leadingDimCReg) {
@@ -53,25 +53,25 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
inst_set_t::avx2>(
- x86::Emitter* a,
- x86::Gp buffer_A,
- x86::Gp buffer_B,
- x86::Gp B_pf,
+ asmjit::X86Emitter* a,
+ asmjit::X86Gp buffer_A,
+ asmjit::X86Gp buffer_B,
+ asmjit::X86Gp B_pf,
int rowRegs,
int colRegs,
int lda,
int leadingDimCReg) {
// used for matrix A
- x86::Ymm AReg = x86::ymm12;
+ asmjit::X86Ymm AReg = x86::ymm12;
// used for matrix B
- x86::Ymm BReg = x86::ymm13;
+ asmjit::X86Ymm BReg = x86::ymm13;
// Contains 16-bit 1s
- x86::Ymm oneReg = x86::ymm15;
+ asmjit::X86Ymm oneReg = x86::ymm15;
// temporary register
- x86::Ymm res1 = x86::ymm14;
+ asmjit::X86Ymm res1 = x86::ymm14;
for (int j = 0; j < colRegs; ++j) {
// load B
@@ -99,11 +99,11 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<
inst_set_t::avx2>(
- x86::Emitter* a,
+ asmjit::X86Emitter* a,
int rowRegs,
int colRegs,
- x86::Gp C_Offset,
- x86::Gp ldcReg,
+ asmjit::X86Gp C_Offset,
+ asmjit::X86Gp ldcReg,
bool accum,
int leadingDimCReg) {
for (int i = 0; i < rowRegs; ++i) {
@@ -177,9 +177,9 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>(
return codeCache_[kernelSig];
}
code_.reset(false);
- code_.init(rt_.codeInfo());
- x86::Assembler assembler(&code_);
- x86::Emitter* a = assembler.as<x86::Emitter>();
+ code_.init(rt_.getCodeInfo());
+ asmjit::X86Assembler assembler(&code_);
+ asmjit::X86Emitter* a = assembler.asEmitter();
#if defined(FBGEMM_LOG_CODE)
// generated code logging
FILE* codeLogfile = fopen(
@@ -205,48 +205,49 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>(
int mRegBlocksRem = mc % mRegBlockSize;
// arguments to the function created
- x86::Gp buffer_A = a->zdi();
- x86::Gp buffer_B = a->zsi();
- x86::Gp B_pf = a->zdx();
- x86::Gp CBase = a->zcx();
- x86::Gp kSize = a->gpz(8);
- x86::Gp ldcReg = a->gpz(9);
+ asmjit::X86Gp buffer_A = a->zdi();
+ asmjit::X86Gp buffer_B = a->zsi();
+ asmjit::X86Gp B_pf = a->zdx();
+ asmjit::X86Gp CBase = a->zcx();
+ asmjit::X86Gp kSize = a->gpzRef(8);
+ asmjit::X86Gp ldcReg = a->gpzRef(9);
asmjit::FuncDetail func;
func.init(
asmjit::
- FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
+ FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
asmjit::CallConv::kIdHost));
- asmjit::FuncFrame frame;
- frame.init(func);
- frame.setDirtyRegs(
- x86::Reg::kGroupVec,
- asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- frame.setDirtyRegs(
- x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));
+ asmjit::FuncFrameInfo ffi;
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindVec,
+ asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14));
- asmjit::FuncArgsAssignment args(&func);
+ asmjit::FuncArgsMapper args(&func);
args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
- args.updateFuncFrame(frame);
- frame.finalize();
+ args.updateFrameInfo(ffi);
- a->emitProlog(frame);
- a->emitArgsAssignment(frame, args);
+ asmjit::FuncFrameLayout layout;
+ layout.init(func, ffi);
+
+ asmjit::FuncUtils::emitProlog(a, layout);
+ asmjit::FuncUtils::allocArgs(a, layout, args);
asmjit::Label Loopk = a->newLabel();
asmjit::Label LoopMBlocks = a->newLabel();
- x86::Gp buffer_B_saved = a->gpz(10);
- x86::Gp C_Offset = a->gpz(11);
- x86::Gp B_pf_saved = a->gpz(12);
- x86::Gp iIdx = a->gpz(13);
- x86::Gp kIdx = a->gpz(14);
- // x86::Gp B_pf = a->gpz(8);
+ asmjit::X86Gp buffer_B_saved = a->gpzRef(10);
+ asmjit::X86Gp C_Offset = a->gpzRef(11);
+ asmjit::X86Gp B_pf_saved = a->gpzRef(12);
+ asmjit::X86Gp iIdx = a->gpzRef(13);
+ asmjit::X86Gp kIdx = a->gpzRef(14);
+ // asmjit::X86Gp B_pf = a->gpzRef(8);
- x86::Ymm oneReg = x86::ymm15;
+ asmjit::X86Ymm oneReg = x86::ymm15;
// create 16-bit 1s
// i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
// and so on
@@ -357,7 +358,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>(
a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
}
- a->emitEpilog(frame);
+ asmjit::FuncUtils::emitEpilog(a, layout);
jit_micro_kernel_fp fn;
asmjit::Error err = rt_.add(&fn, &code_);
diff --git a/src/GenerateKernelU8S8S32ACC32Avx512.cc b/src/GenerateKernelU8S8S32ACC32Avx512.cc
index fe35627..12243ee 100644
--- a/src/GenerateKernelU8S8S32ACC32Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC32Avx512.cc
@@ -19,7 +19,7 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs<
inst_set_t::avx512>(
- x86::Emitter* a,
+ asmjit::X86Emitter* a,
int rowRegs,
int colRegs,
int leadingDimCReg) {
@@ -41,25 +41,25 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
inst_set_t::avx512>(
- x86::Emitter* a,
- x86::Gp buffer_A,
- x86::Gp buffer_B,
- x86::Gp B_pf,
+ asmjit::X86Emitter* a,
+ asmjit::X86Gp buffer_A,
+ asmjit::X86Gp buffer_B,
+ asmjit::X86Gp B_pf,
int rowRegs,
int colRegs,
int lda,
int leadingDimCReg) {
// used for matrix A
- x86::Zmm AReg = x86::zmm31;
+ asmjit::X86Zmm AReg = x86::zmm31;
// used for matrix B
- x86::Zmm BReg = x86::zmm30;
+ asmjit::X86Zmm BReg = x86::zmm30;
// Contains 16-bit 1s
- x86::Zmm oneReg = x86::zmm29;
+ asmjit::X86Zmm oneReg = x86::zmm29;
// temporary register
- x86::Zmm res1 = x86::zmm28;
+ asmjit::X86Zmm res1 = x86::zmm28;
for (int j = 0; j < colRegs; ++j) {
// load B
@@ -87,17 +87,18 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<
inst_set_t::avx512>(
- x86::Emitter* a,
+ asmjit::X86Emitter* a,
int rowRegs,
int colRegs,
- x86::Gp C_Offset,
- x86::Gp ldcReg,
+ asmjit::X86Gp C_Offset,
+ asmjit::X86Gp ldcReg,
bool accum,
int leadingDimCReg) {
for (int i = 0; i < rowRegs; ++i) {
if (i != 0) {
a->add(C_Offset, ldcReg);
- } else {
+ }
+ else {
a->mov(C_Offset, static_cast<asmjit::Imm>(0));
}
for (int j = 0; j < colRegs; ++j) {
@@ -167,9 +168,9 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
return codeCache_[kernelSig];
}
code_.reset(false);
- code_.init(rt_.codeInfo());
- x86::Assembler assembler(&code_);
- x86::Emitter* a = assembler.as<x86::Emitter>();
+ code_.init(rt_.getCodeInfo());
+ asmjit::X86Assembler assembler(&code_);
+ asmjit::X86Emitter* a = assembler.asEmitter();
#if defined(FBGEMM_LOG_CODE)
// generated code logging
@@ -204,52 +205,52 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
int mRegBlocksRem = mc % mRegBlockSize;
// arguments to the function created
- x86::Gp buffer_A = a->zdi();
- x86::Gp buffer_B = a->zsi();
- x86::Gp B_pf = a->zdx();
- x86::Gp CBase = a->zcx();
- x86::Gp kSize = a->gpz(8);
- x86::Gp ldcReg = a->gpz(9);
+ asmjit::X86Gp buffer_A = a->zdi();
+ asmjit::X86Gp buffer_B = a->zsi();
+ asmjit::X86Gp B_pf = a->zdx();
+ asmjit::X86Gp CBase = a->zcx();
+ asmjit::X86Gp kSize = a->gpzRef(8);
+ asmjit::X86Gp ldcReg = a->gpzRef(9);
asmjit::FuncDetail func;
func.init(
asmjit::
- FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
+ FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
asmjit::CallConv::kIdHost));
- asmjit::FuncFrame frame;
- frame.init(func);
-
- frame.setDirtyRegs(
- x86::Reg::kGroupVec,
- asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- frame.setDirtyRegs(
- x86::Reg::kGroupGp,
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ asmjit::FuncFrameInfo ffi;
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindVec,
+ asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindGp,
+ asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
- asmjit::FuncArgsAssignment args(&func);
+ asmjit::FuncArgsMapper args(&func);
args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
- args.updateFuncFrame(frame);
- frame.finalize();
+ args.updateFrameInfo(ffi);
+
+ asmjit::FuncFrameLayout layout;
+ layout.init(func, ffi);
- a->emitProlog(frame);
- a->emitArgsAssignment(frame, args);
+ asmjit::FuncUtils::emitProlog(a, layout);
+ asmjit::FuncUtils::allocArgs(a, layout, args);
asmjit::Label LoopMBlocks = a->newLabel();
asmjit::Label LoopNBlocks = a->newLabel();
asmjit::Label Loopk = a->newLabel();
- x86::Gp buffer_B_saved = a->gpz(10);
- x86::Gp C_Offset = a->gpz(11);
- x86::Gp B_pf_saved = a->gpz(12);
- x86::Gp iIdx = a->gpz(13);
- x86::Gp jIdx = a->gpz(14);
- x86::Gp kIdx = a->gpz(15);
- // x86::Gp B_pf = a->gpz(8);
+ asmjit::X86Gp buffer_B_saved = a->gpzRef(10);
+ asmjit::X86Gp C_Offset = a->gpzRef(11);
+ asmjit::X86Gp B_pf_saved = a->gpzRef(12);
+ asmjit::X86Gp iIdx = a->gpzRef(13);
+ asmjit::X86Gp jIdx = a->gpzRef(14);
+ asmjit::X86Gp kIdx = a->gpzRef(15);
+ // asmjit::X86Gp B_pf = a->gpzRef(8);
- x86::Zmm oneReg = x86::zmm29;
+ asmjit::X86Zmm oneReg = x86::zmm29;
// create 16-bit 1s
// i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
// and so on
@@ -419,7 +420,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
a->jl(LoopNRem);
}
- a->emitEpilog(frame);
+ asmjit::FuncUtils::emitEpilog(a, layout);
jit_micro_kernel_fp fn;
asmjit::Error err = rt_.add(&fn, &code_);
diff --git a/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc b/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc
deleted file mode 100644
index 8ae0745..0000000
--- a/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc
+++ /dev/null
@@ -1,431 +0,0 @@
-/*
- * Copyright (c) Facebook, Inc. and its affiliates.
- * All rights reserved.
- * This source code is licensed under the BSD-style license found in the
- * LICENSE file in the root directory of this source tree.
- */
-#include <iostream>
-#include "GenerateKernel.h"
-
-namespace fbgemm {
-
-namespace x86 = asmjit::x86;
-
-/**
- * Generate AVX512 instructions for initializing the C registers to 0 in 32-bit
- * Accumulation kernel.
- */
-template <>
-template <>
-void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs<
- inst_set_t::avx512_vnni>(
- x86::Emitter* a,
- int rowRegs,
- int colRegs,
- int leadingDimCReg) {
- for (int i = 0; i < rowRegs; ++i) {
- for (int j = 0; j < colRegs; ++j) {
- a->vxorps(
- CRegs_avx512_[i * leadingDimCReg + j],
- CRegs_avx512_[i * leadingDimCReg + j],
- CRegs_avx512_[i * leadingDimCReg + j]);
- }
- }
-}
-
-/**
- * Generate AVX512 instructions for computing block in the rank-k update of
- * 32-bit Accmulation kernel.
- */
-template <>
-template <>
-void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
- inst_set_t::avx512_vnni>(
- x86::Emitter* a,
- x86::Gp buffer_A,
- x86::Gp buffer_B,
- x86::Gp B_pf,
- int rowRegs,
- int colRegs,
- int lda,
- int leadingDimCReg) {
- // used for matrix A
- x86::Zmm AReg = x86::zmm31;
-
- // used for matrix B
- x86::Zmm BReg = x86::zmm30;
-
- for (int j = 0; j < colRegs; ++j) {
- // load B
- a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
- // load A, broadcast and fmas
- for (int i = 0; i < rowRegs; ++i) {
- a->vpbroadcastd(
- AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t)));
- a->vpdpbusd(CRegs_avx512_[i * leadingDimCReg + j], AReg, BReg);
- }
- a->prefetcht0(x86::dword_ptr(B_pf, j * VLEN_ * sizeof(int8_t)));
- }
-}
-
-/**
- * Generate AVX512 instructions for storing the C registers back to the memory
- * in 32-bit Accumulation kernel.
- */
-template <>
-template <>
-void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<
- inst_set_t::avx512_vnni>(
- x86::Emitter* a,
- int rowRegs,
- int colRegs,
- x86::Gp C_Offset,
- x86::Gp ldcReg,
- bool accum,
- int leadingDimCReg) {
- for (int i = 0; i < rowRegs; ++i) {
- if (i != 0) {
- a->add(C_Offset, ldcReg);
- } else {
- a->mov(C_Offset, static_cast<asmjit::Imm>(0));
- }
- for (int j = 0; j < colRegs; ++j) {
- if (accum) {
- a->vpaddd(
- CRegs_avx512_[i * leadingDimCReg + j],
- CRegs_avx512_[i * leadingDimCReg + j],
- x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)));
- }
- a->vmovups(
- x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)),
- CRegs_avx512_[i * leadingDimCReg + j]);
- }
- }
-}
-
-/**
- * Get or Create the AVX512 instructions for 32-bit Accumulation macro-kernel.
- *
- */
-template <>
-template <>
-CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::jit_micro_kernel_fp
-CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<
- inst_set_t::avx512_vnni>(
- bool accum,
- int32_t mc,
- int32_t nc,
- int32_t kc,
- int32_t /* unused */) {
- std::tuple<bool, int, int, int, int, int, int, int> kernelSig;
- int kBlock;
- int nBlock;
- int mRegBlockSize;
- int nRegBlockSize;
- int nRegBlockSizeMin;
- int row_interleave;
-
- if (blocking_params) {
- kBlock = blocking_params->KCB;
- nBlock = blocking_params->NCB;
- mRegBlockSize = blocking_params->MR;
- nRegBlockSize = blocking_params->NR;
- nRegBlockSizeMin = blocking_params->NR_MIN;
- row_interleave = blocking_params->ROW_INTERLEAVE;
- } else {
- kBlock = PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::KCB;
- nBlock = PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::NCB;
- mRegBlockSize =
- PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::MR;
- nRegBlockSize =
- PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::NR;
- nRegBlockSizeMin =
- PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::NR_MIN;
- row_interleave = PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::
- ROW_INTERLEAVE;
- }
-
- kernelSig = std::make_tuple(
- accum,
- mc,
- nc,
- nBlock,
- kBlock,
- mRegBlockSize,
- nRegBlockSize,
- nRegBlockSizeMin);
-
- if (codeCache_.find(kernelSig) != codeCache_.end()) {
- return codeCache_[kernelSig];
- }
- code_.reset(false);
- code_.init(rt_.codeInfo());
- x86::Assembler assembler(&code_);
- x86::Emitter* a = assembler.as<x86::Emitter>();
-
-#if defined(FBGEMM_LOG_CODE)
- // generated code logging
- FILE* codeLogfile = fopen(
- getCodeLoggingFile<inst_set_t::avx512_vnni>(
- accum,
- mc,
- nc,
- nBlock,
- kBlock,
- mRegBlockSize,
- nRegBlockSize,
- nRegBlockSizeMin)
- .c_str(),
- "w");
- asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
- if (codeLogger) {
- code_.setLogger(codeLogger);
- }
-#endif
-
- assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
- assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
- int maxMRegs = mRegBlockSize;
- int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
- assert(
- maxMRegs * maxNRegs <= 28 &&
- "MR*(NR*ROW_INTERLEAVE*8/512) \
- must be <= 28(available registers constraint)");
-
- int mRegBlocks = mc / mRegBlockSize;
- int mRegBlocksRem = mc % mRegBlockSize;
-
- // arguments to the function created
- x86::Gp buffer_A = a->zdi();
- x86::Gp buffer_B = a->zsi();
- x86::Gp B_pf = a->zdx();
- x86::Gp CBase = a->zcx();
- x86::Gp kSize = a->gpz(8);
- x86::Gp ldcReg = a->gpz(9);
-
- asmjit::FuncDetail func;
- func.init(
- asmjit::
- FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
- asmjit::CallConv::kIdHost));
-
- asmjit::FuncFrame frame;
- frame.init(func);
-
- frame.setDirtyRegs(
- x86::Reg::kGroupVec,
- asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- frame.setDirtyRegs(
- x86::Reg::kGroupGp,
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
-
- asmjit::FuncArgsAssignment args(&func);
- args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
-
- args.updateFuncFrame(frame);
- frame.finalize();
-
- a->emitProlog(frame);
- a->emitArgsAssignment(frame, args);
-
- asmjit::Label LoopMBlocks = a->newLabel();
- asmjit::Label LoopNBlocks = a->newLabel();
- asmjit::Label Loopk = a->newLabel();
-
- x86::Gp buffer_B_saved = a->gpz(10);
- x86::Gp C_Offset = a->gpz(11);
- x86::Gp B_pf_saved = a->gpz(12);
- x86::Gp iIdx = a->gpz(13);
- x86::Gp jIdx = a->gpz(14);
- x86::Gp kIdx = a->gpz(15);
- // x86::Gp B_pf = a->gpz(8);
-
- x86::Zmm oneReg = x86::zmm29;
- // create 16-bit 1s
- // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
- // and so on
- // a->vpcmpeqw(oneReg, oneReg, oneReg);
- a->vpternlogd(oneReg, oneReg, oneReg, 0xff);
- a->vpsrlw(oneReg, oneReg, 15);
- a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t)));
-
- // save B_buffer address
- a->mov(buffer_B_saved, buffer_B);
- a->mov(B_pf_saved, B_pf);
-
- int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
- int colRegs = std::min(currColRegs, maxNRegs);
- if (mRegBlocks > 0) {
- // move 0 to iteration variables
- a->mov(iIdx, 0);
-
- a->bind(LoopMBlocks);
- a->inc(iIdx);
- a->mov(jIdx, 0);
-
- a->bind(LoopNBlocks);
- a->inc(jIdx);
-
- int rowRegs = mRegBlockSize;
-
- // init C registers
- initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(Loopk);
-
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx512_vnni>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(
- buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(
- buffer_B,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
- a->add(
- B_pf,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
-
- // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float)));
-
- a->cmp(kIdx, kSize);
- a->jl(Loopk);
-
- // store C matrix
- storeCRegs<inst_set_t::avx512_vnni>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
-
- // reset A
- a->sub(buffer_A, kSize);
-
- // B for next block
- a->mov(buffer_B, buffer_B_saved);
- // using C_Offset as temp reg
- a->imul(
- C_Offset,
- jIdx,
- static_cast<asmjit::Imm>(
- nRegBlockSize * row_interleave * sizeof(int8_t)));
- a->add(buffer_B, C_Offset);
- a->mov(B_pf, B_pf_saved);
- a->add(B_pf, C_Offset);
-
- // increment C for next B block
- a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
-
- int jLoopTrips = currColRegs / maxNRegs;
- // jLoopTrips should be at least 1
- jLoopTrips = jLoopTrips ? jLoopTrips : 1;
- a->cmp(jIdx, jLoopTrips);
- a->jl(LoopNBlocks);
-
- // increment A for next block
- a->add(
- buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
-
- // increment C for next A block
- a->sub(
- CBase,
- static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * sizeof(int32_t)));
- a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs));
- a->add(CBase, C_Offset);
-
- // reset B
- a->mov(buffer_B, buffer_B_saved);
- a->mov(B_pf, B_pf_saved);
- a->cmp(iIdx, mRegBlocks);
- a->jl(LoopMBlocks);
- }
- // generate code for remainder
- if (mRegBlocksRem > 0) {
- asmjit::Label LoopNRem = a->newLabel();
- asmjit::Label LoopkRem = a->newLabel();
- int rowRegs = mRegBlocksRem;
-
- a->mov(jIdx, 0);
- a->bind(LoopNRem);
- a->inc(jIdx);
-
- // init C registers
- initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs);
-
- // init k loop index
- a->mov(kIdx, 0);
- a->bind(LoopkRem);
-
- // k is incremented by row_interleave
- a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
-
- genComputeBlock<inst_set_t::avx512_vnni>(
- a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
-
- // update buffer_A address for next k iteration
- a->add(
- buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
-
- // update buffer_B address for next k iteration
- a->add(
- buffer_B,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
- a->add(
- B_pf,
- static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
-
- a->cmp(kIdx, kSize);
- a->jl(LoopkRem);
-
- // reset A
- a->sub(buffer_A, kSize);
- // B for next block
- // using C_Offset as temp reg
- a->imul(
- C_Offset,
- jIdx,
- static_cast<asmjit::Imm>(
- nRegBlockSize * row_interleave * sizeof(int8_t)));
- a->mov(buffer_B, buffer_B_saved);
- a->add(buffer_B, C_Offset);
- a->mov(B_pf, B_pf_saved);
- a->add(B_pf, C_Offset);
-
- // store C matrix
- storeCRegs<inst_set_t::avx512_vnni>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
-
- // increment C for next B block
- a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
-
- int jLoopTrips = currColRegs / maxNRegs;
- // jLoopTrips should be at least 1
- jLoopTrips = jLoopTrips ? jLoopTrips : 1;
- a->cmp(jIdx, jLoopTrips);
- a->jl(LoopNRem);
- }
-
- a->emitEpilog(frame);
-
- jit_micro_kernel_fp fn;
- asmjit::Error err = rt_.add(&fn, &code_);
- if (err) {
- std::cout << "Error: in fn add" << std::endl;
- return nullptr;
- }
- codeCache_[kernelSig] = fn;
-
-#if defined(FBGEMM_LOG_CODE)
- fclose(codeLogfile);
- delete codeLogger;
-#endif
-
- return fn;
-}
-
-} // namespace fbgemm
diff --git a/src/GroupwiseConv.h b/src/GroupwiseConv.h
index 4c5eea5..1e6324e 100644
--- a/src/GroupwiseConv.h
+++ b/src/GroupwiseConv.h
@@ -128,58 +128,60 @@ class GenConvKernel {
const conv_param_t<SPATIAL_DIM>& conv_param);
template <inst_set_t instSet>
- void createVector16BitOne(x86::Emitter* a);
+ void createVector16BitOne(asmjit::X86Emitter* a);
template <inst_set_t instSet>
- void createVector8BitOne(x86::Emitter* a);
+ void createVector8BitOne(asmjit::X86Emitter* a);
template <inst_set_t instSet>
- void setToZeroPt(x86::Emitter* a, x86::Ymm destReg);
+ void setToZeroPt(asmjit::X86Emitter* a, asmjit::X86Ymm destReg);
template <inst_set_t instSet>
- void gen8bitFMA(x86::Emitter* a, x86::Ymm aReg, x86::Ymm wReg);
+ void
+ gen8bitFMA(asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm wReg);
template <inst_set_t instSet>
- void genForLoadingWeights(x86::Emitter* a, int c_offset);
+ void genForLoadingWeights(asmjit::X86Emitter* a, int c_offset);
template <inst_set_t instSet>
- void genConstForPermutations(x86::Emitter* a);
+ void genConstForPermutations(asmjit::X86Emitter* a);
template <inst_set_t instSet>
- void genForTopEdge(x86::Emitter* a, int c_offset);
+ void genForTopEdge(asmjit::X86Emitter* a, int c_offset);
template <inst_set_t instSet>
- void genForLeftEdge(x86::Emitter* a, int c_offset);
+ void genForLeftEdge(asmjit::X86Emitter* a, int c_offset);
template <inst_set_t instSet>
- void genForRightEdge(x86::Emitter* a, int c_offset);
+ void genForRightEdge(asmjit::X86Emitter* a, int c_offset);
template <inst_set_t instSet>
- void genForBottomEdge(x86::Emitter* a, int c_offset);
+ void genForBottomEdge(asmjit::X86Emitter* a, int c_offset);
template <inst_set_t instSet>
- void genCoreInsts(x86::Emitter* a, int c_offset);
+ void genCoreInsts(asmjit::X86Emitter* a, int c_offset);
template <inst_set_t instSet>
- void storeResult(x86::Emitter* a);
+ void storeResult(asmjit::X86Emitter* a);
// for Rowoffset kernel
// Add 4 consecutive numbers of 32 uint8 and emit 8 32-bit
template <inst_set_t instSet>
- void gen8BitSumX4(x86::Emitter* a, x86::Ymm aReg);
+ void gen8BitSumX4(asmjit::X86Emitter* a, asmjit::X86Ymm aReg);
// Add 8 consecutive numbers of 64 uint8 and emit 8 32-bit
template <inst_set_t instSet>
- void gen8BitSumX8(x86::Emitter* a, x86::Ymm aReg, x86::Ymm bReg);
+ void
+ gen8BitSumX8(asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm bReg);
// Add 16 consecutive numbers of 128 uint8 and emit 8 32-bit
template <inst_set_t instSet>
void gen8BitSumX16(
- x86::Emitter* a,
- x86::Ymm aReg,
- x86::Ymm bReg,
- x86::Ymm cReg,
- x86::Ymm dReg);
+ asmjit::X86Emitter* a,
+ asmjit::X86Ymm aReg,
+ asmjit::X86Ymm bReg,
+ asmjit::X86Ymm cReg,
+ asmjit::X86Ymm dReg);
// Generate instruction sequence that loads 8-bit values and sum them up.
// Depending on C_per_G_, this function dispatches to gen8BitSumX4/8/16
@@ -189,33 +191,35 @@ class GenConvKernel {
// Internally, actRegAvx2_, stPermRegAvx2_, WRegs_avx2_[0, 1], tmpReg1Avx2_,
// and resultRegAvx2_ are used.
template <inst_set_t instSet>
- void
- gen8BitSum(x86::Emitter* a, int act_offset, bool use_scratch_reg1 = true);
+ void gen8BitSum(
+ asmjit::X86Emitter* a,
+ int act_offset,
+ bool use_scratch_reg1 = true);
// Use scratchReg1_ and tmpReg1Avx2_ internally
template <inst_set_t instSet>
- void genZeroPtSum(x86::Emitter* a, int multiplier);
+ void genZeroPtSum(asmjit::X86Emitter* a, int multiplier);
template <inst_set_t instSet>
- void genForTopEdgeRowoffset(x86::Emitter* a);
+ void genForTopEdgeRowoffset(asmjit::X86Emitter* a);
template <inst_set_t instSet>
- void genForLeftEdgeRowoffset(x86::Emitter* a);
+ void genForLeftEdgeRowoffset(asmjit::X86Emitter* a);
template <inst_set_t instSet>
- void genForRightEdgeRowoffset(x86::Emitter* a);
+ void genForRightEdgeRowoffset(asmjit::X86Emitter* a);
template <inst_set_t instSet>
- void genForBottomEdgeRowoffset(x86::Emitter* a);
+ void genForBottomEdgeRowoffset(asmjit::X86Emitter* a);
template <inst_set_t instSet>
- void genRowoffsetCorners(x86::Emitter* a);
+ void genRowoffsetCorners(asmjit::X86Emitter* a);
template <inst_set_t instSet>
- void genRowoffsetCore(x86::Emitter* a);
+ void genRowoffsetCore(asmjit::X86Emitter* a);
template <inst_set_t instSet>
- void storeResultRowoffset(x86::Emitter* a, int offset = 0);
+ void storeResultRowoffset(asmjit::X86Emitter* a, int offset = 0);
static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit.
static thread_local asmjit::CodeHolder code_; ///< JIT Code Holder for asmjit.
@@ -230,30 +234,30 @@ class GenConvKernel {
int vectorWidth_; ///< Vector width in bits.
int VLEN_; ///< Vector width in elements.
// avx2 specific
- x86::Ymm
+ asmjit::X86Ymm
WRegs_avx2_[9]; ///< AVX2 ymm registers for weights in the micro-kernel.
- x86::Ymm zeroPTRegAvx2_;
- x86::Ymm tmpReg1Avx2_;
- x86::Ymm stPermRegAvx2_;
- x86::Ymm actRegAvx2_;
- x86::Ymm resultRegAvx2_;
- x86::Ymm oneReg8BitAvx2_;
- x86::Ymm oneReg16BitAvx2_;
+ asmjit::X86Ymm zeroPTRegAvx2_;
+ asmjit::X86Ymm tmpReg1Avx2_;
+ asmjit::X86Ymm stPermRegAvx2_;
+ asmjit::X86Ymm actRegAvx2_;
+ asmjit::X86Ymm resultRegAvx2_;
+ asmjit::X86Ymm oneReg8BitAvx2_;
+ asmjit::X86Ymm oneReg16BitAvx2_;
// arguments to the function created
- x86::Gp in_acts_R_;
- x86::Gp wghts_R_;
- x86::Gp out_acts_R_;
- x86::Gp a_zero_pt_R_;
- x86::Gp H_R_;
- x86::Gp W_R_;
- x86::Gp row_offset_R_;
+ asmjit::X86Gp in_acts_R_;
+ asmjit::X86Gp wghts_R_;
+ asmjit::X86Gp out_acts_R_;
+ asmjit::X86Gp a_zero_pt_R_;
+ asmjit::X86Gp H_R_;
+ asmjit::X86Gp W_R_;
+ asmjit::X86Gp row_offset_R_;
// Used registers
- x86::Gp loopR1_;
- x86::Gp loopR2_;
- x86::Gp scratchReg1_;
- x86::Gp scratchReg2_;
+ asmjit::X86Gp loopR1_;
+ asmjit::X86Gp loopR2_;
+ asmjit::X86Gp scratchReg1_;
+ asmjit::X86Gp scratchReg2_;
// Other parameters
bool isAZeroPointZero_;
diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc
index b140c83..e789695 100644
--- a/src/GroupwiseConvAcc32Avx2.cc
+++ b/src/GroupwiseConvAcc32Avx2.cc
@@ -104,7 +104,7 @@ jit_conv_kernel_fp getOrCreateConvKernel(
template <>
template <>
void GenConvKernel<2, int32_t>::createVector8BitOne<inst_set_t::avx2>(
- x86::Emitter* a) {
+ asmjit::X86Emitter* a) {
// create 8-bit 1s
// i.e., oneReg16BitAvx2_[0:7] contains 0x01, oneReg8BitAvx2_[8:15] contains
// 0x01 and so on
@@ -115,7 +115,7 @@ void GenConvKernel<2, int32_t>::createVector8BitOne<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::createVector16BitOne<inst_set_t::avx2>(
- x86::Emitter* a) {
+ asmjit::X86Emitter* a) {
// create 16-bit 1s
// i.e., oneReg16BitAvx2_[0:15] contains 0x0001, oneReg16BitAvx2_[16:31]
// contains 0x0001 and so on
@@ -125,11 +125,11 @@ void GenConvKernel<2, int32_t>::createVector16BitOne<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::setToZeroPt<inst_set_t::avx2>(
- x86::Emitter* a,
- x86::Ymm destReg) {
+ asmjit::X86Emitter* a,
+ asmjit::X86Ymm destReg) {
// make destReg all zeros
a->vxorps(destReg, destReg, destReg);
- x86::Xmm const_reg_xmm = x86::xmm10;
+ asmjit::X86Xmm const_reg_xmm = x86::xmm10;
// move zero point to xmm10
a->movq(const_reg_xmm, a_zero_pt_R_);
// make copies of zero point
@@ -143,9 +143,9 @@ void GenConvKernel<2, int32_t>::setToZeroPt<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genConstForPermutations<inst_set_t::avx2>(
- x86::Emitter* a) {
- x86::Gp permute_const_reg = a->gpz(12);
- x86::Xmm const_reg_xmm = x86::xmm10;
+ asmjit::X86Emitter* a) {
+ asmjit::X86Gp permute_const_reg = a->gpzRef(12);
+ asmjit::X86Xmm const_reg_xmm = x86::xmm10;
// We have 1st group in even lanes and 2nd group in odd lanes.
// Permute to put 1st group to lower 128-bit and 2nd group in upper
// 128-bit.
@@ -159,7 +159,8 @@ void GenConvKernel<2, int32_t>::genConstForPermutations<inst_set_t::avx2>(
template <>
template <>
-void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>(x86::Emitter* a) {
+void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>(
+ asmjit::X86Emitter* a) {
if (C_per_G_ == 4) {
// store with permutation
a->vpermd(resultRegAvx2_, stPermRegAvx2_, resultRegAvx2_);
@@ -170,7 +171,7 @@ void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>(x86::Emitter* a) {
template <>
template <>
void GenConvKernel<2, int32_t>::storeResultRowoffset<inst_set_t::avx2>(
- x86::Emitter* a,
+ asmjit::X86Emitter* a,
int offset) {
// store
if (C_per_G_ == 4) {
@@ -197,7 +198,7 @@ void GenConvKernel<2, int32_t>::storeResultRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForLoadingWeights<inst_set_t::avx2>(
- x86::Emitter* a,
+ asmjit::X86Emitter* a,
int c_offset) {
// load weights
for (int r = 0; r < R_; ++r) {
@@ -224,9 +225,9 @@ void GenConvKernel<2, int32_t>::genForLoadingWeights<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8bitFMA<inst_set_t::avx2>(
- x86::Emitter* a,
- x86::Ymm aReg,
- x86::Ymm wReg) {
+ asmjit::X86Emitter* a,
+ asmjit::X86Ymm aReg,
+ asmjit::X86Ymm wReg) {
a->vpmaddubsw(tmpReg1Avx2_, aReg, wReg);
a->vpmaddwd(tmpReg1Avx2_, oneReg16BitAvx2_, tmpReg1Avx2_);
a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_);
@@ -235,8 +236,8 @@ void GenConvKernel<2, int32_t>::gen8bitFMA<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8BitSumX4<inst_set_t::avx2>(
- x86::Emitter* a,
- x86::Ymm aReg) {
+ asmjit::X86Emitter* a,
+ asmjit::X86Ymm aReg) {
a->vpmaddubsw(tmpReg1Avx2_, aReg, oneReg8BitAvx2_);
a->vpmaddwd(tmpReg1Avx2_, tmpReg1Avx2_, oneReg16BitAvx2_);
a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_);
@@ -245,9 +246,9 @@ void GenConvKernel<2, int32_t>::gen8BitSumX4<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8BitSumX8<inst_set_t::avx2>(
- x86::Emitter* a,
- x86::Ymm aReg,
- x86::Ymm bReg) {
+ asmjit::X86Emitter* a,
+ asmjit::X86Ymm aReg,
+ asmjit::X86Ymm bReg) {
a->vxorps(tmpReg1Avx2_, tmpReg1Avx2_, tmpReg1Avx2_);
// Let a[0] denote 0th (LSB) 8-bit of aReg
// After vpsadbw, a[0:2] = a[0] + ... + a[7]
@@ -266,11 +267,11 @@ void GenConvKernel<2, int32_t>::gen8BitSumX8<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8BitSumX16<inst_set_t::avx2>(
- x86::Emitter* a,
- x86::Ymm aReg,
- x86::Ymm bReg,
- x86::Ymm cReg,
- x86::Ymm dReg) {
+ asmjit::X86Emitter* a,
+ asmjit::X86Ymm aReg,
+ asmjit::X86Ymm bReg,
+ asmjit::X86Ymm cReg,
+ asmjit::X86Ymm dReg) {
a->vxorps(tmpReg1Avx2_, tmpReg1Avx2_, tmpReg1Avx2_);
// After vpsadbw, a[0:2] = a[0] + ... + a[7]
// a[8:10] = a[8] + ... + a[15]
@@ -318,7 +319,7 @@ void GenConvKernel<2, int32_t>::gen8BitSumX16<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8BitSum<inst_set_t::avx2>(
- x86::Emitter* a,
+ asmjit::X86Emitter* a,
int act_offset,
bool use_scratch_reg1 /*=true*/) {
if (use_scratch_reg1) {
@@ -384,11 +385,11 @@ void GenConvKernel<2, int32_t>::gen8BitSum<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genZeroPtSum<inst_set_t::avx2>(
- x86::Emitter* a,
+ asmjit::X86Emitter* a,
int multiplier) {
a->mov(scratchReg1_, static_cast<asmjit::Imm>(multiplier));
// tmpReg1Avx2_ also uses xmm11
- x86::Xmm const_reg_xmm = x86::xmm11;
+ asmjit::X86Xmm const_reg_xmm = x86::xmm11;
a->movq(const_reg_xmm, scratchReg1_);
a->vpbroadcastd(tmpReg1Avx2_, const_reg_xmm);
a->vpmulld(tmpReg1Avx2_, zeroPTRegAvx2_, tmpReg1Avx2_);
@@ -398,7 +399,7 @@ void GenConvKernel<2, int32_t>::genZeroPtSum<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForTopEdge<inst_set_t::avx2>(
- x86::Emitter* a,
+ asmjit::X86Emitter* a,
int c_offset) {
// top-left corner code
if (c_offset == 0) {
@@ -558,7 +559,7 @@ void GenConvKernel<2, int32_t>::genForTopEdge<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForLeftEdge<inst_set_t::avx2>(
- x86::Emitter* a,
+ asmjit::X86Emitter* a,
int c_offset) {
// left edge excluding corners
asmjit::Label LoopLeftEdge = a->newLabel();
@@ -625,7 +626,7 @@ void GenConvKernel<2, int32_t>::genForLeftEdge<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForRightEdge<inst_set_t::avx2>(
- x86::Emitter* a,
+ asmjit::X86Emitter* a,
int c_offset) {
// right edge excluding corners
asmjit::Label LoopRightEdge = a->newLabel();
@@ -713,7 +714,7 @@ void GenConvKernel<2, int32_t>::genForRightEdge<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForBottomEdge<inst_set_t::avx2>(
- x86::Emitter* a,
+ asmjit::X86Emitter* a,
int c_offset) {
// bottom-left corner
// we updating the last row
@@ -905,7 +906,7 @@ void GenConvKernel<2, int32_t>::genForBottomEdge<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genCoreInsts<inst_set_t::avx2>(
- x86::Emitter* a,
+ asmjit::X86Emitter* a,
int c_offset) {
// main compute
asmjit::Label LoopH = a->newLabel();
@@ -1010,9 +1011,9 @@ template <>
jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
const conv_param_t<2>& conv_param) {
code_.reset(false);
- code_.init(rt_.codeInfo());
- x86::Assembler assembler(&code_);
- x86::Emitter* a = assembler.as<x86::Emitter>();
+ code_.init(rt_.getCodeInfo());
+ asmjit::X86Assembler assembler(&code_);
+ asmjit::X86Emitter* a = assembler.asEmitter();
#if defined(FBGEMM_LOG_CODE)
// log code to a file
@@ -1029,16 +1030,16 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
wghts_R_ = a->zsi();
out_acts_R_ = a->zdx();
a_zero_pt_R_ = a->zcx();
- H_R_ = a->gpz(8);
- W_R_ = a->gpz(9);
- row_offset_R_ = a->gpz(10);
+ H_R_ = a->gpzRef(8);
+ W_R_ = a->gpzRef(9);
+ row_offset_R_ = a->gpzRef(10);
// register for temporary use
- scratchReg1_ = a->gpz(12);
- scratchReg2_ = a->gpz(13);
+ scratchReg1_ = a->gpzRef(12);
+ scratchReg2_ = a->gpzRef(13);
asmjit::FuncDetail func;
- func.init(asmjit::FuncSignatureT<
+ func.init(asmjit::FuncSignature6<
void,
uint8_t*,
int8_t*,
@@ -1047,29 +1048,29 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
int32_t,
int32_t>(asmjit::CallConv::kIdHost));
- asmjit::FuncFrame frame;
- frame.init(func);
+ asmjit::FuncFrameInfo ffi;
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindVec,
+ asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindGp, asmjit::Utils::mask(10, 11, 12, 13, 14, 15));
- frame.setDirtyRegs(
- x86::Reg::kGroupVec,
- asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- frame.setDirtyRegs(
- x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
-
- asmjit::FuncArgsAssignment args(&func);
+ asmjit::FuncArgsMapper args(&func);
args.assignAll(in_acts_R_, wghts_R_, out_acts_R_, a_zero_pt_R_, H_R_, W_R_);
- args.updateFuncFrame(frame);
- frame.finalize();
+ args.updateFrameInfo(ffi);
+
+ asmjit::FuncFrameLayout layout;
+ layout.init(func, ffi);
- a->emitProlog(frame);
- a->emitArgsAssignment(frame, args);
+ asmjit::FuncUtils::emitProlog(a, layout);
+ asmjit::FuncUtils::allocArgs(a, layout, args);
createVector16BitOne<inst_set_t::avx2>(a);
- loopR1_ = a->gpz(14);
- loopR2_ = a->gpz(15);
+ loopR1_ = a->gpzRef(14);
+ loopR2_ = a->gpzRef(15);
if (!isAZeroPointZero_) {
setToZeroPt<inst_set_t::avx2>(a, zeroPTRegAvx2_);
@@ -1094,7 +1095,7 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
genCoreInsts<inst_set_t::avx2>(a, c);
}
- a->emitEpilog(frame);
+ asmjit::FuncUtils::emitEpilog(a, layout);
jit_conv_kernel_fp fn;
asmjit::Error err = rt_.add(&fn, &code_);
@@ -1116,7 +1117,7 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>(
- x86::Emitter* a) {
+ asmjit::X86Emitter* a) {
// top-left corner code
// zero out the results register
a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
@@ -1212,7 +1213,7 @@ void GenConvKernel<2, int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>(
- x86::Emitter* a) {
+ asmjit::X86Emitter* a) {
// left edge excluding corners
asmjit::Label LoopLeftEdge = a->newLabel();
a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_));
@@ -1255,7 +1256,7 @@ void GenConvKernel<2, int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>(
- x86::Emitter* a) {
+ asmjit::X86Emitter* a) {
// right edge excluding corners
asmjit::Label LoopRightEdge = a->newLabel();
@@ -1325,7 +1326,7 @@ void GenConvKernel<2, int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
- x86::Emitter* a) {
+ asmjit::X86Emitter* a) {
// bottom-left corner
// zero out
a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
@@ -1428,7 +1429,7 @@ void GenConvKernel<2, int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genRowoffsetCore<inst_set_t::avx2>(
- x86::Emitter* a) {
+ asmjit::X86Emitter* a) {
// number of uint8 elements in input channels should be a multiple of 32
assert(C_ % 32 == 0);
@@ -1490,9 +1491,9 @@ jit_rowoffset_kernel_fp
GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
const conv_param_t<2>& conv_param) {
code_.reset(false);
- code_.init(rt_.codeInfo());
- x86::Assembler assembler(&code_);
- x86::Emitter* a = assembler.as<x86::Emitter>();
+ code_.init(rt_.getCodeInfo());
+ asmjit::X86Assembler assembler(&code_);
+ asmjit::X86Emitter* a = assembler.asEmitter();
#if defined(FBGEMM_LOG_CODE)
// log code to a file
@@ -1509,45 +1510,45 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
a_zero_pt_R_ = a->zsi();
H_R_ = a->zdx();
W_R_ = a->zcx();
- row_offset_R_ = a->gpz(8);
+ row_offset_R_ = a->gpzRef(8);
// register for temporary use
- scratchReg1_ = a->gpz(12);
- scratchReg2_ = a->gpz(13);
+ scratchReg1_ = a->gpzRef(12);
+ scratchReg2_ = a->gpzRef(13);
- loopR1_ = a->gpz(14);
- loopR2_ = a->gpz(15);
+ loopR1_ = a->gpzRef(14);
+ loopR2_ = a->gpzRef(15);
asmjit::FuncDetail func;
func.init(
asmjit::
- FuncSignatureT<void, uint8_t*, int32_t, int32_t, int32_t, int32_t*>(
+ FuncSignature5<void, uint8_t*, int32_t, int32_t, int32_t, int32_t*>(
asmjit::CallConv::kIdHost));
- asmjit::FuncFrame frame;
- frame.init(func);
+ asmjit::FuncFrameInfo ffi;
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindVec,
+ asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindGp, asmjit::Utils::mask(10, 11, 12, 13, 14, 15));
- frame.setDirtyRegs(
- x86::Reg::kGroupVec,
- asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- frame.setDirtyRegs(
- x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
-
- asmjit::FuncArgsAssignment args(&func);
+ asmjit::FuncArgsMapper args(&func);
args.assignAll(in_acts_R_, a_zero_pt_R_, H_R_, W_R_, row_offset_R_);
- args.updateFuncFrame(frame);
- frame.finalize();
+ args.updateFrameInfo(ffi);
+
+ asmjit::FuncFrameLayout layout;
+ layout.init(func, ffi);
- a->emitProlog(frame);
- a->emitArgsAssignment(frame, args);
+ asmjit::FuncUtils::emitProlog(a, layout);
+ asmjit::FuncUtils::allocArgs(a, layout, args);
// This uses xmm10 register temporarily. Should come before
// createVector8BitOne
if (!isAZeroPointZero_) {
// we can use xmm11 because ymm11 is used by tmpReg1Avx2_
- x86::Xmm const_reg_xmm = x86::xmm11;
+ asmjit::X86Xmm const_reg_xmm = x86::xmm11;
a->movq(const_reg_xmm, a_zero_pt_R_);
a->vpbroadcastd(zeroPTRegAvx2_, const_reg_xmm);
@@ -1568,7 +1569,7 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
genRowoffsetCore<inst_set_t::avx2>(a);
- a->emitEpilog(frame);
+ asmjit::FuncUtils::emitEpilog(a, layout);
jit_rowoffset_kernel_fp fn;
asmjit::Error err = rt_.add(&fn, &code_);
diff --git a/src/PackAMatrix.cc b/src/PackAMatrix.cc
index 5fabf97..143e11d 100644
--- a/src/PackAMatrix.cc
+++ b/src/PackAMatrix.cc
@@ -34,8 +34,7 @@ PackAMatrix<T, accT>::PackAMatrix(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
- if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
- !fbgemmHasAvx2Support())) {
+ if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) {
assert(0 && "unknown architecure");
}
@@ -44,12 +43,7 @@ PackAMatrix<T, accT>::PackAMatrix(
BaseType::bcol_ = params->KCB;
row_interleave_B_ = params->ROW_INTERLEAVE;
} else {
- if (fbgemmHasAvx512VnniSupport()) {
- BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
- BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB;
- row_interleave_B_ =
- PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE;
- } else if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512Support()) {
BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
row_interleave_B_ =
diff --git a/src/PackAWithIm2Col.cc b/src/PackAWithIm2Col.cc
index 2aca27d..d731654 100644
--- a/src/PackAWithIm2Col.cc
+++ b/src/PackAWithIm2Col.cc
@@ -49,8 +49,7 @@ PackAWithIm2Col<T, accT, SPATIAL_DIM>::PackAWithIm2Col(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
- if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
- !fbgemmHasAvx2Support())) {
+ if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) {
assert(0 && "unknown architecure");
}
@@ -59,12 +58,7 @@ PackAWithIm2Col<T, accT, SPATIAL_DIM>::PackAWithIm2Col(
BaseType::bcol_ = params->KCB;
row_interleave_B_ = params->ROW_INTERLEAVE;
} else {
- if (fbgemmHasAvx512VnniSupport()) {
- BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
- BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB;
- row_interleave_B_ =
- PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE;
- } else if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512Support()) {
BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
row_interleave_B_ =
@@ -484,9 +478,7 @@ int PackAWithIm2Col<T, accT, SPATIAL_DIM>::rowOffsetBufferSize(
if (params) {
return params->MCB;
} else {
- if (fbgemmHasAvx512VnniSupport()) {
- return PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
- } else if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512Support()) {
return PackingTraits<T, accT, inst_set_t::avx512>::MCB;
} else if (fbgemmHasAvx2Support()) {
return PackingTraits<T, accT, inst_set_t::avx2>::MCB;
diff --git a/src/PackAWithQuantRowOffset.cc b/src/PackAWithQuantRowOffset.cc
index 0af05e8..0e5c598 100644
--- a/src/PackAWithQuantRowOffset.cc
+++ b/src/PackAWithQuantRowOffset.cc
@@ -45,8 +45,7 @@ PackAWithQuantRowOffset<T, accT>::PackAWithQuantRowOffset(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
- if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
- !fbgemmHasAvx2Support())) {
+ if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) {
assert(0 && "unknown architecure");
}
@@ -55,12 +54,7 @@ PackAWithQuantRowOffset<T, accT>::PackAWithQuantRowOffset(
BaseType::bcol_ = params->KCB;
row_interleave_B_ = params->ROW_INTERLEAVE;
} else {
- if (fbgemmHasAvx512VnniSupport()) {
- BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
- BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB;
- row_interleave_B_ =
- PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE;
- } else if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512Support()) {
BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
row_interleave_B_ =
@@ -205,9 +199,7 @@ int PackAWithQuantRowOffset<T, accT>::rowOffsetBufferSize(
if (params) {
return params->MCB;
} else {
- if (fbgemmHasAvx512VnniSupport()) {
- return PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
- } else if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512Support()) {
return PackingTraits<T, accT, inst_set_t::avx512>::MCB;
} else if (fbgemmHasAvx2Support()) {
return PackingTraits<T, accT, inst_set_t::avx2>::MCB;
diff --git a/src/PackAWithRowOffset.cc b/src/PackAWithRowOffset.cc
index e84c67b..733bf5c 100644
--- a/src/PackAWithRowOffset.cc
+++ b/src/PackAWithRowOffset.cc
@@ -39,8 +39,7 @@ PackAWithRowOffset<T, accT>::PackAWithRowOffset(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
- if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
- !fbgemmHasAvx2Support())) {
+ if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) {
assert(0 && "unknown architecure");
}
@@ -49,12 +48,7 @@ PackAWithRowOffset<T, accT>::PackAWithRowOffset(
BaseType::bcol_ = params->KCB;
row_interleave_B_ = params->ROW_INTERLEAVE;
} else {
- if (fbgemmHasAvx512VnniSupport()) {
- BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
- BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB;
- row_interleave_B_ =
- PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE;
- } else if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512Support()) {
BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
row_interleave_B_ =
@@ -195,9 +189,7 @@ int PackAWithRowOffset<T, accT>::rowOffsetBufferSize(
if (params) {
return params->MCB;
} else {
- if (fbgemmHasAvx512VnniSupport()) {
- return PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
- } else if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512Support()) {
return PackingTraits<T, accT, inst_set_t::avx512>::MCB;
} else if (fbgemmHasAvx2Support()) {
return PackingTraits<T, accT, inst_set_t::avx2>::MCB;
diff --git a/src/PackBMatrix.cc b/src/PackBMatrix.cc
index bf43fab..0990edb 100644
--- a/src/PackBMatrix.cc
+++ b/src/PackBMatrix.cc
@@ -188,8 +188,7 @@ PackBMatrix<T, accT>::PackBMatrix(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
- if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
- !fbgemmHasAvx2Support())) {
+ if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) {
assert(0 && "unknown architecure");
}
@@ -198,12 +197,7 @@ PackBMatrix<T, accT>::PackBMatrix(
BaseType::bcol_ = params->NCB;
row_interleave_ = params->ROW_INTERLEAVE;
} else {
- if (fbgemmHasAvx512VnniSupport()) {
- BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB;
- BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::NCB;
- row_interleave_ =
- PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE;
- } else if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512Support()) {
BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::NCB;
row_interleave_ =
@@ -323,16 +317,14 @@ void PackBMatrix<T, accT>::pack_unpack_(
}
template <typename T, typename accT>
-void PackBMatrix<T, accT>::pack(
- const block_type_t& block,
- const BlockingFactors* params) {
+void PackBMatrix<T, accT>::pack(const block_type_t& block,
+ const BlockingFactors* params) {
pack_unpack_(block, const_cast<T*>(smat_), BaseType::getBuf(), true, params);
}
template <typename T, typename accT>
-void PackBMatrix<T, accT>::unpack(
- T* origin_buf,
- const BlockingFactors* params) {
+void PackBMatrix<T, accT>::unpack(T* origin_buf,
+ const BlockingFactors* params) {
block_type_t blockB{BaseType::packedRowStart(),
BaseType::numPackedRows(),
BaseType::packedColStart(),
@@ -360,9 +352,8 @@ int32_t PackBMatrix<T, accT>::addr(int32_t r, int32_t c) const {
}
template <typename T, typename accT>
-void PackBMatrix<T, accT>::printPackedMatrix(
- std::string name,
- const BlockingFactors* params) {
+void PackBMatrix<T, accT>::printPackedMatrix(std::string name,
+ const BlockingFactors* params) {
std::cout << name << ":"
<< "[" << BaseType::numPackedRows() << ", "
<< BaseType::numPackedCols() << "]" << std::endl;
diff --git a/src/PackMatrix.cc b/src/PackMatrix.cc
index ff7b842..c7503dd 100644
--- a/src/PackMatrix.cc
+++ b/src/PackMatrix.cc
@@ -36,8 +36,7 @@ int PackMatrix<PT, inpType, accType>::packedBufferSize(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
- if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
- !fbgemmHasAvx2Support())) {
+ if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) {
assert(0 && "unknown architecure");
}
@@ -47,11 +46,7 @@ int PackMatrix<PT, inpType, accType>::packedBufferSize(
NCB = params->NCB;
KCB = params->KCB;
} else {
- if (fbgemmHasAvx512VnniSupport()) {
- MCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni>::MCB;
- NCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni>::NCB;
- KCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni>::KCB;
- } else if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512Support()) {
MCB = PackingTraits<inpType, accType, inst_set_t::avx512>::MCB;
NCB = PackingTraits<inpType, accType, inst_set_t::avx512>::NCB;
KCB = PackingTraits<inpType, accType, inst_set_t::avx512>::KCB;
diff --git a/src/PackWeightMatrixForGConv.cc b/src/PackWeightMatrixForGConv.cc
index f6ad59e..ba6adf3 100644
--- a/src/PackWeightMatrixForGConv.cc
+++ b/src/PackWeightMatrixForGConv.cc
@@ -106,7 +106,7 @@ inline int PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::packed_index_(
* on 2 groups at a time and full SIMD width can be efficiently utilized even
* while working on 1 group at a time.
* In this case, the layout is G (C/4) R S K 4
- */
+*/
template <typename T, typename accT, int SPATIAL_DIM>
void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack_unpack_(
@@ -148,9 +148,9 @@ void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack_unpack_(
if (ispack) {
transposeConvWeights(conv_param_, src, dst);
} else {
- // TODO: Wrap this as a inverseTransposeConvWeights()?
- // For unpack & transposed, call transposeConvWeights()
- // G (R S C/G) K/G => G K/G (R S C/G)
+ // TODO: Wrap this as a inverseTransposeConvWeights()?
+ // For unpack & transposed, call transposeConvWeights()
+ // G (R S C/G) K/G => G K/G (R S C/G)
for (int r = 0; r < R; ++r) {
for (int s = 0; s < S; ++s) {
for (int k = 0; k < OC_per_G; ++k) {
diff --git a/src/Utils.cc b/src/Utils.cc
index 5214e41..0fa620d 100644
--- a/src/Utils.cc
+++ b/src/Utils.cc
@@ -202,7 +202,4 @@ bool fbgemmHasAvx2Support() {
return (cpuinfo_has_x86_avx2());
}
-bool fbgemmHasAvx512VnniSupport() {
- return (cpuinfo_has_x86_avx512vnni());
-}
} // namespace fbgemm
diff --git a/test/GConvTest.cc b/test/GConvTest.cc
index 8c1fb82..0074535 100644
--- a/test/GConvTest.cc
+++ b/test/GConvTest.cc
@@ -465,8 +465,8 @@ TEST_P(fbgemmGConvPackTest, PackUnpackTest) {
for (int i = 0; i < weight_len; ++i) {
EXPECT_EQ(Bint8.data()[i], unpack_buf.data()[i])
<< "Pack/Unpack results differ at index " << i
- << ", Reference: " << static_cast<int>(Bint8.data()[i])
- << ", Pack-Unpacked: " << static_cast<int>(unpack_buf.data()[i]);
+ << ", Reference: " << static_cast<int> (Bint8.data()[i])
+ << ", Pack-Unpacked: " << static_cast<int> (unpack_buf.data()[i]);
}
} // for each shape
}
diff --git a/third_party/asmjit b/third_party/asmjit
-Subproject 5d40561d14f93dc45613bfa03155d1dfb4f5825
+Subproject 673dcefaa048c5f5a2bf8b85daf8f7b9978d018