diff options
author | Young Jin Kim <youki@microsoft.com> | 2019-08-15 02:01:38 +0300 |
---|---|---|
committer | Young Jin Kim <youki@microsoft.com> | 2019-08-15 02:01:38 +0300 |
commit | bb5063533256a8a5a91a812f6a193d7f352a2a3a (patch) | |
tree | 1e64f7127e589ea32a01785198c1cff8fa2813dd | |
parent | eb8fede25bd048da6fd396654936703a474f0504 (diff) | |
parent | a6d1d3eed7ba858d4532fc297b7a4ee984e6e7e3 (diff) |
Merge branch 'upstream/master' into youki/prepack_constrcopyPublic
32 files changed, 1398 insertions, 473 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index e6c7419..0460799 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -37,8 +37,10 @@ set(FBGEMM_GENERIC_SRCS src/ExecuteKernel.cc src/FbgemmI8Spmdm.cc src/GenerateKernelU8S8S32ACC16.cc src/GenerateKernelU8S8S32ACC16Avx512.cc + src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc src/GenerateKernelU8S8S32ACC32.cc src/GenerateKernelU8S8S32ACC32Avx512.cc + src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc src/GroupwiseConvAcc32Avx2.cc src/PackAMatrix.cc src/PackAWithIm2Col.cc @@ -12,9 +12,9 @@ row-wise quantization and outlier-aware quantization. FBGEMM also exploits fusion opportunities in order to overcome the unique challenges of matrix multiplication at lower precision with bandwidth-bound operations. -FBGEMM is used as a backend of Caffe2 quantized operators for x86 machines -(https://github.com/pytorch/pytorch/tree/master/caffe2/quantization/server). -We also plan to integrate FBGEMM into PyTorch. +FBGEMM is used as a backend of Caffe2 and PyTorch quantized operators for x86 machines: +* Caffe2: https://github.com/pytorch/pytorch/tree/master/caffe2/quantization/server +* PyTorch: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native/quantized/cpu ## Examples diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h index 7f428ed..70f6294 100644 --- a/include/fbgemm/Fbgemm.h +++ b/include/fbgemm/Fbgemm.h @@ -458,14 +458,17 @@ class FBGEMM_API PackBMatrix final std::int32_t addr(std::int32_t i, std::int32_t j) const; /** - * @brief Packs a block of source matrix into pmat buffer. + * @brief Packs a block of source matrix into pmat buffer. The blocking + * parameters are needed to compute the buffer size of each group. + * It will use default blocking parameters if params is not provided. */ - void pack(const block_type_t& block); + void pack(const block_type_t& block, const BlockingFactors* params = nullptr); /** * @brief Print the packed block. */ - void printPackedMatrix(std::string name); + void printPackedMatrix(std::string name, + const BlockingFactors* params = nullptr); /** * @return true if meta information like matrix shape is the same. @@ -480,7 +483,7 @@ class FBGEMM_API PackBMatrix final * @brief Unpack pmat buffer to the origin_buf (Used for the serialization to * recover weight matrix). */ - void unpack(T* origin_buf); + void unpack(T* origin_buf, const BlockingFactors* params = nullptr); ~PackBMatrix() {} @@ -497,7 +500,8 @@ class FBGEMM_API PackBMatrix final const block_type_t& block, T* unpack_buf, T* pack_buf, - bool ispack); + bool ispack, + const BlockingFactors* params = nullptr); }; /** @@ -645,6 +649,11 @@ class FBGEMM_API PackWeightsForConv { bool isPackingCompliant(const conv_param_t<SPATIAL_DIM>& conv_p); /** + * @brief Returns a string of mismatching parameters + */ + std::string mismatchingParams(const conv_param_t<SPATIAL_DIM>& conv_p); + + /** * @brief Unpack packed matric into origin_buf (Used for the serialization to * recover weight matrix). */ diff --git a/include/fbgemm/FbgemmFP16.h b/include/fbgemm/FbgemmFP16.h index 3d84977..8da0b56 100644 --- a/include/fbgemm/FbgemmFP16.h +++ b/include/fbgemm/FbgemmFP16.h @@ -104,6 +104,14 @@ class PackedGemmMatrixFP16 { } } + void setPacked(bool p) { + packed_ = p; + } + + bool packed() const { + return packed_; + } + void initializeMemory() { // allocate and initialize packed memory const int padding = 1024; // required by sw pipelined kernels @@ -128,6 +136,16 @@ class PackedGemmMatrixFP16 { #endif } + void unpackFromSrc(const matrix_op_t trans, float16* src_mat) { + bool tr = (trans == matrix_op_t::Transpose); + for (int i = 0; i < numRows(); i++) { + for (int j = 0; j < numCols(); j++) { + pmat_[tr ? i + numRows() * j : i * numCols() + j] = src_mat[addr(i, j)]; + } + } + packed_ = false; + } + // protected: // blocked row-major format address arithmetic uint64_t addr(const int r_, const int c_) const { @@ -163,6 +181,19 @@ class PackedGemmMatrixFP16 { pmat_[addr(i, j)]); } } + packed_ = true; + } + + // This function takes in an unpacked float16 matrix of the same size and + // packs it. There is no floating type conversion. + void packFromSrc(const matrix_op_t trans, const float16* smat) { + bool tr = (trans == matrix_op_t::Transpose); + for (int i = 0; i < numRows(); ++i) { + for (int j = 0; j < numCols(); ++j) { + pmat_[addr(i, j)] = smat[tr ? i + numRows() * j : i * numCols() + j]; + } + } + packed_ = true; } const float16& operator()(const int r, const int c) const { @@ -210,6 +241,7 @@ class PackedGemmMatrixFP16 { uint64_t size_; int kernel_ncol_blocks_; float16* pmat_; + bool packed_{false}; friend void cblas_gemm_compute( const matrix_op_t transa, diff --git a/include/fbgemm/PackingTraits-inl.h b/include/fbgemm/PackingTraits-inl.h index 76eb425..baccfad 100644 --- a/include/fbgemm/PackingTraits-inl.h +++ b/include/fbgemm/PackingTraits-inl.h @@ -222,3 +222,53 @@ struct PackingTraits< 128}; ///< Cache block for N dimension (multiple of NR). static constexpr int KCB{256}; ///< Cache block for K dimension. }; + +/** + * @brief Helper struct to type specialize for int16_t and int32_t together. + */ +template <typename T> +struct is_16or32bit { + static constexpr bool value = + std::is_same<T, int16_t>::value || std::is_same<T, int32_t>::value; +}; + +/** + * @brief Packing parameter specialization for accumulation into 32-bit/16-bit + * integers. + * + * Since there is no int16_t accumulation for AVX512 VNNI, we redirect int16_t + * to int32_t accumulation and use the same blocking parameters as int32_t. + * + * This is picked when T is of int8 type (signed or unsigned) and instruction + * set is avx512_vnni. + */ +template <typename T, typename accT> +struct PackingTraits< + T, + accT, + inst_set_t::avx512_vnni, + typename std::enable_if< + is_8bit<T>::value && is_16or32bit<accT>::value>::type> { + static constexpr int MR{8}; ///< Register block for M dimension. + static constexpr int NR_MIN{ + 16}; ///< Minimum register block for N dimension. + ///< 16 because 16*ROW_INTERLEAVE int8 elements + ///< completely fill a 512-bit wide vector. + static constexpr int NR{ + 32}; ///< Register block for N dimension. + ///< Must be a multiple of 16 because 16*ROW_INTERLEAVE int8 elements + ///< completely fill a 512-bit wide vector. Total registers used for + ///< N dimension: NR*ROW_INTERLEAVE*8/VLEN. We use MR x + ///< NR*ROW_INTERLEAVE*8/VLEN zmm registers + ///< for C accumulations. + + static constexpr int ROW_INTERLEAVE{ + 4}; ///< 4 rows are interleaved to use vpmaddubsw instruction for packing + ///< B matrix. + + static constexpr int MCB{ + 128}; ///< Cache block for M dimension (multiple of MR). + static constexpr int NCB{ + 32}; ///< Cache block for N dimension (multiple of NR). + static constexpr int KCB{256}; ///< Cache block for K dimension. +}; diff --git a/include/fbgemm/Utils.h b/include/fbgemm/Utils.h index eac0bcd..3976790 100644 --- a/include/fbgemm/Utils.h +++ b/include/fbgemm/Utils.h @@ -5,6 +5,7 @@ * LICENSE file in the root directory of this source tree. */ #pragma once +#include <array> #include <string> #include <type_traits> #include "FbgemmBuild.h" @@ -39,7 +40,7 @@ enum class matrix_op_t { NoTranspose, Transpose }; /** * @brief Typed enum for supported instruction sets. */ -enum class inst_set_t { anyarch, avx2, avx512 }; +enum class inst_set_t { anyarch, avx2, avx512, avx512_vnni }; /** * @brief Typed enum for optimized paths for convolutions @@ -110,6 +111,11 @@ FBGEMM_API bool fbgemmHasAvx512Support(); FBGEMM_API bool fbgemmHasAvx2Support(); /** + * @brief Are we running on a AVX512_VNNI supported cpu? + */ +FBGEMM_API bool fbgemmHasAvx512VnniSupport(); + +/** * @brief Helper struct to enable autotuning of FBGEMM packing and kernels. * * This structure is optional. If not used, the default values for these @@ -126,6 +132,16 @@ struct FBGEMM_API BlockingFactors { int NCB; }; +template <int SIZE, typename T = std::int32_t> +FBGEMM_API std::string arrayToString(const std::array<T, SIZE>& inp) { + std::string out = "["; + for (int i = 0; i < SIZE; ++i) { + out += std::to_string(inp[i]); + out += (i != SIZE - 1) ? std::string(", ") : std::string("]"); + } + return out; +} + template <typename accT = std::int32_t> FBGEMM_API bool isValidBlockingFactor(BlockingFactors* param) { constexpr bool is_32bit = std::is_same<accT, int32_t>::value; diff --git a/src/ExecuteKernelU8S8.cc b/src/ExecuteKernelU8S8.cc index f7292fd..0a4ff55 100644 --- a/src/ExecuteKernelU8S8.cc +++ b/src/ExecuteKernelU8S8.cc @@ -49,7 +49,8 @@ ExecuteKernel< throw std::runtime_error("Failed to initialize cpuinfo!"); } if (params) { - if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) { + if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support() || + fbgemmHasAvx2Support()) { mbSize_ = params->MCB; nbSize_ = params->NCB; nrMinSize_ = params->NR_MIN; @@ -59,7 +60,20 @@ ExecuteKernel< assert(0 && "unsupported architecure"); } } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + mbSize_ = PackingTraits< + int8_t, + typename packingAMatrix::accType, + inst_set_t::avx512_vnni>::MCB; + nbSize_ = PackingTraits< + int8_t, + typename packingAMatrix::accType, + inst_set_t::avx512_vnni>::NCB; + nrMinSize_ = PackingTraits< + int8_t, + typename packingAMatrix::accType, + inst_set_t::avx512_vnni>::NR_MIN; + } else if (fbgemmHasAvx512Support()) { mbSize_ = PackingTraits< int8_t, typename packingAMatrix::accType, @@ -118,7 +132,25 @@ void ExecuteKernel< typename BaseType::jit_micro_kernel_fp fn; - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + if (std::is_same<typename packingAMatrix::accType, std::int16_t>::value) { + // For AVX512VNNI, we redirect int16_t to int32_t accumulation. + CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj; + fn = codeObj.getOrCreate<inst_set_t::avx512_vnni>( + accum, + packed_rows_A, + packedB_.blockColSize(), + packedA_.numPackedCols(), + nbSize_); + } else { + fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni>( + accum, + packed_rows_A, + packedB_.blockColSize(), + packedA_.numPackedCols(), + nbSize_); + } + } else if (fbgemmHasAvx512Support()) { fn = BaseType::template getOrCreate<inst_set_t::avx512>( accum, packed_rows_A, @@ -148,7 +180,10 @@ void ExecuteKernel< if (jb == bColBlocks - 1) { int nc = ((packedB_.lastBcol() - 1) / nrMinSize_ + 1) * nrMinSize_; if (nc != nbSize_) { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni>( + accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_); + } else if (fbgemmHasAvx512Support()) { fn = BaseType::template getOrCreate<inst_set_t::avx512>( accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_); } else if (fbgemmHasAvx2Support()) { @@ -213,7 +248,7 @@ void ExecuteKernel< int32_t nSize = C_buffer_start == C_tile_ ? jb * nbSize_ : packedB_.numCols(); if (nSize) { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support()) { // TODO: avx512 path // Currently use avx2 code outputProcess_.template f<inst_set_t::avx2>( @@ -238,7 +273,7 @@ void ExecuteKernel< if (C_buffer_start == C_tile_) { // When C_tile_ scratchpad was used to avoid accessing memory past // C_buffer_ . - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support()) { // TODO: avx512 path // Currently use avx2 code outputProcess_.template f<inst_set_t::avx2>( diff --git a/src/Fbgemm.cc b/src/Fbgemm.cc index 2f641ee..1052044 100644 --- a/src/Fbgemm.cc +++ b/src/Fbgemm.cc @@ -48,7 +48,8 @@ void fbgemmPacked( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) { + if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() && + !fbgemmHasAvx2Support())) { assert(0 && "unknown architecure"); } @@ -62,7 +63,20 @@ void fbgemmPacked( MR = blocking_params->MR; } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + MCB = PackingTraits< + typename packingAMatrix::inpType, + typename packingAMatrix::accType, + inst_set_t::avx512_vnni>::MCB; + KCB = PackingTraits< + typename packingAMatrix::inpType, + typename packingAMatrix::accType, + inst_set_t::avx512_vnni>::KCB; + MR = PackingTraits< + typename packingAMatrix::inpType, + typename packingAMatrix::accType, + inst_set_t::avx512_vnni>::MR; + } else if (fbgemmHasAvx512Support()) { MCB = PackingTraits< typename packingAMatrix::inpType, typename packingAMatrix::accType, diff --git a/src/FbgemmConv.cc b/src/FbgemmConv.cc index 027e6c5..33d1535 100644 --- a/src/FbgemmConv.cc +++ b/src/FbgemmConv.cc @@ -73,9 +73,14 @@ int fbgemmConv( "Only 2D and 3D convolutions are supported"); if (!packed_weights.isPackingCompliant(conv_p)) { - throw std::logic_error( - "[FBGEMM_CONV_ERROR] Prepacked weights can't be used" - " with these convolution parameters!"); + std::string msg = + "[FBGEMM_CONV_ERROR] Convolution parameters " + "mismatch between pre-packed weights and conv invocation! "; + msg += packed_weights.mismatchingParams(conv_p); + msg += std::string( + " Please pack weights using the same parameters " + "with which convolution operation is invoked!"); + throw std::logic_error(msg); } switch (ConvFastPath<SPATIAL_DIM, ACC_T>(conv_p)) { diff --git a/src/GenerateKernel.h b/src/GenerateKernel.h index dccdfc5..e52097e 100644 --- a/src/GenerateKernel.h +++ b/src/GenerateKernel.h @@ -18,7 +18,7 @@ namespace fbgemm { namespace x86 = asmjit::x86; /** - * @brief AVX2/AVX512 JIT assembly code generator. + * @brief AVX2/AVX512/AVX512VNNI JIT assembly code generator. * @tparam TA Type of matrix A. * @tparam TB Type of matrix B. * @tparam TC Type of matrix C. @@ -104,7 +104,7 @@ class CodeGenBase { */ template <inst_set_t instSet> void initCRegs( - asmjit::X86Emitter* a, + x86::Emitter* a, int rowRegs, int colRegs, int leadingDimCRegAssign = 4); @@ -114,10 +114,10 @@ class CodeGenBase { */ template <inst_set_t instSet> void genComputeBlock( - asmjit::X86Emitter* a, - asmjit::X86Gp buffer_A, - asmjit::X86Gp buffer_B, - asmjit::X86Gp B_pf, + x86::Emitter* a, + x86::Gp buffer_A, + x86::Gp buffer_B, + x86::Gp B_pf, int rowRegs, int colRegs, int lda, @@ -129,11 +129,11 @@ class CodeGenBase { */ template <inst_set_t instSet> void storeCRegs( - asmjit::X86Emitter* a, + x86::Emitter* a, int rowRegs, int colRegs, - asmjit::X86Gp C_Offset, - asmjit::X86Gp ldcReg, + x86::Gp C_Offset, + x86::Gp ldcReg, bool accum, int leadingDimCRegAssign = 4); @@ -168,7 +168,9 @@ class CodeGenBase { fileName += "_MR-" + std::to_string(MR); fileName += "_NR-" + std::to_string(NR); fileName += "_NR_MIN-" + std::to_string(NR_MIN); - if (instSet == inst_set_t::avx512) { + if (instSet == inst_set_t::avx512_vnni) { + fileName += "_avx512vnni"; + } else if (instSet == inst_set_t::avx512) { fileName += "_avx512"; } else if (instSet == inst_set_t::avx2) { fileName += "_avx2"; @@ -178,12 +180,10 @@ class CodeGenBase { } private: - asmjit::X86Ymm - CRegs_avx2_[12]; ///< AVX2 ymm registers for C in the micro-kernel. - asmjit::X86Zmm + x86::Ymm CRegs_avx2_[12]; ///< AVX2 ymm registers for C in the micro-kernel. + x86::Zmm CRegs_avx512_[28]; ///< AVX512 zmm registers for C in the micro-kernel. - asmjit::X86Zmm - AllRegs_avx512_[32]; ///< all AVX512 zmm registers. + x86::Zmm AllRegs_avx512_[32]; ///< all AVX512 zmm registers. int vectorWidth_; ///< Vector width in bits. int VLEN_; ///< Vector width in elements. diff --git a/src/GenerateKernelU8S8S32ACC16.cc b/src/GenerateKernelU8S8S32ACC16.cc index 718b883..1e7e081 100644 --- a/src/GenerateKernelU8S8S32ACC16.cc +++ b/src/GenerateKernelU8S8S32ACC16.cc @@ -31,7 +31,7 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs< inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int rowRegs, int colRegs, int leadingDimCReg) { @@ -53,18 +53,18 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock< inst_set_t::avx2>( - asmjit::X86Emitter* a, - asmjit::X86Gp buffer_A, - asmjit::X86Gp buffer_B, - asmjit::X86Gp /* unused (reserved for prefetching)*/, + x86::Emitter* a, + x86::Gp buffer_A, + x86::Gp buffer_B, + x86::Gp /* unused (reserved for prefetching)*/, int rowRegs, int colRegs, int lda, int leadingDimCReg) { // used for matrix A - asmjit::X86Ymm AReg = x86::ymm12; + x86::Ymm AReg = x86::ymm12; - asmjit::X86Ymm tmpReg = x86::ymm14; + x86::Ymm tmpReg = x86::ymm14; for (int i = 0; i < rowRegs; ++i) { // broadcast A @@ -95,15 +95,15 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs< inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int rowRegs, int colRegs, - asmjit::X86Gp C_Offset, - asmjit::X86Gp ldcReg, + x86::Gp C_Offset, + x86::Gp ldcReg, bool accum, int leadingDimCReg) { - asmjit::X86Xmm extractDest128 = x86::xmm15; - asmjit::X86Ymm extractDest256 = x86::ymm15; + x86::Xmm extractDest128 = x86::xmm15; + x86::Ymm extractDest256 = x86::ymm15; for (int i = 0; i < rowRegs; ++i) { a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(i * sizeof(int32_t))); @@ -112,7 +112,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs< a->vextracti128( extractDest128, CRegs_avx2_[i * leadingDimCReg + j], idx); a->vpmovsxwd(extractDest256, extractDest128); - asmjit::X86Mem destAddr = x86::dword_ptr( + x86::Mem destAddr = x86::dword_ptr( a->zcx(), C_Offset, 0, (j * 2 + idx) * 8 * sizeof(int32_t)); if (accum) { a->vpaddd(extractDest256, extractDest256, destAddr); @@ -176,9 +176,9 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>( return codeCache_[kernelSig]; } code_.reset(false); - code_.init(rt_.getCodeInfo()); - asmjit::X86Assembler assembler(&code_); - asmjit::X86Emitter* a = assembler.asEmitter(); + code_.init(rt_.codeInfo()); + x86::Assembler assembler(&code_); + x86::Emitter* a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) // generated code logging @@ -207,46 +207,45 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>( //"nc must be equal to the number of register blocks"); // arguments to the function created - asmjit::X86Gp buffer_A = a->zdi(); - asmjit::X86Gp buffer_B = a->zsi(); - asmjit::X86Gp B_pf = a->zdx(); - asmjit::X86Gp CBase = a->zcx(); - asmjit::X86Gp kSize = a->gpzRef(8); - asmjit::X86Gp ldcReg = a->gpzRef(9); + x86::Gp buffer_A = a->zdi(); + x86::Gp buffer_B = a->zsi(); + x86::Gp B_pf = a->zdx(); + x86::Gp CBase = a->zcx(); + x86::Gp kSize = a->gpz(8); + x86::Gp ldcReg = a->gpz(9); asmjit::FuncDetail func; func.init( asmjit:: - FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( + FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( asmjit::CallConv::kIdHost)); - asmjit::FuncFrameInfo ffi; - ffi.setDirtyRegs( - asmjit::X86Reg::kKindVec, - asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); - ffi.setDirtyRegs( - asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14)); + asmjit::FuncFrame frame; + frame.init(func); + frame.setDirtyRegs( + x86::Reg::kGroupVec, + asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.setDirtyRegs( + x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14)); - asmjit::FuncArgsMapper args(&func); + asmjit::FuncArgsAssignment args(&func); args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); - args.updateFrameInfo(ffi); + args.updateFuncFrame(frame); + frame.finalize(); - asmjit::FuncFrameLayout layout; - layout.init(func, ffi); - - asmjit::FuncUtils::emitProlog(a, layout); - asmjit::FuncUtils::allocArgs(a, layout, args); + a->emitProlog(frame); + a->emitArgsAssignment(frame, args); asmjit::Label Loopk = a->newLabel(); asmjit::Label LoopMBlocks = a->newLabel(); - asmjit::X86Gp buffer_B_saved = a->gpzRef(10); - asmjit::X86Gp C_Offset = a->gpzRef(11); - // asmjit::X86Gp B_pf_saved = a->gpzRef(12); - asmjit::X86Gp iIdx = a->gpzRef(13); - asmjit::X86Gp kIdx = a->gpzRef(14); + x86::Gp buffer_B_saved = a->gpz(10); + x86::Gp C_Offset = a->gpz(11); + // x86::Gp B_pf_saved = a->gpz(12); + x86::Gp iIdx = a->gpz(13); + x86::Gp kIdx = a->gpz(14); int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; if (mRegBlocks > 0) { @@ -289,8 +288,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>( a->jl(Loopk); // store C matrix - storeCRegs<inst_set_t::avx2>( - a, rowRegs, colRegs, C_Offset, ldcReg, accum); + storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum); // increment A for next block a->sub(buffer_A, kSize); @@ -340,11 +338,10 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>( a->jl(LoopkRem); // store C matrix - storeCRegs<inst_set_t::avx2>( - a, rowRegs, colRegs, C_Offset, ldcReg, accum); + storeCRegs<inst_set_t::avx2>(a, rowRegs, colRegs, C_Offset, ldcReg, accum); } - asmjit::FuncUtils::emitEpilog(a, layout); + a->emitEpilog(frame); jit_micro_kernel_fp fn; asmjit::Error err = rt_.add(&fn, &code_); diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc index c95757b..a49e440 100644 --- a/src/GenerateKernelU8S8S32ACC16Avx512.cc +++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc @@ -19,7 +19,7 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs< inst_set_t::avx512>( - asmjit::X86Emitter* a, + x86::Emitter* a, int rowRegs, int colRegs, int leadingDimCReg) { @@ -41,18 +41,18 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock< inst_set_t::avx512>( - asmjit::X86Emitter* a, - asmjit::X86Gp buffer_A, - asmjit::X86Gp buffer_B, - asmjit::X86Gp /* unused (reserved for prefetching)*/, + x86::Emitter* a, + x86::Gp buffer_A, + x86::Gp buffer_B, + x86::Gp /* unused (reserved for prefetching)*/, int rowRegs, int colRegs, int lda, int leadingDimCReg) { // used for matrix A - asmjit::X86Zmm AReg = x86::zmm29; + x86::Zmm AReg = x86::zmm29; - asmjit::X86Zmm tmpReg = x86::zmm30; + x86::Zmm tmpReg = x86::zmm30; // We start allocating BRegs from zmm27 and then allocate zmm26 and so on. for (int j = 0; j < colRegs; ++j) { @@ -66,8 +66,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock< a->vpbroadcastw( AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t))); for (int j = 0; j < colRegs; ++j) { - a->vpmaddubsw( - tmpReg, AReg, AllRegs_avx512_[27-j]); + a->vpmaddubsw(tmpReg, AReg, AllRegs_avx512_[27 - j]); a->vpaddsw( CRegs_avx512_[i * leadingDimCReg + j], tmpReg, @@ -90,15 +89,16 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs< inst_set_t::avx512>( - asmjit::X86Emitter* a, + x86::Emitter* a, int rowRegs, int colRegs, - asmjit::X86Gp C_Offset, - asmjit::X86Gp ldcReg, + x86::Gp C_Offset, + x86::Gp ldcReg, + bool accum, int leadingDimCReg) { - asmjit::X86Ymm extractDest256 = x86::ymm31; - asmjit::X86Zmm extractDest512 = x86::zmm31; + x86::Ymm extractDest256 = x86::ymm31; + x86::Zmm extractDest512 = x86::zmm31; for (int i = 0; i < rowRegs; ++i) { a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(i * sizeof(int32_t))); @@ -107,7 +107,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs< a->vextracti32x8( extractDest256, CRegs_avx512_[i * leadingDimCReg + j], idx); a->vpmovsxwd(extractDest512, extractDest256); - asmjit::X86Mem destAddr = x86::dword_ptr( + x86::Mem destAddr = x86::dword_ptr( a->zcx(), C_Offset, 0, (j * 2 + idx) * 16 * sizeof(int32_t)); if (accum) { a->vpaddd(extractDest512, extractDest512, destAddr); @@ -172,9 +172,9 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( } code_.reset(false); - code_.init(rt_.getCodeInfo()); - asmjit::X86Assembler assembler(&code_); - asmjit::X86Emitter* a = assembler.asEmitter(); + code_.init(rt_.codeInfo()); + x86::Assembler assembler(&code_); + x86::Emitter* a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) // generated code logging @@ -209,49 +209,49 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( int mRegBlocksRem = mc % mRegBlockSize; // arguments to the function created - asmjit::X86Gp buffer_A = a->zdi(); - asmjit::X86Gp buffer_B = a->zsi(); - asmjit::X86Gp B_pf = a->zdx(); - asmjit::X86Gp CBase = a->zcx(); - asmjit::X86Gp kSize = a->gpzRef(8); - asmjit::X86Gp ldcReg = a->gpzRef(9); + x86::Gp buffer_A = a->zdi(); + x86::Gp buffer_B = a->zsi(); + x86::Gp B_pf = a->zdx(); + x86::Gp CBase = a->zcx(); + x86::Gp kSize = a->gpz(8); + x86::Gp ldcReg = a->gpz(9); asmjit::FuncDetail func; func.init( asmjit:: - FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( + FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( asmjit::CallConv::kIdHost)); - asmjit::FuncFrameInfo ffi; - ffi.setDirtyRegs( - asmjit::X86Reg::kKindVec, - asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); - ffi.setDirtyRegs( - asmjit::X86Reg::kKindGp, - asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); + asmjit::FuncFrame frame; + frame.init(func); - asmjit::FuncArgsMapper args(&func); - args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); + frame.setDirtyRegs( + x86::Reg::kGroupVec, + asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.setDirtyRegs( + x86::Reg::kGroupGp, + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - args.updateFrameInfo(ffi); + asmjit::FuncArgsAssignment args(&func); + args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); - asmjit::FuncFrameLayout layout; - layout.init(func, ffi); + args.updateFuncFrame(frame); + frame.finalize(); - asmjit::FuncUtils::emitProlog(a, layout); - asmjit::FuncUtils::allocArgs(a, layout, args); + a->emitProlog(frame); + a->emitArgsAssignment(frame, args); asmjit::Label LoopMBlocks = a->newLabel(); asmjit::Label LoopNBlocks = a->newLabel(); asmjit::Label Loopk = a->newLabel(); - asmjit::X86Gp buffer_B_saved = a->gpzRef(10); - asmjit::X86Gp C_Offset = a->gpzRef(11); - // asmjit::X86Gp B_pf_saved = a->gpzRef(12); - asmjit::X86Gp iIdx = a->gpzRef(13); - asmjit::X86Gp jIdx = a->gpzRef(14); - asmjit::X86Gp kIdx = a->gpzRef(15); + x86::Gp buffer_B_saved = a->gpz(10); + x86::Gp C_Offset = a->gpz(11); + // x86::Gp B_pf_saved = a->gpz(12); + x86::Gp iIdx = a->gpz(13); + x86::Gp jIdx = a->gpz(14); + x86::Gp kIdx = a->gpz(15); // save B_buffer address a->mov(buffer_B_saved, buffer_B); @@ -407,7 +407,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>( a->jl(LoopNRem); } - asmjit::FuncUtils::emitEpilog(a, layout); + a->emitEpilog(frame); jit_micro_kernel_fp fn; asmjit::Error err = rt_.add(&fn, &code_); diff --git a/src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc b/src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc new file mode 100644 index 0000000..f559aba --- /dev/null +++ b/src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc @@ -0,0 +1,102 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include <iostream> +#include "GenerateKernel.h" + +namespace fbgemm { + +namespace x86 = asmjit::x86; + +/** + * Generate AVX512 instructions for initializing the C registers to 0 in 16-bit + * Accumulation kernel. + */ +template <> +template <> +void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs< + inst_set_t::avx512_vnni>( + x86::Emitter* a, + int rowRegs, + int colRegs, + int leadingDimCReg) { + assert(0 && "Accumulation to int16_t is not available for VNNI!"); + + // For AVX512VNNI, redirect to int32_t accumulation. + CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj; + codeObj.initCRegs<inst_set_t::avx512_vnni>( + a, rowRegs, colRegs, leadingDimCReg); +} + +/** + * Generate AVX512 instructions for computing block in the rank-k update of + * 16-bit Accmulation kernel. + */ +template <> +template <> +void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock< + inst_set_t::avx512_vnni>( + x86::Emitter* a, + x86::Gp buffer_A, + x86::Gp buffer_B, + x86::Gp /* unused (reserved for prefetching)*/, + int rowRegs, + int colRegs, + int lda, + int leadingDimCReg) { + assert(0 && "Accumulation to int16_t is not available for VNNI!"); + + // For AVX512VNNI, redirect to int32_t accumulation. + CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj; + codeObj.genComputeBlock<inst_set_t::avx512_vnni>( + a, buffer_A, buffer_B, buffer_B, rowRegs, colRegs, lda, leadingDimCReg); +} + +/** + * Generate AVX512 instructions for storing the C registers back to the memory + * in 16-bit Accumulation kernel. + */ +template <> +template <> +void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs< + inst_set_t::avx512_vnni>( + x86::Emitter* a, + int rowRegs, + int colRegs, + x86::Gp C_Offset, + x86::Gp ldcReg, + bool accum, + int leadingDimCReg) { + assert(0 && "Accumulation to int16_t is not available for VNNI!"); + + // For AVX512VNNI, redirect to int32_t accumulation. + CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj; + codeObj.storeCRegs<inst_set_t::avx512_vnni>( + a, rowRegs, colRegs, C_Offset, ldcReg, accum, leadingDimCReg); +} + +/** + * Get or Create the AVX512 instructions for 16-bit Accumulation macro-kernel. + * + */ +template <> +template <> +CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::jit_micro_kernel_fp +CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate< + inst_set_t::avx512_vnni>( + bool accum, + int32_t mc, + int32_t nc, + int32_t kc, + int32_t /* unused */) { + assert(0 && "Accumulation to int16_t is not available for VNNI!"); + + // For AVX512VNNI, redirect to int32_t accumulation. + CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj; + return codeObj.getOrCreate<inst_set_t::avx512_vnni>(accum, mc, nc, kc, kc); +} + +} // namespace fbgemm diff --git a/src/GenerateKernelU8S8S32ACC32.cc b/src/GenerateKernelU8S8S32ACC32.cc index 58643ad..6b54743 100644 --- a/src/GenerateKernelU8S8S32ACC32.cc +++ b/src/GenerateKernelU8S8S32ACC32.cc @@ -31,7 +31,7 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs< inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int rowRegs, int colRegs, int leadingDimCReg) { @@ -53,25 +53,25 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock< inst_set_t::avx2>( - asmjit::X86Emitter* a, - asmjit::X86Gp buffer_A, - asmjit::X86Gp buffer_B, - asmjit::X86Gp B_pf, + x86::Emitter* a, + x86::Gp buffer_A, + x86::Gp buffer_B, + x86::Gp B_pf, int rowRegs, int colRegs, int lda, int leadingDimCReg) { // used for matrix A - asmjit::X86Ymm AReg = x86::ymm12; + x86::Ymm AReg = x86::ymm12; // used for matrix B - asmjit::X86Ymm BReg = x86::ymm13; + x86::Ymm BReg = x86::ymm13; // Contains 16-bit 1s - asmjit::X86Ymm oneReg = x86::ymm15; + x86::Ymm oneReg = x86::ymm15; // temporary register - asmjit::X86Ymm res1 = x86::ymm14; + x86::Ymm res1 = x86::ymm14; for (int j = 0; j < colRegs; ++j) { // load B @@ -99,11 +99,11 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs< inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int rowRegs, int colRegs, - asmjit::X86Gp C_Offset, - asmjit::X86Gp ldcReg, + x86::Gp C_Offset, + x86::Gp ldcReg, bool accum, int leadingDimCReg) { for (int i = 0; i < rowRegs; ++i) { @@ -177,9 +177,9 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>( return codeCache_[kernelSig]; } code_.reset(false); - code_.init(rt_.getCodeInfo()); - asmjit::X86Assembler assembler(&code_); - asmjit::X86Emitter* a = assembler.asEmitter(); + code_.init(rt_.codeInfo()); + x86::Assembler assembler(&code_); + x86::Emitter* a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) // generated code logging FILE* codeLogfile = fopen( @@ -205,49 +205,48 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>( int mRegBlocksRem = mc % mRegBlockSize; // arguments to the function created - asmjit::X86Gp buffer_A = a->zdi(); - asmjit::X86Gp buffer_B = a->zsi(); - asmjit::X86Gp B_pf = a->zdx(); - asmjit::X86Gp CBase = a->zcx(); - asmjit::X86Gp kSize = a->gpzRef(8); - asmjit::X86Gp ldcReg = a->gpzRef(9); + x86::Gp buffer_A = a->zdi(); + x86::Gp buffer_B = a->zsi(); + x86::Gp B_pf = a->zdx(); + x86::Gp CBase = a->zcx(); + x86::Gp kSize = a->gpz(8); + x86::Gp ldcReg = a->gpz(9); asmjit::FuncDetail func; func.init( asmjit:: - FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( + FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( asmjit::CallConv::kIdHost)); - asmjit::FuncFrameInfo ffi; - ffi.setDirtyRegs( - asmjit::X86Reg::kKindVec, - asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); - ffi.setDirtyRegs( - asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14)); + asmjit::FuncFrame frame; + frame.init(func); + frame.setDirtyRegs( + x86::Reg::kGroupVec, + asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.setDirtyRegs( + x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14)); - asmjit::FuncArgsMapper args(&func); + asmjit::FuncArgsAssignment args(&func); args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); - args.updateFrameInfo(ffi); + args.updateFuncFrame(frame); + frame.finalize(); - asmjit::FuncFrameLayout layout; - layout.init(func, ffi); - - asmjit::FuncUtils::emitProlog(a, layout); - asmjit::FuncUtils::allocArgs(a, layout, args); + a->emitProlog(frame); + a->emitArgsAssignment(frame, args); asmjit::Label Loopk = a->newLabel(); asmjit::Label LoopMBlocks = a->newLabel(); - asmjit::X86Gp buffer_B_saved = a->gpzRef(10); - asmjit::X86Gp C_Offset = a->gpzRef(11); - asmjit::X86Gp B_pf_saved = a->gpzRef(12); - asmjit::X86Gp iIdx = a->gpzRef(13); - asmjit::X86Gp kIdx = a->gpzRef(14); - // asmjit::X86Gp B_pf = a->gpzRef(8); + x86::Gp buffer_B_saved = a->gpz(10); + x86::Gp C_Offset = a->gpz(11); + x86::Gp B_pf_saved = a->gpz(12); + x86::Gp iIdx = a->gpz(13); + x86::Gp kIdx = a->gpz(14); + // x86::Gp B_pf = a->gpz(8); - asmjit::X86Ymm oneReg = x86::ymm15; + x86::Ymm oneReg = x86::ymm15; // create 16-bit 1s // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001 // and so on @@ -358,7 +357,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>( a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); } - asmjit::FuncUtils::emitEpilog(a, layout); + a->emitEpilog(frame); jit_micro_kernel_fp fn; asmjit::Error err = rt_.add(&fn, &code_); diff --git a/src/GenerateKernelU8S8S32ACC32Avx512.cc b/src/GenerateKernelU8S8S32ACC32Avx512.cc index 12243ee..fe35627 100644 --- a/src/GenerateKernelU8S8S32ACC32Avx512.cc +++ b/src/GenerateKernelU8S8S32ACC32Avx512.cc @@ -19,7 +19,7 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs< inst_set_t::avx512>( - asmjit::X86Emitter* a, + x86::Emitter* a, int rowRegs, int colRegs, int leadingDimCReg) { @@ -41,25 +41,25 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock< inst_set_t::avx512>( - asmjit::X86Emitter* a, - asmjit::X86Gp buffer_A, - asmjit::X86Gp buffer_B, - asmjit::X86Gp B_pf, + x86::Emitter* a, + x86::Gp buffer_A, + x86::Gp buffer_B, + x86::Gp B_pf, int rowRegs, int colRegs, int lda, int leadingDimCReg) { // used for matrix A - asmjit::X86Zmm AReg = x86::zmm31; + x86::Zmm AReg = x86::zmm31; // used for matrix B - asmjit::X86Zmm BReg = x86::zmm30; + x86::Zmm BReg = x86::zmm30; // Contains 16-bit 1s - asmjit::X86Zmm oneReg = x86::zmm29; + x86::Zmm oneReg = x86::zmm29; // temporary register - asmjit::X86Zmm res1 = x86::zmm28; + x86::Zmm res1 = x86::zmm28; for (int j = 0; j < colRegs; ++j) { // load B @@ -87,18 +87,17 @@ template <> template <> void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs< inst_set_t::avx512>( - asmjit::X86Emitter* a, + x86::Emitter* a, int rowRegs, int colRegs, - asmjit::X86Gp C_Offset, - asmjit::X86Gp ldcReg, + x86::Gp C_Offset, + x86::Gp ldcReg, bool accum, int leadingDimCReg) { for (int i = 0; i < rowRegs; ++i) { if (i != 0) { a->add(C_Offset, ldcReg); - } - else { + } else { a->mov(C_Offset, static_cast<asmjit::Imm>(0)); } for (int j = 0; j < colRegs; ++j) { @@ -168,9 +167,9 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>( return codeCache_[kernelSig]; } code_.reset(false); - code_.init(rt_.getCodeInfo()); - asmjit::X86Assembler assembler(&code_); - asmjit::X86Emitter* a = assembler.asEmitter(); + code_.init(rt_.codeInfo()); + x86::Assembler assembler(&code_); + x86::Emitter* a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) // generated code logging @@ -205,52 +204,52 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>( int mRegBlocksRem = mc % mRegBlockSize; // arguments to the function created - asmjit::X86Gp buffer_A = a->zdi(); - asmjit::X86Gp buffer_B = a->zsi(); - asmjit::X86Gp B_pf = a->zdx(); - asmjit::X86Gp CBase = a->zcx(); - asmjit::X86Gp kSize = a->gpzRef(8); - asmjit::X86Gp ldcReg = a->gpzRef(9); + x86::Gp buffer_A = a->zdi(); + x86::Gp buffer_B = a->zsi(); + x86::Gp B_pf = a->zdx(); + x86::Gp CBase = a->zcx(); + x86::Gp kSize = a->gpz(8); + x86::Gp ldcReg = a->gpz(9); asmjit::FuncDetail func; func.init( asmjit:: - FuncSignature6<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( + FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( asmjit::CallConv::kIdHost)); - asmjit::FuncFrameInfo ffi; - ffi.setDirtyRegs( - asmjit::X86Reg::kKindVec, - asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); - ffi.setDirtyRegs( - asmjit::X86Reg::kKindGp, - asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); + asmjit::FuncFrame frame; + frame.init(func); - asmjit::FuncArgsMapper args(&func); - args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); + frame.setDirtyRegs( + x86::Reg::kGroupVec, + asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.setDirtyRegs( + x86::Reg::kGroupGp, + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - args.updateFrameInfo(ffi); + asmjit::FuncArgsAssignment args(&func); + args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); - asmjit::FuncFrameLayout layout; - layout.init(func, ffi); + args.updateFuncFrame(frame); + frame.finalize(); - asmjit::FuncUtils::emitProlog(a, layout); - asmjit::FuncUtils::allocArgs(a, layout, args); + a->emitProlog(frame); + a->emitArgsAssignment(frame, args); asmjit::Label LoopMBlocks = a->newLabel(); asmjit::Label LoopNBlocks = a->newLabel(); asmjit::Label Loopk = a->newLabel(); - asmjit::X86Gp buffer_B_saved = a->gpzRef(10); - asmjit::X86Gp C_Offset = a->gpzRef(11); - asmjit::X86Gp B_pf_saved = a->gpzRef(12); - asmjit::X86Gp iIdx = a->gpzRef(13); - asmjit::X86Gp jIdx = a->gpzRef(14); - asmjit::X86Gp kIdx = a->gpzRef(15); - // asmjit::X86Gp B_pf = a->gpzRef(8); + x86::Gp buffer_B_saved = a->gpz(10); + x86::Gp C_Offset = a->gpz(11); + x86::Gp B_pf_saved = a->gpz(12); + x86::Gp iIdx = a->gpz(13); + x86::Gp jIdx = a->gpz(14); + x86::Gp kIdx = a->gpz(15); + // x86::Gp B_pf = a->gpz(8); - asmjit::X86Zmm oneReg = x86::zmm29; + x86::Zmm oneReg = x86::zmm29; // create 16-bit 1s // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001 // and so on @@ -420,7 +419,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>( a->jl(LoopNRem); } - asmjit::FuncUtils::emitEpilog(a, layout); + a->emitEpilog(frame); jit_micro_kernel_fp fn; asmjit::Error err = rt_.add(&fn, &code_); diff --git a/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc b/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc new file mode 100644 index 0000000..8ae0745 --- /dev/null +++ b/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc @@ -0,0 +1,431 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include <iostream> +#include "GenerateKernel.h" + +namespace fbgemm { + +namespace x86 = asmjit::x86; + +/** + * Generate AVX512 instructions for initializing the C registers to 0 in 32-bit + * Accumulation kernel. + */ +template <> +template <> +void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs< + inst_set_t::avx512_vnni>( + x86::Emitter* a, + int rowRegs, + int colRegs, + int leadingDimCReg) { + for (int i = 0; i < rowRegs; ++i) { + for (int j = 0; j < colRegs; ++j) { + a->vxorps( + CRegs_avx512_[i * leadingDimCReg + j], + CRegs_avx512_[i * leadingDimCReg + j], + CRegs_avx512_[i * leadingDimCReg + j]); + } + } +} + +/** + * Generate AVX512 instructions for computing block in the rank-k update of + * 32-bit Accmulation kernel. + */ +template <> +template <> +void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock< + inst_set_t::avx512_vnni>( + x86::Emitter* a, + x86::Gp buffer_A, + x86::Gp buffer_B, + x86::Gp B_pf, + int rowRegs, + int colRegs, + int lda, + int leadingDimCReg) { + // used for matrix A + x86::Zmm AReg = x86::zmm31; + + // used for matrix B + x86::Zmm BReg = x86::zmm30; + + for (int j = 0; j < colRegs; ++j) { + // load B + a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t))); + // load A, broadcast and fmas + for (int i = 0; i < rowRegs; ++i) { + a->vpbroadcastd( + AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t))); + a->vpdpbusd(CRegs_avx512_[i * leadingDimCReg + j], AReg, BReg); + } + a->prefetcht0(x86::dword_ptr(B_pf, j * VLEN_ * sizeof(int8_t))); + } +} + +/** + * Generate AVX512 instructions for storing the C registers back to the memory + * in 32-bit Accumulation kernel. + */ +template <> +template <> +void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegs< + inst_set_t::avx512_vnni>( + x86::Emitter* a, + int rowRegs, + int colRegs, + x86::Gp C_Offset, + x86::Gp ldcReg, + bool accum, + int leadingDimCReg) { + for (int i = 0; i < rowRegs; ++i) { + if (i != 0) { + a->add(C_Offset, ldcReg); + } else { + a->mov(C_Offset, static_cast<asmjit::Imm>(0)); + } + for (int j = 0; j < colRegs; ++j) { + if (accum) { + a->vpaddd( + CRegs_avx512_[i * leadingDimCReg + j], + CRegs_avx512_[i * leadingDimCReg + j], + x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t))); + } + a->vmovups( + x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)), + CRegs_avx512_[i * leadingDimCReg + j]); + } + } +} + +/** + * Get or Create the AVX512 instructions for 32-bit Accumulation macro-kernel. + * + */ +template <> +template <> +CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::jit_micro_kernel_fp +CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate< + inst_set_t::avx512_vnni>( + bool accum, + int32_t mc, + int32_t nc, + int32_t kc, + int32_t /* unused */) { + std::tuple<bool, int, int, int, int, int, int, int> kernelSig; + int kBlock; + int nBlock; + int mRegBlockSize; + int nRegBlockSize; + int nRegBlockSizeMin; + int row_interleave; + + if (blocking_params) { + kBlock = blocking_params->KCB; + nBlock = blocking_params->NCB; + mRegBlockSize = blocking_params->MR; + nRegBlockSize = blocking_params->NR; + nRegBlockSizeMin = blocking_params->NR_MIN; + row_interleave = blocking_params->ROW_INTERLEAVE; + } else { + kBlock = PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::KCB; + nBlock = PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::NCB; + mRegBlockSize = + PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::MR; + nRegBlockSize = + PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::NR; + nRegBlockSizeMin = + PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>::NR_MIN; + row_interleave = PackingTraits<uint8_t, int32_t, inst_set_t::avx512_vnni>:: + ROW_INTERLEAVE; + } + + kernelSig = std::make_tuple( + accum, + mc, + nc, + nBlock, + kBlock, + mRegBlockSize, + nRegBlockSize, + nRegBlockSizeMin); + + if (codeCache_.find(kernelSig) != codeCache_.end()) { + return codeCache_[kernelSig]; + } + code_.reset(false); + code_.init(rt_.codeInfo()); + x86::Assembler assembler(&code_); + x86::Emitter* a = assembler.as<x86::Emitter>(); + +#if defined(FBGEMM_LOG_CODE) + // generated code logging + FILE* codeLogfile = fopen( + getCodeLoggingFile<inst_set_t::avx512_vnni>( + accum, + mc, + nc, + nBlock, + kBlock, + mRegBlockSize, + nRegBlockSize, + nRegBlockSizeMin) + .c_str(), + "w"); + asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); + if (codeLogger) { + code_.setLogger(codeLogger); + } +#endif + + assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); + assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN"); + int maxMRegs = mRegBlockSize; + int maxNRegs = nRegBlockSize * row_interleave / VLEN_; + assert( + maxMRegs * maxNRegs <= 28 && + "MR*(NR*ROW_INTERLEAVE*8/512) \ + must be <= 28(available registers constraint)"); + + int mRegBlocks = mc / mRegBlockSize; + int mRegBlocksRem = mc % mRegBlockSize; + + // arguments to the function created + x86::Gp buffer_A = a->zdi(); + x86::Gp buffer_B = a->zsi(); + x86::Gp B_pf = a->zdx(); + x86::Gp CBase = a->zcx(); + x86::Gp kSize = a->gpz(8); + x86::Gp ldcReg = a->gpz(9); + + asmjit::FuncDetail func; + func.init( + asmjit:: + FuncSignatureT<void, uint8_t*, int8_t*, int8_t*, int32_t*, int, int>( + asmjit::CallConv::kIdHost)); + + asmjit::FuncFrame frame; + frame.init(func); + + frame.setDirtyRegs( + x86::Reg::kGroupVec, + asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.setDirtyRegs( + x86::Reg::kGroupGp, + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + + asmjit::FuncArgsAssignment args(&func); + args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); + + args.updateFuncFrame(frame); + frame.finalize(); + + a->emitProlog(frame); + a->emitArgsAssignment(frame, args); + + asmjit::Label LoopMBlocks = a->newLabel(); + asmjit::Label LoopNBlocks = a->newLabel(); + asmjit::Label Loopk = a->newLabel(); + + x86::Gp buffer_B_saved = a->gpz(10); + x86::Gp C_Offset = a->gpz(11); + x86::Gp B_pf_saved = a->gpz(12); + x86::Gp iIdx = a->gpz(13); + x86::Gp jIdx = a->gpz(14); + x86::Gp kIdx = a->gpz(15); + // x86::Gp B_pf = a->gpz(8); + + x86::Zmm oneReg = x86::zmm29; + // create 16-bit 1s + // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001 + // and so on + // a->vpcmpeqw(oneReg, oneReg, oneReg); + a->vpternlogd(oneReg, oneReg, oneReg, 0xff); + a->vpsrlw(oneReg, oneReg, 15); + a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t))); + + // save B_buffer address + a->mov(buffer_B_saved, buffer_B); + a->mov(B_pf_saved, B_pf); + + int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; + int colRegs = std::min(currColRegs, maxNRegs); + if (mRegBlocks > 0) { + // move 0 to iteration variables + a->mov(iIdx, 0); + + a->bind(LoopMBlocks); + a->inc(iIdx); + a->mov(jIdx, 0); + + a->bind(LoopNBlocks); + a->inc(jIdx); + + int rowRegs = mRegBlockSize; + + // init C registers + initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(Loopk); + + // k is incremented by row_interleave + a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); + + genComputeBlock<inst_set_t::avx512_vnni>( + a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); + + // update buffer_A address for next k iteration + a->add( + buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); + + // update buffer_B address for next k iteration + a->add( + buffer_B, + static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); + a->add( + B_pf, + static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); + + // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float))); + + a->cmp(kIdx, kSize); + a->jl(Loopk); + + // store C matrix + storeCRegs<inst_set_t::avx512_vnni>( + a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); + + // reset A + a->sub(buffer_A, kSize); + + // B for next block + a->mov(buffer_B, buffer_B_saved); + // using C_Offset as temp reg + a->imul( + C_Offset, + jIdx, + static_cast<asmjit::Imm>( + nRegBlockSize * row_interleave * sizeof(int8_t))); + a->add(buffer_B, C_Offset); + a->mov(B_pf, B_pf_saved); + a->add(B_pf, C_Offset); + + // increment C for next B block + a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); + + int jLoopTrips = currColRegs / maxNRegs; + // jLoopTrips should be at least 1 + jLoopTrips = jLoopTrips ? jLoopTrips : 1; + a->cmp(jIdx, jLoopTrips); + a->jl(LoopNBlocks); + + // increment A for next block + a->add( + buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t))); + + // increment C for next A block + a->sub( + CBase, + static_cast<asmjit::Imm>(jLoopTrips * nRegBlockSize * sizeof(int32_t))); + a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs)); + a->add(CBase, C_Offset); + + // reset B + a->mov(buffer_B, buffer_B_saved); + a->mov(B_pf, B_pf_saved); + a->cmp(iIdx, mRegBlocks); + a->jl(LoopMBlocks); + } + // generate code for remainder + if (mRegBlocksRem > 0) { + asmjit::Label LoopNRem = a->newLabel(); + asmjit::Label LoopkRem = a->newLabel(); + int rowRegs = mRegBlocksRem; + + a->mov(jIdx, 0); + a->bind(LoopNRem); + a->inc(jIdx); + + // init C registers + initCRegs<inst_set_t::avx512_vnni>(a, rowRegs, colRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(LoopkRem); + + // k is incremented by row_interleave + a->add(kIdx, static_cast<asmjit::Imm>(row_interleave)); + + genComputeBlock<inst_set_t::avx512_vnni>( + a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); + + // update buffer_A address for next k iteration + a->add( + buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t))); + + // update buffer_B address for next k iteration + a->add( + buffer_B, + static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); + a->add( + B_pf, + static_cast<asmjit::Imm>(nBlock * row_interleave * sizeof(int8_t))); + + a->cmp(kIdx, kSize); + a->jl(LoopkRem); + + // reset A + a->sub(buffer_A, kSize); + // B for next block + // using C_Offset as temp reg + a->imul( + C_Offset, + jIdx, + static_cast<asmjit::Imm>( + nRegBlockSize * row_interleave * sizeof(int8_t))); + a->mov(buffer_B, buffer_B_saved); + a->add(buffer_B, C_Offset); + a->mov(B_pf, B_pf_saved); + a->add(B_pf, C_Offset); + + // store C matrix + storeCRegs<inst_set_t::avx512_vnni>( + a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); + + // increment C for next B block + a->add(CBase, static_cast<asmjit::Imm>(nRegBlockSize * sizeof(int32_t))); + + int jLoopTrips = currColRegs / maxNRegs; + // jLoopTrips should be at least 1 + jLoopTrips = jLoopTrips ? jLoopTrips : 1; + a->cmp(jIdx, jLoopTrips); + a->jl(LoopNRem); + } + + a->emitEpilog(frame); + + jit_micro_kernel_fp fn; + asmjit::Error err = rt_.add(&fn, &code_); + if (err) { + std::cout << "Error: in fn add" << std::endl; + return nullptr; + } + codeCache_[kernelSig] = fn; + +#if defined(FBGEMM_LOG_CODE) + fclose(codeLogfile); + delete codeLogger; +#endif + + return fn; +} + +} // namespace fbgemm diff --git a/src/GroupwiseConv.h b/src/GroupwiseConv.h index 1e6324e..4c5eea5 100644 --- a/src/GroupwiseConv.h +++ b/src/GroupwiseConv.h @@ -128,60 +128,58 @@ class GenConvKernel { const conv_param_t<SPATIAL_DIM>& conv_param); template <inst_set_t instSet> - void createVector16BitOne(asmjit::X86Emitter* a); + void createVector16BitOne(x86::Emitter* a); template <inst_set_t instSet> - void createVector8BitOne(asmjit::X86Emitter* a); + void createVector8BitOne(x86::Emitter* a); template <inst_set_t instSet> - void setToZeroPt(asmjit::X86Emitter* a, asmjit::X86Ymm destReg); + void setToZeroPt(x86::Emitter* a, x86::Ymm destReg); template <inst_set_t instSet> - void - gen8bitFMA(asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm wReg); + void gen8bitFMA(x86::Emitter* a, x86::Ymm aReg, x86::Ymm wReg); template <inst_set_t instSet> - void genForLoadingWeights(asmjit::X86Emitter* a, int c_offset); + void genForLoadingWeights(x86::Emitter* a, int c_offset); template <inst_set_t instSet> - void genConstForPermutations(asmjit::X86Emitter* a); + void genConstForPermutations(x86::Emitter* a); template <inst_set_t instSet> - void genForTopEdge(asmjit::X86Emitter* a, int c_offset); + void genForTopEdge(x86::Emitter* a, int c_offset); template <inst_set_t instSet> - void genForLeftEdge(asmjit::X86Emitter* a, int c_offset); + void genForLeftEdge(x86::Emitter* a, int c_offset); template <inst_set_t instSet> - void genForRightEdge(asmjit::X86Emitter* a, int c_offset); + void genForRightEdge(x86::Emitter* a, int c_offset); template <inst_set_t instSet> - void genForBottomEdge(asmjit::X86Emitter* a, int c_offset); + void genForBottomEdge(x86::Emitter* a, int c_offset); template <inst_set_t instSet> - void genCoreInsts(asmjit::X86Emitter* a, int c_offset); + void genCoreInsts(x86::Emitter* a, int c_offset); template <inst_set_t instSet> - void storeResult(asmjit::X86Emitter* a); + void storeResult(x86::Emitter* a); // for Rowoffset kernel // Add 4 consecutive numbers of 32 uint8 and emit 8 32-bit template <inst_set_t instSet> - void gen8BitSumX4(asmjit::X86Emitter* a, asmjit::X86Ymm aReg); + void gen8BitSumX4(x86::Emitter* a, x86::Ymm aReg); // Add 8 consecutive numbers of 64 uint8 and emit 8 32-bit template <inst_set_t instSet> - void - gen8BitSumX8(asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm bReg); + void gen8BitSumX8(x86::Emitter* a, x86::Ymm aReg, x86::Ymm bReg); // Add 16 consecutive numbers of 128 uint8 and emit 8 32-bit template <inst_set_t instSet> void gen8BitSumX16( - asmjit::X86Emitter* a, - asmjit::X86Ymm aReg, - asmjit::X86Ymm bReg, - asmjit::X86Ymm cReg, - asmjit::X86Ymm dReg); + x86::Emitter* a, + x86::Ymm aReg, + x86::Ymm bReg, + x86::Ymm cReg, + x86::Ymm dReg); // Generate instruction sequence that loads 8-bit values and sum them up. // Depending on C_per_G_, this function dispatches to gen8BitSumX4/8/16 @@ -191,35 +189,33 @@ class GenConvKernel { // Internally, actRegAvx2_, stPermRegAvx2_, WRegs_avx2_[0, 1], tmpReg1Avx2_, // and resultRegAvx2_ are used. template <inst_set_t instSet> - void gen8BitSum( - asmjit::X86Emitter* a, - int act_offset, - bool use_scratch_reg1 = true); + void + gen8BitSum(x86::Emitter* a, int act_offset, bool use_scratch_reg1 = true); // Use scratchReg1_ and tmpReg1Avx2_ internally template <inst_set_t instSet> - void genZeroPtSum(asmjit::X86Emitter* a, int multiplier); + void genZeroPtSum(x86::Emitter* a, int multiplier); template <inst_set_t instSet> - void genForTopEdgeRowoffset(asmjit::X86Emitter* a); + void genForTopEdgeRowoffset(x86::Emitter* a); template <inst_set_t instSet> - void genForLeftEdgeRowoffset(asmjit::X86Emitter* a); + void genForLeftEdgeRowoffset(x86::Emitter* a); template <inst_set_t instSet> - void genForRightEdgeRowoffset(asmjit::X86Emitter* a); + void genForRightEdgeRowoffset(x86::Emitter* a); template <inst_set_t instSet> - void genForBottomEdgeRowoffset(asmjit::X86Emitter* a); + void genForBottomEdgeRowoffset(x86::Emitter* a); template <inst_set_t instSet> - void genRowoffsetCorners(asmjit::X86Emitter* a); + void genRowoffsetCorners(x86::Emitter* a); template <inst_set_t instSet> - void genRowoffsetCore(asmjit::X86Emitter* a); + void genRowoffsetCore(x86::Emitter* a); template <inst_set_t instSet> - void storeResultRowoffset(asmjit::X86Emitter* a, int offset = 0); + void storeResultRowoffset(x86::Emitter* a, int offset = 0); static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit. static thread_local asmjit::CodeHolder code_; ///< JIT Code Holder for asmjit. @@ -234,30 +230,30 @@ class GenConvKernel { int vectorWidth_; ///< Vector width in bits. int VLEN_; ///< Vector width in elements. // avx2 specific - asmjit::X86Ymm + x86::Ymm WRegs_avx2_[9]; ///< AVX2 ymm registers for weights in the micro-kernel. - asmjit::X86Ymm zeroPTRegAvx2_; - asmjit::X86Ymm tmpReg1Avx2_; - asmjit::X86Ymm stPermRegAvx2_; - asmjit::X86Ymm actRegAvx2_; - asmjit::X86Ymm resultRegAvx2_; - asmjit::X86Ymm oneReg8BitAvx2_; - asmjit::X86Ymm oneReg16BitAvx2_; + x86::Ymm zeroPTRegAvx2_; + x86::Ymm tmpReg1Avx2_; + x86::Ymm stPermRegAvx2_; + x86::Ymm actRegAvx2_; + x86::Ymm resultRegAvx2_; + x86::Ymm oneReg8BitAvx2_; + x86::Ymm oneReg16BitAvx2_; // arguments to the function created - asmjit::X86Gp in_acts_R_; - asmjit::X86Gp wghts_R_; - asmjit::X86Gp out_acts_R_; - asmjit::X86Gp a_zero_pt_R_; - asmjit::X86Gp H_R_; - asmjit::X86Gp W_R_; - asmjit::X86Gp row_offset_R_; + x86::Gp in_acts_R_; + x86::Gp wghts_R_; + x86::Gp out_acts_R_; + x86::Gp a_zero_pt_R_; + x86::Gp H_R_; + x86::Gp W_R_; + x86::Gp row_offset_R_; // Used registers - asmjit::X86Gp loopR1_; - asmjit::X86Gp loopR2_; - asmjit::X86Gp scratchReg1_; - asmjit::X86Gp scratchReg2_; + x86::Gp loopR1_; + x86::Gp loopR2_; + x86::Gp scratchReg1_; + x86::Gp scratchReg2_; // Other parameters bool isAZeroPointZero_; diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc index e789695..b140c83 100644 --- a/src/GroupwiseConvAcc32Avx2.cc +++ b/src/GroupwiseConvAcc32Avx2.cc @@ -104,7 +104,7 @@ jit_conv_kernel_fp getOrCreateConvKernel( template <> template <> void GenConvKernel<2, int32_t>::createVector8BitOne<inst_set_t::avx2>( - asmjit::X86Emitter* a) { + x86::Emitter* a) { // create 8-bit 1s // i.e., oneReg16BitAvx2_[0:7] contains 0x01, oneReg8BitAvx2_[8:15] contains // 0x01 and so on @@ -115,7 +115,7 @@ void GenConvKernel<2, int32_t>::createVector8BitOne<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::createVector16BitOne<inst_set_t::avx2>( - asmjit::X86Emitter* a) { + x86::Emitter* a) { // create 16-bit 1s // i.e., oneReg16BitAvx2_[0:15] contains 0x0001, oneReg16BitAvx2_[16:31] // contains 0x0001 and so on @@ -125,11 +125,11 @@ void GenConvKernel<2, int32_t>::createVector16BitOne<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::setToZeroPt<inst_set_t::avx2>( - asmjit::X86Emitter* a, - asmjit::X86Ymm destReg) { + x86::Emitter* a, + x86::Ymm destReg) { // make destReg all zeros a->vxorps(destReg, destReg, destReg); - asmjit::X86Xmm const_reg_xmm = x86::xmm10; + x86::Xmm const_reg_xmm = x86::xmm10; // move zero point to xmm10 a->movq(const_reg_xmm, a_zero_pt_R_); // make copies of zero point @@ -143,9 +143,9 @@ void GenConvKernel<2, int32_t>::setToZeroPt<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genConstForPermutations<inst_set_t::avx2>( - asmjit::X86Emitter* a) { - asmjit::X86Gp permute_const_reg = a->gpzRef(12); - asmjit::X86Xmm const_reg_xmm = x86::xmm10; + x86::Emitter* a) { + x86::Gp permute_const_reg = a->gpz(12); + x86::Xmm const_reg_xmm = x86::xmm10; // We have 1st group in even lanes and 2nd group in odd lanes. // Permute to put 1st group to lower 128-bit and 2nd group in upper // 128-bit. @@ -159,8 +159,7 @@ void GenConvKernel<2, int32_t>::genConstForPermutations<inst_set_t::avx2>( template <> template <> -void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>( - asmjit::X86Emitter* a) { +void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>(x86::Emitter* a) { if (C_per_G_ == 4) { // store with permutation a->vpermd(resultRegAvx2_, stPermRegAvx2_, resultRegAvx2_); @@ -171,7 +170,7 @@ void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::storeResultRowoffset<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int offset) { // store if (C_per_G_ == 4) { @@ -198,7 +197,7 @@ void GenConvKernel<2, int32_t>::storeResultRowoffset<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForLoadingWeights<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int c_offset) { // load weights for (int r = 0; r < R_; ++r) { @@ -225,9 +224,9 @@ void GenConvKernel<2, int32_t>::genForLoadingWeights<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::gen8bitFMA<inst_set_t::avx2>( - asmjit::X86Emitter* a, - asmjit::X86Ymm aReg, - asmjit::X86Ymm wReg) { + x86::Emitter* a, + x86::Ymm aReg, + x86::Ymm wReg) { a->vpmaddubsw(tmpReg1Avx2_, aReg, wReg); a->vpmaddwd(tmpReg1Avx2_, oneReg16BitAvx2_, tmpReg1Avx2_); a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_); @@ -236,8 +235,8 @@ void GenConvKernel<2, int32_t>::gen8bitFMA<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::gen8BitSumX4<inst_set_t::avx2>( - asmjit::X86Emitter* a, - asmjit::X86Ymm aReg) { + x86::Emitter* a, + x86::Ymm aReg) { a->vpmaddubsw(tmpReg1Avx2_, aReg, oneReg8BitAvx2_); a->vpmaddwd(tmpReg1Avx2_, tmpReg1Avx2_, oneReg16BitAvx2_); a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_); @@ -246,9 +245,9 @@ void GenConvKernel<2, int32_t>::gen8BitSumX4<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::gen8BitSumX8<inst_set_t::avx2>( - asmjit::X86Emitter* a, - asmjit::X86Ymm aReg, - asmjit::X86Ymm bReg) { + x86::Emitter* a, + x86::Ymm aReg, + x86::Ymm bReg) { a->vxorps(tmpReg1Avx2_, tmpReg1Avx2_, tmpReg1Avx2_); // Let a[0] denote 0th (LSB) 8-bit of aReg // After vpsadbw, a[0:2] = a[0] + ... + a[7] @@ -267,11 +266,11 @@ void GenConvKernel<2, int32_t>::gen8BitSumX8<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::gen8BitSumX16<inst_set_t::avx2>( - asmjit::X86Emitter* a, - asmjit::X86Ymm aReg, - asmjit::X86Ymm bReg, - asmjit::X86Ymm cReg, - asmjit::X86Ymm dReg) { + x86::Emitter* a, + x86::Ymm aReg, + x86::Ymm bReg, + x86::Ymm cReg, + x86::Ymm dReg) { a->vxorps(tmpReg1Avx2_, tmpReg1Avx2_, tmpReg1Avx2_); // After vpsadbw, a[0:2] = a[0] + ... + a[7] // a[8:10] = a[8] + ... + a[15] @@ -319,7 +318,7 @@ void GenConvKernel<2, int32_t>::gen8BitSumX16<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::gen8BitSum<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int act_offset, bool use_scratch_reg1 /*=true*/) { if (use_scratch_reg1) { @@ -385,11 +384,11 @@ void GenConvKernel<2, int32_t>::gen8BitSum<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genZeroPtSum<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int multiplier) { a->mov(scratchReg1_, static_cast<asmjit::Imm>(multiplier)); // tmpReg1Avx2_ also uses xmm11 - asmjit::X86Xmm const_reg_xmm = x86::xmm11; + x86::Xmm const_reg_xmm = x86::xmm11; a->movq(const_reg_xmm, scratchReg1_); a->vpbroadcastd(tmpReg1Avx2_, const_reg_xmm); a->vpmulld(tmpReg1Avx2_, zeroPTRegAvx2_, tmpReg1Avx2_); @@ -399,7 +398,7 @@ void GenConvKernel<2, int32_t>::genZeroPtSum<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForTopEdge<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int c_offset) { // top-left corner code if (c_offset == 0) { @@ -559,7 +558,7 @@ void GenConvKernel<2, int32_t>::genForTopEdge<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForLeftEdge<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int c_offset) { // left edge excluding corners asmjit::Label LoopLeftEdge = a->newLabel(); @@ -626,7 +625,7 @@ void GenConvKernel<2, int32_t>::genForLeftEdge<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForRightEdge<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int c_offset) { // right edge excluding corners asmjit::Label LoopRightEdge = a->newLabel(); @@ -714,7 +713,7 @@ void GenConvKernel<2, int32_t>::genForRightEdge<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForBottomEdge<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int c_offset) { // bottom-left corner // we updating the last row @@ -906,7 +905,7 @@ void GenConvKernel<2, int32_t>::genForBottomEdge<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genCoreInsts<inst_set_t::avx2>( - asmjit::X86Emitter* a, + x86::Emitter* a, int c_offset) { // main compute asmjit::Label LoopH = a->newLabel(); @@ -1011,9 +1010,9 @@ template <> jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( const conv_param_t<2>& conv_param) { code_.reset(false); - code_.init(rt_.getCodeInfo()); - asmjit::X86Assembler assembler(&code_); - asmjit::X86Emitter* a = assembler.asEmitter(); + code_.init(rt_.codeInfo()); + x86::Assembler assembler(&code_); + x86::Emitter* a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) // log code to a file @@ -1030,16 +1029,16 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( wghts_R_ = a->zsi(); out_acts_R_ = a->zdx(); a_zero_pt_R_ = a->zcx(); - H_R_ = a->gpzRef(8); - W_R_ = a->gpzRef(9); - row_offset_R_ = a->gpzRef(10); + H_R_ = a->gpz(8); + W_R_ = a->gpz(9); + row_offset_R_ = a->gpz(10); // register for temporary use - scratchReg1_ = a->gpzRef(12); - scratchReg2_ = a->gpzRef(13); + scratchReg1_ = a->gpz(12); + scratchReg2_ = a->gpz(13); asmjit::FuncDetail func; - func.init(asmjit::FuncSignature6< + func.init(asmjit::FuncSignatureT< void, uint8_t*, int8_t*, @@ -1048,29 +1047,29 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( int32_t, int32_t>(asmjit::CallConv::kIdHost)); - asmjit::FuncFrameInfo ffi; - ffi.setDirtyRegs( - asmjit::X86Reg::kKindVec, - asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); - ffi.setDirtyRegs( - asmjit::X86Reg::kKindGp, asmjit::Utils::mask(10, 11, 12, 13, 14, 15)); + asmjit::FuncFrame frame; + frame.init(func); - asmjit::FuncArgsMapper args(&func); - args.assignAll(in_acts_R_, wghts_R_, out_acts_R_, a_zero_pt_R_, H_R_, W_R_); + frame.setDirtyRegs( + x86::Reg::kGroupVec, + asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.setDirtyRegs( + x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - args.updateFrameInfo(ffi); + asmjit::FuncArgsAssignment args(&func); + args.assignAll(in_acts_R_, wghts_R_, out_acts_R_, a_zero_pt_R_, H_R_, W_R_); - asmjit::FuncFrameLayout layout; - layout.init(func, ffi); + args.updateFuncFrame(frame); + frame.finalize(); - asmjit::FuncUtils::emitProlog(a, layout); - asmjit::FuncUtils::allocArgs(a, layout, args); + a->emitProlog(frame); + a->emitArgsAssignment(frame, args); createVector16BitOne<inst_set_t::avx2>(a); - loopR1_ = a->gpzRef(14); - loopR2_ = a->gpzRef(15); + loopR1_ = a->gpz(14); + loopR2_ = a->gpz(15); if (!isAZeroPointZero_) { setToZeroPt<inst_set_t::avx2>(a, zeroPTRegAvx2_); @@ -1095,7 +1094,7 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( genCoreInsts<inst_set_t::avx2>(a, c); } - asmjit::FuncUtils::emitEpilog(a, layout); + a->emitEpilog(frame); jit_conv_kernel_fp fn; asmjit::Error err = rt_.add(&fn, &code_); @@ -1117,7 +1116,7 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>( - asmjit::X86Emitter* a) { + x86::Emitter* a) { // top-left corner code // zero out the results register a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); @@ -1213,7 +1212,7 @@ void GenConvKernel<2, int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>( - asmjit::X86Emitter* a) { + x86::Emitter* a) { // left edge excluding corners asmjit::Label LoopLeftEdge = a->newLabel(); a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_)); @@ -1256,7 +1255,7 @@ void GenConvKernel<2, int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>( - asmjit::X86Emitter* a) { + x86::Emitter* a) { // right edge excluding corners asmjit::Label LoopRightEdge = a->newLabel(); @@ -1326,7 +1325,7 @@ void GenConvKernel<2, int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>( - asmjit::X86Emitter* a) { + x86::Emitter* a) { // bottom-left corner // zero out a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); @@ -1429,7 +1428,7 @@ void GenConvKernel<2, int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>( template <> template <> void GenConvKernel<2, int32_t>::genRowoffsetCore<inst_set_t::avx2>( - asmjit::X86Emitter* a) { + x86::Emitter* a) { // number of uint8 elements in input channels should be a multiple of 32 assert(C_ % 32 == 0); @@ -1491,9 +1490,9 @@ jit_rowoffset_kernel_fp GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>( const conv_param_t<2>& conv_param) { code_.reset(false); - code_.init(rt_.getCodeInfo()); - asmjit::X86Assembler assembler(&code_); - asmjit::X86Emitter* a = assembler.asEmitter(); + code_.init(rt_.codeInfo()); + x86::Assembler assembler(&code_); + x86::Emitter* a = assembler.as<x86::Emitter>(); #if defined(FBGEMM_LOG_CODE) // log code to a file @@ -1510,45 +1509,45 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>( a_zero_pt_R_ = a->zsi(); H_R_ = a->zdx(); W_R_ = a->zcx(); - row_offset_R_ = a->gpzRef(8); + row_offset_R_ = a->gpz(8); // register for temporary use - scratchReg1_ = a->gpzRef(12); - scratchReg2_ = a->gpzRef(13); + scratchReg1_ = a->gpz(12); + scratchReg2_ = a->gpz(13); - loopR1_ = a->gpzRef(14); - loopR2_ = a->gpzRef(15); + loopR1_ = a->gpz(14); + loopR2_ = a->gpz(15); asmjit::FuncDetail func; func.init( asmjit:: - FuncSignature5<void, uint8_t*, int32_t, int32_t, int32_t, int32_t*>( + FuncSignatureT<void, uint8_t*, int32_t, int32_t, int32_t, int32_t*>( asmjit::CallConv::kIdHost)); - asmjit::FuncFrameInfo ffi; - ffi.setDirtyRegs( - asmjit::X86Reg::kKindVec, - asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); - ffi.setDirtyRegs( - asmjit::X86Reg::kKindGp, asmjit::Utils::mask(10, 11, 12, 13, 14, 15)); + asmjit::FuncFrame frame; + frame.init(func); - asmjit::FuncArgsMapper args(&func); - args.assignAll(in_acts_R_, a_zero_pt_R_, H_R_, W_R_, row_offset_R_); + frame.setDirtyRegs( + x86::Reg::kGroupVec, + asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.setDirtyRegs( + x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - args.updateFrameInfo(ffi); + asmjit::FuncArgsAssignment args(&func); + args.assignAll(in_acts_R_, a_zero_pt_R_, H_R_, W_R_, row_offset_R_); - asmjit::FuncFrameLayout layout; - layout.init(func, ffi); + args.updateFuncFrame(frame); + frame.finalize(); - asmjit::FuncUtils::emitProlog(a, layout); - asmjit::FuncUtils::allocArgs(a, layout, args); + a->emitProlog(frame); + a->emitArgsAssignment(frame, args); // This uses xmm10 register temporarily. Should come before // createVector8BitOne if (!isAZeroPointZero_) { // we can use xmm11 because ymm11 is used by tmpReg1Avx2_ - asmjit::X86Xmm const_reg_xmm = x86::xmm11; + x86::Xmm const_reg_xmm = x86::xmm11; a->movq(const_reg_xmm, a_zero_pt_R_); a->vpbroadcastd(zeroPTRegAvx2_, const_reg_xmm); @@ -1569,7 +1568,7 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>( genRowoffsetCore<inst_set_t::avx2>(a); - asmjit::FuncUtils::emitEpilog(a, layout); + a->emitEpilog(frame); jit_rowoffset_kernel_fp fn; asmjit::Error err = rt_.add(&fn, &code_); diff --git a/src/PackAMatrix.cc b/src/PackAMatrix.cc index 143e11d..5fabf97 100644 --- a/src/PackAMatrix.cc +++ b/src/PackAMatrix.cc @@ -34,7 +34,8 @@ PackAMatrix<T, accT>::PackAMatrix( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) { + if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() && + !fbgemmHasAvx2Support())) { assert(0 && "unknown architecure"); } @@ -43,7 +44,12 @@ PackAMatrix<T, accT>::PackAMatrix( BaseType::bcol_ = params->KCB; row_interleave_B_ = params->ROW_INTERLEAVE; } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE; + } else if (fbgemmHasAvx512Support()) { BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB; BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; row_interleave_B_ = diff --git a/src/PackAWithIm2Col.cc b/src/PackAWithIm2Col.cc index d731654..2aca27d 100644 --- a/src/PackAWithIm2Col.cc +++ b/src/PackAWithIm2Col.cc @@ -49,7 +49,8 @@ PackAWithIm2Col<T, accT, SPATIAL_DIM>::PackAWithIm2Col( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) { + if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() && + !fbgemmHasAvx2Support())) { assert(0 && "unknown architecure"); } @@ -58,7 +59,12 @@ PackAWithIm2Col<T, accT, SPATIAL_DIM>::PackAWithIm2Col( BaseType::bcol_ = params->KCB; row_interleave_B_ = params->ROW_INTERLEAVE; } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE; + } else if (fbgemmHasAvx512Support()) { BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB; BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; row_interleave_B_ = @@ -478,7 +484,9 @@ int PackAWithIm2Col<T, accT, SPATIAL_DIM>::rowOffsetBufferSize( if (params) { return params->MCB; } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + return PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB; + } else if (fbgemmHasAvx512Support()) { return PackingTraits<T, accT, inst_set_t::avx512>::MCB; } else if (fbgemmHasAvx2Support()) { return PackingTraits<T, accT, inst_set_t::avx2>::MCB; diff --git a/src/PackAWithQuantRowOffset.cc b/src/PackAWithQuantRowOffset.cc index 52caed4..13a8fad 100644 --- a/src/PackAWithQuantRowOffset.cc +++ b/src/PackAWithQuantRowOffset.cc @@ -45,7 +45,8 @@ PackAWithQuantRowOffset<T, accT>::PackAWithQuantRowOffset( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) { + if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() && + !fbgemmHasAvx2Support())) { assert(0 && "unknown architecure"); } @@ -54,7 +55,12 @@ PackAWithQuantRowOffset<T, accT>::PackAWithQuantRowOffset( BaseType::bcol_ = params->KCB; row_interleave_B_ = params->ROW_INTERLEAVE; } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE; + } else if (fbgemmHasAvx512Support()) { BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB; BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; row_interleave_B_ = @@ -201,7 +207,9 @@ int PackAWithQuantRowOffset<T, accT>::rowOffsetBufferSize( if (params) { return params->MCB; } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + return PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB; + } else if (fbgemmHasAvx512Support()) { return PackingTraits<T, accT, inst_set_t::avx512>::MCB; } else if (fbgemmHasAvx2Support()) { return PackingTraits<T, accT, inst_set_t::avx2>::MCB; diff --git a/src/PackAWithRowOffset.cc b/src/PackAWithRowOffset.cc index 733bf5c..e84c67b 100644 --- a/src/PackAWithRowOffset.cc +++ b/src/PackAWithRowOffset.cc @@ -39,7 +39,8 @@ PackAWithRowOffset<T, accT>::PackAWithRowOffset( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) { + if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() && + !fbgemmHasAvx2Support())) { assert(0 && "unknown architecure"); } @@ -48,7 +49,12 @@ PackAWithRowOffset<T, accT>::PackAWithRowOffset( BaseType::bcol_ = params->KCB; row_interleave_B_ = params->ROW_INTERLEAVE; } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB; + row_interleave_B_ = + PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE; + } else if (fbgemmHasAvx512Support()) { BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::MCB; BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; row_interleave_B_ = @@ -189,7 +195,9 @@ int PackAWithRowOffset<T, accT>::rowOffsetBufferSize( if (params) { return params->MCB; } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + return PackingTraits<T, accT, inst_set_t::avx512_vnni>::MCB; + } else if (fbgemmHasAvx512Support()) { return PackingTraits<T, accT, inst_set_t::avx512>::MCB; } else if (fbgemmHasAvx2Support()) { return PackingTraits<T, accT, inst_set_t::avx2>::MCB; diff --git a/src/PackBMatrix.cc b/src/PackBMatrix.cc index b19b5d4..c237ac4 100644 --- a/src/PackBMatrix.cc +++ b/src/PackBMatrix.cc @@ -188,7 +188,8 @@ PackBMatrix<T, accT>::PackBMatrix( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) { + if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() && + !fbgemmHasAvx2Support())) { assert(0 && "unknown architecure"); } @@ -197,7 +198,12 @@ PackBMatrix<T, accT>::PackBMatrix( BaseType::bcol_ = params->NCB; row_interleave_ = params->ROW_INTERLEAVE; } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::KCB; + BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512_vnni>::NCB; + row_interleave_ = + PackingTraits<T, accT, inst_set_t::avx512_vnni>::ROW_INTERLEAVE; + } else if (fbgemmHasAvx512Support()) { BaseType::brow_ = PackingTraits<T, accT, inst_set_t::avx512>::KCB; BaseType::bcol_ = PackingTraits<T, accT, inst_set_t::avx512>::NCB; row_interleave_ = @@ -228,7 +234,7 @@ PackBMatrix<T, accT>::PackBMatrix( BaseType::numGroups() * BaseType::blockRows() * BaseType::brow_ * BaseType::blockCols() * BaseType::bcol_ * sizeof(T)); } - pack(block); + pack(block, params); } template <typename T, typename accT> @@ -294,7 +300,8 @@ void PackBMatrix<T, accT>::pack_unpack_( const block_type_t& block, T* unpack_buf, T* pack_buf, - bool ispack) { + bool ispack, + const BlockingFactors* params) { assert((BaseType::blockRowSize() % row_interleave_) == 0); assert((block.row_start % BaseType::blockRowSize()) == 0); assert((block.col_start % BaseType::blockColSize()) == 0); @@ -303,7 +310,7 @@ void PackBMatrix<T, accT>::pack_unpack_( bool tr = (trans_ == matrix_op_t::Transpose); for (int g = 0; g < BaseType::numGroups(); ++g) { T* pack_buf_cur = pack_buf + - g * BaseType::packedBufferSize(block.row_size, block.col_size); + g * BaseType::packedBufferSize(block.row_size, block.col_size, params); for (int i = block.row_start; i < block.row_start + block.row_size; ++i) { int r_offset = ((i / BaseType::blockRowSize()) * BaseType::blockCols()) * (BaseType::blockRowSize() * BaseType::blockColSize()) + @@ -374,17 +381,21 @@ void PackBMatrix<T, accT>::pack_unpack_( } template <typename T, typename accT> -void PackBMatrix<T, accT>::pack(const block_type_t& block) { - pack_unpack_(block, const_cast<T*>(smat_), BaseType::getBuf(), true); +void PackBMatrix<T, accT>::pack( + const block_type_t& block, + const BlockingFactors* params) { + pack_unpack_(block, const_cast<T*>(smat_), BaseType::getBuf(), true, params); } template <typename T, typename accT> -void PackBMatrix<T, accT>::unpack(T* origin_buf) { +void PackBMatrix<T, accT>::unpack( + T* origin_buf, + const BlockingFactors* params) { block_type_t blockB{BaseType::packedRowStart(), BaseType::numPackedRows(), BaseType::packedColStart(), BaseType::numPackedCols()}; - pack_unpack_(blockB, origin_buf, BaseType::getBuf(), false); + pack_unpack_(blockB, origin_buf, BaseType::getBuf(), false, params); } template <typename T, typename accT> @@ -407,7 +418,9 @@ int32_t PackBMatrix<T, accT>::addr(int32_t r, int32_t c) const { } template <typename T, typename accT> -void PackBMatrix<T, accT>::printPackedMatrix(std::string name) { +void PackBMatrix<T, accT>::printPackedMatrix( + std::string name, + const BlockingFactors* params) { std::cout << name << ":" << "[" << BaseType::numPackedRows() << ", " << BaseType::numPackedCols() << "]" << std::endl; @@ -419,7 +432,7 @@ void PackBMatrix<T, accT>::printPackedMatrix(std::string name) { T* out = BaseType::getBuf() + g * BaseType::packedBufferSize( - BaseType::numPackedRows(), BaseType::numPackedCols()); + BaseType::numPackedRows(), BaseType::numPackedCols(), params); std::cout << "group: " << g << std::endl; for (auto nr = 0; nr < BaseType::blockRows(); ++nr) { auto rows = (nr == BaseType::blockRows() - 1) ? BaseType::lastBrow() diff --git a/src/PackMatrix.cc b/src/PackMatrix.cc index c7503dd..ff7b842 100644 --- a/src/PackMatrix.cc +++ b/src/PackMatrix.cc @@ -36,7 +36,8 @@ int PackMatrix<PT, inpType, accType>::packedBufferSize( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) { + if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() && + !fbgemmHasAvx2Support())) { assert(0 && "unknown architecure"); } @@ -46,7 +47,11 @@ int PackMatrix<PT, inpType, accType>::packedBufferSize( NCB = params->NCB; KCB = params->KCB; } else { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + MCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni>::MCB; + NCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni>::NCB; + KCB = PackingTraits<inpType, accType, inst_set_t::avx512_vnni>::KCB; + } else if (fbgemmHasAvx512Support()) { MCB = PackingTraits<inpType, accType, inst_set_t::avx512>::MCB; NCB = PackingTraits<inpType, accType, inst_set_t::avx512>::NCB; KCB = PackingTraits<inpType, accType, inst_set_t::avx512>::KCB; diff --git a/src/PackWeightMatrixForGConv.cc b/src/PackWeightMatrixForGConv.cc index ba6adf3..f6ad59e 100644 --- a/src/PackWeightMatrixForGConv.cc +++ b/src/PackWeightMatrixForGConv.cc @@ -106,7 +106,7 @@ inline int PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::packed_index_( * on 2 groups at a time and full SIMD width can be efficiently utilized even * while working on 1 group at a time. * In this case, the layout is G (C/4) R S K 4 -*/ + */ template <typename T, typename accT, int SPATIAL_DIM> void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack_unpack_( @@ -148,9 +148,9 @@ void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack_unpack_( if (ispack) { transposeConvWeights(conv_param_, src, dst); } else { - // TODO: Wrap this as a inverseTransposeConvWeights()? - // For unpack & transposed, call transposeConvWeights() - // G (R S C/G) K/G => G K/G (R S C/G) + // TODO: Wrap this as a inverseTransposeConvWeights()? + // For unpack & transposed, call transposeConvWeights() + // G (R S C/G) K/G => G K/G (R S C/G) for (int r = 0; r < R; ++r) { for (int s = 0; s < S; ++s) { for (int k = 0; k < OC_per_G; ++k) { diff --git a/src/PackWeightsForConv.cc b/src/PackWeightsForConv.cc index 25b04af..44f210e 100644 --- a/src/PackWeightsForConv.cc +++ b/src/PackWeightsForConv.cc @@ -125,6 +125,74 @@ bool PackWeightsForConv<SPATIAL_DIM, T, accT>::isPackingCompliant( test_conv_p.dilation.begin()); } +template <int SPATIAL_DIM, typename T, typename accT> +std::string PackWeightsForConv<SPATIAL_DIM, T, accT>::mismatchingParams( + const conv_param_t<SPATIAL_DIM>& test_conv_p) { + std::string msg = ""; + + auto combineStr = [](std::string id, std::string str1, std::string str2) { + std::string out = id + std::string(" "); + out += str1; + out += std::string(" vs ") + str2; + out += std::string(";"); + return out; + }; + + auto combineInt = [&combineStr](std::string id, int int1, int int2) { + return combineStr(id, std::to_string(int1), std::to_string(int2)); + }; + + if (conv_param_.IC != test_conv_p.IC) { + msg += combineInt("input_channels", conv_param_.IC, test_conv_p.IC); + } + if (conv_param_.OC != test_conv_p.OC) { + msg += combineInt("output_channels", conv_param_.IC, test_conv_p.IC); + } + if (conv_param_.G != test_conv_p.G) { + msg += combineInt("groups", conv_param_.G, test_conv_p.G); + } + + if (!std::equal( + conv_param_.K.begin(), conv_param_.K.end(), test_conv_p.K.begin())) { + msg += combineStr( + "kernel", + arrayToString<SPATIAL_DIM>(conv_param_.K), + arrayToString<SPATIAL_DIM>(test_conv_p.K)); + } + + if (!std::equal( + conv_param_.stride.begin(), + conv_param_.stride.end(), + test_conv_p.stride.begin())) { + msg += combineStr( + "stride", + arrayToString<SPATIAL_DIM>(conv_param_.stride), + arrayToString<SPATIAL_DIM>(test_conv_p.stride)); + } + + if (!std::equal( + conv_param_.pad.begin(), + conv_param_.pad.end(), + test_conv_p.pad.begin())) { + msg += combineStr( + "pad", + arrayToString<2 * SPATIAL_DIM>(conv_param_.pad), + arrayToString<2 * SPATIAL_DIM>(test_conv_p.pad)); + } + + if (!std::equal( + conv_param_.dilation.begin(), + conv_param_.dilation.end(), + test_conv_p.dilation.begin())) { + msg += combineStr( + "dilation", + arrayToString<SPATIAL_DIM>(conv_param_.dilation), + arrayToString<SPATIAL_DIM>(test_conv_p.dilation)); + } + + return msg; +} + template class PackWeightsForConv<2, int8_t, int32_t>; template class PackWeightsForConv<3, int8_t, int32_t>; diff --git a/src/Utils.cc b/src/Utils.cc index 355a5cb..af7d918 100644 --- a/src/Utils.cc +++ b/src/Utils.cc @@ -206,4 +206,7 @@ bool fbgemmHasAvx2Support() { return (cpuinfo_initialize() && cpuinfo_has_x86_avx2()); } +bool fbgemmHasAvx512VnniSupport() { + return (cpuinfo_has_x86_avx512vnni()); +} } // namespace fbgemm diff --git a/test/FP16Test.cc b/test/FP16Test.cc index eb49086..3267655 100644 --- a/test/FP16Test.cc +++ b/test/FP16Test.cc @@ -27,7 +27,26 @@ using namespace fbgemm; namespace { // The template parameter is transpose of A and B class FBGemmFP16Test - : public testing::TestWithParam<pair<matrix_op_t, matrix_op_t>> {}; + : public testing::TestWithParam<pair<matrix_op_t, matrix_op_t>> { + protected: + vector<vector<int>> GenShapes() const { + vector<vector<int>> shapes; + random_device r; + default_random_engine generator(r()); + uniform_int_distribution<int> dm(1, 256); + uniform_int_distribution<int> dnk(1, 1024); + for (int i = 0; i < 10; i++) { + int m = dm(generator); + int n = dnk(generator); + int k = dnk(generator); + shapes.push_back({m, n, k}); + if (m > 10) { + shapes.push_back({(m / 10) * 10, n, k}); + } + } + return shapes; + } +}; }; // namespace INSTANTIATE_TEST_CASE_P( @@ -44,21 +63,75 @@ INSTANTIATE_TEST_CASE_P( matrix_op_t::Transpose, matrix_op_t::Transpose)*/)); TEST_P(FBGemmFP16Test, Test) { - vector<vector<int>> shapes; - random_device r; - default_random_engine generator(r()); - uniform_int_distribution<int> dm(1, 256); - uniform_int_distribution<int> dnk(1, 1024); - for (int i = 0; i < 10; i++) { - int m = dm(generator); - int n = dnk(generator); - int k = dnk(generator); - shapes.push_back({m, n, k}); - if (m > 10) { - shapes.push_back({(m / 10) * 10, n, k}); + auto shapes = GenShapes(); + float alpha = 1.f, beta = 0.f; + matrix_op_t atrans, btrans; + tie(atrans, btrans) = GetParam(); + + for (auto s : shapes) { + int m = s[0]; + int n = s[1]; + int k = s[2]; + + cerr << "m = " << m << " n = " << n << " k = " << k; + if (atrans == matrix_op_t::Transpose) { + cerr << " A_transposed"; + } + if (btrans == matrix_op_t::Transpose) { + cerr << " B_transposed"; + } + cerr << endl; + + // initialize with small numbers + aligned_vector<int> Aint(m * k); + aligned_vector<int> Bint(k * n); + randFill(Aint, 0, 4); + randFill(Bint, 0, 4); + aligned_vector<float> A(Aint.begin(), Aint.end()); + aligned_vector<float> B(Bint.begin(), Bint.end()); + + aligned_vector<float> C(m * n, NAN); + + aligned_vector<float> A_ref(A), B_ref(B), C_ref(C); + + if (atrans == matrix_op_t::Transpose) { + transpose_matrix(A_ref.data(), k, m); + } + if (btrans == matrix_op_t::Transpose) { + transpose_matrix(B_ref.data(), n, k); + } + + // Gold via reference sgemm + matmul_fp_ref(m, n, k, k, n, n, A_ref.data(), B_ref.data(), C_ref.data()); + + // fbgemm fp16 + PackedGemmMatrixFP16 Bp(btrans, k, n, alpha, B.data()); +#ifdef _OPENMP +#pragma omp parallel +#endif + { + int num_threads = fbgemm_get_num_threads(); + int tid = fbgemm_get_thread_num(); + + cblas_gemm_compute( + atrans, m, A.data(), Bp, beta, C.data(), tid, num_threads); + } + + // correctness check + for (int i = 0; i < m; ++i) { + for (int j = 0; j < n; ++j) { + float expected = C_ref[i * n + j]; + float actual = C[i * n + j]; + EXPECT_EQ(expected, actual) + << "GEMM results differ at (" << i << ", " << j << "). ref " + << expected << " FBGemm " << actual; + } } } +} +TEST_P(FBGemmFP16Test, Unpack) { + auto shapes = GenShapes(); float alpha = 1.f, beta = 0.f; matrix_op_t atrans, btrans; tie(atrans, btrans) = GetParam(); @@ -101,6 +174,23 @@ TEST_P(FBGemmFP16Test, Test) { // fbgemm fp16 PackedGemmMatrixFP16 Bp(btrans, k, n, alpha, B.data()); + EXPECT_TRUE(Bp.packed()); + + // Test unpack + aligned_vector<float16> tmp(Bp.matSize()); + memcpy(tmp.data(), Bp.pmat(), Bp.matSize() * sizeof(float16)); + Bp.unpackFromSrc(btrans, tmp.data()); + EXPECT_FALSE(Bp.packed()); + memcpy(tmp.data(), Bp.pmat(), Bp.matSize() * sizeof(float16)); + for (int i = 0; i < k; ++i) { + for (int j = 0; j < n; ++j) { + EXPECT_EQ(B[i * n + j], cpu_half2float(tmp[i * n + j])); + } + } + + // Pack it back + Bp.packFromSrc(btrans, tmp.data()); + EXPECT_TRUE(Bp.packed()); #ifdef _OPENMP #pragma omp parallel diff --git a/test/GConvTest.cc b/test/GConvTest.cc index 0074535..8c1fb82 100644 --- a/test/GConvTest.cc +++ b/test/GConvTest.cc @@ -465,8 +465,8 @@ TEST_P(fbgemmGConvPackTest, PackUnpackTest) { for (int i = 0; i < weight_len; ++i) { EXPECT_EQ(Bint8.data()[i], unpack_buf.data()[i]) << "Pack/Unpack results differ at index " << i - << ", Reference: " << static_cast<int> (Bint8.data()[i]) - << ", Pack-Unpacked: " << static_cast<int> (unpack_buf.data()[i]); + << ", Reference: " << static_cast<int>(Bint8.data()[i]) + << ", Pack-Unpacked: " << static_cast<int>(unpack_buf.data()[i]); } } // for each shape } diff --git a/test/PackedRequantizeAcc16Test.cc b/test/PackedRequantizeAcc16Test.cc index 23af3eb..62b1303 100644 --- a/test/PackedRequantizeAcc16Test.cc +++ b/test/PackedRequantizeAcc16Test.cc @@ -94,6 +94,8 @@ static vector<vector<int>> GetShapes_() { {102, 512, 258}, {1024, 512, 258}, + + {120, 4, 288}, }; return shapes; } @@ -827,54 +829,67 @@ TEST_P(fbgemmPackUnpackAcc16Test, TestPackUnpack) { bool test_ld; tie(btrans, test_ld) = GetParam(); + BlockingFactors params; + params.MCB = 48; + params.NCB = 16; + params.KCB = 256; + params.MR = 1; + params.NR = 16; + params.ROW_INTERLEAVE = 4; + params.NR_MIN = 16; + vector<BlockingFactors*> vec_params_ptr = {¶ms, nullptr}; + for (auto shape : shapes) { for (int groups : {1, 3, 4}) { - int n = shape[1]; - int k = shape[2]; + for (auto params_ptr : vec_params_ptr) { + int n = shape[1]; + int k = shape[2]; - if (k % groups != 0) { - continue; - } - int k_per_group = k / groups; + if (k % groups != 0) { + continue; + } + int k_per_group = k / groups; - // kxn matrix - aligned_vector<int8_t> Bint8(k * n); - randFill<int8_t>(Bint8, -128, 127); + // kxn matrix + aligned_vector<int8_t> Bint8(k * n); + randFill<int8_t>(Bint8, -128, 127); - // To test lda != k , we just reduce k by half and use the original k - // as lda. - int n_adjusted = n; - if (test_ld) { - if (btrans == matrix_op_t::NoTranspose) { - n_adjusted = std::max(n / 2, 1); + // To test lda != k , we just reduce k by half and use the original k + // as lda. + int n_adjusted = n; + if (test_ld) { + if (btrans == matrix_op_t::NoTranspose) { + n_adjusted = std::max(n / 2, 1); + } } - } - // Note that packing for weight is performed during the constructor - // stage. - PackBMatrix<int8_t, int16_t> packedWeights( - btrans, - k, - n_adjusted, - Bint8.data(), - (btrans == matrix_op_t::Transpose) ? k_per_group : n, - nullptr, - groups); + // Note that packing for weight is performed during the constructor + // stage. + PackBMatrix<int8_t, int16_t> packedWeights( + btrans, + k, + n_adjusted, + Bint8.data(), + (btrans == matrix_op_t::Transpose) ? k_per_group : n, + nullptr, + groups, + params_ptr); - // Setup a buffer to get pack -> unpacked results - aligned_vector<int8_t> unpack_buf(k * n, 0); + // Setup a buffer to get pack -> unpacked results + aligned_vector<int8_t> unpack_buf(k * n, 0); - // Perform unpacking - packedWeights.unpack(unpack_buf.data()); + // Perform unpacking + packedWeights.unpack(unpack_buf.data(), params_ptr); - // Sanity check - for (int i = 0; i < k; i++) { - for (int j = 0; j < n_adjusted; j++) { - EXPECT_EQ(Bint8.data()[i * n + j], unpack_buf.data()[i * n + j]) + // Sanity check + for (int i = 0; i < k; i++) { + for (int j = 0; j < n_adjusted; j++) { + EXPECT_EQ(Bint8.data()[i * n + j], unpack_buf.data()[i * n + j]) << "Pack/Unpack results differ at index (" << i << ", " << j << ", Reference: " << static_cast<int>(Bint8.data()[i * n + j]) << ", Pack-Unpacked: " << static_cast<int>(unpack_buf.data()[i * n + j]); + } } } } diff --git a/test/PackedRequantizeTest.cc b/test/PackedRequantizeTest.cc index 11ef6ff..5338243 100644 --- a/test/PackedRequantizeTest.cc +++ b/test/PackedRequantizeTest.cc @@ -93,6 +93,8 @@ static vector<vector<int>> GetShapes_() { {102, 512, 258}, {1024, 512, 258}, + + {120, 4, 288}, }; return shapes; } @@ -766,54 +768,67 @@ TEST_P(fbgemmPackUnpackAcc32Test, TestPackUnpack) { bool test_ld; tie(btrans, test_ld) = GetParam(); + BlockingFactors params; + params.MCB = 48; + params.NCB = 16; + params.KCB = 256; + params.MR = 1; + params.NR = 16; + params.ROW_INTERLEAVE = 4; + params.NR_MIN = 16; + vector<BlockingFactors*> vec_params_ptr = {¶ms, nullptr}; + for (auto shape : shapes) { for (int groups : {1, 3, 4}) { - int n = shape[1]; - int k = shape[2]; + for (auto params_ptr : vec_params_ptr) { + int n = shape[1]; + int k = shape[2]; - if (k % groups != 0) { - continue; - } - int k_per_group = k / groups; + if (k % groups != 0) { + continue; + } + int k_per_group = k / groups; - // kxn matrix - aligned_vector<int8_t> Bint8(k * n); - randFill<int8_t>(Bint8, -128, 127); + // kxn matrix + aligned_vector<int8_t> Bint8(k * n); + randFill<int8_t>(Bint8, -128, 127); - // To test lda != k , we just reduce k by half and use the original k - // as lda. - int n_adjusted = n; - if (test_ld) { - if (btrans == matrix_op_t::NoTranspose) { - n_adjusted = std::max(n / 2, 1); + // To test lda != k , we just reduce k by half and use the original k + // as lda. + int n_adjusted = n; + if (test_ld) { + if (btrans == matrix_op_t::NoTranspose) { + n_adjusted = std::max(n / 2, 1); + } } - } - // Note that packing for weight is performed during the constructor - // stage. - PackBMatrix<int8_t> packedWeights( - btrans, - k, - n_adjusted, - Bint8.data(), - (btrans == matrix_op_t::Transpose) ? k_per_group : n, - nullptr, - groups); + // Note that packing for weight is performed during the constructor + // stage. + PackBMatrix<int8_t> packedWeights( + btrans, + k, + n_adjusted, + Bint8.data(), + (btrans == matrix_op_t::Transpose) ? k_per_group : n, + nullptr, + groups, + params_ptr); - // Setup a buffer to get pack -> unpacked results - aligned_vector<int8_t> unpack_buf(k * n, 0); + // Setup a buffer to get pack -> unpacked results + aligned_vector<int8_t> unpack_buf(k * n, 0); - // Perform unpacking - packedWeights.unpack(unpack_buf.data()); + // Perform unpacking + packedWeights.unpack(unpack_buf.data(), params_ptr); - // Sanity check - for (int i = 0; i < k; i++) { - for (int j = 0; j < n_adjusted; j++) { - EXPECT_EQ(Bint8.data()[i * n + j], unpack_buf.data()[i * n + j]) + // Sanity check + for (int i = 0; i < k; i++) { + for (int j = 0; j < n_adjusted; j++) { + EXPECT_EQ(Bint8.data()[i * n + j], unpack_buf.data()[i * n + j]) << "Pack/Unpack results differ at index (" << i << ", " << j << ", Reference: " << static_cast<int>(Bint8.data()[i * n + j]) << ", Pack-Unpacked: " << static_cast<int>(unpack_buf.data()[i * n + j]); + } } } } diff --git a/third_party/asmjit b/third_party/asmjit -Subproject 673dcefaa048c5f5a2bf8b85daf8f7b9978d018 +Subproject 4da474ac9aa2689e88d5e40a2f37628f302d7e3 |