Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYoung Jin Kim <youki@microsoft.com>2019-08-15 02:01:38 +0300
committerYoung Jin Kim <youki@microsoft.com>2019-08-15 02:01:38 +0300
commitbb5063533256a8a5a91a812f6a193d7f352a2a3a (patch)
tree1e64f7127e589ea32a01785198c1cff8fa2813dd
parenteb8fede25bd048da6fd396654936703a474f0504 (diff)
parenta6d1d3eed7ba858d4532fc297b7a4ee984e6e7e3 (diff)
Merge branch 'upstream/master' into youki/prepack_constrcopyPublic
-rw-r--r--CMakeLists.txt2
-rw-r--r--README.md6
-rw-r--r--include/fbgemm/Fbgemm.h19
-rw-r--r--include/fbgemm/FbgemmFP16.h32
-rw-r--r--include/fbgemm/PackingTraits-inl.h50
-rw-r--r--include/fbgemm/Utils.h18
-rw-r--r--src/ExecuteKernelU8S8.cc47
-rw-r--r--src/Fbgemm.cc18
-rw-r--r--src/FbgemmConv.cc11
-rw-r--r--src/GenerateKernel.h30
-rw-r--r--src/GenerateKernelU8S8S32ACC16.cc91
-rw-r--r--src/GenerateKernelU8S8S32ACC16Avx512.cc94
-rw-r--r--src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc102
-rw-r--r--src/GenerateKernelU8S8S32ACC32.cc87
-rw-r--r--src/GenerateKernelU8S8S32ACC32Avx512.cc95
-rw-r--r--src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc431
-rw-r--r--src/GroupwiseConv.h100
-rw-r--r--src/GroupwiseConvAcc32Avx2.cc177
-rw-r--r--src/PackAMatrix.cc10
-rw-r--r--src/PackAWithIm2Col.cc14
-rw-r--r--src/PackAWithQuantRowOffset.cc14
-rw-r--r--src/PackAWithRowOffset.cc14
-rw-r--r--src/PackBMatrix.cc35
-rw-r--r--src/PackMatrix.cc9
-rw-r--r--src/PackWeightMatrixForGConv.cc8
-rw-r--r--src/PackWeightsForConv.cc68
-rw-r--r--src/Utils.cc3
-rw-r--r--test/FP16Test.cc116
-rw-r--r--test/GConvTest.cc4
-rw-r--r--test/PackedRequantizeAcc16Test.cc83
-rw-r--r--test/PackedRequantizeTest.cc83
m---------third_party/asmjit0
32 files changed, 1398 insertions, 473 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index e6c7419..0460799 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -37,8 +37,10 @@ set(FBGEMM_GENERIC_SRCS src/ExecuteKernel.cc
src/FbgemmI8Spmdm.cc
src/GenerateKernelU8S8S32ACC16.cc
src/GenerateKernelU8S8S32ACC16Avx512.cc
+ src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc
src/GenerateKernelU8S8S32ACC32.cc
src/GenerateKernelU8S8S32ACC32Avx512.cc
+ src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc
src/GroupwiseConvAcc32Avx2.cc
src/PackAMatrix.cc
src/PackAWithIm2Col.cc
diff --git a/README.md b/README.md
index d287c44..5f3ca40 100644
--- a/README.md
+++ b/README.md
@@ -12,9 +12,9 @@ row-wise quantization and outlier-aware quantization. FBGEMM also exploits
fusion opportunities in order to overcome the unique challenges of matrix
multiplication at lower precision with bandwidth-bound operations.
-FBGEMM is used as a backend of Caffe2 quantized operators for x86 machines
-(https://github.com/pytorch/pytorch/tree/master/caffe2/quantization/server).
-We also plan to integrate FBGEMM into PyTorch.
+FBGEMM is used as a backend of Caffe2 and PyTorch quantized operators for x86 machines:
+* Caffe2: https://github.com/pytorch/pytorch/tree/master/caffe2/quantization/server
+* PyTorch: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu
## Examples
diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h
index 7f428ed..70f6294 100644
--- a/include/fbgemm/Fbgemm.h
+++ b/include/fbgemm/Fbgemm.h
@@ -458,14 +458,17 @@ class FBGEMM_API PackBMatrix final
std::int32_t addr(std::int32_t i, std::int32_t j) const;
/**
- * @brief Packs a block of source matrix into pmat buffer.
+ * @brief Packs a block of source matrix into pmat buffer. The blocking
+ * parameters are needed to compute the buffer size of each group.
+ * It will use default blocking parameters if params is not provided.
*/
- void pack(const block_type_t& block);
+ void pack(const block_type_t& block, const BlockingFactors* params = nullptr);
/**
* @brief Print the packed block.
*/
- void printPackedMatrix(std::string name);
+ void printPackedMatrix(std::string name,
+ const BlockingFactors* params = nullptr);
/**
* @return true if meta information like matrix shape is the same.
@@ -480,7 +483,7 @@ class FBGEMM_API PackBMatrix final
* @brief Unpack pmat buffer to the origin_buf (Used for the serialization to
* recover weight matrix).
*/
- void unpack(T* origin_buf);
+ void unpack(T* origin_buf, const BlockingFactors* params = nullptr);
~PackBMatrix() {}
@@ -497,7 +500,8 @@ class FBGEMM_API PackBMatrix final
const block_type_t& block,
T* unpack_buf,
T* pack_buf,
- bool ispack);
+ bool ispack,
+ const BlockingFactors* params = nullptr);
};
/**
@@ -645,6 +649,11 @@ class FBGEMM_API PackWeightsForConv {
bool isPackingCompliant(const conv_param_t<SPATIAL_DIM>& conv_p);
/**
+ * @brief Returns a string of mismatching parameters
+ */
+ std::string mismatchingParams(const conv_param_t<SPATIAL_DIM>& conv_p);
+
+ /**
* @brief Unpack packed matric into origin_buf (Used for the serialization to
* recover weight matrix).
*/
diff --git a/include/fbgemm/FbgemmFP16.h b/include/fbgemm/FbgemmFP16.h
index 3d84977..8da0b56 100644
--- a/include/fbgemm/FbgemmFP16.h
+++ b/include/fbgemm/FbgemmFP16.h
@@ -104,6 +104,14 @@ class PackedGemmMatrixFP16 {
}
}
+ void setPacked(bool p) {
+ packed_ = p;
+ }
+
+ bool packed() const {
+ return packed_;
+ }
+
void initializeMemory() {
// allocate and initialize packed memory
const int padding = 1024; // required by sw pipelined kernels
@@ -128,6 +136,16 @@ class PackedGemmMatrixFP16 {
#endif
}
+ void unpackFromSrc(const matrix_op_t trans, float16* src_mat) {
+ bool tr = (trans == matrix_op_t::Transpose);
+ for (int i = 0; i < numRows(); i++) {
+ for (int j = 0; j < numCols(); j++) {
+ pmat_[tr ? i + numRows() * j : i * numCols() + j] = src_mat[addr(i, j)];
+ }
+ }
+ packed_ = false;
+ }
+
// protected:
// blocked row-major format address arithmetic
uint64_t addr(const int r_, const int c_) const {
@@ -163,6 +181,19 @@ class PackedGemmMatrixFP16 {
pmat_[addr(i, j)]);
}
}
+ packed_ = true;
+ }
+
+ // This function takes in an unpacked float16 matrix of the same size and
+ // packs it. There is no floating type conversion.
+ void packFromSrc(const matrix_op_t trans, const float16* smat) {
+ bool tr = (trans == matrix_op_t::Transpose);
+ for (int i = 0; i < numRows(); ++i) {
+ for (int j = 0; j < numCols(); ++j) {
+ pmat_[addr(i, j)] = smat[tr ? i + numRows() * j : i * numCols() + j];
+ }
+ }
+ packed_ = true;
}
const float16& operator()(const int r, const int c) const {
@@ -210,6 +241,7 @@ class PackedGemmMatrixFP16 {
uint64_t size_;
int kernel_ncol_blocks_;
float16* pmat_;
+ bool packed_{false};
friend void cblas_gemm_compute(
const matrix_op_t transa,
diff --git a/include/fbgemm/PackingTraits-inl.h b/include/fbgemm/PackingTraits-inl.h
index 76eb425..baccfad 100644
--- a/include/fbgemm/PackingTraits-inl.h
+++ b/include/fbgemm/PackingTraits-inl.h
@@ -222,3 +222,53 @@ struct PackingTraits<
128}; ///< Cache block for N dimension (multiple of NR).
static constexpr int KCB{256}; ///< Cache block for K dimension.
};
+
+/**
+ * @brief Helper struct to type specialize for int16_t and int32_t together.
+ */
+template <typename T>
+struct is_16or32bit {
+ static constexpr bool value =
+ std::is_same<T, int16_t>::value || std::is_same<T, int32_t>::value;
+};
+
+/**
+ * @brief Packing parameter specialization for accumulation into 32-bit/16-bit
+ * integers.
+ *
+ * Since there is no int16_t accumulation for AVX512 VNNI, we redirect int16_t
+ * to int32_t accumulation and use the same blocking parameters as int32_t.
+ *
+ * This is picked when T is of int8 type (signed or unsigned) and instruction
+ * set is avx512_vnni.
+ */
+template <typename T, typename accT>
+struct PackingTraits<
+ T,
+ accT,
+ inst_set_t::avx512_vnni,
+ typename std::enable_if<
+ is_8bit<T>::value && is_16or32bit<accT>::value>::type> {
+ static constexpr int MR{8}; ///< Register block for M dimension.
+ static constexpr int NR_MIN{
+ 16}; ///< Minimum register block for N dimension.
+ ///< 16 because 16*ROW_INTERLEAVE int8 elements
+ ///< completely fill a 512-bit wide vector.
+ static constexpr int NR{
+ 32}; ///< Register block for N dimension.
+ ///< Must be a multiple of 16 because 16*ROW_INTERLEAVE int8 elements
+ ///< completely fill a 512-bit wide vector. Total registers used for
+ ///< N dimension: NR*ROW_INTERLEAVE*8/VLEN. We use MR x
+ ///< NR*ROW_INTERLEAVE*8/VLEN zmm registers
+ ///< for C accumulations.
+
+ static constexpr int ROW_INTERLEAVE{
+ 4}; ///< 4 rows are interleaved to use vpmaddubsw instruction for packing
+ ///< B matrix.
+
+ static constexpr int MCB{
+ 128}; ///< Cache block for M dimension (multiple of MR).
+ static constexpr int NCB{
+ 32}; ///< Cache block for N dimension (multiple of NR).
+ static constexpr int KCB{256}; ///< Cache block for K dimension.
+};
diff --git a/include/fbgemm/Utils.h b/include/fbgemm/Utils.h
index eac0bcd..3976790 100644
--- a/include/fbgemm/Utils.h
+++ b/include/fbgemm/Utils.h
@@ -5,6 +5,7 @@
* LICENSE file in the root directory of this source tree.
*/
#pragma once
+#include <array>
#include <string>
#include <type_traits>
#include "FbgemmBuild.h"
@@ -39,7 +40,7 @@ enum class matrix_op_t { NoTranspose, Transpose };
/**
* @brief Typed enum for supported instruction sets.
*/
-enum class inst_set_t { anyarch, avx2, avx512 };
+enum class inst_set_t { anyarch, avx2, avx512, avx512_vnni };
/**
* @brief Typed enum for optimized paths for convolutions
@@ -110,6 +111,11 @@ FBGEMM_API bool fbgemmHasAvx512Support();
FBGEMM_API bool fbgemmHasAvx2Support();
/**
+ * @brief Are we running on a AVX512_VNNI supported cpu?
+ */
+FBGEMM_API bool fbgemmHasAvx512VnniSupport();
+
+/**
* @brief Helper struct to enable autotuning of FBGEMM packing and kernels.
*
* This structure is optional. If not used, the default values for these
@@ -126,6 +132,16 @@ struct FBGEMM_API BlockingFactors {
int NCB;
};
+template <int SIZE, typename T = std::int32_t>
+FBGEMM_API std::string arrayToString(const std::array<T, SIZE>& inp) {
+ std::string out = "[";
+ for (int i = 0; i < SIZE; ++i) {
+ out += std::to_string(inp[i]);
+ out += (i != SIZE - 1) ? std::string(", ") : std::string("]");
+ }
+ return out;
+}
+
template <typename accT = std::int32_t>
FBGEMM_API bool isValidBlockingFactor(BlockingFactors* param) {
constexpr bool is_32bit = std::is_same<accT, int32_t>::value;
diff --git a/src/ExecuteKernelU8S8.cc b/src/ExecuteKernelU8S8.cc
index f7292fd..0a4ff55 100644
--- a/src/ExecuteKernelU8S8.cc
+++ b/src/ExecuteKernelU8S8.cc
@@ -49,7 +49,8 @@ ExecuteKernel<
throw std::runtime_error("Failed to initialize cpuinfo!");
}
if (params) {
- if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) {
+ if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support() ||
+ fbgemmHasAvx2Support()) {
mbSize_ = params->MCB;
nbSize_ = params->NCB;
nrMinSize_ = params->NR_MIN;
@@ -59,7 +60,20 @@ ExecuteKernel<
assert(0 && "unsupported architecure");
}
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ mbSize_ = PackingTraits<
+ int8_t,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512_vnni>::MCB;
+ nbSize_ = PackingTraits<
+ int8_t,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512_vnni>::NCB;
+ nrMinSize_ = PackingTraits<
+ int8_t,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512_vnni>::NR_MIN;
+ } else if (fbgemmHasAvx512Support()) {
mbSize_ = PackingTraits<
int8_t,
typename packingAMatrix::accType,
@@ -118,7 +132,25 @@ void ExecuteKernel<
typename BaseType::jit_micro_kernel_fp fn;
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ if (std::is_same<typename packingAMatrix::accType, std::int16_t>::value) {
+ // For AVX512VNNI, we redirect int16_t to int32_t accumulation.
+ CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
+ fn = codeObj.getOrCreate<inst_set_t::avx512_vnni>(
+ accum,
+ packed_rows_A,
+ packedB_.blockColSize(),
+ packedA_.numPackedCols(),
+ nbSize_);
+ } else {
+ fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni>(
+ accum,
+ packed_rows_A,
+ packedB_.blockColSize(),
+ packedA_.numPackedCols(),
+ nbSize_);
+ }
+ } else if (fbgemmHasAvx512Support()) {
fn = BaseType::template getOrCreate<inst_set_t::avx512>(
accum,
packed_rows_A,
@@ -148,7 +180,10 @@ void ExecuteKernel<
if (jb == bColBlocks - 1) {
int nc = ((packedB_.lastBcol() - 1) / nrMinSize_ + 1) * nrMinSize_;
if (nc != nbSize_) {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni>(
+ accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_);
+ } else if (fbgemmHasAvx512Support()) {
fn = BaseType::template getOrCreate<inst_set_t::avx512>(
accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_);
} else if (fbgemmHasAvx2Support()) {
@@ -213,7 +248,7 @@ void ExecuteKernel<
int32_t nSize =
C_buffer_start == C_tile_ ? jb * nbSize_ : packedB_.numCols();
if (nSize) {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support()) {
// TODO: avx512 path
// Currently use avx2 code
outputProcess_.template f<inst_set_t::avx2>(
@@ -238,7 +273,7 @@ void ExecuteKernel<
if (C_buffer_start == C_tile_) {
// When C_tile_ scratchpad was used to avoid accessing memory past
// C_buffer_ .
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support()) {
// TODO: avx512 path
// Currently use avx2 code
outputProcess_.template f<inst_set_t::avx2>(
diff --git a/src/Fbgemm.cc b/src/Fbgemm.cc
index 2f641ee..1052044 100644
--- a/src/Fbgemm.cc
+++ b/src/Fbgemm.cc
@@ -48,7 +48,8 @@ void fbgemmPacked(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
- if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) {
+ if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
+ !fbgemmHasAvx2Support())) {
assert(0 && "unknown architecure");
}
@@ -62,7 +63,20 @@ void fbgemmPacked(
MR = blocking_params->MR;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ MCB = PackingTraits<
+ typename packingAMatrix::inpType,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512_vnni>::MCB;
+ KCB = PackingTraits<
+ typename packingAMatrix::inpType,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512_vnni>::KCB;
+ MR = PackingTraits<
+ typename packingAMatrix::inpType,
+ typename packingAMatrix::accType,
+ inst_set_t::avx512_vnni>::MR;
+ } else if (fbgemmHasAvx512Support()) {
MCB = PackingTraits<
typename packingAMatrix::inpType,
typename packingAMatrix::accType,
diff --git a/src/FbgemmConv.cc b/src/FbgemmConv.cc
index 027e6c5..33d1535 100644
--- a/src/FbgemmConv.cc
+++ b/src/FbgemmConv.cc
@@ -73,9 +73,14 @@ int fbgemmConv(
"Only 2D and 3D convolutions are supported");
if (!packed_weights.isPackingCompliant(conv_p)) {
- throw std::logic_error(
- "[FBGEMM_CONV_ERROR] Prepacked weights can't be used"
- " with these convolution parameters!");
+ std::string msg =
+ "[FBGEMM_CONV_ERROR] Convolution parameters "
+ "mismatch between pre-packed weights and conv invocation! ";
+ msg += packed_weights.mismatchingParams(conv_p);
+ msg += std::string(
+ " Please pack weights using the same parameters "
+ "with which convolution operation is invoked!");
+ throw std::logic_error(msg);
}
switch (ConvFastPath<SPATIAL_DIM, ACC_T>(conv_p)) {
diff --git a/src/GenerateKernel.h b/src/GenerateKernel.h
index dccdfc5..e52097e 100644
--- a/src/GenerateKernel.h
+++ b/src/GenerateKernel.h
@@ -18,7 +18,7 @@ namespace fbgemm {
namespace x86 = asmjit::x86;
/**
- * @brief AVX2/AVX512 JIT assembly code generator.
+ * @brief AVX2/AVX512/AVX512VNNI JIT assembly code generator.
* @tparam TA Type of matrix A.
* @tparam TB Type of matrix B.
* @tparam TC Type of matrix C.
@@ -104,7 +104,7 @@ class CodeGenBase {
*/
template <inst_set_t instSet>
void initCRegs(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
int leadingDimCRegAssign = 4);
@@ -114,10 +114,10 @@ class CodeGenBase {
*/
template <inst_set_t instSet>
void genComputeBlock(
- asmjit::X86Emitter* a,
- asmjit::X86Gp buffer_A,
- asmjit::X86Gp buffer_B,
- asmjit::X86Gp B_pf,
+ x86::Emitter* a,
+ x86::Gp buffer_A,
+ x86::Gp buffer_B,
+ x86::Gp B_pf,
int rowRegs,
int colRegs,
int lda,
@@ -129,11 +129,11 @@ class CodeGenBase {
*/
template <inst_set_t instSet>
void storeCRegs(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
- asmjit::X86Gp C_Offset,
- asmjit::X86Gp ldcReg,
+ x86::Gp C_Offset,
+ x86::Gp ldcReg,
bool accum,
int leadingDimCRegAssign = 4);
@@ -168,7 +168,9 @@ class CodeGenBase {
fileName += "_MR-" + std::to_string(MR);
fileName += "_NR-" + std::to_string(NR);
fileName += "_NR_MIN-" + std::to_string(NR_MIN);
- if (instSet == inst_set_t::avx512) {
+ if (instSet == inst_set_t::avx512_vnni) {
+ fileName += "_avx512vnni";
+ } else if (instSet == inst_set_t::avx512) {
fileName += "_avx512";
} else if (instSet == inst_set_t::avx2) {
fileName += "_avx2";
@@ -178,12 +180,10 @@ class CodeGenBase {
}
private:
- asmjit::X86Ymm
- CRegs_avx2_[12]; ///< AVX2 ymm registers for C in the micro-kernel.
- asmjit::X86Zmm
+ x86::Ymm CRegs_avx2_[12]; ///< AVX2 ymm registers for C in the micro-kernel.
+ x86::Zmm
CRegs_avx512_[28]; ///< AVX512 zmm registers for C in the micro-kernel.
- asmjit::X86Zmm
- AllRegs_avx512_[32]; ///< all AVX512 zmm registers.
+ x86::Zmm AllRegs_avx512_[32]; ///< all AVX512 zmm registers.
int vectorWidth_; ///< Vector width in bits.
int VLEN_; ///< Vector width in elements.
diff --git a/src/GenerateKernelU8S8S32ACC16.cc b/src/GenerateKernelU8S8S32ACC16.cc
index 718b883..1e7e081 100644
--- a/src/GenerateKernelU8S8S32ACC16.cc
+++ b/src/GenerateKernelU8S8S32ACC16.cc
@@ -31,7 +31,7 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs<
inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
int leadingDimCReg) {
@@ -53,18 +53,18 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
inst_set_t::avx2>(
- asmjit::X86Emitter* a,
- asmjit::X86Gp buffer_A,
- asmjit::X86Gp buffer_B,
- asmjit::X86Gp /* unused (reserved for prefetching)*/,
+ x86::Emitter* a,
+ x86::Gp buffer_A,
+ x86::Gp buffer_B,
+ x86::Gp /* unused (reserved for prefetching)*/,
int rowRegs,
int colRegs,
int lda,
int leadingDimCReg) {
// used for matrix A
- asmjit::X86Ymm AReg = x86::ymm12;
+ x86::Ymm AReg = x86::ymm12;
- asmjit::X86Ymm tmpReg = x86::ymm14;
+ x86::Ymm tmpReg = x86::ymm14;
for (int i = 0; i < rowRegs; ++i) {
// broadcast A
@@ -95,15 +95,15 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
- asmjit::X86Gp C_Offset,
- asmjit::X86Gp ldcReg,
+ x86::Gp C_Offset,
+ x86::Gp ldcReg,
bool accum,
int leadingDimCReg) {
- asmjit::X86Xmm extractDest128 = x86::xmm15;
- asmjit::X86Ymm extractDest256 = x86::ymm15;
+ x86::Xmm extractDest128 = x86::xmm15;
+ x86::Ymm extractDest256 = x86::ymm15;
for (int i = 0; i < rowRegs; ++i) {
a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(i * sizeof(int32_t)));
@@ -112,7 +112,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
a->vextracti128(
extractDest128, CRegs_avx2_[i * leadingDimCReg + j], idx);
a->vpmovsxwd(extractDest256, extractDest128);
- asmjit::X86Mem destAddr = x86::dword_ptr(
+ x86::Mem destAddr = x86::dword_ptr(
a->zcx(), C_Offset, 0, (j * 2 + idx) * 8 * sizeof(int32_t));
if (accum) {
a->vpaddd(extractDest256, extractDest256, destAddr);
@@ -176,9 +176,9 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
return codeCache_[kernelSig];
}
code_.reset(false);
- code_.init(rt_.getCodeInfo());
- asmjit::X86Assembler assembler(&code_);
- asmjit::X86Emitter* a = assembler.asEmitter();
+ code_.init(rt_.codeInfo());
+ x86::Assembler assembler(&code_);
+ x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
// generated code logging
@@ -207,46 +207,45 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
//"nc must be equal to the number of register blocks");
// arguments to the function created
- asmjit::X86Gp buffer_A = a->zdi();
- asmjit::X86Gp buffer_B = a->zsi();
- asmjit::X86Gp B_pf = a->zdx();
- asmjit::X86Gp CBase = a->zcx();
- asmjit::X86Gp kSize = a->gpzRef(8);
- asmjit::X86Gp ldcReg = a->gpzRef(9);
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
asmjit::FuncDetail func;
func.init(
asmjit::
- FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
+ FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
asmjit::CallConv::kIdHost));
- asmjit::FuncFrameInfo ffi;
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindVec,
- asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14));
+ asmjit::FuncFrame frame;
+ frame.init(func);
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(
+ x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));
- asmjit::FuncArgsMapper args(&func);
+ asmjit::FuncArgsAssignment args(&func);
args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
- args.updateFrameInfo(ffi);
+ args.updateFuncFrame(frame);
+ frame.finalize();
- asmjit::FuncFrameLayout layout;
- layout.init(func, ffi);
-
- asmjit::FuncUtils::emitProlog(a, layout);
- asmjit::FuncUtils::allocArgs(a, layout, args);
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
asmjit::Label Loopk = a->newLabel();
asmjit::Label LoopMBlocks = a->newLabel();
- asmjit::X86Gp buffer_B_saved = a->gpzRef(10);
- asmjit::X86Gp C_Offset = a->gpzRef(11);
- // asmjit::X86Gp B_pf_saved = a->gpzRef(12);
- asmjit::X86Gp iIdx = a->gpzRef(13);
- asmjit::X86Gp kIdx = a->gpzRef(14);
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ // x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp kIdx = a->gpz(14);
int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
if (mRegBlocks > 0) {
@@ -289,8 +288,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
a->jl(Loopk);
// store C matrix
- storeCRegs<inst_set_t::avx2>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum);
+ storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum);
// increment A for next block
a->sub(buffer_A, kSize);
@@ -340,11 +338,10 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
a->jl(LoopkRem);
// store C matrix
- storeCRegs<inst_set_t::avx2>(
- a, rowRegs, colRegs, C_Offset, ldcReg, accum);
+ storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum);
}
- asmjit::FuncUtils::emitEpilog(a, layout);
+ a->emitEpilog(frame);
jit_micro_kernel_fp fn;
asmjit::Error err = rt_.add(&fn, &code_);
diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc
index c95757b..a49e440 100644
--- a/src/GenerateKernelU8S8S32ACC16Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc
@@ -19,7 +19,7 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs<
inst_set_t::avx512>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
int leadingDimCReg) {
@@ -41,18 +41,18 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
inst_set_t::avx512>(
- asmjit::X86Emitter* a,
- asmjit::X86Gp buffer_A,
- asmjit::X86Gp buffer_B,
- asmjit::X86Gp /* unused (reserved for prefetching)*/,
+ x86::Emitter* a,
+ x86::Gp buffer_A,
+ x86::Gp buffer_B,
+ x86::Gp /* unused (reserved for prefetching)*/,
int rowRegs,
int colRegs,
int lda,
int leadingDimCReg) {
// used for matrix A
- asmjit::X86Zmm AReg = x86::zmm29;
+ x86::Zmm AReg = x86::zmm29;
- asmjit::X86Zmm tmpReg = x86::zmm30;
+ x86::Zmm tmpReg = x86::zmm30;
// We start allocating BRegs from zmm27 and then allocate zmm26 and so on.
for (int j = 0; j < colRegs; ++j) {
@@ -66,8 +66,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
a->vpbroadcastw(
AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t)));
for (int j = 0; j < colRegs; ++j) {
- a->vpmaddubsw(
- tmpReg, AReg, AllRegs_avx512_[27-j]);
+ a->vpmaddubsw(tmpReg, AReg, AllRegs_avx512_[27 - j]);
a->vpaddsw(
CRegs_avx512_[i * leadingDimCReg + j],
tmpReg,
@@ -90,15 +89,16 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
inst_set_t::avx512>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
- asmjit::X86Gp C_Offset,
- asmjit::X86Gp ldcReg,
+ x86::Gp C_Offset,
+ x86::Gp ldcReg,
+
bool accum,
int leadingDimCReg) {
- asmjit::X86Ymm extractDest256 = x86::ymm31;
- asmjit::X86Zmm extractDest512 = x86::zmm31;
+ x86::Ymm extractDest256 = x86::ymm31;
+ x86::Zmm extractDest512 = x86::zmm31;
for (int i = 0; i < rowRegs; ++i) {
a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(i * sizeof(int32_t)));
@@ -107,7 +107,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
a->vextracti32x8(
extractDest256, CRegs_avx512_[i * leadingDimCReg + j], idx);
a->vpmovsxwd(extractDest512, extractDest256);
- asmjit::X86Mem destAddr = x86::dword_ptr(
+ x86::Mem destAddr = x86::dword_ptr(
a->zcx(), C_Offset, 0, (j * 2 + idx) * 16 * sizeof(int32_t));
if (accum) {
a->vpaddd(extractDest512, extractDest512, destAddr);
@@ -172,9 +172,9 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
}
code_.reset(false);
- code_.init(rt_.getCodeInfo());
- asmjit::X86Assembler assembler(&code_);
- asmjit::X86Emitter* a = assembler.asEmitter();
+ code_.init(rt_.codeInfo());
+ x86::Assembler assembler(&code_);
+ x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
// generated code logging
@@ -209,49 +209,49 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
int mRegBlocksRem = mc % mRegBlockSize;
// arguments to the function created
- asmjit::X86Gp buffer_A = a->zdi();
- asmjit::X86Gp buffer_B = a->zsi();
- asmjit::X86Gp B_pf = a->zdx();
- asmjit::X86Gp CBase = a->zcx();
- asmjit::X86Gp kSize = a->gpzRef(8);
- asmjit::X86Gp ldcReg = a->gpzRef(9);
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
asmjit::FuncDetail func;
func.init(
asmjit::
- FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
+ FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
asmjit::CallConv::kIdHost));
- asmjit::FuncFrameInfo ffi;
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindVec,
- asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindGp,
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
+ asmjit::FuncFrame frame;
+ frame.init(func);
- asmjit::FuncArgsMapper args(&func);
- args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(
+ x86::Reg::kGroupGp,
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- args.updateFrameInfo(ffi);
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
- asmjit::FuncFrameLayout layout;
- layout.init(func, ffi);
+ args.updateFuncFrame(frame);
+ frame.finalize();
- asmjit::FuncUtils::emitProlog(a, layout);
- asmjit::FuncUtils::allocArgs(a, layout, args);
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
asmjit::Label LoopMBlocks = a->newLabel();
asmjit::Label LoopNBlocks = a->newLabel();
asmjit::Label Loopk = a->newLabel();
- asmjit::X86Gp buffer_B_saved = a->gpzRef(10);
- asmjit::X86Gp C_Offset = a->gpzRef(11);
- // asmjit::X86Gp B_pf_saved = a->gpzRef(12);
- asmjit::X86Gp iIdx = a->gpzRef(13);
- asmjit::X86Gp jIdx = a->gpzRef(14);
- asmjit::X86Gp kIdx = a->gpzRef(15);
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ // x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp jIdx = a->gpz(14);
+ x86::Gp kIdx = a->gpz(15);
// save B_buffer address
a->mov(buffer_B_saved, buffer_B);
@@ -407,7 +407,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
a->jl(LoopNRem);
}
- asmjit::FuncUtils::emitEpilog(a, layout);
+ a->emitEpilog(frame);
jit_micro_kernel_fp fn;
asmjit::Error err = rt_.add(&fn, &code_);
diff --git a/src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc b/src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc
new file mode 100644
index 0000000..f559aba
--- /dev/null
+++ b/src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc
@@ -0,0 +1,102 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <iostream>
+#include "GenerateKernel.h"
+
+namespace fbgemm {
+
+namespace x86 = asmjit::x86;
+
+/**
+ * Generate AVX512 instructions for initializing the C registers to 0 in 16-bit
+ * Accumulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs<
+ inst_set_t::avx512_vnni>(
+ x86::Emitter* a,
+ int rowRegs,
+ int colRegs,
+ int leadingDimCReg) {
+ assert(0 && "Accumulation to int16_t is not available for VNNI!");
+
+ // For AVX512VNNI, redirect to int32_t accumulation.
+ CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
+ codeObj.initCRegs<inst_set_t::avx512_vnni>(
+ a, rowRegs, colRegs, leadingDimCReg);
+}
+
+/**
+ * Generate AVX512 instructions for computing block in the rank-k update of
+ * 16-bit Accmulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
+ inst_set_t::avx512_vnni>(
+ x86::Emitter* a,
+ x86::Gp buffer_A,
+ x86::Gp buffer_B,
+ x86::Gp /* unused (reserved for prefetching)*/,
+ int rowRegs,
+ int colRegs,
+ int lda,
+ int leadingDimCReg) {
+ assert(0 && "Accumulation to int16_t is not available for VNNI!");
+
+ // For AVX512VNNI, redirect to int32_t accumulation.
+ CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
+ codeObj.genComputeBlock<inst_set_t::avx512_vnni>(
+ a, buffer_A, buffer_B, buffer_B, rowRegs, colRegs, lda, leadingDimCReg);
+}
+
+/**
+ * Generate AVX512 instructions for storing the C registers back to the memory
+ * in 16-bit Accumulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
+ inst_set_t::avx512_vnni>(
+ x86::Emitter* a,
+ int rowRegs,
+ int colRegs,
+ x86::Gp C_Offset,
+ x86::Gp ldcReg,
+ bool accum,
+ int leadingDimCReg) {
+ assert(0 && "Accumulation to int16_t is not available for VNNI!");
+
+ // For AVX512VNNI, redirect to int32_t accumulation.
+ CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
+ codeObj.storeCRegs<inst_set_t::avx512_vnni>(
+ a, rowRegs, colRegs, C_Offset, ldcReg, accum, leadingDimCReg);
+}
+
+/**
+ * Get or Create the AVX512 instructions for 16-bit Accumulation macro-kernel.
+ *
+ */
+template <>
+template <>
+CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::jit_micro_kernel_fp
+CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<
+ inst_set_t::avx512_vnni>(
+ bool accum,
+ int32_t mc,
+ int32_t nc,
+ int32_t kc,
+ int32_t /* unused */) {
+ assert(0 && "Accumulation to int16_t is not available for VNNI!");
+
+ // For AVX512VNNI, redirect to int32_t accumulation.
+ CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
+ return codeObj.getOrCreate<inst_set_t::avx512_vnni>(accum, mc, nc, kc, kc);
+}
+
+} // namespace fbgemm
diff --git a/src/GenerateKernelU8S8S32ACC32.cc b/src/GenerateKernelU8S8S32ACC32.cc
index 58643ad..6b54743 100644
--- a/src/GenerateKernelU8S8S32ACC32.cc
+++ b/src/GenerateKernelU8S8S32ACC32.cc
@@ -31,7 +31,7 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs<
inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
int leadingDimCReg) {
@@ -53,25 +53,25 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
inst_set_t::avx2>(
- asmjit::X86Emitter* a,
- asmjit::X86Gp buffer_A,
- asmjit::X86Gp buffer_B,
- asmjit::X86Gp B_pf,
+ x86::Emitter* a,
+ x86::Gp buffer_A,
+ x86::Gp buffer_B,
+ x86::Gp B_pf,
int rowRegs,
int colRegs,
int lda,
int leadingDimCReg) {
// used for matrix A
- asmjit::X86Ymm AReg = x86::ymm12;
+ x86::Ymm AReg = x86::ymm12;
// used for matrix B
- asmjit::X86Ymm BReg = x86::ymm13;
+ x86::Ymm BReg = x86::ymm13;
// Contains 16-bit 1s
- asmjit::X86Ymm oneReg = x86::ymm15;
+ x86::Ymm oneReg = x86::ymm15;
// temporary register
- asmjit::X86Ymm res1 = x86::ymm14;
+ x86::Ymm res1 = x86::ymm14;
for (int j = 0; j < colRegs; ++j) {
// load B
@@ -99,11 +99,11 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<
inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
- asmjit::X86Gp C_Offset,
- asmjit::X86Gp ldcReg,
+ x86::Gp C_Offset,
+ x86::Gp ldcReg,
bool accum,
int leadingDimCReg) {
for (int i = 0; i < rowRegs; ++i) {
@@ -177,9 +177,9 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>(
return codeCache_[kernelSig];
}
code_.reset(false);
- code_.init(rt_.getCodeInfo());
- asmjit::X86Assembler assembler(&code_);
- asmjit::X86Emitter* a = assembler.asEmitter();
+ code_.init(rt_.codeInfo());
+ x86::Assembler assembler(&code_);
+ x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
// generated code logging
FILE* codeLogfile = fopen(
@@ -205,49 +205,48 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>(
int mRegBlocksRem = mc % mRegBlockSize;
// arguments to the function created
- asmjit::X86Gp buffer_A = a->zdi();
- asmjit::X86Gp buffer_B = a->zsi();
- asmjit::X86Gp B_pf = a->zdx();
- asmjit::X86Gp CBase = a->zcx();
- asmjit::X86Gp kSize = a->gpzRef(8);
- asmjit::X86Gp ldcReg = a->gpzRef(9);
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
asmjit::FuncDetail func;
func.init(
asmjit::
- FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
+ FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
asmjit::CallConv::kIdHost));
- asmjit::FuncFrameInfo ffi;
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindVec,
- asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14));
+ asmjit::FuncFrame frame;
+ frame.init(func);
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(
+ x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));
- asmjit::FuncArgsMapper args(&func);
+ asmjit::FuncArgsAssignment args(&func);
args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
- args.updateFrameInfo(ffi);
+ args.updateFuncFrame(frame);
+ frame.finalize();
- asmjit::FuncFrameLayout layout;
- layout.init(func, ffi);
-
- asmjit::FuncUtils::emitProlog(a, layout);
- asmjit::FuncUtils::allocArgs(a, layout, args);
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
asmjit::Label Loopk = a->newLabel();
asmjit::Label LoopMBlocks = a->newLabel();
- asmjit::X86Gp buffer_B_saved = a->gpzRef(10);
- asmjit::X86Gp C_Offset = a->gpzRef(11);
- asmjit::X86Gp B_pf_saved = a->gpzRef(12);
- asmjit::X86Gp iIdx = a->gpzRef(13);
- asmjit::X86Gp kIdx = a->gpzRef(14);
- // asmjit::X86Gp B_pf = a->gpzRef(8);
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp kIdx = a->gpz(14);
+ // x86::Gp B_pf = a->gpz(8);
- asmjit::X86Ymm oneReg = x86::ymm15;
+ x86::Ymm oneReg = x86::ymm15;
// create 16-bit 1s
// i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
// and so on
@@ -358,7 +357,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>(
a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
}
- asmjit::FuncUtils::emitEpilog(a, layout);
+ a->emitEpilog(frame);
jit_micro_kernel_fp fn;
asmjit::Error err = rt_.add(&fn, &code_);
diff --git a/src/GenerateKernelU8S8S32ACC32Avx512.cc b/src/GenerateKernelU8S8S32ACC32Avx512.cc
index 12243ee..fe35627 100644
--- a/src/GenerateKernelU8S8S32ACC32Avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC32Avx512.cc
@@ -19,7 +19,7 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs<
inst_set_t::avx512>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
int leadingDimCReg) {
@@ -41,25 +41,25 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
inst_set_t::avx512>(
- asmjit::X86Emitter* a,
- asmjit::X86Gp buffer_A,
- asmjit::X86Gp buffer_B,
- asmjit::X86Gp B_pf,
+ x86::Emitter* a,
+ x86::Gp buffer_A,
+ x86::Gp buffer_B,
+ x86::Gp B_pf,
int rowRegs,
int colRegs,
int lda,
int leadingDimCReg) {
// used for matrix A
- asmjit::X86Zmm AReg = x86::zmm31;
+ x86::Zmm AReg = x86::zmm31;
// used for matrix B
- asmjit::X86Zmm BReg = x86::zmm30;
+ x86::Zmm BReg = x86::zmm30;
// Contains 16-bit 1s
- asmjit::X86Zmm oneReg = x86::zmm29;
+ x86::Zmm oneReg = x86::zmm29;
// temporary register
- asmjit::X86Zmm res1 = x86::zmm28;
+ x86::Zmm res1 = x86::zmm28;
for (int j = 0; j < colRegs; ++j) {
// load B
@@ -87,18 +87,17 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<
inst_set_t::avx512>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int rowRegs,
int colRegs,
- asmjit::X86Gp C_Offset,
- asmjit::X86Gp ldcReg,
+ x86::Gp C_Offset,
+ x86::Gp ldcReg,
bool accum,
int leadingDimCReg) {
for (int i = 0; i < rowRegs; ++i) {
if (i != 0) {
a->add(C_Offset, ldcReg);
- }
- else {
+ } else {
a->mov(C_Offset, static_cast<asmjit::Imm>(0));
}
for (int j = 0; j < colRegs; ++j) {
@@ -168,9 +167,9 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
return codeCache_[kernelSig];
}
code_.reset(false);
- code_.init(rt_.getCodeInfo());
- asmjit::X86Assembler assembler(&code_);
- asmjit::X86Emitter* a = assembler.asEmitter();
+ code_.init(rt_.codeInfo());
+ x86::Assembler assembler(&code_);
+ x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
// generated code logging
@@ -205,52 +204,52 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
int mRegBlocksRem = mc % mRegBlockSize;
// arguments to the function created
- asmjit::X86Gp buffer_A = a->zdi();
- asmjit::X86Gp buffer_B = a->zsi();
- asmjit::X86Gp B_pf = a->zdx();
- asmjit::X86Gp CBase = a->zcx();
- asmjit::X86Gp kSize = a->gpzRef(8);
- asmjit::X86Gp ldcReg = a->gpzRef(9);
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
asmjit::FuncDetail func;
func.init(
asmjit::
- FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
+ FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
asmjit::CallConv::kIdHost));
- asmjit::FuncFrameInfo ffi;
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindVec,
- asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindGp,
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
+ asmjit::FuncFrame frame;
+ frame.init(func);
- asmjit::FuncArgsMapper args(&func);
- args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(
+ x86::Reg::kGroupGp,
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- args.updateFrameInfo(ffi);
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
- asmjit::FuncFrameLayout layout;
- layout.init(func, ffi);
+ args.updateFuncFrame(frame);
+ frame.finalize();
- asmjit::FuncUtils::emitProlog(a, layout);
- asmjit::FuncUtils::allocArgs(a, layout, args);
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
asmjit::Label LoopMBlocks = a->newLabel();
asmjit::Label LoopNBlocks = a->newLabel();
asmjit::Label Loopk = a->newLabel();
- asmjit::X86Gp buffer_B_saved = a->gpzRef(10);
- asmjit::X86Gp C_Offset = a->gpzRef(11);
- asmjit::X86Gp B_pf_saved = a->gpzRef(12);
- asmjit::X86Gp iIdx = a->gpzRef(13);
- asmjit::X86Gp jIdx = a->gpzRef(14);
- asmjit::X86Gp kIdx = a->gpzRef(15);
- // asmjit::X86Gp B_pf = a->gpzRef(8);
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp jIdx = a->gpz(14);
+ x86::Gp kIdx = a->gpz(15);
+ // x86::Gp B_pf = a->gpz(8);
- asmjit::X86Zmm oneReg = x86::zmm29;
+ x86::Zmm oneReg = x86::zmm29;
// create 16-bit 1s
// i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
// and so on
@@ -420,7 +419,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
a->jl(LoopNRem);
}
- asmjit::FuncUtils::emitEpilog(a, layout);
+ a->emitEpilog(frame);
jit_micro_kernel_fp fn;
asmjit::Error err = rt_.add(&fn, &code_);
diff --git a/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc b/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc
new file mode 100644
index 0000000..8ae0745
--- /dev/null
+++ b/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc
@@ -0,0 +1,431 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <iostream>
+#include "GenerateKernel.h"
+
+namespace fbgemm {
+
+namespace x86 = asmjit::x86;
+
+/**
+ * Generate AVX512 instructions for initializing the C registers to 0 in 32-bit
+ * Accumulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs<
+ inst_set_t::avx512_vnni>(
+ x86::Emitter* a,
+ int rowRegs,
+ int colRegs,
+ int leadingDimCReg) {
+ for (int i = 0; i < rowRegs; ++i) {
+ for (int j = 0; j < colRegs; ++j) {
+ a->vxorps(
+ CRegs_avx512_[i * leadingDimCReg + j],
+ CRegs_avx512_[i * leadingDimCReg + j],
+ CRegs_avx512_[i * leadingDimCReg + j]);
+ }
+ }
+}
+
+/**
+ * Generate AVX512 instructions for computing block in the rank-k update of
+ * 32-bit Accmulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
+ inst_set_t::avx512_vnni>(
+ x86::Emitter* a,
+ x86::Gp buffer_A,
+ x86::Gp buffer_B,
+ x86::Gp B_pf,
+ int rowRegs,
+ int colRegs,
+ int lda,
+ int leadingDimCReg) {
+ // used for matrix A
+ x86::Zmm AReg = x86::zmm31;
+
+ // used for matrix B
+ x86::Zmm BReg = x86::zmm30;
+
+ for (int j = 0; j < colRegs; ++j) {
+ // load B
+ a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
+ // load A, broadcast and fmas
+ for (int i = 0; i < rowRegs; ++i) {
+ a->vpbroadcastd(
+ AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t)));
+ a->vpdpbusd(CRegs_avx512_[i * leadingDimCReg + j], AReg, BReg);
+ }
+ a->prefetcht0(x86::dword_ptr(B_pf, j * VLEN_ * sizeof(int8_t)));
+ }
+}
+
+/**
+ * Generate AVX512 instructions for storing the C registers back to the memory
+ * in 32-bit Accumulation kernel.
+ */
+template <>
+template <>
+void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs<
+ inst_set_t::avx512_vnni>(
+ x86::Emitter* a,
+ int rowRegs,
+ int colRegs,
+ x86::Gp C_Offset,
+ x86::Gp ldcReg,
+ bool accum,
+ int leadingDimCReg) {
+ for (int i = 0; i < rowRegs; ++i) {
+ if (i != 0) {
+ a->add(C_Offset, ldcReg);
+ } else {
+ a->mov(C_Offset, static_cast<asmjit::Imm>(0));
+ }
+ for (int j = 0; j < colRegs; ++j) {
+ if (accum) {
+ a->vpaddd(
+ CRegs_avx512_[i * leadingDimCReg + j],
+ CRegs_avx512_[i * leadingDimCReg + j],
+ x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)));
+ }
+ a->vmovups(
+ x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)),
+ CRegs_avx512_[i * leadingDimCReg + j]);
+ }
+ }
+}
+
+/**
+ * Get or Create the AVX512 instructions for 32-bit Accumulation macro-kernel.
+ *
+ */
+template <>
+template <>
+CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::jit_micro_kernel_fp
+CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<
+ inst_set_t::avx512_vnni>(
+ bool accum,
+ int32_t mc,
+ int32_t nc,
+ int32_t kc,
+ int32_t /* unused */) {
+ std::tuple<bool, int, int, int, int, int, int, int> kernelSig;
+ int kBlock;
+ int nBlock;
+ int mRegBlockSize;
+ int nRegBlockSize;
+ int nRegBlockSizeMin;
+ int row_interleave;
+
+ if (blocking_params) {
+ kBlock = blocking_params->KCB;
+ nBlock = blocking_params->NCB;
+ mRegBlockSize = blocking_params->MR;
+ nRegBlockSize = blocking_params->NR;
+ nRegBlockSizeMin = blocking_params->NR_MIN;
+ row_interleave = blocking_params->ROW_INTERLEAVE;
+ } else {
+ kBlock = PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::KCB;
+ nBlock = PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::NCB;
+ mRegBlockSize =
+ PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::MR;
+ nRegBlockSize =
+ PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::NR;
+ nRegBlockSizeMin =
+ PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::NR_MIN;
+ row_interleave = PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::
+ ROW_INTERLEAVE;
+ }
+
+ kernelSig = std::make_tuple(
+ accum,
+ mc,
+ nc,
+ nBlock,
+ kBlock,
+ mRegBlockSize,
+ nRegBlockSize,
+ nRegBlockSizeMin);
+
+ if (codeCache_.find(kernelSig) != codeCache_.end()) {
+ return codeCache_[kernelSig];
+ }
+ code_.reset(false);
+ code_.init(rt_.codeInfo());
+ x86::Assembler assembler(&code_);
+ x86::Emitter* a = assembler.as<x86::Emitter>();
+
+#if defined(FBGEMM_LOG_CODE)
+ // generated code logging
+ FILE* codeLogfile = fopen(
+ getCodeLoggingFile<inst_set_t::avx512_vnni>(
+ accum,
+ mc,
+ nc,
+ nBlock,
+ kBlock,
+ mRegBlockSize,
+ nRegBlockSize,
+ nRegBlockSizeMin)
+ .c_str(),
+ "w");
+ asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
+ if (codeLogger) {
+ code_.setLogger(codeLogger);
+ }
+#endif
+
+ assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
+ assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
+ int maxMRegs = mRegBlockSize;
+ int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
+ assert(
+ maxMRegs * maxNRegs <= 28 &&
+ "MR*(NR*ROW_INTERLEAVE*8/512) \
+ must be <= 28(available registers constraint)");
+
+ int mRegBlocks = mc / mRegBlockSize;
+ int mRegBlocksRem = mc % mRegBlockSize;
+
+ // arguments to the function created
+ x86::Gp buffer_A = a->zdi();
+ x86::Gp buffer_B = a->zsi();
+ x86::Gp B_pf = a->zdx();
+ x86::Gp CBase = a->zcx();
+ x86::Gp kSize = a->gpz(8);
+ x86::Gp ldcReg = a->gpz(9);
+
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::
+ FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>(
+ asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrame frame;
+ frame.init(func);
+
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(
+ x86::Reg::kGroupGp,
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
+
+ args.updateFuncFrame(frame);
+ frame.finalize();
+
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
+
+ asmjit::Label LoopMBlocks = a->newLabel();
+ asmjit::Label LoopNBlocks = a->newLabel();
+ asmjit::Label Loopk = a->newLabel();
+
+ x86::Gp buffer_B_saved = a->gpz(10);
+ x86::Gp C_Offset = a->gpz(11);
+ x86::Gp B_pf_saved = a->gpz(12);
+ x86::Gp iIdx = a->gpz(13);
+ x86::Gp jIdx = a->gpz(14);
+ x86::Gp kIdx = a->gpz(15);
+ // x86::Gp B_pf = a->gpz(8);
+
+ x86::Zmm oneReg = x86::zmm29;
+ // create 16-bit 1s
+ // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001
+ // and so on
+ // a->vpcmpeqw(oneReg, oneReg, oneReg);
+ a->vpternlogd(oneReg, oneReg, oneReg, 0xff);
+ a->vpsrlw(oneReg, oneReg, 15);
+ a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t)));
+
+ // save B_buffer address
+ a->mov(buffer_B_saved, buffer_B);
+ a->mov(B_pf_saved, B_pf);
+
+ int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
+ int colRegs = std::min(currColRegs, maxNRegs);
+ if (mRegBlocks > 0) {
+ // move 0 to iteration variables
+ a->mov(iIdx, 0);
+
+ a->bind(LoopMBlocks);
+ a->inc(iIdx);
+ a->mov(jIdx, 0);
+
+ a->bind(LoopNBlocks);
+ a->inc(jIdx);
+
+ int rowRegs = mRegBlockSize;
+
+ // init C registers
+ initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(Loopk);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx512_vnni>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(
+ buffer_B,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+ a->add(
+ B_pf,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+
+ // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(Loopk);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512_vnni>(
+ a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
+
+ // reset A
+ a->sub(buffer_A, kSize);
+
+ // B for next block
+ a->mov(buffer_B, buffer_B_saved);
+ // using C_Offset as temp reg
+ a->imul(
+ C_Offset,
+ jIdx,
+ static_cast<asmjit::Imm>(
+ nRegBlockSize * row_interleave * sizeof(int8_t)));
+ a->add(buffer_B, C_Offset);
+ a->mov(B_pf, B_pf_saved);
+ a->add(B_pf, C_Offset);
+
+ // increment C for next B block
+ a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
+
+ int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
+ a->cmp(jIdx, jLoopTrips);
+ a->jl(LoopNBlocks);
+
+ // increment A for next block
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
+
+ // increment C for next A block
+ a->sub(
+ CBase,
+ static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * sizeof(int32_t)));
+ a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs));
+ a->add(CBase, C_Offset);
+
+ // reset B
+ a->mov(buffer_B, buffer_B_saved);
+ a->mov(B_pf, B_pf_saved);
+ a->cmp(iIdx, mRegBlocks);
+ a->jl(LoopMBlocks);
+ }
+ // generate code for remainder
+ if (mRegBlocksRem > 0) {
+ asmjit::Label LoopNRem = a->newLabel();
+ asmjit::Label LoopkRem = a->newLabel();
+ int rowRegs = mRegBlocksRem;
+
+ a->mov(jIdx, 0);
+ a->bind(LoopNRem);
+ a->inc(jIdx);
+
+ // init C registers
+ initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs);
+
+ // init k loop index
+ a->mov(kIdx, 0);
+ a->bind(LoopkRem);
+
+ // k is incremented by row_interleave
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
+
+ genComputeBlock<inst_set_t::avx512_vnni>(
+ a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
+
+ // update buffer_A address for next k iteration
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
+
+ // update buffer_B address for next k iteration
+ a->add(
+ buffer_B,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+ a->add(
+ B_pf,
+ static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t)));
+
+ a->cmp(kIdx, kSize);
+ a->jl(LoopkRem);
+
+ // reset A
+ a->sub(buffer_A, kSize);
+ // B for next block
+ // using C_Offset as temp reg
+ a->imul(
+ C_Offset,
+ jIdx,
+ static_cast<asmjit::Imm>(
+ nRegBlockSize * row_interleave * sizeof(int8_t)));
+ a->mov(buffer_B, buffer_B_saved);
+ a->add(buffer_B, C_Offset);
+ a->mov(B_pf, B_pf_saved);
+ a->add(B_pf, C_Offset);
+
+ // store C matrix
+ storeCRegs<inst_set_t::avx512_vnni>(
+ a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs);
+
+ // increment C for next B block
+ a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t)));
+
+ int jLoopTrips = currColRegs / maxNRegs;
+ // jLoopTrips should be at least 1
+ jLoopTrips = jLoopTrips ? jLoopTrips : 1;
+ a->cmp(jIdx, jLoopTrips);
+ a->jl(LoopNRem);
+ }
+
+ a->emitEpilog(frame);
+
+ jit_micro_kernel_fp fn;
+ asmjit::Error err = rt_.add(&fn, &code_);
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
+ codeCache_[kernelSig] = fn;
+
+#if defined(FBGEMM_LOG_CODE)
+ fclose(codeLogfile);
+ delete codeLogger;
+#endif
+
+ return fn;
+}
+
+} // namespace fbgemm
diff --git a/src/GroupwiseConv.h b/src/GroupwiseConv.h
index 1e6324e..4c5eea5 100644
--- a/src/GroupwiseConv.h
+++ b/src/GroupwiseConv.h
@@ -128,60 +128,58 @@ class GenConvKernel {
const conv_param_t<SPATIAL_DIM>& conv_param);
template <inst_set_t instSet>
- void createVector16BitOne(asmjit::X86Emitter* a);
+ void createVector16BitOne(x86::Emitter* a);
template <inst_set_t instSet>
- void createVector8BitOne(asmjit::X86Emitter* a);
+ void createVector8BitOne(x86::Emitter* a);
template <inst_set_t instSet>
- void setToZeroPt(asmjit::X86Emitter* a, asmjit::X86Ymm destReg);
+ void setToZeroPt(x86::Emitter* a, x86::Ymm destReg);
template <inst_set_t instSet>
- void
- gen8bitFMA(asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm wReg);
+ void gen8bitFMA(x86::Emitter* a, x86::Ymm aReg, x86::Ymm wReg);
template <inst_set_t instSet>
- void genForLoadingWeights(asmjit::X86Emitter* a, int c_offset);
+ void genForLoadingWeights(x86::Emitter* a, int c_offset);
template <inst_set_t instSet>
- void genConstForPermutations(asmjit::X86Emitter* a);
+ void genConstForPermutations(x86::Emitter* a);
template <inst_set_t instSet>
- void genForTopEdge(asmjit::X86Emitter* a, int c_offset);
+ void genForTopEdge(x86::Emitter* a, int c_offset);
template <inst_set_t instSet>
- void genForLeftEdge(asmjit::X86Emitter* a, int c_offset);
+ void genForLeftEdge(x86::Emitter* a, int c_offset);
template <inst_set_t instSet>
- void genForRightEdge(asmjit::X86Emitter* a, int c_offset);
+ void genForRightEdge(x86::Emitter* a, int c_offset);
template <inst_set_t instSet>
- void genForBottomEdge(asmjit::X86Emitter* a, int c_offset);
+ void genForBottomEdge(x86::Emitter* a, int c_offset);
template <inst_set_t instSet>
- void genCoreInsts(asmjit::X86Emitter* a, int c_offset);
+ void genCoreInsts(x86::Emitter* a, int c_offset);
template <inst_set_t instSet>
- void storeResult(asmjit::X86Emitter* a);
+ void storeResult(x86::Emitter* a);
// for Rowoffset kernel
// Add 4 consecutive numbers of 32 uint8 and emit 8 32-bit
template <inst_set_t instSet>
- void gen8BitSumX4(asmjit::X86Emitter* a, asmjit::X86Ymm aReg);
+ void gen8BitSumX4(x86::Emitter* a, x86::Ymm aReg);
// Add 8 consecutive numbers of 64 uint8 and emit 8 32-bit
template <inst_set_t instSet>
- void
- gen8BitSumX8(asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm bReg);
+ void gen8BitSumX8(x86::Emitter* a, x86::Ymm aReg, x86::Ymm bReg);
// Add 16 consecutive numbers of 128 uint8 and emit 8 32-bit
template <inst_set_t instSet>
void gen8BitSumX16(
- asmjit::X86Emitter* a,
- asmjit::X86Ymm aReg,
- asmjit::X86Ymm bReg,
- asmjit::X86Ymm cReg,
- asmjit::X86Ymm dReg);
+ x86::Emitter* a,
+ x86::Ymm aReg,
+ x86::Ymm bReg,
+ x86::Ymm cReg,
+ x86::Ymm dReg);
// Generate instruction sequence that loads 8-bit values and sum them up.
// Depending on C_per_G_, this function dispatches to gen8BitSumX4/8/16
@@ -191,35 +189,33 @@ class GenConvKernel {
// Internally, actRegAvx2_, stPermRegAvx2_, WRegs_avx2_[0, 1], tmpReg1Avx2_,
// and resultRegAvx2_ are used.
template <inst_set_t instSet>
- void gen8BitSum(
- asmjit::X86Emitter* a,
- int act_offset,
- bool use_scratch_reg1 = true);
+ void
+ gen8BitSum(x86::Emitter* a, int act_offset, bool use_scratch_reg1 = true);
// Use scratchReg1_ and tmpReg1Avx2_ internally
template <inst_set_t instSet>
- void genZeroPtSum(asmjit::X86Emitter* a, int multiplier);
+ void genZeroPtSum(x86::Emitter* a, int multiplier);
template <inst_set_t instSet>
- void genForTopEdgeRowoffset(asmjit::X86Emitter* a);
+ void genForTopEdgeRowoffset(x86::Emitter* a);
template <inst_set_t instSet>
- void genForLeftEdgeRowoffset(asmjit::X86Emitter* a);
+ void genForLeftEdgeRowoffset(x86::Emitter* a);
template <inst_set_t instSet>
- void genForRightEdgeRowoffset(asmjit::X86Emitter* a);
+ void genForRightEdgeRowoffset(x86::Emitter* a);
template <inst_set_t instSet>
- void genForBottomEdgeRowoffset(asmjit::X86Emitter* a);
+ void genForBottomEdgeRowoffset(x86::Emitter* a);
template <inst_set_t instSet>
- void genRowoffsetCorners(asmjit::X86Emitter* a);
+ void genRowoffsetCorners(x86::Emitter* a);
template <inst_set_t instSet>
- void genRowoffsetCore(asmjit::X86Emitter* a);
+ void genRowoffsetCore(x86::Emitter* a);
template <inst_set_t instSet>
- void storeResultRowoffset(asmjit::X86Emitter* a, int offset = 0);
+ void storeResultRowoffset(x86::Emitter* a, int offset = 0);
static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit.
static thread_local asmjit::CodeHolder code_; ///< JIT Code Holder for asmjit.
@@ -234,30 +230,30 @@ class GenConvKernel {
int vectorWidth_; ///< Vector width in bits.
int VLEN_; ///< Vector width in elements.
// avx2 specific
- asmjit::X86Ymm
+ x86::Ymm
WRegs_avx2_[9]; ///< AVX2 ymm registers for weights in the micro-kernel.
- asmjit::X86Ymm zeroPTRegAvx2_;
- asmjit::X86Ymm tmpReg1Avx2_;
- asmjit::X86Ymm stPermRegAvx2_;
- asmjit::X86Ymm actRegAvx2_;
- asmjit::X86Ymm resultRegAvx2_;
- asmjit::X86Ymm oneReg8BitAvx2_;
- asmjit::X86Ymm oneReg16BitAvx2_;
+ x86::Ymm zeroPTRegAvx2_;
+ x86::Ymm tmpReg1Avx2_;
+ x86::Ymm stPermRegAvx2_;
+ x86::Ymm actRegAvx2_;
+ x86::Ymm resultRegAvx2_;
+ x86::Ymm oneReg8BitAvx2_;
+ x86::Ymm oneReg16BitAvx2_;
// arguments to the function created
- asmjit::X86Gp in_acts_R_;
- asmjit::X86Gp wghts_R_;
- asmjit::X86Gp out_acts_R_;
- asmjit::X86Gp a_zero_pt_R_;
- asmjit::X86Gp H_R_;
- asmjit::X86Gp W_R_;
- asmjit::X86Gp row_offset_R_;
+ x86::Gp in_acts_R_;
+ x86::Gp wghts_R_;
+ x86::Gp out_acts_R_;
+ x86::Gp a_zero_pt_R_;
+ x86::Gp H_R_;
+ x86::Gp W_R_;
+ x86::Gp row_offset_R_;
// Used registers
- asmjit::X86Gp loopR1_;
- asmjit::X86Gp loopR2_;
- asmjit::X86Gp scratchReg1_;
- asmjit::X86Gp scratchReg2_;
+ x86::Gp loopR1_;
+ x86::Gp loopR2_;
+ x86::Gp scratchReg1_;
+ x86::Gp scratchReg2_;
// Other parameters
bool isAZeroPointZero_;
diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc
index e789695..b140c83 100644
--- a/src/GroupwiseConvAcc32Avx2.cc
+++ b/src/GroupwiseConvAcc32Avx2.cc
@@ -104,7 +104,7 @@ jit_conv_kernel_fp getOrCreateConvKernel(
template <>
template <>
void GenConvKernel<2, int32_t>::createVector8BitOne<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// create 8-bit 1s
// i.e., oneReg16BitAvx2_[0:7] contains 0x01, oneReg8BitAvx2_[8:15] contains
// 0x01 and so on
@@ -115,7 +115,7 @@ void GenConvKernel<2, int32_t>::createVector8BitOne<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::createVector16BitOne<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// create 16-bit 1s
// i.e., oneReg16BitAvx2_[0:15] contains 0x0001, oneReg16BitAvx2_[16:31]
// contains 0x0001 and so on
@@ -125,11 +125,11 @@ void GenConvKernel<2, int32_t>::createVector16BitOne<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::setToZeroPt<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
- asmjit::X86Ymm destReg) {
+ x86::Emitter* a,
+ x86::Ymm destReg) {
// make destReg all zeros
a->vxorps(destReg, destReg, destReg);
- asmjit::X86Xmm const_reg_xmm = x86::xmm10;
+ x86::Xmm const_reg_xmm = x86::xmm10;
// move zero point to xmm10
a->movq(const_reg_xmm, a_zero_pt_R_);
// make copies of zero point
@@ -143,9 +143,9 @@ void GenConvKernel<2, int32_t>::setToZeroPt<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genConstForPermutations<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
- asmjit::X86Gp permute_const_reg = a->gpzRef(12);
- asmjit::X86Xmm const_reg_xmm = x86::xmm10;
+ x86::Emitter* a) {
+ x86::Gp permute_const_reg = a->gpz(12);
+ x86::Xmm const_reg_xmm = x86::xmm10;
// We have 1st group in even lanes and 2nd group in odd lanes.
// Permute to put 1st group to lower 128-bit and 2nd group in upper
// 128-bit.
@@ -159,8 +159,7 @@ void GenConvKernel<2, int32_t>::genConstForPermutations<inst_set_t::avx2>(
template <>
template <>
-void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>(x86::Emitter* a) {
if (C_per_G_ == 4) {
// store with permutation
a->vpermd(resultRegAvx2_, stPermRegAvx2_, resultRegAvx2_);
@@ -171,7 +170,7 @@ void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::storeResultRowoffset<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int offset) {
// store
if (C_per_G_ == 4) {
@@ -198,7 +197,7 @@ void GenConvKernel<2, int32_t>::storeResultRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForLoadingWeights<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int c_offset) {
// load weights
for (int r = 0; r < R_; ++r) {
@@ -225,9 +224,9 @@ void GenConvKernel<2, int32_t>::genForLoadingWeights<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8bitFMA<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
- asmjit::X86Ymm aReg,
- asmjit::X86Ymm wReg) {
+ x86::Emitter* a,
+ x86::Ymm aReg,
+ x86::Ymm wReg) {
a->vpmaddubsw(tmpReg1Avx2_, aReg, wReg);
a->vpmaddwd(tmpReg1Avx2_, oneReg16BitAvx2_, tmpReg1Avx2_);
a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_);
@@ -236,8 +235,8 @@ void GenConvKernel<2, int32_t>::gen8bitFMA<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8BitSumX4<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
- asmjit::X86Ymm aReg) {
+ x86::Emitter* a,
+ x86::Ymm aReg) {
a->vpmaddubsw(tmpReg1Avx2_, aReg, oneReg8BitAvx2_);
a->vpmaddwd(tmpReg1Avx2_, tmpReg1Avx2_, oneReg16BitAvx2_);
a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_);
@@ -246,9 +245,9 @@ void GenConvKernel<2, int32_t>::gen8BitSumX4<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8BitSumX8<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
- asmjit::X86Ymm aReg,
- asmjit::X86Ymm bReg) {
+ x86::Emitter* a,
+ x86::Ymm aReg,
+ x86::Ymm bReg) {
a->vxorps(tmpReg1Avx2_, tmpReg1Avx2_, tmpReg1Avx2_);
// Let a[0] denote 0th (LSB) 8-bit of aReg
// After vpsadbw, a[0:2] = a[0] + ... + a[7]
@@ -267,11 +266,11 @@ void GenConvKernel<2, int32_t>::gen8BitSumX8<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8BitSumX16<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
- asmjit::X86Ymm aReg,
- asmjit::X86Ymm bReg,
- asmjit::X86Ymm cReg,
- asmjit::X86Ymm dReg) {
+ x86::Emitter* a,
+ x86::Ymm aReg,
+ x86::Ymm bReg,
+ x86::Ymm cReg,
+ x86::Ymm dReg) {
a->vxorps(tmpReg1Avx2_, tmpReg1Avx2_, tmpReg1Avx2_);
// After vpsadbw, a[0:2] = a[0] + ... + a[7]
// a[8:10] = a[8] + ... + a[15]
@@ -319,7 +318,7 @@ void GenConvKernel<2, int32_t>::gen8BitSumX16<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::gen8BitSum<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int act_offset,
bool use_scratch_reg1 /*=true*/) {
if (use_scratch_reg1) {
@@ -385,11 +384,11 @@ void GenConvKernel<2, int32_t>::gen8BitSum<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genZeroPtSum<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int multiplier) {
a->mov(scratchReg1_, static_cast<asmjit::Imm>(multiplier));
// tmpReg1Avx2_ also uses xmm11
- asmjit::X86Xmm const_reg_xmm = x86::xmm11;
+ x86::Xmm const_reg_xmm = x86::xmm11;
a->movq(const_reg_xmm, scratchReg1_);
a->vpbroadcastd(tmpReg1Avx2_, const_reg_xmm);
a->vpmulld(tmpReg1Avx2_, zeroPTRegAvx2_, tmpReg1Avx2_);
@@ -399,7 +398,7 @@ void GenConvKernel<2, int32_t>::genZeroPtSum<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForTopEdge<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int c_offset) {
// top-left corner code
if (c_offset == 0) {
@@ -559,7 +558,7 @@ void GenConvKernel<2, int32_t>::genForTopEdge<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForLeftEdge<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int c_offset) {
// left edge excluding corners
asmjit::Label LoopLeftEdge = a->newLabel();
@@ -626,7 +625,7 @@ void GenConvKernel<2, int32_t>::genForLeftEdge<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForRightEdge<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int c_offset) {
// right edge excluding corners
asmjit::Label LoopRightEdge = a->newLabel();
@@ -714,7 +713,7 @@ void GenConvKernel<2, int32_t>::genForRightEdge<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForBottomEdge<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int c_offset) {
// bottom-left corner
// we updating the last row
@@ -906,7 +905,7 @@ void GenConvKernel<2, int32_t>::genForBottomEdge<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genCoreInsts<inst_set_t::avx2>(
- asmjit::X86Emitter* a,
+ x86::Emitter* a,
int c_offset) {
// main compute
asmjit::Label LoopH = a->newLabel();
@@ -1011,9 +1010,9 @@ template <>
jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
const conv_param_t<2>& conv_param) {
code_.reset(false);
- code_.init(rt_.getCodeInfo());
- asmjit::X86Assembler assembler(&code_);
- asmjit::X86Emitter* a = assembler.asEmitter();
+ code_.init(rt_.codeInfo());
+ x86::Assembler assembler(&code_);
+ x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
// log code to a file
@@ -1030,16 +1029,16 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
wghts_R_ = a->zsi();
out_acts_R_ = a->zdx();
a_zero_pt_R_ = a->zcx();
- H_R_ = a->gpzRef(8);
- W_R_ = a->gpzRef(9);
- row_offset_R_ = a->gpzRef(10);
+ H_R_ = a->gpz(8);
+ W_R_ = a->gpz(9);
+ row_offset_R_ = a->gpz(10);
// register for temporary use
- scratchReg1_ = a->gpzRef(12);
- scratchReg2_ = a->gpzRef(13);
+ scratchReg1_ = a->gpz(12);
+ scratchReg2_ = a->gpz(13);
asmjit::FuncDetail func;
- func.init(asmjit::FuncSignature6<
+ func.init(asmjit::FuncSignatureT<
void,
uint8_t*,
int8_t*,
@@ -1048,29 +1047,29 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
int32_t,
int32_t>(asmjit::CallConv::kIdHost));
- asmjit::FuncFrameInfo ffi;
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindVec,
- asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindGp, asmjit::Utils::mask(10, 11, 12, 13, 14, 15));
+ asmjit::FuncFrame frame;
+ frame.init(func);
- asmjit::FuncArgsMapper args(&func);
- args.assignAll(in_acts_R_, wghts_R_, out_acts_R_, a_zero_pt_R_, H_R_, W_R_);
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(
+ x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- args.updateFrameInfo(ffi);
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(in_acts_R_, wghts_R_, out_acts_R_, a_zero_pt_R_, H_R_, W_R_);
- asmjit::FuncFrameLayout layout;
- layout.init(func, ffi);
+ args.updateFuncFrame(frame);
+ frame.finalize();
- asmjit::FuncUtils::emitProlog(a, layout);
- asmjit::FuncUtils::allocArgs(a, layout, args);
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
createVector16BitOne<inst_set_t::avx2>(a);
- loopR1_ = a->gpzRef(14);
- loopR2_ = a->gpzRef(15);
+ loopR1_ = a->gpz(14);
+ loopR2_ = a->gpz(15);
if (!isAZeroPointZero_) {
setToZeroPt<inst_set_t::avx2>(a, zeroPTRegAvx2_);
@@ -1095,7 +1094,7 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
genCoreInsts<inst_set_t::avx2>(a, c);
}
- asmjit::FuncUtils::emitEpilog(a, layout);
+ a->emitEpilog(frame);
jit_conv_kernel_fp fn;
asmjit::Error err = rt_.add(&fn, &code_);
@@ -1117,7 +1116,7 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// top-left corner code
// zero out the results register
a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
@@ -1213,7 +1212,7 @@ void GenConvKernel<2, int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// left edge excluding corners
asmjit::Label LoopLeftEdge = a->newLabel();
a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_));
@@ -1256,7 +1255,7 @@ void GenConvKernel<2, int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// right edge excluding corners
asmjit::Label LoopRightEdge = a->newLabel();
@@ -1326,7 +1325,7 @@ void GenConvKernel<2, int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// bottom-left corner
// zero out
a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
@@ -1429,7 +1428,7 @@ void GenConvKernel<2, int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
template <>
template <>
void GenConvKernel<2, int32_t>::genRowoffsetCore<inst_set_t::avx2>(
- asmjit::X86Emitter* a) {
+ x86::Emitter* a) {
// number of uint8 elements in input channels should be a multiple of 32
assert(C_ % 32 == 0);
@@ -1491,9 +1490,9 @@ jit_rowoffset_kernel_fp
GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
const conv_param_t<2>& conv_param) {
code_.reset(false);
- code_.init(rt_.getCodeInfo());
- asmjit::X86Assembler assembler(&code_);
- asmjit::X86Emitter* a = assembler.asEmitter();
+ code_.init(rt_.codeInfo());
+ x86::Assembler assembler(&code_);
+ x86::Emitter* a = assembler.as<x86::Emitter>();
#if defined(FBGEMM_LOG_CODE)
// log code to a file
@@ -1510,45 +1509,45 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
a_zero_pt_R_ = a->zsi();
H_R_ = a->zdx();
W_R_ = a->zcx();
- row_offset_R_ = a->gpzRef(8);
+ row_offset_R_ = a->gpz(8);
// register for temporary use
- scratchReg1_ = a->gpzRef(12);
- scratchReg2_ = a->gpzRef(13);
+ scratchReg1_ = a->gpz(12);
+ scratchReg2_ = a->gpz(13);
- loopR1_ = a->gpzRef(14);
- loopR2_ = a->gpzRef(15);
+ loopR1_ = a->gpz(14);
+ loopR2_ = a->gpz(15);
asmjit::FuncDetail func;
func.init(
asmjit::
- FuncSignature5<void, uint8_t*, int32_t, int32_t, int32_t, int32_t*>(
+ FuncSignatureT<void, uint8_t*, int32_t, int32_t, int32_t, int32_t*>(
asmjit::CallConv::kIdHost));
- asmjit::FuncFrameInfo ffi;
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindVec,
- asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
- asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
- ffi.setDirtyRegs(
- asmjit::X86Reg::kKindGp, asmjit::Utils::mask(10, 11, 12, 13, 14, 15));
+ asmjit::FuncFrame frame;
+ frame.init(func);
- asmjit::FuncArgsMapper args(&func);
- args.assignAll(in_acts_R_, a_zero_pt_R_, H_R_, W_R_, row_offset_R_);
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(
+ x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
- args.updateFrameInfo(ffi);
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(in_acts_R_, a_zero_pt_R_, H_R_, W_R_, row_offset_R_);
- asmjit::FuncFrameLayout layout;
- layout.init(func, ffi);
+ args.updateFuncFrame(frame);
+ frame.finalize();
- asmjit::FuncUtils::emitProlog(a, layout);
- asmjit::FuncUtils::allocArgs(a, layout, args);
+ a->emitProlog(frame);
+ a->emitArgsAssignment(frame, args);
// This uses xmm10 register temporarily. Should come before
// createVector8BitOne
if (!isAZeroPointZero_) {
// we can use xmm11 because ymm11 is used by tmpReg1Avx2_
- asmjit::X86Xmm const_reg_xmm = x86::xmm11;
+ x86::Xmm const_reg_xmm = x86::xmm11;
a->movq(const_reg_xmm, a_zero_pt_R_);
a->vpbroadcastd(zeroPTRegAvx2_, const_reg_xmm);
@@ -1569,7 +1568,7 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
genRowoffsetCore<inst_set_t::avx2>(a);
- asmjit::FuncUtils::emitEpilog(a, layout);
+ a->emitEpilog(frame);
jit_rowoffset_kernel_fp fn;
asmjit::Error err = rt_.add(&fn, &code_);
diff --git a/src/PackAMatrix.cc b/src/PackAMatrix.cc
index 143e11d..5fabf97 100644
--- a/src/PackAMatrix.cc
+++ b/src/PackAMatrix.cc
@@ -34,7 +34,8 @@ PackAMatrix<T, accT>::PackAMatrix(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
- if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) {
+ if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
+ !fbgemmHasAvx2Support())) {
assert(0 && "unknown architecure");
}
@@ -43,7 +44,12 @@ PackAMatrix<T, accT>::PackAMatrix(
BaseType::bcol_ = params->KCB;
row_interleave_B_ = params->ROW_INTERLEAVE;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE;
+ } else if (fbgemmHasAvx512Support()) {
BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
row_interleave_B_ =
diff --git a/src/PackAWithIm2Col.cc b/src/PackAWithIm2Col.cc
index d731654..2aca27d 100644
--- a/src/PackAWithIm2Col.cc
+++ b/src/PackAWithIm2Col.cc
@@ -49,7 +49,8 @@ PackAWithIm2Col<T, accT, SPATIAL_DIM>::PackAWithIm2Col(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
- if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) {
+ if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
+ !fbgemmHasAvx2Support())) {
assert(0 && "unknown architecure");
}
@@ -58,7 +59,12 @@ PackAWithIm2Col<T, accT, SPATIAL_DIM>::PackAWithIm2Col(
BaseType::bcol_ = params->KCB;
row_interleave_B_ = params->ROW_INTERLEAVE;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE;
+ } else if (fbgemmHasAvx512Support()) {
BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
row_interleave_B_ =
@@ -478,7 +484,9 @@ int PackAWithIm2Col<T, accT, SPATIAL_DIM>::rowOffsetBufferSize(
if (params) {
return params->MCB;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ return PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
+ } else if (fbgemmHasAvx512Support()) {
return PackingTraits<T, accT, inst_set_t::avx512>::MCB;
} else if (fbgemmHasAvx2Support()) {
return PackingTraits<T, accT, inst_set_t::avx2>::MCB;
diff --git a/src/PackAWithQuantRowOffset.cc b/src/PackAWithQuantRowOffset.cc
index 52caed4..13a8fad 100644
--- a/src/PackAWithQuantRowOffset.cc
+++ b/src/PackAWithQuantRowOffset.cc
@@ -45,7 +45,8 @@ PackAWithQuantRowOffset<T, accT>::PackAWithQuantRowOffset(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
- if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) {
+ if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
+ !fbgemmHasAvx2Support())) {
assert(0 && "unknown architecure");
}
@@ -54,7 +55,12 @@ PackAWithQuantRowOffset<T, accT>::PackAWithQuantRowOffset(
BaseType::bcol_ = params->KCB;
row_interleave_B_ = params->ROW_INTERLEAVE;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE;
+ } else if (fbgemmHasAvx512Support()) {
BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
row_interleave_B_ =
@@ -201,7 +207,9 @@ int PackAWithQuantRowOffset<T, accT>::rowOffsetBufferSize(
if (params) {
return params->MCB;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ return PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
+ } else if (fbgemmHasAvx512Support()) {
return PackingTraits<T, accT, inst_set_t::avx512>::MCB;
} else if (fbgemmHasAvx2Support()) {
return PackingTraits<T, accT, inst_set_t::avx2>::MCB;
diff --git a/src/PackAWithRowOffset.cc b/src/PackAWithRowOffset.cc
index 733bf5c..e84c67b 100644
--- a/src/PackAWithRowOffset.cc
+++ b/src/PackAWithRowOffset.cc
@@ -39,7 +39,8 @@ PackAWithRowOffset<T, accT>::PackAWithRowOffset(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
- if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) {
+ if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
+ !fbgemmHasAvx2Support())) {
assert(0 && "unknown architecure");
}
@@ -48,7 +49,12 @@ PackAWithRowOffset<T, accT>::PackAWithRowOffset(
BaseType::bcol_ = params->KCB;
row_interleave_B_ = params->ROW_INTERLEAVE;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB;
+ row_interleave_B_ =
+ PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE;
+ } else if (fbgemmHasAvx512Support()) {
BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB;
BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
row_interleave_B_ =
@@ -189,7 +195,9 @@ int PackAWithRowOffset<T, accT>::rowOffsetBufferSize(
if (params) {
return params->MCB;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ return PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB;
+ } else if (fbgemmHasAvx512Support()) {
return PackingTraits<T, accT, inst_set_t::avx512>::MCB;
} else if (fbgemmHasAvx2Support()) {
return PackingTraits<T, accT, inst_set_t::avx2>::MCB;
diff --git a/src/PackBMatrix.cc b/src/PackBMatrix.cc
index b19b5d4..c237ac4 100644
--- a/src/PackBMatrix.cc
+++ b/src/PackBMatrix.cc
@@ -188,7 +188,8 @@ PackBMatrix<T, accT>::PackBMatrix(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
- if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) {
+ if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
+ !fbgemmHasAvx2Support())) {
assert(0 && "unknown architecure");
}
@@ -197,7 +198,12 @@ PackBMatrix<T, accT>::PackBMatrix(
BaseType::bcol_ = params->NCB;
row_interleave_ = params->ROW_INTERLEAVE;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB;
+ BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::NCB;
+ row_interleave_ =
+ PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE;
+ } else if (fbgemmHasAvx512Support()) {
BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB;
BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::NCB;
row_interleave_ =
@@ -228,7 +234,7 @@ PackBMatrix<T, accT>::PackBMatrix(
BaseType::numGroups() * BaseType::blockRows() * BaseType::brow_ *
BaseType::blockCols() * BaseType::bcol_ * sizeof(T));
}
- pack(block);
+ pack(block, params);
}
template <typename T, typename accT>
@@ -294,7 +300,8 @@ void PackBMatrix<T, accT>::pack_unpack_(
const block_type_t& block,
T* unpack_buf,
T* pack_buf,
- bool ispack) {
+ bool ispack,
+ const BlockingFactors* params) {
assert((BaseType::blockRowSize() % row_interleave_) == 0);
assert((block.row_start % BaseType::blockRowSize()) == 0);
assert((block.col_start % BaseType::blockColSize()) == 0);
@@ -303,7 +310,7 @@ void PackBMatrix<T, accT>::pack_unpack_(
bool tr = (trans_ == matrix_op_t::Transpose);
for (int g = 0; g < BaseType::numGroups(); ++g) {
T* pack_buf_cur = pack_buf +
- g * BaseType::packedBufferSize(block.row_size, block.col_size);
+ g * BaseType::packedBufferSize(block.row_size, block.col_size, params);
for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
int r_offset = ((i / BaseType::blockRowSize()) * BaseType::blockCols()) *
(BaseType::blockRowSize() * BaseType::blockColSize()) +
@@ -374,17 +381,21 @@ void PackBMatrix<T, accT>::pack_unpack_(
}
template <typename T, typename accT>
-void PackBMatrix<T, accT>::pack(const block_type_t& block) {
- pack_unpack_(block, const_cast<T*>(smat_), BaseType::getBuf(), true);
+void PackBMatrix<T, accT>::pack(
+ const block_type_t& block,
+ const BlockingFactors* params) {
+ pack_unpack_(block, const_cast<T*>(smat_), BaseType::getBuf(), true, params);
}
template <typename T, typename accT>
-void PackBMatrix<T, accT>::unpack(T* origin_buf) {
+void PackBMatrix<T, accT>::unpack(
+ T* origin_buf,
+ const BlockingFactors* params) {
block_type_t blockB{BaseType::packedRowStart(),
BaseType::numPackedRows(),
BaseType::packedColStart(),
BaseType::numPackedCols()};
- pack_unpack_(blockB, origin_buf, BaseType::getBuf(), false);
+ pack_unpack_(blockB, origin_buf, BaseType::getBuf(), false, params);
}
template <typename T, typename accT>
@@ -407,7 +418,9 @@ int32_t PackBMatrix<T, accT>::addr(int32_t r, int32_t c) const {
}
template <typename T, typename accT>
-void PackBMatrix<T, accT>::printPackedMatrix(std::string name) {
+void PackBMatrix<T, accT>::printPackedMatrix(
+ std::string name,
+ const BlockingFactors* params) {
std::cout << name << ":"
<< "[" << BaseType::numPackedRows() << ", "
<< BaseType::numPackedCols() << "]" << std::endl;
@@ -419,7 +432,7 @@ void PackBMatrix<T, accT>::printPackedMatrix(std::string name) {
T* out = BaseType::getBuf() +
g *
BaseType::packedBufferSize(
- BaseType::numPackedRows(), BaseType::numPackedCols());
+ BaseType::numPackedRows(), BaseType::numPackedCols(), params);
std::cout << "group: " << g << std::endl;
for (auto nr = 0; nr < BaseType::blockRows(); ++nr) {
auto rows = (nr == BaseType::blockRows() - 1) ? BaseType::lastBrow()
diff --git a/src/PackMatrix.cc b/src/PackMatrix.cc
index c7503dd..ff7b842 100644
--- a/src/PackMatrix.cc
+++ b/src/PackMatrix.cc
@@ -36,7 +36,8 @@ int PackMatrix<PT, inpType, accType>::packedBufferSize(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
- if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) {
+ if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
+ !fbgemmHasAvx2Support())) {
assert(0 && "unknown architecure");
}
@@ -46,7 +47,11 @@ int PackMatrix<PT, inpType, accType>::packedBufferSize(
NCB = params->NCB;
KCB = params->KCB;
} else {
- if (fbgemmHasAvx512Support()) {
+ if (fbgemmHasAvx512VnniSupport()) {
+ MCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni>::MCB;
+ NCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni>::NCB;
+ KCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni>::KCB;
+ } else if (fbgemmHasAvx512Support()) {
MCB = PackingTraits<inpType, accType, inst_set_t::avx512>::MCB;
NCB = PackingTraits<inpType, accType, inst_set_t::avx512>::NCB;
KCB = PackingTraits<inpType, accType, inst_set_t::avx512>::KCB;
diff --git a/src/PackWeightMatrixForGConv.cc b/src/PackWeightMatrixForGConv.cc
index ba6adf3..f6ad59e 100644
--- a/src/PackWeightMatrixForGConv.cc
+++ b/src/PackWeightMatrixForGConv.cc
@@ -106,7 +106,7 @@ inline int PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::packed_index_(
* on 2 groups at a time and full SIMD width can be efficiently utilized even
* while working on 1 group at a time.
* In this case, the layout is G (C/4) R S K 4
-*/
+ */
template <typename T, typename accT, int SPATIAL_DIM>
void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack_unpack_(
@@ -148,9 +148,9 @@ void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack_unpack_(
if (ispack) {
transposeConvWeights(conv_param_, src, dst);
} else {
- // TODO: Wrap this as a inverseTransposeConvWeights()?
- // For unpack & transposed, call transposeConvWeights()
- // G (R S C/G) K/G => G K/G (R S C/G)
+ // TODO: Wrap this as a inverseTransposeConvWeights()?
+ // For unpack & transposed, call transposeConvWeights()
+ // G (R S C/G) K/G => G K/G (R S C/G)
for (int r = 0; r < R; ++r) {
for (int s = 0; s < S; ++s) {
for (int k = 0; k < OC_per_G; ++k) {
diff --git a/src/PackWeightsForConv.cc b/src/PackWeightsForConv.cc
index 25b04af..44f210e 100644
--- a/src/PackWeightsForConv.cc
+++ b/src/PackWeightsForConv.cc
@@ -125,6 +125,74 @@ bool PackWeightsForConv<SPATIAL_DIM, T, accT>::isPackingCompliant(
test_conv_p.dilation.begin());
}
+template <int SPATIAL_DIM, typename T, typename accT>
+std::string PackWeightsForConv<SPATIAL_DIM, T, accT>::mismatchingParams(
+ const conv_param_t<SPATIAL_DIM>& test_conv_p) {
+ std::string msg = "";
+
+ auto combineStr = [](std::string id, std::string str1, std::string str2) {
+ std::string out = id + std::string(" ");
+ out += str1;
+ out += std::string(" vs ") + str2;
+ out += std::string(";");
+ return out;
+ };
+
+ auto combineInt = [&combineStr](std::string id, int int1, int int2) {
+ return combineStr(id, std::to_string(int1), std::to_string(int2));
+ };
+
+ if (conv_param_.IC != test_conv_p.IC) {
+ msg += combineInt("input_channels", conv_param_.IC, test_conv_p.IC);
+ }
+ if (conv_param_.OC != test_conv_p.OC) {
+ msg += combineInt("output_channels", conv_param_.IC, test_conv_p.IC);
+ }
+ if (conv_param_.G != test_conv_p.G) {
+ msg += combineInt("groups", conv_param_.G, test_conv_p.G);
+ }
+
+ if (!std::equal(
+ conv_param_.K.begin(), conv_param_.K.end(), test_conv_p.K.begin())) {
+ msg += combineStr(
+ "kernel",
+ arrayToString<SPATIAL_DIM>(conv_param_.K),
+ arrayToString<SPATIAL_DIM>(test_conv_p.K));
+ }
+
+ if (!std::equal(
+ conv_param_.stride.begin(),
+ conv_param_.stride.end(),
+ test_conv_p.stride.begin())) {
+ msg += combineStr(
+ "stride",
+ arrayToString<SPATIAL_DIM>(conv_param_.stride),
+ arrayToString<SPATIAL_DIM>(test_conv_p.stride));
+ }
+
+ if (!std::equal(
+ conv_param_.pad.begin(),
+ conv_param_.pad.end(),
+ test_conv_p.pad.begin())) {
+ msg += combineStr(
+ "pad",
+ arrayToString<2 * SPATIAL_DIM>(conv_param_.pad),
+ arrayToString<2 * SPATIAL_DIM>(test_conv_p.pad));
+ }
+
+ if (!std::equal(
+ conv_param_.dilation.begin(),
+ conv_param_.dilation.end(),
+ test_conv_p.dilation.begin())) {
+ msg += combineStr(
+ "dilation",
+ arrayToString<SPATIAL_DIM>(conv_param_.dilation),
+ arrayToString<SPATIAL_DIM>(test_conv_p.dilation));
+ }
+
+ return msg;
+}
+
template class PackWeightsForConv<2, int8_t, int32_t>;
template class PackWeightsForConv<3, int8_t, int32_t>;
diff --git a/src/Utils.cc b/src/Utils.cc
index 355a5cb..af7d918 100644
--- a/src/Utils.cc
+++ b/src/Utils.cc
@@ -206,4 +206,7 @@ bool fbgemmHasAvx2Support() {
return (cpuinfo_initialize() && cpuinfo_has_x86_avx2());
}
+bool fbgemmHasAvx512VnniSupport() {
+ return (cpuinfo_has_x86_avx512vnni());
+}
} // namespace fbgemm
diff --git a/test/FP16Test.cc b/test/FP16Test.cc
index eb49086..3267655 100644
--- a/test/FP16Test.cc
+++ b/test/FP16Test.cc
@@ -27,7 +27,26 @@ using namespace fbgemm;
namespace {
// The template parameter is transpose of A and B
class FBGemmFP16Test
- : public testing::TestWithParam<pair<matrix_op_t, matrix_op_t>> {};
+ : public testing::TestWithParam<pair<matrix_op_t, matrix_op_t>> {
+ protected:
+ vector<vector<int>> GenShapes() const {
+ vector<vector<int>> shapes;
+ random_device r;
+ default_random_engine generator(r());
+ uniform_int_distribution<int> dm(1, 256);
+ uniform_int_distribution<int> dnk(1, 1024);
+ for (int i = 0; i < 10; i++) {
+ int m = dm(generator);
+ int n = dnk(generator);
+ int k = dnk(generator);
+ shapes.push_back({m, n, k});
+ if (m > 10) {
+ shapes.push_back({(m / 10) * 10, n, k});
+ }
+ }
+ return shapes;
+ }
+};
}; // namespace
INSTANTIATE_TEST_CASE_P(
@@ -44,21 +63,75 @@ INSTANTIATE_TEST_CASE_P(
matrix_op_t::Transpose, matrix_op_t::Transpose)*/));
TEST_P(FBGemmFP16Test, Test) {
- vector<vector<int>> shapes;
- random_device r;
- default_random_engine generator(r());
- uniform_int_distribution<int> dm(1, 256);
- uniform_int_distribution<int> dnk(1, 1024);
- for (int i = 0; i < 10; i++) {
- int m = dm(generator);
- int n = dnk(generator);
- int k = dnk(generator);
- shapes.push_back({m, n, k});
- if (m > 10) {
- shapes.push_back({(m / 10) * 10, n, k});
+ auto shapes = GenShapes();
+ float alpha = 1.f, beta = 0.f;
+ matrix_op_t atrans, btrans;
+ tie(atrans, btrans) = GetParam();
+
+ for (auto s : shapes) {
+ int m = s[0];
+ int n = s[1];
+ int k = s[2];
+
+ cerr << "m = " << m << " n = " << n << " k = " << k;
+ if (atrans == matrix_op_t::Transpose) {
+ cerr << " A_transposed";
+ }
+ if (btrans == matrix_op_t::Transpose) {
+ cerr << " B_transposed";
+ }
+ cerr << endl;
+
+ // initialize with small numbers
+ aligned_vector<int> Aint(m * k);
+ aligned_vector<int> Bint(k * n);
+ randFill(Aint, 0, 4);
+ randFill(Bint, 0, 4);
+ aligned_vector<float> A(Aint.begin(), Aint.end());
+ aligned_vector<float> B(Bint.begin(), Bint.end());
+
+ aligned_vector<float> C(m * n, NAN);
+
+ aligned_vector<float> A_ref(A), B_ref(B), C_ref(C);
+
+ if (atrans == matrix_op_t::Transpose) {
+ transpose_matrix(A_ref.data(), k, m);
+ }
+ if (btrans == matrix_op_t::Transpose) {
+ transpose_matrix(B_ref.data(), n, k);
+ }
+
+ // Gold via reference sgemm
+ matmul_fp_ref(m, n, k, k, n, n, A_ref.data(), B_ref.data(), C_ref.data());
+
+ // fbgemm fp16
+ PackedGemmMatrixFP16 Bp(btrans, k, n, alpha, B.data());
+#ifdef _OPENMP
+#pragma omp parallel
+#endif
+ {
+ int num_threads = fbgemm_get_num_threads();
+ int tid = fbgemm_get_thread_num();
+
+ cblas_gemm_compute(
+ atrans, m, A.data(), Bp, beta, C.data(), tid, num_threads);
+ }
+
+ // correctness check
+ for (int i = 0; i < m; ++i) {
+ for (int j = 0; j < n; ++j) {
+ float expected = C_ref[i * n + j];
+ float actual = C[i * n + j];
+ EXPECT_EQ(expected, actual)
+ << "GEMM results differ at (" << i << ", " << j << "). ref "
+ << expected << " FBGemm " << actual;
+ }
}
}
+}
+TEST_P(FBGemmFP16Test, Unpack) {
+ auto shapes = GenShapes();
float alpha = 1.f, beta = 0.f;
matrix_op_t atrans, btrans;
tie(atrans, btrans) = GetParam();
@@ -101,6 +174,23 @@ TEST_P(FBGemmFP16Test, Test) {
// fbgemm fp16
PackedGemmMatrixFP16 Bp(btrans, k, n, alpha, B.data());
+ EXPECT_TRUE(Bp.packed());
+
+ // Test unpack
+ aligned_vector<float16> tmp(Bp.matSize());
+ memcpy(tmp.data(), Bp.pmat(), Bp.matSize() * sizeof(float16));
+ Bp.unpackFromSrc(btrans, tmp.data());
+ EXPECT_FALSE(Bp.packed());
+ memcpy(tmp.data(), Bp.pmat(), Bp.matSize() * sizeof(float16));
+ for (int i = 0; i < k; ++i) {
+ for (int j = 0; j < n; ++j) {
+ EXPECT_EQ(B[i * n + j], cpu_half2float(tmp[i * n + j]));
+ }
+ }
+
+ // Pack it back
+ Bp.packFromSrc(btrans, tmp.data());
+ EXPECT_TRUE(Bp.packed());
#ifdef _OPENMP
#pragma omp parallel
diff --git a/test/GConvTest.cc b/test/GConvTest.cc
index 0074535..8c1fb82 100644
--- a/test/GConvTest.cc
+++ b/test/GConvTest.cc
@@ -465,8 +465,8 @@ TEST_P(fbgemmGConvPackTest, PackUnpackTest) {
for (int i = 0; i < weight_len; ++i) {
EXPECT_EQ(Bint8.data()[i], unpack_buf.data()[i])
<< "Pack/Unpack results differ at index " << i
- << ", Reference: " << static_cast<int> (Bint8.data()[i])
- << ", Pack-Unpacked: " << static_cast<int> (unpack_buf.data()[i]);
+ << ", Reference: " << static_cast<int>(Bint8.data()[i])
+ << ", Pack-Unpacked: " << static_cast<int>(unpack_buf.data()[i]);
}
} // for each shape
}
diff --git a/test/PackedRequantizeAcc16Test.cc b/test/PackedRequantizeAcc16Test.cc
index 23af3eb..62b1303 100644
--- a/test/PackedRequantizeAcc16Test.cc
+++ b/test/PackedRequantizeAcc16Test.cc
@@ -94,6 +94,8 @@ static vector<vector<int>> GetShapes_() {
{102, 512, 258},
{1024, 512, 258},
+
+ {120, 4, 288},
};
return shapes;
}
@@ -827,54 +829,67 @@ TEST_P(fbgemmPackUnpackAcc16Test, TestPackUnpack) {
bool test_ld;
tie(btrans, test_ld) = GetParam();
+ BlockingFactors params;
+ params.MCB = 48;
+ params.NCB = 16;
+ params.KCB = 256;
+ params.MR = 1;
+ params.NR = 16;
+ params.ROW_INTERLEAVE = 4;
+ params.NR_MIN = 16;
+ vector<BlockingFactors*> vec_params_ptr = {&params, nullptr};
+
for (auto shape : shapes) {
for (int groups : {1, 3, 4}) {
- int n = shape[1];
- int k = shape[2];
+ for (auto params_ptr : vec_params_ptr) {
+ int n = shape[1];
+ int k = shape[2];
- if (k % groups != 0) {
- continue;
- }
- int k_per_group = k / groups;
+ if (k % groups != 0) {
+ continue;
+ }
+ int k_per_group = k / groups;
- // kxn matrix
- aligned_vector<int8_t> Bint8(k * n);
- randFill<int8_t>(Bint8, -128, 127);
+ // kxn matrix
+ aligned_vector<int8_t> Bint8(k * n);
+ randFill<int8_t>(Bint8, -128, 127);
- // To test lda != k , we just reduce k by half and use the original k
- // as lda.
- int n_adjusted = n;
- if (test_ld) {
- if (btrans == matrix_op_t::NoTranspose) {
- n_adjusted = std::max(n / 2, 1);
+ // To test lda != k , we just reduce k by half and use the original k
+ // as lda.
+ int n_adjusted = n;
+ if (test_ld) {
+ if (btrans == matrix_op_t::NoTranspose) {
+ n_adjusted = std::max(n / 2, 1);
+ }
}
- }
- // Note that packing for weight is performed during the constructor
- // stage.
- PackBMatrix<int8_t, int16_t> packedWeights(
- btrans,
- k,
- n_adjusted,
- Bint8.data(),
- (btrans == matrix_op_t::Transpose) ? k_per_group : n,
- nullptr,
- groups);
+ // Note that packing for weight is performed during the constructor
+ // stage.
+ PackBMatrix<int8_t, int16_t> packedWeights(
+ btrans,
+ k,
+ n_adjusted,
+ Bint8.data(),
+ (btrans == matrix_op_t::Transpose) ? k_per_group : n,
+ nullptr,
+ groups,
+ params_ptr);
- // Setup a buffer to get pack -> unpacked results
- aligned_vector<int8_t> unpack_buf(k * n, 0);
+ // Setup a buffer to get pack -> unpacked results
+ aligned_vector<int8_t> unpack_buf(k * n, 0);
- // Perform unpacking
- packedWeights.unpack(unpack_buf.data());
+ // Perform unpacking
+ packedWeights.unpack(unpack_buf.data(), params_ptr);
- // Sanity check
- for (int i = 0; i < k; i++) {
- for (int j = 0; j < n_adjusted; j++) {
- EXPECT_EQ(Bint8.data()[i * n + j], unpack_buf.data()[i * n + j])
+ // Sanity check
+ for (int i = 0; i < k; i++) {
+ for (int j = 0; j < n_adjusted; j++) {
+ EXPECT_EQ(Bint8.data()[i * n + j], unpack_buf.data()[i * n + j])
<< "Pack/Unpack results differ at index (" << i << ", " << j
<< ", Reference: " << static_cast<int>(Bint8.data()[i * n + j])
<< ", Pack-Unpacked: "
<< static_cast<int>(unpack_buf.data()[i * n + j]);
+ }
}
}
}
diff --git a/test/PackedRequantizeTest.cc b/test/PackedRequantizeTest.cc
index 11ef6ff..5338243 100644
--- a/test/PackedRequantizeTest.cc
+++ b/test/PackedRequantizeTest.cc
@@ -93,6 +93,8 @@ static vector<vector<int>> GetShapes_() {
{102, 512, 258},
{1024, 512, 258},
+
+ {120, 4, 288},
};
return shapes;
}
@@ -766,54 +768,67 @@ TEST_P(fbgemmPackUnpackAcc32Test, TestPackUnpack) {
bool test_ld;
tie(btrans, test_ld) = GetParam();
+ BlockingFactors params;
+ params.MCB = 48;
+ params.NCB = 16;
+ params.KCB = 256;
+ params.MR = 1;
+ params.NR = 16;
+ params.ROW_INTERLEAVE = 4;
+ params.NR_MIN = 16;
+ vector<BlockingFactors*> vec_params_ptr = {&params, nullptr};
+
for (auto shape : shapes) {
for (int groups : {1, 3, 4}) {
- int n = shape[1];
- int k = shape[2];
+ for (auto params_ptr : vec_params_ptr) {
+ int n = shape[1];
+ int k = shape[2];
- if (k % groups != 0) {
- continue;
- }
- int k_per_group = k / groups;
+ if (k % groups != 0) {
+ continue;
+ }
+ int k_per_group = k / groups;
- // kxn matrix
- aligned_vector<int8_t> Bint8(k * n);
- randFill<int8_t>(Bint8, -128, 127);
+ // kxn matrix
+ aligned_vector<int8_t> Bint8(k * n);
+ randFill<int8_t>(Bint8, -128, 127);
- // To test lda != k , we just reduce k by half and use the original k
- // as lda.
- int n_adjusted = n;
- if (test_ld) {
- if (btrans == matrix_op_t::NoTranspose) {
- n_adjusted = std::max(n / 2, 1);
+ // To test lda != k , we just reduce k by half and use the original k
+ // as lda.
+ int n_adjusted = n;
+ if (test_ld) {
+ if (btrans == matrix_op_t::NoTranspose) {
+ n_adjusted = std::max(n / 2, 1);
+ }
}
- }
- // Note that packing for weight is performed during the constructor
- // stage.
- PackBMatrix<int8_t> packedWeights(
- btrans,
- k,
- n_adjusted,
- Bint8.data(),
- (btrans == matrix_op_t::Transpose) ? k_per_group : n,
- nullptr,
- groups);
+ // Note that packing for weight is performed during the constructor
+ // stage.
+ PackBMatrix<int8_t> packedWeights(
+ btrans,
+ k,
+ n_adjusted,
+ Bint8.data(),
+ (btrans == matrix_op_t::Transpose) ? k_per_group : n,
+ nullptr,
+ groups,
+ params_ptr);
- // Setup a buffer to get pack -> unpacked results
- aligned_vector<int8_t> unpack_buf(k * n, 0);
+ // Setup a buffer to get pack -> unpacked results
+ aligned_vector<int8_t> unpack_buf(k * n, 0);
- // Perform unpacking
- packedWeights.unpack(unpack_buf.data());
+ // Perform unpacking
+ packedWeights.unpack(unpack_buf.data(), params_ptr);
- // Sanity check
- for (int i = 0; i < k; i++) {
- for (int j = 0; j < n_adjusted; j++) {
- EXPECT_EQ(Bint8.data()[i * n + j], unpack_buf.data()[i * n + j])
+ // Sanity check
+ for (int i = 0; i < k; i++) {
+ for (int j = 0; j < n_adjusted; j++) {
+ EXPECT_EQ(Bint8.data()[i * n + j], unpack_buf.data()[i * n + j])
<< "Pack/Unpack results differ at index (" << i << ", " << j
<< ", Reference: " << static_cast<int>(Bint8.data()[i * n + j])
<< ", Pack-Unpacked: "
<< static_cast<int>(unpack_buf.data()[i * n + j]);
+ }
}
}
}
diff --git a/third_party/asmjit b/third_party/asmjit
-Subproject 673dcefaa048c5f5a2bf8b85daf8f7b9978d018
+Subproject 4da474ac9aa2689e88d5e40a2f37628f302d7e3