Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaya Khudia <dskhudia@fb.com>2019-06-05 22:44:57 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-06-05 22:50:08 +0300
commit8197494f3ae280941639c72bc1342a9faa8e2ad6 (patch)
treee0f7fb9b8aad61c46b7d2a79f019255514c03985
parent77868418c7963572167690ef069b06cbfe67de1f (diff)
Unified convolution interface
Summary: We want to combine three different convolution interfaces under one top level function. Reviewed By: protonu Differential Revision: D15399811 fbshipit-source-id: 7390616d92783506fc156f0f6017f10b5f7f8e30
-rw-r--r--CMakeLists.txt2
-rw-r--r--bench/ConvUnifiedBenchmark.cc309
-rw-r--r--bench/Depthwise3DBenchmark.cc2
-rw-r--r--bench/DepthwiseBenchmark.cc2
-rw-r--r--include/fbgemm/ConvUtils.h2
-rw-r--r--include/fbgemm/Fbgemm.h107
-rw-r--r--include/fbgemm/FbgemmI8DepthwiseAvx2.h175
-rw-r--r--include/fbgemm/Utils.h5
-rw-r--r--src/FbgemmConv.cc222
-rw-r--r--src/FbgemmI8DepthwiseAvx2.cc2
-rw-r--r--src/FbgemmI8DepthwiseAvx2.h4
-rw-r--r--src/GroupwiseConv.h10
-rw-r--r--src/GroupwiseConvAcc32Avx2.cc193
-rw-r--r--src/PackWeightMatrixForGConv.cc4
-rw-r--r--src/PackWeightsForConv.cc71
-rw-r--r--src/RefImplementations.cc28
-rw-r--r--src/RefImplementations.h31
-rw-r--r--test/I8DepthwiseTest.cc2
-rw-r--r--test/Im2ColFusedRequantizeTest.cc4
-rw-r--r--test/UniConvPackingTest.cc148
20 files changed, 1205 insertions, 118 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 80de824..b575e17 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -29,6 +29,7 @@ set(FBGEMM_GENERIC_SRCS src/ExecuteKernel.cc
src/ExecuteKernelU8S8.cc
src/Fbgemm.cc
src/FbgemmFP16.cc
+ src/FbgemmConv.cc
src/FbgemmI8Spmdm.cc
src/GenerateKernelU8S8S32ACC16.cc
src/GenerateKernelU8S8S32ACC16Avx512.cc
@@ -42,6 +43,7 @@ set(FBGEMM_GENERIC_SRCS src/ExecuteKernel.cc
src/PackAWithQuantRowOffset.cc
src/PackAWithRowOffset.cc
src/PackWeightMatrixForGConv.cc
+ src/PackWeightsForConv.cc
src/QuantUtils.cc
src/RefImplementations.cc
src/Utils.cc)
diff --git a/bench/ConvUnifiedBenchmark.cc b/bench/ConvUnifiedBenchmark.cc
new file mode 100644
index 0000000..59079c7
--- /dev/null
+++ b/bench/ConvUnifiedBenchmark.cc
@@ -0,0 +1,309 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <algorithm>
+#include <chrono>
+#include <cmath>
+#include <iomanip>
+#include <iostream>
+#include <random>
+#include <vector>
+
+#ifdef _OPENMP
+#include <omp.h>
+#endif
+
+#include "BenchUtils.h"
+#include "fbgemm/Fbgemm.h"
+#include "src/RefImplementations.h"
+
+using namespace std;
+using namespace fbgemm;
+
+// 2D conv shapes
+vector<conv_param_t<2>> shapes_2d = {
+ // MB, IC, OC, IH, IW, G, KH, KW, stride_h, stride_w,
+ // pad_h_top, pad_w_left, pad_h_bottom, pad_w_right
+ // 2D convolutions
+ // regular
+ conv_param_t<>(1, 128, 128, {56, 56}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+ // groupwise
+ conv_param_t<>(1, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+
+ // DW
+ conv_param_t<>(1, 272, 272, {47, 125}, 272, {3, 3}, {1, 1}, {1, 1, 1, 1}),
+};
+
+// 3D conv shapes
+vector<conv_param_t<3>> shapes_3d = {
+ // MB, IC, OC, {IT, IH, IW}, G, {KT, KH, KW}, {stride_t, stride_h, stride_w},
+ // {pad_prev, pad_h_top, pad_w_left, pad_next, pad_h_bottom, pad_w_right}
+ // Regular
+ conv_param_t<3>(1, 64, 64, {32, 56, 56}, 1, {3, 3, 3}, {1, 1, 1}, {1, 1, 1, 1, 1, 1}),
+ // Depthwise
+ conv_param_t<3>(1, 64, 64, {32, 56, 56}, 64, {3, 3, 3}, {1, 1, 1}, {1, 1, 1, 1, 1, 1})
+};
+
+template <int SPATIAL_DIM, typename Acc_t>
+void performance_test(const vector<conv_param_t<SPATIAL_DIM>>& shapes) {
+ bool flush = true;
+ std::vector<char> llc;
+
+ if (flush) {
+ llc.resize(128 * 1024 * 1024, 1.0);
+ }
+
+ constexpr int NWARMUP = 4;
+ constexpr int NITER = 10;
+
+ string header = "MB, IC, OC, ";
+ if (SPATIAL_DIM == 3) {
+ header += "IT, ";
+ }
+ header += "IH, IW, G, ";
+ if (SPATIAL_DIM == 3) {
+ header += "KT, ";
+ }
+ header += "KH, KW, ";
+ if (SPATIAL_DIM == 3) {
+ header += "stride_t, ";
+ }
+ header += "stride_h, stride_w, ";
+ if (SPATIAL_DIM == 3) {
+ header += "pad_t, ";
+ }
+ header += "pad_h, pad_w, ";
+
+ header += "Type, M, N, K, ";
+
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ cout << "WARNING: the timer may be inaccurate when used by multiple threads."
+ << endl;
+ cout << header << "Im2Col (ms), "
+ << "Packing (ms), "
+ << "Kernel (ms), "
+ << "Postprocessing (ms), "
+ << "fbgemmPacked (ms), "
+ << "Total (ms), "
+ << "GOPS" << endl;
+#else
+ cout << setw(6) << header << setw(5) << "GOPS" << endl;
+#endif
+
+ chrono::time_point<chrono::high_resolution_clock> begin, end;
+
+ for (auto conv_p : shapes) {
+ if (conv_p.IC % conv_p.G != 0 || conv_p.OC % conv_p.G != 0) {
+ // invalid shapes
+ continue;
+ }
+ int im_in_dim = accumulate(
+ conv_p.IN_DIM.begin(), conv_p.IN_DIM.end(), 1, multiplies<int>());
+ aligned_vector<uint8_t> Aint8(conv_p.MB * im_in_dim * conv_p.IC);
+
+ int kernel_dim =
+ accumulate(conv_p.K.begin(), conv_p.K.end(), 1, multiplies<int>());
+ aligned_vector<int8_t> Bint8(
+ kernel_dim * conv_p.IC * (conv_p.OC / conv_p.G));
+
+ int im_out_dim = accumulate(
+ conv_p.OUT_DIM.begin(), conv_p.OUT_DIM.end(), 1, multiplies<int>());
+ aligned_vector<int32_t> Cint32_ref(conv_p.MB * im_out_dim * conv_p.OC);
+ aligned_vector<uint8_t> Cint8_ref(Cint32_ref.size(), 0);
+ aligned_vector<int32_t> Cint32_fb(Cint32_ref.size());
+ aligned_vector<uint8_t> Cint8_fb(Cint32_ref.size(), 0);
+ aligned_vector<uint8_t> Cint8_fb2(Cint32_ref.size(), 0);
+ aligned_vector<int32_t> Cint32_fb2(Cint32_ref.size());
+
+ // A matrix (input activations)
+ randFill<uint8_t>(Aint8, 0, 5);
+ int32_t Aint8_zero_point = 4;
+
+ // B matrix (weights)
+ randFill<int8_t>(Bint8, -4, 4);
+ aligned_vector<int32_t> Bint8_zero_point(1);
+ randFill(Bint8_zero_point, -3, -1);
+
+ aligned_vector<float> C_multiplier(Bint8_zero_point.size());
+ randFill(C_multiplier, 0.1234f / 2, 0.1234f * 3 / 2);
+ int32_t C_zero_point = 5;
+
+ aligned_vector<float> Bfp32(Bint8.begin(), Bint8.end());
+
+ // reference implementation
+ conv_ref(
+ conv_p,
+ Aint8.data(),
+ Aint8_zero_point,
+ Bint8.data(),
+ Cint32_ref.data());
+
+ // matrix dimensions after im2col
+ int MDim = conv_p.MB * im_out_dim;
+ int NDim = conv_p.OC / conv_p.G;
+ int KDim = kernel_dim * conv_p.IC;
+ int KDimPerGroup = KDim / conv_p.G;
+
+ int OC_per_G = conv_p.OC / conv_p.G;
+
+ // computing row offset
+ vector<int32_t> row_offsets(MDim);
+ vector<uint8_t> Aint8_im2col(MDim * KDim);
+ im2col_ref(conv_p, Aint8.data(), Aint8_zero_point, Aint8_im2col.data());
+
+ // computing column offset
+ vector<int32_t> col_offsets(conv_p.OC);
+ for (int g = 0; g < conv_p.G; ++g) {
+ col_offsets_with_zero_pt_s8acc32_ref(
+ KDimPerGroup,
+ OC_per_G,
+ OC_per_G,
+ Bint8.data() + g * KDimPerGroup * OC_per_G,
+ Bint8_zero_point.data(),
+ col_offsets.data() + g * OC_per_G,
+ conv_p.OC);
+ }
+
+ for (int g = 0; g < conv_p.G; ++g) {
+ row_offsets_u8acc32_ref(
+ MDim,
+ KDimPerGroup,
+ KDim,
+ Aint8_im2col.data() + g * KDimPerGroup,
+ row_offsets.data());
+
+ requantize_u8acc32_ref(
+ MDim,
+ NDim,
+ conv_p.G * NDim,
+ Cint32_ref.data() + g * NDim,
+ Cint8_ref.data() + g * NDim,
+ C_multiplier.data() + g * NDim / conv_p.OC,
+ C_zero_point,
+ Aint8_zero_point,
+ Bint8_zero_point.data() + g * NDim / conv_p.OC,
+ row_offsets.data(),
+ col_offsets.data() + g * NDim,
+ nullptr,
+ conv_p.OC);
+ }
+
+ double nops = 2.0 * static_cast<double>(NITER) * MDim * NDim * KDim;
+ double ttot = 0.0;
+ string runType;
+
+ PackWeightsForConv<SPATIAL_DIM> packedB(conv_p, Bint8.data());
+
+ // no-op output process objects
+ DoNothing<> doNothingObj{};
+ ReQuantizeOutput<false, QuantizationGranularity::TENSOR> outputProcObj(
+ doNothingObj,
+ C_multiplier.data(),
+ C_zero_point,
+ Aint8_zero_point,
+ Bint8_zero_point.data(),
+ nullptr, // row offsets
+ col_offsets.data(),
+ nullptr, // bias
+ conv_p.OC,
+ conv_p.G);
+
+ runType = "UniConv";
+ ttot = 0;
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ double im2col_time = 0.0;
+ double total_im2col_time = 0.0;
+ double total_packing_time = 0.0;
+ double total_computing_time = 0.0;
+ double total_kernel_time = 0.0;
+ double total_postprocessing_time = 0.0;
+ double total_run_time = 0.0;
+#endif
+ for (auto i = 0; i < NWARMUP + NITER; ++i) {
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ packing_time = 0.0;
+ computing_time = 0.0;
+ kernel_time = 0.0;
+ postprocessing_time = 0.0;
+ run_time = 0.0;
+#endif
+ llc_flush(llc);
+ begin = chrono::high_resolution_clock::now();
+ fbgemmConv(
+ conv_p,
+ Aint8.data(),
+ packedB,
+ Cint8_fb.data(),
+ Cint32_fb.data(),
+ outputProcObj,
+ 0,
+ 1);
+ end = chrono::high_resolution_clock::now();
+
+ if (i >= NWARMUP) {
+ auto dur = chrono::duration_cast<chrono::nanoseconds>(end - begin);
+ ttot += dur.count();
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ total_packing_time += packing_time;
+ total_computing_time += computing_time;
+ total_kernel_time += kernel_time;
+ total_postprocessing_time += postprocessing_time;
+ total_run_time += run_time;
+#endif
+ }
+ }
+
+ cout << conv_p.MB << ", " << conv_p.IC << ", " << conv_p.OC << ", ";
+ for (int i = 0; i < SPATIAL_DIM; ++i) {
+ cout << conv_p.IN_DIM[i] << ", ";
+ }
+ cout << conv_p.G << ", ";
+ for (int i = 0; i < SPATIAL_DIM; ++i) {
+ cout << conv_p.K[i] << ", ";
+ }
+ for (int i = 0; i < SPATIAL_DIM; ++i) {
+ cout << conv_p.stride[i] << ", ";
+ }
+ for (int i = 0; i < SPATIAL_DIM; ++i) {
+ cout << conv_p.pad[i] << ", ";
+ }
+
+ cout << setw(13) << runType << ", " << setw(5) << fixed << setw(5)
+ << setw(6) << MDim << ", " << setw(6) << NDim << ", " << setw(6)
+ << KDim << ", ";
+#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
+ cout << fixed << setprecision(6) << setw(8) << 0 << ", "
+ << total_packing_time / (double)NITER / 1e6 << ", "
+ << total_kernel_time / (double)NITER / 1e6 << ", "
+ << total_postprocessing_time / (double)NITER / 1e6 << ", "
+ << total_run_time / (double)NITER / 1e6 << ", "
+ << ttot / (double)NITER / 1e6 << ", ";
+#endif
+ cout << setprecision(2) << nops / ttot << endl;
+
+ compare_buffers(
+ Cint8_ref.data(),
+ Cint8_fb.data(),
+ MDim,
+ NDim * conv_p.G,
+ NDim * conv_p.G,
+ 5);
+ } // shapes
+}
+
+int main() {
+#ifdef _OPENMP
+ // Use 1 thread unless OMP_NUM_THREADS is explicit set.
+ const char* val = getenv("OMP_NUM_THREADS");
+ if (val == nullptr || !*val) {
+ omp_set_num_threads(1);
+ }
+#endif
+ // performance_test<int16_t>();
+ performance_test<2, int32_t>(shapes_2d);
+ performance_test<3, int32_t>(shapes_3d);
+ return 0;
+}
diff --git a/bench/Depthwise3DBenchmark.cc b/bench/Depthwise3DBenchmark.cc
index c5f8ed9..0efdcac 100644
--- a/bench/Depthwise3DBenchmark.cc
+++ b/bench/Depthwise3DBenchmark.cc
@@ -20,7 +20,7 @@
#include "AlignedVec.h"
#include "BenchUtils.h"
#include "fbgemm/Utils.h"
-#include "src/FbgemmI8DepthwiseAvx2.h"
+#include "fbgemm/FbgemmI8DepthwiseAvx2.h"
#include "src/RefImplementations.h"
using namespace std;
diff --git a/bench/DepthwiseBenchmark.cc b/bench/DepthwiseBenchmark.cc
index 780d83c..96921a1 100644
--- a/bench/DepthwiseBenchmark.cc
+++ b/bench/DepthwiseBenchmark.cc
@@ -18,7 +18,7 @@
#include "AlignedVec.h"
#include "BenchUtils.h"
#include "fbgemm/Utils.h"
-#include "src/FbgemmI8DepthwiseAvx2.h"
+#include "fbgemm/FbgemmI8DepthwiseAvx2.h"
#include "src/RefImplementations.h"
using namespace std;
diff --git a/include/fbgemm/ConvUtils.h b/include/fbgemm/ConvUtils.h
index 1c8251e..11f3dcc 100644
--- a/include/fbgemm/ConvUtils.h
+++ b/include/fbgemm/ConvUtils.h
@@ -101,7 +101,7 @@ struct conv_param_t {
std::to_string(stride[d]) + ", ";
}
for (int d = 0; d < SPATIAL_DIM * 2; ++d) {
- out += "pad_" + dim_string[3 - (SPATIAL_DIM % 3) + d] + ":" +
+ out += "pad_" + dim_string[3 - SPATIAL_DIM + (d % SPATIAL_DIM)] + ":" +
std::to_string(pad[d]);
if (d < SPATIAL_DIM * 2 - 1) {
out += ", ";
diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h
index 720f681..721b12f 100644
--- a/include/fbgemm/Fbgemm.h
+++ b/include/fbgemm/Fbgemm.h
@@ -18,6 +18,7 @@
#include "FbgemmBuild.h"
#include "FbgemmI8Spmdm.h"
#include "QuantUtilsAvx2.h"
+#include "FbgemmI8DepthwiseAvx2.h"
#include "Types.h"
#include "Utils.h"
@@ -524,6 +525,62 @@ class FBGEMM_API PackWeightMatrixForGConv {
};
/**
+ * @brief A container class to keep packed weight tensor for convolution.
+ * The source tensor should already be quantized.
+ *
+ * @tparam SPATIAL_DIM is equal to 2 for 2D convolutions and 3 for 3D
+ * convolutions. Default value is 2.
+ * @tparam T is the datatype for source tensor. Default value is int8.
+ * @tparam accT is the datatype to accumulate into. Default value is int32.
+ */
+template <
+ int SPATIAL_DIM = 2,
+ typename T = std::int8_t,
+ typename accT = std::int32_t>
+class FBGEMM_API PackWeightsForConv {
+ public:
+ using This = PackWeightsForConv<SPATIAL_DIM, T, accT>;
+ using inpType = T;
+ using accType = accT;
+
+ PackWeightsForConv() = delete; // no default constructor
+
+ PackWeightsForConv(
+ const conv_param_t<SPATIAL_DIM>& conv_param,
+ const inpType* sdata,
+ const BlockingFactors* blocking_params = nullptr);
+
+ std::shared_ptr<PackBMatrix<T, accT>> getPackedWForIm2col() {
+ return W_im2col_packed_;
+ }
+
+ std::shared_ptr<Packed3x3ConvMatrix> getPackedWFor2DDW() {
+ return W_dw_2D_packed_;
+ }
+
+ std::shared_ptr<Packed3x3x3ConvMatrix> getPackedWFor3DDW() {
+ return W_dw_3D_packed_;
+ }
+
+ std::shared_ptr<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>>
+ getPackedWForGroupwise() {
+ return W_gconv_packed_;
+ }
+
+ private:
+ // Packed weights if we use im2col based convolution implementation
+ std::shared_ptr<PackBMatrix<T, accT>> W_im2col_packed_;
+ // Packed weights if we use 2D depthwise convolution implementation
+ std::shared_ptr<Packed3x3ConvMatrix> W_dw_2D_packed_;
+ // Packed weights if we use 3D depthwise convolution implementation
+ std::shared_ptr<Packed3x3x3ConvMatrix> W_dw_3D_packed_;
+ // Packed weights if we use groupwise (small channels per group) convolution
+ // implementation
+ std::shared_ptr<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>>
+ W_gconv_packed_;
+};
+
+/**
* @brief Matrix packed for the first input matrix in GEMM (usually activation),
* and row offsets used for requantization is computed during packing.
* Im2col is fused with packing here. The source matrix is already
@@ -1002,6 +1059,7 @@ template <
typename nextOPType = DoNothing<outT, outT>>
class FBGEMM_API ReQuantizeOutput {
public:
+ static constexpr int RELU_FUSED = FUSE_RELU;
using outType = outT;
using inpType = inT;
/**
@@ -1056,12 +1114,18 @@ class FBGEMM_API ReQuantizeOutput {
const float* getCMultiplier() const {
return C_multiplier_;
}
+ std::int32_t getAZeroPoint() const {
+ return Aq_zero_point_;
+ }
std::int32_t getCZeroPoint() const {
return C_zero_point_;
}
const std::int32_t* getBZeroPoint() const {
return Bq_zero_point_;
}
+ const std::int32_t* getRowOffsets() const {
+ return q_row_offsets_;
+ }
const std::int32_t* getColOffsets() const {
return q_col_offsets_;
}
@@ -1072,6 +1136,10 @@ class FBGEMM_API ReQuantizeOutput {
return ncols_;
}
+ void setRowOffsets(const std::int32_t* row_offsets) {
+ q_row_offsets_ = row_offsets;
+ }
+
private:
nextOPType& nextop_;
const float* C_multiplier_;
@@ -1273,6 +1341,12 @@ void convDepthwiseSeparable(
const processOutputType& output);
/**
+ * @brief Is this depthwise convolution optimized?
+ */
+template <int SPATIAL_DIM = 2, typename ACC_T = std::int32_t>
+bool takeDepthWiseFastPath(const conv_param_t<SPATIAL_DIM>& conv_p);
+
+/**
* @brief Is this groupwise convolution supported?
*/
template <int SPATIAL_DIM>
@@ -1351,4 +1425,37 @@ static void fbgemmGetRange(
}
}
+/**
+ * @brief Performs convolution using fastest path available.
+ *
+ * @tparam SPATIAL_DIM It's 2 for 2D convolutions and 3 for 3D convolutions.
+ */
+template <
+ typename processOutputType,
+ int SPATIAL_DIM = 2,
+ typename ACC_T = std::int32_t>
+FBGEMM_API int fbgemmConv(
+ const conv_param_t<SPATIAL_DIM>& conv_p,
+ const std::uint8_t* activations,
+ PackWeightsForConv<SPATIAL_DIM, std::int8_t, ACC_T>& packed_weights,
+ typename processOutputType::outType* out,
+ std::int32_t* outBuffer,
+ processOutputType& outProcess,
+ int thread_id,
+ int num_threads,
+ const BlockingFactors* blocking_params = nullptr);
+
+/**
+ * @brief Returns which fast path to take
+ *
+ * @tparam SPATIAL_DIM It's 2 for 2D convolutions and 3 for 3D convolutions.
+ *
+ * @return optimized_conv_t::depthwise, optimized_conv_t::groupwise or
+ * optimized_conv_t::im2col
+ *
+ */
+template <int SPATIAL_DIM = 2, typename ACC_T = std::int32_t>
+FBGEMM_API optimized_conv_t
+ConvFastPath(const conv_param_t<SPATIAL_DIM>& conv_p);
+
} // namespace fbgemm
diff --git a/include/fbgemm/FbgemmI8DepthwiseAvx2.h b/include/fbgemm/FbgemmI8DepthwiseAvx2.h
new file mode 100644
index 0000000..432687c
--- /dev/null
+++ b/include/fbgemm/FbgemmI8DepthwiseAvx2.h
@@ -0,0 +1,175 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#ifndef I8DEPTHWISE_H
+#define I8DEPTHWISE_H
+#include <cstdint>
+#include "fbgemm/FbgemmBuild.h"
+
+namespace fbgemm {
+
+// KERNEL_PROD is the product of all kernels.
+// For example, KERNEL_PROD = 9 for 3x3, and 27 for 3x3x3.
+template <int KERNEL_PROD>
+class FBGEMM_API PackedDepthWiseConvMatrix {
+ public:
+ // smat in RSG layout
+ PackedDepthWiseConvMatrix(int K, const std::int8_t* smat);
+ virtual ~PackedDepthWiseConvMatrix();
+
+ const std::int8_t* PackedMat() const {
+ return pmat_;
+ }
+
+ private:
+ int K_;
+ std::int8_t* pmat_;
+}; // Packed3x3ConvMatrix
+
+using Packed3x3ConvMatrix = PackedDepthWiseConvMatrix<3 * 3>;
+using Packed3x3x3ConvMatrix = PackedDepthWiseConvMatrix<3 * 3 * 3>;
+
+/**
+ * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8
+ * @params A The input image in NHWK layout
+ * @params Bp The pre-packed filter
+ */
+FBGEMM_API void depthwise_3x3_pad_1(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ std::int32_t A_zero_point,
+ const std::uint8_t* A,
+ const Packed3x3ConvMatrix& Bp,
+ std::int32_t* C,
+ int thread_id = 0,
+ int num_threads = 1);
+
+/**
+ * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8
+ * This version is fused with requantization.
+ *
+ * @col_offsets nullptr if col_offsets are folded into bias
+ */
+FBGEMM_API void depthwise_3x3_pad_1(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ std::int32_t A_zero_point,
+ const std::uint8_t* A,
+ std::int32_t B_zero_point,
+ const Packed3x3ConvMatrix& Bp,
+ float C_multiplier,
+ std::int32_t C_zero_point,
+ std::uint8_t* C,
+ const std::int32_t* col_offsets,
+ const std::int32_t* bias,
+ bool fuse_relu = false,
+ int thread_id = 0,
+ int num_threads = 1);
+
+/**
+ * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8
+ * This version is fused with requantization and uses per-channel quantization.
+ *
+ * @col_offsets nullptr if col_offsets are folded into bias
+ */
+FBGEMM_API void depthwise_3x3_per_channel_quantization_pad_1(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ std::int32_t A_zero_point,
+ const std::uint8_t* A,
+ const std::int32_t* B_zero_point,
+ const Packed3x3ConvMatrix& Bp,
+ const float* C_multiplier,
+ std::int32_t C_zero_point,
+ std::uint8_t* C,
+ const std::int32_t* col_offsets,
+ const std::int32_t* bias,
+ bool fuse_relu = false,
+ int thread_id = 0,
+ int num_threads = 1);
+
+FBGEMM_API void depthwise_3x3x3_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ std::int32_t A_zero_point,
+ const std::uint8_t* A,
+ const Packed3x3x3ConvMatrix& Bp,
+ std::int32_t* C,
+ int thread_id = 0,
+ int num_threads = 1);
+
+/**
+ * @col_offsets nullptr if col_offsets are folded into bias
+ */
+FBGEMM_API void depthwise_3x3x3_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ std::int32_t A_zero_point,
+ const std::uint8_t* A,
+ std::int32_t B_zero_point,
+ const Packed3x3x3ConvMatrix& Bp,
+ float C_multiplier,
+ std::int32_t C_zero_point,
+ std::uint8_t* C,
+ const std::int32_t* col_offsets,
+ const std::int32_t* bias,
+ bool fuse_relu = false,
+ int thread_id = 0,
+ int num_threads = 1);
+
+/**
+ * @col_offsets nullptr if col_offsets are folded into bias
+ */
+FBGEMM_API void depthwise_3x3x3_per_channel_quantization_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ std::int32_t A_zero_point,
+ const std::uint8_t* A,
+ const std::int32_t* B_zero_point,
+ const Packed3x3x3ConvMatrix& Bp,
+ const float* C_multiplier,
+ std::int32_t C_zero_point,
+ std::uint8_t* C,
+ const std::int32_t* col_offsets,
+ const std::int32_t* bias,
+ bool fuse_relu = false,
+ int thread_id = 0,
+ int num_threads = 1);
+
+} // namespace fbgemm
+
+#endif
diff --git a/include/fbgemm/Utils.h b/include/fbgemm/Utils.h
index 58232e7..ef1d4ab 100644
--- a/include/fbgemm/Utils.h
+++ b/include/fbgemm/Utils.h
@@ -32,6 +32,11 @@ enum class matrix_op_t { NoTranspose, Transpose };
enum class inst_set_t { anyarch, avx2, avx512 };
/**
+ * @brief Typed enum for optimized paths for convolutions
+ */
+enum class optimized_conv_t { depthwise, groupwise, im2col };
+
+/**
* @brief Typed enum for implementation type.
*
* ref is reference and opt is optimized.
diff --git a/src/FbgemmConv.cc b/src/FbgemmConv.cc
new file mode 100644
index 0000000..5db63f6
--- /dev/null
+++ b/src/FbgemmConv.cc
@@ -0,0 +1,222 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include <algorithm>
+#include <iostream>
+#include <vector>
+#include "fbgemm/Fbgemm.h"
+
+namespace fbgemm {
+
+template <int SPATIAL_DIM, typename ACC_T>
+bool takeDepthWiseFastPath(const conv_param_t<SPATIAL_DIM>& conv_p) {
+ // Note: Depthwise convolutions (both 2D and 3D) are optimized for the most
+ // common case.
+ return std::is_same<ACC_T, std::int32_t>::value && conv_p.G == conv_p.IC &&
+ conv_p.G == conv_p.OC && conv_p.G % 8 == 0 &&
+ std::all_of(
+ conv_p.stride.begin(),
+ conv_p.stride.end(),
+ [](int i) { return i == 1 || i == 2; }) &&
+ std::all_of(
+ conv_p.K.begin(), conv_p.K.end(), [](int i) { return i == 3; }) &&
+ std::all_of(
+ conv_p.dilation.begin(),
+ conv_p.dilation.end(),
+ [](int i) { return i == 1; }) &&
+ std::all_of(conv_p.pad.begin(), conv_p.pad.end(), [](int i) {
+ return i == 1;
+ });
+}
+
+template <int SPATIAL_DIM, typename ACC_T>
+optimized_conv_t ConvFastPath(const conv_param_t<SPATIAL_DIM>& conv_p) {
+ if (takeDepthWiseFastPath<SPATIAL_DIM, ACC_T>(conv_p)) {
+ return optimized_conv_t::depthwise;
+ } else if (fbgemmOptimizedGConv<SPATIAL_DIM>(conv_p)) {
+ return optimized_conv_t::groupwise;
+ } else {
+ return optimized_conv_t::im2col;
+ }
+}
+
+template <typename processOutputType, int SPATIAL_DIM, typename ACC_T>
+int fbgemmConv(
+ const conv_param_t<SPATIAL_DIM>& conv_p,
+ const std::uint8_t* activations,
+ PackWeightsForConv<SPATIAL_DIM, std::int8_t, ACC_T>& packed_weights,
+ typename processOutputType::outType* out,
+ std::int32_t* outBuffer,
+ processOutputType& outProcess,
+ int thread_id,
+ int num_threads,
+ const BlockingFactors* blocking_params) {
+ static_assert(
+ SPATIAL_DIM == 2 || SPATIAL_DIM == 3,
+ "Only 2D and 3D convolutions are supported");
+ switch (ConvFastPath<SPATIAL_DIM, ACC_T>(conv_p)) {
+ case optimized_conv_t::depthwise: {
+ // 2D and 3D depthwise fast path
+ // std::cout << "Depthwise fast path" << std::endl;
+ const std::int32_t* B_zero_point = outProcess.getBZeroPoint();
+ const float* C_multiplier = outProcess.getCMultiplier();
+ if (SPATIAL_DIM == 3) {
+ static_assert(
+ std::is_same<typename processOutputType::outType, std::uint8_t>::
+ value,
+ "For depthwise, only requantized output is supported");
+ depthwise_3x3x3_pad_1(
+ conv_p.MB, // mini batch
+ conv_p.IN_DIM[0], // T
+ conv_p.IN_DIM[1], // H
+ conv_p.IN_DIM[2], // W
+ conv_p.OC, // output channels
+ conv_p.stride[0], // stride_t
+ conv_p.stride[1], // stride_h
+ conv_p.stride[2], // stride_w
+ outProcess.getAZeroPoint(),
+ activations,
+ B_zero_point[0],
+ *(packed_weights.getPackedWFor3DDW()),
+ C_multiplier[0],
+ outProcess.getCZeroPoint(),
+ out,
+ outProcess.getColOffsets(),
+ outProcess.getBias(),
+ outProcess.RELU_FUSED, // fuse_relu
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_3x3_pad_1(
+ conv_p.MB, // mini batch
+ conv_p.IN_DIM[0], // H
+ conv_p.IN_DIM[1], // W
+ conv_p.OC, // output channels
+ conv_p.stride[0], // stride_h
+ conv_p.stride[1], // stride_w
+ outProcess.getAZeroPoint(),
+ activations,
+ B_zero_point[0],
+ *(packed_weights.getPackedWFor2DDW()),
+ C_multiplier[0],
+ outProcess.getCZeroPoint(),
+ out,
+ outProcess.getColOffsets(),
+ outProcess.getBias(),
+ outProcess.RELU_FUSED, // fuse_relu
+ thread_id,
+ num_threads);
+ }
+ break;
+ }
+ case optimized_conv_t::groupwise: {
+ // optimized groupwise convolution
+ // std::cout << "Groupwise fast path" << std::endl;
+ assert(
+ SPATIAL_DIM == 2 && "Only 2D groupwise convolutions are supported");
+ std::vector<int32_t> row_offset_buf(
+ rowOffsetBufferSizeGConv<SPATIAL_DIM>(conv_p));
+ outProcess.setRowOffsets(row_offset_buf.data());
+ fbgemmGroupwiseConv(
+ conv_p,
+ activations,
+ outProcess.getAZeroPoint(),
+ row_offset_buf.data(),
+ *(packed_weights.getPackedWForGroupwise()),
+ out,
+ outBuffer,
+ outProcess,
+ thread_id,
+ num_threads);
+ break;
+ }
+ case optimized_conv_t::im2col: {
+ // All other convolutions go through im2col-based implementation
+ // std::cout << "Im2col path" << std::endl;
+ std::vector<int32_t> row_offset_buf(
+ PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM>::rowOffsetBufferSize());
+
+ const std::int32_t* b_zero_point = outProcess.getBZeroPoint();
+ bool b_symmetric = b_zero_point[0] == 0;
+ PackAWithIm2Col<uint8_t, ACC_T, SPATIAL_DIM> packA(
+ conv_p,
+ activations,
+ nullptr, /* buffer for packed matrix */
+ outProcess.getAZeroPoint(),
+ row_offset_buf.data(),
+ b_symmetric,
+ blocking_params);
+
+ outProcess.setRowOffsets(row_offset_buf.data());
+ fbgemmPacked(
+ packA,
+ *(packed_weights.getPackedWForIm2col()),
+ out,
+ outBuffer,
+ conv_p.OC,
+ outProcess,
+ thread_id,
+ num_threads,
+ blocking_params);
+ break;
+ }
+ } // switch
+
+ return 0;
+}
+
+#define INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, SPATIAL_DIM) \
+ template int fbgemmConv( \
+ const conv_param_t<SPATIAL_DIM>& conv_p, \
+ const std::uint8_t* activations, \
+ PackWeightsForConv<SPATIAL_DIM, std::int8_t, ACC_T>& packed_weights, \
+ std::uint8_t* out, \
+ std::int32_t* outBuffer, \
+ ReQuantizeOutput<RELU, Q_GRAN>& outProcess, \
+ int thread_id, \
+ int num_threads, \
+ const BlockingFactors* blocking_params);
+
+#define INSTANTIATE_SPATIAL_DIM(ACC_T, Q_GRAN, RELU) \
+ INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, 2); \
+ INSTANTIATE_BASE(ACC_T, Q_GRAN, RELU, 3);
+
+#define INSTANTIATE_RELU(ACC_T, Q_GRAN) \
+ INSTANTIATE_SPATIAL_DIM(ACC_T, Q_GRAN, true); \
+ INSTANTIATE_SPATIAL_DIM(ACC_T, Q_GRAN, false);
+
+#define INSTANTIATE_Q_GRANS(ACC_T) \
+ INSTANTIATE_RELU(ACC_T, QuantizationGranularity::TENSOR); \
+ INSTANTIATE_RELU(ACC_T, QuantizationGranularity::GROUP); \
+ INSTANTIATE_RELU(ACC_T, QuantizationGranularity::OUT_CHANNEL);
+
+INSTANTIATE_Q_GRANS(std::int32_t);
+
+#undef INSTANTIATE_Q_GRANS
+#undef INSTANTIATE_RELU
+#undef INSTANTIATE_SPATIAL_DIM
+#undef INSTANTIATE_BASE
+
+template bool takeDepthWiseFastPath<2, std::int32_t>(
+ const conv_param_t<2>& conv_p);
+template bool takeDepthWiseFastPath<3, std::int32_t>(
+ const conv_param_t<3>& conv_p);
+template bool takeDepthWiseFastPath<2, std::int16_t>(
+ const conv_param_t<2>& conv_p);
+template bool takeDepthWiseFastPath<3, std::int16_t>(
+ const conv_param_t<3>& conv_p);
+
+template optimized_conv_t ConvFastPath<2, std::int32_t>(
+ const conv_param_t<2>& conv_p);
+template optimized_conv_t ConvFastPath<3, std::int32_t>(
+ const conv_param_t<3>& conv_p);
+template optimized_conv_t ConvFastPath<2, std::int16_t>(
+ const conv_param_t<2>& conv_p);
+template optimized_conv_t ConvFastPath<3, std::int16_t>(
+ const conv_param_t<3>& conv_p);
+
+} // namespace fbgemm
diff --git a/src/FbgemmI8DepthwiseAvx2.cc b/src/FbgemmI8DepthwiseAvx2.cc
index 2620e43..ee39faf 100644
--- a/src/FbgemmI8DepthwiseAvx2.cc
+++ b/src/FbgemmI8DepthwiseAvx2.cc
@@ -4,7 +4,7 @@
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
-#include "FbgemmI8DepthwiseAvx2.h"
+#include "fbgemm/FbgemmI8DepthwiseAvx2.h"
#include <algorithm> // for min and max
#include <cassert>
diff --git a/src/FbgemmI8DepthwiseAvx2.h b/src/FbgemmI8DepthwiseAvx2.h
index 069ff77..e2730df 100644
--- a/src/FbgemmI8DepthwiseAvx2.h
+++ b/src/FbgemmI8DepthwiseAvx2.h
@@ -4,7 +4,8 @@
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
-#pragma once
+#ifndef I8DEPTHWISE_H
+#define I8DEPTHWISE_H
#include <cstdint>
#include "fbgemm/FbgemmBuild.h"
@@ -170,3 +171,4 @@ FBGEMM_API void depthwise_3x3x3_per_channel_quantization_pad_1(
int num_threads = 1);
} // namespace fbgemm
+#endif
diff --git a/src/GroupwiseConv.h b/src/GroupwiseConv.h
index 4e539f2..1e6324e 100644
--- a/src/GroupwiseConv.h
+++ b/src/GroupwiseConv.h
@@ -36,10 +36,12 @@ using jit_rowoffset_kernel_fp = void (*)(
int32_t width,
int32_t* row_offset);
-template <typename accT = int32_t>
+template <int SPATIAL_DIM = 2, typename accT = int32_t>
class GenConvKernel {
public:
- GenConvKernel(const conv_param_t<>& conv_param, std::int32_t a_zero_point)
+ GenConvKernel(
+ const conv_param_t<SPATIAL_DIM>& conv_param,
+ std::int32_t a_zero_point)
: WRegs_avx2_{x86::ymm0,
x86::ymm1,
x86::ymm2,
@@ -119,11 +121,11 @@ class GenConvKernel {
~GenConvKernel() {}
template <inst_set_t instSet>
- jit_conv_kernel_fp getOrCreate(const conv_param_t<>& conv_param);
+ jit_conv_kernel_fp getOrCreate(const conv_param_t<SPATIAL_DIM>& conv_param);
template <inst_set_t instSet>
jit_rowoffset_kernel_fp getOrCreateRowOffset(
- const conv_param_t<>& conv_param);
+ const conv_param_t<SPATIAL_DIM>& conv_param);
template <inst_set_t instSet>
void createVector16BitOne(asmjit::X86Emitter* a);
diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc
index 0032e72..e789695 100644
--- a/src/GroupwiseConvAcc32Avx2.cc
+++ b/src/GroupwiseConvAcc32Avx2.cc
@@ -21,24 +21,25 @@ namespace fbgemm {
using namespace std;
-template <typename accT>
-thread_local asmjit::JitRuntime GenConvKernel<accT>::rt_;
+template <int SPATIAL_DIM, typename accT>
+thread_local asmjit::JitRuntime GenConvKernel<SPATIAL_DIM, accT>::rt_;
-template <typename accT>
-thread_local asmjit::CodeHolder GenConvKernel<accT>::code_;
+template <int SPATIAL_DIM, typename accT>
+thread_local asmjit::CodeHolder GenConvKernel<SPATIAL_DIM, accT>::code_;
-template <typename accT>
+template <int SPATIAL_DIM, typename accT>
thread_local std::map<std::tuple<bool, int, int, int>, jit_conv_kernel_fp>
- GenConvKernel<accT>::codeCache_;
+ GenConvKernel<SPATIAL_DIM, accT>::codeCache_;
-template <typename accT>
+template <int SPATIAL_DIM, typename accT>
thread_local std::map<std::tuple<bool, int, int, int>, jit_rowoffset_kernel_fp>
- GenConvKernel<accT>::codeCacheRowOffset_;
+ GenConvKernel<SPATIAL_DIM, accT>::codeCacheRowOffset_;
namespace x86 = asmjit::x86;
+template <int SPATIAL_DIM>
void calculateRowOffsets(
- const conv_param_t<>& conv_param,
+ const conv_param_t<SPATIAL_DIM>& conv_param,
const uint8_t* activations,
int32_t* rowOffsetBuf,
int32_t a_zero_point,
@@ -72,8 +73,9 @@ void calculateRowOffsets(
}
}
+template <int SPATIAL_DIM = 2>
tuple<bool, int, int, int> getKernelSig(
- const conv_param_t<>& conv_param,
+ const conv_param_t<SPATIAL_DIM>& conv_param,
bool isAZeroPointZero) {
int C_per_G = conv_param.IC / conv_param.G;
int K_per_G = conv_param.OC / conv_param.G;
@@ -82,18 +84,18 @@ tuple<bool, int, int, int> getKernelSig(
return kernelSig;
}
-template <typename accT = int32_t>
+template <int SPATIAL_DIM = 2, typename accT = int32_t>
jit_conv_kernel_fp getOrCreateConvKernel(
- const conv_param_t<>& conv_param,
+ const conv_param_t<SPATIAL_DIM>& conv_param,
int a_zero_point) {
// Note: Wrong code is generated if it's not one of the supported convolution
- assert(fbgemmOptimizedGConv<2>(conv_param));
+ assert(fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param));
auto kernelSig = getKernelSig(conv_param, a_zero_point == 0);
- if (GenConvKernel<accT>::codeCache_.find(kernelSig) !=
- GenConvKernel<accT>::codeCache_.end()) {
- return GenConvKernel<accT>::codeCache_[kernelSig];
+ if (GenConvKernel<SPATIAL_DIM, accT>::codeCache_.find(kernelSig) !=
+ GenConvKernel<SPATIAL_DIM, accT>::codeCache_.end()) {
+ return GenConvKernel<SPATIAL_DIM, accT>::codeCache_[kernelSig];
} else {
- auto genObj = GenConvKernel<accT>(conv_param, a_zero_point);
+ auto genObj = GenConvKernel<SPATIAL_DIM, accT>(conv_param, a_zero_point);
// TODO: Instruction set based dispatch
return genObj.template getOrCreate<inst_set_t::avx2>(conv_param);
}
@@ -101,7 +103,7 @@ jit_conv_kernel_fp getOrCreateConvKernel(
template <>
template <>
-void GenConvKernel<int32_t>::createVector8BitOne<inst_set_t::avx2>(
+void GenConvKernel<2, int32_t>::createVector8BitOne<inst_set_t::avx2>(
asmjit::X86Emitter* a) {
// create 8-bit 1s
// i.e., oneReg16BitAvx2_[0:7] contains 0x01, oneReg8BitAvx2_[8:15] contains
@@ -112,7 +114,7 @@ void GenConvKernel<int32_t>::createVector8BitOne<inst_set_t::avx2>(
template <>
template <>
-void GenConvKernel<int32_t>::createVector16BitOne<inst_set_t::avx2>(
+void GenConvKernel<2, int32_t>::createVector16BitOne<inst_set_t::avx2>(
asmjit::X86Emitter* a) {
// create 16-bit 1s
// i.e., oneReg16BitAvx2_[0:15] contains 0x0001, oneReg16BitAvx2_[16:31]
@@ -122,7 +124,7 @@ void GenConvKernel<int32_t>::createVector16BitOne<inst_set_t::avx2>(
}
template <>
template <>
-void GenConvKernel<int32_t>::setToZeroPt<inst_set_t::avx2>(
+void GenConvKernel<2, int32_t>::setToZeroPt<inst_set_t::avx2>(
asmjit::X86Emitter* a,
asmjit::X86Ymm destReg) {
// make destReg all zeros
@@ -140,7 +142,7 @@ void GenConvKernel<int32_t>::setToZeroPt<inst_set_t::avx2>(
template <>
template <>
-void GenConvKernel<int32_t>::genConstForPermutations<inst_set_t::avx2>(
+void GenConvKernel<2, int32_t>::genConstForPermutations<inst_set_t::avx2>(
asmjit::X86Emitter* a) {
asmjit::X86Gp permute_const_reg = a->gpzRef(12);
asmjit::X86Xmm const_reg_xmm = x86::xmm10;
@@ -157,7 +159,7 @@ void GenConvKernel<int32_t>::genConstForPermutations<inst_set_t::avx2>(
template <>
template <>
-void GenConvKernel<int32_t>::storeResult<inst_set_t::avx2>(
+void GenConvKernel<2, int32_t>::storeResult<inst_set_t::avx2>(
asmjit::X86Emitter* a) {
if (C_per_G_ == 4) {
// store with permutation
@@ -168,7 +170,7 @@ void GenConvKernel<int32_t>::storeResult<inst_set_t::avx2>(
template <>
template <>
-void GenConvKernel<int32_t>::storeResultRowoffset<inst_set_t::avx2>(
+void GenConvKernel<2, int32_t>::storeResultRowoffset<inst_set_t::avx2>(
asmjit::X86Emitter* a,
int offset) {
// store
@@ -195,8 +197,9 @@ void GenConvKernel<int32_t>::storeResultRowoffset<inst_set_t::avx2>(
template <>
template <>
-void GenConvKernel<int32_t>::genForLoadingWeights<inst_set_t::avx2>(
- asmjit::X86Emitter* a, int c_offset) {
+void GenConvKernel<2, int32_t>::genForLoadingWeights<inst_set_t::avx2>(
+ asmjit::X86Emitter* a,
+ int c_offset) {
// load weights
for (int r = 0; r < R_; ++r) {
for (int s = 0; s < S_; ++s) {
@@ -221,7 +224,7 @@ void GenConvKernel<int32_t>::genForLoadingWeights<inst_set_t::avx2>(
template <>
template <>
-void GenConvKernel<int32_t>::gen8bitFMA<inst_set_t::avx2>(
+void GenConvKernel<2, int32_t>::gen8bitFMA<inst_set_t::avx2>(
asmjit::X86Emitter* a,
asmjit::X86Ymm aReg,
asmjit::X86Ymm wReg) {
@@ -232,7 +235,7 @@ void GenConvKernel<int32_t>::gen8bitFMA<inst_set_t::avx2>(
template <>
template <>
-void GenConvKernel<int32_t>::gen8BitSumX4<inst_set_t::avx2>(
+void GenConvKernel<2, int32_t>::gen8BitSumX4<inst_set_t::avx2>(
asmjit::X86Emitter* a,
asmjit::X86Ymm aReg) {
a->vpmaddubsw(tmpReg1Avx2_, aReg, oneReg8BitAvx2_);
@@ -242,7 +245,7 @@ void GenConvKernel<int32_t>::gen8BitSumX4<inst_set_t::avx2>(
template <>
template <>
-void GenConvKernel<int32_t>::gen8BitSumX8<inst_set_t::avx2>(
+void GenConvKernel<2, int32_t>::gen8BitSumX8<inst_set_t::avx2>(
asmjit::X86Emitter* a,
asmjit::X86Ymm aReg,
asmjit::X86Ymm bReg) {
@@ -263,7 +266,7 @@ void GenConvKernel<int32_t>::gen8BitSumX8<inst_set_t::avx2>(
template <>
template <>
-void GenConvKernel<int32_t>::gen8BitSumX16<inst_set_t::avx2>(
+void GenConvKernel<2, int32_t>::gen8BitSumX16<inst_set_t::avx2>(
asmjit::X86Emitter* a,
asmjit::X86Ymm aReg,
asmjit::X86Ymm bReg,
@@ -315,7 +318,7 @@ void GenConvKernel<int32_t>::gen8BitSumX16<inst_set_t::avx2>(
template <>
template <>
-void GenConvKernel<int32_t>::gen8BitSum<inst_set_t::avx2>(
+void GenConvKernel<2, int32_t>::gen8BitSum<inst_set_t::avx2>(
asmjit::X86Emitter* a,
int act_offset,
bool use_scratch_reg1 /*=true*/) {
@@ -381,7 +384,7 @@ void GenConvKernel<int32_t>::gen8BitSum<inst_set_t::avx2>(
template <>
template <>
-void GenConvKernel<int32_t>::genZeroPtSum<inst_set_t::avx2>(
+void GenConvKernel<2, int32_t>::genZeroPtSum<inst_set_t::avx2>(
asmjit::X86Emitter* a,
int multiplier) {
a->mov(scratchReg1_, static_cast<asmjit::Imm>(multiplier));
@@ -395,8 +398,9 @@ void GenConvKernel<int32_t>::genZeroPtSum<inst_set_t::avx2>(
template <>
template <>
-void GenConvKernel<int32_t>::genForTopEdge<inst_set_t::avx2>(
- asmjit::X86Emitter* a, int c_offset) {
+void GenConvKernel<2, int32_t>::genForTopEdge<inst_set_t::avx2>(
+ asmjit::X86Emitter* a,
+ int c_offset) {
// top-left corner code
if (c_offset == 0) {
// zero out the results register
@@ -554,8 +558,9 @@ void GenConvKernel<int32_t>::genForTopEdge<inst_set_t::avx2>(
template <>
template <>
-void GenConvKernel<int32_t>::genForLeftEdge<inst_set_t::avx2>(
- asmjit::X86Emitter* a, int c_offset) {
+void GenConvKernel<2, int32_t>::genForLeftEdge<inst_set_t::avx2>(
+ asmjit::X86Emitter* a,
+ int c_offset) {
// left edge excluding corners
asmjit::Label LoopLeftEdge = a->newLabel();
a->mov(loopR1_, static_cast<asmjit::Imm>(H_PAD_));
@@ -620,8 +625,9 @@ void GenConvKernel<int32_t>::genForLeftEdge<inst_set_t::avx2>(
template <>
template <>
-void GenConvKernel<int32_t>::genForRightEdge<inst_set_t::avx2>(
- asmjit::X86Emitter* a, int c_offset) {
+void GenConvKernel<2, int32_t>::genForRightEdge<inst_set_t::avx2>(
+ asmjit::X86Emitter* a,
+ int c_offset) {
// right edge excluding corners
asmjit::Label LoopRightEdge = a->newLabel();
@@ -707,8 +713,9 @@ void GenConvKernel<int32_t>::genForRightEdge<inst_set_t::avx2>(
template <>
template <>
-void GenConvKernel<int32_t>::genForBottomEdge<inst_set_t::avx2>(
- asmjit::X86Emitter* a, int c_offset) {
+void GenConvKernel<2, int32_t>::genForBottomEdge<inst_set_t::avx2>(
+ asmjit::X86Emitter* a,
+ int c_offset) {
// bottom-left corner
// we updating the last row
a->mov(scratchReg1_, H_R_);
@@ -898,8 +905,9 @@ void GenConvKernel<int32_t>::genForBottomEdge<inst_set_t::avx2>(
template <>
template <>
-void GenConvKernel<int32_t>::genCoreInsts<inst_set_t::avx2>(
- asmjit::X86Emitter* a, int c_offset) {
+void GenConvKernel<2, int32_t>::genCoreInsts<inst_set_t::avx2>(
+ asmjit::X86Emitter* a,
+ int c_offset) {
// main compute
asmjit::Label LoopH = a->newLabel();
asmjit::Label LoopW = a->newLabel();
@@ -1000,8 +1008,8 @@ void GenConvKernel<int32_t>::genCoreInsts<inst_set_t::avx2>(
template <>
template <>
-jit_conv_kernel_fp GenConvKernel<int32_t>::getOrCreate<inst_set_t::avx2>(
- const conv_param_t<>& conv_param) {
+jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate<inst_set_t::avx2>(
+ const conv_param_t<2>& conv_param) {
code_.reset(false);
code_.init(rt_.getCodeInfo());
asmjit::X86Assembler assembler(&code_);
@@ -1108,7 +1116,7 @@ jit_conv_kernel_fp GenConvKernel<int32_t>::getOrCreate<inst_set_t::avx2>(
template <>
template <>
-void GenConvKernel<int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>(
+void GenConvKernel<2, int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>(
asmjit::X86Emitter* a) {
// top-left corner code
// zero out the results register
@@ -1204,7 +1212,7 @@ void GenConvKernel<int32_t>::genForTopEdgeRowoffset<inst_set_t::avx2>(
template <>
template <>
-void GenConvKernel<int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>(
+void GenConvKernel<2, int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>(
asmjit::X86Emitter* a) {
// left edge excluding corners
asmjit::Label LoopLeftEdge = a->newLabel();
@@ -1247,7 +1255,7 @@ void GenConvKernel<int32_t>::genForLeftEdgeRowoffset<inst_set_t::avx2>(
template <>
template <>
-void GenConvKernel<int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>(
+void GenConvKernel<2, int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>(
asmjit::X86Emitter* a) {
// right edge excluding corners
asmjit::Label LoopRightEdge = a->newLabel();
@@ -1317,7 +1325,7 @@ void GenConvKernel<int32_t>::genForRightEdgeRowoffset<inst_set_t::avx2>(
template <>
template <>
-void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
+void GenConvKernel<2, int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
asmjit::X86Emitter* a) {
// bottom-left corner
// zero out
@@ -1420,7 +1428,7 @@ void GenConvKernel<int32_t>::genForBottomEdgeRowoffset<inst_set_t::avx2>(
template <>
template <>
-void GenConvKernel<int32_t>::genRowoffsetCore<inst_set_t::avx2>(
+void GenConvKernel<2, int32_t>::genRowoffsetCore<inst_set_t::avx2>(
asmjit::X86Emitter* a) {
// number of uint8 elements in input channels should be a multiple of 32
assert(C_ % 32 == 0);
@@ -1480,8 +1488,8 @@ void GenConvKernel<int32_t>::genRowoffsetCore<inst_set_t::avx2>(
template <>
template <>
jit_rowoffset_kernel_fp
-GenConvKernel<int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
- const conv_param_t<>& conv_param) {
+GenConvKernel<2, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
+ const conv_param_t<2>& conv_param) {
code_.reset(false);
code_.init(rt_.getCodeInfo());
asmjit::X86Assembler assembler(&code_);
@@ -1598,7 +1606,6 @@ void fbgemmGroupwiseConvBase_(
const processOutputType& outProcess,
int thread_id,
int num_threads) {
-
int MB = conv_param.MB;
int H = conv_param.OUT_DIM[0];
int W = conv_param.OUT_DIM[1];
@@ -1608,14 +1615,14 @@ void fbgemmGroupwiseConvBase_(
int oh_ow = conv_param.OUT_DIM[0] * conv_param.OUT_DIM[1];
int ih_iw = conv_param.IN_DIM[0] * conv_param.IN_DIM[1];
- static_assert(SPATIAL_DIM == 2, "3D conv not supported yet");
+ assert(SPATIAL_DIM == 2 && "3D groupwise conv not supported yet");
int32_t* rowOffsetTrDest = rowOffsetBuf ? rowOffsetBuf + 8 * ih_iw : nullptr;
if (fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param)) {
assert(G % 8 == 0);
// generate convolution kernel
jit_conv_kernel_fp fpConv =
- getOrCreateConvKernel<>(conv_param, a_zero_point);
+ getOrCreateConvKernel<SPATIAL_DIM>(conv_param, a_zero_point);
// generate row offset kernel
jit_rowoffset_kernel_fp fpRowoffset =
getOrCreateRowOffsetKernel(conv_param, a_zero_point);
@@ -1670,8 +1677,7 @@ void fbgemmGroupwiseConvBase_(
for (int j = 0; j < gDelta; ++j) {
// calculateRowOffsets(
// conv_param, actStartGroup, rowOffsetBuf, a_zero_point, j);
- int32_t* rowOffsetForCurG =
- rowOffsetTrDest
+ int32_t* rowOffsetForCurG = rowOffsetTrDest
? rowOffsetTrDest + ((g - gOuter) + j) * ih_iw
: nullptr;
// compare_buffers(rowOffsetBuf, rowOffsetForCurG,
@@ -1735,7 +1741,7 @@ void fbgemmGroupwiseConvBase_(
}
}
-}
+} // namespace
template <
typename packed_W,
@@ -1820,13 +1826,13 @@ void fbgemmGroupwiseConv(
int oh_ow = conv_param.OUT_DIM[0] * conv_param.OUT_DIM[1];
int ih_iw = conv_param.IN_DIM[0] * conv_param.IN_DIM[1];
- static_assert(SPATIAL_DIM == 2, "3D conv not supported yet");
+ assert(SPATIAL_DIM == 2 && "3D conv not supported yet");
int32_t* rowOffsetTrDest = rowOffsetBuf ? rowOffsetBuf + 8 * ih_iw : nullptr;
assert(G % 8 == 0);
// generate convolution kernel
jit_conv_kernel_fp fpConv =
- getOrCreateConvKernel<>(conv_param, a_zero_point);
+ getOrCreateConvKernel<SPATIAL_DIM>(conv_param, a_zero_point);
// generate row offset kernel
jit_rowoffset_kernel_fp fpRowoffset =
getOrCreateRowOffsetKernel(conv_param, a_zero_point);
@@ -2132,17 +2138,37 @@ void fbgemmGroupwiseConv(
} // i loop
}
+// 3D not implemented yet
+template <>
+template <>
+jit_conv_kernel_fp GenConvKernel<3, int32_t>::getOrCreate<inst_set_t::avx2>(
+ const conv_param_t<3>& /* unused */) {
+ assert(0 && "not implemented yet");
+ return nullptr;
+}
+
+template <>
+template <>
+jit_rowoffset_kernel_fp
+GenConvKernel<3, int32_t>::getOrCreateRowOffset<inst_set_t::avx2>(
+ const conv_param_t<3>& conv_param) {
+ assert(0 && "not implemented yet");
+ return nullptr;
+}
+
+template <int SPATIAL_DIM = 2>
jit_rowoffset_kernel_fp getOrCreateRowOffsetKernel(
- const conv_param_t<>& conv_param,
+ const conv_param_t<SPATIAL_DIM>& conv_param,
int a_zero_point) {
// Note: Wrong code is generated if it's not one of the supported convolution
- assert(fbgemmOptimizedGConv<2>(conv_param));
+ assert(fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param));
auto kernelSig = getKernelSig(conv_param, a_zero_point == 0);
- if (GenConvKernel<int32_t>::codeCacheRowOffset_.find(kernelSig) !=
- GenConvKernel<int32_t>::codeCacheRowOffset_.end()) {
- return GenConvKernel<int32_t>::codeCacheRowOffset_[kernelSig];
+ if (GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_.find(
+ kernelSig) !=
+ GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_.end()) {
+ return GenConvKernel<SPATIAL_DIM, int32_t>::codeCacheRowOffset_[kernelSig];
} else {
- auto genObj = GenConvKernel<int32_t>(conv_param, a_zero_point);
+ auto genObj = GenConvKernel<SPATIAL_DIM, int32_t>(conv_param, a_zero_point);
// TODO: Instruction set based dispatch
return genObj.template getOrCreateRowOffset<inst_set_t::avx2>(conv_param);
}
@@ -2152,6 +2178,7 @@ template <int SPATIAL_DIM>
int rowOffsetBufferSizeGConv(const conv_param_t<SPATIAL_DIM>& conv_param) {
// row offset buffer should be a able to hold row offsets for however
// number of groups we process at a time.
+ assert(SPATIAL_DIM == 2 && "Only 2D is supported currently");
if (cpuinfo_initialize()) {
if (fbgemmHasAvx512Support()) {
int bufferSize = conv_param.OUT_DIM[0] * conv_param.OUT_DIM[1];
@@ -2186,29 +2213,35 @@ int rowOffsetBufferSizeGConv(const conv_param_t<SPATIAL_DIM>& conv_param) {
}
template int rowOffsetBufferSizeGConv<2>(const conv_param_t<2>& conv_param);
-
-#define INSTANTIATE_BASE(RELU, Q_GRAN) \
- template void fbgemmGroupwiseConv( \
- const conv_param_t<2>& conv_param, \
- const uint8_t* activations, \
- int32_t a_zero_point, \
- std::int32_t* rowOffsetBuf, \
- PackWeightMatrixForGConv<int8_t, int32_t, 2>& packed_weights, \
- uint8_t* out, \
- int32_t* outBuffer, \
- const ReQuantizeOutput<RELU, Q_GRAN>& outProcess, \
- int thread_id, \
+template int rowOffsetBufferSizeGConv<3>(const conv_param_t<3>& conv_param);
+
+#define INSTANTIATE_BASE(RELU, Q_GRAN, SPATIAL_DIM) \
+ template void fbgemmGroupwiseConv( \
+ const conv_param_t<SPATIAL_DIM>& conv_param, \
+ const uint8_t* activations, \
+ int32_t a_zero_point, \
+ std::int32_t* rowOffsetBuf, \
+ PackWeightMatrixForGConv<int8_t, int32_t, SPATIAL_DIM>& packed_weights, \
+ uint8_t* out, \
+ int32_t* outBuffer, \
+ const ReQuantizeOutput<RELU, Q_GRAN>& outProcess, \
+ int thread_id, \
int num_threads);
-#define INSTANTIATE_Q_GRANS(RELU) \
- INSTANTIATE_BASE(RELU, QuantizationGranularity::TENSOR); \
- INSTANTIATE_BASE(RELU, QuantizationGranularity::GROUP); \
- INSTANTIATE_BASE(RELU, QuantizationGranularity::OUT_CHANNEL);
+#define INSTANTIATE_SPATIAL_DIM(RELU, Q_GRAN) \
+ INSTANTIATE_BASE(RELU, Q_GRAN, 2); \
+ INSTANTIATE_BASE(RELU, Q_GRAN, 3);
+
+#define INSTANTIATE_Q_GRANS(RELU) \
+ INSTANTIATE_SPATIAL_DIM(RELU, QuantizationGranularity::TENSOR); \
+ INSTANTIATE_SPATIAL_DIM(RELU, QuantizationGranularity::GROUP); \
+ INSTANTIATE_SPATIAL_DIM(RELU, QuantizationGranularity::OUT_CHANNEL);
INSTANTIATE_Q_GRANS(false);
INSTANTIATE_Q_GRANS(true);
#undef INSTANTIATE_Q_GRANS
+#undef INSTANTIATE_SPATIAL_DIM
#undef INSTANTIATE_BASE
template void fbgemmGroupwiseConv(
diff --git a/src/PackWeightMatrixForGConv.cc b/src/PackWeightMatrixForGConv.cc
index 5870fa5..0fb0e2c 100644
--- a/src/PackWeightMatrixForGConv.cc
+++ b/src/PackWeightMatrixForGConv.cc
@@ -19,7 +19,7 @@ PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::PackWeightMatrixForGConv(
const T* sdata,
T* pdata)
: trans_(trans), conv_param_(conv_param), sdata_(sdata) {
- static_assert(SPATIAL_DIM == 2, "3D conv not supported yet");
+ assert(SPATIAL_DIM == 2 && "3D conv not supported yet");
if (!pdata) {
bufAllocatedHere_ = true;
@@ -111,4 +111,6 @@ void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack() {
template class PackWeightMatrixForGConv<int8_t, int32_t, 2>;
template class PackWeightMatrixForGConv<int8_t, int16_t, 2>;
+template class PackWeightMatrixForGConv<int8_t, int32_t, 3>;
+template class PackWeightMatrixForGConv<int8_t, int16_t, 3>;
} // namespace fbgemm
diff --git a/src/PackWeightsForConv.cc b/src/PackWeightsForConv.cc
new file mode 100644
index 0000000..c811144
--- /dev/null
+++ b/src/PackWeightsForConv.cc
@@ -0,0 +1,71 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <memory>
+#include "fbgemm/Fbgemm.h"
+
+namespace fbgemm {
+
+template <int SPATIAL_DIM, typename T, typename accT>
+PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv(
+ const conv_param_t<SPATIAL_DIM>& conv_p,
+ const T* sdata,
+ const BlockingFactors* blocking_params) {
+ static_assert(
+ SPATIAL_DIM == 2 || SPATIAL_DIM == 3,
+ "Only 2D and 3D convolutions are supported");
+ // Note: The following logic should *exactly* match with what we have in
+ // FbgemmConv.cc
+ switch (ConvFastPath<SPATIAL_DIM, accT>(conv_p)) {
+ case optimized_conv_t::depthwise: {
+ if (SPATIAL_DIM == 3) {
+ W_im2col_packed_ = nullptr;
+ W_dw_2D_packed_ = nullptr;
+ W_dw_3D_packed_ =
+ std::make_shared<Packed3x3x3ConvMatrix>(conv_p.G, sdata);
+ W_gconv_packed_ = nullptr;
+ } else {
+ W_im2col_packed_ = nullptr;
+ W_dw_2D_packed_ =
+ std::make_shared<Packed3x3ConvMatrix>(conv_p.G, sdata);
+ W_dw_3D_packed_ = nullptr;
+ W_gconv_packed_ = nullptr;
+ }
+ break;
+ }
+ case optimized_conv_t::groupwise: {
+ W_im2col_packed_ = nullptr;
+ W_dw_2D_packed_ = nullptr;
+ W_dw_3D_packed_ = nullptr;
+ W_gconv_packed_ =
+ std::make_shared<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>>(
+ matrix_op_t::NoTranspose, conv_p, sdata, nullptr);
+ break;
+ }
+ case optimized_conv_t::im2col: {
+ int NDim = conv_p.OC / conv_p.G;
+ int KDim = conv_p.K[0] * conv_p.K[1] * conv_p.IC;
+ W_im2col_packed_ = std::make_shared<PackBMatrix<T, accT>>(
+ matrix_op_t::NoTranspose,
+ KDim,
+ NDim,
+ sdata,
+ NDim,
+ nullptr,
+ conv_p.G,
+ blocking_params);
+ W_dw_2D_packed_ = nullptr;
+ W_dw_3D_packed_ = nullptr;
+ W_gconv_packed_ = nullptr;
+ break;
+ }
+ } // switch
+}
+
+template class PackWeightsForConv<2, int8_t, int32_t>;
+template class PackWeightsForConv<3, int8_t, int32_t>;
+
+} // namespace fbgemm
diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc
index 72ef93f..b4b0c2b 100644
--- a/src/RefImplementations.cc
+++ b/src/RefImplementations.cc
@@ -285,8 +285,9 @@ int32_t clip_16bit(int32_t x) {
* A: NHWC: NH_0W_0 x C_0
* Ao: NHWC: NH_1W_1 x G RS C_0/G
*/
+template <>
void im2col_ref(
- const conv_param_t<>& conv_p,
+ const conv_param_t<2>& conv_p,
const uint8_t* A,
int32_t A_zero_point,
uint8_t* Ao) {
@@ -346,7 +347,8 @@ void im2col_ref(
* A: NHWC: NT_0H_0W_0 x C_0
* Ao: NHWC: NT_1H_1W_1 x G QRS C_0/G
*/
-void im2col3d_ref(
+template <>
+void im2col_ref(
const conv_param_t<3>& conv_p,
const uint8_t* A,
int32_t A_zero_point,
@@ -422,8 +424,10 @@ void im2col3d_ref(
} // for each n
}
+// 2D Conv
+template <>
void conv_ref(
- const conv_param_t<>& conv_p,
+ const conv_param_t<2>& conv_p,
const uint8_t* A,
int32_t A_zero_point,
const int8_t* B,
@@ -471,7 +475,9 @@ void conv_ref(
} // for each n
}
-void conv3d_ref(
+// 3D Conv
+template <>
+void conv_ref(
const conv_param_t<3>& conv_p,
const uint8_t* A,
int32_t A_zero_point,
@@ -531,10 +537,12 @@ void conv3d_ref(
} // for each n
}
+template <int SPATIAL_DIM>
void transposeConvWeights(
- const conv_param_t<>& conv_p,
+ const conv_param_t<SPATIAL_DIM>& conv_p,
const std::int8_t* src,
std::int8_t* dest) {
+ assert(SPATIAL_DIM == 2 && "Only 2D supported currently");
int R = conv_p.K[0];
int S = conv_p.K[1];
int G = conv_p.G;
@@ -956,4 +964,14 @@ void depthwise_3x3x3_per_channel_quantization_pad_1_ref(
}
};
+template void transposeConvWeights(
+ const conv_param_t<2>& conv_p,
+ const std::int8_t* src,
+ std::int8_t* dest);
+
+template void transposeConvWeights(
+ const conv_param_t<3>& conv_p,
+ const std::int8_t* src,
+ std::int8_t* dest);
+
} // namespace fbgemm
diff --git a/src/RefImplementations.h b/src/RefImplementations.h
index 117c8a1..082bdf1 100644
--- a/src/RefImplementations.h
+++ b/src/RefImplementations.h
@@ -180,15 +180,9 @@ int32_t clip_16bit(int32_t x);
* The filters B are assumed to be in RSCK format.
* The output C is assumed to be in NHoWoC format.
*/
+template <int SPATIAL_DIM = 2>
FBGEMM_API void conv_ref(
- const conv_param_t<>& conv_p,
- const std::uint8_t* A,
- std::int32_t A_zero_point,
- const std::int8_t* B,
- std::int32_t* C);
-
-FBGEMM_API void conv3d_ref(
- const conv_param_t<3>& conv_p,
+ const conv_param_t<SPATIAL_DIM>& conv_p,
const std::uint8_t* A,
std::int32_t A_zero_point,
const std::int8_t* B,
@@ -197,29 +191,26 @@ FBGEMM_API void conv3d_ref(
/*
* @brief Transforms weights from G K/G (R S C/G) to G (R S C/G) K/G format.
*/
+template <int SPATIAL_DIM = 2>
FBGEMM_API void transposeConvWeights(
- const conv_param_t<>& conv_p,
+ const conv_param_t<SPATIAL_DIM>& conv_p,
const std::int8_t* src,
std::int8_t* dest);
/*
* @brief Reference implementation of im2col operation.
+ *
+ * For 2D:
* The input A is assumed to be in NHiWiC format.
* The output A is assumed to be in NHoWoRSC format.
- */
-FBGEMM_API void im2col_ref(
- const conv_param_t<>& conv_p,
- const std::uint8_t* A,
- std::int32_t A_zero_point,
- std::uint8_t* Ao);
-
-/*
- * @brief Reference implementation of im2col 3D operation.
+ *
+ * For 3D:
* The input A is assumed to be in NTiHiWiC format.
* The output A is assumed to be in NToHoWoK0K1K2C format.
*/
-FBGEMM_API void im2col3d_ref(
- const conv_param_t<3>& conv_p,
+template <int SPATIAL_DIM = 2>
+FBGEMM_API void im2col_ref(
+ const conv_param_t<SPATIAL_DIM>& conv_p,
const std::uint8_t* A,
std::int32_t A_zero_point,
std::uint8_t* Ao);
diff --git a/test/I8DepthwiseTest.cc b/test/I8DepthwiseTest.cc
index 7843aae..11bd625 100644
--- a/test/I8DepthwiseTest.cc
+++ b/test/I8DepthwiseTest.cc
@@ -14,7 +14,7 @@
#include "TestUtils.h"
#include "bench/AlignedVec.h"
#include "bench/BenchUtils.h"
-#include "src/FbgemmI8DepthwiseAvx2.h"
+#include "fbgemm/FbgemmI8DepthwiseAvx2.h"
#include "src/RefImplementations.h"
using namespace std;
diff --git a/test/Im2ColFusedRequantizeTest.cc b/test/Im2ColFusedRequantizeTest.cc
index 1ec2a06..d9c2f75 100644
--- a/test/Im2ColFusedRequantizeTest.cc
+++ b/test/Im2ColFusedRequantizeTest.cc
@@ -610,7 +610,7 @@ static void Im2col3DTest(bool b_symmetric) {
// computing row offset
vector<int32_t> row_offsets(MDim);
vector<uint8_t> Aint8_im2col(MDim * KDim);
- im2col3d_ref(conv_p, Aint8.data(), Aint8_zero_point, Aint8_im2col.data());
+ im2col_ref(conv_p, Aint8.data(), Aint8_zero_point, Aint8_im2col.data());
// computing column offset
vector<int32_t> col_offsets(conv_p.G * NDim);
@@ -625,7 +625,7 @@ static void Im2col3DTest(bool b_symmetric) {
ncols_per_quant_group);
}
- conv3d_ref(
+ conv_ref(
conv_p,
Aint8.data(),
Aint8_zero_point,
diff --git a/test/UniConvPackingTest.cc b/test/UniConvPackingTest.cc
new file mode 100644
index 0000000..77552af
--- /dev/null
+++ b/test/UniConvPackingTest.cc
@@ -0,0 +1,148 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include <algorithm>
+#include <random>
+#include <iostream>
+
+
+#include <gtest/gtest.h>
+
+#include "QuantizationHelpers.h"
+#include "TestUtils.h"
+#include "bench/BenchUtils.h"
+#include "fbgemm/Fbgemm.h"
+#include "src/RefImplementations.h"
+
+using namespace std;
+using namespace fbgemm;
+
+namespace {
+
+// tuple represents MB, IC, OC, IT, IH, IW, KH/KW, stride, pad
+class convPackingTest
+ : public testing::TestWithParam<
+ tuple<int, int, int, int, int, int, int, int, int, int>> {};
+
+}; // namespace
+
+INSTANTIATE_TEST_CASE_P(
+ InstantiationName,
+ convPackingTest,
+ ::testing::Combine(
+ ::testing::ValuesIn({1, 2}), // MB
+ ::testing::ValuesIn({16, 32}), // IC
+ ::testing::ValuesIn({16, 32}), // OC
+ ::testing::ValuesIn({17}), // IT
+ ::testing::ValuesIn({10, 30, 55}), // IH
+ ::testing::ValuesIn({10, 30, 55}), // IW
+ ::testing::ValuesIn({1, 4, 16}), // G
+ ::testing::ValuesIn({3, 7}), // kernel
+ ::testing::ValuesIn({1, 2}), // stride
+ ::testing::ValuesIn({1, 2}))); // pad
+
+/**
+ * Test for conv packing
+ */
+TEST_P(convPackingTest, packingTest) {
+ int MB, IC, OC, IT, IH, IW, G, kernel, stride, pad;
+ tie(MB, IC, OC, IT, IH, IW, G, kernel, stride, pad) = GetParam();
+
+ conv_param_t<2> conv_p_2d(
+ MB,
+ IC,
+ OC,
+ {IH, IW},
+ G,
+ {kernel, kernel},
+ {stride, stride},
+ {pad, pad, pad, pad});
+
+ int kernel_dim_2d = kernel * kernel;
+ aligned_vector<int8_t> Bint8_2d(
+ kernel_dim_2d * conv_p_2d.IC * (conv_p_2d.OC / conv_p_2d.G));
+ PackWeightsForConv<2> packedB_2D(conv_p_2d, Bint8_2d.data());
+
+ switch (ConvFastPath<2, int32_t>(conv_p_2d)) {
+ case optimized_conv_t::depthwise: {
+ ASSERT_NE(packedB_2D.getPackedWFor2DDW(), nullptr)
+ << "2D depthwise packed matrix is null";
+ ASSERT_EQ(packedB_2D.getPackedWForIm2col(), nullptr)
+ << "im2col packed matrix should be null";
+ ASSERT_EQ(packedB_2D.getPackedWFor3DDW(), nullptr)
+ << "3D depthwise packed matrix should be null";
+ ASSERT_EQ(packedB_2D.getPackedWForGroupwise(), nullptr)
+ << "groupwise packed matrix should be null";
+ break;
+ }
+ case optimized_conv_t::groupwise: {
+ ASSERT_EQ(packedB_2D.getPackedWForIm2col(), nullptr)
+ << "im2col packed matrix should be null";
+ ASSERT_EQ(packedB_2D.getPackedWFor2DDW(), nullptr)
+ << "2D depthwise packed matrix is null";
+ ASSERT_EQ(packedB_2D.getPackedWFor3DDW(), nullptr)
+ << "3D depthwise packed matrix should be null";
+ ASSERT_NE(packedB_2D.getPackedWForGroupwise(), nullptr)
+ << "Groupwise packed matrix is null";
+ break;
+ }
+ case optimized_conv_t::im2col: {
+ ASSERT_EQ(packedB_2D.getPackedWFor2DDW(), nullptr)
+ << "2D depthwise packed matrix is null";
+ ASSERT_EQ(packedB_2D.getPackedWFor3DDW(), nullptr)
+ << "3D depthwise packed matrix should be null";
+ ASSERT_EQ(packedB_2D.getPackedWForGroupwise(), nullptr)
+ << "groupwise packed matrix should be null";
+ ASSERT_NE(packedB_2D.getPackedWForIm2col(), nullptr)
+ << "im2col packed matrix is null";
+ break;
+ }
+ }
+
+ conv_param_t<3> conv_p_3d(
+ MB,
+ IC,
+ OC,
+ {IT, IH, IW},
+ G,
+ {kernel, kernel, kernel},
+ {stride, stride, stride},
+ {pad, pad, pad, pad, pad, pad});
+
+ int kernel_dim_3d = kernel * kernel * kernel;
+ aligned_vector<int8_t> Bint8_3d(
+ kernel_dim_3d * conv_p_3d.IC * (conv_p_3d.OC / conv_p_3d.G));
+ PackWeightsForConv<3> packedB_3D(conv_p_3d, Bint8_3d.data());
+
+ switch (ConvFastPath<3, int32_t>(conv_p_3d)) {
+ case optimized_conv_t::depthwise: {
+ ASSERT_EQ(packedB_3D.getPackedWFor2DDW(), nullptr)
+ << "2D depthwise packed matrix is null";
+ ASSERT_EQ(packedB_3D.getPackedWForIm2col(), nullptr)
+ << "im2col packed matrix should be null";
+ ASSERT_NE(packedB_3D.getPackedWFor3DDW(), nullptr)
+ << "3D depthwise packed matrix should be null";
+ ASSERT_EQ(packedB_3D.getPackedWForGroupwise(), nullptr)
+ << "groupwise packed matrix should be null";
+ break;
+ }
+ case optimized_conv_t::groupwise: {
+ ASSERT_TRUE(false) << "groupwise are not supported for 3D";
+ break;
+ }
+ case optimized_conv_t::im2col: {
+ ASSERT_EQ(packedB_3D.getPackedWFor2DDW(), nullptr)
+ << "2D depthwise packed matrix is null";
+ ASSERT_EQ(packedB_3D.getPackedWFor3DDW(), nullptr)
+ << "3D depthwise packed matrix should be null";
+ ASSERT_EQ(packedB_3D.getPackedWForGroupwise(), nullptr)
+ << "groupwise packed matrix should be null";
+ ASSERT_NE(packedB_3D.getPackedWForIm2col(), nullptr)
+ << "im2col packed matrix is null";
+ break;
+ }
+ }
+}