Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CMakeLists.txt2
-rw-r--r--bench/GroupwiseConvRequantizeBenchmark.cc507
-rw-r--r--include/fbgemm/Fbgemm.h85
-rw-r--r--src/Fbgemm.cc16
-rw-r--r--src/GroupwiseConv.h248
-rw-r--r--src/GroupwiseConvAcc32Avx2.cc1552
-rw-r--r--src/PackWeightMatrixForGConv.cc103
-rw-r--r--src/RefImplementations.cc25
-rw-r--r--src/RefImplementations.h8
-rw-r--r--test/GConvTest.cc382
10 files changed, 2928 insertions, 0 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 5d889cd..dfb5623 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -34,12 +34,14 @@ set(FBGEMM_GENERIC_SRCS src/ExecuteKernel.cc
src/GenerateKernelU8S8S32ACC16Avx512.cc
src/GenerateKernelU8S8S32ACC32.cc
src/GenerateKernelU8S8S32ACC32Avx512.cc
+ src/GroupwiseConvAcc32Avx2.cc
src/PackAMatrix.cc
src/PackAWithIm2Col.cc
src/PackBMatrix.cc
src/PackMatrix.cc
src/PackAWithQuantRowOffset.cc
src/PackAWithRowOffset.cc
+ src/PackWeightMatrixForGConv.cc
src/QuantUtils.cc
src/RefImplementations.cc
src/Utils.cc)
diff --git a/bench/GroupwiseConvRequantizeBenchmark.cc b/bench/GroupwiseConvRequantizeBenchmark.cc
new file mode 100644
index 0000000..032d0d3
--- /dev/null
+++ b/bench/GroupwiseConvRequantizeBenchmark.cc
@@ -0,0 +1,507 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <algorithm>
+#include <chrono>
+#include <cmath>
+#include <iomanip>
+#include <iostream>
+#include <random>
+#include <vector>
+
+#ifdef _OPENMP
+#include <omp.h>
+#endif
+
+#include "BenchUtils.h"
+#include "fbgemm/Fbgemm.h"
+#include "src/RefImplementations.h"
+
+using namespace std;
+using namespace fbgemm;
+
+void performance_test() {
+ vector<conv_param_t<>> shapes = {
+ // MB, IC, OC, {IH, IW}, G, {KH, KW}, {stride_h, stride_w}, pad_t, pad_l,
+ // pad_b, pad_r
+ // conv_param_t<>(1, 16, 16, {16, 14}, 4, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 128, 128, {56, 48}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 128, 128, {48, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(2, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ // conv_param_t<>(1, 256, 256, {56, 56}, 64, {3, 3}, {1, 1}, {1, 1, 1,
+ // 1}),
+ // conv_param_t<>(1, 3, 64, {224, 224}, 1, {7, 7}, {2, 2}, {3, 3, 3, 3}),
+ // conv_param_t<>(1, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1,
+ // 1}),
+ // conv_param_t<>(1, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1,
+ // 1}),
+ // conv_param_t<>(1, 256, 256, {56, 56}, 32, {3, 3}, {2, 2}, {1, 1, 1,
+ // 1}),
+ // conv_param_t<>(1, 256, 256, {28, 28}, 32, {3, 3}, {1, 1}, {1, 1, 1,
+ // 1}),
+ // conv_param_t<>(1, 512, 512, {28, 28}, 32, {3, 3}, {2, 2}, {1, 1, 1,
+ // 1}),
+ // conv_param_t<>(1, 512, 512, {14, 14}, 32, {3, 3}, {1, 1}, {1, 1, 1,
+ // 1}),
+ // conv_param_t<>(1, 512, 512, {14, 14}, 32, {3, 3}, {1, 1}, {1, 1, 1,
+ // 1}),
+ // conv_param_t<>(1, 1024, 1024, {14, 14}, 32, {3, 3}, {2, 2}, {1, 1, 1,
+ // 1}),
+ // conv_param_t<>(1, 1024, 1024, {7, 7}, 32, {3, 3}, {1, 1}, {1, 1, 1,
+ // 1}),
+ // conv_param_t<>(1, 1024, 1024, {7, 7}, 32, {3, 3}, {1, 1}, {1, 1, 1,
+ // 1}),
+ // BatchSize > 1
+ // conv_param_t<>(2, 128, 128, {56, 48}, 32, {3, 3}, {1, 1}, {1, 1, 1,
+ // 1}),
+ };
+
+ bool flush = true;
+ std::vector<char> llc;
+
+ if (flush) {
+ llc.resize(128 * 1024 * 1024, 1.0);
+ }
+
+ constexpr int NWARMUP = 4;
+ constexpr int NITER = 10;
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ cout << "WARNING: the timer may be inaccurate when used by multiple threads."
+ << endl;
+ cout << "MB, "
+ << "IC, "
+ << "OC, "
+ << "IH, "
+ << "IW, "
+ << "KH, "
+ << "KW, "
+ << "stride_h, "
+ << "stride_w, "
+ << "pad_h, "
+ << "pad_w, "
+ << "Type, "
+ << "M, "
+ << "N, "
+ << "K, "
+ << "Im2Col (ms), "
+ << "Packing (ms), "
+ << "Kernel (ms), "
+ << "Postprocessing (ms), "
+ << "fbgemmPacked (ms), "
+ << "Total (ms), "
+ << "GOPS" << endl;
+#else
+ cout << setw(8) << "MB, "
+ << "IC, "
+ << "OC, "
+ << "IH, "
+ << "IW, "
+ << "KH, "
+ << "KW, "
+ << "stride_h, "
+ << "stride_w, "
+ << "pad_h, "
+ << "pad_w, "
+ << "Type, "
+ << "M, "
+ << "N, "
+ << "K, " << setw(5) << "GOPS" << endl;
+#endif
+
+ chrono::time_point<chrono::high_resolution_clock> begin, end;
+ for (auto conv_p : shapes) {
+ if (conv_p.IC % conv_p.G != 0) {
+ cout << "Error: Number of input channels " << conv_p.IC
+ << " is not a multiple of groups " << conv_p.G << endl;
+ continue;
+ }
+ if (conv_p.OC % conv_p.G != 0) {
+ cout << "Error: Number of output channels " << conv_p.OC
+ << " is not a multiple of groups " << conv_p.G << endl;
+ continue;
+ }
+
+ int IC_per_G = conv_p.IC / conv_p.G;
+ int OC_per_G = conv_p.OC / conv_p.G;
+
+ aligned_vector<uint8_t> Aint8(
+ conv_p.MB * conv_p.IN_DIM[0] * conv_p.IN_DIM[1] * conv_p.IC, 0);
+
+ // aligned_vector<uint8_t> Aint8_im2col(
+ // conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.K[0] *
+ // conv_p.K[1] * conv_p.IC,
+ // 0);
+
+ aligned_vector<int8_t> Bint8(
+ conv_p.K[0] * conv_p.K[1] * conv_p.G * IC_per_G * OC_per_G, 0);
+ aligned_vector<int8_t> Bp(
+ conv_p.K[0] * conv_p.K[1] * conv_p.G * IC_per_G * OC_per_G, 0);
+
+ aligned_vector<int32_t> Cint32_ref(
+ conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OC, 0);
+
+ aligned_vector<uint8_t> Cint8_ref(
+ conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OC, 0);
+
+ aligned_vector<int32_t> Cint32_fb_fused(
+ conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OC, 0);
+
+ aligned_vector<uint8_t> Cint8_fb_fused(
+ conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OC, 0);
+
+ aligned_vector<int32_t> Cint32_fb_direct(
+ conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OC, 0);
+
+ aligned_vector<uint8_t> Cint8_fb_direct(
+ conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OC, 0);
+
+ // cout << conv_p.toString() << endl;
+
+ // A matrix (input activations)
+ randFill<uint8_t>(Aint8, 0, 5);
+ int32_t Aint8_zero_point = 4;
+
+ // B matrix (weights)
+ randFill<int8_t>(Bint8, -4, 4);
+ aligned_vector<int32_t> Bint8_zero_point(1);
+ randFill(Bint8_zero_point, -3, -1);
+
+ aligned_vector<float> C_multiplier(Bint8_zero_point.size());
+ randFill(C_multiplier, 0.1234f / 2, 0.1234f * 3 / 2);
+ int32_t C_zero_pt = 5;
+
+ int R = conv_p.K[0];
+ int S = conv_p.K[1];
+
+ // reference implementation
+ conv_ref(
+ conv_p,
+ Aint8.data(),
+ Aint8_zero_point,
+ Bint8.data(),
+ Cint32_ref.data());
+
+ // matrix dimensions after im2col
+ int MDim = conv_p.MB * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1];
+ int NDim = conv_p.OC / conv_p.G;
+ int KDim = conv_p.K[0] * conv_p.K[1] * conv_p.IC;
+
+ // computing row offset
+ vector<int32_t> row_offsets(MDim);
+ vector<uint8_t> Aint8_im2col(MDim * KDim);
+ im2col_ref(conv_p, Aint8.data(), Aint8_zero_point, Aint8_im2col.data());
+
+ // computing column offset
+ vector<int32_t> col_offsets(conv_p.OC);
+ for (int g = 0; g < conv_p.G; ++g) {
+ col_offsets_with_zero_pt_s8acc32_ref(
+ R * S * IC_per_G,
+ OC_per_G,
+ OC_per_G,
+ Bint8.data() + g * R * S * IC_per_G * OC_per_G,
+ Bint8_zero_point.data(),
+ col_offsets.data() + g * OC_per_G,
+ conv_p.OC);
+ }
+
+ for (int g = 0; g < conv_p.G; ++g) {
+ row_offsets_u8acc32_ref(
+ MDim,
+ R * S * IC_per_G,
+ KDim,
+ Aint8_im2col.data() + g * R * S * IC_per_G,
+ row_offsets.data());
+
+ requantize_u8acc32_ref(
+ MDim,
+ NDim,
+ conv_p.G * NDim,
+ Cint32_ref.data() + g * NDim,
+ Cint8_ref.data() + g * NDim,
+ C_multiplier.data() + g * NDim / conv_p.OC,
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point.data() + g * NDim / conv_p.OC,
+ row_offsets.data(),
+ col_offsets.data() + g * NDim,
+ nullptr,
+ conv_p.OC);
+ }
+ // printMatrix(matrix_op_t::NoTranspose, Cint8_ref.data(), MDim, NDim, NDim,
+ // "B unpacked");
+
+ // printMatrix(matrix_op_t::NoTranspose, Bint8.data(), KDim, NDim, NDim,
+ // "B unpacked");
+ // packedB.printPackedMatrix("B Packed");
+
+ double nops = 2.0 * static_cast<double>(NITER) * MDim * NDim * KDim;
+ double ttot = 0.0;
+ string runType;
+
+ vector<int32_t> row_offset_buf;
+ row_offset_buf.resize(
+ PackAWithIm2Col<uint8_t, int32_t>::rowOffsetBufferSize());
+
+ PackAWithIm2Col<uint8_t, int32_t> packA(
+ conv_p, Aint8.data(), nullptr, Aint8_zero_point, row_offset_buf.data());
+
+ PackBMatrix<int8_t, int32_t> packedB(
+ matrix_op_t::NoTranspose,
+ KDim,
+ NDim,
+ Bint8.data(),
+ NDim,
+ nullptr,
+ conv_p.G);
+
+ // no-op output process objects
+ DoNothing<> doNothingObj{};
+ ReQuantizeOutput<false, QuantizationGranularity::TENSOR> outputProcObj(
+ doNothingObj,
+ C_multiplier.data(),
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point.data(),
+ packA.getRowOffsetBuffer(),
+ col_offsets.data(),
+ nullptr,
+ conv_p.G * NDim,
+ conv_p.G);
+
+ runType = "FusedIm2Col";
+ ttot = 0;
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ double im2col_time = 0.0;
+ double total_im2col_time = 0.0;
+ double total_packing_time = 0.0;
+ double total_computing_time = 0.0;
+ double total_kernel_time = 0.0;
+ double total_postprocessing_time = 0.0;
+ double total_run_time = 0.0;
+#endif
+ for (auto i = 0; i < NWARMUP + NITER; ++i) {
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ packing_time = 0.0;
+ computing_time = 0.0;
+ kernel_time = 0.0;
+ postprocessing_time = 0.0;
+ run_time = 0.0;
+#endif
+ llc_flush(llc);
+ begin = chrono::high_resolution_clock::now();
+ fbgemmPacked(
+ packA,
+ packedB,
+ Cint8_fb_fused.data(),
+ Cint32_fb_fused.data(),
+ conv_p.G * NDim,
+ outputProcObj,
+ 0,
+ 1);
+ end = chrono::high_resolution_clock::now();
+
+ if (i >= NWARMUP) {
+ auto dur = chrono::duration_cast<chrono::nanoseconds>(end - begin);
+ ttot += dur.count();
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ total_packing_time += packing_time;
+ total_computing_time += computing_time;
+ total_kernel_time += kernel_time;
+ total_postprocessing_time += postprocessing_time;
+ total_run_time += run_time;
+#endif
+ }
+ }
+
+ cout << setw(4) << conv_p.MB << ", " << conv_p.IC << ", " << conv_p.OC
+ << ", " << conv_p.IN_DIM[0] << ", " << conv_p.IN_DIM[1] << ", "
+ << conv_p.G << ", " << conv_p.K[0] << ", " << conv_p.K[1] << ", "
+ << conv_p.stride[0] << ", " << conv_p.stride[1] << ", "
+ << conv_p.pad[0] << ", " << conv_p.pad[1] << ", ";
+
+ cout << setw(13) << runType << ", " << setw(5) << fixed << setw(5)
+ << setw(6) << MDim << ", " << setw(6) << NDim << ", " << setw(6)
+ << KDim << ", ";
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ cout << fixed << setprecision(6) << setw(8) << 0 << ", "
+ << total_packing_time / (double)NITER / 1e6 << ", "
+ << total_kernel_time / (double)NITER / 1e6 << ", "
+ << total_postprocessing_time / (double)NITER / 1e6 << ", "
+ << total_run_time / (double)NITER / 1e6 << ", "
+ << ttot / (double)NITER / 1e6 << ", ";
+#endif
+ cout << setprecision(2) << nops / ttot << endl;
+
+ // correctness check
+ for (int n = 0; n < conv_p.MB; ++n) {
+ for (int h = 0; h < conv_p.OUT_DIM[0]; ++h) {
+ for (int w = 0; w < conv_p.OUT_DIM[1]; ++w) {
+ for (int k = 0; k < conv_p.OC; ++k) {
+ int32_t expected = Cint8_ref
+ [((n * conv_p.OUT_DIM[0] + h) * conv_p.OUT_DIM[1] + w) *
+ conv_p.OC +
+ k];
+ int32_t actual = Cint8_fb_fused
+ [((n * conv_p.OUT_DIM[0] + h) * conv_p.OUT_DIM[1] + w) *
+ conv_p.OC +
+ k];
+ if (expected != actual) {
+ cout << "Im2Col fused results differ at (" << n << ", " << h
+ << ", " << w << ", " << k << ")."
+ << " expected:" << expected << " actual:" << actual << endl;
+ }
+ }
+ }
+ }
+ }
+ // compare_buffers(Cint32_ref.data(), Cint32_fb_fused.data(), MDim, NDim *
+ // conv_p.G, NDim*conv_p.G, 5);
+
+ runType = "direct";
+ ttot = 0;
+
+ vector<int32_t> row_offset_buf_direct(rowOffsetBufferSizeGConv(conv_p));
+
+ PackWeightMatrixForGConv<int8_t> packedWeights(
+ matrix_op_t::NoTranspose, conv_p, Bint8.data(), nullptr);
+
+ ReQuantizeOutput<false, QuantizationGranularity::TENSOR> reqObj(
+ doNothingObj,
+ C_multiplier.data(),
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point.data(),
+ row_offset_buf_direct.data(),
+ col_offsets.data(),
+ nullptr,
+ conv_p.OC,
+ conv_p.G);
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ total_im2col_time = 0.0;
+ total_packing_time = 0.0;
+ total_computing_time = 0.0;
+ total_kernel_time = 0.0;
+ total_postprocessing_time = 0.0;
+ total_run_time = 0.0;
+#endif
+ for (auto i = 0; i < NWARMUP + NITER; ++i) {
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ im2col_time = 0.0;
+ packing_time = 0.0;
+ computing_time = 0.0;
+ kernel_time = 0.0;
+ postprocessing_time = 0.0;
+ run_time = 0.0;
+#endif
+ llc_flush(llc);
+ begin = chrono::high_resolution_clock::now();
+
+ // im2col_ref(conv_p, Aint8.data(), Aint8_zero_point,
+ // Aint8_im2col.data());
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ end = chrono::high_resolution_clock::now();
+ im2col_time =
+ chrono::duration_cast<chrono::nanoseconds>(end - begin).count();
+#endif
+
+ // printMatrix(matrix_op_t::NoTranspose, Aint8_im2col.data(), MDim, KDim,
+ // KDim, "A_out after im2col unpacked");
+
+ fbgemmGroupwiseConv(
+ conv_p,
+ Aint8.data(),
+ Aint8_zero_point,
+ row_offset_buf_direct.data(),
+ packedWeights,
+ Cint8_fb_direct.data(),
+ Cint32_fb_direct.data(),
+ reqObj,
+ 0,
+ 1);
+
+ end = chrono::high_resolution_clock::now();
+
+ if (i >= NWARMUP) {
+ auto dur = chrono::duration_cast<chrono::nanoseconds>(end - begin);
+ ttot += dur.count();
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ total_im2col_time += im2col_time;
+ total_packing_time += packing_time;
+ total_computing_time += computing_time;
+ total_kernel_time += kernel_time;
+ total_postprocessing_time += postprocessing_time;
+ total_run_time += run_time;
+#endif
+ }
+ }
+
+ ((volatile char*)(llc.data()));
+
+ // packedB.printPackedMatrix("bench B Packed");
+ // printMatrix(matrix_op_t::NoTranspose, Cint32_fb_fused.data(), MDim, NDim,
+ // NDim, "C fb fp32"); printMatrix(matrix_op_t::NoTranspose,
+ // Cint32_fb_direct.data(), MDim, NDim, NDim, "C fb2 fp32");
+ // printMatrix(matrix_op_t::NoTranspose,
+ // Cint32_ref.data(), MDim, NDim, NDim, "C ref fp32");
+
+ cout << setw(4) << conv_p.MB << ", " << conv_p.IC << ", " << conv_p.OC
+ << ", " << conv_p.IN_DIM[0] << ", " << conv_p.IN_DIM[1] << ", "
+ << conv_p.G << ", " << conv_p.K[0] << ", " << conv_p.K[1] << ", "
+ << conv_p.stride[0] << ", " << conv_p.stride[1] << ", "
+ << conv_p.pad[0] << ", " << conv_p.pad[1] << ", ";
+
+ cout << setw(13) << runType << ", " << setw(5) << fixed << setw(5)
+ << setw(6) << MDim << ", " << setw(6) << NDim << ", " << setw(6)
+ << KDim << ", ";
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ cout << fixed << setprecision(6) << setw(8)
+ << total_im2col_time / (double)NITER / 1e6 << ", "
+ << total_packing_time / (double)NITER / 1e6 << ", "
+ << total_kernel_time / (double)NITER / 1e6 << ", "
+ << total_postprocessing_time / (double)NITER / 1e6 << ", "
+ << total_run_time / (double)NITER / 1e6 << ", "
+ << ttot / (double)NITER / 1e6 << ", ";
+#endif
+ cout << setprecision(2) << nops / ttot << endl;
+
+ // correctness check
+ for (int n = 0; n < conv_p.MB; ++n) {
+ for (int h = 0; h < conv_p.OUT_DIM[0]; ++h) {
+ for (int w = 0; w < conv_p.OUT_DIM[1]; ++w) {
+ for (int k = 0; k < conv_p.OC; ++k) {
+ int32_t expected = Cint8_ref
+ [((n * conv_p.OUT_DIM[0] + h) * conv_p.OUT_DIM[1] + w) *
+ conv_p.OC +
+ k];
+ int32_t actual = Cint8_fb_direct
+ [((n * conv_p.OUT_DIM[0] + h) * conv_p.OUT_DIM[1] + w) *
+ conv_p.OC +
+ k];
+ if (expected != actual) {
+ cout << "direct conv results differ at (" << n << ", " << h
+ << ", " << w << ", " << k << ")."
+ << " expected:" << expected << " actual:" << actual << endl;
+ }
+ }
+ }
+ }
+ }
+ // compare_buffers(Cint32_ref.data(), Cint32_fb_direct.data(), MDim,
+ // NDim*conv_p.G, NDim*conv_p.G, 5);
+ } // shapes
+}
+
+int main() {
+#ifdef _OPENMP
+ omp_set_num_threads(1);
+#endif
+ performance_test();
+ return 0;
+}
diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h
index bca5347..f49da57 100644
--- a/include/fbgemm/Fbgemm.h
+++ b/include/fbgemm/Fbgemm.h
@@ -453,6 +453,56 @@ class FBGEMM_API PackBMatrix final
};
/**
+ * @brief Matrix packed for direct group convolution.
+ * The source matrix is already quantized. Default accumulation
+ * type is int32.
+ */
+template <typename T, typename accT = std::int32_t, int SPATIAL_DIM = 2>
+class FBGEMM_API PackWeightMatrixForGConv {
+ public:
+ using This = PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>;
+ using inpType = T;
+ using accType = accT;
+
+ PackWeightMatrixForGConv() = delete; // no default constructor
+
+ /**
+ * @params pmat if nullptr, a buffer is allocated and owned by this class.
+ *
+ */
+ PackWeightMatrixForGConv(
+ matrix_op_t trans,
+ const conv_param_t<SPATIAL_DIM>& conv_param,
+ const inpType* sdata,
+ inpType* pdata = nullptr);
+
+ /**
+ * @brief Packs a block of source matrix into pmat buffer.
+ */
+ void pack();
+
+ /**
+ * @brief Return packed data
+ */
+ inpType* getBuf() {
+ return pdata_;
+ }
+
+ ~PackWeightMatrixForGConv() {
+ if (bufAllocatedHere_) {
+ free(pdata_);
+ }
+ }
+
+ private:
+ matrix_op_t trans_;
+ const conv_param_t<SPATIAL_DIM> conv_param_;
+ const T* sdata_;
+ T* pdata_;
+ bool bufAllocatedHere_;
+};
+
+/**
* @brief Matrix packed for the first input matrix in GEMM (usually activation),
* and row offsets used for requantization is computed during packing.
* Im2col is fused with packing here. The source matrix is already
@@ -1106,6 +1156,35 @@ FBGEMM_API void fbgemmPacked(
int num_threads);
/**
+ * @brief Perform small-channels-per-group groupwise convolution
+ *
+ */
+
+template <
+ typename packed_W,
+ typename outType,
+ typename processOutputType,
+ int SPATIAL_DIM = 2>
+FBGEMM_API void fbgemmGroupwiseConv(
+ const conv_param_t<SPATIAL_DIM>& conv_param,
+ const std::uint8_t* activations,
+ std::int32_t a_zero_point,
+ std::int32_t* rowOffsetBuf,
+ packed_W& packed_weights,
+ outType* out,
+ std::int32_t* outBuffer,
+ const processOutputType& outProcess,
+ int thread_id,
+ int num_threads);
+/**
+ * @return Size of row offset buffer in number of elements needed for
+ * fbgemmGroupwiseConv
+ */
+template <int SPATIAL_DIM = 2>
+FBGEMM_API int rowOffsetBufferSizeGConv(
+ const conv_param_t<SPATIAL_DIM>& conv_param);
+
+/**
* @brief Perform depthwise separable convolution
*/
template <
@@ -1122,6 +1201,12 @@ void convDepthwiseSeparable(
const processOutputType& output);
/**
+ * @brief Is this groupwise convolution supported?
+ */
+template <int SPATIAL_DIM>
+FBGEMM_API bool fbgemmOptimizedGConv(const conv_param_t<SPATIAL_DIM>& conv_p);
+
+/**
* @brief Allocate __size bytes of uninitialized storage whose alignment is
* specified by __align.
*/
diff --git a/src/Fbgemm.cc b/src/Fbgemm.cc
index 45108d0..9384af6 100644
--- a/src/Fbgemm.cc
+++ b/src/Fbgemm.cc
@@ -192,6 +192,22 @@ void fbgemmPacked(
#endif
}
+template <int SPATIAL_DIM>
+FBGEMM_API bool fbgemmOptimizedGConv(const conv_param_t<SPATIAL_DIM>& conv_p) {
+ int C_per_G = conv_p.IC / conv_p.G;
+ int K_per_G = conv_p.OC / conv_p.G;
+
+ return (SPATIAL_DIM == 2) && (C_per_G == K_per_G) && (C_per_G == 4) &&
+ (conv_p.G % 8 == 0) && (conv_p.K[0] == conv_p.K[1]) &&
+ (conv_p.K[0] == 3) && (conv_p.pad[0] == 1) && (conv_p.pad[1] == 1) &&
+ (conv_p.pad[0] == conv_p.pad[2]) && (conv_p.pad[1] == conv_p.pad[3]) &&
+ (conv_p.dilation[0] == 1) && (conv_p.dilation[0] == conv_p.dilation[1]) &&
+ (conv_p.stride[0] == 1) && (conv_p.stride[0] == conv_p.stride[1]);
+}
+
+template bool fbgemmOptimizedGConv(const conv_param_t<2>& conv_p);
+template bool fbgemmOptimizedGConv(const conv_param_t<3>& conv_p);
+
bool fbgemmSupportedCPU() {
return (cpuinfo_initialize() && cpuinfo_has_x86_avx2());
}
diff --git a/src/GroupwiseConv.h b/src/GroupwiseConv.h
new file mode 100644
index 0000000..a46a895
--- /dev/null
+++ b/src/GroupwiseConv.h
@@ -0,0 +1,248 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+#include <asmjit/asmjit.h>
+#include <cpuinfo.h>
+#include <cassert>
+#include <cstdint>
+#include <map>
+#include <string>
+#include <tuple>
+#include "fbgemm/ConvUtils.h"
+#include "fbgemm/Fbgemm.h"
+#include "fbgemm/Utils.h"
+/*#define FBGEMM_LOG_CODE 1*/
+
+namespace fbgemm {
+
+namespace x86 = asmjit::x86;
+
+using jit_conv_kernel_fp = void (*)(
+ const uint8_t* in_acts,
+ int8_t* wghts,
+ int32_t* out_acts,
+ int32_t a_zero_pt,
+ int32_t height,
+ int32_t width);
+
+using jit_rowoffset_kernel_fp = void (*)(
+ const uint8_t* in_acts,
+ int32_t a_zero_pt,
+ int32_t height,
+ int32_t width,
+ int32_t* row_offset);
+
+template <typename accT = int32_t>
+class GenConvKernel {
+ public:
+ GenConvKernel(const conv_param_t<>& conv_param, std::int32_t zero_point)
+ : WRegs_avx2_{x86::ymm0,
+ x86::ymm1,
+ x86::ymm2,
+ x86::ymm3,
+ x86::ymm4,
+ x86::ymm5,
+ x86::ymm6,
+ x86::ymm7,
+ x86::ymm8} {
+ // vector width in bits
+ if (cpuinfo_initialize()) {
+ if (cpuinfo_has_x86_avx512f()) {
+ vectorWidth_ = 512;
+ } else if (cpuinfo_has_x86_avx2()) {
+ vectorWidth_ = 256;
+ } else {
+ // TODO: Have default path
+ assert(0 && "unsupported architecture");
+ return;
+ }
+ } else {
+ throw std::runtime_error("Failed to initialize cpuinfo!");
+ }
+ zeroPTRegAvx2_ = x86::ymm9;
+ oneReg8BitAvx2_ = x86::ymm10;
+ tmpReg1Avx2_ = x86::ymm11;
+ stPermRegAvx2_ = x86::ymm12;
+ actRegAvx2_ = x86::ymm13;
+ resultRegAvx2_ = x86::ymm14;
+ oneReg16BitAvx2_ = x86::ymm15;
+
+ // vector width in elements; Each element is int8 or uint8
+ VLEN_ = vectorWidth_ / 8;
+
+ if (zero_point == 0) {
+ isZeroPointZero_ = true;
+ } else {
+ isZeroPointZero_ = false;
+ }
+
+ G_ = conv_param.G;
+ K_per_G_ = conv_param.OC / conv_param.G;
+ K_ = conv_param.OC;
+ C_per_G_ = conv_param.IC / conv_param.G;
+ C_ = conv_param.IC;
+ R_ = conv_param.K[0];
+ S_ = conv_param.K[1];
+ H_ = conv_param.OUT_DIM[0];
+ W_ = conv_param.OUT_DIM[1];
+ H_PAD_ = conv_param.pad[0];
+ W_PAD_ = conv_param.pad[1];
+
+ assert(fbgemmOptimizedGConv(conv_param));
+ }
+
+ template <inst_set_t instSet>
+ std::string getCodeLoggingFile(bool rowOffsetKernel = false) {
+ std::string fileName = "conv_";
+ fileName += "G-" + std::to_string(G_);
+ fileName += "_K-" + std::to_string(K_);
+ fileName += "_C-" + std::to_string(C_);
+ fileName += "_R-" + std::to_string(R_);
+ fileName += "_S-" + std::to_string(S_);
+ fileName += "_PADH-" + std::to_string(H_PAD_);
+ fileName += "_PADW-" + std::to_string(W_PAD_);
+ fileName += "_isZeroPointZero-" + std::to_string(isZeroPointZero_);
+ if (rowOffsetKernel) {
+ fileName += "_rowOffset";
+ }
+
+ if (instSet == inst_set_t::avx512) {
+ fileName += "_avx512";
+ } else if (instSet == inst_set_t::avx2) {
+ fileName += "_avx2";
+ }
+ fileName += ".txt";
+ return fileName;
+ }
+
+ ~GenConvKernel() {}
+
+ template <inst_set_t instSet>
+ jit_conv_kernel_fp getOrCreate(const conv_param_t<>& conv_param);
+
+ template <inst_set_t instSet>
+ jit_rowoffset_kernel_fp getOrCreateRowOffset(
+ const conv_param_t<>& conv_param);
+
+ template <inst_set_t instSet>
+ void createVector16BitOne(asmjit::X86Emitter* a);
+
+ template <inst_set_t instSet>
+ void createVector8BitOne(asmjit::X86Emitter* a);
+
+ template <inst_set_t instSet>
+ void setToZeroPt(asmjit::X86Emitter* a, asmjit::X86Ymm destReg);
+
+ template <inst_set_t instSet>
+ void
+ gen8bitFMA(asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm wReg);
+
+ template <inst_set_t instSet>
+ void genForLoadingWeights(asmjit::X86Emitter* a);
+
+ template <inst_set_t instSet>
+ void genConstForPermutations(asmjit::X86Emitter* a);
+
+ template <inst_set_t instSet>
+ void genForTopEdge(asmjit::X86Emitter* a);
+
+ template <inst_set_t instSet>
+ void genForLeftEdge(asmjit::X86Emitter* a);
+
+ template <inst_set_t instSet>
+ void genForRightEdge(asmjit::X86Emitter* a);
+
+ template <inst_set_t instSet>
+ void genForBottomEdge(asmjit::X86Emitter* a);
+
+ template <inst_set_t instSet>
+ void genCoreInsts(asmjit::X86Emitter* a);
+
+ template <inst_set_t instSet>
+ void storeResult(asmjit::X86Emitter* a, int offset = 0);
+
+ // for Rowoffset kernel
+ template <inst_set_t instSet>
+ void gen8BitSum(asmjit::X86Emitter* a, asmjit::X86Ymm aReg);
+
+ template <inst_set_t instSet>
+ void genForTopEdgeRowoffset(asmjit::X86Emitter* a);
+
+ template <inst_set_t instSet>
+ void genForLeftEdgeRowoffset(asmjit::X86Emitter* a);
+
+ template <inst_set_t instSet>
+ void genForRightEdgeRowoffset(asmjit::X86Emitter* a);
+
+ template <inst_set_t instSet>
+ void genForBottomEdgeRowoffset(asmjit::X86Emitter* a);
+
+ template <inst_set_t instSet>
+ void genRowoffsetCorners(asmjit::X86Emitter* a);
+
+ template <inst_set_t instSet>
+ void genRowoffsetCore(asmjit::X86Emitter* a);
+
+ template <inst_set_t instSet>
+ void storeResultRowoffset(asmjit::X86Emitter* a, int offset = 0);
+
+ static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit.
+ static thread_local asmjit::CodeHolder code_; ///< JIT Code Holder for asmjit.
+ static thread_local std::
+ map<std::tuple<bool, int, int, int>, jit_conv_kernel_fp>
+ codeCache_; ///< JIT Code Cache for reuse.
+ static thread_local std::
+ map<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp>
+ codeCacheRowOffset_; ///< JIT Code Cache for row offset kernel.
+
+ private:
+ int vectorWidth_; ///< Vector width in bits.
+ int VLEN_; ///< Vector width in elements.
+ // avx2 specific
+ asmjit::X86Ymm
+ WRegs_avx2_[9]; ///< AVX2 ymm registers for weights in the micro-kernel.
+ asmjit::X86Ymm zeroPTRegAvx2_;
+ asmjit::X86Ymm tmpReg1Avx2_;
+ asmjit::X86Ymm stPermRegAvx2_;
+ asmjit::X86Ymm actRegAvx2_;
+ asmjit::X86Ymm resultRegAvx2_;
+ asmjit::X86Ymm oneReg8BitAvx2_;
+ asmjit::X86Ymm oneReg16BitAvx2_;
+
+ // arguments to the function created
+ asmjit::X86Gp in_acts_R_;
+ asmjit::X86Gp wghts_R_;
+ asmjit::X86Gp out_acts_R_;
+ asmjit::X86Gp a_zero_pt_R_;
+ asmjit::X86Gp H_R_;
+ asmjit::X86Gp W_R_;
+ asmjit::X86Gp row_offset_R_;
+
+ // Used registers
+ asmjit::X86Gp loopR1_;
+ asmjit::X86Gp loopR2_;
+ asmjit::X86Gp scratchReg1_;
+ asmjit::X86Gp scratchReg2_;
+
+ // Other parameters
+ bool isZeroPointZero_;
+
+ // current conv parameters
+ int G_; ///< Number of groups
+ int K_; ///< Number of output channels
+ int K_per_G_; ///< Number of output channels per group
+ int C_; ///< Number of input channels
+ int C_per_G_; ///< Number of input channels per group
+ int R_; ///< Filter/Kernel height
+ int S_; ///< Filter/Kernel width
+ int H_;
+ int W_;
+ int H_PAD_; ///< Padding for height (top and bottom)
+ int W_PAD_; ///< Padding for width (left and right)
+};
+
+} // namespace fbgemm
diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc
new file mode 100644
index 0000000..8298f4c
--- /dev/null
+++ b/src/GroupwiseConvAcc32Avx2.cc
@@ -0,0 +1,1552 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <asmjit/asmjit.h>
+#include <cpuinfo.h>
+#include <immintrin.h>
+#include <array>
+#include <iostream>
+#include <map>
+#include <stdexcept>
+#include <tuple>
+#include "GroupwiseConv.h"
+#include "RefImplementations.h"
+#include "TransposeUtils.h"
+#include "fbgemm/Fbgemm.h"
+
+namespace fbgemm {
+
+using namespace std;
+
+template <typename accT>
+thread_local asmjit::JitRuntime GenConvKernel<accT>::rt_;
+
+template <typename accT>
+thread_local asmjit::CodeHolder GenConvKernel<accT>::code_;
+
+template <typename accT>
+thread_local std::map<std::tuple<bool, int, int, int>, jit_conv_kernel_fp>
+ GenConvKernel<accT>::codeCache_;
+
+template <typename accT>
+thread_local std::map<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp>
+ GenConvKernel<accT>::codeCacheRowOffset_;
+
+namespace x86 = asmjit::x86;
+
+void calculateRowOffsets(
+ const conv_param_t<>& conv_param,
+ const uint8_t* activations,
+ int32_t* rowOffsetBuf,
+ int32_t a_zero_point,
+ int groupNum) {
+ int H = conv_param.OUT_DIM[0];
+ int W = conv_param.OUT_DIM[1];
+ int G = conv_param.G;
+ int C_per_G = conv_param.IC / conv_param.G;
+ int H_PAD = conv_param.pad[0];
+ int W_PAD = conv_param.pad[1];
+ // calculate row offset
+ for (int h = 0; h < H; ++h) {
+ for (int w = 0; w < W; ++w) {
+ int32_t sum = 0;
+ for (int r = 0; r < conv_param.K[0]; ++r) {
+ int h_in = -H_PAD + h + r;
+ for (int s = 0; s < conv_param.K[1]; ++s) {
+ int w_in = -W_PAD + w + s;
+ for (int c = 0; c < C_per_G; ++c) {
+ if (h_in < 0 || h_in >= H || w_in < 0 || w_in >= W) {
+ sum += a_zero_point;
+ } else {
+ sum +=
+ activations[((h_in * W + w_in) * G + groupNum) * C_per_G + c];
+ }
+ }
+ }
+ }
+ rowOffsetBuf[h * W + w] = sum;
+ }
+ }
+}
+
+tuple<bool, int, int, int> getKernelSig(
+ const conv_param_t<>& conv_param,
+ bool isZeroPointZero) {
+ int C_per_G = conv_param.IC / conv_param.G;
+ int K_per_G = conv_param.OC / conv_param.G;
+ auto kernelSig =
+ std::make_tuple(isZeroPointZero, conv_param.G, C_per_G, K_per_G);
+ return kernelSig;
+}
+
+template <typename accT = int32_t>
+jit_conv_kernel_fp getOrCreateConvKernel(
+ const conv_param_t<>& conv_param,
+ int a_zero_point) {
+ // Note: Wrong code is generated if it's not one of the supported convolution
+ assert(fbgemmOptimizedGConv<2>(conv_param));
+ auto kernelSig = getKernelSig(conv_param, a_zero_point == 0);
+ if (GenConvKernel<accT>::codeCache_.find(kernelSig) !=
+ GenConvKernel<accT>::codeCache_.end()) {
+ return GenConvKernel<accT>::codeCache_[kernelSig];
+ } else {
+ auto genObj = GenConvKernel<accT>(conv_param, a_zero_point);
+ // TODO: Instruction set based dispatch
+ return genObj.template getOrCreate<inst_set_t::avx2>(conv_param);
+ }
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::createVector8BitOne<inst_set_t::avx2>(
+ asmjit::X86Emitter* a) {
+ // create 8-bit 1s
+ // i.e., oneReg16BitAvx2_[0:7] contains 0x01, oneReg8BitAvx2_[8:15] contains
+ // 0x01 and so on
+ a->vpcmpeqw(oneReg8BitAvx2_, oneReg8BitAvx2_, oneReg8BitAvx2_);
+ a->vpabsb(oneReg8BitAvx2_, oneReg8BitAvx2_);
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::createVector16BitOne<inst_set_t::avx2>(
+ asmjit::X86Emitter* a) {
+ // create 16-bit 1s
+ // i.e., oneReg16BitAvx2_[0:15] contains 0x0001, oneReg16BitAvx2_[16:31]
+ // contains 0x0001 and so on
+ a->vpcmpeqw(oneReg16BitAvx2_, oneReg16BitAvx2_, oneReg16BitAvx2_);
+ a->vpsrlw(oneReg16BitAvx2_, oneReg16BitAvx2_, 15);
+}
+template <>
+template <>
+void GenConvKernel<int32_t>::setToZeroPt<inst_set_t::avx2>(
+ asmjit::X86Emitter* a,
+ asmjit::X86Ymm destReg) {
+ // make destReg all zeros
+ a->vxorps(destReg, destReg, destReg);
+ asmjit::X86Xmm const_reg_xmm = x86::xmm10;
+ // move zero point to xmm10
+ a->movq(const_reg_xmm, a_zero_pt_R_);
+ // make copies of zero point
+ a->vbroadcastsd(x86::ymm10, const_reg_xmm);
+ // shuffle
+ // overall impact is that destReg contains 32 8-bit values equal to the lower
+ // 8-bits of a_zero_pt_R_
+ a->vpshufb(destReg, x86::ymm10, destReg);
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::genConstForPermutations<inst_set_t::avx2>(
+ asmjit::X86Emitter* a) {
+ asmjit::X86Gp permute_const_reg = a->gpzRef(12);
+ asmjit::X86Xmm const_reg_xmm = x86::xmm10;
+ // We have 1st group in even lanes and 2nd group in odd lanes.
+ // Permute to put 1st group to lower 128-bit and 2nd group in upper
+ // 128-bit.
+ // load 7, 5, 3, 1, 6, 4, 2, 0 in a 64-bit reg
+ a->mov(permute_const_reg, 0x0705030106040200);
+ a->movq(const_reg_xmm, permute_const_reg);
+ // Zero extend 8 packed 8-bit integers in the low 8 bytes of const_reg_xmm to
+ // 8 packed 32-bit integers in stPermRegAvx2_
+ a->vpmovzxbd(stPermRegAvx2_, const_reg_xmm);
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::storeResult<inst_set_t::avx2>(
+ asmjit::X86Emitter* a,
+ int offset) {
+ // store with permutation
+ a->vpermd(resultRegAvx2_, stPermRegAvx2_, resultRegAvx2_);
+ a->vmovups(x86::dword_ptr(out_acts_R_, offset), resultRegAvx2_);
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::storeResultRowoffset<inst_set_t::avx2>(
+ asmjit::X86Emitter* a,
+ int offset) {
+ // store
+ a->vmovups(x86::dword_ptr(row_offset_R_, offset), resultRegAvx2_);
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::genForLoadingWeights<inst_set_t::avx2>(
+ asmjit::X86Emitter* a) {
+ // load weights
+ for (int r = 0; r < R_; ++r) {
+ for (int s = 0; s < S_; ++s) {
+ a->vmovaps(
+ WRegs_avx2_[r * S_ + s],
+ x86::dword_ptr(
+ wghts_R_,
+ (r * S_ + s) * G_ * K_per_G_ * C_per_G_ * sizeof(int8_t)));
+ }
+ }
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::gen8bitFMA<inst_set_t::avx2>(
+ asmjit::X86Emitter* a,
+ asmjit::X86Ymm aReg,
+ asmjit::X86Ymm wReg) {
+ a->vpmaddubsw(tmpReg1Avx2_, aReg, wReg);
+ a->vpmaddwd(tmpReg1Avx2_, oneReg16BitAvx2_, tmpReg1Avx2_);
+ a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_);
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::gen8BitSum<inst_set_t::avx2>(
+ asmjit::X86Emitter* a,
+ asmjit::X86Ymm aReg) {
+ a->vpmaddubsw(tmpReg1Avx2_, aReg, oneReg8BitAvx2_);
+ a->vpmaddwd(tmpReg1Avx2_, tmpReg1Avx2_, oneReg16BitAvx2_);
+ a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_);
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::genForTopEdge<inst_set_t::avx2>(
+ asmjit::X86Emitter* a) {
+ // top-left corner code
+ // zero out the results register
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ for (int r = 0; r < R_; ++r) {
+ int h_in = -H_PAD_ + r;
+ if (h_in >= 0) {
+ a->imul(
+ scratchReg1_,
+ W_R_,
+ static_cast<asmjit::Imm>(h_in * C_ * sizeof(uint8_t)));
+ }
+ for (int s = 0; s < S_; ++s) {
+ int w_in = -W_PAD_ + s;
+ if (h_in >= 0 && w_in >= 0) {
+ a->vbroadcastsd(
+ actRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_, scratchReg1_, 0, w_in * C_ * sizeof(uint8_t)));
+ gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ } else {
+ if (!isZeroPointZero_) {
+ gen8bitFMA<inst_set_t::avx2>(
+ a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ }
+ }
+ }
+ }
+ storeResult<inst_set_t::avx2>(a);
+
+ a->add(out_acts_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+
+ // top edge excluding corners
+ asmjit::Label LoopTopEdge = a->newLabel();
+ a->mov(loopR2_, static_cast<asmjit::Imm>(W_PAD_));
+ a->bind(LoopTopEdge);
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ if (!isZeroPointZero_) {
+ for (int r = 0; r < H_PAD_; ++r) {
+ for (int s = 0; s < S_; ++s) {
+ gen8bitFMA<inst_set_t::avx2>(a, zeroPTRegAvx2_, WRegs_avx2_[s]);
+ }
+ }
+ }
+ for (int r = H_PAD_; r < R_; ++r) {
+ int h_in = -H_PAD_ + r;
+ a->imul(
+ scratchReg1_,
+ W_R_,
+ static_cast<asmjit::Imm>(h_in * C_ * sizeof(uint8_t)));
+ for (int s = 0; s < S_; ++s) {
+ a->vbroadcastsd(
+ actRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t)));
+ gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ }
+ }
+ a->add(in_acts_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+
+ storeResult<inst_set_t::avx2>(a);
+
+ a->add(out_acts_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+ a->mov(loopR1_, W_R_);
+ a->sub(loopR1_, static_cast<asmjit::Imm>(W_PAD_));
+ a->inc(loopR2_);
+ a->cmp(loopR2_, loopR1_);
+ a->jl(LoopTopEdge);
+ a->mov(scratchReg2_, W_R_);
+ a->imul(scratchReg2_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->sub(
+ scratchReg2_,
+ static_cast<asmjit::Imm>(2 * W_PAD_ * C_ * sizeof(uint8_t)));
+ a->sub(in_acts_R_, scratchReg2_);
+
+ // top-right corner code
+
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ if (!isZeroPointZero_) {
+ for (int r = 0; r < H_PAD_; ++r) {
+ for (int s = 0; s < S_; ++s) {
+ gen8bitFMA<inst_set_t::avx2>(
+ a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ }
+ }
+ }
+ for (int r = H_PAD_; r < R_; ++r) {
+ int h_in = -H_PAD_ + r;
+ for (int s = 0; s < S_ - W_PAD_; ++s) {
+ a->imul(
+ scratchReg1_,
+ W_R_,
+ static_cast<asmjit::Imm>(h_in * C_ * sizeof(uint8_t)));
+ a->mov(scratchReg2_, W_R_);
+ a->sub(scratchReg2_, static_cast<asmjit::Imm>(R_ - W_PAD_ - s));
+ a->imul(scratchReg2_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->add(scratchReg1_, scratchReg2_);
+ a->vbroadcastsd(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_));
+ gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ }
+ if (!isZeroPointZero_) {
+ for (int s = S_ - W_PAD_; s < S_; ++s) {
+ gen8bitFMA<inst_set_t::avx2>(
+ a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ }
+ }
+ }
+ storeResult<inst_set_t::avx2>(a);
+ a->add(out_acts_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+
+ // reset output activation pointer
+ a->imul(scratchReg1_, W_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+ a->sub(out_acts_R_, scratchReg1_);
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::genForLeftEdge<inst_set_t::avx2>(
+ asmjit::X86Emitter* a) {
+ // left edge excluding corners
+ asmjit::Label LoopLeftEdge = a->newLabel();
+ a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_));
+ a->bind(LoopLeftEdge);
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ a->mov(scratchReg1_, loopR1_);
+ a->sub(scratchReg1_, static_cast<asmjit::Imm>(H_PAD_));
+ a->imul(scratchReg1_, W_R_);
+ a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ for (int r = 0; r < R_; ++r) {
+ if (!isZeroPointZero_) {
+ for (int s = 0; s < W_PAD_; ++s) {
+ gen8bitFMA<inst_set_t::avx2>(
+ a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ }
+ }
+ for (int s = W_PAD_; s < S_; ++s) {
+ a->vbroadcastsd(
+ actRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_,
+ scratchReg1_,
+ 0,
+ (s - W_PAD_) * C_ * sizeof(uint8_t)));
+ gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ }
+ a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->add(scratchReg1_, scratchReg2_);
+ }
+
+ a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+ a->add(out_acts_R_, scratchReg2_);
+ storeResult<inst_set_t::avx2>(a);
+
+ a->inc(loopR1_);
+ a->mov(loopR2_, H_R_);
+ a->sub(loopR2_, static_cast<asmjit::Imm>(H_PAD_));
+ a->cmp(loopR1_, loopR2_);
+ a->jl(LoopLeftEdge);
+
+ // reset output activation pointer
+ a->mov(scratchReg2_, H_R_);
+ a->sub(scratchReg2_, static_cast<asmjit::Imm>(2 * H_PAD_));
+ a->imul(scratchReg2_, W_R_);
+ a->imul(scratchReg2_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+ a->sub(out_acts_R_, scratchReg2_);
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::genForRightEdge<inst_set_t::avx2>(
+ asmjit::X86Emitter* a) {
+ // right edge excluding corners
+ asmjit::Label LoopRightEdge = a->newLabel();
+
+ // output pointer to the right edge
+ // (W_ + W_ - 1)*K_*sizeof(int32_t)
+ a->mov(scratchReg2_, W_R_);
+ a->imul(scratchReg2_, 2);
+ a->sub(scratchReg2_, 1);
+ a->imul(scratchReg2_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+ a->add(out_acts_R_, scratchReg2_);
+
+ a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_));
+ a->bind(LoopRightEdge);
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ a->mov(scratchReg1_, loopR1_);
+ a->sub(scratchReg1_, static_cast<asmjit::Imm>(H_PAD_));
+ a->imul(scratchReg1_, W_R_);
+ a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+
+ a->mov(scratchReg2_, W_R_);
+ a->sub(scratchReg2_, static_cast<asmjit::Imm>(2 * W_PAD_));
+ a->imul(scratchReg2_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->add(scratchReg1_, scratchReg2_);
+ for (int r = 0; r < R_; ++r) {
+ for (int s = 0; s < S_ - W_PAD_; ++s) {
+ a->vbroadcastsd(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_));
+ gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ a->add(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ }
+ if (!isZeroPointZero_) {
+ for (int s = S_ - W_PAD_; s < S_; ++s) {
+ gen8bitFMA<inst_set_t::avx2>(
+ a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ }
+ }
+
+ a->sub(
+ scratchReg1_,
+ static_cast<asmjit::Imm>((S_ - W_PAD_) * C_ * sizeof(uint8_t)));
+ a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->add(scratchReg1_, scratchReg2_);
+ }
+
+ // storeResult<inst_set_t::avx2>(a, (W_+W_-1)*K_*sizeof(int32_t));
+ storeResult<inst_set_t::avx2>(a);
+
+ a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+ a->add(out_acts_R_, scratchReg2_);
+ a->mov(loopR2_, H_R_);
+ a->sub(loopR2_, static_cast<asmjit::Imm>(H_PAD_));
+ a->inc(loopR1_);
+ a->cmp(loopR1_, loopR2_);
+ a->jl(LoopRightEdge);
+
+ // reset base
+ a->mov(scratchReg2_, W_R_);
+ a->imul(scratchReg2_, 2);
+ a->sub(scratchReg2_, 1);
+ a->imul(scratchReg2_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+ a->sub(out_acts_R_, scratchReg2_);
+
+ // reset loop increments
+ //(H_ - 2*H_PAD_)*W_*K_*sizeof(int32_t)
+ a->mov(scratchReg2_, H_R_);
+ a->sub(scratchReg2_, static_cast<asmjit::Imm>(2 * H_PAD_));
+ a->imul(scratchReg2_, W_R_);
+ a->imul(scratchReg2_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+ a->sub(out_acts_R_, scratchReg2_);
+ // a->sub(out_acts_R_, (H_ - 2*H_PAD_)*W_*K_*sizeof(int32_t));
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::genForBottomEdge<inst_set_t::avx2>(
+ asmjit::X86Emitter* a) {
+ // bottom-left corner
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ a->mov(scratchReg1_, H_R_);
+ a->sub(scratchReg1_, static_cast<asmjit::Imm>(2 * H_PAD_));
+ a->imul(scratchReg1_, W_R_);
+ a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ for (int r = 0; r < R_ - H_PAD_; ++r) {
+ if (!isZeroPointZero_) {
+ for (int s = 0; s < W_PAD_; ++s) {
+ gen8bitFMA<inst_set_t::avx2>(
+ a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ }
+ }
+ for (int s = W_PAD_; s < S_; ++s) {
+ a->vbroadcastsd(
+ actRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_,
+ scratchReg1_,
+ 0,
+ (s - W_PAD_) * C_ * sizeof(uint8_t)));
+ gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ }
+ a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->add(scratchReg1_, scratchReg2_);
+ }
+ if (!isZeroPointZero_) {
+ for (int r = R_ - H_PAD_; r < R_; ++r) {
+ for (int s = 0; s < S_; ++s) {
+ gen8bitFMA<inst_set_t::avx2>(
+ a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ }
+ }
+ }
+
+ // we updating the last row
+ a->mov(scratchReg1_, H_R_);
+ a->sub(scratchReg1_, 1);
+ a->imul(scratchReg1_, W_R_);
+ a->imul(scratchReg1_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+ a->add(out_acts_R_, scratchReg1_);
+ // storeResult<inst_set_t::avx2>(a, (H_-1)*W_*K_*sizeof(int32_t));
+ storeResult<inst_set_t::avx2>(a);
+ a->add(out_acts_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+
+ // bottom edge excluding corners
+ asmjit::Label LoopBottomEdge = a->newLabel();
+ a->mov(loopR2_, static_cast<asmjit::Imm>(W_PAD_));
+ a->bind(LoopBottomEdge);
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ a->mov(scratchReg1_, H_R_);
+ a->sub(scratchReg1_, static_cast<asmjit::Imm>(2 * H_PAD_));
+ a->imul(scratchReg1_, W_R_);
+ a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ for (int r = 0; r < R_ - W_PAD_; ++r) {
+ // int h_in = H_-2*H_PAD_ + r;
+ for (int s = 0; s < S_; ++s) {
+ a->vbroadcastsd(
+ actRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t)));
+ gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ }
+ a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->add(scratchReg1_, scratchReg2_);
+ }
+
+ if (!isZeroPointZero_) {
+ for (int r = R_ - W_PAD_; r < R_; ++r) {
+ for (int s = 0; s < S_; ++s) {
+ gen8bitFMA<inst_set_t::avx2>(
+ a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ }
+ }
+ }
+
+ a->add(in_acts_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ // storeResult<inst_set_t::avx2>(a, ((H_-1)*W_+1)*K_*sizeof(int32_t));
+ storeResult<inst_set_t::avx2>(a);
+
+ a->add(out_acts_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+ a->inc(loopR2_);
+ a->mov(loopR1_, W_R_);
+ a->sub(loopR1_, static_cast<asmjit::Imm>(W_PAD_));
+ a->cmp(loopR2_, loopR1_);
+ a->jl(LoopBottomEdge);
+ a->mov(scratchReg1_, W_R_);
+ a->sub(scratchReg1_, static_cast<asmjit::Imm>(2 * W_PAD_));
+ a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->sub(in_acts_R_, scratchReg1_);
+ // a->sub(in_acts_R_, (W_ - 2*W_PAD_)*C_*sizeof(uint8_t));
+ // a->sub(out_acts_R_, (W_ - 2*W_PAD_)*K_*sizeof(int32_t));
+
+ // bottom-right corner
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ // input start point
+ // ((H_-(R_-H_PAD_))*W_+(W_-(S_-W_PAD_)))*C_*sizeof(uint8_t)
+ a->mov(scratchReg1_, H_R_);
+ a->sub(scratchReg1_, static_cast<asmjit::Imm>(R_ - H_PAD_));
+ a->imul(scratchReg1_, W_R_);
+ a->add(scratchReg1_, W_R_);
+ a->sub(scratchReg1_, static_cast<asmjit::Imm>(S_ - W_PAD_));
+ a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ for (int r = 0; r < R_ - H_PAD_; ++r) {
+ for (int s = 0; s < S_ - W_PAD_; ++s) {
+ a->vbroadcastsd(
+ actRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t)));
+ gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ }
+ a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->add(scratchReg1_, scratchReg2_);
+ if (!isZeroPointZero_) {
+ for (int s = S_ - W_PAD_; s < S_; ++s) {
+ gen8bitFMA<inst_set_t::avx2>(
+ a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ }
+ }
+ }
+
+ if (!isZeroPointZero_) {
+ for (int r = R_ - H_PAD_; r < R_; ++r) {
+ for (int s = 0; s < S_; ++s) {
+ gen8bitFMA<inst_set_t::avx2>(
+ a, zeroPTRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ }
+ }
+ }
+
+ storeResult<inst_set_t::avx2>(a);
+ // storeResult<inst_set_t::avx2>(a, ((H_-1)*W_+W_-1)*K_*sizeof(int32_t));
+ // reset output pointer
+ a->mov(scratchReg1_, H_R_);
+ a->sub(scratchReg1_, 1);
+ a->imul(scratchReg1_, W_R_);
+ a->add(scratchReg1_, W_R_);
+ a->sub(scratchReg1_, 1);
+ a->imul(scratchReg1_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+ a->sub(out_acts_R_, scratchReg1_);
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::genCoreInsts<inst_set_t::avx2>(
+ asmjit::X86Emitter* a) {
+ // main compute
+ asmjit::Label LoopH = a->newLabel();
+ asmjit::Label LoopW = a->newLabel();
+ // base for output
+ a->mov(scratchReg2_, static_cast<asmjit::Imm>(H_PAD_));
+ a->imul(scratchReg2_, W_R_);
+ a->add(scratchReg2_, static_cast<asmjit::Imm>(W_PAD_));
+ a->imul(scratchReg2_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+ a->add(out_acts_R_, scratchReg2_);
+
+ a->mov(scratchReg1_, W_R_);
+ a->sub(scratchReg1_, static_cast<asmjit::Imm>(W_PAD_));
+
+ // H loop
+ a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_));
+ a->bind(LoopH);
+ // W loop
+ a->mov(loopR2_, static_cast<asmjit::Imm>(W_PAD_));
+ a->bind(LoopW);
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ // compute on all filters
+ for (int r = 0; r < R_; ++r) {
+ for (int s = 0; s < S_; ++s) {
+ a->vbroadcastsd(
+ actRegAvx2_, x86::dword_ptr(in_acts_R_, s * C_ * sizeof(uint8_t)));
+ gen8bitFMA<inst_set_t::avx2>(a, actRegAvx2_, WRegs_avx2_[r * S_ + s]);
+ }
+ a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->add(in_acts_R_, scratchReg2_);
+ }
+ a->imul(
+ scratchReg2_, W_R_, static_cast<asmjit::Imm>(R_ * C_ * sizeof(uint8_t)));
+ a->sub(in_acts_R_, scratchReg2_);
+ // a->add(scratchReg1_, C_*sizeof(uint8_t));
+ a->add(in_acts_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+
+ // storeResult<inst_set_t::avx2>(a, (W_+1)*K_*sizeof(int32_t));
+ storeResult<inst_set_t::avx2>(a);
+
+ a->add(out_acts_R_, static_cast<asmjit::Imm>(K_ * sizeof(int32_t)));
+
+ a->inc(loopR2_);
+ a->cmp(loopR2_, scratchReg1_);
+ a->jl(LoopW);
+ // add (W_ - 2*W_PAD_)*C_*sizeof(uint8_t) and subtract W_*C_*sizeof(uint8_t)
+ a->add(
+ in_acts_R_, static_cast<asmjit::Imm>(2 * W_PAD_ * C_ * sizeof(uint8_t)));
+ // a->sub(in_acts_R_, (W_ - 2*W_PAD_)*C_*sizeof(uint8_t));
+ // a->add(in_acts_R_, W_*C_*sizeof(uint8_t));
+ a->add(
+ out_acts_R_, static_cast<asmjit::Imm>(2 * W_PAD_ * K_ * sizeof(int32_t)));
+ // a->sub(out_acts_R_, (W_ - 2*W_PAD_)*K_*sizeof(int32_t));
+ // a->add(out_acts_R_, W_*K_*sizeof(int32_t));
+
+ a->inc(loopR1_);
+ a->mov(scratchReg2_, H_R_);
+ a->sub(scratchReg2_, static_cast<asmjit::Imm>(H_PAD_));
+ a->cmp(loopR1_, scratchReg2_);
+ a->jl(LoopH);
+}
+
+template <>
+template <>
+jit_conv_kernel_fp GenConvKernel<int32_t>::getOrCreate<inst_set_t::avx2>(
+ const conv_param_t<>& conv_param) {
+ code_.reset(false);
+ code_.init(rt_.getCodeInfo());
+ asmjit::X86Assembler assembler(&code_);
+ asmjit::X86Emitter* a = assembler.asEmitter();
+
+#if defined(FBGEMM_LOG_CODE)
+ // log code to a file
+ FILE* codeLogfile =
+ fopen(getCodeLoggingFile<inst_set_t::avx2>(false).c_str(), "w");
+ asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
+ if (codeLogger) {
+ code_.setLogger(codeLogger);
+ }
+#endif
+
+ // arguments to the function created
+ in_acts_R_ = a->zdi();
+ wghts_R_ = a->zsi();
+ out_acts_R_ = a->zdx();
+ a_zero_pt_R_ = a->zcx();
+ H_R_ = a->gpzRef(8);
+ W_R_ = a->gpzRef(9);
+ row_offset_R_ = a->gpzRef(10);
+
+ // register for temporary use
+ scratchReg1_ = a->gpzRef(12);
+ scratchReg2_ = a->gpzRef(13);
+
+ asmjit::FuncDetail func;
+ func.init(asmjit::FuncSignature6<
+ void,
+ uint8_t*,
+ int8_t*,
+ int32_t*,
+ int32_t,
+ int32_t,
+ int32_t>(asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrameInfo ffi;
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindVec,
+ asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindGp, asmjit::Utils::mask(10, 11, 12, 13, 14, 15));
+
+ asmjit::FuncArgsMapper args(&func);
+ args.assignAll(in_acts_R_, wghts_R_, out_acts_R_, a_zero_pt_R_, H_R_, W_R_);
+
+ args.updateFrameInfo(ffi);
+
+ asmjit::FuncFrameLayout layout;
+ layout.init(func, ffi);
+
+ asmjit::FuncUtils::emitProlog(a, layout);
+ asmjit::FuncUtils::allocArgs(a, layout, args);
+
+ createVector16BitOne<inst_set_t::avx2>(a);
+
+ loopR1_ = a->gpzRef(14);
+ loopR2_ = a->gpzRef(15);
+
+ if (!isZeroPointZero_) {
+ setToZeroPt<inst_set_t::avx2>(a, zeroPTRegAvx2_);
+ }
+
+ genForLoadingWeights<inst_set_t::avx2>(a);
+
+ genConstForPermutations<inst_set_t::avx2>(a);
+
+ genForTopEdge<inst_set_t::avx2>(a);
+ genForLeftEdge<inst_set_t::avx2>(a);
+ genForRightEdge<inst_set_t::avx2>(a);
+ genForBottomEdge<inst_set_t::avx2>(a);
+
+ genCoreInsts<inst_set_t::avx2>(a);
+
+ asmjit::FuncUtils::emitEpilog(a, layout);
+
+ jit_conv_kernel_fp fn;
+ asmjit::Error err = rt_.add(&fn, &code_);
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
+ auto kernelSig = getKernelSig(conv_param, isZeroPointZero_);
+ codeCache_[kernelSig] = fn;
+
+#if defined(FBGEMM_LOG_CODE)
+ fclose(codeLogfile);
+ delete codeLogger;
+#endif
+
+ return fn;
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>(
+ asmjit::X86Emitter* a) {
+ // top-left corner code
+ // zero out the results register
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ for (int r = 0; r < R_; ++r) {
+ int h_in = -H_PAD_ + r;
+ if (h_in >= 0) {
+ a->imul(
+ scratchReg1_,
+ W_R_,
+ static_cast<asmjit::Imm>(h_in * C_ * sizeof(uint8_t)));
+ }
+ for (int s = 0; s < S_; ++s) {
+ int w_in = -W_PAD_ + s;
+ if (h_in >= 0 && w_in >= 0) {
+ a->vmovaps(
+ actRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_, scratchReg1_, 0, w_in * C_ * sizeof(uint8_t)));
+ gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
+ } else {
+ if (!isZeroPointZero_) {
+ gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
+ }
+ }
+ }
+ }
+ // store results
+ storeResultRowoffset<inst_set_t::avx2>(a);
+
+ // for C_per_G == 4 and K_per_G == 4, 8 groups processed at a time
+ a->add(row_offset_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
+
+ // top edge excluding corners
+ asmjit::Label LoopTopEdge = a->newLabel();
+ a->mov(loopR2_, static_cast<asmjit::Imm>(W_PAD_));
+ a->bind(LoopTopEdge);
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ if (!isZeroPointZero_) {
+ for (int r = 0; r < H_PAD_; ++r) {
+ for (int s = 0; s < S_; ++s) {
+ gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
+ }
+ }
+ }
+ for (int r = H_PAD_; r < R_; ++r) {
+ int h_in = -H_PAD_ + r;
+ a->imul(
+ scratchReg1_,
+ W_R_,
+ static_cast<asmjit::Imm>(h_in * C_ * sizeof(uint8_t)));
+ for (int s = 0; s < S_; ++s) {
+ a->vmovaps(
+ actRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t)));
+ gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
+ }
+ }
+ a->add(in_acts_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+
+ // store results
+ storeResultRowoffset<inst_set_t::avx2>(a);
+
+ a->add(row_offset_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
+ a->mov(loopR1_, W_R_);
+ a->sub(loopR1_, static_cast<asmjit::Imm>(W_PAD_));
+ a->inc(loopR2_);
+ a->cmp(loopR2_, loopR1_);
+ a->jl(LoopTopEdge);
+ a->mov(scratchReg2_, W_R_);
+ a->imul(scratchReg2_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->sub(
+ scratchReg2_,
+ static_cast<asmjit::Imm>(2 * W_PAD_ * C_ * sizeof(uint8_t)));
+ a->sub(in_acts_R_, scratchReg2_);
+
+ // top-right corner code
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ if (!isZeroPointZero_) {
+ for (int r = 0; r < H_PAD_; ++r) {
+ for (int s = 0; s < S_; ++s) {
+ gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
+ }
+ }
+ }
+ for (int r = H_PAD_; r < R_; ++r) {
+ int h_in = -H_PAD_ + r;
+ for (int s = 0; s < S_ - W_PAD_; ++s) {
+ a->imul(
+ scratchReg1_,
+ W_R_,
+ static_cast<asmjit::Imm>(h_in * C_ * sizeof(uint8_t)));
+ a->mov(scratchReg2_, W_R_);
+ a->sub(scratchReg2_, static_cast<asmjit::Imm>(R_ - W_PAD_ - s));
+ a->imul(scratchReg2_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->add(scratchReg1_, scratchReg2_);
+ a->vmovaps(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_));
+ gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
+ }
+ if (!isZeroPointZero_) {
+ for (int s = S_ - W_PAD_; s < S_; ++s) {
+ gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
+ }
+ }
+ }
+
+ // store results
+ storeResultRowoffset<inst_set_t::avx2>(a);
+
+ a->add(row_offset_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
+
+ // reset output pointer
+ a->imul(scratchReg1_, W_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
+ a->sub(row_offset_R_, scratchReg1_);
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>(
+ asmjit::X86Emitter* a) {
+ // left edge excluding corners
+ asmjit::Label LoopLeftEdge = a->newLabel();
+ a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_));
+ a->bind(LoopLeftEdge);
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ a->mov(scratchReg1_, loopR1_);
+ a->sub(scratchReg1_, static_cast<asmjit::Imm>(H_PAD_));
+ a->imul(scratchReg1_, W_R_);
+ a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ for (int r = 0; r < R_; ++r) {
+ if (!isZeroPointZero_) {
+ for (int s = 0; s < W_PAD_; ++s) {
+ gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
+ }
+ }
+ for (int s = W_PAD_; s < S_; ++s) {
+ a->vmovaps(
+ actRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_,
+ scratchReg1_,
+ 0,
+ (s - W_PAD_) * C_ * sizeof(uint8_t)));
+ gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
+ }
+ a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->add(scratchReg1_, scratchReg2_);
+ }
+
+ a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
+ a->add(row_offset_R_, scratchReg2_);
+ storeResultRowoffset<inst_set_t::avx2>(a);
+
+ a->inc(loopR1_);
+ a->mov(loopR2_, H_R_);
+ a->sub(loopR2_, static_cast<asmjit::Imm>(H_PAD_));
+ a->cmp(loopR1_, loopR2_);
+ a->jl(LoopLeftEdge);
+
+ // reset output pointer
+ a->mov(scratchReg2_, H_R_);
+ a->sub(scratchReg2_, static_cast<asmjit::Imm>(2 * H_PAD_));
+ a->imul(scratchReg2_, W_R_);
+ a->imul(scratchReg2_, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
+ a->sub(row_offset_R_, scratchReg2_);
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>(
+ asmjit::X86Emitter* a) {
+ // right edge excluding corners
+ asmjit::Label LoopRightEdge = a->newLabel();
+
+ // output pointer to the right edge
+ // (W_ + W_ - 1)*8*sizeof(int32_t)
+ a->mov(scratchReg2_, W_R_);
+ a->imul(scratchReg2_, 2);
+ a->sub(scratchReg2_, 1);
+ a->imul(scratchReg2_, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
+ a->add(row_offset_R_, scratchReg2_);
+
+ a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_));
+ a->bind(LoopRightEdge);
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ a->mov(scratchReg1_, loopR1_);
+ a->sub(scratchReg1_, static_cast<asmjit::Imm>(H_PAD_));
+ a->imul(scratchReg1_, W_R_);
+ a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+
+ a->mov(scratchReg2_, W_R_);
+ a->sub(scratchReg2_, static_cast<asmjit::Imm>(2 * W_PAD_));
+ a->imul(scratchReg2_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->add(scratchReg1_, scratchReg2_);
+ for (int r = 0; r < R_; ++r) {
+ for (int s = 0; s < S_ - W_PAD_; ++s) {
+ a->vbroadcastsd(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_));
+ a->vmovaps(actRegAvx2_, x86::dword_ptr(in_acts_R_, scratchReg1_));
+ gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
+ a->add(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ }
+ if (!isZeroPointZero_) {
+ for (int s = S_ - W_PAD_; s < S_; ++s) {
+ gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
+ }
+ }
+
+ a->sub(
+ scratchReg1_,
+ static_cast<asmjit::Imm>((S_ - W_PAD_) * C_ * sizeof(uint8_t)));
+ a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->add(scratchReg1_, scratchReg2_);
+ }
+
+ storeResultRowoffset<inst_set_t::avx2>(a);
+
+ a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
+ a->add(row_offset_R_, scratchReg2_);
+ a->mov(loopR2_, H_R_);
+ a->sub(loopR2_, static_cast<asmjit::Imm>(H_PAD_));
+ a->inc(loopR1_);
+ a->cmp(loopR1_, loopR2_);
+ a->jl(LoopRightEdge);
+
+ // reset base
+ a->mov(scratchReg2_, W_R_);
+ a->imul(scratchReg2_, 2);
+ a->sub(scratchReg2_, 1);
+ a->imul(scratchReg2_, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
+ a->sub(row_offset_R_, scratchReg2_);
+
+ // reset increments done in the loop
+ //(H_ - 2*H_PAD_)*W_*8*sizeof(int32_t)
+ a->mov(scratchReg2_, H_R_);
+ a->sub(scratchReg2_, static_cast<asmjit::Imm>(2 * H_PAD_));
+ a->imul(scratchReg2_, W_R_);
+ a->imul(scratchReg2_, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
+ a->sub(row_offset_R_, scratchReg2_);
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
+ asmjit::X86Emitter* a) {
+ // bottom-left corner
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ a->mov(scratchReg1_, H_R_);
+ a->sub(scratchReg1_, static_cast<asmjit::Imm>(2 * H_PAD_));
+ a->imul(scratchReg1_, W_R_);
+ a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ for (int r = 0; r < R_ - H_PAD_; ++r) {
+ if (!isZeroPointZero_) {
+ for (int s = 0; s < W_PAD_; ++s) {
+ gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
+ }
+ }
+ for (int s = W_PAD_; s < S_; ++s) {
+ a->vmovaps(
+ actRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_,
+ scratchReg1_,
+ 0,
+ (s - W_PAD_) * C_ * sizeof(uint8_t)));
+ gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
+ }
+ a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->add(scratchReg1_, scratchReg2_);
+ }
+ if (!isZeroPointZero_) {
+ for (int r = R_ - H_PAD_; r < R_; ++r) {
+ for (int s = 0; s < S_; ++s) {
+ gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
+ }
+ }
+ }
+
+ // we updating the last row
+ a->mov(scratchReg1_, H_R_);
+ a->sub(scratchReg1_, 1);
+ a->imul(scratchReg1_, W_R_);
+ a->imul(scratchReg1_, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
+ a->add(row_offset_R_, scratchReg1_);
+ storeResultRowoffset<inst_set_t::avx2>(a);
+ a->add(row_offset_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
+
+ // bottom edge excluding corners
+ asmjit::Label LoopBottomEdge = a->newLabel();
+ a->mov(loopR2_, static_cast<asmjit::Imm>(W_PAD_));
+ a->bind(LoopBottomEdge);
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ a->mov(scratchReg1_, H_R_);
+ a->sub(scratchReg1_, static_cast<asmjit::Imm>(2 * H_PAD_));
+ a->imul(scratchReg1_, W_R_);
+ a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ for (int r = 0; r < R_ - W_PAD_; ++r) {
+ // int h_in = H_-2*H_PAD_ + r;
+ for (int s = 0; s < S_; ++s) {
+ a->vmovaps(
+ actRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t)));
+ gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
+ }
+ a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->add(scratchReg1_, scratchReg2_);
+ }
+
+ if (!isZeroPointZero_) {
+ for (int r = R_ - W_PAD_; r < R_; ++r) {
+ for (int s = 0; s < S_; ++s) {
+ gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
+ }
+ }
+ }
+
+ a->add(in_acts_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ // storeResult<inst_set_t::avx2>(a, ((H_-1)*W_+1)*8*sizeof(int32_t));
+ storeResultRowoffset<inst_set_t::avx2>(a);
+
+ a->add(row_offset_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
+ a->inc(loopR2_);
+ a->mov(loopR1_, W_R_);
+ a->sub(loopR1_, static_cast<asmjit::Imm>(W_PAD_));
+ a->cmp(loopR2_, loopR1_);
+ a->jl(LoopBottomEdge);
+ a->mov(scratchReg1_, W_R_);
+ a->sub(scratchReg1_, static_cast<asmjit::Imm>(2 * W_PAD_));
+ a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->sub(in_acts_R_, scratchReg1_);
+ // a->sub(in_acts_R_, (W_ - 2*W_PAD_)*C_*sizeof(uint8_t));
+ // a->sub(out_acts_R_, (W_ - 2*W_PAD_)*8*sizeof(int32_t));
+
+ // bottom-right corner
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ // input start point
+ // ((H_-(R_-H_PAD_))*W_+(W_-(S_-W_PAD_)))*C_*sizeof(uint8_t)
+ a->mov(scratchReg1_, H_R_);
+ a->sub(scratchReg1_, static_cast<asmjit::Imm>(R_ - H_PAD_));
+ a->imul(scratchReg1_, W_R_);
+ a->add(scratchReg1_, W_R_);
+ a->sub(scratchReg1_, static_cast<asmjit::Imm>(S_ - W_PAD_));
+ a->imul(scratchReg1_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ for (int r = 0; r < R_ - H_PAD_; ++r) {
+ for (int s = 0; s < S_ - W_PAD_; ++s) {
+ a->vmovaps(
+ actRegAvx2_,
+ x86::dword_ptr(
+ in_acts_R_, scratchReg1_, 0, s * C_ * sizeof(uint8_t)));
+ gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
+ }
+ a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->add(scratchReg1_, scratchReg2_);
+ if (!isZeroPointZero_) {
+ for (int s = S_ - W_PAD_; s < S_; ++s) {
+ gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
+ }
+ }
+ }
+
+ if (!isZeroPointZero_) {
+ for (int r = R_ - H_PAD_; r < R_; ++r) {
+ for (int s = 0; s < S_; ++s) {
+ gen8BitSum<inst_set_t::avx2>(a, zeroPTRegAvx2_);
+ }
+ }
+ }
+
+ storeResultRowoffset<inst_set_t::avx2>(a);
+ // reset output pointer
+ a->mov(scratchReg1_, H_R_);
+ a->sub(scratchReg1_, 1);
+ a->imul(scratchReg1_, W_R_);
+ a->add(scratchReg1_, W_R_);
+ a->sub(scratchReg1_, 1);
+ a->imul(scratchReg1_, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
+ a->sub(row_offset_R_, scratchReg1_);
+}
+
+template <>
+template <>
+void GenConvKernel<int32_t>::genRowoffsetCore<inst_set_t::avx2>(
+ asmjit::X86Emitter* a) {
+ // number of uint8 elements in input channels should be a multiple of 32
+ assert(C_ % 32 == 0);
+
+ asmjit::Label LoopH = a->newLabel();
+ asmjit::Label LoopW = a->newLabel();
+ // base for output
+ a->mov(scratchReg2_, static_cast<asmjit::Imm>(H_PAD_));
+ a->imul(scratchReg2_, W_R_);
+ a->add(scratchReg2_, static_cast<asmjit::Imm>(W_PAD_));
+ a->imul(scratchReg2_, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
+ a->add(row_offset_R_, scratchReg2_);
+
+ a->mov(scratchReg1_, W_R_);
+ a->sub(scratchReg1_, static_cast<asmjit::Imm>(W_PAD_));
+
+ // H loop
+ a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_));
+ a->bind(LoopH);
+ // W loop
+ a->mov(loopR2_, static_cast<asmjit::Imm>(W_PAD_));
+ a->bind(LoopW);
+
+ // zero out
+ a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_);
+ for (int r = 0; r < R_; ++r) {
+ for (int s = 0; s < S_; ++s) {
+ a->vmovaps(
+ actRegAvx2_, x86::dword_ptr(in_acts_R_, s * C_ * sizeof(uint8_t)));
+ gen8BitSum<inst_set_t::avx2>(a, actRegAvx2_);
+ }
+ a->imul(scratchReg2_, W_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->add(in_acts_R_, scratchReg2_);
+ }
+ a->imul(
+ scratchReg2_, W_R_, static_cast<asmjit::Imm>(R_ * C_ * sizeof(uint8_t)));
+ a->sub(in_acts_R_, scratchReg2_);
+ // store results
+ storeResultRowoffset<inst_set_t::avx2>(a);
+
+ a->add(in_acts_R_, static_cast<asmjit::Imm>(C_ * sizeof(uint8_t)));
+ a->add(row_offset_R_, static_cast<asmjit::Imm>(8 * sizeof(int32_t)));
+
+ a->inc(loopR2_);
+ a->cmp(loopR2_, scratchReg1_);
+ a->jl(LoopW);
+ a->add(
+ in_acts_R_, static_cast<asmjit::Imm>(2 * W_PAD_ * C_ * sizeof(uint8_t)));
+ a->add(
+ row_offset_R_,
+ static_cast<asmjit::Imm>(2 * W_PAD_ * 8 * sizeof(int32_t)));
+ a->inc(loopR1_);
+ a->mov(scratchReg2_, H_R_);
+ a->sub(scratchReg2_, static_cast<asmjit::Imm>(H_PAD_));
+ a->cmp(loopR1_, scratchReg2_);
+ a->jl(LoopH);
+}
+
+template <>
+template <>
+jit_rowoffset_kernel_fp
+GenConvKernel<int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
+ const conv_param_t<>& conv_param) {
+ code_.reset(false);
+ code_.init(rt_.getCodeInfo());
+ asmjit::X86Assembler assembler(&code_);
+ asmjit::X86Emitter* a = assembler.asEmitter();
+
+#if defined(FBGEMM_LOG_CODE)
+ // log code to a file
+ FILE* codeLogfile =
+ fopen(getCodeLoggingFile<inst_set_t::avx2>(true).c_str(), "w");
+ asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
+ if (codeLogger) {
+ code_.setLogger(codeLogger);
+ }
+#endif
+
+ // arguments to the function created
+ in_acts_R_ = a->zdi();
+ a_zero_pt_R_ = a->zsi();
+ H_R_ = a->zdx();
+ W_R_ = a->zcx();
+ row_offset_R_ = a->gpzRef(8);
+
+ // register for temporary use
+ scratchReg1_ = a->gpzRef(12);
+ scratchReg2_ = a->gpzRef(13);
+
+ loopR1_ = a->gpzRef(14);
+ loopR2_ = a->gpzRef(15);
+
+ asmjit::FuncDetail func;
+ func.init(
+ asmjit::
+ FuncSignature5<void, uint8_t*, int32_t, int32_t, int32_t, int32_t*>(
+ asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrameInfo ffi;
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindVec,
+ asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15));
+ ffi.setDirtyRegs(
+ asmjit::X86Reg::kKindGp, asmjit::Utils::mask(10, 11, 12, 13, 14, 15));
+
+ asmjit::FuncArgsMapper args(&func);
+ args.assignAll(in_acts_R_, a_zero_pt_R_, H_R_, W_R_, row_offset_R_);
+
+ args.updateFrameInfo(ffi);
+
+ asmjit::FuncFrameLayout layout;
+ layout.init(func, ffi);
+
+ asmjit::FuncUtils::emitProlog(a, layout);
+ asmjit::FuncUtils::allocArgs(a, layout, args);
+
+ // This uses xmm10 register temporarily. Should come before
+ // createVector8BitOne
+ if (!isZeroPointZero_) {
+ setToZeroPt<inst_set_t::avx2>(a, zeroPTRegAvx2_);
+ }
+
+ createVector16BitOne<inst_set_t::avx2>(a);
+ // we set ymm10 to contain 8-bit 1s
+ createVector8BitOne<inst_set_t::avx2>(a);
+
+ genForTopEdgeRowoffset<inst_set_t::avx2>(a);
+ genForLeftEdgeRowoffset<inst_set_t::avx2>(a);
+ genForRightEdgeRowoffset<inst_set_t::avx2>(a);
+ genForBottomEdgeRowoffset<inst_set_t::avx2>(a);
+
+ genRowoffsetCore<inst_set_t::avx2>(a);
+
+ asmjit::FuncUtils::emitEpilog(a, layout);
+
+ jit_rowoffset_kernel_fp fn;
+ asmjit::Error err = rt_.add(&fn, &code_);
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
+ auto kernelSig = getKernelSig(conv_param, isZeroPointZero_);
+ codeCacheRowOffset_[kernelSig] = fn;
+
+#if defined(FBGEMM_LOG_CODE)
+ delete codeLogger;
+ fclose(codeLogfile);
+#endif
+
+ return fn;
+}
+
+template <
+ typename packed_W,
+ typename outType,
+ typename processOutputType,
+ int SPATIAL_DIM>
+void fbgemmGroupwiseConv(
+ const conv_param_t<SPATIAL_DIM>& conv_param,
+ const std::uint8_t* activations,
+ std::int32_t a_zero_point,
+ std::int32_t* rowOffsetBuf,
+ packed_W& packed_weights,
+ outType* out,
+ int32_t* outBuffer,
+ const processOutputType& outProcess,
+ int thread_id,
+ int num_threads) {
+
+ int MB = conv_param.MB;
+ int H = conv_param.OUT_DIM[0];
+ int W = conv_param.OUT_DIM[1];
+ int G = conv_param.G;
+ int K_per_G = conv_param.OC / G;
+ int C_per_G = conv_param.IC / G;
+ int oh_ow = conv_param.OUT_DIM[0] * conv_param.OUT_DIM[1];
+
+ static_assert(SPATIAL_DIM == 2, "3D conv not supported yet");
+
+ int32_t* rowOffsetTrDest =
+ rowOffsetBuf + 8 * conv_param.IN_DIM[0] * conv_param.IN_DIM[1];
+ if (fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param)) {
+ assert(G % 8 == 0);
+ // generate convolution kernel
+ jit_conv_kernel_fp fpConv =
+ getOrCreateConvKernel<>(conv_param, a_zero_point);
+ // generate row offset kernel
+ jit_rowoffset_kernel_fp fpRowoffset =
+ getOrCreateRowOffsetKernel(conv_param, a_zero_point);
+ for (int i = 0; i < MB; ++i) {
+ const uint8_t* actStartBatch = activations +
+ i * conv_param.IN_DIM[0] * conv_param.IN_DIM[1] * conv_param.IC;
+ for (int gOuter = 0; gOuter < G; gOuter += 8) {
+ // for C_per_G == 4 and K_per_G == 4, row offset is calcualted for 8
+ // groups at a time The result is row offsets in the format IH*IW x G
+ fpRowoffset(
+ actStartBatch + gOuter * C_per_G,
+ a_zero_point,
+ H,
+ W,
+ rowOffsetBuf);
+ // Transpose to get row offsets in the format G x IH*IW
+ internal::transpose_8x8(
+ conv_param.IN_DIM[0] * conv_param.IN_DIM[1],
+ 8,
+ (const float*)rowOffsetBuf,
+ 8,
+ (float*)rowOffsetTrDest,
+ conv_param.IN_DIM[0] * conv_param.IN_DIM[1]);
+ int gLimit = gOuter + 8;
+ for (int g = gOuter; g < gLimit; g += 2) {
+ int32_t* currOutBuf =
+ outBuffer + i * oh_ow * conv_param.OC + g * K_per_G;
+ const uint8_t* actStartGroup = actStartBatch + g * C_per_G;
+
+ fpConv(
+ actStartGroup,
+ packed_weights.getBuf() + g * K_per_G * C_per_G,
+ currOutBuf,
+ a_zero_point,
+ H,
+ W);
+
+ // Output processing should be called for each group
+ for (int j = 0; j < 2; ++j) {
+ // calculateRowOffsets(
+ // conv_param, actStartGroup, rowOffsetBuf, a_zero_point, j);
+ int32_t* rowOffsetForCurG = rowOffsetTrDest +
+ ((g - gOuter) + j) * conv_param.IN_DIM[0] *
+ conv_param.IN_DIM[1];
+ // compare_buffers(rowOffsetBuf, rowOffsetForCurG,
+ // conv_param.IN_DIM[0]*conv_param.IN_DIM[1], 1, 1, 100);
+
+ // outProcess expects rowOffsetBuf to contain row offsets for the
+ // current group
+ memcpy(
+ rowOffsetBuf,
+ rowOffsetForCurG,
+ conv_param.IN_DIM[0] * conv_param.IN_DIM[1] * sizeof(int32_t));
+
+ if (cpuinfo_has_x86_avx512f()) {
+ // Currently use avx2 code
+ outProcess.template f<inst_set_t::avx2>(
+ out,
+ currOutBuf + j * K_per_G,
+ {i * oh_ow, oh_ow, (g + j) * K_per_G, K_per_G},
+ K_per_G * G,
+ K_per_G * G);
+ } else if (cpuinfo_has_x86_avx2()) {
+ outProcess.template f<inst_set_t::avx2>(
+ out,
+ currOutBuf + j * K_per_G,
+ {i * oh_ow, oh_ow, (g + j) * K_per_G, K_per_G},
+ K_per_G * G,
+ K_per_G * G);
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unsupported architecure");
+ }
+ } // j loop
+ }
+ }
+ }
+ } else {
+ // for the not supported cases, just execute the naive C implementation
+ conv_ref(
+ conv_param,
+ activations,
+ a_zero_point,
+ packed_weights.getBuf(),
+ outBuffer);
+ for (int i = 0; i < conv_param.MB; ++i) {
+ for (int g = 0; g < conv_param.G; ++g) {
+ calculateRowOffsets(
+ conv_param,
+ activations +
+ i * conv_param.IN_DIM[0] * conv_param.IN_DIM[1] * conv_param.IC,
+ rowOffsetBuf,
+ a_zero_point,
+ g);
+ outProcess.template f<inst_set_t::anyarch>(
+ out,
+ outBuffer + i * oh_ow * conv_param.OC + g * K_per_G,
+ {i * oh_ow, oh_ow, g * K_per_G, K_per_G},
+ K_per_G * G,
+ K_per_G * G);
+ }
+ }
+ }
+}
+
+jit_rowoffset_kernel_fp getOrCreateRowOffsetKernel(
+ const conv_param_t<>& conv_param,
+ int a_zero_point) {
+ // Note: Wrong code is generated if it's not one of the supported convolution
+ assert(fbgemmOptimizedGConv<2>(conv_param));
+ auto kernelSig = getKernelSig(conv_param, a_zero_point == 0);
+ if (GenConvKernel<int32_t>::codeCacheRowOffset_.find(kernelSig) !=
+ GenConvKernel<int32_t>::codeCacheRowOffset_.end()) {
+ return GenConvKernel<int32_t>::codeCacheRowOffset_[kernelSig];
+ } else {
+ auto genObj = GenConvKernel<int32_t>(conv_param, a_zero_point);
+ // TODO: Instruction set based dispatch
+ return genObj.template getOrCreateRowOffset<inst_set_t::avx2>(conv_param);
+ }
+}
+
+template <int SPATIAL_DIM>
+int rowOffsetBufferSizeGConv(const conv_param_t<SPATIAL_DIM>& conv_param) {
+ // row offset buffer should be a able to hold row offsets for however
+ // number of groups we process at a time.
+ if (cpuinfo_initialize()) {
+ if (cpuinfo_has_x86_avx512f()) {
+ int bufferSize = conv_param.OUT_DIM[0] * conv_param.OUT_DIM[1];
+ int C_per_G = conv_param.IC / conv_param.G;
+ int K_per_G = conv_param.OC / conv_param.G;
+ if (C_per_G == 4 && K_per_G == 4) {
+ return 2 * 8 * bufferSize;
+ } else {
+ return conv_param.G * bufferSize;
+ }
+ } else if (cpuinfo_has_x86_avx2()) {
+ int bufferSize = conv_param.OUT_DIM[0] * conv_param.OUT_DIM[1];
+ int C_per_G = conv_param.IC / conv_param.G;
+ int K_per_G = conv_param.OC / conv_param.G;
+ if (C_per_G == 4 && K_per_G == 4) {
+ // row offset is calculated for 8 groups at a time
+ // 2x is needed for transposing
+ return 2 * 8 * bufferSize;
+ } else {
+ return conv_param.G * bufferSize;
+ }
+ } else {
+ // TODO: Have default slower path
+ assert(0 && "unsupported architecture");
+ return -1;
+ }
+ } else {
+ throw std::runtime_error("Failed to initialize cpuinfo!");
+ }
+}
+
+template int rowOffsetBufferSizeGConv<2>(const conv_param_t<2>& conv_param);
+
+#define INSTANTIATE_BASE(RELU, Q_GRAN) \
+ template void fbgemmGroupwiseConv( \
+ const conv_param_t<2>& conv_param, \
+ const uint8_t* activations, \
+ int32_t a_zero_point, \
+ std::int32_t* rowOffsetBuf, \
+ PackWeightMatrixForGConv<int8_t, int32_t, 2>& packed_weights, \
+ uint8_t* out, \
+ int32_t* outBuffer, \
+ const ReQuantizeOutput<RELU, Q_GRAN>& outProcess, \
+ int thread_id, \
+ int num_threads);
+
+#define INSTANTIATE_Q_GRANS(RELU) \
+ INSTANTIATE_BASE(RELU, QuantizationGranularity::TENSOR); \
+ INSTANTIATE_BASE(RELU, QuantizationGranularity::GROUP); \
+ INSTANTIATE_BASE(RELU, QuantizationGranularity::OUT_CHANNEL);
+
+INSTANTIATE_Q_GRANS(false);
+INSTANTIATE_Q_GRANS(true);
+
+#undef INSTANTIATE_Q_GRANS
+#undef INSTANTIATE_BASE
+
+template void fbgemmGroupwiseConv(
+ const conv_param_t<2>& conv_param,
+ const uint8_t* activations,
+ int32_t a_zero_point,
+ std::int32_t* rowOffsetBuf,
+ PackWeightMatrixForGConv<int8_t, int32_t, 2>& packed_weights,
+ int32_t* out,
+ int32_t* outBuffer,
+ const DoNothing<int32_t, int32_t>& outProcess,
+ int thread_id,
+ int num_threads);
+
+} // namespace fbgemm
diff --git a/src/PackWeightMatrixForGConv.cc b/src/PackWeightMatrixForGConv.cc
new file mode 100644
index 0000000..e6c9b7d
--- /dev/null
+++ b/src/PackWeightMatrixForGConv.cc
@@ -0,0 +1,103 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <cpuinfo.h>
+#include <cassert>
+#include <iomanip>
+#include "RefImplementations.h"
+#include "fbgemm/Fbgemm.h"
+
+namespace fbgemm {
+
+template <typename T, typename accT, int SPATIAL_DIM>
+PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::PackWeightMatrixForGConv(
+ matrix_op_t trans,
+ const conv_param_t<SPATIAL_DIM>& conv_param,
+ const T* sdata,
+ T* pdata)
+ : trans_(trans), conv_param_(conv_param), sdata_(sdata) {
+ static_assert(SPATIAL_DIM == 2, "3D conv not supported yet");
+
+ if (!pdata) {
+ bufAllocatedHere_ = true;
+ pdata_ = static_cast<T*>(fbgemmAlignedAlloc(
+ 64,
+ conv_param_.G * conv_param_.K[0] * conv_param_.K[1] *
+ (conv_param_.OC / conv_param_.G) *
+ (conv_param_.IC / conv_param_.G) * sizeof(T)));
+ } else {
+ bufAllocatedHere_ = false;
+ pdata_ = pdata;
+ }
+ pack();
+}
+
+/**
+ * @brief Pack weight tensor in a suitable format required for the optimized
+ * kernel.
+ *
+ * Let IC_per_G be number of input channels per group and OC_per_G be number of
+ * output channels per group.
+ *
+ * For IC_per_G == 4 && OC_per_G == 4 optimized
+ * kernel works on 2 groups at a time hence input channels for g and g+1 group
+ * are laid out sequentially for each output channel, i.e., the layout is R S
+ * (G/2) K (2C)
+ * We work on two groups at a time to fully utilize the avx2 SIMD width of
+ * 256-bits.
+ *
+ * For IC_per_G == 8, 16, 32 && OC_per_G == 8, 16, 32 there is no need to work
+ * on 2 groups at a time and full SIMD width can be efficiently utilized even
+ * while working on 1 group at a time.
+ */
+template <typename T, typename accT, int SPATIAL_DIM>
+void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack() {
+ // filters are assumed to be in G RS C/G K/G format
+ int R = conv_param_.K[0];
+ int S = conv_param_.K[1];
+ int G = conv_param_.G;
+ int IC_per_G = conv_param_.IC / conv_param_.G;
+ int OC_per_G = conv_param_.OC / conv_param_.G;
+
+ // If transpose option is set, the weight matrix is in layout G K/G (R S C/G)
+ // instead of G (R S C/G) K/G
+ bool tr = (trans_ == matrix_op_t::Transpose);
+ if (fbgemmOptimizedGConv(conv_param_)) {
+ // currently only this case is supported
+ for (int r = 0; r < R; ++r) {
+ for (int s = 0; s < S; ++s) {
+ for (int k = 0; k < OC_per_G; ++k) {
+ for (int g = 0; g < G; ++g) {
+ for (int c = 0; c < IC_per_G; ++c) {
+ inpType b = tr
+ ? sdata_
+ [(((g * OC_per_G + k) * R + r) * S + s) * IC_per_G + c]
+ : sdata_
+ [(((g * R + r) * S + s) * IC_per_G + c) * OC_per_G + k];
+ pdata_
+ [((((r * S + s) * (G / 2) + (g / 2)) * OC_per_G + k) * 2 +
+ (g % 2)) *
+ IC_per_G +
+ c] = b;
+ }
+ }
+ }
+ }
+ }
+ } else {
+ if (tr) {
+ // conv_ref expects weights to be in G (R S C/G) K/G format
+ transposeConvWeights(conv_param_, sdata_, pdata_);
+ } else {
+ // just copy the data for not supported cases
+ memcpy(pdata_, sdata_, G * R * S * OC_per_G * IC_per_G * sizeof(inpType));
+ }
+ }
+}
+
+template class PackWeightMatrixForGConv<int8_t, int32_t, 2>;
+template class PackWeightMatrixForGConv<int8_t, int16_t, 2>;
+} // namespace fbgemm
diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc
index 5168a15..5c6cf1b 100644
--- a/src/RefImplementations.cc
+++ b/src/RefImplementations.cc
@@ -486,6 +486,31 @@ void conv3d_ref(
} // for each n
}
+void transposeConvWeights(
+ const conv_param_t<>& conv_p,
+ const std::int8_t* src,
+ std::int8_t* dest) {
+ int R = conv_p.K[0];
+ int S = conv_p.K[1];
+ int G = conv_p.G;
+ int IC_per_G = conv_p.IC / conv_p.G;
+ int OC_per_G = conv_p.OC / conv_p.G;
+
+ // Transforms weights from G K/G (R S C/G) to G (R S C/G) K/G format.
+ for (int r = 0; r < R; ++r) {
+ for (int s = 0; s < S; ++s) {
+ for (int k = 0; k < OC_per_G; ++k) {
+ for (int g = 0; g < G; ++g) {
+ for (int c = 0; c < IC_per_G; ++c) {
+ dest[(((g * R + r) * S + s) * IC_per_G + c) * OC_per_G + k] =
+ src[(((g * OC_per_G + k) * R + r) * S + s) * IC_per_G + c];
+ }
+ }
+ }
+ }
+ }
+}
+
void depthwise_3x3_pad_1_ref(
int N,
int H,
diff --git a/src/RefImplementations.h b/src/RefImplementations.h
index fce68e6..62f17e9 100644
--- a/src/RefImplementations.h
+++ b/src/RefImplementations.h
@@ -176,6 +176,14 @@ FBGEMM_API void conv3d_ref(
std::int32_t* C);
/*
+ * @brief Transforms weights from G K/G (R S C/G) to G (R S C/G) K/G format.
+ */
+FBGEMM_API void transposeConvWeights(
+ const conv_param_t<>& conv_p,
+ const std::int8_t* src,
+ std::int8_t* dest);
+
+/*
* @brief Reference implementation of im2col operation.
* The input A is assumed to be in NHiWiC format.
* The output A is assumed to be in NHoWoRSC format.
diff --git a/test/GConvTest.cc b/test/GConvTest.cc
new file mode 100644
index 0000000..34042c6
--- /dev/null
+++ b/test/GConvTest.cc
@@ -0,0 +1,382 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <algorithm>
+#include <chrono>
+#include <cmath>
+#include <random>
+#include <vector>
+
+#ifdef _OPENMP
+#include <omp.h>
+#endif
+
+#include <gtest/gtest.h>
+
+#include "QuantizationHelpers.h"
+#include "TestUtils.h"
+#include "bench/BenchUtils.h"
+#include "fbgemm/Fbgemm.h"
+#include "src/RefImplementations.h"
+
+using namespace std;
+using namespace fbgemm;
+
+vector<matrix_op_t> transposeVals{matrix_op_t::NoTranspose,
+ matrix_op_t::Transpose};
+
+vector<QuantizationGranularity> qGranularityVals{
+ QuantizationGranularity::TENSOR,
+ QuantizationGranularity::GROUP,
+ QuantizationGranularity::OUT_CHANNEL};
+
+namespace {
+class fbgemmGConvAcc32Test
+ : public testing::TestWithParam<tuple<matrix_op_t, matrix_op_t>> {};
+class fbgemmGConvAcc32WithQuantGranularityTest
+ : public testing::TestWithParam<
+ tuple<matrix_op_t, matrix_op_t, QuantizationGranularity>> {};
+}; // namespace
+
+INSTANTIATE_TEST_CASE_P(
+ InstantiationName,
+ fbgemmGConvAcc32Test,
+ ::testing::Combine(
+ ::testing::Values(matrix_op_t::NoTranspose),
+ ::testing::ValuesIn(transposeVals)));
+
+INSTANTIATE_TEST_CASE_P(
+ InstantiationName,
+ fbgemmGConvAcc32WithQuantGranularityTest,
+ ::testing::Combine(
+ ::testing::Values(matrix_op_t::NoTranspose),
+ ::testing::ValuesIn(transposeVals),
+ ::testing::ValuesIn(qGranularityVals)));
+/**
+ * @brief Shapes for unit test.
+ */
+static vector<conv_param_t<>> GetShapes_() {
+ vector<conv_param_t<>> shapes = {
+ // MB, IC, OC, {IH, IW}, G, {KH, KW}, {stride_h, stride_w}, {pad_t, pad_l,
+ // pad_b, pad_r}
+ conv_param_t<>(1, 32, 32, {3, 3}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 32, 32, {4, 4}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 32, 32, {3, 5}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 32, 32, {5, 3}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 8, 8, {5, 5}, 2, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 128, 128, {56, 48}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 128, 128, {48, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(1, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ conv_param_t<>(2, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ };
+ return shapes;
+}
+
+/**
+ * @brief Unit test for uint8 activations, int8 weights, and 32-bit
+ * accumulation. Output processing: requantization -> nothing
+ */
+TEST_P(fbgemmGConvAcc32WithQuantGranularityTest, requantizeTest) {
+ vector<conv_param_t<>> shapes(GetShapes_());
+ matrix_op_t atrans, btrans;
+ QuantizationGranularity q_granularity;
+ tie(atrans, btrans, q_granularity) = GetParam();
+
+ for (auto conv_p : shapes) {
+ int R = conv_p.K[0];
+ int S = conv_p.K[1];
+ int G = conv_p.G;
+ int OC = conv_p.OC;
+ int OH = conv_p.OUT_DIM[0];
+ int OW = conv_p.OUT_DIM[1];
+ int IC_per_G = conv_p.IC / conv_p.G;
+ int OC_per_G = conv_p.OC / conv_p.G;
+
+ // activations
+ aligned_vector<uint8_t> Aint8(
+ conv_p.MB * conv_p.IN_DIM[0] * conv_p.IN_DIM[1] * conv_p.IC, 0);
+
+ // weights
+ // when btrans == Transpose, the weight matrix is in layout G K/G (R S C/G)
+ // instead of G (R S C/G) K/G
+ aligned_vector<int8_t> Bint8(R * S * conv_p.G * IC_per_G * OC_per_G, 0);
+ aligned_vector<int8_t> Bint8_tr(R * S * G * IC_per_G * OC_per_G, 0);
+
+ aligned_vector<int32_t> Cint32_ref(conv_p.MB * OH * OW * OC, 0);
+ aligned_vector<int32_t> Cint32_fb(Cint32_ref.size(), 0);
+ aligned_vector<uint8_t> Cint8_ref(Cint32_ref.size(), 0);
+ aligned_vector<uint8_t> Cint8_fb(Cint32_ref.size(), 0);
+
+ randFill<uint8_t>(Aint8, 0, 5);
+ int32_t Aint8_zero_point = 4;
+
+ randFill<int8_t>(Bint8, -4, 4);
+
+ // computing column offset
+ vector<int32_t> col_offsets(G * OC_per_G);
+
+ int ncols_per_quant_group = G * OC_per_G;
+ if (q_granularity == QuantizationGranularity::GROUP) {
+ ncols_per_quant_group = OC_per_G;
+ } else if (q_granularity == QuantizationGranularity::OUT_CHANNEL) {
+ ncols_per_quant_group = 1;
+ }
+
+ aligned_vector<int32_t> Bint8_zero_point(
+ G * OC_per_G / ncols_per_quant_group);
+ randFill(Bint8_zero_point, -3, -1);
+
+ // matrix dimensions after im2col for each GEMM.
+ // For each group, there is one GEMM of the following dimensions
+ int MDim = conv_p.MB * OH * OW;
+ int NDim = OC_per_G;
+ int KDim = R * S * IC_per_G;
+
+ vector<uint8_t> Aint8_im2col(MDim * KDim * G);
+ im2col_ref(conv_p, Aint8.data(), Aint8_zero_point, Aint8_im2col.data());
+
+ vector<int32_t> row_offsets(MDim);
+
+ aligned_vector<float> C_multiplier(Bint8_zero_point.size());
+ randFill(C_multiplier, 0.1234f / 2, 0.1234f * 3 / 2);
+ int32_t C_zero_pt = 5;
+
+ // reference implementation
+ // conv_ref expects weights to be in G (R S C/G) K/G
+ int8_t* rightBData = Bint8.data();
+ if (btrans == matrix_op_t::Transpose) {
+ transposeConvWeights(conv_p, Bint8.data(), Bint8_tr.data());
+ rightBData = Bint8_tr.data();
+ }
+ for (int g = 0; g < G; ++g) {
+ col_offsets_with_zero_pt_s8acc32_ref(
+ R * S * IC_per_G,
+ OC_per_G,
+ OC_per_G,
+ rightBData + g * R * S * IC_per_G * OC_per_G,
+ Bint8_zero_point.data() + g * OC_per_G / ncols_per_quant_group,
+ col_offsets.data() + g * OC_per_G,
+ ncols_per_quant_group);
+ }
+ conv_ref(
+ conv_p, Aint8.data(), Aint8_zero_point, rightBData, Cint32_ref.data());
+
+ for (int g = 0; g < G; ++g) {
+ row_offsets_u8acc32_ref(
+ MDim,
+ KDim,
+ KDim * G,
+ Aint8_im2col.data() + g * KDim,
+ row_offsets.data());
+
+ requantize_u8acc32_ref(
+ MDim,
+ NDim,
+ G * NDim,
+ Cint32_ref.data() + g * NDim,
+ Cint8_ref.data() + g * NDim,
+ C_multiplier.data() + g * NDim / ncols_per_quant_group,
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point.data() + g * NDim / ncols_per_quant_group,
+ row_offsets.data(),
+ col_offsets.data() + g * NDim,
+ nullptr,
+ ncols_per_quant_group);
+ }
+
+ PackWeightMatrixForGConv<int8_t> packedWeights(
+ btrans, conv_p, Bint8.data(), nullptr);
+
+ // TODO: Uncomment once we support multiple threads in fbgemmGroupwiseConv
+ // #ifdef _OPENMP
+ // #pragma omp parallel
+ // #endif
+ {
+ vector<int32_t> row_offset_buf(rowOffsetBufferSizeGConv(conv_p));
+
+ DoNothing<> doNothingObj{};
+
+ int num_threads = fbgemm_get_num_threads();
+ int tid = fbgemm_get_thread_num();
+
+ if (q_granularity == QuantizationGranularity::TENSOR) {
+ ReQuantizeOutput<false, QuantizationGranularity::TENSOR> reqObj(
+ doNothingObj,
+ C_multiplier.data(),
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point.data(),
+ row_offset_buf.data(),
+ col_offsets.data(),
+ nullptr,
+ G * NDim,
+ G);
+
+ fbgemmGroupwiseConv(
+ conv_p,
+ Aint8.data(),
+ Aint8_zero_point,
+ row_offset_buf.data(),
+ packedWeights,
+ Cint8_fb.data(),
+ Cint32_fb.data(),
+ reqObj,
+ tid,
+ num_threads);
+
+ } else if (q_granularity == QuantizationGranularity::GROUP) {
+ ReQuantizeOutput<false, QuantizationGranularity::GROUP> reqObj(
+ doNothingObj,
+ C_multiplier.data(),
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point.data(),
+ row_offset_buf.data(),
+ col_offsets.data(),
+ nullptr,
+ G * NDim,
+ G);
+
+ fbgemmGroupwiseConv(
+ conv_p,
+ Aint8.data(),
+ Aint8_zero_point,
+ row_offset_buf.data(),
+ packedWeights,
+ Cint8_fb.data(),
+ Cint32_fb.data(),
+ reqObj,
+ tid,
+ num_threads);
+
+ } else {
+ ReQuantizeOutput<false, QuantizationGranularity::OUT_CHANNEL> reqObj(
+ doNothingObj,
+ C_multiplier.data(),
+ C_zero_pt,
+ Aint8_zero_point,
+ Bint8_zero_point.data(),
+ row_offset_buf.data(),
+ col_offsets.data(),
+ nullptr,
+ G * NDim,
+ G);
+
+ fbgemmGroupwiseConv(
+ conv_p,
+ Aint8.data(),
+ Aint8_zero_point,
+ row_offset_buf.data(),
+ packedWeights,
+ Cint8_fb.data(),
+ Cint32_fb.data(),
+ reqObj,
+ tid,
+ num_threads);
+ }
+ } // omp parallel
+
+ compare_validate_buffers(
+ Cint8_ref.data(),
+ Cint8_fb.data(),
+ MDim,
+ NDim * G,
+ NDim * G,
+ static_cast<uint8_t>(0));
+ } // for each shape
+}
+
+/**
+ * @brief Unit test for uint8 activations, int8 weights, and 32-bit
+ * accumulation. Output processing: nothing
+ */
+TEST_P(fbgemmGConvAcc32Test, NoRequantizeTest) {
+ vector<conv_param_t<>> shapes(GetShapes_());
+ matrix_op_t atrans, btrans;
+ tie(atrans, btrans) = GetParam();
+
+ for (auto conv_p : shapes) {
+ int R = conv_p.K[0];
+ int S = conv_p.K[1];
+ int G = conv_p.G;
+ int OC = conv_p.OC;
+ int OH = conv_p.OUT_DIM[0];
+ int OW = conv_p.OUT_DIM[1];
+ int IC_per_G = conv_p.IC / conv_p.G;
+ int OC_per_G = conv_p.OC / conv_p.G;
+
+ // activations
+ aligned_vector<uint8_t> Aint8(
+ conv_p.MB * conv_p.IN_DIM[0] * conv_p.IN_DIM[1] * conv_p.IC, 0);
+
+ // weights
+ // when btrans == Transpose, the weight matrix is in layout G K/G (R S C/G)
+ // instead of G (R S C/G) K/G
+ aligned_vector<int8_t> Bint8(R * S * conv_p.G * IC_per_G * OC_per_G, 0);
+ aligned_vector<int8_t> Bint8_tr(R * S * conv_p.G * IC_per_G * OC_per_G, 0);
+
+ aligned_vector<int32_t> Cint32_ref(conv_p.MB * OH * OW * OC, 0);
+ aligned_vector<int32_t> Cint32_fb(Cint32_ref.size(), 0);
+
+ randFill<uint8_t>(Aint8, 0, 5);
+ int32_t Aint8_zero_point = 4;
+
+ randFill<int8_t>(Bint8, -4, 4);
+
+ // matrix dimensions after im2col for each GEMM.
+ // For each group, there is one GEMM of the following dimensions
+ int MDim = conv_p.MB * OH * OW;
+ int NDim = OC_per_G;
+ // int KDim = R * S * IC_per_G;
+
+ // reference implementation
+ // conv_ref expects weights to be in G (R S C/G) K/G
+ int8_t* rightBData = Bint8.data();
+ if (btrans == matrix_op_t::Transpose) {
+ transposeConvWeights(conv_p, Bint8.data(), Bint8_tr.data());
+ rightBData = Bint8_tr.data();
+ }
+ conv_ref(
+ conv_p, Aint8.data(), Aint8_zero_point, rightBData, Cint32_ref.data());
+
+ PackWeightMatrixForGConv<int8_t> packedWeights(
+ btrans, conv_p, Bint8.data(), nullptr);
+
+ // TODO: Uncomment once we support multiple threads in fbgemmGroupwiseConv
+ // #ifdef _OPENMP
+ // #pragma omp parallel
+ // #endif
+ {
+ vector<int32_t> row_offset_buf(rowOffsetBufferSizeGConv(conv_p));
+
+ DoNothing<int32_t, int32_t> doNothingObj{};
+
+ int num_threads = fbgemm_get_num_threads();
+ int tid = fbgemm_get_thread_num();
+
+ fbgemmGroupwiseConv(
+ conv_p,
+ Aint8.data(),
+ Aint8_zero_point,
+ row_offset_buf.data(),
+ packedWeights,
+ Cint32_fb.data(),
+ Cint32_fb.data(),
+ doNothingObj,
+ tid,
+ num_threads);
+ }
+
+ compare_validate_buffers(
+ Cint32_ref.data(),
+ Cint32_fb.data(),
+ MDim,
+ NDim * G,
+ NDim * G,
+ static_cast<int32_t>(0));
+ } // for each shape
+}