diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/Fbgemm.cc | 16 | ||||
-rw-r--r-- | src/GroupwiseConv.h | 248 | ||||
-rw-r--r-- | src/GroupwiseConvAcc32Avx2.cc | 1552 | ||||
-rw-r--r-- | src/PackWeightMatrixForGConv.cc | 103 | ||||
-rw-r--r-- | src/RefImplementations.cc | 25 | ||||
-rw-r--r-- | src/RefImplementations.h | 8 |
6 files changed, 1952 insertions, 0 deletions
diff --git a/src/Fbgemm.cc b/src/Fbgemm.cc index 45108d0..9384af6 100644 --- a/src/Fbgemm.cc +++ b/src/Fbgemm.cc @@ -192,6 +192,22 @@ void fbgemmPacked( #endif } +template <int SPATIAL_DIM> +FBGEMM_API bool fbgemmOptimizedGConv(const conv_param_t<SPATIAL_DIM>& conv_p) { + int C_per_G = conv_p.IC / conv_p.G; + int K_per_G = conv_p.OC / conv_p.G; + + return (SPATIAL_DIM == 2) && (C_per_G == K_per_G) && (C_per_G == 4) && + (conv_p.G % 8 == 0) && (conv_p.K[0] == conv_p.K[1]) && + (conv_p.K[0] == 3) && (conv_p.pad[0] == 1) && (conv_p.pad[1] == 1) && + (conv_p.pad[0] == conv_p.pad[2]) && (conv_p.pad[1] == conv_p.pad[3]) && + (conv_p.dilation[0] == 1) && (conv_p.dilation[0] == conv_p.dilation[1]) && + (conv_p.stride[0] == 1) && (conv_p.stride[0] == conv_p.stride[1]); +} + +template bool fbgemmOptimizedGConv(const conv_param_t<2>& conv_p); +template bool fbgemmOptimizedGConv(const conv_param_t<3>& conv_p); + bool fbgemmSupportedCPU() { return (cpuinfo_initialize() && cpuinfo_has_x86_avx2()); } diff --git a/src/GroupwiseConv.h b/src/GroupwiseConv.h new file mode 100644 index 0000000..a46a895 --- /dev/null +++ b/src/GroupwiseConv.h @@ -0,0 +1,248 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once +#include <asmjit/asmjit.h> +#include <cpuinfo.h> +#include <cassert> +#include <cstdint> +#include <map> +#include <string> +#include <tuple> +#include "fbgemm/ConvUtils.h" +#include "fbgemm/Fbgemm.h" +#include "fbgemm/Utils.h" +/*#define FBGEMM_LOG_CODE 1*/ + +namespace fbgemm { + +namespace x86 = asmjit::x86; + +using jit_conv_kernel_fp = void (*)( + const uint8_t* in_acts, + int8_t* wghts, + int32_t* out_acts, + int32_t a_zero_pt, + int32_t height, + int32_t width); + +using jit_rowoffset_kernel_fp = void (*)( + const uint8_t* in_acts, + int32_t a_zero_pt, + int32_t height, + int32_t width, + int32_t* row_offset); + +template <typename accT = int32_t> +class GenConvKernel { + public: + GenConvKernel(const conv_param_t<>& conv_param, std::int32_t zero_point) + : WRegs_avx2_{x86::ymm0, + x86::ymm1, + x86::ymm2, + x86::ymm3, + x86::ymm4, + x86::ymm5, + x86::ymm6, + x86::ymm7, + x86::ymm8} { + // vector width in bits + if (cpuinfo_initialize()) { + if (cpuinfo_has_x86_avx512f()) { + vectorWidth_ = 512; + } else if (cpuinfo_has_x86_avx2()) { + vectorWidth_ = 256; + } else { + // TODO: Have default path + assert(0 && "unsupported architecture"); + return; + } + } else { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } + zeroPTRegAvx2_ = x86::ymm9; + oneReg8BitAvx2_ = x86::ymm10; + tmpReg1Avx2_ = x86::ymm11; + stPermRegAvx2_ = x86::ymm12; + actRegAvx2_ = x86::ymm13; + resultRegAvx2_ = x86::ymm14; + oneReg16BitAvx2_ = x86::ymm15; + + // vector width in elements; Each element is int8 or uint8 + VLEN_ = vectorWidth_ / 8; + + if (zero_point == 0) { + isZeroPointZero_ = true; + } else { + isZeroPointZero_ = false; + } + + G_ = conv_param.G; + K_per_G_ = conv_param.OC / conv_param.G; + K_ = conv_param.OC; + C_per_G_ = conv_param.IC / conv_param.G; + C_ = conv_param.IC; + R_ = conv_param.K[0]; + S_ = conv_param.K[1]; + H_ = conv_param.OUT_DIM[0]; + W_ = conv_param.OUT_DIM[1]; + H_PAD_ = conv_param.pad[0]; + W_PAD_ = conv_param.pad[1]; + + assert(fbgemmOptimizedGConv(conv_param)); + } + + template <inst_set_t instSet> + std::string getCodeLoggingFile(bool rowOffsetKernel = false) { + std::string fileName = "conv_"; + fileName += "G-" + std::to_string(G_); + fileName += "_K-" + std::to_string(K_); + fileName += "_C-" + std::to_string(C_); + fileName += "_R-" + std::to_string(R_); + fileName += "_S-" + std::to_string(S_); + fileName += "_PADH-" + std::to_string(H_PAD_); + fileName += "_PADW-" + std::to_string(W_PAD_); + fileName += "_isZeroPointZero-" + std::to_string(isZeroPointZero_); + if (rowOffsetKernel) { + fileName += "_rowOffset"; + } + + if (instSet == inst_set_t::avx512) { + fileName += "_avx512"; + } else if (instSet == inst_set_t::avx2) { + fileName += "_avx2"; + } + fileName += ".txt"; + return fileName; + } + + ~GenConvKernel() {} + + template <inst_set_t instSet> + jit_conv_kernel_fp getOrCreate(const conv_param_t<>& conv_param); + + template <inst_set_t instSet> + jit_rowoffset_kernel_fp getOrCreateRowOffset( + const conv_param_t<>& conv_param); + + template <inst_set_t instSet> + void createVector16BitOne(asmjit::X86Emitter* a); + + template <inst_set_t instSet> + void createVector8BitOne(asmjit::X86Emitter* a); + + template <inst_set_t instSet> + void setToZeroPt(asmjit::X86Emitter* a, asmjit::X86Ymm destReg); + + template <inst_set_t instSet> + void + gen8bitFMA(asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm wReg); + + template <inst_set_t instSet> + void genForLoadingWeights(asmjit::X86Emitter* a); + + template <inst_set_t instSet> + void genConstForPermutations(asmjit::X86Emitter* a); + + template <inst_set_t instSet> + void genForTopEdge(asmjit::X86Emitter* a); + + template <inst_set_t instSet> + void genForLeftEdge(asmjit::X86Emitter* a); + + template <inst_set_t instSet> + void genForRightEdge(asmjit::X86Emitter* a); + + template <inst_set_t instSet> + void genForBottomEdge(asmjit::X86Emitter* a); + + template <inst_set_t instSet> + void genCoreInsts(asmjit::X86Emitter* a); + + template <inst_set_t instSet> + void storeResult(asmjit::X86Emitter* a, int offset = 0); + + // for Rowoffset kernel + template <inst_set_t instSet> + void gen8BitSum(asmjit::X86Emitter* a, asmjit::X86Ymm aReg); + + template <inst_set_t instSet> + void genForTopEdgeRowoffset(asmjit::X86Emitter* a); + + template <inst_set_t instSet> + void genForLeftEdgeRowoffset(asmjit::X86Emitter* a); + + template <inst_set_t instSet> + void genForRightEdgeRowoffset(asmjit::X86Emitter* a); + + template <inst_set_t instSet> + void genForBottomEdgeRowoffset(asmjit::X86Emitter* a); + + template <inst_set_t instSet> + void genRowoffsetCorners(asmjit::X86Emitter* a); + + template <inst_set_t instSet> + void genRowoffsetCore(asmjit::X86Emitter* a); + + template <inst_set_t instSet> + void storeResultRowoffset(asmjit::X86Emitter* a, int offset = 0); + + static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit. + static thread_local asmjit::CodeHolder code_; ///< JIT Code Holder for asmjit. + static thread_local std:: + map<std::tuple<bool, int, int, int>, jit_conv_kernel_fp> + codeCache_; ///< JIT Code Cache for reuse. + static thread_local std:: + map<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp> + codeCacheRowOffset_; ///< JIT Code Cache for row offset kernel. + + private: + int vectorWidth_; ///< Vector width in bits. + int VLEN_; ///< Vector width in elements. + // avx2 specific + asmjit::X86Ymm + WRegs_avx2_[9]; ///< AVX2 ymm registers for weights in the micro-kernel. + asmjit::X86Ymm zeroPTRegAvx2_; + asmjit::X86Ymm tmpReg1Avx2_; + asmjit::X86Ymm stPermRegAvx2_; + asmjit::X86Ymm actRegAvx2_; + asmjit::X86Ymm resultRegAvx2_; + asmjit::X86Ymm oneReg8BitAvx2_; + asmjit::X86Ymm oneReg16BitAvx2_; + + // arguments to the function created + asmjit::X86Gp in_acts_R_; + asmjit::X86Gp wghts_R_; + asmjit::X86Gp out_acts_R_; + asmjit::X86Gp a_zero_pt_R_; + asmjit::X86Gp H_R_; + asmjit::X86Gp W_R_; + asmjit::X86Gp row_offset_R_; + + // Used registers + asmjit::X86Gp loopR1_; + asmjit::X86Gp loopR2_; + asmjit::X86Gp scratchReg1_; + asmjit::X86Gp scratchReg2_; + + // Other parameters + bool isZeroPointZero_; + + // current conv parameters + int G_; ///< Number of groups + int K_; ///< Number of output channels + int K_per_G_; ///< Number of output channels per group + int C_; ///< Number of input channels + int C_per_G_; ///< Number of input channels per group + int R_; ///< Filter/Kernel height + int S_; ///< Filter/Kernel width + int H_; + int W_; + int H_PAD_; ///< Padding for height (top and bottom) + int W_PAD_; ///< Padding for width (left and right) +}; + +} // namespace fbgemm diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc new file mode 100644 index 0000000..8298f4c --- /dev/null +++ b/src/GroupwiseConvAcc32Avx2.cc @@ -0,0 +1,1552 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include <asmjit/asmjit.h> +#include <cpuinfo.h> +#include <immintrin.h> +#include <array> +#include <iostream> +#include <map> +#include <stdexcept> +#include <tuple> +#include "GroupwiseConv.h" +#include "RefImplementations.h" +#include "TransposeUtils.h" +#include "fbgemm/Fbgemm.h" + +namespace fbgemm { + +using namespace std; + +template <typename accT> +thread_local asmjit::JitRuntime GenConvKernel<accT>::rt_; + +template <typename accT> +thread_local asmjit::CodeHolder GenConvKernel<accT>::code_; + +template <typename accT> +thread_local std::map<std::tuple<bool, int, int, int>, jit_conv_kernel_fp> + GenConvKernel<accT>::codeCache_; + +template <typename accT> +thread_local std::map<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp> + GenConvKernel<accT>::codeCacheRowOffset_; + +namespace x86 = asmjit::x86; + +void calculateRowOffsets( + const conv_param_t<>& conv_param, + const uint8_t* activations, + int32_t* rowOffsetBuf, + int32_t a_zero_point, + int groupNum) { + int H = conv_param.OUT_DIM[0]; + int W = conv_param.OUT_DIM[1]; + int G = conv_param.G; + int C_per_G = conv_param.IC / conv_param.G; + int H_PAD = conv_param.pad[0]; + int W_PAD = conv_param.pad[1]; + // calculate row offset + for (int h = 0; h < H; ++h) { + for (int w = 0; w < W; ++w) { + int32_t sum = 0; + for (int r = 0; r < conv_param.K[0]; ++r) { + int h_in = -H_PAD + h + r; + for (int s = 0; s < conv_param.K[1]; ++s) { + int w_in = -W_PAD + w + s; + for (int c = 0; c < C_per_G; ++c) { + if (h_in < 0 || h_in >= H || w_in < 0 || w_in >= W) { + sum += a_zero_point; + } else { + sum += + activations[((h_in * W + w_in) * G + groupNum) * C_per_G + c]; + } + } + } + } + rowOffsetBuf[h * W + w] = sum; + } + } +} + +tuple<bool, int, int, int> getKernelSig( + const conv_param_t<>& conv_param, + bool isZeroPointZero) { + int C_per_G = conv_param.IC / conv_param.G; + int K_per_G = conv_param.OC / conv_param.G; + auto kernelSig = + std::make_tuple(isZeroPointZero, conv_param.G, C_per_G, K_per_G); + return kernelSig; +} + +template <typename accT = int32_t> +jit_conv_kernel_fp getOrCreateConvKernel( + const conv_param_t<>& conv_param, + int a_zero_point) { + // Note: Wrong code is generated if it's not one of the supported convolution + assert(fbgemmOptimizedGConv<2>(conv_param)); + auto kernelSig = getKernelSig(conv_param, a_zero_point == 0); + if (GenConvKernel<accT>::codeCache_.find(kernelSig) != + GenConvKernel<accT>::codeCache_.end()) { + return GenConvKernel<accT>::codeCache_[kernelSig]; + } else { + auto genObj = GenConvKernel<accT>(conv_param, a_zero_point); + // TODO: Instruction set based dispatch + return genObj.template getOrCreate<inst_set_t::avx2>(conv_param); + } +} + +template <> +template <> +void GenConvKernel<int32_t>::createVector8BitOne<inst_set_t::avx2>( + asmjit::X86Emitter* a) { + // create 8-bit 1s + // i.e., oneReg16BitAvx2_[0:7] contains 0x01, oneReg8BitAvx2_[8:15] contains + // 0x01 and so on + a->vpcmpeqw(oneReg8BitAvx2_, oneReg8BitAvx2_, oneReg8BitAvx2_); + a->vpabsb(oneReg8BitAvx2_, oneReg8BitAvx2_); +} + +template <> +template <> +void GenConvKernel<int32_t>::createVector16BitOne<inst_set_t::avx2>( + asmjit::X86Emitter* a) { + // create 16-bit 1s + // i.e., oneReg16BitAvx2_[0:15] contains 0x0001, oneReg16BitAvx2_[16:31] + // contains 0x0001 and so on + a->vpcmpeqw(oneReg16BitAvx2_, oneReg16BitAvx2_, oneReg16BitAvx2_); + a->vpsrlw(oneReg16BitAvx2_, oneReg16BitAvx2_, 15); +} +template <> +template <> +void GenConvKernel<int32_t>::setToZeroPt<inst_set_t::avx2>( + asmjit::X86Emitter* a, + asmjit::X86Ymm destReg) { + // make destReg all zeros + a->vxorps(destReg, destReg, destReg); + asmjit::X86Xmm const_reg_xmm = x86::xmm10; + // move zero point to xmm10 + a->movq(const_reg_xmm, a_zero_pt_R_); + // make copies of zero point + a->vbroadcastsd(x86::ymm10, const_reg_xmm); + // shuffle + // overall impact is that destReg contains 32 8-bit values equal to the lower + // 8-bits of a_zero_pt_R_ + a->vpshufb(destReg, x86::ymm10, destReg); +} + +template <> +template <> +void GenConvKernel<int32_t>::genConstForPermutations<inst_set_t::avx2>( + asmjit::X86Emitter* a) { + asmjit::X86Gp permute_const_reg = a->gpzRef(12); + asmjit::X86Xmm const_reg_xmm = x86::xmm10; + // We have 1st group in even lanes and 2nd group in odd lanes. + // Permute to put 1st group to lower 128-bit and 2nd group in upper + // 128-bit. + // load 7, 5, 3, 1, 6, 4, 2, 0 in a 64-bit reg + a->mov(permute_const_reg, 0x0705030106040200); + a->movq(const_reg_xmm, permute_const_reg); + // Zero extend 8 packed 8-bit integers in the low 8 bytes of const_reg_xmm to + // 8 packed 32-bit integers in stPermRegAvx2_ + a->vpmovzxbd(stPermRegAvx2_, const_reg_xmm); +} + +template <> +template <> +void GenConvKernel<int32_t>::storeResult<inst_set_t::avx2>( + asmjit::X86Emitter* a, + int offset) { + // store with permutation + a->vpermd(resultRegAvx2_, stPermRegAvx2_, resultRegAvx2_); + a->vmovups(x86::dword_ptr(out_acts_R_, offset), resultRegAvx2_); +} + +template <> +template <> +void GenConvKernel<int32_t>::storeResultRowoffset<inst_set_t::avx2>( + asmjit::X86Emitter* a, + int offset) { + // store + a->vmovups(x86::dword_ptr(row_offset_R_, offset), resultRegAvx2_); +} + +template <> +template <> +void GenConvKernel<int32_t>::genForLoadingWeights<inst_set_t::avx2>( + asmjit::X86Emitter* a) { + // load weights + for (int r = 0; r < R_; ++r) { + for (int s = 0; s < S_; ++s) { + a->vmovaps( + WRegs_avx2_[r * S_ + s], + x86::dword_ptr( + wghts_R_, + (r * S_ + s) * G_ * K_per_G_ * C_per_G_ * sizeof(int8_t))); + } + } +} + +template <> +template <> +void GenConvKernel<int32_t>::gen8bitFMA<inst_set_t::avx2>( + asmjit::X86Emitter* a, + asmjit::X86Ymm aReg, + asmjit::X86Ymm wReg) { + a->vpmaddubsw(tmpReg1Avx2_, aReg, wReg); + a->vpmaddwd(tmpReg1Avx2_, oneReg16BitAvx2_, tmpReg1Avx2_); + a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_); +} + +template <> +template <> +void GenConvKernel<int32_t>::gen8BitSum<inst_set_t::avx2>( + asmjit::X86Emitter* a, + asmjit::X86Ymm aReg) { + a->vpmaddubsw(tmpReg1Avx2_, aReg, oneReg8BitAvx2_); + a->vpmaddwd(tmpReg1Avx2_, tmpReg1Avx2_, oneReg16BitAvx2_); + a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_); +} + +template <> +template <> +void GenConvKernel<int32_t>::genForTopEdge<inst_set_t::avx2>( + asmjit::X86Emitter* a) { + // top-left corner code + // zero out the results register + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + for (int r = 0; r < R_; ++r) { + int h_in = -H_PAD_ + r; + if (h_in >= 0) { + a->imul( + scratchReg1_, + W_R_, + static_cast<asmjit::Imm>(h_in * C_ * sizeof(uint8_t))); + } + for (int s = 0; s < S_; ++s) { + int w_in = -W_PAD_ + s; + if (h_in >= 0 && w_in >= 0) { + a->vbroadcastsd( + actRegAvx2_, + x86::dword_ptr( + in_acts_R_, scratchReg1_, 0, w_in * C_ * sizeof(uint8_t))); + gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]); + } else { + if (!isZeroPointZero_) { + gen8bitFMA<inst_set_t::avx2>( + a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]); + } + } + } + } + storeResult<inst_set_t::avx2>(a); + + a->add(out_acts_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + + // top edge excluding corners + asmjit::Label LoopTopEdge = a->newLabel(); + a->mov(loopR2_, static_cast<asmjit::Imm>(W_PAD_)); + a->bind(LoopTopEdge); + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + if (!isZeroPointZero_) { + for (int r = 0; r < H_PAD_; ++r) { + for (int s = 0; s < S_; ++s) { + gen8bitFMA<inst_set_t::avx2>(a, zeroPTRegAvx2_, WRegs_avx2_[s]); + } + } + } + for (int r = H_PAD_; r < R_; ++r) { + int h_in = -H_PAD_ + r; + a->imul( + scratchReg1_, + W_R_, + static_cast<asmjit::Imm>(h_in * C_ * sizeof(uint8_t))); + for (int s = 0; s < S_; ++s) { + a->vbroadcastsd( + actRegAvx2_, + x86::dword_ptr( + in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t))); + gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]); + } + } + a->add(in_acts_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + + storeResult<inst_set_t::avx2>(a); + + a->add(out_acts_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + a->mov(loopR1_, W_R_); + a->sub(loopR1_, static_cast<asmjit::Imm>(W_PAD_)); + a->inc(loopR2_); + a->cmp(loopR2_, loopR1_); + a->jl(LoopTopEdge); + a->mov(scratchReg2_, W_R_); + a->imul(scratchReg2_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->sub( + scratchReg2_, + static_cast<asmjit::Imm>(2 * W_PAD_ * C_ * sizeof(uint8_t))); + a->sub(in_acts_R_, scratchReg2_); + + // top-right corner code + + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + if (!isZeroPointZero_) { + for (int r = 0; r < H_PAD_; ++r) { + for (int s = 0; s < S_; ++s) { + gen8bitFMA<inst_set_t::avx2>( + a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]); + } + } + } + for (int r = H_PAD_; r < R_; ++r) { + int h_in = -H_PAD_ + r; + for (int s = 0; s < S_ - W_PAD_; ++s) { + a->imul( + scratchReg1_, + W_R_, + static_cast<asmjit::Imm>(h_in * C_ * sizeof(uint8_t))); + a->mov(scratchReg2_, W_R_); + a->sub(scratchReg2_, static_cast<asmjit::Imm>(R_ - W_PAD_ - s)); + a->imul(scratchReg2_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->add(scratchReg1_, scratchReg2_); + a->vbroadcastsd(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_)); + gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]); + } + if (!isZeroPointZero_) { + for (int s = S_ - W_PAD_; s < S_; ++s) { + gen8bitFMA<inst_set_t::avx2>( + a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]); + } + } + } + storeResult<inst_set_t::avx2>(a); + a->add(out_acts_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + + // reset output activation pointer + a->imul(scratchReg1_, W_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + a->sub(out_acts_R_, scratchReg1_); +} + +template <> +template <> +void GenConvKernel<int32_t>::genForLeftEdge<inst_set_t::avx2>( + asmjit::X86Emitter* a) { + // left edge excluding corners + asmjit::Label LoopLeftEdge = a->newLabel(); + a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_)); + a->bind(LoopLeftEdge); + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + a->mov(scratchReg1_, loopR1_); + a->sub(scratchReg1_, static_cast<asmjit::Imm>(H_PAD_)); + a->imul(scratchReg1_, W_R_); + a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + for (int r = 0; r < R_; ++r) { + if (!isZeroPointZero_) { + for (int s = 0; s < W_PAD_; ++s) { + gen8bitFMA<inst_set_t::avx2>( + a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]); + } + } + for (int s = W_PAD_; s < S_; ++s) { + a->vbroadcastsd( + actRegAvx2_, + x86::dword_ptr( + in_acts_R_, + scratchReg1_, + 0, + (s - W_PAD_) * C_ * sizeof(uint8_t))); + gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]); + } + a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->add(scratchReg1_, scratchReg2_); + } + + a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + a->add(out_acts_R_, scratchReg2_); + storeResult<inst_set_t::avx2>(a); + + a->inc(loopR1_); + a->mov(loopR2_, H_R_); + a->sub(loopR2_, static_cast<asmjit::Imm>(H_PAD_)); + a->cmp(loopR1_, loopR2_); + a->jl(LoopLeftEdge); + + // reset output activation pointer + a->mov(scratchReg2_, H_R_); + a->sub(scratchReg2_, static_cast<asmjit::Imm>(2 * H_PAD_)); + a->imul(scratchReg2_, W_R_); + a->imul(scratchReg2_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + a->sub(out_acts_R_, scratchReg2_); +} + +template <> +template <> +void GenConvKernel<int32_t>::genForRightEdge<inst_set_t::avx2>( + asmjit::X86Emitter* a) { + // right edge excluding corners + asmjit::Label LoopRightEdge = a->newLabel(); + + // output pointer to the right edge + // (W_ + W_ - 1)*K_*sizeof(int32_t) + a->mov(scratchReg2_, W_R_); + a->imul(scratchReg2_, 2); + a->sub(scratchReg2_, 1); + a->imul(scratchReg2_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + a->add(out_acts_R_, scratchReg2_); + + a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_)); + a->bind(LoopRightEdge); + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + a->mov(scratchReg1_, loopR1_); + a->sub(scratchReg1_, static_cast<asmjit::Imm>(H_PAD_)); + a->imul(scratchReg1_, W_R_); + a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + + a->mov(scratchReg2_, W_R_); + a->sub(scratchReg2_, static_cast<asmjit::Imm>(2 * W_PAD_)); + a->imul(scratchReg2_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->add(scratchReg1_, scratchReg2_); + for (int r = 0; r < R_; ++r) { + for (int s = 0; s < S_ - W_PAD_; ++s) { + a->vbroadcastsd(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_)); + gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]); + a->add(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + } + if (!isZeroPointZero_) { + for (int s = S_ - W_PAD_; s < S_; ++s) { + gen8bitFMA<inst_set_t::avx2>( + a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]); + } + } + + a->sub( + scratchReg1_, + static_cast<asmjit::Imm>((S_ - W_PAD_) * C_ * sizeof(uint8_t))); + a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->add(scratchReg1_, scratchReg2_); + } + + // storeResult<inst_set_t::avx2>(a, (W_+W_-1)*K_*sizeof(int32_t)); + storeResult<inst_set_t::avx2>(a); + + a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + a->add(out_acts_R_, scratchReg2_); + a->mov(loopR2_, H_R_); + a->sub(loopR2_, static_cast<asmjit::Imm>(H_PAD_)); + a->inc(loopR1_); + a->cmp(loopR1_, loopR2_); + a->jl(LoopRightEdge); + + // reset base + a->mov(scratchReg2_, W_R_); + a->imul(scratchReg2_, 2); + a->sub(scratchReg2_, 1); + a->imul(scratchReg2_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + a->sub(out_acts_R_, scratchReg2_); + + // reset loop increments + //(H_ - 2*H_PAD_)*W_*K_*sizeof(int32_t) + a->mov(scratchReg2_, H_R_); + a->sub(scratchReg2_, static_cast<asmjit::Imm>(2 * H_PAD_)); + a->imul(scratchReg2_, W_R_); + a->imul(scratchReg2_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + a->sub(out_acts_R_, scratchReg2_); + // a->sub(out_acts_R_, (H_ - 2*H_PAD_)*W_*K_*sizeof(int32_t)); +} + +template <> +template <> +void GenConvKernel<int32_t>::genForBottomEdge<inst_set_t::avx2>( + asmjit::X86Emitter* a) { + // bottom-left corner + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + a->mov(scratchReg1_, H_R_); + a->sub(scratchReg1_, static_cast<asmjit::Imm>(2 * H_PAD_)); + a->imul(scratchReg1_, W_R_); + a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + for (int r = 0; r < R_ - H_PAD_; ++r) { + if (!isZeroPointZero_) { + for (int s = 0; s < W_PAD_; ++s) { + gen8bitFMA<inst_set_t::avx2>( + a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]); + } + } + for (int s = W_PAD_; s < S_; ++s) { + a->vbroadcastsd( + actRegAvx2_, + x86::dword_ptr( + in_acts_R_, + scratchReg1_, + 0, + (s - W_PAD_) * C_ * sizeof(uint8_t))); + gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]); + } + a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->add(scratchReg1_, scratchReg2_); + } + if (!isZeroPointZero_) { + for (int r = R_ - H_PAD_; r < R_; ++r) { + for (int s = 0; s < S_; ++s) { + gen8bitFMA<inst_set_t::avx2>( + a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]); + } + } + } + + // we updating the last row + a->mov(scratchReg1_, H_R_); + a->sub(scratchReg1_, 1); + a->imul(scratchReg1_, W_R_); + a->imul(scratchReg1_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + a->add(out_acts_R_, scratchReg1_); + // storeResult<inst_set_t::avx2>(a, (H_-1)*W_*K_*sizeof(int32_t)); + storeResult<inst_set_t::avx2>(a); + a->add(out_acts_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + + // bottom edge excluding corners + asmjit::Label LoopBottomEdge = a->newLabel(); + a->mov(loopR2_, static_cast<asmjit::Imm>(W_PAD_)); + a->bind(LoopBottomEdge); + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + a->mov(scratchReg1_, H_R_); + a->sub(scratchReg1_, static_cast<asmjit::Imm>(2 * H_PAD_)); + a->imul(scratchReg1_, W_R_); + a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + for (int r = 0; r < R_ - W_PAD_; ++r) { + // int h_in = H_-2*H_PAD_ + r; + for (int s = 0; s < S_; ++s) { + a->vbroadcastsd( + actRegAvx2_, + x86::dword_ptr( + in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t))); + gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]); + } + a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->add(scratchReg1_, scratchReg2_); + } + + if (!isZeroPointZero_) { + for (int r = R_ - W_PAD_; r < R_; ++r) { + for (int s = 0; s < S_; ++s) { + gen8bitFMA<inst_set_t::avx2>( + a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]); + } + } + } + + a->add(in_acts_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + // storeResult<inst_set_t::avx2>(a, ((H_-1)*W_+1)*K_*sizeof(int32_t)); + storeResult<inst_set_t::avx2>(a); + + a->add(out_acts_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + a->inc(loopR2_); + a->mov(loopR1_, W_R_); + a->sub(loopR1_, static_cast<asmjit::Imm>(W_PAD_)); + a->cmp(loopR2_, loopR1_); + a->jl(LoopBottomEdge); + a->mov(scratchReg1_, W_R_); + a->sub(scratchReg1_, static_cast<asmjit::Imm>(2 * W_PAD_)); + a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->sub(in_acts_R_, scratchReg1_); + // a->sub(in_acts_R_, (W_ - 2*W_PAD_)*C_*sizeof(uint8_t)); + // a->sub(out_acts_R_, (W_ - 2*W_PAD_)*K_*sizeof(int32_t)); + + // bottom-right corner + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + // input start point + // ((H_-(R_-H_PAD_))*W_+(W_-(S_-W_PAD_)))*C_*sizeof(uint8_t) + a->mov(scratchReg1_, H_R_); + a->sub(scratchReg1_, static_cast<asmjit::Imm>(R_ - H_PAD_)); + a->imul(scratchReg1_, W_R_); + a->add(scratchReg1_, W_R_); + a->sub(scratchReg1_, static_cast<asmjit::Imm>(S_ - W_PAD_)); + a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + for (int r = 0; r < R_ - H_PAD_; ++r) { + for (int s = 0; s < S_ - W_PAD_; ++s) { + a->vbroadcastsd( + actRegAvx2_, + x86::dword_ptr( + in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t))); + gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]); + } + a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->add(scratchReg1_, scratchReg2_); + if (!isZeroPointZero_) { + for (int s = S_ - W_PAD_; s < S_; ++s) { + gen8bitFMA<inst_set_t::avx2>( + a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]); + } + } + } + + if (!isZeroPointZero_) { + for (int r = R_ - H_PAD_; r < R_; ++r) { + for (int s = 0; s < S_; ++s) { + gen8bitFMA<inst_set_t::avx2>( + a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]); + } + } + } + + storeResult<inst_set_t::avx2>(a); + // storeResult<inst_set_t::avx2>(a, ((H_-1)*W_+W_-1)*K_*sizeof(int32_t)); + // reset output pointer + a->mov(scratchReg1_, H_R_); + a->sub(scratchReg1_, 1); + a->imul(scratchReg1_, W_R_); + a->add(scratchReg1_, W_R_); + a->sub(scratchReg1_, 1); + a->imul(scratchReg1_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + a->sub(out_acts_R_, scratchReg1_); +} + +template <> +template <> +void GenConvKernel<int32_t>::genCoreInsts<inst_set_t::avx2>( + asmjit::X86Emitter* a) { + // main compute + asmjit::Label LoopH = a->newLabel(); + asmjit::Label LoopW = a->newLabel(); + // base for output + a->mov(scratchReg2_, static_cast<asmjit::Imm>(H_PAD_)); + a->imul(scratchReg2_, W_R_); + a->add(scratchReg2_, static_cast<asmjit::Imm>(W_PAD_)); + a->imul(scratchReg2_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + a->add(out_acts_R_, scratchReg2_); + + a->mov(scratchReg1_, W_R_); + a->sub(scratchReg1_, static_cast<asmjit::Imm>(W_PAD_)); + + // H loop + a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_)); + a->bind(LoopH); + // W loop + a->mov(loopR2_, static_cast<asmjit::Imm>(W_PAD_)); + a->bind(LoopW); + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + // compute on all filters + for (int r = 0; r < R_; ++r) { + for (int s = 0; s < S_; ++s) { + a->vbroadcastsd( + actRegAvx2_, x86::dword_ptr(in_acts_R_, s * C_ * sizeof(uint8_t))); + gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]); + } + a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->add(in_acts_R_, scratchReg2_); + } + a->imul( + scratchReg2_, W_R_, static_cast<asmjit::Imm>(R_ * C_ * sizeof(uint8_t))); + a->sub(in_acts_R_, scratchReg2_); + // a->add(scratchReg1_, C_*sizeof(uint8_t)); + a->add(in_acts_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + + // storeResult<inst_set_t::avx2>(a, (W_+1)*K_*sizeof(int32_t)); + storeResult<inst_set_t::avx2>(a); + + a->add(out_acts_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t))); + + a->inc(loopR2_); + a->cmp(loopR2_, scratchReg1_); + a->jl(LoopW); + // add (W_ - 2*W_PAD_)*C_*sizeof(uint8_t) and subtract W_*C_*sizeof(uint8_t) + a->add( + in_acts_R_, static_cast<asmjit::Imm>(2 * W_PAD_ * C_ * sizeof(uint8_t))); + // a->sub(in_acts_R_, (W_ - 2*W_PAD_)*C_*sizeof(uint8_t)); + // a->add(in_acts_R_, W_*C_*sizeof(uint8_t)); + a->add( + out_acts_R_, static_cast<asmjit::Imm>(2 * W_PAD_ * K_ * sizeof(int32_t))); + // a->sub(out_acts_R_, (W_ - 2*W_PAD_)*K_*sizeof(int32_t)); + // a->add(out_acts_R_, W_*K_*sizeof(int32_t)); + + a->inc(loopR1_); + a->mov(scratchReg2_, H_R_); + a->sub(scratchReg2_, static_cast<asmjit::Imm>(H_PAD_)); + a->cmp(loopR1_, scratchReg2_); + a->jl(LoopH); +} + +template <> +template <> +jit_conv_kernel_fp GenConvKernel<int32_t>::getOrCreate<inst_set_t::avx2>( + const conv_param_t<>& conv_param) { + code_.reset(false); + code_.init(rt_.getCodeInfo()); + asmjit::X86Assembler assembler(&code_); + asmjit::X86Emitter* a = assembler.asEmitter(); + +#if defined(FBGEMM_LOG_CODE) + // log code to a file + FILE* codeLogfile = + fopen(getCodeLoggingFile<inst_set_t::avx2>(false).c_str(), "w"); + asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); + if (codeLogger) { + code_.setLogger(codeLogger); + } +#endif + + // arguments to the function created + in_acts_R_ = a->zdi(); + wghts_R_ = a->zsi(); + out_acts_R_ = a->zdx(); + a_zero_pt_R_ = a->zcx(); + H_R_ = a->gpzRef(8); + W_R_ = a->gpzRef(9); + row_offset_R_ = a->gpzRef(10); + + // register for temporary use + scratchReg1_ = a->gpzRef(12); + scratchReg2_ = a->gpzRef(13); + + asmjit::FuncDetail func; + func.init(asmjit::FuncSignature6< + void, + uint8_t*, + int8_t*, + int32_t*, + int32_t, + int32_t, + int32_t>(asmjit::CallConv::kIdHost)); + + asmjit::FuncFrameInfo ffi; + ffi.setDirtyRegs( + asmjit::X86Reg::kKindVec, + asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); + ffi.setDirtyRegs( + asmjit::X86Reg::kKindGp, asmjit::Utils::mask(10, 11, 12, 13, 14, 15)); + + asmjit::FuncArgsMapper args(&func); + args.assignAll(in_acts_R_, wghts_R_, out_acts_R_, a_zero_pt_R_, H_R_, W_R_); + + args.updateFrameInfo(ffi); + + asmjit::FuncFrameLayout layout; + layout.init(func, ffi); + + asmjit::FuncUtils::emitProlog(a, layout); + asmjit::FuncUtils::allocArgs(a, layout, args); + + createVector16BitOne<inst_set_t::avx2>(a); + + loopR1_ = a->gpzRef(14); + loopR2_ = a->gpzRef(15); + + if (!isZeroPointZero_) { + setToZeroPt<inst_set_t::avx2>(a, zeroPTRegAvx2_); + } + + genForLoadingWeights<inst_set_t::avx2>(a); + + genConstForPermutations<inst_set_t::avx2>(a); + + genForTopEdge<inst_set_t::avx2>(a); + genForLeftEdge<inst_set_t::avx2>(a); + genForRightEdge<inst_set_t::avx2>(a); + genForBottomEdge<inst_set_t::avx2>(a); + + genCoreInsts<inst_set_t::avx2>(a); + + asmjit::FuncUtils::emitEpilog(a, layout); + + jit_conv_kernel_fp fn; + asmjit::Error err = rt_.add(&fn, &code_); + if (err) { + std::cout << "Error: in fn add" << std::endl; + return nullptr; + } + auto kernelSig = getKernelSig(conv_param, isZeroPointZero_); + codeCache_[kernelSig] = fn; + +#if defined(FBGEMM_LOG_CODE) + fclose(codeLogfile); + delete codeLogger; +#endif + + return fn; +} + +template <> +template <> +void GenConvKernel<int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>( + asmjit::X86Emitter* a) { + // top-left corner code + // zero out the results register + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + for (int r = 0; r < R_; ++r) { + int h_in = -H_PAD_ + r; + if (h_in >= 0) { + a->imul( + scratchReg1_, + W_R_, + static_cast<asmjit::Imm>(h_in * C_ * sizeof(uint8_t))); + } + for (int s = 0; s < S_; ++s) { + int w_in = -W_PAD_ + s; + if (h_in >= 0 && w_in >= 0) { + a->vmovaps( + actRegAvx2_, + x86::dword_ptr( + in_acts_R_, scratchReg1_, 0, w_in * C_ * sizeof(uint8_t))); + gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_); + } else { + if (!isZeroPointZero_) { + gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_); + } + } + } + } + // store results + storeResultRowoffset<inst_set_t::avx2>(a); + + // for C_per_G == 4 and K_per_G == 4, 8 groups processed at a time + a->add(row_offset_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); + + // top edge excluding corners + asmjit::Label LoopTopEdge = a->newLabel(); + a->mov(loopR2_, static_cast<asmjit::Imm>(W_PAD_)); + a->bind(LoopTopEdge); + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + if (!isZeroPointZero_) { + for (int r = 0; r < H_PAD_; ++r) { + for (int s = 0; s < S_; ++s) { + gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_); + } + } + } + for (int r = H_PAD_; r < R_; ++r) { + int h_in = -H_PAD_ + r; + a->imul( + scratchReg1_, + W_R_, + static_cast<asmjit::Imm>(h_in * C_ * sizeof(uint8_t))); + for (int s = 0; s < S_; ++s) { + a->vmovaps( + actRegAvx2_, + x86::dword_ptr( + in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t))); + gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_); + } + } + a->add(in_acts_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + + // store results + storeResultRowoffset<inst_set_t::avx2>(a); + + a->add(row_offset_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); + a->mov(loopR1_, W_R_); + a->sub(loopR1_, static_cast<asmjit::Imm>(W_PAD_)); + a->inc(loopR2_); + a->cmp(loopR2_, loopR1_); + a->jl(LoopTopEdge); + a->mov(scratchReg2_, W_R_); + a->imul(scratchReg2_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->sub( + scratchReg2_, + static_cast<asmjit::Imm>(2 * W_PAD_ * C_ * sizeof(uint8_t))); + a->sub(in_acts_R_, scratchReg2_); + + // top-right corner code + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + if (!isZeroPointZero_) { + for (int r = 0; r < H_PAD_; ++r) { + for (int s = 0; s < S_; ++s) { + gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_); + } + } + } + for (int r = H_PAD_; r < R_; ++r) { + int h_in = -H_PAD_ + r; + for (int s = 0; s < S_ - W_PAD_; ++s) { + a->imul( + scratchReg1_, + W_R_, + static_cast<asmjit::Imm>(h_in * C_ * sizeof(uint8_t))); + a->mov(scratchReg2_, W_R_); + a->sub(scratchReg2_, static_cast<asmjit::Imm>(R_ - W_PAD_ - s)); + a->imul(scratchReg2_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->add(scratchReg1_, scratchReg2_); + a->vmovaps(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_)); + gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_); + } + if (!isZeroPointZero_) { + for (int s = S_ - W_PAD_; s < S_; ++s) { + gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_); + } + } + } + + // store results + storeResultRowoffset<inst_set_t::avx2>(a); + + a->add(row_offset_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); + + // reset output pointer + a->imul(scratchReg1_, W_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); + a->sub(row_offset_R_, scratchReg1_); +} + +template <> +template <> +void GenConvKernel<int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>( + asmjit::X86Emitter* a) { + // left edge excluding corners + asmjit::Label LoopLeftEdge = a->newLabel(); + a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_)); + a->bind(LoopLeftEdge); + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + a->mov(scratchReg1_, loopR1_); + a->sub(scratchReg1_, static_cast<asmjit::Imm>(H_PAD_)); + a->imul(scratchReg1_, W_R_); + a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + for (int r = 0; r < R_; ++r) { + if (!isZeroPointZero_) { + for (int s = 0; s < W_PAD_; ++s) { + gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_); + } + } + for (int s = W_PAD_; s < S_; ++s) { + a->vmovaps( + actRegAvx2_, + x86::dword_ptr( + in_acts_R_, + scratchReg1_, + 0, + (s - W_PAD_) * C_ * sizeof(uint8_t))); + gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_); + } + a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->add(scratchReg1_, scratchReg2_); + } + + a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); + a->add(row_offset_R_, scratchReg2_); + storeResultRowoffset<inst_set_t::avx2>(a); + + a->inc(loopR1_); + a->mov(loopR2_, H_R_); + a->sub(loopR2_, static_cast<asmjit::Imm>(H_PAD_)); + a->cmp(loopR1_, loopR2_); + a->jl(LoopLeftEdge); + + // reset output pointer + a->mov(scratchReg2_, H_R_); + a->sub(scratchReg2_, static_cast<asmjit::Imm>(2 * H_PAD_)); + a->imul(scratchReg2_, W_R_); + a->imul(scratchReg2_, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); + a->sub(row_offset_R_, scratchReg2_); +} + +template <> +template <> +void GenConvKernel<int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>( + asmjit::X86Emitter* a) { + // right edge excluding corners + asmjit::Label LoopRightEdge = a->newLabel(); + + // output pointer to the right edge + // (W_ + W_ - 1)*8*sizeof(int32_t) + a->mov(scratchReg2_, W_R_); + a->imul(scratchReg2_, 2); + a->sub(scratchReg2_, 1); + a->imul(scratchReg2_, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); + a->add(row_offset_R_, scratchReg2_); + + a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_)); + a->bind(LoopRightEdge); + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + a->mov(scratchReg1_, loopR1_); + a->sub(scratchReg1_, static_cast<asmjit::Imm>(H_PAD_)); + a->imul(scratchReg1_, W_R_); + a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + + a->mov(scratchReg2_, W_R_); + a->sub(scratchReg2_, static_cast<asmjit::Imm>(2 * W_PAD_)); + a->imul(scratchReg2_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->add(scratchReg1_, scratchReg2_); + for (int r = 0; r < R_; ++r) { + for (int s = 0; s < S_ - W_PAD_; ++s) { + a->vbroadcastsd(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_)); + a->vmovaps(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_)); + gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_); + a->add(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + } + if (!isZeroPointZero_) { + for (int s = S_ - W_PAD_; s < S_; ++s) { + gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_); + } + } + + a->sub( + scratchReg1_, + static_cast<asmjit::Imm>((S_ - W_PAD_) * C_ * sizeof(uint8_t))); + a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->add(scratchReg1_, scratchReg2_); + } + + storeResultRowoffset<inst_set_t::avx2>(a); + + a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); + a->add(row_offset_R_, scratchReg2_); + a->mov(loopR2_, H_R_); + a->sub(loopR2_, static_cast<asmjit::Imm>(H_PAD_)); + a->inc(loopR1_); + a->cmp(loopR1_, loopR2_); + a->jl(LoopRightEdge); + + // reset base + a->mov(scratchReg2_, W_R_); + a->imul(scratchReg2_, 2); + a->sub(scratchReg2_, 1); + a->imul(scratchReg2_, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); + a->sub(row_offset_R_, scratchReg2_); + + // reset increments done in the loop + //(H_ - 2*H_PAD_)*W_*8*sizeof(int32_t) + a->mov(scratchReg2_, H_R_); + a->sub(scratchReg2_, static_cast<asmjit::Imm>(2 * H_PAD_)); + a->imul(scratchReg2_, W_R_); + a->imul(scratchReg2_, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); + a->sub(row_offset_R_, scratchReg2_); +} + +template <> +template <> +void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>( + asmjit::X86Emitter* a) { + // bottom-left corner + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + a->mov(scratchReg1_, H_R_); + a->sub(scratchReg1_, static_cast<asmjit::Imm>(2 * H_PAD_)); + a->imul(scratchReg1_, W_R_); + a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + for (int r = 0; r < R_ - H_PAD_; ++r) { + if (!isZeroPointZero_) { + for (int s = 0; s < W_PAD_; ++s) { + gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_); + } + } + for (int s = W_PAD_; s < S_; ++s) { + a->vmovaps( + actRegAvx2_, + x86::dword_ptr( + in_acts_R_, + scratchReg1_, + 0, + (s - W_PAD_) * C_ * sizeof(uint8_t))); + gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_); + } + a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->add(scratchReg1_, scratchReg2_); + } + if (!isZeroPointZero_) { + for (int r = R_ - H_PAD_; r < R_; ++r) { + for (int s = 0; s < S_; ++s) { + gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_); + } + } + } + + // we updating the last row + a->mov(scratchReg1_, H_R_); + a->sub(scratchReg1_, 1); + a->imul(scratchReg1_, W_R_); + a->imul(scratchReg1_, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); + a->add(row_offset_R_, scratchReg1_); + storeResultRowoffset<inst_set_t::avx2>(a); + a->add(row_offset_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); + + // bottom edge excluding corners + asmjit::Label LoopBottomEdge = a->newLabel(); + a->mov(loopR2_, static_cast<asmjit::Imm>(W_PAD_)); + a->bind(LoopBottomEdge); + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + a->mov(scratchReg1_, H_R_); + a->sub(scratchReg1_, static_cast<asmjit::Imm>(2 * H_PAD_)); + a->imul(scratchReg1_, W_R_); + a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + for (int r = 0; r < R_ - W_PAD_; ++r) { + // int h_in = H_-2*H_PAD_ + r; + for (int s = 0; s < S_; ++s) { + a->vmovaps( + actRegAvx2_, + x86::dword_ptr( + in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t))); + gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_); + } + a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->add(scratchReg1_, scratchReg2_); + } + + if (!isZeroPointZero_) { + for (int r = R_ - W_PAD_; r < R_; ++r) { + for (int s = 0; s < S_; ++s) { + gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_); + } + } + } + + a->add(in_acts_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + // storeResult<inst_set_t::avx2>(a, ((H_-1)*W_+1)*8*sizeof(int32_t)); + storeResultRowoffset<inst_set_t::avx2>(a); + + a->add(row_offset_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); + a->inc(loopR2_); + a->mov(loopR1_, W_R_); + a->sub(loopR1_, static_cast<asmjit::Imm>(W_PAD_)); + a->cmp(loopR2_, loopR1_); + a->jl(LoopBottomEdge); + a->mov(scratchReg1_, W_R_); + a->sub(scratchReg1_, static_cast<asmjit::Imm>(2 * W_PAD_)); + a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->sub(in_acts_R_, scratchReg1_); + // a->sub(in_acts_R_, (W_ - 2*W_PAD_)*C_*sizeof(uint8_t)); + // a->sub(out_acts_R_, (W_ - 2*W_PAD_)*8*sizeof(int32_t)); + + // bottom-right corner + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + // input start point + // ((H_-(R_-H_PAD_))*W_+(W_-(S_-W_PAD_)))*C_*sizeof(uint8_t) + a->mov(scratchReg1_, H_R_); + a->sub(scratchReg1_, static_cast<asmjit::Imm>(R_ - H_PAD_)); + a->imul(scratchReg1_, W_R_); + a->add(scratchReg1_, W_R_); + a->sub(scratchReg1_, static_cast<asmjit::Imm>(S_ - W_PAD_)); + a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + for (int r = 0; r < R_ - H_PAD_; ++r) { + for (int s = 0; s < S_ - W_PAD_; ++s) { + a->vmovaps( + actRegAvx2_, + x86::dword_ptr( + in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t))); + gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_); + } + a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->add(scratchReg1_, scratchReg2_); + if (!isZeroPointZero_) { + for (int s = S_ - W_PAD_; s < S_; ++s) { + gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_); + } + } + } + + if (!isZeroPointZero_) { + for (int r = R_ - H_PAD_; r < R_; ++r) { + for (int s = 0; s < S_; ++s) { + gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_); + } + } + } + + storeResultRowoffset<inst_set_t::avx2>(a); + // reset output pointer + a->mov(scratchReg1_, H_R_); + a->sub(scratchReg1_, 1); + a->imul(scratchReg1_, W_R_); + a->add(scratchReg1_, W_R_); + a->sub(scratchReg1_, 1); + a->imul(scratchReg1_, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); + a->sub(row_offset_R_, scratchReg1_); +} + +template <> +template <> +void GenConvKernel<int32_t>::genRowoffsetCore<inst_set_t::avx2>( + asmjit::X86Emitter* a) { + // number of uint8 elements in input channels should be a multiple of 32 + assert(C_ % 32 == 0); + + asmjit::Label LoopH = a->newLabel(); + asmjit::Label LoopW = a->newLabel(); + // base for output + a->mov(scratchReg2_, static_cast<asmjit::Imm>(H_PAD_)); + a->imul(scratchReg2_, W_R_); + a->add(scratchReg2_, static_cast<asmjit::Imm>(W_PAD_)); + a->imul(scratchReg2_, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); + a->add(row_offset_R_, scratchReg2_); + + a->mov(scratchReg1_, W_R_); + a->sub(scratchReg1_, static_cast<asmjit::Imm>(W_PAD_)); + + // H loop + a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_)); + a->bind(LoopH); + // W loop + a->mov(loopR2_, static_cast<asmjit::Imm>(W_PAD_)); + a->bind(LoopW); + + // zero out + a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); + for (int r = 0; r < R_; ++r) { + for (int s = 0; s < S_; ++s) { + a->vmovaps( + actRegAvx2_, x86::dword_ptr(in_acts_R_, s * C_ * sizeof(uint8_t))); + gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_); + } + a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->add(in_acts_R_, scratchReg2_); + } + a->imul( + scratchReg2_, W_R_, static_cast<asmjit::Imm>(R_ * C_ * sizeof(uint8_t))); + a->sub(in_acts_R_, scratchReg2_); + // store results + storeResultRowoffset<inst_set_t::avx2>(a); + + a->add(in_acts_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t))); + a->add(row_offset_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t))); + + a->inc(loopR2_); + a->cmp(loopR2_, scratchReg1_); + a->jl(LoopW); + a->add( + in_acts_R_, static_cast<asmjit::Imm>(2 * W_PAD_ * C_ * sizeof(uint8_t))); + a->add( + row_offset_R_, + static_cast<asmjit::Imm>(2 * W_PAD_ * 8 * sizeof(int32_t))); + a->inc(loopR1_); + a->mov(scratchReg2_, H_R_); + a->sub(scratchReg2_, static_cast<asmjit::Imm>(H_PAD_)); + a->cmp(loopR1_, scratchReg2_); + a->jl(LoopH); +} + +template <> +template <> +jit_rowoffset_kernel_fp +GenConvKernel<int32_t>::getOrCreateRowOffset<inst_set_t::avx2>( + const conv_param_t<>& conv_param) { + code_.reset(false); + code_.init(rt_.getCodeInfo()); + asmjit::X86Assembler assembler(&code_); + asmjit::X86Emitter* a = assembler.asEmitter(); + +#if defined(FBGEMM_LOG_CODE) + // log code to a file + FILE* codeLogfile = + fopen(getCodeLoggingFile<inst_set_t::avx2>(true).c_str(), "w"); + asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); + if (codeLogger) { + code_.setLogger(codeLogger); + } +#endif + + // arguments to the function created + in_acts_R_ = a->zdi(); + a_zero_pt_R_ = a->zsi(); + H_R_ = a->zdx(); + W_R_ = a->zcx(); + row_offset_R_ = a->gpzRef(8); + + // register for temporary use + scratchReg1_ = a->gpzRef(12); + scratchReg2_ = a->gpzRef(13); + + loopR1_ = a->gpzRef(14); + loopR2_ = a->gpzRef(15); + + asmjit::FuncDetail func; + func.init( + asmjit:: + FuncSignature5<void, uint8_t*, int32_t, int32_t, int32_t, int32_t*>( + asmjit::CallConv::kIdHost)); + + asmjit::FuncFrameInfo ffi; + ffi.setDirtyRegs( + asmjit::X86Reg::kKindVec, + asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); + ffi.setDirtyRegs( + asmjit::X86Reg::kKindGp, asmjit::Utils::mask(10, 11, 12, 13, 14, 15)); + + asmjit::FuncArgsMapper args(&func); + args.assignAll(in_acts_R_, a_zero_pt_R_, H_R_, W_R_, row_offset_R_); + + args.updateFrameInfo(ffi); + + asmjit::FuncFrameLayout layout; + layout.init(func, ffi); + + asmjit::FuncUtils::emitProlog(a, layout); + asmjit::FuncUtils::allocArgs(a, layout, args); + + // This uses xmm10 register temporarily. Should come before + // createVector8BitOne + if (!isZeroPointZero_) { + setToZeroPt<inst_set_t::avx2>(a, zeroPTRegAvx2_); + } + + createVector16BitOne<inst_set_t::avx2>(a); + // we set ymm10 to contain 8-bit 1s + createVector8BitOne<inst_set_t::avx2>(a); + + genForTopEdgeRowoffset<inst_set_t::avx2>(a); + genForLeftEdgeRowoffset<inst_set_t::avx2>(a); + genForRightEdgeRowoffset<inst_set_t::avx2>(a); + genForBottomEdgeRowoffset<inst_set_t::avx2>(a); + + genRowoffsetCore<inst_set_t::avx2>(a); + + asmjit::FuncUtils::emitEpilog(a, layout); + + jit_rowoffset_kernel_fp fn; + asmjit::Error err = rt_.add(&fn, &code_); + if (err) { + std::cout << "Error: in fn add" << std::endl; + return nullptr; + } + auto kernelSig = getKernelSig(conv_param, isZeroPointZero_); + codeCacheRowOffset_[kernelSig] = fn; + +#if defined(FBGEMM_LOG_CODE) + delete codeLogger; + fclose(codeLogfile); +#endif + + return fn; +} + +template < + typename packed_W, + typename outType, + typename processOutputType, + int SPATIAL_DIM> +void fbgemmGroupwiseConv( + const conv_param_t<SPATIAL_DIM>& conv_param, + const std::uint8_t* activations, + std::int32_t a_zero_point, + std::int32_t* rowOffsetBuf, + packed_W& packed_weights, + outType* out, + int32_t* outBuffer, + const processOutputType& outProcess, + int thread_id, + int num_threads) { + + int MB = conv_param.MB; + int H = conv_param.OUT_DIM[0]; + int W = conv_param.OUT_DIM[1]; + int G = conv_param.G; + int K_per_G = conv_param.OC / G; + int C_per_G = conv_param.IC / G; + int oh_ow = conv_param.OUT_DIM[0] * conv_param.OUT_DIM[1]; + + static_assert(SPATIAL_DIM == 2, "3D conv not supported yet"); + + int32_t* rowOffsetTrDest = + rowOffsetBuf + 8 * conv_param.IN_DIM[0] * conv_param.IN_DIM[1]; + if (fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param)) { + assert(G % 8 == 0); + // generate convolution kernel + jit_conv_kernel_fp fpConv = + getOrCreateConvKernel<>(conv_param, a_zero_point); + // generate row offset kernel + jit_rowoffset_kernel_fp fpRowoffset = + getOrCreateRowOffsetKernel(conv_param, a_zero_point); + for (int i = 0; i < MB; ++i) { + const uint8_t* actStartBatch = activations + + i * conv_param.IN_DIM[0] * conv_param.IN_DIM[1] * conv_param.IC; + for (int gOuter = 0; gOuter < G; gOuter += 8) { + // for C_per_G == 4 and K_per_G == 4, row offset is calcualted for 8 + // groups at a time The result is row offsets in the format IH*IW x G + fpRowoffset( + actStartBatch + gOuter * C_per_G, + a_zero_point, + H, + W, + rowOffsetBuf); + // Transpose to get row offsets in the format G x IH*IW + internal::transpose_8x8( + conv_param.IN_DIM[0] * conv_param.IN_DIM[1], + 8, + (const float*)rowOffsetBuf, + 8, + (float*)rowOffsetTrDest, + conv_param.IN_DIM[0] * conv_param.IN_DIM[1]); + int gLimit = gOuter + 8; + for (int g = gOuter; g < gLimit; g += 2) { + int32_t* currOutBuf = + outBuffer + i * oh_ow * conv_param.OC + g * K_per_G; + const uint8_t* actStartGroup = actStartBatch + g * C_per_G; + + fpConv( + actStartGroup, + packed_weights.getBuf() + g * K_per_G * C_per_G, + currOutBuf, + a_zero_point, + H, + W); + + // Output processing should be called for each group + for (int j = 0; j < 2; ++j) { + // calculateRowOffsets( + // conv_param, actStartGroup, rowOffsetBuf, a_zero_point, j); + int32_t* rowOffsetForCurG = rowOffsetTrDest + + ((g - gOuter) + j) * conv_param.IN_DIM[0] * + conv_param.IN_DIM[1]; + // compare_buffers(rowOffsetBuf, rowOffsetForCurG, + // conv_param.IN_DIM[0]*conv_param.IN_DIM[1], 1, 1, 100); + + // outProcess expects rowOffsetBuf to contain row offsets for the + // current group + memcpy( + rowOffsetBuf, + rowOffsetForCurG, + conv_param.IN_DIM[0] * conv_param.IN_DIM[1] * sizeof(int32_t)); + + if (cpuinfo_has_x86_avx512f()) { + // Currently use avx2 code + outProcess.template f<inst_set_t::avx2>( + out, + currOutBuf + j * K_per_G, + {i * oh_ow, oh_ow, (g + j) * K_per_G, K_per_G}, + K_per_G * G, + K_per_G * G); + } else if (cpuinfo_has_x86_avx2()) { + outProcess.template f<inst_set_t::avx2>( + out, + currOutBuf + j * K_per_G, + {i * oh_ow, oh_ow, (g + j) * K_per_G, K_per_G}, + K_per_G * G, + K_per_G * G); + } else { + // TODO: Have default slower path + assert(0 && "unsupported architecure"); + } + } // j loop + } + } + } + } else { + // for the not supported cases, just execute the naive C implementation + conv_ref( + conv_param, + activations, + a_zero_point, + packed_weights.getBuf(), + outBuffer); + for (int i = 0; i < conv_param.MB; ++i) { + for (int g = 0; g < conv_param.G; ++g) { + calculateRowOffsets( + conv_param, + activations + + i * conv_param.IN_DIM[0] * conv_param.IN_DIM[1] * conv_param.IC, + rowOffsetBuf, + a_zero_point, + g); + outProcess.template f<inst_set_t::anyarch>( + out, + outBuffer + i * oh_ow * conv_param.OC + g * K_per_G, + {i * oh_ow, oh_ow, g * K_per_G, K_per_G}, + K_per_G * G, + K_per_G * G); + } + } + } +} + +jit_rowoffset_kernel_fp getOrCreateRowOffsetKernel( + const conv_param_t<>& conv_param, + int a_zero_point) { + // Note: Wrong code is generated if it's not one of the supported convolution + assert(fbgemmOptimizedGConv<2>(conv_param)); + auto kernelSig = getKernelSig(conv_param, a_zero_point == 0); + if (GenConvKernel<int32_t>::codeCacheRowOffset_.find(kernelSig) != + GenConvKernel<int32_t>::codeCacheRowOffset_.end()) { + return GenConvKernel<int32_t>::codeCacheRowOffset_[kernelSig]; + } else { + auto genObj = GenConvKernel<int32_t>(conv_param, a_zero_point); + // TODO: Instruction set based dispatch + return genObj.template getOrCreateRowOffset<inst_set_t::avx2>(conv_param); + } +} + +template <int SPATIAL_DIM> +int rowOffsetBufferSizeGConv(const conv_param_t<SPATIAL_DIM>& conv_param) { + // row offset buffer should be a able to hold row offsets for however + // number of groups we process at a time. + if (cpuinfo_initialize()) { + if (cpuinfo_has_x86_avx512f()) { + int bufferSize = conv_param.OUT_DIM[0] * conv_param.OUT_DIM[1]; + int C_per_G = conv_param.IC / conv_param.G; + int K_per_G = conv_param.OC / conv_param.G; + if (C_per_G == 4 && K_per_G == 4) { + return 2 * 8 * bufferSize; + } else { + return conv_param.G * bufferSize; + } + } else if (cpuinfo_has_x86_avx2()) { + int bufferSize = conv_param.OUT_DIM[0] * conv_param.OUT_DIM[1]; + int C_per_G = conv_param.IC / conv_param.G; + int K_per_G = conv_param.OC / conv_param.G; + if (C_per_G == 4 && K_per_G == 4) { + // row offset is calculated for 8 groups at a time + // 2x is needed for transposing + return 2 * 8 * bufferSize; + } else { + return conv_param.G * bufferSize; + } + } else { + // TODO: Have default slower path + assert(0 && "unsupported architecture"); + return -1; + } + } else { + throw std::runtime_error("Failed to initialize cpuinfo!"); + } +} + +template int rowOffsetBufferSizeGConv<2>(const conv_param_t<2>& conv_param); + +#define INSTANTIATE_BASE(RELU, Q_GRAN) \ + template void fbgemmGroupwiseConv( \ + const conv_param_t<2>& conv_param, \ + const uint8_t* activations, \ + int32_t a_zero_point, \ + std::int32_t* rowOffsetBuf, \ + PackWeightMatrixForGConv<int8_t, int32_t, 2>& packed_weights, \ + uint8_t* out, \ + int32_t* outBuffer, \ + const ReQuantizeOutput<RELU, Q_GRAN>& outProcess, \ + int thread_id, \ + int num_threads); + +#define INSTANTIATE_Q_GRANS(RELU) \ + INSTANTIATE_BASE(RELU, QuantizationGranularity::TENSOR); \ + INSTANTIATE_BASE(RELU, QuantizationGranularity::GROUP); \ + INSTANTIATE_BASE(RELU, QuantizationGranularity::OUT_CHANNEL); + +INSTANTIATE_Q_GRANS(false); +INSTANTIATE_Q_GRANS(true); + +#undef INSTANTIATE_Q_GRANS +#undef INSTANTIATE_BASE + +template void fbgemmGroupwiseConv( + const conv_param_t<2>& conv_param, + const uint8_t* activations, + int32_t a_zero_point, + std::int32_t* rowOffsetBuf, + PackWeightMatrixForGConv<int8_t, int32_t, 2>& packed_weights, + int32_t* out, + int32_t* outBuffer, + const DoNothing<int32_t, int32_t>& outProcess, + int thread_id, + int num_threads); + +} // namespace fbgemm diff --git a/src/PackWeightMatrixForGConv.cc b/src/PackWeightMatrixForGConv.cc new file mode 100644 index 0000000..e6c9b7d --- /dev/null +++ b/src/PackWeightMatrixForGConv.cc @@ -0,0 +1,103 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include <cpuinfo.h> +#include <cassert> +#include <iomanip> +#include "RefImplementations.h" +#include "fbgemm/Fbgemm.h" + +namespace fbgemm { + +template <typename T, typename accT, int SPATIAL_DIM> +PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::PackWeightMatrixForGConv( + matrix_op_t trans, + const conv_param_t<SPATIAL_DIM>& conv_param, + const T* sdata, + T* pdata) + : trans_(trans), conv_param_(conv_param), sdata_(sdata) { + static_assert(SPATIAL_DIM == 2, "3D conv not supported yet"); + + if (!pdata) { + bufAllocatedHere_ = true; + pdata_ = static_cast<T*>(fbgemmAlignedAlloc( + 64, + conv_param_.G * conv_param_.K[0] * conv_param_.K[1] * + (conv_param_.OC / conv_param_.G) * + (conv_param_.IC / conv_param_.G) * sizeof(T))); + } else { + bufAllocatedHere_ = false; + pdata_ = pdata; + } + pack(); +} + +/** + * @brief Pack weight tensor in a suitable format required for the optimized + * kernel. + * + * Let IC_per_G be number of input channels per group and OC_per_G be number of + * output channels per group. + * + * For IC_per_G == 4 && OC_per_G == 4 optimized + * kernel works on 2 groups at a time hence input channels for g and g+1 group + * are laid out sequentially for each output channel, i.e., the layout is R S + * (G/2) K (2C) + * We work on two groups at a time to fully utilize the avx2 SIMD width of + * 256-bits. + * + * For IC_per_G == 8, 16, 32 && OC_per_G == 8, 16, 32 there is no need to work + * on 2 groups at a time and full SIMD width can be efficiently utilized even + * while working on 1 group at a time. + */ +template <typename T, typename accT, int SPATIAL_DIM> +void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack() { + // filters are assumed to be in G RS C/G K/G format + int R = conv_param_.K[0]; + int S = conv_param_.K[1]; + int G = conv_param_.G; + int IC_per_G = conv_param_.IC / conv_param_.G; + int OC_per_G = conv_param_.OC / conv_param_.G; + + // If transpose option is set, the weight matrix is in layout G K/G (R S C/G) + // instead of G (R S C/G) K/G + bool tr = (trans_ == matrix_op_t::Transpose); + if (fbgemmOptimizedGConv(conv_param_)) { + // currently only this case is supported + for (int r = 0; r < R; ++r) { + for (int s = 0; s < S; ++s) { + for (int k = 0; k < OC_per_G; ++k) { + for (int g = 0; g < G; ++g) { + for (int c = 0; c < IC_per_G; ++c) { + inpType b = tr + ? sdata_ + [(((g * OC_per_G + k) * R + r) * S + s) * IC_per_G + c] + : sdata_ + [(((g * R + r) * S + s) * IC_per_G + c) * OC_per_G + k]; + pdata_ + [((((r * S + s) * (G / 2) + (g / 2)) * OC_per_G + k) * 2 + + (g % 2)) * + IC_per_G + + c] = b; + } + } + } + } + } + } else { + if (tr) { + // conv_ref expects weights to be in G (R S C/G) K/G format + transposeConvWeights(conv_param_, sdata_, pdata_); + } else { + // just copy the data for not supported cases + memcpy(pdata_, sdata_, G * R * S * OC_per_G * IC_per_G * sizeof(inpType)); + } + } +} + +template class PackWeightMatrixForGConv<int8_t, int32_t, 2>; +template class PackWeightMatrixForGConv<int8_t, int16_t, 2>; +} // namespace fbgemm diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc index 5168a15..5c6cf1b 100644 --- a/src/RefImplementations.cc +++ b/src/RefImplementations.cc @@ -486,6 +486,31 @@ void conv3d_ref( } // for each n } +void transposeConvWeights( + const conv_param_t<>& conv_p, + const std::int8_t* src, + std::int8_t* dest) { + int R = conv_p.K[0]; + int S = conv_p.K[1]; + int G = conv_p.G; + int IC_per_G = conv_p.IC / conv_p.G; + int OC_per_G = conv_p.OC / conv_p.G; + + // Transforms weights from G K/G (R S C/G) to G (R S C/G) K/G format. + for (int r = 0; r < R; ++r) { + for (int s = 0; s < S; ++s) { + for (int k = 0; k < OC_per_G; ++k) { + for (int g = 0; g < G; ++g) { + for (int c = 0; c < IC_per_G; ++c) { + dest[(((g * R + r) * S + s) * IC_per_G + c) * OC_per_G + k] = + src[(((g * OC_per_G + k) * R + r) * S + s) * IC_per_G + c]; + } + } + } + } + } +} + void depthwise_3x3_pad_1_ref( int N, int H, diff --git a/src/RefImplementations.h b/src/RefImplementations.h index fce68e6..62f17e9 100644 --- a/src/RefImplementations.h +++ b/src/RefImplementations.h @@ -176,6 +176,14 @@ FBGEMM_API void conv3d_ref( std::int32_t* C); /* + * @brief Transforms weights from G K/G (R S C/G) to G (R S C/G) K/G format. + */ +FBGEMM_API void transposeConvWeights( + const conv_param_t<>& conv_p, + const std::int8_t* src, + std::int8_t* dest); + +/* * @brief Reference implementation of im2col operation. * The input A is assumed to be in NHiWiC format. * The output A is assumed to be in NHoWoRSC format. |