Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/Fbgemm.cc16
-rw-r--r--src/GroupwiseConv.h248
-rw-r--r--src/GroupwiseConvAcc32Avx2.cc1552
-rw-r--r--src/PackWeightMatrixForGConv.cc103
-rw-r--r--src/RefImplementations.cc25
-rw-r--r--src/RefImplementations.h8
6 files changed, 1952 insertions, 0 deletions
diff --git a/src/Fbgemm.cc b/src/Fbgemm.cc
index 45108d0..9384af6 100644
--- a/src/Fbgemm.cc
+++ b/src/Fbgemm.cc
@@ -192,6 +192,22 @@ void fbgemmPacked(
#endif
}
+template <int SPATIAL_DIM>
+FBGEMM_API bool fbgemmOptimizedGConv(const conv_param_t<SPATIAL_DIM>& conv_p) {
+ int C_per_G = conv_p.IC / conv_p.G;
+ int K_per_G = conv_p.OC / conv_p.G;
+
+ return (SPATIAL_DIM == 2) && (C_per_G == K_per_G) && (C_per_G == 4) &&
+ (conv_p.G % 8 == 0) && (conv_p.K[0] == conv_p.K[1]) &&
+ (conv_p.K[0] == 3) && (conv_p.pad[0] == 1) && (conv_p.pad[1] == 1) &&
+ (conv_p.pad[0] == conv_p.pad[2]) && (conv_p.pad[1] == conv_p.pad[3]) &&
+ (conv_p.dilation[0] == 1) && (conv_p.dilation[0] == conv_p.dilation[1]) &&
+ (conv_p.stride[0] == 1) && (conv_p.stride[0] == conv_p.stride[1]);
+}
+
+template bool fbgemmOptimizedGConv(const conv_param_t<2>& conv_p);
+template bool fbgemmOptimizedGConv(const conv_param_t<3>& conv_p);
+
bool fbgemmSupportedCPU() {
return (cpuinfo_initialize() && cpuinfo_has_x86_avx2());
}
diff --git a/src/GroupwiseConv.h b/src/GroupwiseConv.h
new file mode 100644
index 0000000..a46a895
--- /dev/null
+++ b/src/GroupwiseConv.h
@@ -0,0 +1,248 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+#include <asmjit/asmjit.h>
+#include <cpuinfo.h>
+#include <cassert>
+#include <cstdint>
+#include <map>
+#include <string>
+#include <tuple>
+#include "fbgemm/ConvUtils.h"
+#include "fbgemm/Fbgemm.h"
+#include "fbgemm/Utils.h"
+/*#define FBGEMM_LOG_CODE 1*/
+
+namespace fbgemm {
+
+namespace x86 = asmjit::x86;
+
+using jit_conv_kernel_fp = void (*)(
+ const uint8_t* in_acts,
+ int8_t* wghts,
+ int32_t* out_acts,
+ int32_t a_zero_pt,
+ int32_t height,
+ int32_t width);
+
+using jit_rowoffset_kernel_fp = void (*)(
+ const uint8_t* in_acts,
+ int32_t a_zero_pt,
+ int32_t height,
+ int32_t width,
+ int32_t* row_offset);
+
+template <typename accT = int32_t>
+class GenConvKernel {
+ public:
+ GenConvKernel(const conv_param_t<>& conv_param, std::int32_t zero_point)
+ : WRegs_avx2_{x86::ymm0,
+ x86::ymm1,
+ x86::ymm2,
+ x86::ymm3,
+ x86::ymm4,
+ x86::ymm5,
+ x86::ymm6,
+ x86::ymm7,
+ x86::ymm8} {
+ // vector width in bits
+ if (cpuinfo_initialize()) {
+ if (cpuinfo_has_x86_avx512f()) {
+ vectorWidth_ = 512;
+ } else if (cpuinfo_has_x86_avx2()) {
+ vectorWidth_ = 256;
+ } else {
+ // TODO: Have default path
+ assert(0 && "unsupported architecture");
+ return;
+ }
+ } else {
+ throw std::runtime_error("Failed to initialize cpuinfo!");
+ }
+ zeroPTRegAvx2_ = x86::ymm9;
+ oneReg8BitAvx2_ = x86::ymm10;
+ tmpReg1Avx2_ = x86::ymm11;
+ stPermRegAvx2_ = x86::ymm12;
+ actRegAvx2_ = x86::ymm13;
+ resultRegAvx2_ = x86::ymm14;
+ oneReg16BitAvx2_ = x86::ymm15;
+
+ // vector width in elements; Each element is int8 or uint8
+ VLEN_ = vectorWidth_ / 8;
+
+ if (zero_point == 0) {
+ isZeroPointZero_ = true;
+ } else {
+ isZeroPointZero_ = false;
+ }
+
+ G_ = conv_param.G;
+ K_per_G_ = conv_param.OC / conv_param.G;
+ K_ = conv_param.OC;
+ C_per_G_ = conv_param.IC / conv_param.G;
+ C_ = conv_param.IC;
+ R_ = conv_param.K[0];
+ S_ = conv_param.K[1];
+ H_ = conv_param.OUT_DIM[0];
+ W_ = conv_param.OUT_DIM[1];
+ H_PAD_ = conv_param.pad[0];
+ W_PAD_ = conv_param.pad[1];
+
+ assert(fbgemmOptimizedGConv(conv_param));
+ }
+
+ template <inst_set_t instSet>
+ std::string getCodeLoggingFile(bool rowOffsetKernel = false) {
+ std::string fileName = "conv_";
+ fileName += "G-" + std::to_string(G_);
+ fileName += "_K-" + std::to_string(K_);
+ fileName += "_C-" + std::to_string(C_);
+ fileName += "_R-" + std::to_string(R_);
+ fileName += "_S-" + std::to_string(S_);
+ fileName += "_PADH-" + std::to_string(H_PAD_);
+ fileName += "_PADW-" + std::to_string(W_PAD_);
+ fileName += "_isZeroPointZero-" + std::to_string(isZeroPointZero_);
+ if (rowOffsetKernel) {
+ fileName += "_rowOffset";
+ }
+
+ if (instSet == inst_set_t::avx512) {
+ fileName += "_avx512";
+ } else if (instSet == inst_set_t::avx2) {
+ fileName += "_avx2";
+ }
+ fileName += ".txt";
+ return fileName;
+ }
+
+ ~GenConvKernel() {}
+
+ template <inst_set_t instSet>
+ jit_conv_kernel_fp getOrCreate(const conv_param_t<>& conv_param);
+
+ template <inst_set_t instSet>
+ jit_rowoffset_kernel_fp getOrCreateRowOffset(
+ const conv_param_t<>& conv_param);
+
+ template <inst_set_t instSet>
+ void createVector16BitOne(asmjit::X86Emitter* a);
+
+ template <inst_set_t instSet>
+ void createVector8BitOne(asmjit::X86Emitter* a);
+
+ template <inst_set_t instSet>
+ void setToZeroPt(asmjit::X86Emitter* a, asmjit::X86Ymm destReg);
+
+ template <inst_set_t instSet>
+ void
+ gen8bitFMA(asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm wReg);
+
+ template <inst_set_t instSet>
+ void genForLoadingWeights(asmjit::X86Emitter* a);
+
+ template <inst_set_t instSet>
+ void genConstForPermutations(asmjit::X86Emitter* a);
+
+ template <inst_set_t instSet>
+ void genForTopEdge(asmjit::X86Emitter* a);
+
+ template <inst_set_t instSet>
+ void genForLeftEdge(asmjit::X86Emitter* a);
+
+ template <inst_set_t instSet>
+ void genForRightEdge(asmjit::X86Emitter* a);
+
+ template <inst_set_t instSet>
+ void genForBottomEdge(asmjit::X86Emitter* a);
+
+ template <inst_set_t instSet>
+ void genCoreInsts(asmjit::X86Emitter* a);
+
+ template <inst_set_t instSet>
+ void storeResult(asmjit::X86Emitter* a, int offset = 0);
+
+ // for Rowoffset kernel
+ template <inst_set_t instSet>
+ void gen8BitSum(asmjit::X86Emitter* a, asmjit::X86Ymm aReg);
+
+ template <inst_set_t instSet>
+ void genForTopEdgeRowoffset(asmjit::X86Emitter* a);
+
+ template <inst_set_t instSet>
+ void genForLeftEdgeRowoffset(asmjit::X86Emitter* a);
+
+ template <inst_set_t instSet>
+ void genForRightEdgeRowoffset(asmjit::X86Emitter* a);
+
+ template <inst_set_t instSet>
+ void genForBottomEdgeRowoffset(asmjit::X86Emitter* a);
+
+ template <inst_set_t instSet>
+ void genRowoffsetCorners(asmjit::X86Emitter* a);
+
+ template <inst_set_t instSet>
+ void genRowoffsetCore(asmjit::X86Emitter* a);
+
+ template <inst_set_t instSet>
+ void storeResultRowoffset(asmjit::X86Emitter* a, int offset = 0);
+
+ static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit.
+ static thread_local asmjit::CodeHolder code_; ///< JIT Code Holder for asmjit.
+ static thread_local std::
+ map<std::tuple<bool, int, int, int>, jit_conv_kernel_fp>
+ codeCache_; ///< JIT Code Cache for reuse.
+ static thread_local std::
+ map<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp>
+ codeCacheRowOffset_; ///< JIT Code Cache for row offset kernel.
+
+ private:
+ int vectorWidth_; ///< Vector width in bits.
+ int VLEN_; ///< Vector width in elements.
+ // avx2 specific
+ asmjit::X86Ymm
+ WRegs_avx2_[9]; ///< AVX2 ymm registers for weights in the micro-kernel.
+ asmjit::X86Ymm zeroPTRegAvx2_;
+ asmjit::X86Ymm tmpReg1Avx2_;
+ asmjit::X86Ymm stPermRegAvx2_;
+ asmjit::X86Ymm actRegAvx2_;
+ asmjit::X86Ymm resultRegAvx2_;
+ asmjit::X86Ymm oneReg8BitAvx2_;
+ asmjit::X86Ymm oneReg16BitAvx2_;
+
+ // arguments to the function created
+ asmjit::X86Gp in_acts_R_;
+ asmjit::X86Gp wghts_R_;
+ asmjit::X86Gp out_acts_R_;
+ asmjit::X86Gp a_zero_pt_R_;
+ asmjit::X86Gp H_R_;
+ asmjit::X86Gp W_R_;
+ asmjit::X86Gp row_offset_R_;
+
+ // Used registers
+ asmjit::X86Gp loopR1_;
+ asmjit::X86Gp loopR2_;
+ asmjit::X86Gp scratchReg1_;
+ asmjit::X86Gp scratchReg2_;
+
+ // Other parameters
+ bool isZeroPointZero_;
+
+ // current conv parameters
+ int G_; ///< Number of groups
+ int K_; ///< Number of output channels
+ int K_per_G_; ///< Number of output channels per group
+ int C_; ///< Number of input channels
+ int C_per_G_; ///< Number of input channels per group
+ int R_; ///< Filter/Kernel height
+ int S_; ///< Filter/Kernel width
+ int H_;
+ int W_;
+ int H_PAD_; ///< Padding for height (top and bottom)
+ int W_PAD_; ///< Padding for width (left and right)
+};
+
+} // namespace fbgemm
diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc
new file mode 100644
index 0000000..8298f4c
--- /dev/null
+++ b/src/GroupwiseConvAcc32Avx2.cc
@@ -0,0 +1,1552 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <asmjit/asmjit.h>
+#include <cpuinfo.h>
+#include <immintrin.h>
+#include <array>
+#include <iostream>
+#include <map>
+#include <stdexcept>
+#include <tuple>
+#include "GroupwiseConv.h"
+#include "RefImplementations.h"
+#include "TransposeUtils.h"
+#include "fbgemm/Fbgemm.h"
+
+namespace fbgemm {
+
+using namespace std;
+
+template <typename accT>
+thread_local asmjit::JitRuntime GenConvKernel<accT>::rt_;
+
+template <typename accT>
+thread_local asmjit::CodeHolder GenConvKernel<accT>::code_;
+
+template <typename accT>
+thread_local std::map<std::tuple<bool, int, int, int>, jit_conv_kernel_fp>
+ GenConvKernel<accT>::codeCache_;
+
+template <typename accT>
+thread_local std::map<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp>
+ GenConvKernel<accT>::codeCacheRowOffset_;
+
+namespace x86 = asmjit::x86;
+
+void calculateRowOffsets(
+ const conv_param_t<>& conv_param,
+ const uint8_t* activations,
+ int32_t* rowOffsetBuf,
+ int32_t a_zero_point,
+ int groupNum) {
+ int H = conv_param.OUT_DIM[0];
+ int W = conv_param.OUT_DIM[1];
+ int G = conv_param.G;
+ int C_per_G = conv_param.IC / conv_param.G;
+ int H_PAD = conv_param.pad[0];
+ int W_PAD = conv_param.pad[1];
+ // calculate row offset
+ for (int h = 0; h < H; ++h) {
+ for (int w = 0; w < W; ++w) {
+ int32_t sum = 0;
+ for (int r = 0; r < conv_param.K[0]; ++r) {
+ int h_in = -H_PAD + h + r;
+ for (int s = 0; s < conv_param.K[1]; ++s) {
+ int w_in = -W_PAD + w + s;
+ for (int c = 0; c < C_per_G; ++c) {
+ if (h_in < 0 || h_in >= H || w_in < 0 || w_in >= W) {
+ sum += a_zero_point;
+ } else {
+ sum +=
+ activations[((h_in * W + w_in) * G + groupNum) * C_per_G + c];
+ }
+ }
+ }
+ }
+ rowOffsetBuf[h * W + w] = sum;
+ }
+ }
+}
+
+tuple<bool, int, int, int> getKernelSig(
+ const conv_param_t<>& conv_param,
+ bool isZeroPointZero) {
+ int C_per_G = conv_param.IC / conv_param.G;
+ int K_per_G = conv_param.OC / conv_param.G;
+ auto kernelSig =
+ std::make_tuple(isZeroPointZero, conv_param.G, C_per_G, K_per_G);
+ return kernelSig;
+}
+
+template <typename accT = int32_t>
+jit_conv_kernel_fp getOrCreateConvKernel(
+ const conv_param_t<>& conv_param,
+ int a_zero_point) {
+ // Note: Wrong code is generated if it's not one of the supported convolution
+ assert(fbgemmOptimizedGConv<2>(conv_param));
+ auto kernelSig = getKernelSig(conv_param, a_zero_point == 0);
+ if (GenConvKernel<accT>::codeCache_.find(kernelSig) !=
+ GenConvKernel<accT>::codeCache_.end()) {
+ return GenConvKernel<accT>::codeCache_[kernelSig];
+ } else {
+ auto genObj = GenConvKernel<accT>(conv_param, a_zero_point);
+ // TODO: Instruction set based dispatch
+ return genObj.template getOrCreate<inst_set_t::avx2>(conv_param);
+ }
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::createVector8BitOne<inst_set_t::avx2>(
+ asmjit::X86Emitter* a) {
+ // create 8-bit 1s
+ // i.e., oneReg16BitAvx2_[0:7] contains 0x01, oneReg8BitAvx2_[8:15] contains
+ // 0x01 and so on
+ a->vpcmpeqw(oneReg8BitAvx2_, oneReg8BitAvx2_, oneReg8BitAvx2_);
+ a->vpabsb(oneReg8BitAvx2_, oneReg8BitAvx2_);
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::createVector16BitOne<inst_set_t::avx2>(
+ asmjit::X86Emitter* a) {
+ // create 16-bit 1s
+ // i.e., oneReg16BitAvx2_[0:15] contains 0x0001, oneReg16BitAvx2_[16:31]
+ // contains 0x0001 and so on
+ a->vpcmpeqw(oneReg16BitAvx2_, oneReg16BitAvx2_, oneReg16BitAvx2_);
+ a->vpsrlw(oneReg16BitAvx2_, oneReg16BitAvx2_, 15);
+}
+template <>
+template <>
+void GenConvKernel<int32_t>::setToZeroPt<inst_set_t::avx2>(
+ asmjit::X86Emitter* a,
+ asmjit::X86Ymm destReg) {
+ // make destReg all zeros
+ a->vxorps(destReg, destReg, destReg);
+ asmjit::X86Xmm const_reg_xmm = x86::xmm10;
+ // move zero point to xmm10
+ a->movq(const_reg_xmm, a_zero_pt_R_);
+ // make copies of zero point
+ a->vbroadcastsd(x86::ymm10, const_reg_xmm);
+ // shuffle
+ // overall impact is that destReg contains 32 8-bit values equal to the lower
+ // 8-bits of a_zero_pt_R_
+ a->vpshufb(destReg, x86::ymm10, destReg);
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::genConstForPermutations<inst_set_t::avx2>(
+ asmjit::X86Emitter* a) {
+ asmjit::X86Gp permute_const_reg = a->gpzRef(12);
+ asmjit::X86Xmm const_reg_xmm = x86::xmm10;
+ // We have 1st group in even lanes and 2nd group in odd lanes.
+ // Permute to put 1st group to lower 128-bit and 2nd group in upper
+ // 128-bit.
+ // load 7, 5, 3, 1, 6, 4, 2, 0 in a 64-bit reg
+ a->mov(permute_const_reg, 0x0705030106040200);
+ a->movq(const_reg_xmm, permute_const_reg);
+ // Zero extend 8 packed 8-bit integers in the low 8 bytes of const_reg_xmm to
+ // 8 packed 32-bit integers in stPermRegAvx2_
+ a->vpmovzxbd(stPermRegAvx2_, const_reg_xmm);
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::storeResult<inst_set_t::avx2>(
+ asmjit::X86Emitter* a,
+ int offset) {
+ // store with permutation
+ a->vpermd(resultRegAvx2_, stPermRegAvx2_, resultRegAvx2_);
+ a->vmovups(x86::dword_ptr(out_acts_R_, offset), resultRegAvx2_);
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::storeResultRowoffset<inst_set_t::avx2>(
+ asmjit::X86Emitter* a,
+ int offset) {
+ // store
+ a->vmovups(x86::dword_ptr(row_offset_R_, offset), resultRegAvx2_);
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::genForLoadingWeights<inst_set_t::avx2>(
+ asmjit::X86Emitter* a) {
+ // load weights
+ for (int r = 0; r < R_; ++r) {
+ for (int s = 0; s < S_; ++s) {
+ a->vmovaps(
+ WRegs_avx2_[r * S_ + s],
+ x86::dword_ptr(
+ wghts_R_,
+ (r * S_ + s) * G_ * K_per_G_ * C_per_G_ * sizeof(int8_t)));
+ }
+ }
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::gen8bitFMA<inst_set_t::avx2>(
+ asmjit::X86Emitter* a,
+ asmjit::X86Ymm aReg,
+ asmjit::X86Ymm wReg) {
+ a->vpmaddubsw(tmpReg1Avx2_, aReg, wReg);
+ a->vpmaddwd(tmpReg1Avx2_, oneReg16BitAvx2_, tmpReg1Avx2_);
+ a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_);
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::gen8BitSum<inst_set_t::avx2>(
+ asmjit::X86Emitter* a,
+ asmjit::X86Ymm aReg) {
+ a->vpmaddubsw(tmpReg1Avx2_, aReg, oneReg8BitAvx2_);
+ a->vpmaddwd(tmpReg1Avx2_, tmpReg1Avx2_, oneReg16BitAvx2_);
+ a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_);
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::genForTopEdge<inst_set_t::avx2>(
+ asmjit::X86Emitter* a) {
+ // top-left corner code
+ // zero out the results register
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ for (int r = 0; r < R_; ++r) {
+ int h_in = -H_PAD_ + r;
+ if (h_in >= 0) {
+ a->imul(
+ scratchReg1_,
+ W_R_,
+ static_cast<asmjit::Imm>(h_in * C_ * sizeof(uint8_t)));
+ }
+ for (int s = 0; s < S_; ++s) {
+ int w_in = -W_PAD_ + s;
+ if (h_in >= 0 && w_in >= 0) {
+ a->vbroadcastsd(
+ actRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_, scratchReg1_, 0, w_in * C_ * sizeof(uint8_t)));
+ gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ } else {
+ if (!isZeroPointZero_) {
+ gen8bitFMA<inst_set_t::avx2>(
+ a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ }
+ }
+ }
+ }
+ storeResult<inst_set_t::avx2>(a);
+
+ a->add(out_acts_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+
+ // top edge excluding corners
+ asmjit::Label LoopTopEdge = a->newLabel();
+ a->mov(loopR2_, static_cast<asmjit::Imm>(W_PAD_));
+ a->bind(LoopTopEdge);
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ if (!isZeroPointZero_) {
+ for (int r = 0; r < H_PAD_; ++r) {
+ for (int s = 0; s < S_; ++s) {
+ gen8bitFMA<inst_set_t::avx2>(a, zeroPTRegAvx2_, WRegs_avx2_[s]);
+ }
+ }
+ }
+ for (int r = H_PAD_; r < R_; ++r) {
+ int h_in = -H_PAD_ + r;
+ a->imul(
+ scratchReg1_,
+ W_R_,
+ static_cast<asmjit::Imm>(h_in * C_ * sizeof(uint8_t)));
+ for (int s = 0; s < S_; ++s) {
+ a->vbroadcastsd(
+ actRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t)));
+ gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ }
+ }
+ a->add(in_acts_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+
+ storeResult<inst_set_t::avx2>(a);
+
+ a->add(out_acts_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+ a->mov(loopR1_, W_R_);
+ a->sub(loopR1_, static_cast<asmjit::Imm>(W_PAD_));
+ a->inc(loopR2_);
+ a->cmp(loopR2_, loopR1_);
+ a->jl(LoopTopEdge);
+ a->mov(scratchReg2_, W_R_);
+ a->imul(scratchReg2_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->sub(
+ scratchReg2_,
+ static_cast<asmjit::Imm>(2 * W_PAD_ * C_ * sizeof(uint8_t)));
+ a->sub(in_acts_R_, scratchReg2_);
+
+ // top-right corner code
+
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ if (!isZeroPointZero_) {
+ for (int r = 0; r < H_PAD_; ++r) {
+ for (int s = 0; s < S_; ++s) {
+ gen8bitFMA<inst_set_t::avx2>(
+ a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ }
+ }
+ }
+ for (int r = H_PAD_; r < R_; ++r) {
+ int h_in = -H_PAD_ + r;
+ for (int s = 0; s < S_ - W_PAD_; ++s) {
+ a->imul(
+ scratchReg1_,
+ W_R_,
+ static_cast<asmjit::Imm>(h_in * C_ * sizeof(uint8_t)));
+ a->mov(scratchReg2_, W_R_);
+ a->sub(scratchReg2_, static_cast<asmjit::Imm>(R_ - W_PAD_ - s));
+ a->imul(scratchReg2_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->add(scratchReg1_, scratchReg2_);
+ a->vbroadcastsd(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_));
+ gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ }
+ if (!isZeroPointZero_) {
+ for (int s = S_ - W_PAD_; s < S_; ++s) {
+ gen8bitFMA<inst_set_t::avx2>(
+ a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ }
+ }
+ }
+ storeResult<inst_set_t::avx2>(a);
+ a->add(out_acts_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+
+ // reset output activation pointer
+ a->imul(scratchReg1_, W_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+ a->sub(out_acts_R_, scratchReg1_);
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::genForLeftEdge<inst_set_t::avx2>(
+ asmjit::X86Emitter* a) {
+ // left edge excluding corners
+ asmjit::Label LoopLeftEdge = a->newLabel();
+ a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_));
+ a->bind(LoopLeftEdge);
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ a->mov(scratchReg1_, loopR1_);
+ a->sub(scratchReg1_, static_cast<asmjit::Imm>(H_PAD_));
+ a->imul(scratchReg1_, W_R_);
+ a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ for (int r = 0; r < R_; ++r) {
+ if (!isZeroPointZero_) {
+ for (int s = 0; s < W_PAD_; ++s) {
+ gen8bitFMA<inst_set_t::avx2>(
+ a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ }
+ }
+ for (int s = W_PAD_; s < S_; ++s) {
+ a->vbroadcastsd(
+ actRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_,
+ scratchReg1_,
+ 0,
+ (s - W_PAD_) * C_ * sizeof(uint8_t)));
+ gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ }
+ a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->add(scratchReg1_, scratchReg2_);
+ }
+
+ a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+ a->add(out_acts_R_, scratchReg2_);
+ storeResult<inst_set_t::avx2>(a);
+
+ a->inc(loopR1_);
+ a->mov(loopR2_, H_R_);
+ a->sub(loopR2_, static_cast<asmjit::Imm>(H_PAD_));
+ a->cmp(loopR1_, loopR2_);
+ a->jl(LoopLeftEdge);
+
+ // reset output activation pointer
+ a->mov(scratchReg2_, H_R_);
+ a->sub(scratchReg2_, static_cast<asmjit::Imm>(2 * H_PAD_));
+ a->imul(scratchReg2_, W_R_);
+ a->imul(scratchReg2_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+ a->sub(out_acts_R_, scratchReg2_);
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::genForRightEdge<inst_set_t::avx2>(
+ asmjit::X86Emitter* a) {
+ // right edge excluding corners
+ asmjit::Label LoopRightEdge = a->newLabel();
+
+ // output pointer to the right edge
+ // (W_ + W_ - 1)*K_*sizeof(int32_t)
+ a->mov(scratchReg2_, W_R_);
+ a->imul(scratchReg2_, 2);
+ a->sub(scratchReg2_, 1);
+ a->imul(scratchReg2_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+ a->add(out_acts_R_, scratchReg2_);
+
+ a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_));
+ a->bind(LoopRightEdge);
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ a->mov(scratchReg1_, loopR1_);
+ a->sub(scratchReg1_, static_cast<asmjit::Imm>(H_PAD_));
+ a->imul(scratchReg1_, W_R_);
+ a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+
+ a->mov(scratchReg2_, W_R_);
+ a->sub(scratchReg2_, static_cast<asmjit::Imm>(2 * W_PAD_));
+ a->imul(scratchReg2_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->add(scratchReg1_, scratchReg2_);
+ for (int r = 0; r < R_; ++r) {
+ for (int s = 0; s < S_ - W_PAD_; ++s) {
+ a->vbroadcastsd(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_));
+ gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ a->add(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ }
+ if (!isZeroPointZero_) {
+ for (int s = S_ - W_PAD_; s < S_; ++s) {
+ gen8bitFMA<inst_set_t::avx2>(
+ a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ }
+ }
+
+ a->sub(
+ scratchReg1_,
+ static_cast<asmjit::Imm>((S_ - W_PAD_) * C_ * sizeof(uint8_t)));
+ a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->add(scratchReg1_, scratchReg2_);
+ }
+
+ // storeResult<inst_set_t::avx2>(a, (W_+W_-1)*K_*sizeof(int32_t));
+ storeResult<inst_set_t::avx2>(a);
+
+ a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+ a->add(out_acts_R_, scratchReg2_);
+ a->mov(loopR2_, H_R_);
+ a->sub(loopR2_, static_cast<asmjit::Imm>(H_PAD_));
+ a->inc(loopR1_);
+ a->cmp(loopR1_, loopR2_);
+ a->jl(LoopRightEdge);
+
+ // reset base
+ a->mov(scratchReg2_, W_R_);
+ a->imul(scratchReg2_, 2);
+ a->sub(scratchReg2_, 1);
+ a->imul(scratchReg2_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+ a->sub(out_acts_R_, scratchReg2_);
+
+ // reset loop increments
+ //(H_ - 2*H_PAD_)*W_*K_*sizeof(int32_t)
+ a->mov(scratchReg2_, H_R_);
+ a->sub(scratchReg2_, static_cast<asmjit::Imm>(2 * H_PAD_));
+ a->imul(scratchReg2_, W_R_);
+ a->imul(scratchReg2_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+ a->sub(out_acts_R_, scratchReg2_);
+ // a->sub(out_acts_R_, (H_ - 2*H_PAD_)*W_*K_*sizeof(int32_t));
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::genForBottomEdge<inst_set_t::avx2>(
+ asmjit::X86Emitter* a) {
+ // bottom-left corner
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ a->mov(scratchReg1_, H_R_);
+ a->sub(scratchReg1_, static_cast<asmjit::Imm>(2 * H_PAD_));
+ a->imul(scratchReg1_, W_R_);
+ a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ for (int r = 0; r < R_ - H_PAD_; ++r) {
+ if (!isZeroPointZero_) {
+ for (int s = 0; s < W_PAD_; ++s) {
+ gen8bitFMA<inst_set_t::avx2>(
+ a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ }
+ }
+ for (int s = W_PAD_; s < S_; ++s) {
+ a->vbroadcastsd(
+ actRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_,
+ scratchReg1_,
+ 0,
+ (s - W_PAD_) * C_ * sizeof(uint8_t)));
+ gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ }
+ a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->add(scratchReg1_, scratchReg2_);
+ }
+ if (!isZeroPointZero_) {
+ for (int r = R_ - H_PAD_; r < R_; ++r) {
+ for (int s = 0; s < S_; ++s) {
+ gen8bitFMA<inst_set_t::avx2>(
+ a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ }
+ }
+ }
+
+ // we updating the last row
+ a->mov(scratchReg1_, H_R_);
+ a->sub(scratchReg1_, 1);
+ a->imul(scratchReg1_, W_R_);
+ a->imul(scratchReg1_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+ a->add(out_acts_R_, scratchReg1_);
+ // storeResult<inst_set_t::avx2>(a, (H_-1)*W_*K_*sizeof(int32_t));
+ storeResult<inst_set_t::avx2>(a);
+ a->add(out_acts_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+
+ // bottom edge excluding corners
+ asmjit::Label LoopBottomEdge = a->newLabel();
+ a->mov(loopR2_, static_cast<asmjit::Imm>(W_PAD_));
+ a->bind(LoopBottomEdge);
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ a->mov(scratchReg1_, H_R_);
+ a->sub(scratchReg1_, static_cast<asmjit::Imm>(2 * H_PAD_));
+ a->imul(scratchReg1_, W_R_);
+ a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ for (int r = 0; r < R_ - W_PAD_; ++r) {
+ // int h_in = H_-2*H_PAD_ + r;
+ for (int s = 0; s < S_; ++s) {
+ a->vbroadcastsd(
+ actRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t)));
+ gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ }
+ a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->add(scratchReg1_, scratchReg2_);
+ }
+
+ if (!isZeroPointZero_) {
+ for (int r = R_ - W_PAD_; r < R_; ++r) {
+ for (int s = 0; s < S_; ++s) {
+ gen8bitFMA<inst_set_t::avx2>(
+ a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ }
+ }
+ }
+
+ a->add(in_acts_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ // storeResult<inst_set_t::avx2>(a, ((H_-1)*W_+1)*K_*sizeof(int32_t));
+ storeResult<inst_set_t::avx2>(a);
+
+ a->add(out_acts_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+ a->inc(loopR2_);
+ a->mov(loopR1_, W_R_);
+ a->sub(loopR1_, static_cast<asmjit::Imm>(W_PAD_));
+ a->cmp(loopR2_, loopR1_);
+ a->jl(LoopBottomEdge);
+ a->mov(scratchReg1_, W_R_);
+ a->sub(scratchReg1_, static_cast<asmjit::Imm>(2 * W_PAD_));
+ a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->sub(in_acts_R_, scratchReg1_);
+ // a->sub(in_acts_R_, (W_ - 2*W_PAD_)*C_*sizeof(uint8_t));
+ // a->sub(out_acts_R_, (W_ - 2*W_PAD_)*K_*sizeof(int32_t));
+
+ // bottom-right corner
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ // input start point
+ // ((H_-(R_-H_PAD_))*W_+(W_-(S_-W_PAD_)))*C_*sizeof(uint8_t)
+ a->mov(scratchReg1_, H_R_);
+ a->sub(scratchReg1_, static_cast<asmjit::Imm>(R_ - H_PAD_));
+ a->imul(scratchReg1_, W_R_);
+ a->add(scratchReg1_, W_R_);
+ a->sub(scratchReg1_, static_cast<asmjit::Imm>(S_ - W_PAD_));
+ a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ for (int r = 0; r < R_ - H_PAD_; ++r) {
+ for (int s = 0; s < S_ - W_PAD_; ++s) {
+ a->vbroadcastsd(
+ actRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t)));
+ gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ }
+ a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->add(scratchReg1_, scratchReg2_);
+ if (!isZeroPointZero_) {
+ for (int s = S_ - W_PAD_; s < S_; ++s) {
+ gen8bitFMA<inst_set_t::avx2>(
+ a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ }
+ }
+ }
+
+ if (!isZeroPointZero_) {
+ for (int r = R_ - H_PAD_; r < R_; ++r) {
+ for (int s = 0; s < S_; ++s) {
+ gen8bitFMA<inst_set_t::avx2>(
+ a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ }
+ }
+ }
+
+ storeResult<inst_set_t::avx2>(a);
+ // storeResult<inst_set_t::avx2>(a, ((H_-1)*W_+W_-1)*K_*sizeof(int32_t));
+ // reset output pointer
+ a->mov(scratchReg1_, H_R_);
+ a->sub(scratchReg1_, 1);
+ a->imul(scratchReg1_, W_R_);
+ a->add(scratchReg1_, W_R_);
+ a->sub(scratchReg1_, 1);
+ a->imul(scratchReg1_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+ a->sub(out_acts_R_, scratchReg1_);
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::genCoreInsts<inst_set_t::avx2>(
+ asmjit::X86Emitter* a) {
+ // main compute
+ asmjit::Label LoopH = a->newLabel();
+ asmjit::Label LoopW = a->newLabel();
+ // base for output
+ a->mov(scratchReg2_, static_cast<asmjit::Imm>(H_PAD_));
+ a->imul(scratchReg2_, W_R_);
+ a->add(scratchReg2_, static_cast<asmjit::Imm>(W_PAD_));
+ a->imul(scratchReg2_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+ a->add(out_acts_R_, scratchReg2_);
+
+ a->mov(scratchReg1_, W_R_);
+ a->sub(scratchReg1_, static_cast<asmjit::Imm>(W_PAD_));
+
+ // H loop
+ a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_));
+ a->bind(LoopH);
+ // W loop
+ a->mov(loopR2_, static_cast<asmjit::Imm>(W_PAD_));
+ a->bind(LoopW);
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ // compute on all filters
+ for (int r = 0; r < R_; ++r) {
+ for (int s = 0; s < S_; ++s) {
+ a->vbroadcastsd(
+ actRegAvx2_, x86::dword_ptr(in_acts_R_, s * C_ * sizeof(uint8_t)));
+ gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ }
+ a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->add(in_acts_R_, scratchReg2_);
+ }
+ a->imul(
+ scratchReg2_, W_R_, static_cast<asmjit::Imm>(R_ * C_ * sizeof(uint8_t)));
+ a->sub(in_acts_R_, scratchReg2_);
+ // a->add(scratchReg1_, C_*sizeof(uint8_t));
+ a->add(in_acts_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+
+ // storeResult<inst_set_t::avx2>(a, (W_+1)*K_*sizeof(int32_t));
+ storeResult<inst_set_t::avx2>(a);
+
+ a->add(out_acts_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+
+ a->inc(loopR2_);
+ a->cmp(loopR2_, scratchReg1_);
+ a->jl(LoopW);
+ // add (W_ - 2*W_PAD_)*C_*sizeof(uint8_t) and subtract W_*C_*sizeof(uint8_t)
+ a->add(
+ in_acts_R_, static_cast<asmjit::Imm>(2 * W_PAD_ * C_ * sizeof(uint8_t)));
+ // a->sub(in_acts_R_, (W_ - 2*W_PAD_)*C_*sizeof(uint8_t));
+ // a->add(in_acts_R_, W_*C_*sizeof(uint8_t));
+ a->add(
+ out_acts_R_, static_cast<asmjit::Imm>(2 * W_PAD_ * K_ * sizeof(int32_t)));
+ // a->sub(out_acts_R_, (W_ - 2*W_PAD_)*K_*sizeof(int32_t));
+ // a->add(out_acts_R_, W_*K_*sizeof(int32_t));
+
+ a->inc(loopR1_);
+ a->mov(scratchReg2_, H_R_);
+ a->sub(scratchReg2_, static_cast<asmjit::Imm>(H_PAD_));
+ a->cmp(loopR1_, scratchReg2_);
+ a->jl(LoopH);
+}
+
+template <>
+template <>
+jit_conv_kernel_fp GenConvKernel<int32_t>::getOrCreate<inst_set_t::avx2>(
+ const conv_param_t<>& conv_param) {
+ code_.reset(false);
+ code_.init(rt_.getCodeInfo());
+ asmjit::X86Assembler assembler(&code_);
+ asmjit::X86Emitter* a = assembler.asEmitter();
+
+#if defined(FBGEMM_LOG_CODE)
+ // log code to a file
+ FILE* codeLogfile =
+ fopen(getCodeLoggingFile<inst_set_t::avx2>(false).c_str(), "w");
+ asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
+ if (codeLogger) {
+ code_.setLogger(codeLogger);
+ }
+#endif
+
+ // arguments to the function created
+ in_acts_R_ = a->zdi();
+ wghts_R_ = a->zsi();
+ out_acts_R_ = a->zdx();
+ a_zero_pt_R_ = a->zcx();
+ H_R_ = a->gpzRef(8);
+ W_R_ = a->gpzRef(9);
+ row_offset_R_ = a->gpzRef(10);
+
+ // register for temporary use
+ scratchReg1_ = a->gpzRef(12);
+ scratchReg2_ = a->gpzRef(13);
+
+ asmjit::FuncDetail func;
+ func.init(asmjit::FuncSignature6<
+ void,
+ uint8_t*,
+ int8_t*,
+ int32_t*,
+ int32_t,
+ int32_t,
+ int32_t>(asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrameInfo ffi;
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindVec,
+ asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindGp, asmjit::Utils::mask(10, 11, 12, 13, 14, 15));
+
+ asmjit::FuncArgsMapper args(&func);
+ args.assignAll(in_acts_R_, wghts_R_, out_acts_R_, a_zero_pt_R_, H_R_, W_R_);
+
+ args.updateFrameInfo(ffi);
+
+ asmjit::FuncFrameLayout layout;
+ layout.init(func, ffi);
+
+ asmjit::FuncUtils::emitProlog(a, layout);
+ asmjit::FuncUtils::allocArgs(a, layout, args);
+
+ createVector16BitOne<inst_set_t::avx2>(a);
+
+ loopR1_ = a->gpzRef(14);
+ loopR2_ = a->gpzRef(15);
+
+ if (!isZeroPointZero_) {
+ setToZeroPt<inst_set_t::avx2>(a, zeroPTRegAvx2_);
+ }
+
+ genForLoadingWeights<inst_set_t::avx2>(a);
+
+ genConstForPermutations<inst_set_t::avx2>(a);
+
+ genForTopEdge<inst_set_t::avx2>(a);
+ genForLeftEdge<inst_set_t::avx2>(a);
+ genForRightEdge<inst_set_t::avx2>(a);
+ genForBottomEdge<inst_set_t::avx2>(a);
+
+ genCoreInsts<inst_set_t::avx2>(a);
+
+ asmjit::FuncUtils::emitEpilog(a, layout);
+
+ jit_conv_kernel_fp fn;
+ asmjit::Error err = rt_.add(&fn, &code_);
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
+ auto kernelSig = getKernelSig(conv_param, isZeroPointZero_);
+ codeCache_[kernelSig] = fn;
+
+#if defined(FBGEMM_LOG_CODE)
+ fclose(codeLogfile);
+ delete codeLogger;
+#endif
+
+ return fn;
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>(
+ asmjit::X86Emitter* a) {
+ // top-left corner code
+ // zero out the results register
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ for (int r = 0; r < R_; ++r) {
+ int h_in = -H_PAD_ + r;
+ if (h_in >= 0) {
+ a->imul(
+ scratchReg1_,
+ W_R_,
+ static_cast<asmjit::Imm>(h_in * C_ * sizeof(uint8_t)));
+ }
+ for (int s = 0; s < S_; ++s) {
+ int w_in = -W_PAD_ + s;
+ if (h_in >= 0 && w_in >= 0) {
+ a->vmovaps(
+ actRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_, scratchReg1_, 0, w_in * C_ * sizeof(uint8_t)));
+ gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
+ } else {
+ if (!isZeroPointZero_) {
+ gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
+ }
+ }
+ }
+ }
+ // store results
+ storeResultRowoffset<inst_set_t::avx2>(a);
+
+ // for C_per_G == 4 and K_per_G == 4, 8 groups processed at a time
+ a->add(row_offset_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
+
+ // top edge excluding corners
+ asmjit::Label LoopTopEdge = a->newLabel();
+ a->mov(loopR2_, static_cast<asmjit::Imm>(W_PAD_));
+ a->bind(LoopTopEdge);
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ if (!isZeroPointZero_) {
+ for (int r = 0; r < H_PAD_; ++r) {
+ for (int s = 0; s < S_; ++s) {
+ gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
+ }
+ }
+ }
+ for (int r = H_PAD_; r < R_; ++r) {
+ int h_in = -H_PAD_ + r;
+ a->imul(
+ scratchReg1_,
+ W_R_,
+ static_cast<asmjit::Imm>(h_in * C_ * sizeof(uint8_t)));
+ for (int s = 0; s < S_; ++s) {
+ a->vmovaps(
+ actRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t)));
+ gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
+ }
+ }
+ a->add(in_acts_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+
+ // store results
+ storeResultRowoffset<inst_set_t::avx2>(a);
+
+ a->add(row_offset_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
+ a->mov(loopR1_, W_R_);
+ a->sub(loopR1_, static_cast<asmjit::Imm>(W_PAD_));
+ a->inc(loopR2_);
+ a->cmp(loopR2_, loopR1_);
+ a->jl(LoopTopEdge);
+ a->mov(scratchReg2_, W_R_);
+ a->imul(scratchReg2_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->sub(
+ scratchReg2_,
+ static_cast<asmjit::Imm>(2 * W_PAD_ * C_ * sizeof(uint8_t)));
+ a->sub(in_acts_R_, scratchReg2_);
+
+ // top-right corner code
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ if (!isZeroPointZero_) {
+ for (int r = 0; r < H_PAD_; ++r) {
+ for (int s = 0; s < S_; ++s) {
+ gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
+ }
+ }
+ }
+ for (int r = H_PAD_; r < R_; ++r) {
+ int h_in = -H_PAD_ + r;
+ for (int s = 0; s < S_ - W_PAD_; ++s) {
+ a->imul(
+ scratchReg1_,
+ W_R_,
+ static_cast<asmjit::Imm>(h_in * C_ * sizeof(uint8_t)));
+ a->mov(scratchReg2_, W_R_);
+ a->sub(scratchReg2_, static_cast<asmjit::Imm>(R_ - W_PAD_ - s));
+ a->imul(scratchReg2_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->add(scratchReg1_, scratchReg2_);
+ a->vmovaps(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_));
+ gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
+ }
+ if (!isZeroPointZero_) {
+ for (int s = S_ - W_PAD_; s < S_; ++s) {
+ gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
+ }
+ }
+ }
+
+ // store results
+ storeResultRowoffset<inst_set_t::avx2>(a);
+
+ a->add(row_offset_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
+
+ // reset output pointer
+ a->imul(scratchReg1_, W_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
+ a->sub(row_offset_R_, scratchReg1_);
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>(
+ asmjit::X86Emitter* a) {
+ // left edge excluding corners
+ asmjit::Label LoopLeftEdge = a->newLabel();
+ a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_));
+ a->bind(LoopLeftEdge);
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ a->mov(scratchReg1_, loopR1_);
+ a->sub(scratchReg1_, static_cast<asmjit::Imm>(H_PAD_));
+ a->imul(scratchReg1_, W_R_);
+ a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ for (int r = 0; r < R_; ++r) {
+ if (!isZeroPointZero_) {
+ for (int s = 0; s < W_PAD_; ++s) {
+ gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
+ }
+ }
+ for (int s = W_PAD_; s < S_; ++s) {
+ a->vmovaps(
+ actRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_,
+ scratchReg1_,
+ 0,
+ (s - W_PAD_) * C_ * sizeof(uint8_t)));
+ gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
+ }
+ a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->add(scratchReg1_, scratchReg2_);
+ }
+
+ a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
+ a->add(row_offset_R_, scratchReg2_);
+ storeResultRowoffset<inst_set_t::avx2>(a);
+
+ a->inc(loopR1_);
+ a->mov(loopR2_, H_R_);
+ a->sub(loopR2_, static_cast<asmjit::Imm>(H_PAD_));
+ a->cmp(loopR1_, loopR2_);
+ a->jl(LoopLeftEdge);
+
+ // reset output pointer
+ a->mov(scratchReg2_, H_R_);
+ a->sub(scratchReg2_, static_cast<asmjit::Imm>(2 * H_PAD_));
+ a->imul(scratchReg2_, W_R_);
+ a->imul(scratchReg2_, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
+ a->sub(row_offset_R_, scratchReg2_);
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>(
+ asmjit::X86Emitter* a) {
+ // right edge excluding corners
+ asmjit::Label LoopRightEdge = a->newLabel();
+
+ // output pointer to the right edge
+ // (W_ + W_ - 1)*8*sizeof(int32_t)
+ a->mov(scratchReg2_, W_R_);
+ a->imul(scratchReg2_, 2);
+ a->sub(scratchReg2_, 1);
+ a->imul(scratchReg2_, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
+ a->add(row_offset_R_, scratchReg2_);
+
+ a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_));
+ a->bind(LoopRightEdge);
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ a->mov(scratchReg1_, loopR1_);
+ a->sub(scratchReg1_, static_cast<asmjit::Imm>(H_PAD_));
+ a->imul(scratchReg1_, W_R_);
+ a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+
+ a->mov(scratchReg2_, W_R_);
+ a->sub(scratchReg2_, static_cast<asmjit::Imm>(2 * W_PAD_));
+ a->imul(scratchReg2_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->add(scratchReg1_, scratchReg2_);
+ for (int r = 0; r < R_; ++r) {
+ for (int s = 0; s < S_ - W_PAD_; ++s) {
+ a->vbroadcastsd(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_));
+ a->vmovaps(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_));
+ gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
+ a->add(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ }
+ if (!isZeroPointZero_) {
+ for (int s = S_ - W_PAD_; s < S_; ++s) {
+ gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
+ }
+ }
+
+ a->sub(
+ scratchReg1_,
+ static_cast<asmjit::Imm>((S_ - W_PAD_) * C_ * sizeof(uint8_t)));
+ a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->add(scratchReg1_, scratchReg2_);
+ }
+
+ storeResultRowoffset<inst_set_t::avx2>(a);
+
+ a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
+ a->add(row_offset_R_, scratchReg2_);
+ a->mov(loopR2_, H_R_);
+ a->sub(loopR2_, static_cast<asmjit::Imm>(H_PAD_));
+ a->inc(loopR1_);
+ a->cmp(loopR1_, loopR2_);
+ a->jl(LoopRightEdge);
+
+ // reset base
+ a->mov(scratchReg2_, W_R_);
+ a->imul(scratchReg2_, 2);
+ a->sub(scratchReg2_, 1);
+ a->imul(scratchReg2_, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
+ a->sub(row_offset_R_, scratchReg2_);
+
+ // reset increments done in the loop
+ //(H_ - 2*H_PAD_)*W_*8*sizeof(int32_t)
+ a->mov(scratchReg2_, H_R_);
+ a->sub(scratchReg2_, static_cast<asmjit::Imm>(2 * H_PAD_));
+ a->imul(scratchReg2_, W_R_);
+ a->imul(scratchReg2_, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
+ a->sub(row_offset_R_, scratchReg2_);
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
+ asmjit::X86Emitter* a) {
+ // bottom-left corner
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ a->mov(scratchReg1_, H_R_);
+ a->sub(scratchReg1_, static_cast<asmjit::Imm>(2 * H_PAD_));
+ a->imul(scratchReg1_, W_R_);
+ a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ for (int r = 0; r < R_ - H_PAD_; ++r) {
+ if (!isZeroPointZero_) {
+ for (int s = 0; s < W_PAD_; ++s) {
+ gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
+ }
+ }
+ for (int s = W_PAD_; s < S_; ++s) {
+ a->vmovaps(
+ actRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_,
+ scratchReg1_,
+ 0,
+ (s - W_PAD_) * C_ * sizeof(uint8_t)));
+ gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
+ }
+ a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->add(scratchReg1_, scratchReg2_);
+ }
+ if (!isZeroPointZero_) {
+ for (int r = R_ - H_PAD_; r < R_; ++r) {
+ for (int s = 0; s < S_; ++s) {
+ gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
+ }
+ }
+ }
+
+ // we updating the last row
+ a->mov(scratchReg1_, H_R_);
+ a->sub(scratchReg1_, 1);
+ a->imul(scratchReg1_, W_R_);
+ a->imul(scratchReg1_, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
+ a->add(row_offset_R_, scratchReg1_);
+ storeResultRowoffset<inst_set_t::avx2>(a);
+ a->add(row_offset_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
+
+ // bottom edge excluding corners
+ asmjit::Label LoopBottomEdge = a->newLabel();
+ a->mov(loopR2_, static_cast<asmjit::Imm>(W_PAD_));
+ a->bind(LoopBottomEdge);
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ a->mov(scratchReg1_, H_R_);
+ a->sub(scratchReg1_, static_cast<asmjit::Imm>(2 * H_PAD_));
+ a->imul(scratchReg1_, W_R_);
+ a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ for (int r = 0; r < R_ - W_PAD_; ++r) {
+ // int h_in = H_-2*H_PAD_ + r;
+ for (int s = 0; s < S_; ++s) {
+ a->vmovaps(
+ actRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t)));
+ gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
+ }
+ a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->add(scratchReg1_, scratchReg2_);
+ }
+
+ if (!isZeroPointZero_) {
+ for (int r = R_ - W_PAD_; r < R_; ++r) {
+ for (int s = 0; s < S_; ++s) {
+ gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
+ }
+ }
+ }
+
+ a->add(in_acts_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ // storeResult<inst_set_t::avx2>(a, ((H_-1)*W_+1)*8*sizeof(int32_t));
+ storeResultRowoffset<inst_set_t::avx2>(a);
+
+ a->add(row_offset_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
+ a->inc(loopR2_);
+ a->mov(loopR1_, W_R_);
+ a->sub(loopR1_, static_cast<asmjit::Imm>(W_PAD_));
+ a->cmp(loopR2_, loopR1_);
+ a->jl(LoopBottomEdge);
+ a->mov(scratchReg1_, W_R_);
+ a->sub(scratchReg1_, static_cast<asmjit::Imm>(2 * W_PAD_));
+ a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->sub(in_acts_R_, scratchReg1_);
+ // a->sub(in_acts_R_, (W_ - 2*W_PAD_)*C_*sizeof(uint8_t));
+ // a->sub(out_acts_R_, (W_ - 2*W_PAD_)*8*sizeof(int32_t));
+
+ // bottom-right corner
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ // input start point
+ // ((H_-(R_-H_PAD_))*W_+(W_-(S_-W_PAD_)))*C_*sizeof(uint8_t)
+ a->mov(scratchReg1_, H_R_);
+ a->sub(scratchReg1_, static_cast<asmjit::Imm>(R_ - H_PAD_));
+ a->imul(scratchReg1_, W_R_);
+ a->add(scratchReg1_, W_R_);
+ a->sub(scratchReg1_, static_cast<asmjit::Imm>(S_ - W_PAD_));
+ a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ for (int r = 0; r < R_ - H_PAD_; ++r) {
+ for (int s = 0; s < S_ - W_PAD_; ++s) {
+ a->vmovaps(
+ actRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t)));
+ gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
+ }
+ a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->add(scratchReg1_, scratchReg2_);
+ if (!isZeroPointZero_) {
+ for (int s = S_ - W_PAD_; s < S_; ++s) {
+ gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
+ }
+ }
+ }
+
+ if (!isZeroPointZero_) {
+ for (int r = R_ - H_PAD_; r < R_; ++r) {
+ for (int s = 0; s < S_; ++s) {
+ gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
+ }
+ }
+ }
+
+ storeResultRowoffset<inst_set_t::avx2>(a);
+ // reset output pointer
+ a->mov(scratchReg1_, H_R_);
+ a->sub(scratchReg1_, 1);
+ a->imul(scratchReg1_, W_R_);
+ a->add(scratchReg1_, W_R_);
+ a->sub(scratchReg1_, 1);
+ a->imul(scratchReg1_, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
+ a->sub(row_offset_R_, scratchReg1_);
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::genRowoffsetCore<inst_set_t::avx2>(
+ asmjit::X86Emitter* a) {
+ // number of uint8 elements in input channels should be a multiple of 32
+ assert(C_ % 32 == 0);
+
+ asmjit::Label LoopH = a->newLabel();
+ asmjit::Label LoopW = a->newLabel();
+ // base for output
+ a->mov(scratchReg2_, static_cast<asmjit::Imm>(H_PAD_));
+ a->imul(scratchReg2_, W_R_);
+ a->add(scratchReg2_, static_cast<asmjit::Imm>(W_PAD_));
+ a->imul(scratchReg2_, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
+ a->add(row_offset_R_, scratchReg2_);
+
+ a->mov(scratchReg1_, W_R_);
+ a->sub(scratchReg1_, static_cast<asmjit::Imm>(W_PAD_));
+
+ // H loop
+ a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_));
+ a->bind(LoopH);
+ // W loop
+ a->mov(loopR2_, static_cast<asmjit::Imm>(W_PAD_));
+ a->bind(LoopW);
+
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ for (int r = 0; r < R_; ++r) {
+ for (int s = 0; s < S_; ++s) {
+ a->vmovaps(
+ actRegAvx2_, x86::dword_ptr(in_acts_R_, s * C_ * sizeof(uint8_t)));
+ gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
+ }
+ a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->add(in_acts_R_, scratchReg2_);
+ }
+ a->imul(
+ scratchReg2_, W_R_, static_cast<asmjit::Imm>(R_ * C_ * sizeof(uint8_t)));
+ a->sub(in_acts_R_, scratchReg2_);
+ // store results
+ storeResultRowoffset<inst_set_t::avx2>(a);
+
+ a->add(in_acts_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->add(row_offset_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
+
+ a->inc(loopR2_);
+ a->cmp(loopR2_, scratchReg1_);
+ a->jl(LoopW);
+ a->add(
+ in_acts_R_, static_cast<asmjit::Imm>(2 * W_PAD_ * C_ * sizeof(uint8_t)));
+ a->add(
+ row_offset_R_,
+ static_cast<asmjit::Imm>(2 * W_PAD_ * 8 * sizeof(int32_t)));
+ a->inc(loopR1_);
+ a->mov(scratchReg2_, H_R_);
+ a->sub(scratchReg2_, static_cast<asmjit::Imm>(H_PAD_));
+ a->cmp(loopR1_, scratchReg2_);
+ a->jl(LoopH);
+}
+
+template <>
+template <>
+jit_rowoffset_kernel_fp
+GenConvKernel<int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
+ const conv_param_t<>& conv_param) {
+ code_.reset(false);
+ code_.init(rt_.getCodeInfo());
+ asmjit::X86Assembler assembler(&code_);
+ asmjit::X86Emitter* a = assembler.asEmitter();
+
+#if defined(FBGEMM_LOG_CODE)
+ // log code to a file
+ FILE* codeLogfile =
+ fopen(getCodeLoggingFile<inst_set_t::avx2>(true).c_str(), "w");
+ asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
+ if (codeLogger) {
+ code_.setLogger(codeLogger);
+ }
+#endif
+
+ // arguments to the function created
+ in_acts_R_ = a->zdi();
+ a_zero_pt_R_ = a->zsi();
+ H_R_ = a->zdx();
+ W_R_ = a->zcx();
+ row_offset_R_ = a->gpzRef(8);
+
+ // register for temporary use
+ scratchReg1_ = a->gpzRef(12);
+ scratchReg2_ = a->gpzRef(13);
+
+ loopR1_ = a->gpzRef(14);
+ loopR2_ = a->gpzRef(15);
+
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::
+ FuncSignature5<void, uint8_t*, int32_t, int32_t, int32_t, int32_t*>(
+ asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrameInfo ffi;
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindVec,
+ asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindGp, asmjit::Utils::mask(10, 11, 12, 13, 14, 15));
+
+ asmjit::FuncArgsMapper args(&func);
+ args.assignAll(in_acts_R_, a_zero_pt_R_, H_R_, W_R_, row_offset_R_);
+
+ args.updateFrameInfo(ffi);
+
+ asmjit::FuncFrameLayout layout;
+ layout.init(func, ffi);
+
+ asmjit::FuncUtils::emitProlog(a, layout);
+ asmjit::FuncUtils::allocArgs(a, layout, args);
+
+ // This uses xmm10 register temporarily. Should come before
+ // createVector8BitOne
+ if (!isZeroPointZero_) {
+ setToZeroPt<inst_set_t::avx2>(a, zeroPTRegAvx2_);
+ }
+
+ createVector16BitOne<inst_set_t::avx2>(a);
+ // we set ymm10 to contain 8-bit 1s
+ createVector8BitOne<inst_set_t::avx2>(a);
+
+ genForTopEdgeRowoffset<inst_set_t::avx2>(a);
+ genForLeftEdgeRowoffset<inst_set_t::avx2>(a);
+ genForRightEdgeRowoffset<inst_set_t::avx2>(a);
+ genForBottomEdgeRowoffset<inst_set_t::avx2>(a);
+
+ genRowoffsetCore<inst_set_t::avx2>(a);
+
+ asmjit::FuncUtils::emitEpilog(a, layout);
+
+ jit_rowoffset_kernel_fp fn;
+ asmjit::Error err = rt_.add(&fn, &code_);
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
+ auto kernelSig = getKernelSig(conv_param, isZeroPointZero_);
+ codeCacheRowOffset_[kernelSig] = fn;
+
+#if defined(FBGEMM_LOG_CODE)
+ delete codeLogger;
+ fclose(codeLogfile);
+#endif
+
+ return fn;
+}
+
+template <
+ typename packed_W,
+ typename outType,
+ typename processOutputType,
+ int SPATIAL_DIM>
+void fbgemmGroupwiseConv(
+ const conv_param_t<SPATIAL_DIM>& conv_param,
+ const std::uint8_t* activations,
+ std::int32_t a_zero_point,
+ std::int32_t* rowOffsetBuf,
+ packed_W& packed_weights,
+ outType* out,
+ int32_t* outBuffer,
+ const processOutputType& outProcess,
+ int thread_id,
+ int num_threads) {
+
+ int MB = conv_param.MB;
+ int H = conv_param.OUT_DIM[0];
+ int W = conv_param.OUT_DIM[1];
+ int G = conv_param.G;
+ int K_per_G = conv_param.OC / G;
+ int C_per_G = conv_param.IC / G;
+ int oh_ow = conv_param.OUT_DIM[0] * conv_param.OUT_DIM[1];
+
+ static_assert(SPATIAL_DIM == 2, "3D conv not supported yet");
+
+ int32_t* rowOffsetTrDest =
+ rowOffsetBuf + 8 * conv_param.IN_DIM[0] * conv_param.IN_DIM[1];
+ if (fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param)) {
+ assert(G % 8 == 0);
+ // generate convolution kernel
+ jit_conv_kernel_fp fpConv =
+ getOrCreateConvKernel<>(conv_param, a_zero_point);
+ // generate row offset kernel
+ jit_rowoffset_kernel_fp fpRowoffset =
+ getOrCreateRowOffsetKernel(conv_param, a_zero_point);
+ for (int i = 0; i < MB; ++i) {
+ const uint8_t* actStartBatch = activations +
+ i * conv_param.IN_DIM[0] * conv_param.IN_DIM[1] * conv_param.IC;
+ for (int gOuter = 0; gOuter < G; gOuter += 8) {
+ // for C_per_G == 4 and K_per_G == 4, row offset is calcualted for 8
+ // groups at a time The result is row offsets in the format IH*IW x G
+ fpRowoffset(
+ actStartBatch + gOuter * C_per_G,
+ a_zero_point,
+ H,
+ W,
+ rowOffsetBuf);
+ // Transpose to get row offsets in the format G x IH*IW
+ internal::transpose_8x8(
+ conv_param.IN_DIM[0] * conv_param.IN_DIM[1],
+ 8,
+ (const float*)rowOffsetBuf,
+ 8,
+ (float*)rowOffsetTrDest,
+ conv_param.IN_DIM[0] * conv_param.IN_DIM[1]);
+ int gLimit = gOuter + 8;
+ for (int g = gOuter; g < gLimit; g += 2) {
+ int32_t* currOutBuf =
+ outBuffer + i * oh_ow * conv_param.OC + g * K_per_G;
+ const uint8_t* actStartGroup = actStartBatch + g * C_per_G;
+
+ fpConv(
+ actStartGroup,
+ packed_weights.getBuf() + g * K_per_G * C_per_G,
+ currOutBuf,
+ a_zero_point,
+ H,
+ W);
+
+ // Output processing should be called for each group
+ for (int j = 0; j < 2; ++j) {
+ // calculateRowOffsets(
+ // conv_param, actStartGroup, rowOffsetBuf, a_zero_point, j);
+ int32_t* rowOffsetForCurG = rowOffsetTrDest +
+ ((g - gOuter) + j) * conv_param.IN_DIM[0] *
+ conv_param.IN_DIM[1];
+ // compare_buffers(rowOffsetBuf, rowOffsetForCurG,
+ // conv_param.IN_DIM[0]*conv_param.IN_DIM[1], 1, 1, 100);
+
+ // outProcess expects rowOffsetBuf to contain row offsets for the
+ // current group
+ memcpy(
+ rowOffsetBuf,
+ rowOffsetForCurG,
+ conv_param.IN_DIM[0] * conv_param.IN_DIM[1] * sizeof(int32_t));
+
+ if (cpuinfo_has_x86_avx512f()) {
+ // Currently use avx2 code
+ outProcess.template f<inst_set_t::avx2>(
+ out,
+ currOutBuf + j * K_per_G,
+ {i * oh_ow, oh_ow, (g + j) * K_per_G, K_per_G},
+ K_per_G * G,
+ K_per_G * G);
+ } else if (cpuinfo_has_x86_avx2()) {
+ outProcess.template f<inst_set_t::avx2>(
+ out,
+ currOutBuf + j * K_per_G,
+ {i * oh_ow, oh_ow, (g + j) * K_per_G, K_per_G},
+ K_per_G * G,
+ K_per_G * G);
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unsupported architecure");
+ }
+ } // j loop
+ }
+ }
+ }
+ } else {
+ // for the not supported cases, just execute the naive C implementation
+ conv_ref(
+ conv_param,
+ activations,
+ a_zero_point,
+ packed_weights.getBuf(),
+ outBuffer);
+ for (int i = 0; i < conv_param.MB; ++i) {
+ for (int g = 0; g < conv_param.G; ++g) {
+ calculateRowOffsets(
+ conv_param,
+ activations +
+ i * conv_param.IN_DIM[0] * conv_param.IN_DIM[1] * conv_param.IC,
+ rowOffsetBuf,
+ a_zero_point,
+ g);
+ outProcess.template f<inst_set_t::anyarch>(
+ out,
+ outBuffer + i * oh_ow * conv_param.OC + g * K_per_G,
+ {i * oh_ow, oh_ow, g * K_per_G, K_per_G},
+ K_per_G * G,
+ K_per_G * G);
+ }
+ }
+ }
+}
+
+jit_rowoffset_kernel_fp getOrCreateRowOffsetKernel(
+ const conv_param_t<>& conv_param,
+ int a_zero_point) {
+ // Note: Wrong code is generated if it's not one of the supported convolution
+ assert(fbgemmOptimizedGConv<2>(conv_param));
+ auto kernelSig = getKernelSig(conv_param, a_zero_point == 0);
+ if (GenConvKernel<int32_t>::codeCacheRowOffset_.find(kernelSig) !=
+ GenConvKernel<int32_t>::codeCacheRowOffset_.end()) {
+ return GenConvKernel<int32_t>::codeCacheRowOffset_[kernelSig];
+ } else {
+ auto genObj = GenConvKernel<int32_t>(conv_param, a_zero_point);
+ // TODO: Instruction set based dispatch
+ return genObj.template getOrCreateRowOffset<inst_set_t::avx2>(conv_param);
+ }
+}
+
+template <int SPATIAL_DIM>
+int rowOffsetBufferSizeGConv(const conv_param_t<SPATIAL_DIM>& conv_param) {
+ // row offset buffer should be a able to hold row offsets for however
+ // number of groups we process at a time.
+ if (cpuinfo_initialize()) {
+ if (cpuinfo_has_x86_avx512f()) {
+ int bufferSize = conv_param.OUT_DIM[0] * conv_param.OUT_DIM[1];
+ int C_per_G = conv_param.IC / conv_param.G;
+ int K_per_G = conv_param.OC / conv_param.G;
+ if (C_per_G == 4 && K_per_G == 4) {
+ return 2 * 8 * bufferSize;
+ } else {
+ return conv_param.G * bufferSize;
+ }
+ } else if (cpuinfo_has_x86_avx2()) {
+ int bufferSize = conv_param.OUT_DIM[0] * conv_param.OUT_DIM[1];
+ int C_per_G = conv_param.IC / conv_param.G;
+ int K_per_G = conv_param.OC / conv_param.G;
+ if (C_per_G == 4 && K_per_G == 4) {
+ // row offset is calculated for 8 groups at a time
+ // 2x is needed for transposing
+ return 2 * 8 * bufferSize;
+ } else {
+ return conv_param.G * bufferSize;
+ }
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unsupported architecture");
+ return -1;
+ }
+ } else {
+ throw std::runtime_error("Failed to initialize cpuinfo!");
+ }
+}
+
+template int rowOffsetBufferSizeGConv<2>(const conv_param_t<2>& conv_param);
+
+#define INSTANTIATE_BASE(RELU, Q_GRAN) \
+ template void fbgemmGroupwiseConv( \
+ const conv_param_t<2>& conv_param, \
+ const uint8_t* activations, \
+ int32_t a_zero_point, \
+ std::int32_t* rowOffsetBuf, \
+ PackWeightMatrixForGConv<int8_t, int32_t, 2>& packed_weights, \
+ uint8_t* out, \
+ int32_t* outBuffer, \
+ const ReQuantizeOutput<RELU, Q_GRAN>& outProcess, \
+ int thread_id, \
+ int num_threads);
+
+#define INSTANTIATE_Q_GRANS(RELU) \
+ INSTANTIATE_BASE(RELU, QuantizationGranularity::TENSOR); \
+ INSTANTIATE_BASE(RELU, QuantizationGranularity::GROUP); \
+ INSTANTIATE_BASE(RELU, QuantizationGranularity::OUT_CHANNEL);
+
+INSTANTIATE_Q_GRANS(false);
+INSTANTIATE_Q_GRANS(true);
+
+#undef INSTANTIATE_Q_GRANS
+#undef INSTANTIATE_BASE
+
+template void fbgemmGroupwiseConv(
+ const conv_param_t<2>& conv_param,
+ const uint8_t* activations,
+ int32_t a_zero_point,
+ std::int32_t* rowOffsetBuf,
+ PackWeightMatrixForGConv<int8_t, int32_t, 2>& packed_weights,
+ int32_t* out,
+ int32_t* outBuffer,
+ const DoNothing<int32_t, int32_t>& outProcess,
+ int thread_id,
+ int num_threads);
+
+} // namespace fbgemm
diff --git a/src/PackWeightMatrixForGConv.cc b/src/PackWeightMatrixForGConv.cc
new file mode 100644
index 0000000..e6c9b7d
--- /dev/null
+++ b/src/PackWeightMatrixForGConv.cc
@@ -0,0 +1,103 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <cpuinfo.h>
+#include <cassert>
+#include <iomanip>
+#include "RefImplementations.h"
+#include "fbgemm/Fbgemm.h"
+
+namespace fbgemm {
+
+template <typename T, typename accT, int SPATIAL_DIM>
+PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::PackWeightMatrixForGConv(
+ matrix_op_t trans,
+ const conv_param_t<SPATIAL_DIM>& conv_param,
+ const T* sdata,
+ T* pdata)
+ : trans_(trans), conv_param_(conv_param), sdata_(sdata) {
+ static_assert(SPATIAL_DIM == 2, "3D conv not supported yet");
+
+ if (!pdata) {
+ bufAllocatedHere_ = true;
+ pdata_ = static_cast<T*>(fbgemmAlignedAlloc(
+ 64,
+ conv_param_.G * conv_param_.K[0] * conv_param_.K[1] *
+ (conv_param_.OC / conv_param_.G) *
+ (conv_param_.IC / conv_param_.G) * sizeof(T)));
+ } else {
+ bufAllocatedHere_ = false;
+ pdata_ = pdata;
+ }
+ pack();
+}
+
+/**
+ * @brief Pack weight tensor in a suitable format required for the optimized
+ * kernel.
+ *
+ * Let IC_per_G be number of input channels per group and OC_per_G be number of
+ * output channels per group.
+ *
+ * For IC_per_G == 4 && OC_per_G == 4 optimized
+ * kernel works on 2 groups at a time hence input channels for g and g+1 group
+ * are laid out sequentially for each output channel, i.e., the layout is R S
+ * (G/2) K (2C)
+ * We work on two groups at a time to fully utilize the avx2 SIMD width of
+ * 256-bits.
+ *
+ * For IC_per_G == 8, 16, 32 && OC_per_G == 8, 16, 32 there is no need to work
+ * on 2 groups at a time and full SIMD width can be efficiently utilized even
+ * while working on 1 group at a time.
+ */
+template <typename T, typename accT, int SPATIAL_DIM>
+void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack() {
+ // filters are assumed to be in G RS C/G K/G format
+ int R = conv_param_.K[0];
+ int S = conv_param_.K[1];
+ int G = conv_param_.G;
+ int IC_per_G = conv_param_.IC / conv_param_.G;
+ int OC_per_G = conv_param_.OC / conv_param_.G;
+
+ // If transpose option is set, the weight matrix is in layout G K/G (R S C/G)
+ // instead of G (R S C/G) K/G
+ bool tr = (trans_ == matrix_op_t::Transpose);
+ if (fbgemmOptimizedGConv(conv_param_)) {
+ // currently only this case is supported
+ for (int r = 0; r < R; ++r) {
+ for (int s = 0; s < S; ++s) {
+ for (int k = 0; k < OC_per_G; ++k) {
+ for (int g = 0; g < G; ++g) {
+ for (int c = 0; c < IC_per_G; ++c) {
+ inpType b = tr
+ ? sdata_
+ [(((g * OC_per_G + k) * R + r) * S + s) * IC_per_G + c]
+ : sdata_
+ [(((g * R + r) * S + s) * IC_per_G + c) * OC_per_G + k];
+ pdata_
+ [((((r * S + s) * (G / 2) + (g / 2)) * OC_per_G + k) * 2 +
+ (g % 2)) *
+ IC_per_G +
+ c] = b;
+ }
+ }
+ }
+ }
+ }
+ } else {
+ if (tr) {
+ // conv_ref expects weights to be in G (R S C/G) K/G format
+ transposeConvWeights(conv_param_, sdata_, pdata_);
+ } else {
+ // just copy the data for not supported cases
+ memcpy(pdata_, sdata_, G * R * S * OC_per_G * IC_per_G * sizeof(inpType));
+ }
+ }
+}
+
+template class PackWeightMatrixForGConv<int8_t, int32_t, 2>;
+template class PackWeightMatrixForGConv<int8_t, int16_t, 2>;
+} // namespace fbgemm
diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc
index 5168a15..5c6cf1b 100644
--- a/src/RefImplementations.cc
+++ b/src/RefImplementations.cc
@@ -486,6 +486,31 @@ void conv3d_ref(
} // for each n
}
+void transposeConvWeights(
+ const conv_param_t<>& conv_p,
+ const std::int8_t* src,
+ std::int8_t* dest) {
+ int R = conv_p.K[0];
+ int S = conv_p.K[1];
+ int G = conv_p.G;
+ int IC_per_G = conv_p.IC / conv_p.G;
+ int OC_per_G = conv_p.OC / conv_p.G;
+
+ // Transforms weights from G K/G (R S C/G) to G (R S C/G) K/G format.
+ for (int r = 0; r < R; ++r) {
+ for (int s = 0; s < S; ++s) {
+ for (int k = 0; k < OC_per_G; ++k) {
+ for (int g = 0; g < G; ++g) {
+ for (int c = 0; c < IC_per_G; ++c) {
+ dest[(((g * R + r) * S + s) * IC_per_G + c) * OC_per_G + k] =
+ src[(((g * OC_per_G + k) * R + r) * S + s) * IC_per_G + c];
+ }
+ }
+ }
+ }
+ }
+}
+
void depthwise_3x3_pad_1_ref(
int N,
int H,
diff --git a/src/RefImplementations.h b/src/RefImplementations.h
index fce68e6..62f17e9 100644
--- a/src/RefImplementations.h
+++ b/src/RefImplementations.h
@@ -176,6 +176,14 @@ FBGEMM_API void conv3d_ref(
std::int32_t* C);
/*
+ * @brief Transforms weights from G K/G (R S C/G) to G (R S C/G) K/G format.
+ */
+FBGEMM_API void transposeConvWeights(
+ const conv_param_t<>& conv_p,
+ const std::int8_t* src,
+ std::int8_t* dest);
+
+/*
* @brief Reference implementation of im2col operation.
* The input A is assumed to be in NHiWiC format.
* The output A is assumed to be in NHoWoRSC format.