Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJongsoo Park <jongsoo@fb.com>2019-09-24 17:06:47 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-09-24 17:25:20 +0300
commit518d8a1832cf1eb1dda2feace1a278e9e4f302ba (patch)
tree532f3e479fa8a96644689c65fe7891b9ce30bcf0
parent53f0c0d175ae4283609a5b251052f9c6598b8aee (diff)
remove template parameter from PackedDepthWiseConvMatrix (#128)
Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/128 We don't really need to have KERNEL_PROD as a compile time constant template parameter in PackedDepthWiseConvMatrix for performance. Removing the template parameter will make generalizing depth-wise convolution to non 3x3 cases easier. This diff only changes fbgemm while maintaining the old interface. The follow-up diff will change Caffe2 code using the old interface and remove the old interface. This diff also splits FbgemmI8DepthwiseAvx2.cc into FbgemmI8Depthwise3DAvx2.cc and PackDepthwiseConvMatrixAvx2.cc to avoid compilation timeouts in OSS build tests. Reviewed By: dskhudia Differential Revision: D17514003 fbshipit-source-id: 2214637ac0762a585f619f0035d3449cc4f7669e
-rw-r--r--CMakeLists.txt2
-rw-r--r--bench/Depthwise3DBenchmark.cc2
-rw-r--r--bench/DepthwiseBenchmark.cc2
-rw-r--r--include/fbgemm/Fbgemm.h18
-rw-r--r--include/fbgemm/FbgemmI8DepthwiseAvx2.h60
-rw-r--r--src/FbgemmConv.cc8
-rw-r--r--src/FbgemmI8Depthwise3DAvx2.cc1415
-rw-r--r--src/FbgemmI8DepthwiseAvx2-inl.h709
-rw-r--r--src/FbgemmI8DepthwiseAvx2.cc2308
-rw-r--r--src/PackDepthwiseConvMatrixAvx2.cc203
-rw-r--r--src/PackWeightsForConv.cc33
-rw-r--r--test/I8DepthwiseTest.cc40
-rw-r--r--test/UniConvTest.cc42
13 files changed, 2433 insertions, 2409 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 817f699..8bf6371 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -75,8 +75,10 @@ endif()
#All the source files that either use avx2 instructions statically
set(FBGEMM_AVX2_SRCS
src/FbgemmFP16UKernelsAvx2.cc
+ src/FbgemmI8Depthwise3DAvx2.cc
src/FbgemmI8DepthwiseAvx2.cc
src/OptimizedKernelsAvx2.cc
+ src/PackDepthwiseConvMatrixAvx2.cc
src/QuantUtilsAvx2.cc
src/UtilsAvx2.cc)
diff --git a/bench/Depthwise3DBenchmark.cc b/bench/Depthwise3DBenchmark.cc
index cd31524..ff2be6f 100644
--- a/bench/Depthwise3DBenchmark.cc
+++ b/bench/Depthwise3DBenchmark.cc
@@ -159,7 +159,7 @@ int main() {
K);
}
- Packed3x3x3ConvMatrix Bp(K, B.data());
+ PackedDepthWiseConvMatrix Bp(K, 3 * 3 * 3, B.data());
double ttot = 0;
double bytes = double(NITER) *
diff --git a/bench/DepthwiseBenchmark.cc b/bench/DepthwiseBenchmark.cc
index 6b2f8b8..6c2ee17 100644
--- a/bench/DepthwiseBenchmark.cc
+++ b/bench/DepthwiseBenchmark.cc
@@ -235,7 +235,7 @@ int main() {
K);
}
- Packed3x3ConvMatrix Bp(K, B.data());
+ PackedDepthWiseConvMatrix Bp(K, 3 * 3, B.data());
double ttot = 0;
double bytes = double(NITER) *
diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h
index 4efd181..3680a48 100644
--- a/include/fbgemm/Fbgemm.h
+++ b/include/fbgemm/Fbgemm.h
@@ -588,12 +588,16 @@ class FBGEMM_API PackWeightsForConv {
return W_im2col_packed_;
}
- std::shared_ptr<Packed3x3ConvMatrix> getPackedWFor2DDW() {
- return W_dw_2D_packed_;
+ std::shared_ptr<PackedDepthWiseConvMatrix> getPackedWForDepthwise() {
+ return W_dw_packed_;
}
- std::shared_ptr<Packed3x3x3ConvMatrix> getPackedWFor3DDW() {
- return W_dw_3D_packed_;
+ std::shared_ptr<PackedDepthWiseConvMatrix> getPackedWFor2DDW() {
+ return W_dw_packed_;
+ }
+
+ std::shared_ptr<PackedDepthWiseConvMatrix> getPackedWFor3DDW() {
+ return W_dw_packed_;
}
std::shared_ptr<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>>
@@ -642,10 +646,8 @@ class FBGEMM_API PackWeightsForConv {
const conv_param_t<SPATIAL_DIM> conv_param_;
// Packed weights if we use im2col based convolution implementation
std::shared_ptr<PackBMatrix<T, accT>> W_im2col_packed_;
- // Packed weights if we use 2D depthwise convolution implementation
- std::shared_ptr<Packed3x3ConvMatrix> W_dw_2D_packed_;
- // Packed weights if we use 3D depthwise convolution implementation
- std::shared_ptr<Packed3x3x3ConvMatrix> W_dw_3D_packed_;
+ // Packed weights if we use depthwise convolution implementation
+ std::shared_ptr<PackedDepthWiseConvMatrix> W_dw_packed_;
// Packed weights if we use groupwise (small channels per group) convolution
// implementation
std::shared_ptr<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>>
diff --git a/include/fbgemm/FbgemmI8DepthwiseAvx2.h b/include/fbgemm/FbgemmI8DepthwiseAvx2.h
index 19946cf..c454b16 100644
--- a/include/fbgemm/FbgemmI8DepthwiseAvx2.h
+++ b/include/fbgemm/FbgemmI8DepthwiseAvx2.h
@@ -11,19 +11,26 @@
namespace fbgemm {
-// KERNEL_PROD is the product of all kernels.
-// For example, KERNEL_PROD = 9 for 3x3, and 27 for 3x3x3.
-template <int KERNEL_PROD>
class FBGEMM_API PackedDepthWiseConvMatrix {
public:
- // smat in GRS layout
- PackedDepthWiseConvMatrix(int K, const std::int8_t* smat);
+ /**
+ * @params K the number of channels (same as the number of groups because
+ * depth-wise convolution has one input/output channel per group)
+ * @params kernel_prod the product of all kernels. For example, kernel_prod =
+ * 9 for 3x3 conv, and 27 for 3x3x3 conv.
+ * @param smat the source unpacked weight in GRS layout
+ */
+ PackedDepthWiseConvMatrix(int K, int kernel_prod, const std::int8_t* smat);
virtual ~PackedDepthWiseConvMatrix();
const std::int8_t* PackedMat() const {
return pmat_;
}
+ int GetKernelProduct() const {
+ return kernel_prod_;
+ }
+
/**
* @brief Unpacks pmat_ into unpack_data.
* Used for recovering the weight matrix into the original format
@@ -36,19 +43,22 @@ class FBGEMM_API PackedDepthWiseConvMatrix {
int addr(int r, int c);
private:
- int K_;
- std::int8_t* pmat_;
-}; // Packed3x3ConvMatrix
+ const int K_; /**< the number of channels */
+ const int kernel_prod_; /** the product of all kernel dims */
+ std::int8_t* pmat_; /** packed weight */
+}; // PackedDepthWiseConvMatrix
-using Packed3x3ConvMatrix = PackedDepthWiseConvMatrix<3 * 3>;
-using Packed3x3x3ConvMatrix = PackedDepthWiseConvMatrix<3 * 3 * 3>;
-using Packed1ConvMatrix = PackedDepthWiseConvMatrix<1>;
-using Packed2ConvMatrix = PackedDepthWiseConvMatrix<2>;
-using Packed3ConvMatrix = PackedDepthWiseConvMatrix<3>;
-using Packed4ConvMatrix = PackedDepthWiseConvMatrix<4>;
-using Packed5ConvMatrix = PackedDepthWiseConvMatrix<5>;
-using Packed10ConvMatrix = PackedDepthWiseConvMatrix<10>;
-using Packed11ConvMatrix = PackedDepthWiseConvMatrix<11>;
+class FBGEMM_API Packed3x3ConvMatrix : public PackedDepthWiseConvMatrix {
+ public:
+ Packed3x3ConvMatrix(int K, const std::int8_t* smat)
+ : PackedDepthWiseConvMatrix(K, 3 * 3, smat) {}
+};
+
+class FBGEMM_API Packed3x3x3ConvMatrix : public PackedDepthWiseConvMatrix {
+ public:
+ Packed3x3x3ConvMatrix(int K, const std::int8_t* smat)
+ : PackedDepthWiseConvMatrix(K, 3 * 3 * 3, smat) {}
+};
/** To be removed. Keeping it just to make sure we don't change C2 files and
* fbgemm files in a single diff
@@ -64,7 +74,7 @@ FBGEMM_API void depthwise_3x3_pad_1(
std::int32_t A_zero_point,
const std::uint8_t* A,
std::int32_t B_zero_point,
- const Packed3x3ConvMatrix& Bp,
+ const PackedDepthWiseConvMatrix& Bp,
float C_multiplier,
std::int32_t C_zero_point,
std::uint8_t* C,
@@ -93,7 +103,7 @@ FBGEMM_API void depthwise_3x3_pad_1(
std::int32_t A_zero_point,
const std::uint8_t* A,
std::int32_t B_zero_point,
- const Packed3x3ConvMatrix& Bp,
+ const PackedDepthWiseConvMatrix& Bp,
float C_multiplier,
std::int32_t C_zero_point,
std::uint8_t* C,
@@ -121,7 +131,7 @@ FBGEMM_API void depthwise_3x3_per_channel_quantization_pad_1(
std::int32_t A_zero_point,
const std::uint8_t* A,
const std::int32_t* B_zero_point,
- const Packed3x3ConvMatrix& Bp,
+ const PackedDepthWiseConvMatrix& Bp,
const float* C_multiplier,
std::int32_t C_zero_point,
std::uint8_t* C,
@@ -145,7 +155,7 @@ FBGEMM_API void depthwise_3x3_per_channel_quantization_pad_1(
std::int32_t A_zero_point,
const std::uint8_t* A,
const std::int32_t* B_zero_point,
- const Packed3x3ConvMatrix& Bp,
+ const PackedDepthWiseConvMatrix& Bp,
const float* C_multiplier,
std::int32_t C_zero_point,
std::uint8_t* C,
@@ -171,7 +181,7 @@ FBGEMM_API void depthwise_3x3x3_pad_1(
std::int32_t A_zero_point,
const std::uint8_t* A,
std::int32_t B_zero_point,
- const Packed3x3x3ConvMatrix& Bp,
+ const PackedDepthWiseConvMatrix& Bp,
float C_multiplier,
std::int32_t C_zero_point,
std::uint8_t* C,
@@ -196,7 +206,7 @@ FBGEMM_API void depthwise_3x3x3_pad_1(
std::int32_t A_zero_point,
const std::uint8_t* A,
std::int32_t B_zero_point,
- const Packed3x3x3ConvMatrix& Bp,
+ const PackedDepthWiseConvMatrix& Bp,
float C_multiplier,
std::int32_t C_zero_point,
std::uint8_t* C,
@@ -223,7 +233,7 @@ FBGEMM_API void depthwise_3x3x3_per_channel_quantization_pad_1(
std::int32_t A_zero_point,
const std::uint8_t* A,
const std::int32_t* B_zero_point,
- const Packed3x3x3ConvMatrix& Bp,
+ const PackedDepthWiseConvMatrix& Bp,
const float* C_multiplier,
std::int32_t C_zero_point,
std::uint8_t* C,
@@ -249,7 +259,7 @@ FBGEMM_API void depthwise_3x3x3_per_channel_quantization_pad_1(
std::int32_t A_zero_point,
const std::uint8_t* A,
const std::int32_t* B_zero_point,
- const Packed3x3x3ConvMatrix& Bp,
+ const PackedDepthWiseConvMatrix& Bp,
const float* C_multiplier,
std::int32_t C_zero_point,
std::uint8_t* C,
diff --git a/src/FbgemmConv.cc b/src/FbgemmConv.cc
index 5a486c0..de833d2 100644
--- a/src/FbgemmConv.cc
+++ b/src/FbgemmConv.cc
@@ -109,7 +109,7 @@ int fbgemmConv(
outProcess.getAZeroPoint(),
activations,
B_zero_point[0],
- *(packed_weights.getPackedWFor3DDW()),
+ *(packed_weights.getPackedWForDepthwise()),
C_multiplier[0],
outProcess.getCZeroPoint(),
out,
@@ -135,7 +135,7 @@ int fbgemmConv(
outProcess.getAZeroPoint(),
activations,
B_zero_point,
- *(packed_weights.getPackedWFor3DDW()),
+ *(packed_weights.getPackedWForDepthwise()),
C_multiplier,
outProcess.getCZeroPoint(),
out,
@@ -163,7 +163,7 @@ int fbgemmConv(
outProcess.getAZeroPoint(),
activations,
B_zero_point[0],
- *(packed_weights.getPackedWFor2DDW()),
+ *(packed_weights.getPackedWForDepthwise()),
C_multiplier[0],
outProcess.getCZeroPoint(),
out,
@@ -188,7 +188,7 @@ int fbgemmConv(
outProcess.getAZeroPoint(),
activations,
B_zero_point,
- *(packed_weights.getPackedWFor2DDW()),
+ *(packed_weights.getPackedWForDepthwise()),
C_multiplier,
outProcess.getCZeroPoint(),
out,
diff --git a/src/FbgemmI8Depthwise3DAvx2.cc b/src/FbgemmI8Depthwise3DAvx2.cc
new file mode 100644
index 0000000..925d265
--- /dev/null
+++ b/src/FbgemmI8Depthwise3DAvx2.cc
@@ -0,0 +1,1415 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "fbgemm/FbgemmI8DepthwiseAvx2.h"
+
+#include <string>
+#include <tuple> // for tie
+
+#include "FbgemmI8DepthwiseAvx2-inl.h"
+
+using namespace std;
+
+namespace fbgemm {
+
+template <
+ bool SUM_A,
+ bool REMAINDER = false,
+ bool PER_CHANNEL_QUANTIZATION = false>
+static inline __attribute__((always_inline)) void inner_prod_3x3x3_packed_(
+ int T,
+ int H,
+ int W,
+ int K,
+ int t_in,
+ int h_in,
+ int w_in,
+ const uint8_t* A,
+ int32_t A_zero_point,
+ const int8_t* Bp,
+ const int32_t* B_zero_point,
+ int32_t* C,
+ int remainder,
+ int32_t* row_offsets) {
+ __m256i A_zero_point_v = _mm256_set1_epi8(static_cast<uint8_t>(A_zero_point));
+ __m256i mask_v = _mm256_setzero_si256();
+ if (REMAINDER) {
+ mask_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(masks[remainder / 4]));
+ }
+
+ // The code below can be written as a simple R*S loop but the compiler
+ // doesn't unroll so we're manually unrolling it.
+ // constexpr int R = 3, S = 3;
+ // array<__m256i, R * S> a_v;
+ // for (int r = 0; r < R; ++r) {
+ // for (int s = 0; s < S; ++s) {
+ // if (h_in + r >= 0 && h_in + r < H && w_in + s >= 0 && w_in + s < W) {
+ // if (REMAINDER) {
+ // a_v[r * S + s] =
+ // _mm256_maskload_epi32((const int *)(A + (r * W + s) * K),
+ // mask_v);
+ // } else {
+ // a_v[r * S + s] =
+ // _mm256_lddqu_si256((const __m256i *)(A + (r * W + s) * K));
+ // }
+ // } else {
+ // a_v[r * S + s] = A_zero_point_v;
+ // }
+ // }
+ // }
+ __m256i a_v[8];
+ a_v[0] = A_zero_point_v;
+ a_v[1] = A_zero_point_v;
+ a_v[2] = A_zero_point_v;
+ a_v[3] = A_zero_point_v;
+ a_v[4] = A_zero_point_v;
+ a_v[5] = A_zero_point_v;
+ a_v[6] = A_zero_point_v;
+ a_v[7] = A_zero_point_v;
+
+ if (t_in >= 0 && t_in < T) {
+ if (h_in >= 0 && h_in < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[0] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[1] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[2] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 2) * K, mask_v);
+ }
+ }
+
+ if (h_in + 1 >= 0 && h_in + 1 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[3] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[4] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[5] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 2) * K, mask_v);
+ }
+ }
+
+ if (h_in + 2 >= 0 && h_in + 2 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[6] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[7] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 1) * K, mask_v);
+ }
+ }
+ }
+
+ __m256i a_sum[4];
+ inner_prod_packed_<8, SUM_A, REMAINDER>(
+ a_v, reinterpret_cast<const __m256i*>(Bp), C, remainder, a_sum);
+
+ a_v[0] = A_zero_point_v;
+ a_v[1] = A_zero_point_v;
+ a_v[2] = A_zero_point_v;
+ a_v[3] = A_zero_point_v;
+ a_v[4] = A_zero_point_v;
+ a_v[5] = A_zero_point_v;
+ a_v[6] = A_zero_point_v;
+ a_v[7] = A_zero_point_v;
+
+ if (t_in >= 0 && t_in < T) {
+ if (h_in + 2 >= 0 && h_in + 2 < H) {
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[0] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 2) * K, mask_v);
+ }
+ }
+ }
+
+ if (t_in + 1 >= 0 && t_in + 1 < T) {
+ if (h_in >= 0 && h_in < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[1] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[2] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[3] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 2) * K, mask_v);
+ }
+ }
+
+ if (h_in + 1 >= 0 && h_in + 1 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[4] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[5] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[6] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 2) * K, mask_v);
+ }
+ }
+
+ if (h_in + 2 >= 0 && h_in + 2 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[7] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 0) * K, mask_v);
+ }
+ }
+ }
+
+ __m256i a_sum_temp[4];
+ inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>(
+ a_v, reinterpret_cast<const __m256i*>(Bp) + 8, C, remainder, a_sum_temp);
+ if (SUM_A) {
+ a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]);
+ a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]);
+ a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]);
+ a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]);
+ }
+
+ a_v[0] = A_zero_point_v;
+ a_v[1] = A_zero_point_v;
+ a_v[2] = A_zero_point_v;
+ a_v[3] = A_zero_point_v;
+ a_v[4] = A_zero_point_v;
+ a_v[5] = A_zero_point_v;
+ a_v[6] = A_zero_point_v;
+ a_v[7] = A_zero_point_v;
+
+ if (t_in + 1 >= 0 && t_in + 1 < T) {
+ if (h_in + 2 >= 0 && h_in + 2 < H) {
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[0] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[1] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 2) * K, mask_v);
+ }
+ }
+ }
+
+ if (t_in + 2 >= 0 && t_in + 2 < T) {
+ if (h_in >= 0 && h_in < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[2] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[3] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[4] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 2) * K, mask_v);
+ }
+ }
+
+ if (h_in + 1 >= 0 && h_in + 1 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[5] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[6] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[7] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 2) * K, mask_v);
+ }
+ }
+ }
+
+ inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>(
+ a_v, reinterpret_cast<const __m256i*>(Bp) + 16, C, remainder, a_sum_temp);
+ if (SUM_A) {
+ a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]);
+ a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]);
+ a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]);
+ a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]);
+ }
+
+ a_v[0] = A_zero_point_v;
+ a_v[1] = A_zero_point_v;
+ a_v[2] = A_zero_point_v;
+
+ if (t_in + 2 >= 0 && t_in + 2 < T) {
+ if (h_in + 2 >= 0 && h_in + 2 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[0] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[1] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[2] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 2) * K, mask_v);
+ }
+ }
+ }
+
+ inner_prod_packed_<3, SUM_A, REMAINDER, true /* acc */>(
+ a_v, reinterpret_cast<const __m256i*>(Bp) + 24, C, remainder, a_sum_temp);
+
+ if (SUM_A) {
+ a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]);
+ a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]);
+ a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]);
+ a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]);
+
+ __m256i B_zero_point_v;
+ for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) {
+ if (PER_CHANNEL_QUANTIZATION) {
+ B_zero_point_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(B_zero_point + i * 8));
+ } else {
+ B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]);
+ }
+ _mm256_store_si256(
+ reinterpret_cast<__m256i*>(&row_offsets[i * 8]),
+ _mm256_mullo_epi32(a_sum[i], B_zero_point_v));
+ }
+ }
+}
+
+template <
+ bool FUSE_RELU,
+ bool HAS_BIAS,
+ bool A_SYMMETRIC,
+ bool B_SYMMETRIC,
+ typename BIAS_TYPE>
+static inline __attribute__((always_inline)) void depthwise_3x3x3_kernel_(
+ int T,
+ int H,
+ int W,
+ int K,
+ int t,
+ int h,
+ int w,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const int8_t* Bp,
+ float C_multiplier,
+ int32_t C_zero_point,
+ int32_t* C_int32,
+ uint8_t* C_uint8,
+ int32_t* row_offsets,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ float act_times_w_scale) {
+ constexpr int R = 3, S = 3;
+ constexpr int PAD_P = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
+ int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+ int t_in = -PAD_P + t * stride_t;
+ int h_in = -PAD_T + h * stride_h;
+ int w_in = -PAD_L + w * stride_w;
+
+ int k;
+ for (k = 0; k < K / 32 * 32; k += 32) {
+ inner_prod_3x3x3_packed_<!B_SYMMETRIC /*SUM_A*/>(
+ T,
+ H,
+ W,
+ K,
+ t_in,
+ h_in,
+ w_in,
+ A + ((t_in * H + h_in) * W + w_in) * K + k,
+ A_zero_point,
+ Bp + k * 28,
+ &B_zero_point,
+ C_int32 + k,
+ 0,
+ B_SYMMETRIC ? nullptr : &row_offsets[k]);
+ }
+ int remainder = K - k;
+ if (remainder) {
+ inner_prod_3x3x3_packed_<!B_SYMMETRIC /*SUM_A*/, true>(
+ T,
+ H,
+ W,
+ K,
+ t_in,
+ h_in,
+ w_in,
+ A + ((t_in * H + h_in) * W + w_in) * K + k,
+ A_zero_point,
+ Bp + k * 28,
+ &B_zero_point,
+ C_int32 + k,
+ remainder,
+ B_SYMMETRIC ? nullptr : &row_offsets[k]);
+ }
+
+ requantize_<
+ FUSE_RELU,
+ HAS_BIAS,
+ false, /*PER_CHAN_QUANT*/
+ A_SYMMETRIC,
+ B_SYMMETRIC>(
+ A_zero_point,
+ &C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K,
+ K,
+ row_offsets,
+ col_offsets,
+ bias,
+ &act_times_w_scale);
+}
+
+template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, typename BIAS_TYPE>
+static inline __attribute__((always_inline)) void
+depthwise_3x3x3_per_channel_quantization_kernel_(
+ int T,
+ int H,
+ int W,
+ int K,
+ int t,
+ int h,
+ int w,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const int8_t* Bp,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ int32_t* C_int32,
+ uint8_t* C_uint8,
+ int32_t* row_offsets,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ const float* act_times_w_scale) {
+ constexpr int R = 3, S = 3;
+ constexpr int PAD_P = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
+ int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+ int t_in = -PAD_P + t * stride_t;
+ int h_in = -PAD_T + h * stride_h;
+ int w_in = -PAD_L + w * stride_w;
+
+ int k;
+ for (k = 0; k < K / 32 * 32; k += 32) {
+ inner_prod_3x3x3_packed_<
+ true, /*SUM_A*/
+ false, /*remainder*/
+ true /*per-channel*/>(
+ T,
+ H,
+ W,
+ K,
+ t_in,
+ h_in,
+ w_in,
+ A + ((t_in * H + h_in) * W + w_in) * K + k,
+ A_zero_point,
+ Bp + k * 28,
+ B_zero_point + k,
+ C_int32 + k,
+ 0,
+ &row_offsets[k]);
+ }
+ int remainder = K - k;
+ if (remainder) {
+ inner_prod_3x3x3_packed_<
+ true, /*SUM_A*/
+ true, /*remainder*/
+ true /*per-channel*/>(
+ T,
+ H,
+ W,
+ K,
+ t_in,
+ h_in,
+ w_in,
+ A + ((t_in * H + h_in) * W + w_in) * K + k,
+ A_zero_point,
+ Bp + k * 28,
+ B_zero_point + k,
+ C_int32 + k,
+ remainder,
+ &row_offsets[k]);
+ }
+ requantize_<
+ FUSE_RELU,
+ HAS_BIAS,
+ true, /*PER_CHAN_QUANT*/
+ A_SYMMETRIC,
+ false /*B_SYMM*/>(
+ A_zero_point,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K,
+ K,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+}
+
+template <
+ bool FUSE_RELU,
+ bool HAS_BIAS,
+ bool A_SYMMETRIC,
+ bool B_SYMMETRIC,
+ typename BIAS_TYPE>
+static inline __attribute__((always_inline)) void depthwise_3x3x3_pad_1_(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ int32_t* C_int32,
+ uint8_t* C_uint8,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ float act_times_w_scale,
+ int thread_id,
+ int num_threads) {
+ assert(K % 8 == 0);
+ constexpr int K_T = 3, K_H = 3, K_W = 3;
+ constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
+ PAD_R = 1;
+ int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
+ int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
+ const int8_t* Bp = B.PackedMat();
+
+ int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64)));
+
+ int n_begin, n_end;
+ int t_begin, t_end, h_begin, h_end;
+ if (N >= num_threads) {
+ int n_per_thread = (N + num_threads - 1) / num_threads;
+ n_begin = std::min(thread_id * n_per_thread, N);
+ n_end = std::min(n_begin + n_per_thread, N);
+ t_begin = 0;
+ t_end = T_OUT;
+ h_begin = 0;
+ h_end = H_OUT;
+ } else {
+ int nthreads_per_n = num_threads / N;
+ n_begin = std::min(thread_id / nthreads_per_n, N);
+ n_end = std::min(n_begin + 1, N);
+
+ int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads);
+ int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads);
+ int nthreads_of_n = tid_of_n_end - tid_of_n_begin;
+ int tid_within_n = thread_id - tid_of_n_begin;
+ assert(tid_within_n >= 0);
+ assert(tid_within_n < nthreads_of_n);
+
+ // n is processed by num_threads_t * num_threads_h 2D grid of threads
+ int num_threads_t, num_threads_h;
+ // num_threads_w <= num_threads_h
+ tie(num_threads_t, num_threads_h) = closest_factors_(nthreads_of_n);
+ int tid_t = tid_within_n / num_threads_h;
+ int tid_h = tid_within_n % num_threads_h;
+
+ int t_per_thread = (T_OUT + num_threads_t - 1) / num_threads_t;
+ t_begin = std::min(tid_t * t_per_thread, T_OUT);
+ t_end = std::min(t_begin + t_per_thread, T_OUT);
+
+ int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h;
+ h_begin = std::min(tid_h * h_per_thread, H_OUT);
+ h_end = std::min(h_begin + h_per_thread, H_OUT);
+ }
+
+ for (int n = n_begin; n < n_end; ++n) {
+ const uint8_t* A_base = A + n * T * H * W * K;
+ uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * K;
+
+ for (int t = t_begin; t < t_end; ++t) {
+ for (int h = h_begin; h < h_end; ++h) {
+ for (int w = 0; w < W_OUT; ++w) {
+ depthwise_3x3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC>(
+ T,
+ H,
+ W,
+ K,
+ t,
+ h,
+ w,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+ } // w
+ } // h
+ } // t
+ } // for each n
+};
+
+template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, typename BIAS_TYPE>
+static inline __attribute__((always_inline)) void
+depthwise_3x3x3_per_channel_quantization_pad_1_(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ int32_t* C_int32,
+ uint8_t* C_uint8,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ const float* act_times_w_scale,
+ int thread_id,
+ int num_threads) {
+ assert(K % 8 == 0);
+ constexpr int K_T = 3, K_H = 3, K_W = 3;
+ constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
+ PAD_R = 1;
+ int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
+ int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
+ const int8_t* Bp = B.PackedMat();
+
+ int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64)));
+
+ int n_begin, n_end;
+ int t_begin, t_end, h_begin, h_end;
+ if (N >= num_threads) {
+ int n_per_thread = (N + num_threads - 1) / num_threads;
+ n_begin = std::min(thread_id * n_per_thread, N);
+ n_end = std::min(n_begin + n_per_thread, N);
+ t_begin = 0;
+ t_end = T_OUT;
+ h_begin = 0;
+ h_end = H_OUT;
+ } else {
+ int nthreads_per_n = num_threads / N;
+ n_begin = std::min(thread_id / nthreads_per_n, N);
+ n_end = std::min(n_begin + 1, N);
+
+ int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads);
+ int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads);
+ int nthreads_of_n = tid_of_n_end - tid_of_n_begin;
+ int tid_within_n = thread_id - tid_of_n_begin;
+ assert(tid_within_n >= 0);
+ assert(tid_within_n < nthreads_of_n);
+
+ // n is processed by num_threads_t * num_threads_h 2D grid of threads
+ int num_threads_t, num_threads_h;
+ // num_threads_w <= num_threads_h
+ tie(num_threads_t, num_threads_h) = closest_factors_(nthreads_of_n);
+ int tid_t = tid_within_n / num_threads_h;
+ int tid_h = tid_within_n % num_threads_h;
+
+ int t_per_thread = (T_OUT + num_threads_t - 1) / num_threads_t;
+ t_begin = std::min(tid_t * t_per_thread, T_OUT);
+ t_end = std::min(t_begin + t_per_thread, T_OUT);
+
+ int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h;
+ h_begin = std::min(tid_h * h_per_thread, H_OUT);
+ h_end = std::min(h_begin + h_per_thread, H_OUT);
+ }
+
+ for (int n = n_begin; n < n_end; ++n) {
+ const uint8_t* A_base = A + n * T * H * W * K;
+ uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * K;
+
+ for (int t = t_begin; t < t_end; ++t) {
+ for (int h = h_begin; h < h_end; ++h) {
+ for (int w = 0; w < W_OUT; ++w) {
+ depthwise_3x3x3_per_channel_quantization_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ BIAS_TYPE>(
+ T,
+ H,
+ W,
+ K,
+ t,
+ h,
+ w,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+ } // w
+ } // h
+ } // t
+ } // for each n
+};
+
+// Dispatch A_SYMMETRIC and B_SYMMETRIC
+template <bool FUSE_RELU, bool HAS_BIAS, typename BIAS_TYPE>
+static void depthwise_3x3x3_pad_1_(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ float act_times_w_scale,
+ int thread_id,
+ int num_threads) {
+ int32_t C_int32_temp[(K + 31) / 32 * 32];
+ if (A_zero_point == 0 || col_offsets == nullptr) {
+ if (B_zero_point == 0) {
+ depthwise_3x3x3_pad_1_<
+ FUSE_RELU,
+ HAS_BIAS,
+ true /*A_symmetric*/,
+ true /*B_symmetric*/,
+ BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_3x3x3_pad_1_<
+ FUSE_RELU,
+ HAS_BIAS,
+ true /*A_symmetric*/,
+ false /*B_symmetric*/,
+ BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ }
+ } else {
+ if (B_zero_point == 0) {
+ depthwise_3x3x3_pad_1_<
+ FUSE_RELU,
+ HAS_BIAS,
+ false /*A_symmetric*/,
+ true /*B_symmetric*/,
+ BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_3x3x3_pad_1_<
+ FUSE_RELU,
+ HAS_BIAS,
+ false /*A_symmetric*/,
+ false /*B_symmetric*/,
+ BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ }
+ }
+}
+
+// Dispatch HAS_BIAS
+template <bool FUSE_RELU, typename BIAS_TYPE>
+static void depthwise_3x3x3_pad_1_(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ float act_times_w_scale,
+ int thread_id,
+ int num_threads) {
+ if (bias) {
+ depthwise_3x3x3_pad_1_<FUSE_RELU, true /*HAS_BIAS*/, BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_3x3x3_pad_1_<FUSE_RELU, false /*HAS_BIAS*/, BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ }
+}
+
+// Dispatch FUSE_RELU
+template <typename BIAS_TYPE>
+void depthwise_3x3x3_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ bool fuse_relu,
+ float act_times_w_scale,
+ int thread_id,
+ int num_threads) {
+ if (B.GetKernelProduct() != 3 * 3 * 3) {
+ string msg =
+ "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " +
+ to_string(3 * 3 * 3) + " but has " + to_string(B.GetKernelProduct());
+ throw logic_error(msg);
+ }
+ if (stride_t == 0 || stride_h == 0 || stride_w == 0 || num_threads == 0) {
+ assert(
+ 0 &&
+ "stride_t == 0 || stride_h == 0 || stride_w == 0 || num_threads == 0");
+ return;
+ }
+ if (N == 0) {
+ // In C2, batch size 0 is allowed, so we should just early return.
+ return;
+ }
+ if (fuse_relu) {
+ depthwise_3x3x3_pad_1_<true /*FUSE_RELU*/, BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_3x3x3_pad_1_<false /*FUSE_RELU*/, BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ }
+}
+
+// Dispatch A_SYMMETRIC
+template <bool FUSE_RELU, bool HAS_BIAS, typename BIAS_TYPE>
+static void depthwise_3x3x3_per_channel_quantization_pad_1_(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ const float* act_times_w_scale,
+ int thread_id,
+ int num_threads) {
+ int32_t C_int32_temp[(K + 31) / 32 * 32];
+ if (A_zero_point == 0 || col_offsets == nullptr) {
+ depthwise_3x3x3_per_channel_quantization_pad_1_<
+ FUSE_RELU,
+ HAS_BIAS,
+ true /*A_SYMM*/,
+ BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_3x3x3_per_channel_quantization_pad_1_<
+ FUSE_RELU,
+ HAS_BIAS,
+ false /*A_SYMM*/,
+ BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ }
+}
+
+// Dispatch HAS_BIAS
+template <bool FUSE_RELU, typename BIAS_TYPE>
+static void depthwise_3x3x3_per_channel_quantization_pad_1_(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ const float* act_times_w_scale,
+ int thread_id,
+ int num_threads) {
+ if (bias) {
+ depthwise_3x3x3_per_channel_quantization_pad_1_<
+ FUSE_RELU,
+ true /* HAS_BIAS */,
+ BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_3x3x3_per_channel_quantization_pad_1_<
+ FUSE_RELU,
+ false /* HAS_BIAS */,
+ BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ }
+}
+
+// Dispatch FUSE_RELU
+template <typename BIAS_TYPE>
+void depthwise_3x3x3_per_channel_quantization_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ bool fuse_relu,
+ const float* act_times_w_scale,
+ int thread_id,
+ int num_threads) {
+ if (B.GetKernelProduct() != 3 * 3 * 3) {
+ string msg =
+ "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " +
+ to_string(3 * 3 * 3) + " but has " + to_string(B.GetKernelProduct());
+ throw logic_error(msg);
+ }
+ if (stride_t == 0 || stride_h == 0 || stride_w == 0 || num_threads == 0) {
+ assert(
+ 0 &&
+ "stride_t == 0 || stride_h == 0 || stride_w == 0 || num_threads == 0");
+ return;
+ }
+ if (N == 0) {
+ // In C2, batch size 0 is allowed, so we should just early return.
+ return;
+ }
+ if (fuse_relu) {
+ depthwise_3x3x3_per_channel_quantization_pad_1_<
+ true /* FUSE_RELU */,
+ BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_3x3x3_per_channel_quantization_pad_1_<
+ false /* FUSE_RELU */,
+ BIAS_TYPE>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ }
+}
+
+// To be removed
+void depthwise_3x3x3_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ bool fuse_relu,
+ int thread_id,
+ int num_threads) {
+ depthwise_3x3x3_pad_1<int32_t>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ fuse_relu,
+ 1.0f, // act_scale * weight_scale
+ thread_id,
+ num_threads);
+}
+
+void depthwise_3x3x3_per_channel_quantization_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ bool fuse_relu,
+ int thread_id,
+ int num_threads) {
+ depthwise_3x3x3_per_channel_quantization_pad_1(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ fuse_relu,
+ nullptr, // act_scale * weight_scale
+ thread_id,
+ num_threads);
+}
+
+template void depthwise_3x3x3_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ bool fuse_relu,
+ float act_times_w_scale,
+ int thread_id,
+ int num_threads);
+
+template void depthwise_3x3x3_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const float* bias,
+ bool fuse_relu,
+ float act_times_w_scale,
+ int thread_id,
+ int num_threads);
+
+template void depthwise_3x3x3_per_channel_quantization_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ bool fuse_relu,
+ const float* act_times_w_scale,
+ int thread_id,
+ int num_threads);
+
+template void depthwise_3x3x3_per_channel_quantization_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const float* bias,
+ bool fuse_relu,
+ const float* act_times_w_scale,
+ int thread_id,
+ int num_threads);
+
+} // namespace fbgemm
diff --git a/src/FbgemmI8DepthwiseAvx2-inl.h b/src/FbgemmI8DepthwiseAvx2-inl.h
new file mode 100644
index 0000000..7ad39fc
--- /dev/null
+++ b/src/FbgemmI8DepthwiseAvx2-inl.h
@@ -0,0 +1,709 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+
+#include <algorithm> // for min and max
+#include <cassert>
+#include <cmath> // for lrintf and sqrt
+#include <cstdint>
+#include <type_traits> // for is_same
+
+#include <immintrin.h>
+
+namespace fbgemm {
+
+// clang-format off
+static int masks[8][8] = {
+ // NOTE: clang-format wants to use a different formatting but the current
+ // formatting should be easier to read.
+ { 0, 0, 0, 0, 0, 0, 0, 0, },
+ { -1, 0, 0, 0, 0, 0, 0, 0, },
+ { -1, -1, 0, 0, 0, 0, 0, 0, },
+ { -1, -1, -1, 0, 0, 0, 0, 0, },
+ { -1, -1, -1, -1, 0, 0, 0, 0, },
+ { -1, -1, -1, -1, -1, 0, 0, 0, },
+ { -1, -1, -1, -1, -1, -1, 0, 0, },
+ { -1, -1, -1, -1, -1, -1, -1, 0, },
+};
+// clang-format on
+
+// c = a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3
+// A is in uint8_t
+// B is in int8_t and pre-interleaved
+// C is in int32_t and 4 registers have results in the following layout:
+// c0_v: c[0:4], c[16:20]
+// c1_v: c[4:8], c[20:24]
+// c2_v: c[8:12], c[24:28]
+// c3_v: c[12:16], c[28:32]
+template <bool SUM_A = false>
+static inline __attribute__((always_inline)) void madd_epi16x4_packed(
+ __m256i a0_v,
+ __m256i a1_v,
+ __m256i a2_v,
+ __m256i a3_v,
+ const __m256i* b,
+ __m256i* c0_v,
+ __m256i* c1_v,
+ __m256i* c2_v,
+ __m256i* c3_v,
+ __m256i* a_sum = nullptr) {
+ __m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v);
+ __m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v);
+ __m256i a23_lo_v = _mm256_unpacklo_epi8(a2_v, a3_v);
+ __m256i a23_hi_v = _mm256_unpackhi_epi8(a2_v, a3_v);
+
+ if (SUM_A) {
+ __m256i one_epi8_v = _mm256_set1_epi8(1);
+ a_sum[0] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a01_lo_v, one_epi8_v), a_sum[0]);
+ a_sum[1] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a01_hi_v, one_epi8_v), a_sum[1]);
+ a_sum[0] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a23_lo_v, one_epi8_v), a_sum[0]);
+ a_sum[1] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a23_hi_v, one_epi8_v), a_sum[1]);
+ }
+
+ __m256i a0_interleaved_v = _mm256_unpacklo_epi16(a01_lo_v, a23_lo_v);
+ __m256i a1_interleaved_v = _mm256_unpackhi_epi16(a01_lo_v, a23_lo_v);
+ __m256i a2_interleaved_v = _mm256_unpacklo_epi16(a01_hi_v, a23_hi_v);
+ __m256i a3_interleaved_v = _mm256_unpackhi_epi16(a01_hi_v, a23_hi_v);
+
+ __m256i b0_v = _mm256_load_si256(b + 0);
+ __m256i b1_v = _mm256_load_si256(b + 1);
+ __m256i b2_v = _mm256_load_si256(b + 2);
+ __m256i b3_v = _mm256_load_si256(b + 3);
+
+ __m256i ab0 = _mm256_maddubs_epi16(a0_interleaved_v, b0_v);
+ __m256i ab1 = _mm256_maddubs_epi16(a1_interleaved_v, b1_v);
+ __m256i ab2 = _mm256_maddubs_epi16(a2_interleaved_v, b2_v);
+ __m256i ab3 = _mm256_maddubs_epi16(a3_interleaved_v, b3_v);
+
+ __m256i one_v = _mm256_set1_epi16(1);
+ *c0_v = _mm256_madd_epi16(ab0, one_v);
+ *c1_v = _mm256_madd_epi16(ab1, one_v);
+ *c2_v = _mm256_madd_epi16(ab2, one_v);
+ *c3_v = _mm256_madd_epi16(ab3, one_v);
+}
+
+// c = a0 * b0 + a1 * b1 + a2 * b2
+// A is in uint8_t
+// B is in int8_t and pre-interleaved
+// C is in int32_t and 4 registers have results in the following layout:
+// c0_v: c[0:4], c[16:20]
+// c1_v: c[4:8], c[20:24]
+// c2_v: c[8:12], c[24:28]
+// c3_v: c[12:16], c[28:32]
+template <bool SUM_A = false>
+static inline __attribute__((always_inline)) void madd_epi16x3_packed(
+ __m256i a0_v,
+ __m256i a1_v,
+ __m256i a2_v,
+ const __m256i* b,
+ __m256i* c0_v,
+ __m256i* c1_v,
+ __m256i* c2_v,
+ __m256i* c3_v,
+ __m256i* a_sum = nullptr) {
+ __m256i zero_v = _mm256_setzero_si256();
+
+ __m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v);
+ __m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v);
+ __m256i a23_lo_v = _mm256_unpacklo_epi8(a2_v, zero_v);
+ __m256i a23_hi_v = _mm256_unpackhi_epi8(a2_v, zero_v);
+
+ if (SUM_A) {
+ __m256i one_epi8_v = _mm256_set1_epi8(1);
+ a_sum[0] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a01_lo_v, one_epi8_v), a_sum[0]);
+ a_sum[1] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a01_hi_v, one_epi8_v), a_sum[1]);
+ a_sum[0] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a23_lo_v, one_epi8_v), a_sum[0]);
+ a_sum[1] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a23_hi_v, one_epi8_v), a_sum[1]);
+ }
+
+ __m256i a0_interleaved_v = _mm256_unpacklo_epi16(a01_lo_v, a23_lo_v);
+ __m256i a1_interleaved_v = _mm256_unpackhi_epi16(a01_lo_v, a23_lo_v);
+ __m256i a2_interleaved_v = _mm256_unpacklo_epi16(a01_hi_v, a23_hi_v);
+ __m256i a3_interleaved_v = _mm256_unpackhi_epi16(a01_hi_v, a23_hi_v);
+
+ __m256i b0_v = _mm256_load_si256(b + 0);
+ __m256i b1_v = _mm256_load_si256(b + 1);
+ __m256i b2_v = _mm256_load_si256(b + 2);
+ __m256i b3_v = _mm256_load_si256(b + 3);
+
+ __m256i ab0 = _mm256_maddubs_epi16(a0_interleaved_v, b0_v);
+ __m256i ab1 = _mm256_maddubs_epi16(a1_interleaved_v, b1_v);
+ __m256i ab2 = _mm256_maddubs_epi16(a2_interleaved_v, b2_v);
+ __m256i ab3 = _mm256_maddubs_epi16(a3_interleaved_v, b3_v);
+
+ __m256i one_v = _mm256_set1_epi16(1);
+ *c0_v = _mm256_madd_epi16(ab0, one_v);
+ *c1_v = _mm256_madd_epi16(ab1, one_v);
+ *c2_v = _mm256_madd_epi16(ab2, one_v);
+ *c3_v = _mm256_madd_epi16(ab3, one_v);
+}
+
+// c = a0 * b0 + a1 * b1
+// A is in uint8_t
+// B is in int8_t and pre-interleaved
+// C is in int32_t and 4 registers have results in the following layout:
+// c0_v: c[0:4], c[4:8]
+// c1_v: c[8:12], c[12:16]
+// c2_v: c[16:20], c[20:24]
+// c3_v: c[24:28], c[28:32]
+template <bool SUM_A = false>
+static inline __attribute__((always_inline)) void madd_epi16x2_packed(
+ __m256i a0_v,
+ __m256i a1_v,
+ const __m256i* b,
+ __m256i* c0_v,
+ __m256i* c1_v,
+ __m256i* c2_v,
+ __m256i* c3_v,
+ __m256i* a_sum = nullptr) {
+ __m256i a_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v);
+ __m256i a_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v);
+
+ if (SUM_A) {
+ __m256i one_epi8_v = _mm256_set1_epi8(1);
+ a_sum[0] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a_lo_v, one_epi8_v), a_sum[0]);
+ a_sum[1] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a_hi_v, one_epi8_v), a_sum[1]);
+ }
+
+ __m256i b0_v = _mm256_load_si256(b + 0);
+ __m256i b1_v = _mm256_load_si256(b + 1);
+
+ __m256i ab_lo_v = _mm256_maddubs_epi16(a_lo_v, b0_v);
+ __m256i ab_hi_v = _mm256_maddubs_epi16(a_hi_v, b1_v);
+
+ *c0_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_lo_v));
+ *c1_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_hi_v));
+ *c2_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_lo_v, 1));
+ *c3_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_hi_v, 1));
+}
+
+// c = a0 * b0
+// A is in uint8_t
+// B is in int8_t and pre-interleaved
+// C is in int32_t and 4 registers have results in the following layout:
+// c0_v: c[0:4], c[4:8]
+// c1_v: c[8:12], c[12:16]
+// c2_v: c[16:20], c[20:24]
+// c3_v: c[24:28], c[28:32]
+template <bool SUM_A = false>
+static inline __attribute__((always_inline)) void madd_epi16_packed(
+ __m256i a_v,
+ const __m256i* b,
+ __m256i* c0_v,
+ __m256i* c1_v,
+ __m256i* c2_v,
+ __m256i* c3_v,
+ __m256i* a_sum = nullptr) {
+ __m256i zero_v = _mm256_setzero_si256();
+
+ __m256i a_lo_v = _mm256_unpacklo_epi8(a_v, zero_v);
+ __m256i a_hi_v = _mm256_unpackhi_epi8(a_v, zero_v);
+
+ if (SUM_A) {
+ __m256i one_epi8_v = _mm256_set1_epi8(1);
+ a_sum[0] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a_lo_v, one_epi8_v), a_sum[0]);
+ a_sum[1] =
+ _mm256_adds_epi16(_mm256_maddubs_epi16(a_hi_v, one_epi8_v), a_sum[1]);
+ }
+
+ __m256i b0_v = _mm256_load_si256(b + 0);
+ __m256i b1_v = _mm256_load_si256(b + 1);
+
+ __m256i ab_lo_v = _mm256_maddubs_epi16(a_lo_v, b0_v);
+ __m256i ab_hi_v = _mm256_maddubs_epi16(a_hi_v, b1_v);
+
+ *c0_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_lo_v));
+ *c1_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_hi_v));
+ *c2_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_lo_v, 1));
+ *c3_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_hi_v, 1));
+}
+
+// K is the number of accumulations we're doing
+template <int K, bool SUM_A = false, bool REMAINDER = false, bool ACC = false>
+static inline __attribute__((always_inline)) void inner_prod_packed_(
+ const __m256i* a_v,
+ const __m256i* Bp,
+ std::int32_t* C,
+ int remainder,
+ __m256i* a_sum = nullptr) {
+ __m256i c[4], c_temp[4];
+ __m256i a_sum_temp[2] = {0, 0};
+
+ int k = 0;
+ if (K >= 4) {
+ madd_epi16x4_packed<SUM_A>(
+ a_v[0],
+ a_v[1],
+ a_v[2],
+ a_v[3],
+ Bp,
+ &c[0],
+ &c[1],
+ &c[2],
+ &c[3],
+ a_sum_temp);
+
+ for (k = 4; k < K / 4 * 4; k += 4) {
+ madd_epi16x4_packed<SUM_A>(
+ a_v[k + 0],
+ a_v[k + 1],
+ a_v[k + 2],
+ a_v[k + 3],
+ Bp + k,
+ &c_temp[0],
+ &c_temp[1],
+ &c_temp[2],
+ &c_temp[3],
+ a_sum_temp);
+
+ c[0] = _mm256_add_epi32(c[0], c_temp[0]);
+ c[1] = _mm256_add_epi32(c[1], c_temp[1]);
+ c[2] = _mm256_add_epi32(c[2], c_temp[2]);
+ c[3] = _mm256_add_epi32(c[3], c_temp[3]);
+ }
+ } else {
+ c[0] = _mm256_setzero_si256();
+ c[1] = _mm256_setzero_si256();
+ c[2] = _mm256_setzero_si256();
+ c[3] = _mm256_setzero_si256();
+ }
+
+ if (K - k == 3) {
+ madd_epi16x3_packed<SUM_A>(
+ a_v[k],
+ a_v[k + 1],
+ a_v[k + 2],
+ Bp + k,
+ &c_temp[0],
+ &c_temp[1],
+ &c_temp[2],
+ &c_temp[3],
+ a_sum_temp);
+
+ c[0] = _mm256_add_epi32(c[0], c_temp[0]);
+ c[1] = _mm256_add_epi32(c[1], c_temp[1]);
+ c[2] = _mm256_add_epi32(c[2], c_temp[2]);
+ c[3] = _mm256_add_epi32(c[3], c_temp[3]);
+ }
+
+ c_temp[0] = _mm256_permute2f128_si256(c[0], c[1], 0x20);
+ c_temp[1] = _mm256_permute2f128_si256(c[2], c[3], 0x20);
+ c_temp[2] = _mm256_permute2f128_si256(c[0], c[1], 0x31);
+ c_temp[3] = _mm256_permute2f128_si256(c[2], c[3], 0x31);
+
+ if (K - k == 0 || K - k == 3) {
+ c[0] = c_temp[0];
+ c[1] = c_temp[1];
+ c[2] = c_temp[2];
+ c[3] = c_temp[3];
+ } else {
+ if (K - k == 1) {
+ madd_epi16_packed<SUM_A>(
+ a_v[k], Bp + k, &c[0], &c[1], &c[2], &c[3], a_sum_temp);
+ } else if (K - k == 2) {
+ madd_epi16x2_packed<SUM_A>(
+ a_v[k], a_v[k + 1], Bp + k, &c[0], &c[1], &c[2], &c[3], a_sum_temp);
+ }
+
+ c[0] = _mm256_add_epi32(c[0], c_temp[0]);
+ c[1] = _mm256_add_epi32(c[1], c_temp[1]);
+ c[2] = _mm256_add_epi32(c[2], c_temp[2]);
+ c[3] = _mm256_add_epi32(c[3], c_temp[3]);
+ }
+
+ if (REMAINDER) {
+ for (int r = 0; r < remainder / 8; ++r) {
+ if (ACC) {
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(C + r * 8),
+ _mm256_add_epi32(
+ _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + r * 8)),
+ c[r]));
+ } else {
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + r * 8), c[r]);
+ }
+ }
+ } else {
+ if (ACC) {
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(C),
+ _mm256_add_epi32(
+ _mm256_loadu_si256(reinterpret_cast<__m256i*>(C)), c[0]));
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(C + 8),
+ _mm256_add_epi32(
+ _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 8)), c[1]));
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(C + 16),
+ _mm256_add_epi32(
+ _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 16)), c[2]));
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(C + 24),
+ _mm256_add_epi32(
+ _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 24)), c[3]));
+ } else {
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(C), c[0]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 8), c[1]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 16), c[2]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 24), c[3]);
+ }
+ }
+
+ if (SUM_A) {
+ a_sum[0] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(a_sum_temp[0]));
+ a_sum[1] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(a_sum_temp[1]));
+ a_sum[2] =
+ _mm256_cvtepi16_epi32(_mm256_extracti128_si256(a_sum_temp[0], 1));
+ a_sum[3] =
+ _mm256_cvtepi16_epi32(_mm256_extracti128_si256(a_sum_temp[1], 1));
+ }
+}
+
+// Almost same as ReQuantizeOutput in OutputProcessing-inh.h but different
+// row_offsets for each row because of depth-wise convolution
+template <
+ bool FUSE_RELU,
+ bool HAS_BIAS,
+ bool PER_CHANNEL_QUANTIZATION,
+ bool A_SYMMETRIC,
+ bool B_SYMMETRIC,
+ typename BIAS_TYPE>
+static inline __attribute__((always_inline)) void requantize_(
+ std::int32_t A_zero_point,
+ const float* C_multiplier,
+ std::int32_t C_zero_point,
+ const std::int32_t* C_int32,
+ std::uint8_t* C_uint8,
+ int n,
+ const std::int32_t* row_offsets,
+ const std::int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ const float* act_times_w_scale = nullptr) {
+ __m256 multiplier_v = _mm256_setzero_ps();
+ // Broadcasted reciprocal of act_times_w_scale
+ __m256 act_times_w_rcp_v = _mm256_setzero_ps();
+ if (!PER_CHANNEL_QUANTIZATION) {
+ multiplier_v = _mm256_set1_ps(*C_multiplier);
+ if (std::is_same<BIAS_TYPE, float>::value) {
+ act_times_w_rcp_v = _mm256_set1_ps(1.0f / (*act_times_w_scale));
+ }
+ }
+
+ __m256i min_v = _mm256_set1_epi8(static_cast<std::uint8_t>(0));
+ __m256i max_v = _mm256_set1_epi8(static_cast<std::uint8_t>(255));
+
+ if (A_SYMMETRIC) {
+ assert(A_zero_point == 0 || col_offsets == nullptr);
+ }
+ __m256i A_zero_point_v = _mm256_set1_epi32(A_zero_point);
+ __m256i C_zero_point_epi16_v = _mm256_set1_epi16(C_zero_point);
+ __m256i C_zero_point_epi8_v = _mm256_set1_epi8(C_zero_point);
+
+ __m256i permute_mask_v =
+ _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
+
+ constexpr int VLEN = 8;
+ int j = 0;
+ for (; j < n / (VLEN * 4) * (VLEN * 4); j += (VLEN * 4)) {
+ __m256i x_v =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j));
+ __m256i y_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(C_int32 + j + VLEN));
+ __m256i z_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(C_int32 + j + 2 * VLEN));
+ __m256i w_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(C_int32 + j + 3 * VLEN));
+
+ __m256i row_offset_v;
+ if (!B_SYMMETRIC) {
+ row_offset_v =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(row_offsets + j));
+ x_v = _mm256_sub_epi32(x_v, row_offset_v);
+ }
+ __m256i col_off_v;
+ if (!A_SYMMETRIC) {
+ col_off_v = _mm256_mullo_epi32(
+ A_zero_point_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(col_offsets + j)));
+ x_v = _mm256_sub_epi32(x_v, col_off_v);
+ }
+
+ if (!B_SYMMETRIC) {
+ row_offset_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(row_offsets + j + VLEN));
+ y_v = _mm256_sub_epi32(y_v, row_offset_v);
+ }
+ if (!A_SYMMETRIC) {
+ col_off_v = _mm256_mullo_epi32(
+ A_zero_point_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(col_offsets + j + VLEN)));
+ y_v = _mm256_sub_epi32(y_v, col_off_v);
+ }
+
+ if (!B_SYMMETRIC) {
+ row_offset_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(row_offsets + j + 2 * VLEN));
+ z_v = _mm256_sub_epi32(z_v, row_offset_v);
+ }
+ if (!A_SYMMETRIC) {
+ col_off_v = _mm256_mullo_epi32(
+ A_zero_point_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(col_offsets + j + 2 * VLEN)));
+ z_v = _mm256_sub_epi32(z_v, col_off_v);
+ }
+
+ if (!B_SYMMETRIC) {
+ row_offset_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(row_offsets + j + 3 * VLEN));
+ w_v = _mm256_sub_epi32(w_v, row_offset_v);
+ }
+ if (!A_SYMMETRIC) {
+ col_off_v = _mm256_mullo_epi32(
+ A_zero_point_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(col_offsets + j + 3 * VLEN)));
+ w_v = _mm256_sub_epi32(w_v, col_off_v);
+ }
+
+ // convert to float
+ __m256 xf_v, yf_v, zf_v, wf_v;
+ if (HAS_BIAS) { // static if
+ if (std::is_same<BIAS_TYPE, float>::value) {
+ __m256 x_bias_v, y_bias_v, z_bias_v, w_bias_v;
+ if (PER_CHANNEL_QUANTIZATION) {
+ x_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(bias + j + 0 * VLEN)),
+ _mm256_loadu_ps(act_times_w_scale + j + 0 * VLEN));
+ y_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(bias + j + 1 * VLEN)),
+ _mm256_loadu_ps(act_times_w_scale + j + 1 * VLEN));
+ z_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(bias + j + 2 * VLEN)),
+ _mm256_loadu_ps(act_times_w_scale + j + 2 * VLEN));
+ w_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(bias + j + 3 * VLEN)),
+ _mm256_loadu_ps(act_times_w_scale + j + 3 * VLEN));
+ } else {
+ x_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(bias + j + 0 * VLEN)),
+ act_times_w_rcp_v);
+ y_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(bias + j + 1 * VLEN)),
+ act_times_w_rcp_v);
+ z_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(bias + j + 2 * VLEN)),
+ act_times_w_rcp_v);
+ w_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(
+ reinterpret_cast<const float*>(bias + j + 3 * VLEN)),
+ act_times_w_rcp_v);
+ }
+ xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
+ yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), y_bias_v);
+ zf_v = _mm256_add_ps(_mm256_cvtepi32_ps(z_v), z_bias_v);
+ wf_v = _mm256_add_ps(_mm256_cvtepi32_ps(w_v), w_bias_v);
+ } else {
+ x_v = _mm256_add_epi32(
+ x_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(bias + j + 0 * VLEN)));
+ y_v = _mm256_add_epi32(
+ y_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(bias + j + 1 * VLEN)));
+ z_v = _mm256_add_epi32(
+ z_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(bias + j + 2 * VLEN)));
+ w_v = _mm256_add_epi32(
+ w_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(bias + j + 3 * VLEN)));
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ yf_v = _mm256_cvtepi32_ps(y_v);
+ zf_v = _mm256_cvtepi32_ps(z_v);
+ wf_v = _mm256_cvtepi32_ps(w_v);
+ }
+ } else {
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ yf_v = _mm256_cvtepi32_ps(y_v);
+ zf_v = _mm256_cvtepi32_ps(z_v);
+ wf_v = _mm256_cvtepi32_ps(w_v);
+ }
+
+ if (PER_CHANNEL_QUANTIZATION) {
+ multiplier_v = _mm256_loadu_ps(C_multiplier + j + 0 * VLEN);
+ }
+ __m256 x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
+ if (PER_CHANNEL_QUANTIZATION) {
+ multiplier_v = _mm256_loadu_ps(C_multiplier + j + 1 * VLEN);
+ }
+ __m256 y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v);
+ if (PER_CHANNEL_QUANTIZATION) {
+ multiplier_v = _mm256_loadu_ps(C_multiplier + j + 2 * VLEN);
+ }
+ __m256 z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v);
+ if (PER_CHANNEL_QUANTIZATION) {
+ multiplier_v = _mm256_loadu_ps(C_multiplier + j + 3 * VLEN);
+ }
+ __m256 w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v);
+
+ __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
+ __m256i y_rounded_v = _mm256_cvtps_epi32(y_scaled_v);
+ __m256i z_rounded_v = _mm256_cvtps_epi32(z_scaled_v);
+ __m256i w_rounded_v = _mm256_cvtps_epi32(w_scaled_v);
+
+ __m256i xy_packed_v = _mm256_adds_epi16(
+ _mm256_packs_epi32(x_rounded_v, y_rounded_v), C_zero_point_epi16_v);
+ __m256i zw_packed_v = _mm256_adds_epi16(
+ _mm256_packs_epi32(z_rounded_v, w_rounded_v), C_zero_point_epi16_v);
+ __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v);
+ __m256i xyzw_clamped_v = _mm256_max_epu8(
+ FUSE_RELU ? C_zero_point_epi8_v : min_v,
+ _mm256_min_epu8(xyzw_packed_v, max_v));
+
+ xyzw_clamped_v =
+ _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
+
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(C_uint8 + j), xyzw_clamped_v);
+ } // j loop vectorized and unrolled 4x
+
+ for (; j < n / VLEN * VLEN; j += VLEN) {
+ __m256i x_v =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j));
+
+ if (!B_SYMMETRIC) {
+ __m256i row_offset_v =
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(row_offsets + j));
+ x_v = _mm256_sub_epi32(x_v, row_offset_v);
+ }
+ if (!A_SYMMETRIC) {
+ __m256i col_off_v = _mm256_mullo_epi32(
+ A_zero_point_v,
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(col_offsets + j)));
+ x_v = _mm256_sub_epi32(x_v, col_off_v);
+ }
+
+ // Convert to float
+ __m256 xf_v;
+ if (HAS_BIAS) { // static if
+ if (std::is_same<BIAS_TYPE, float>::value) {
+ __m256 x_bias_v;
+ if (PER_CHANNEL_QUANTIZATION) {
+ x_bias_v = _mm256_div_ps(
+ _mm256_loadu_ps(reinterpret_cast<const float*>(bias + j)),
+ _mm256_loadu_ps(act_times_w_scale + j));
+ } else {
+ x_bias_v = _mm256_mul_ps(
+ _mm256_loadu_ps(reinterpret_cast<const float*>(bias + j)),
+ act_times_w_rcp_v);
+ }
+ xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
+ } else {
+ x_v = _mm256_add_epi32(
+ x_v,
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias + j)));
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ }
+ } else {
+ xf_v = _mm256_cvtepi32_ps(x_v);
+ }
+
+ if (PER_CHANNEL_QUANTIZATION) {
+ multiplier_v = _mm256_loadu_ps(C_multiplier + j);
+ }
+ __m256 x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
+ __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
+
+ __m256i x_packed_v = _mm256_adds_epi16(
+ _mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()),
+ C_zero_point_epi16_v);
+ x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256());
+ __m256i x_clamped_v = _mm256_max_epu8(
+ FUSE_RELU ? C_zero_point_epi8_v : min_v,
+ _mm256_min_epu8(x_packed_v, max_v));
+
+ x_clamped_v = _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v);
+
+ _mm_storel_epi64(
+ reinterpret_cast<__m128i*>(C_uint8 + j),
+ _mm256_castsi256_si128(x_clamped_v));
+ } // j loop vectorized
+
+ for (; j < n; ++j) {
+ std::int32_t raw = C_int32[j];
+ if (!B_SYMMETRIC) {
+ raw -= row_offsets[j];
+ }
+ if (!A_SYMMETRIC) {
+ raw -= A_zero_point * col_offsets[j];
+ }
+ float raw_f;
+ if (HAS_BIAS) { // static if
+ if (std::is_same<BIAS_TYPE, float>::value) {
+ raw_f = raw;
+ raw_f += bias[j] / act_times_w_scale[PER_CHANNEL_QUANTIZATION ? j : 0];
+ } else {
+ raw += bias[j];
+ raw_f = raw;
+ }
+ } else {
+ raw_f = raw;
+ }
+
+ float ab = raw_f * C_multiplier[PER_CHANNEL_QUANTIZATION ? j : 0];
+ long rounded = lrintf(ab) + C_zero_point;
+
+ C_uint8[j] = std::max(
+ FUSE_RELU ? static_cast<long>(C_zero_point) : 0l,
+ std::min(255l, rounded));
+ }
+}
+
+template <bool REMAINDER>
+static inline __attribute__((always_inline)) __m256i load_a(
+ const std::uint8_t* A,
+ __m256i mask_v) {
+ if (REMAINDER) {
+ return _mm256_maskload_epi32(reinterpret_cast<const int*>(A), mask_v);
+ } else {
+ return _mm256_lddqu_si256(reinterpret_cast<const __m256i*>(A));
+ }
+}
+
+static inline std::pair<int, int> closest_factors_(int n) {
+ int a = static_cast<int>(std::sqrt(n));
+ while (n % a != 0) {
+ a--;
+ }
+ return {a, n / a}; // a <= n / a
+}
+
+} // namespace fbgemm
diff --git a/src/FbgemmI8DepthwiseAvx2.cc b/src/FbgemmI8DepthwiseAvx2.cc
index aa7b90e..9d57d5b 100644
--- a/src/FbgemmI8DepthwiseAvx2.cc
+++ b/src/FbgemmI8DepthwiseAvx2.cc
@@ -6,562 +6,15 @@
*/
#include "fbgemm/FbgemmI8DepthwiseAvx2.h"
-#include <algorithm> // for min and max
-#include <cassert>
-#include <cmath> // for lrintf and sqrt
+#include <string>
#include <tuple> // for tie
-#include <type_traits> // for is_same
-#include <immintrin.h>
+#include "FbgemmI8DepthwiseAvx2-inl.h"
using namespace std;
namespace fbgemm {
-// clang-format off
-static int masks[8][8] = {
- // NOTE: clang-format wants to use a different formatting but the current
- // formatting should be easier to read.
- { 0, 0, 0, 0, 0, 0, 0, 0, },
- { -1, 0, 0, 0, 0, 0, 0, 0, },
- { -1, -1, 0, 0, 0, 0, 0, 0, },
- { -1, -1, -1, 0, 0, 0, 0, 0, },
- { -1, -1, -1, -1, 0, 0, 0, 0, },
- { -1, -1, -1, -1, -1, 0, 0, 0, },
- { -1, -1, -1, -1, -1, -1, 0, 0, },
- { -1, -1, -1, -1, -1, -1, -1, 0, },
-};
-// clang-format on
-
-template <int KERNEL_PROD>
-PackedDepthWiseConvMatrix<KERNEL_PROD>::PackedDepthWiseConvMatrix(
- int K,
- const int8_t* smat)
- : K_(K) {
- // Transpose the input matrix to make packing faster.
- alignas(64) int8_t smat_transposed[K * KERNEL_PROD];
- for (int i = 0; i < KERNEL_PROD; ++i) {
- for (int j = 0; j < K; ++j) {
- smat_transposed[i * K + j] = smat[i + j * KERNEL_PROD];
- }
- }
-
- // Allocate packed arrays
- constexpr int KERNEL_PROD_ALIGNED = (KERNEL_PROD + 1) / 2 * 2;
- // pmat_ = static_cast<int8_t *>(fbgemmAlignedAlloc(
- // 64, ((K + 31) / 32) * KERNEL_PROD_ALIGNED * 32 * sizeof(int8_t)));
- posix_memalign(
- (void**)&pmat_,
- 64,
- ((K + 31) / 32) * KERNEL_PROD_ALIGNED * 32 * sizeof(int8_t));
-
- // Pack input matrix
- // The layout is optimized to use vpmaddubsw efficiently (see
- // madd_epi16x4_packed function).
- // For a group of 32 channels, we have 10 32B SIMD registers.
- // Denote ith channel jth filter as (i, j)
- // 0th SIMD register:
- // (0, 0), (0, 1), (0, 2), (0, 3), ..., (3, 0), (3, 1), (3, 2), (3, 3)
- // (16, 0), (16, 1), (16, 2), (16, 3), ..., (19, 0), (19, 1), (19, 2), (19, 3)
- // 1st SIMD register:
- // (4, 0), (4, 1), (4, 2), (4, 3), ..., (7, 0), (7, 1), (7, 2), (7, 3)
- // (20, 0), (20, 1), (20, 2), (20, 3), ..., (23, 0), (23, 1), (23, 2), (23, 3)
- // 2nd SIMD register:
- // (8, 0), (8, 1), (8, 2), (8, 3), ..., (11, 0), (11, 1), (11, 2), (11, 3)
- // (24, 0), (24, 1), (24, 2), (24, 3), ..., (27, 0), (27, 1), (27, 2), (27, 3)
- // 3rd SIMD register:
- // (12, 0), (12, 1), (12, 2), (12, 3), ..., (15, 0), (15, 1), (15, 2), (15, 3)
- // (28, 0), (28, 1), (28, 2), (28, 3), ..., (31, 0), (31, 1), (31, 2), (31, 3)
- // 4-7th SIMD register: same as the previous 4 registers but for 4-7th filter
- // coefficients
- // ...
- //
- // REMAINDER
- // If KERNEL_PROD % 4 == 1 for example when KERNEL_PROD == 9
- // 8th SIMD register:
- // (0, 8), zero, ..., (7, 8), zero
- // (16, 8), zero, ..., (23, 8), zero
- // 9th SIMD register:
- // (8, 8), zero, ..., (15, 8), zero
- // (24, 8), zero, ..., (31, 8), zero
- // We use madd_epi16_packed for this case
- //
- // If KERNEL_PROD % 4 == 2 for example when KERNEL_PROD == 10
- // 8th SIMD register:
- // (0, 8), (0, 9), ..., (7, 8), (7, 9)
- // (16, 8), (16, 9), ..., (23, 8), (23, 9)
- // 9th SIMD register:
- // (8, 8), (8, 9), ..., (15, 8), (15, 9)
- // (24, 8), (24, 9), ..., (31, 8), (31, 9)
- //
- // If KERNEL_PROD % 4 == 3 for example when KERNEL_PROD == 11
- // 8th SIMD register:
- // (0, 8), (0, 9), (0, 10), zero, ..., (3, 8), (3, 9), (3, 10), zero
- // (16, 8), (16, 9), (16, 10), zero, ..., (19, 8), (19, 9), (19, 10), zero
- // 9th SIMD register:
- // (4, 8), (4, 9), (4, 10), zero, ..., (7, 8), (7, 9), (7, 10), zero
- // (20, 8), (20, 9), (20, 10), zero, ..., (23, 8), (23, 9), (23, 10), zero
- // 10th SIMD register:
- // (8, 8), (8, 9), (8, 10), zero, ..., (11, 8), (11, 9), (11, 10), zero
- // (24, 8), (24, 9), (24, 10), zero, ..., (27, 8), (27, 9), (27, 10), zero
- // 11th SIMD register:
- // (12, 8), (12, 9), (12, 10), zero, ..., (15, 8), (15, 9), (15, 10), zero
- // (28, 8), (28, 9), (28, 10), zero, ..., (31, 8), (31, 9), (31, 10), zero
- for (int k1 = 0; k1 < K; k1 += 32) {
- __m256i b_v[KERNEL_PROD];
- int remainder = K - k1;
- if (remainder < 32) {
- __m256i mask_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(masks[remainder / 4]));
- for (int i = 0; i < KERNEL_PROD; ++i) {
- b_v[i] = _mm256_maskload_epi32(
- reinterpret_cast<const int*>(smat_transposed + i * K + k1), mask_v);
- }
- } else {
- for (int i = 0; i < KERNEL_PROD; ++i) {
- b_v[i] = _mm256_lddqu_si256(
- reinterpret_cast<const __m256i*>(smat_transposed + i * K + k1));
- }
- }
-
- // Interleave 2 SIMD registers
- __m256i b_interleaved_epi16[KERNEL_PROD_ALIGNED];
- __m256i zero_v = _mm256_setzero_si256();
- for (int i = 0; i < KERNEL_PROD_ALIGNED / 2; ++i) {
- if (2 * i + 1 >= KERNEL_PROD) {
- b_interleaved_epi16[2 * i] = _mm256_unpacklo_epi8(b_v[2 * i], zero_v);
- b_interleaved_epi16[2 * i + 1] =
- _mm256_unpackhi_epi8(b_v[2 * i], zero_v);
- } else {
- b_interleaved_epi16[2 * i] =
- _mm256_unpacklo_epi8(b_v[2 * i], b_v[2 * i + 1]);
- b_interleaved_epi16[2 * i + 1] =
- _mm256_unpackhi_epi8(b_v[2 * i], b_v[2 * i + 1]);
- }
- }
-
- // Interleave 4 SIMD registers
- __m256i b_interleaved_epi32[KERNEL_PROD_ALIGNED];
- for (int i = 0; i < KERNEL_PROD_ALIGNED / 4; ++i) {
- b_interleaved_epi32[4 * i] = _mm256_unpacklo_epi16(
- b_interleaved_epi16[4 * i], b_interleaved_epi16[4 * i + 2]);
- b_interleaved_epi32[4 * i + 1] = _mm256_unpackhi_epi16(
- b_interleaved_epi16[4 * i], b_interleaved_epi16[4 * i + 2]);
- b_interleaved_epi32[4 * i + 2] = _mm256_unpacklo_epi16(
- b_interleaved_epi16[4 * i + 1], b_interleaved_epi16[4 * i + 3]);
- b_interleaved_epi32[4 * i + 3] = _mm256_unpackhi_epi16(
- b_interleaved_epi16[4 * i + 1], b_interleaved_epi16[4 * i + 3]);
- }
- for (int i = KERNEL_PROD_ALIGNED / 4 * 4; i < KERNEL_PROD_ALIGNED; ++i) {
- b_interleaved_epi32[i] = b_interleaved_epi16[i];
- }
-
- for (int i = 0; i < KERNEL_PROD_ALIGNED; ++i) {
- _mm256_storeu_si256(
- reinterpret_cast<__m256i*>(
- &pmat_[((k1 / 32) * KERNEL_PROD_ALIGNED + i) * 32]),
- b_interleaved_epi32[i]);
- }
- }
-}
-
-template <int KERNEL_PROD>
-int PackedDepthWiseConvMatrix<KERNEL_PROD>::addr(int r, int c) {
- constexpr int KERNEL_PROD_ALIGNED = (KERNEL_PROD + 1) / 2 * 2;
- if (c >= KERNEL_PROD / 4 * 4 &&
- (KERNEL_PROD % 4 == 1 || KERNEL_PROD % 4 == 2)) {
- int kBlock = r / 32;
- int reg_idx = (r % 16) / 8 + c / 4 * 4;
-
- int blk_idx = kBlock * KERNEL_PROD_ALIGNED + reg_idx;
-
- int r_ = r % 8;
- int c_ = c % 4;
-
- int in_blk_idx = (r % 32) / 16 * 16 + 2 * r_ + c_;
- return blk_idx * 32 + in_blk_idx;
-
- } else {
- int kBlock = r / 32;
- int reg_idx = (r % 16) / 4 + c / 4 * 4;
-
- int blk_idx = kBlock * KERNEL_PROD_ALIGNED + reg_idx;
-
- int r_ = r % 4;
- int c_ = c % 4;
-
- int in_blk_idx = (r % 32) / 16 * 16 + 4 * r_ + c_;
- return blk_idx * 32 + in_blk_idx;
- }
-}
-
-template <int KERNEL_PROD>
-void PackedDepthWiseConvMatrix<KERNEL_PROD>::unpack(int8_t* unpacked_data) {
- for (int r = 0; r < K_; ++r) {
- for (int c = 0; c < KERNEL_PROD; ++c) {
- unpacked_data[r * KERNEL_PROD + c] = pmat_[addr(r, c)];
- }
- }
-}
-
-template <int KERNEL_PROD>
-PackedDepthWiseConvMatrix<KERNEL_PROD>::~PackedDepthWiseConvMatrix() {
- free(pmat_);
-}
-
-template class PackedDepthWiseConvMatrix<3 * 3>;
-template class PackedDepthWiseConvMatrix<3 * 3 * 3>;
-template class PackedDepthWiseConvMatrix<1>;
-template class PackedDepthWiseConvMatrix<2>;
-template class PackedDepthWiseConvMatrix<3>;
-template class PackedDepthWiseConvMatrix<4>;
-template class PackedDepthWiseConvMatrix<5>;
-template class PackedDepthWiseConvMatrix<5 * 2>;
-template class PackedDepthWiseConvMatrix<11 * 1>;
-
-// c = a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3
-// A is in uint8_t
-// B is in int8_t and pre-interleaved
-// C is in int32_t and 4 registers have results in the following layout:
-// c0_v: c[0:4], c[16:20]
-// c1_v: c[4:8], c[20:24]
-// c2_v: c[8:12], c[24:28]
-// c3_v: c[12:16], c[28:32]
-template <bool SUM_A = false>
-static inline __attribute__((always_inline)) void madd_epi16x4_packed(
- __m256i a0_v,
- __m256i a1_v,
- __m256i a2_v,
- __m256i a3_v,
- const __m256i* b,
- __m256i* c0_v,
- __m256i* c1_v,
- __m256i* c2_v,
- __m256i* c3_v,
- __m256i* a_sum = nullptr) {
- __m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v);
- __m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v);
- __m256i a23_lo_v = _mm256_unpacklo_epi8(a2_v, a3_v);
- __m256i a23_hi_v = _mm256_unpackhi_epi8(a2_v, a3_v);
-
- if (SUM_A) {
- __m256i one_epi8_v = _mm256_set1_epi8(1);
- a_sum[0] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a01_lo_v, one_epi8_v), a_sum[0]);
- a_sum[1] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a01_hi_v, one_epi8_v), a_sum[1]);
- a_sum[0] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a23_lo_v, one_epi8_v), a_sum[0]);
- a_sum[1] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a23_hi_v, one_epi8_v), a_sum[1]);
- }
-
- __m256i a0_interleaved_v = _mm256_unpacklo_epi16(a01_lo_v, a23_lo_v);
- __m256i a1_interleaved_v = _mm256_unpackhi_epi16(a01_lo_v, a23_lo_v);
- __m256i a2_interleaved_v = _mm256_unpacklo_epi16(a01_hi_v, a23_hi_v);
- __m256i a3_interleaved_v = _mm256_unpackhi_epi16(a01_hi_v, a23_hi_v);
-
- __m256i b0_v = _mm256_load_si256(b + 0);
- __m256i b1_v = _mm256_load_si256(b + 1);
- __m256i b2_v = _mm256_load_si256(b + 2);
- __m256i b3_v = _mm256_load_si256(b + 3);
-
- __m256i ab0 = _mm256_maddubs_epi16(a0_interleaved_v, b0_v);
- __m256i ab1 = _mm256_maddubs_epi16(a1_interleaved_v, b1_v);
- __m256i ab2 = _mm256_maddubs_epi16(a2_interleaved_v, b2_v);
- __m256i ab3 = _mm256_maddubs_epi16(a3_interleaved_v, b3_v);
-
- __m256i one_v = _mm256_set1_epi16(1);
- *c0_v = _mm256_madd_epi16(ab0, one_v);
- *c1_v = _mm256_madd_epi16(ab1, one_v);
- *c2_v = _mm256_madd_epi16(ab2, one_v);
- *c3_v = _mm256_madd_epi16(ab3, one_v);
-}
-
-// c = a0 * b0 + a1 * b1 + a2 * b2
-// A is in uint8_t
-// B is in int8_t and pre-interleaved
-// C is in int32_t and 4 registers have results in the following layout:
-// c0_v: c[0:4], c[16:20]
-// c1_v: c[4:8], c[20:24]
-// c2_v: c[8:12], c[24:28]
-// c3_v: c[12:16], c[28:32]
-template <bool SUM_A = false>
-static inline __attribute__((always_inline)) void madd_epi16x3_packed(
- __m256i a0_v,
- __m256i a1_v,
- __m256i a2_v,
- const __m256i* b,
- __m256i* c0_v,
- __m256i* c1_v,
- __m256i* c2_v,
- __m256i* c3_v,
- __m256i* a_sum = nullptr) {
- __m256i zero_v = _mm256_setzero_si256();
-
- __m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v);
- __m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v);
- __m256i a23_lo_v = _mm256_unpacklo_epi8(a2_v, zero_v);
- __m256i a23_hi_v = _mm256_unpackhi_epi8(a2_v, zero_v);
-
- if (SUM_A) {
- __m256i one_epi8_v = _mm256_set1_epi8(1);
- a_sum[0] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a01_lo_v, one_epi8_v), a_sum[0]);
- a_sum[1] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a01_hi_v, one_epi8_v), a_sum[1]);
- a_sum[0] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a23_lo_v, one_epi8_v), a_sum[0]);
- a_sum[1] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a23_hi_v, one_epi8_v), a_sum[1]);
- }
-
- __m256i a0_interleaved_v = _mm256_unpacklo_epi16(a01_lo_v, a23_lo_v);
- __m256i a1_interleaved_v = _mm256_unpackhi_epi16(a01_lo_v, a23_lo_v);
- __m256i a2_interleaved_v = _mm256_unpacklo_epi16(a01_hi_v, a23_hi_v);
- __m256i a3_interleaved_v = _mm256_unpackhi_epi16(a01_hi_v, a23_hi_v);
-
- __m256i b0_v = _mm256_load_si256(b + 0);
- __m256i b1_v = _mm256_load_si256(b + 1);
- __m256i b2_v = _mm256_load_si256(b + 2);
- __m256i b3_v = _mm256_load_si256(b + 3);
-
- __m256i ab0 = _mm256_maddubs_epi16(a0_interleaved_v, b0_v);
- __m256i ab1 = _mm256_maddubs_epi16(a1_interleaved_v, b1_v);
- __m256i ab2 = _mm256_maddubs_epi16(a2_interleaved_v, b2_v);
- __m256i ab3 = _mm256_maddubs_epi16(a3_interleaved_v, b3_v);
-
- __m256i one_v = _mm256_set1_epi16(1);
- *c0_v = _mm256_madd_epi16(ab0, one_v);
- *c1_v = _mm256_madd_epi16(ab1, one_v);
- *c2_v = _mm256_madd_epi16(ab2, one_v);
- *c3_v = _mm256_madd_epi16(ab3, one_v);
-}
-
-// c = a0 * b0 + a1 * b1
-// A is in uint8_t
-// B is in int8_t and pre-interleaved
-// C is in int32_t and 4 registers have results in the following layout:
-// c0_v: c[0:4], c[4:8]
-// c1_v: c[8:12], c[12:16]
-// c2_v: c[16:20], c[20:24]
-// c3_v: c[24:28], c[28:32]
-template <bool SUM_A = false>
-static inline __attribute__((always_inline)) void madd_epi16x2_packed(
- __m256i a0_v,
- __m256i a1_v,
- const __m256i* b,
- __m256i* c0_v,
- __m256i* c1_v,
- __m256i* c2_v,
- __m256i* c3_v,
- __m256i* a_sum = nullptr) {
- __m256i a_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v);
- __m256i a_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v);
-
- if (SUM_A) {
- __m256i one_epi8_v = _mm256_set1_epi8(1);
- a_sum[0] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a_lo_v, one_epi8_v), a_sum[0]);
- a_sum[1] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a_hi_v, one_epi8_v), a_sum[1]);
- }
-
- __m256i b0_v = _mm256_load_si256(b + 0);
- __m256i b1_v = _mm256_load_si256(b + 1);
-
- __m256i ab_lo_v = _mm256_maddubs_epi16(a_lo_v, b0_v);
- __m256i ab_hi_v = _mm256_maddubs_epi16(a_hi_v, b1_v);
-
- *c0_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_lo_v));
- *c1_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_hi_v));
- *c2_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_lo_v, 1));
- *c3_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_hi_v, 1));
-}
-
-// c = a0 * b0
-// A is in uint8_t
-// B is in int8_t and pre-interleaved
-// C is in int32_t and 4 registers have results in the following layout:
-// c0_v: c[0:4], c[4:8]
-// c1_v: c[8:12], c[12:16]
-// c2_v: c[16:20], c[20:24]
-// c3_v: c[24:28], c[28:32]
-template <bool SUM_A = false>
-static inline __attribute__((always_inline)) void madd_epi16_packed(
- __m256i a_v,
- const __m256i* b,
- __m256i* c0_v,
- __m256i* c1_v,
- __m256i* c2_v,
- __m256i* c3_v,
- __m256i* a_sum = nullptr) {
- __m256i zero_v = _mm256_setzero_si256();
-
- __m256i a_lo_v = _mm256_unpacklo_epi8(a_v, zero_v);
- __m256i a_hi_v = _mm256_unpackhi_epi8(a_v, zero_v);
-
- if (SUM_A) {
- __m256i one_epi8_v = _mm256_set1_epi8(1);
- a_sum[0] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a_lo_v, one_epi8_v), a_sum[0]);
- a_sum[1] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a_hi_v, one_epi8_v), a_sum[1]);
- }
-
- __m256i b0_v = _mm256_load_si256(b + 0);
- __m256i b1_v = _mm256_load_si256(b + 1);
-
- __m256i ab_lo_v = _mm256_maddubs_epi16(a_lo_v, b0_v);
- __m256i ab_hi_v = _mm256_maddubs_epi16(a_hi_v, b1_v);
-
- *c0_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_lo_v));
- *c1_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_hi_v));
- *c2_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_lo_v, 1));
- *c3_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_hi_v, 1));
-}
-
-// K is the number of accumulations we're doing
-template <int K, bool SUM_A = false, bool REMAINDER = false, bool ACC = false>
-static inline __attribute__((always_inline)) void inner_prod_packed_(
- const __m256i* a_v,
- const __m256i* Bp,
- int32_t* C,
- int remainder,
- __m256i* a_sum = nullptr) {
- __m256i c[4], c_temp[4];
- __m256i a_sum_temp[2] = {0, 0};
-
- int k = 0;
- if (K >= 4) {
- madd_epi16x4_packed<SUM_A>(
- a_v[0],
- a_v[1],
- a_v[2],
- a_v[3],
- Bp,
- &c[0],
- &c[1],
- &c[2],
- &c[3],
- a_sum_temp);
-
- for (k = 4; k < K / 4 * 4; k += 4) {
- madd_epi16x4_packed<SUM_A>(
- a_v[k + 0],
- a_v[k + 1],
- a_v[k + 2],
- a_v[k + 3],
- Bp + k,
- &c_temp[0],
- &c_temp[1],
- &c_temp[2],
- &c_temp[3],
- a_sum_temp);
-
- c[0] = _mm256_add_epi32(c[0], c_temp[0]);
- c[1] = _mm256_add_epi32(c[1], c_temp[1]);
- c[2] = _mm256_add_epi32(c[2], c_temp[2]);
- c[3] = _mm256_add_epi32(c[3], c_temp[3]);
- }
- } else {
- c[0] = _mm256_setzero_si256();
- c[1] = _mm256_setzero_si256();
- c[2] = _mm256_setzero_si256();
- c[3] = _mm256_setzero_si256();
- }
-
- if (K - k == 3) {
- madd_epi16x3_packed<SUM_A>(
- a_v[k],
- a_v[k + 1],
- a_v[k + 2],
- Bp + k,
- &c_temp[0],
- &c_temp[1],
- &c_temp[2],
- &c_temp[3],
- a_sum_temp);
-
- c[0] = _mm256_add_epi32(c[0], c_temp[0]);
- c[1] = _mm256_add_epi32(c[1], c_temp[1]);
- c[2] = _mm256_add_epi32(c[2], c_temp[2]);
- c[3] = _mm256_add_epi32(c[3], c_temp[3]);
- }
-
- c_temp[0] = _mm256_permute2f128_si256(c[0], c[1], 0x20);
- c_temp[1] = _mm256_permute2f128_si256(c[2], c[3], 0x20);
- c_temp[2] = _mm256_permute2f128_si256(c[0], c[1], 0x31);
- c_temp[3] = _mm256_permute2f128_si256(c[2], c[3], 0x31);
-
- if (K - k == 0 || K - k == 3) {
- c[0] = c_temp[0];
- c[1] = c_temp[1];
- c[2] = c_temp[2];
- c[3] = c_temp[3];
- } else {
- if (K - k == 1) {
- madd_epi16_packed<SUM_A>(
- a_v[k], Bp + k, &c[0], &c[1], &c[2], &c[3], a_sum_temp);
- } else if (K - k == 2) {
- madd_epi16x2_packed<SUM_A>(
- a_v[k], a_v[k + 1], Bp + k, &c[0], &c[1], &c[2], &c[3], a_sum_temp);
- }
-
- c[0] = _mm256_add_epi32(c[0], c_temp[0]);
- c[1] = _mm256_add_epi32(c[1], c_temp[1]);
- c[2] = _mm256_add_epi32(c[2], c_temp[2]);
- c[3] = _mm256_add_epi32(c[3], c_temp[3]);
- }
-
- if (REMAINDER) {
- for (int r = 0; r < remainder / 8; ++r) {
- if (ACC) {
- _mm256_storeu_si256(
- reinterpret_cast<__m256i*>(C + r * 8),
- _mm256_add_epi32(
- _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + r * 8)),
- c[r]));
- } else {
- _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + r * 8), c[r]);
- }
- }
- } else {
- if (ACC) {
- _mm256_storeu_si256(
- reinterpret_cast<__m256i*>(C),
- _mm256_add_epi32(
- _mm256_loadu_si256(reinterpret_cast<__m256i*>(C)), c[0]));
- _mm256_storeu_si256(
- reinterpret_cast<__m256i*>(C + 8),
- _mm256_add_epi32(
- _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 8)), c[1]));
- _mm256_storeu_si256(
- reinterpret_cast<__m256i*>(C + 16),
- _mm256_add_epi32(
- _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 16)), c[2]));
- _mm256_storeu_si256(
- reinterpret_cast<__m256i*>(C + 24),
- _mm256_add_epi32(
- _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 24)), c[3]));
- } else {
- _mm256_storeu_si256(reinterpret_cast<__m256i*>(C), c[0]);
- _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 8), c[1]);
- _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 16), c[2]);
- _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 24), c[3]);
- }
- }
-
- if (SUM_A) {
- a_sum[0] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(a_sum_temp[0]));
- a_sum[1] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(a_sum_temp[1]));
- a_sum[2] =
- _mm256_cvtepi16_epi32(_mm256_extracti128_si256(a_sum_temp[0], 1));
- a_sum[3] =
- _mm256_cvtepi16_epi32(_mm256_extracti128_si256(a_sum_temp[1], 1));
- }
-}
-
template <bool SUM_A = false, bool REMAINDER = false>
static inline __attribute__((always_inline)) void inner_prod_3x3_packed_(
const __m256i* a_v,
@@ -572,330 +25,6 @@ static inline __attribute__((always_inline)) void inner_prod_3x3_packed_(
return inner_prod_packed_<9, SUM_A, REMAINDER>(a_v, Bp, C, remainder, a_sum);
}
-// Almost same as ReQuantizeOutput in OutputProcessing-inh.h but different
-// row_offsets for each row because of depth-wise convolution
-template <
- bool FUSE_RELU,
- bool HAS_BIAS,
- bool PER_CHANNEL_QUANTIZATION,
- bool A_SYMMETRIC,
- bool B_SYMMETRIC,
- typename BIAS_TYPE>
-static inline __attribute__((always_inline)) void requantize_(
- int32_t A_zero_point,
- const float* C_multiplier,
- int32_t C_zero_point,
- const int32_t* C_int32,
- uint8_t* C_uint8,
- int n,
- const int32_t* row_offsets,
- const int32_t* col_offsets,
- const BIAS_TYPE* bias,
- const float* act_times_w_scale = nullptr) {
- __m256 multiplier_v = _mm256_setzero_ps();
- // Broadcasted reciprocal of act_times_w_scale
- __m256 act_times_w_rcp_v = _mm256_setzero_ps();
- if (!PER_CHANNEL_QUANTIZATION) {
- multiplier_v = _mm256_set1_ps(*C_multiplier);
- if (is_same<BIAS_TYPE, float>::value) {
- act_times_w_rcp_v = _mm256_set1_ps(1.0f / (*act_times_w_scale));
- }
- }
-
- __m256i min_v = _mm256_set1_epi8(static_cast<uint8_t>(0));
- __m256i max_v = _mm256_set1_epi8(static_cast<uint8_t>(255));
-
- if (A_SYMMETRIC) {
- assert(A_zero_point == 0 || col_offsets == nullptr);
- }
- __m256i A_zero_point_v = _mm256_set1_epi32(A_zero_point);
- __m256i C_zero_point_epi16_v = _mm256_set1_epi16(C_zero_point);
- __m256i C_zero_point_epi8_v = _mm256_set1_epi8(C_zero_point);
-
- __m256i permute_mask_v =
- _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
-
- constexpr int VLEN = 8;
- int j = 0;
- for (; j < n / (VLEN * 4) * (VLEN * 4); j += (VLEN * 4)) {
- __m256i x_v =
- _mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j));
- __m256i y_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(C_int32 + j + VLEN));
- __m256i z_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(C_int32 + j + 2 * VLEN));
- __m256i w_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(C_int32 + j + 3 * VLEN));
-
- __m256i row_offset_v;
- if (!B_SYMMETRIC) {
- row_offset_v =
- _mm256_loadu_si256(reinterpret_cast<const __m256i*>(row_offsets + j));
- x_v = _mm256_sub_epi32(x_v, row_offset_v);
- }
- __m256i col_off_v;
- if (!A_SYMMETRIC) {
- col_off_v = _mm256_mullo_epi32(
- A_zero_point_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(col_offsets + j)));
- x_v = _mm256_sub_epi32(x_v, col_off_v);
- }
-
- if (!B_SYMMETRIC) {
- row_offset_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(row_offsets + j + VLEN));
- y_v = _mm256_sub_epi32(y_v, row_offset_v);
- }
- if (!A_SYMMETRIC) {
- col_off_v = _mm256_mullo_epi32(
- A_zero_point_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(col_offsets + j + VLEN)));
- y_v = _mm256_sub_epi32(y_v, col_off_v);
- }
-
- if (!B_SYMMETRIC) {
- row_offset_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(row_offsets + j + 2 * VLEN));
- z_v = _mm256_sub_epi32(z_v, row_offset_v);
- }
- if (!A_SYMMETRIC) {
- col_off_v = _mm256_mullo_epi32(
- A_zero_point_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(col_offsets + j + 2 * VLEN)));
- z_v = _mm256_sub_epi32(z_v, col_off_v);
- }
-
- if (!B_SYMMETRIC) {
- row_offset_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(row_offsets + j + 3 * VLEN));
- w_v = _mm256_sub_epi32(w_v, row_offset_v);
- }
- if (!A_SYMMETRIC) {
- col_off_v = _mm256_mullo_epi32(
- A_zero_point_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(col_offsets + j + 3 * VLEN)));
- w_v = _mm256_sub_epi32(w_v, col_off_v);
- }
-
- // convert to float
- __m256 xf_v, yf_v, zf_v, wf_v;
- if (HAS_BIAS) { // static if
- if (is_same<BIAS_TYPE, float>::value) {
- __m256 x_bias_v, y_bias_v, z_bias_v, w_bias_v;
- if (PER_CHANNEL_QUANTIZATION) {
- x_bias_v = _mm256_div_ps(
- _mm256_loadu_ps(
- reinterpret_cast<const float*>(bias + j + 0 * VLEN)),
- _mm256_loadu_ps(act_times_w_scale + j + 0 * VLEN));
- y_bias_v = _mm256_div_ps(
- _mm256_loadu_ps(
- reinterpret_cast<const float*>(bias + j + 1 * VLEN)),
- _mm256_loadu_ps(act_times_w_scale + j + 1 * VLEN));
- z_bias_v = _mm256_div_ps(
- _mm256_loadu_ps(
- reinterpret_cast<const float*>(bias + j + 2 * VLEN)),
- _mm256_loadu_ps(act_times_w_scale + j + 2 * VLEN));
- w_bias_v = _mm256_div_ps(
- _mm256_loadu_ps(
- reinterpret_cast<const float*>(bias + j + 3 * VLEN)),
- _mm256_loadu_ps(act_times_w_scale + j + 3 * VLEN));
- } else {
- x_bias_v = _mm256_mul_ps(
- _mm256_loadu_ps(
- reinterpret_cast<const float*>(bias + j + 0 * VLEN)),
- act_times_w_rcp_v);
- y_bias_v = _mm256_mul_ps(
- _mm256_loadu_ps(
- reinterpret_cast<const float*>(bias + j + 1 * VLEN)),
- act_times_w_rcp_v);
- z_bias_v = _mm256_mul_ps(
- _mm256_loadu_ps(
- reinterpret_cast<const float*>(bias + j + 2 * VLEN)),
- act_times_w_rcp_v);
- w_bias_v = _mm256_mul_ps(
- _mm256_loadu_ps(
- reinterpret_cast<const float*>(bias + j + 3 * VLEN)),
- act_times_w_rcp_v);
- }
- xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
- yf_v = _mm256_add_ps(_mm256_cvtepi32_ps(y_v), y_bias_v);
- zf_v = _mm256_add_ps(_mm256_cvtepi32_ps(z_v), z_bias_v);
- wf_v = _mm256_add_ps(_mm256_cvtepi32_ps(w_v), w_bias_v);
- } else {
- x_v = _mm256_add_epi32(
- x_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(bias + j + 0 * VLEN)));
- y_v = _mm256_add_epi32(
- y_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(bias + j + 1 * VLEN)));
- z_v = _mm256_add_epi32(
- z_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(bias + j + 2 * VLEN)));
- w_v = _mm256_add_epi32(
- w_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(bias + j + 3 * VLEN)));
- xf_v = _mm256_cvtepi32_ps(x_v);
- yf_v = _mm256_cvtepi32_ps(y_v);
- zf_v = _mm256_cvtepi32_ps(z_v);
- wf_v = _mm256_cvtepi32_ps(w_v);
- }
- } else {
- xf_v = _mm256_cvtepi32_ps(x_v);
- yf_v = _mm256_cvtepi32_ps(y_v);
- zf_v = _mm256_cvtepi32_ps(z_v);
- wf_v = _mm256_cvtepi32_ps(w_v);
- }
-
- if (PER_CHANNEL_QUANTIZATION) {
- multiplier_v = _mm256_loadu_ps(C_multiplier + j + 0 * VLEN);
- }
- __m256 x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
- if (PER_CHANNEL_QUANTIZATION) {
- multiplier_v = _mm256_loadu_ps(C_multiplier + j + 1 * VLEN);
- }
- __m256 y_scaled_v = _mm256_mul_ps(yf_v, multiplier_v);
- if (PER_CHANNEL_QUANTIZATION) {
- multiplier_v = _mm256_loadu_ps(C_multiplier + j + 2 * VLEN);
- }
- __m256 z_scaled_v = _mm256_mul_ps(zf_v, multiplier_v);
- if (PER_CHANNEL_QUANTIZATION) {
- multiplier_v = _mm256_loadu_ps(C_multiplier + j + 3 * VLEN);
- }
- __m256 w_scaled_v = _mm256_mul_ps(wf_v, multiplier_v);
-
- __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
- __m256i y_rounded_v = _mm256_cvtps_epi32(y_scaled_v);
- __m256i z_rounded_v = _mm256_cvtps_epi32(z_scaled_v);
- __m256i w_rounded_v = _mm256_cvtps_epi32(w_scaled_v);
-
- __m256i xy_packed_v = _mm256_adds_epi16(
- _mm256_packs_epi32(x_rounded_v, y_rounded_v), C_zero_point_epi16_v);
- __m256i zw_packed_v = _mm256_adds_epi16(
- _mm256_packs_epi32(z_rounded_v, w_rounded_v), C_zero_point_epi16_v);
- __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v);
- __m256i xyzw_clamped_v = _mm256_max_epu8(
- FUSE_RELU ? C_zero_point_epi8_v : min_v,
- _mm256_min_epu8(xyzw_packed_v, max_v));
-
- xyzw_clamped_v =
- _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
-
- _mm256_storeu_si256(
- reinterpret_cast<__m256i*>(C_uint8 + j), xyzw_clamped_v);
- } // j loop vectorized and unrolled 4x
-
- for (; j < n / VLEN * VLEN; j += VLEN) {
- __m256i x_v =
- _mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j));
-
- if (!B_SYMMETRIC) {
- __m256i row_offset_v =
- _mm256_loadu_si256(reinterpret_cast<const __m256i*>(row_offsets + j));
- x_v = _mm256_sub_epi32(x_v, row_offset_v);
- }
- if (!A_SYMMETRIC) {
- __m256i col_off_v = _mm256_mullo_epi32(
- A_zero_point_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(col_offsets + j)));
- x_v = _mm256_sub_epi32(x_v, col_off_v);
- }
-
- // Convert to float
- __m256 xf_v;
- if (HAS_BIAS) { // static if
- if (is_same<BIAS_TYPE, float>::value) {
- __m256 x_bias_v;
- if (PER_CHANNEL_QUANTIZATION) {
- x_bias_v = _mm256_div_ps(
- _mm256_loadu_ps(reinterpret_cast<const float*>(bias + j)),
- _mm256_loadu_ps(act_times_w_scale + j));
- } else {
- x_bias_v = _mm256_mul_ps(
- _mm256_loadu_ps(reinterpret_cast<const float*>(bias + j)),
- act_times_w_rcp_v);
- }
- xf_v = _mm256_add_ps(_mm256_cvtepi32_ps(x_v), x_bias_v);
- } else {
- x_v = _mm256_add_epi32(
- x_v,
- _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias + j)));
- xf_v = _mm256_cvtepi32_ps(x_v);
- }
- } else {
- xf_v = _mm256_cvtepi32_ps(x_v);
- }
-
- if (PER_CHANNEL_QUANTIZATION) {
- multiplier_v = _mm256_loadu_ps(C_multiplier + j);
- }
- __m256 x_scaled_v = _mm256_mul_ps(xf_v, multiplier_v);
- __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
-
- __m256i x_packed_v = _mm256_adds_epi16(
- _mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()),
- C_zero_point_epi16_v);
- x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256());
- __m256i x_clamped_v = _mm256_max_epu8(
- FUSE_RELU ? C_zero_point_epi8_v : min_v,
- _mm256_min_epu8(x_packed_v, max_v));
-
- x_clamped_v = _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v);
-
- _mm_storel_epi64(
- reinterpret_cast<__m128i*>(C_uint8 + j),
- _mm256_castsi256_si128(x_clamped_v));
- } // j loop vectorized
-
- for (; j < n; ++j) {
- int32_t raw = C_int32[j];
- if (!B_SYMMETRIC) {
- raw -= row_offsets[j];
- }
- if (!A_SYMMETRIC) {
- raw -= A_zero_point * col_offsets[j];
- }
- float raw_f;
- if (HAS_BIAS) { // static if
- if (is_same<BIAS_TYPE, float>::value) {
- raw_f = raw;
- raw_f += bias[j] / act_times_w_scale[PER_CHANNEL_QUANTIZATION ? j : 0];
- } else {
- raw += bias[j];
- raw_f = raw;
- }
- } else {
- raw_f = raw;
- }
-
- float ab = raw_f * C_multiplier[PER_CHANNEL_QUANTIZATION ? j : 0];
- long rounded = lrintf(ab) + C_zero_point;
-
- C_uint8[j] = std::max(
- FUSE_RELU ? static_cast<long>(C_zero_point) : 0l,
- std::min(255l, rounded));
- }
-}
-
-template <bool REMAINDER>
-static inline __attribute__((always_inline)) __m256i load_a(
- const uint8_t* A,
- __m256i mask_v) {
- if (REMAINDER) {
- return _mm256_maskload_epi32(reinterpret_cast<const int*>(A), mask_v);
- } else {
- return _mm256_lddqu_si256(reinterpret_cast<const __m256i*>(A));
- }
-}
-
template <
bool SUM_A,
bool REMAINDER = false,
@@ -1008,257 +137,6 @@ static inline __attribute__((always_inline)) void inner_prod_3x3_packed_(
}
template <
- bool SUM_A,
- bool REMAINDER = false,
- bool PER_CHANNEL_QUANTIZATION = false>
-static inline __attribute__((always_inline)) void inner_prod_3x3x3_packed_(
- int T,
- int H,
- int W,
- int K,
- int t_in,
- int h_in,
- int w_in,
- const uint8_t* A,
- int32_t A_zero_point,
- const int8_t* Bp,
- const int32_t* B_zero_point,
- int32_t* C,
- int remainder,
- int32_t* row_offsets) {
- __m256i A_zero_point_v = _mm256_set1_epi8(static_cast<uint8_t>(A_zero_point));
- __m256i mask_v = _mm256_setzero_si256();
- if (REMAINDER) {
- mask_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(masks[remainder / 4]));
- }
-
- // The code below can be written as a simple R*S loop but the compiler
- // doesn't unroll so we're manually unrolling it.
- // constexpr int R = 3, S = 3;
- // array<__m256i, R * S> a_v;
- // for (int r = 0; r < R; ++r) {
- // for (int s = 0; s < S; ++s) {
- // if (h_in + r >= 0 && h_in + r < H && w_in + s >= 0 && w_in + s < W) {
- // if (REMAINDER) {
- // a_v[r * S + s] =
- // _mm256_maskload_epi32((const int *)(A + (r * W + s) * K),
- // mask_v);
- // } else {
- // a_v[r * S + s] =
- // _mm256_lddqu_si256((const __m256i *)(A + (r * W + s) * K));
- // }
- // } else {
- // a_v[r * S + s] = A_zero_point_v;
- // }
- // }
- // }
- __m256i a_v[8];
- a_v[0] = A_zero_point_v;
- a_v[1] = A_zero_point_v;
- a_v[2] = A_zero_point_v;
- a_v[3] = A_zero_point_v;
- a_v[4] = A_zero_point_v;
- a_v[5] = A_zero_point_v;
- a_v[6] = A_zero_point_v;
- a_v[7] = A_zero_point_v;
-
- if (t_in >= 0 && t_in < T) {
- if (h_in >= 0 && h_in < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[0] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[1] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[2] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 2) * K, mask_v);
- }
- }
-
- if (h_in + 1 >= 0 && h_in + 1 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[3] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[4] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[5] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 2) * K, mask_v);
- }
- }
-
- if (h_in + 2 >= 0 && h_in + 2 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[6] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[7] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 1) * K, mask_v);
- }
- }
- }
-
- __m256i a_sum[4];
- inner_prod_packed_<8, SUM_A, REMAINDER>(
- a_v, reinterpret_cast<const __m256i*>(Bp), C, remainder, a_sum);
-
- a_v[0] = A_zero_point_v;
- a_v[1] = A_zero_point_v;
- a_v[2] = A_zero_point_v;
- a_v[3] = A_zero_point_v;
- a_v[4] = A_zero_point_v;
- a_v[5] = A_zero_point_v;
- a_v[6] = A_zero_point_v;
- a_v[7] = A_zero_point_v;
-
- if (t_in >= 0 && t_in < T) {
- if (h_in + 2 >= 0 && h_in + 2 < H) {
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[0] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 2) * K, mask_v);
- }
- }
- }
-
- if (t_in + 1 >= 0 && t_in + 1 < T) {
- if (h_in >= 0 && h_in < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[1] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[2] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[3] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 2) * K, mask_v);
- }
- }
-
- if (h_in + 1 >= 0 && h_in + 1 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[4] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[5] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[6] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 2) * K, mask_v);
- }
- }
-
- if (h_in + 2 >= 0 && h_in + 2 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[7] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 0) * K, mask_v);
- }
- }
- }
-
- __m256i a_sum_temp[4];
- inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>(
- a_v, reinterpret_cast<const __m256i*>(Bp) + 8, C, remainder, a_sum_temp);
- if (SUM_A) {
- a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]);
- a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]);
- a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]);
- a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]);
- }
-
- a_v[0] = A_zero_point_v;
- a_v[1] = A_zero_point_v;
- a_v[2] = A_zero_point_v;
- a_v[3] = A_zero_point_v;
- a_v[4] = A_zero_point_v;
- a_v[5] = A_zero_point_v;
- a_v[6] = A_zero_point_v;
- a_v[7] = A_zero_point_v;
-
- if (t_in + 1 >= 0 && t_in + 1 < T) {
- if (h_in + 2 >= 0 && h_in + 2 < H) {
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[0] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[1] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 2) * K, mask_v);
- }
- }
- }
-
- if (t_in + 2 >= 0 && t_in + 2 < T) {
- if (h_in >= 0 && h_in < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[2] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[3] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[4] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 2) * K, mask_v);
- }
- }
-
- if (h_in + 1 >= 0 && h_in + 1 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[5] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[6] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[7] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 2) * K, mask_v);
- }
- }
- }
-
- inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>(
- a_v, reinterpret_cast<const __m256i*>(Bp) + 16, C, remainder, a_sum_temp);
- if (SUM_A) {
- a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]);
- a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]);
- a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]);
- a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]);
- }
-
- a_v[0] = A_zero_point_v;
- a_v[1] = A_zero_point_v;
- a_v[2] = A_zero_point_v;
-
- if (t_in + 2 >= 0 && t_in + 2 < T) {
- if (h_in + 2 >= 0 && h_in + 2 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[0] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[1] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[2] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 2) * K, mask_v);
- }
- }
- }
-
- inner_prod_packed_<3, SUM_A, REMAINDER, true /* acc */>(
- a_v, reinterpret_cast<const __m256i*>(Bp) + 24, C, remainder, a_sum_temp);
-
- if (SUM_A) {
- a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]);
- a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]);
- a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]);
- a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]);
-
- __m256i B_zero_point_v;
- for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) {
- if (PER_CHANNEL_QUANTIZATION) {
- B_zero_point_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(B_zero_point + i * 8));
- } else {
- B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]);
- }
- _mm256_store_si256(
- reinterpret_cast<__m256i*>(&row_offsets[i * 8]),
- _mm256_mullo_epi32(a_sum[i], B_zero_point_v));
- }
- }
-}
-
-template <
bool FUSE_RELU,
bool HAS_BIAS,
bool A_SYMMETRIC,
@@ -1342,98 +220,6 @@ static inline __attribute__((always_inline)) void depthwise_3x3_kernel_(
&act_times_w_scale);
}
-template <
- bool FUSE_RELU,
- bool HAS_BIAS,
- bool A_SYMMETRIC,
- bool B_SYMMETRIC,
- typename BIAS_TYPE>
-static inline __attribute__((always_inline)) void depthwise_3x3x3_kernel_(
- int T,
- int H,
- int W,
- int K,
- int t,
- int h,
- int w,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- int32_t B_zero_point,
- const int8_t* Bp,
- float C_multiplier,
- int32_t C_zero_point,
- int32_t* C_int32,
- uint8_t* C_uint8,
- int32_t* row_offsets,
- const int32_t* col_offsets,
- const BIAS_TYPE* bias,
- float act_times_w_scale) {
- constexpr int R = 3, S = 3;
- constexpr int PAD_P = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
- int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
- int t_in = -PAD_P + t * stride_t;
- int h_in = -PAD_T + h * stride_h;
- int w_in = -PAD_L + w * stride_w;
-
- int k;
- for (k = 0; k < K / 32 * 32; k += 32) {
- inner_prod_3x3x3_packed_<!B_SYMMETRIC /*SUM_A*/>(
- T,
- H,
- W,
- K,
- t_in,
- h_in,
- w_in,
- A + ((t_in * H + h_in) * W + w_in) * K + k,
- A_zero_point,
- Bp + k * 28,
- &B_zero_point,
- C_int32 + k,
- 0,
- B_SYMMETRIC ? nullptr : &row_offsets[k]);
- }
- int remainder = K - k;
- if (remainder) {
- inner_prod_3x3x3_packed_<!B_SYMMETRIC /*SUM_A*/, true>(
- T,
- H,
- W,
- K,
- t_in,
- h_in,
- w_in,
- A + ((t_in * H + h_in) * W + w_in) * K + k,
- A_zero_point,
- Bp + k * 28,
- &B_zero_point,
- C_int32 + k,
- remainder,
- B_SYMMETRIC ? nullptr : &row_offsets[k]);
- }
-
- requantize_<
- FUSE_RELU,
- HAS_BIAS,
- false, /*PER_CHAN_QUANT*/
- A_SYMMETRIC,
- B_SYMMETRIC>(
- A_zero_point,
- &C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K,
- K,
- row_offsets,
- col_offsets,
- bias,
- &act_times_w_scale);
-}
-
template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, typename BIAS_TYPE>
static inline __attribute__((always_inline)) void
depthwise_3x3_per_channel_quantization_kernel_(
@@ -1520,107 +306,6 @@ depthwise_3x3_per_channel_quantization_kernel_(
act_times_w_scale);
}
-template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, typename BIAS_TYPE>
-static inline __attribute__((always_inline)) void
-depthwise_3x3x3_per_channel_quantization_kernel_(
- int T,
- int H,
- int W,
- int K,
- int t,
- int h,
- int w,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- const int32_t* B_zero_point,
- const int8_t* Bp,
- const float* C_multiplier,
- int32_t C_zero_point,
- int32_t* C_int32,
- uint8_t* C_uint8,
- int32_t* row_offsets,
- const int32_t* col_offsets,
- const BIAS_TYPE* bias,
- const float* act_times_w_scale) {
- constexpr int R = 3, S = 3;
- constexpr int PAD_P = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
- int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
- int t_in = -PAD_P + t * stride_t;
- int h_in = -PAD_T + h * stride_h;
- int w_in = -PAD_L + w * stride_w;
-
- int k;
- for (k = 0; k < K / 32 * 32; k += 32) {
- inner_prod_3x3x3_packed_<
- true, /*SUM_A*/
- false, /*remainder*/
- true /*per-channel*/>(
- T,
- H,
- W,
- K,
- t_in,
- h_in,
- w_in,
- A + ((t_in * H + h_in) * W + w_in) * K + k,
- A_zero_point,
- Bp + k * 28,
- B_zero_point + k,
- C_int32 + k,
- 0,
- &row_offsets[k]);
- }
- int remainder = K - k;
- if (remainder) {
- inner_prod_3x3x3_packed_<
- true, /*SUM_A*/
- true, /*remainder*/
- true /*per-channel*/>(
- T,
- H,
- W,
- K,
- t_in,
- h_in,
- w_in,
- A + ((t_in * H + h_in) * W + w_in) * K + k,
- A_zero_point,
- Bp + k * 28,
- B_zero_point + k,
- C_int32 + k,
- remainder,
- &row_offsets[k]);
- }
- requantize_<
- FUSE_RELU,
- HAS_BIAS,
- true, /*PER_CHAN_QUANT*/
- A_SYMMETRIC,
- false /*B_SYMM*/>(
- A_zero_point,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K,
- K,
- row_offsets,
- col_offsets,
- bias,
- act_times_w_scale);
-}
-
-static pair<int, int> closest_factors_(int n) {
- int a = (int)std::sqrt(n);
- while (n % a != 0) {
- a--;
- }
- return {a, n / a}; // a <= n / a
-}
-
// TODO: short-circuit when B_zero_point is 0 or A_zero_point is 0
// This implemntation should be general enough to handle not just 3x3 but other
// filter shapes by parameterizing with R and S but restricting it to just 3x3
@@ -1641,7 +326,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_(
int32_t A_zero_point,
const uint8_t* A,
int32_t B_zero_point,
- const Packed3x3ConvMatrix& B,
+ const PackedDepthWiseConvMatrix& B,
float C_multiplier,
int32_t C_zero_point,
int32_t* C_int32,
@@ -1971,123 +656,6 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_(
} // for each n
};
-template <
- bool FUSE_RELU,
- bool HAS_BIAS,
- bool A_SYMMETRIC,
- bool B_SYMMETRIC,
- typename BIAS_TYPE>
-static inline __attribute__((always_inline)) void depthwise_3x3x3_pad_1_(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- int32_t B_zero_point,
- const Packed3x3x3ConvMatrix& B,
- float C_multiplier,
- int32_t C_zero_point,
- int32_t* C_int32,
- uint8_t* C_uint8,
- const int32_t* col_offsets,
- const BIAS_TYPE* bias,
- float act_times_w_scale,
- int thread_id,
- int num_threads) {
- assert(K % 8 == 0);
- constexpr int K_T = 3, K_H = 3, K_W = 3;
- constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
- PAD_R = 1;
- int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
- int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
- const int8_t* Bp = B.PackedMat();
-
- int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64)));
-
- int n_begin, n_end;
- int t_begin, t_end, h_begin, h_end;
- if (N >= num_threads) {
- int n_per_thread = (N + num_threads - 1) / num_threads;
- n_begin = std::min(thread_id * n_per_thread, N);
- n_end = std::min(n_begin + n_per_thread, N);
- t_begin = 0;
- t_end = T_OUT;
- h_begin = 0;
- h_end = H_OUT;
- } else {
- int nthreads_per_n = num_threads / N;
- n_begin = std::min(thread_id / nthreads_per_n, N);
- n_end = std::min(n_begin + 1, N);
-
- int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads);
- int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads);
- int nthreads_of_n = tid_of_n_end - tid_of_n_begin;
- int tid_within_n = thread_id - tid_of_n_begin;
- assert(tid_within_n >= 0);
- assert(tid_within_n < nthreads_of_n);
-
- // n is processed by num_threads_t * num_threads_h 2D grid of threads
- int num_threads_t, num_threads_h;
- // num_threads_w <= num_threads_h
- tie(num_threads_t, num_threads_h) = closest_factors_(nthreads_of_n);
- int tid_t = tid_within_n / num_threads_h;
- int tid_h = tid_within_n % num_threads_h;
-
- int t_per_thread = (T_OUT + num_threads_t - 1) / num_threads_t;
- t_begin = std::min(tid_t * t_per_thread, T_OUT);
- t_end = std::min(t_begin + t_per_thread, T_OUT);
-
- int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h;
- h_begin = std::min(tid_h * h_per_thread, H_OUT);
- h_end = std::min(h_begin + h_per_thread, H_OUT);
- }
-
- for (int n = n_begin; n < n_end; ++n) {
- const uint8_t* A_base = A + n * T * H * W * K;
- uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * K;
-
- for (int t = t_begin; t < t_end; ++t) {
- for (int h = h_begin; h < h_end; ++h) {
- for (int w = 0; w < W_OUT; ++w) {
- depthwise_3x3x3_kernel_<
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC,
- B_SYMMETRIC>(
- T,
- H,
- W,
- K,
- t,
- h,
- w,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias,
- act_times_w_scale);
- } // w
- } // h
- } // t
- } // for each n
-};
-
template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, typename BIAS_TYPE>
static inline __attribute__((always_inline)) void
depthwise_3x3_per_channel_quantization_pad_1_(
@@ -2100,7 +668,7 @@ depthwise_3x3_per_channel_quantization_pad_1_(
int32_t A_zero_point,
const uint8_t* A,
const int32_t* B_zero_point,
- const Packed3x3ConvMatrix& B,
+ const PackedDepthWiseConvMatrix& B,
const float* C_multiplier,
int32_t C_zero_point,
int32_t* C_int32,
@@ -2421,119 +989,6 @@ depthwise_3x3_per_channel_quantization_pad_1_(
} // for each n
};
-template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, typename BIAS_TYPE>
-static inline __attribute__((always_inline)) void
-depthwise_3x3x3_per_channel_quantization_pad_1_(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- const int32_t* B_zero_point,
- const Packed3x3x3ConvMatrix& B,
- const float* C_multiplier,
- int32_t C_zero_point,
- int32_t* C_int32,
- uint8_t* C_uint8,
- const int32_t* col_offsets,
- const BIAS_TYPE* bias,
- const float* act_times_w_scale,
- int thread_id,
- int num_threads) {
- assert(K % 8 == 0);
- constexpr int K_T = 3, K_H = 3, K_W = 3;
- constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
- PAD_R = 1;
- int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
- int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
- const int8_t* Bp = B.PackedMat();
-
- int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64)));
-
- int n_begin, n_end;
- int t_begin, t_end, h_begin, h_end;
- if (N >= num_threads) {
- int n_per_thread = (N + num_threads - 1) / num_threads;
- n_begin = std::min(thread_id * n_per_thread, N);
- n_end = std::min(n_begin + n_per_thread, N);
- t_begin = 0;
- t_end = T_OUT;
- h_begin = 0;
- h_end = H_OUT;
- } else {
- int nthreads_per_n = num_threads / N;
- n_begin = std::min(thread_id / nthreads_per_n, N);
- n_end = std::min(n_begin + 1, N);
-
- int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads);
- int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads);
- int nthreads_of_n = tid_of_n_end - tid_of_n_begin;
- int tid_within_n = thread_id - tid_of_n_begin;
- assert(tid_within_n >= 0);
- assert(tid_within_n < nthreads_of_n);
-
- // n is processed by num_threads_t * num_threads_h 2D grid of threads
- int num_threads_t, num_threads_h;
- // num_threads_w <= num_threads_h
- tie(num_threads_t, num_threads_h) = closest_factors_(nthreads_of_n);
- int tid_t = tid_within_n / num_threads_h;
- int tid_h = tid_within_n % num_threads_h;
-
- int t_per_thread = (T_OUT + num_threads_t - 1) / num_threads_t;
- t_begin = std::min(tid_t * t_per_thread, T_OUT);
- t_end = std::min(t_begin + t_per_thread, T_OUT);
-
- int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h;
- h_begin = std::min(tid_h * h_per_thread, H_OUT);
- h_end = std::min(h_begin + h_per_thread, H_OUT);
- }
-
- for (int n = n_begin; n < n_end; ++n) {
- const uint8_t* A_base = A + n * T * H * W * K;
- uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * K;
-
- for (int t = t_begin; t < t_end; ++t) {
- for (int h = h_begin; h < h_end; ++h) {
- for (int w = 0; w < W_OUT; ++w) {
- depthwise_3x3x3_per_channel_quantization_kernel_<
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC,
- BIAS_TYPE>(
- T,
- H,
- W,
- K,
- t,
- h,
- w,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias,
- act_times_w_scale);
- } // w
- } // h
- } // t
- } // for each n
-};
-
// Dispatch A_SYMMETRIC and B_SYMMETRIC
template <bool FUSE_RELU, bool HAS_BIAS, typename BIAS_TYPE>
static void depthwise_3x3_pad_1_(
@@ -2546,7 +1001,7 @@ static void depthwise_3x3_pad_1_(
int32_t A_zero_point,
const uint8_t* A,
int32_t B_zero_point,
- const Packed3x3ConvMatrix& B,
+ const PackedDepthWiseConvMatrix& B,
float C_multiplier,
int32_t C_zero_point,
uint8_t* C,
@@ -2679,7 +1134,7 @@ static void depthwise_3x3_pad_1_(
int32_t A_zero_point,
const uint8_t* A,
int32_t B_zero_point,
- const Packed3x3ConvMatrix& B,
+ const PackedDepthWiseConvMatrix& B,
float C_multiplier,
int32_t C_zero_point,
uint8_t* C,
@@ -2744,7 +1199,7 @@ void depthwise_3x3_pad_1(
int32_t A_zero_point,
const uint8_t* A,
int32_t B_zero_point,
- const Packed3x3ConvMatrix& B,
+ const PackedDepthWiseConvMatrix& B,
float C_multiplier,
int32_t C_zero_point,
uint8_t* C,
@@ -2754,6 +1209,12 @@ void depthwise_3x3_pad_1(
float act_times_w_scale,
int thread_id,
int num_threads) {
+ if (B.GetKernelProduct() != 3 * 3) {
+ string msg =
+ "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " +
+ to_string(3 * 3) + " but has " + to_string(B.GetKernelProduct());
+ throw logic_error(msg);
+ }
if (stride_h == 0 || stride_w == 0 || num_threads == 0) {
assert(0 && "stride_h == 0 || stride_w == 0 || num_threads == 0");
return;
@@ -2969,300 +1430,6 @@ void depthwise_3x3_pad_1(
}
}
-// Dispatch A_SYMMETRIC and B_SYMMETRIC
-template <bool FUSE_RELU, bool HAS_BIAS, typename BIAS_TYPE>
-static void depthwise_3x3x3_pad_1_(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- int32_t B_zero_point,
- const Packed3x3x3ConvMatrix& B,
- float C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const BIAS_TYPE* bias,
- float act_times_w_scale,
- int thread_id,
- int num_threads) {
- int32_t C_int32_temp[(K + 31) / 32 * 32];
- if (A_zero_point == 0 || col_offsets == nullptr) {
- if (B_zero_point == 0) {
- depthwise_3x3x3_pad_1_<
- FUSE_RELU,
- HAS_BIAS,
- true /*A_symmetric*/,
- true /*B_symmetric*/,
- BIAS_TYPE>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3x3_pad_1_<
- FUSE_RELU,
- HAS_BIAS,
- true /*A_symmetric*/,
- false /*B_symmetric*/,
- BIAS_TYPE>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- }
- } else {
- if (B_zero_point == 0) {
- depthwise_3x3x3_pad_1_<
- FUSE_RELU,
- HAS_BIAS,
- false /*A_symmetric*/,
- true /*B_symmetric*/,
- BIAS_TYPE>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3x3_pad_1_<
- FUSE_RELU,
- HAS_BIAS,
- false /*A_symmetric*/,
- false /*B_symmetric*/,
- BIAS_TYPE>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- }
- }
-}
-
-// Dispatch HAS_BIAS
-template <bool FUSE_RELU, typename BIAS_TYPE>
-static void depthwise_3x3x3_pad_1_(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- int32_t B_zero_point,
- const Packed3x3x3ConvMatrix& B,
- float C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const BIAS_TYPE* bias,
- float act_times_w_scale,
- int thread_id,
- int num_threads) {
- if (bias) {
- depthwise_3x3x3_pad_1_<FUSE_RELU, true /*HAS_BIAS*/, BIAS_TYPE>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3x3_pad_1_<FUSE_RELU, false /*HAS_BIAS*/, BIAS_TYPE>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- }
-}
-
-// Dispatch FUSE_RELU
-template <typename BIAS_TYPE>
-void depthwise_3x3x3_pad_1(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- int32_t B_zero_point,
- const Packed3x3x3ConvMatrix& B,
- float C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const BIAS_TYPE* bias,
- bool fuse_relu,
- float act_times_w_scale,
- int thread_id,
- int num_threads) {
- if (stride_t == 0 || stride_h == 0 || stride_w == 0 || num_threads == 0) {
- assert(
- 0 &&
- "stride_t == 0 || stride_h == 0 || stride_w == 0 || num_threads == 0");
- return;
- }
- if (N == 0) {
- // In C2, batch size 0 is allowed, so we should just early return.
- return;
- }
- if (fuse_relu) {
- depthwise_3x3x3_pad_1_<true /*FUSE_RELU*/, BIAS_TYPE>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3x3_pad_1_<false /*FUSE_RELU*/, BIAS_TYPE>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- }
-}
-
// Dispatch A_SYMMETRIC
template <bool FUSE_RELU, bool HAS_BIAS, typename BIAS_TYPE>
static void depthwise_3x3_per_channel_quantization_pad_1_(
@@ -3275,7 +1442,7 @@ static void depthwise_3x3_per_channel_quantization_pad_1_(
int32_t A_zero_point,
const uint8_t* A,
const int32_t* B_zero_point,
- const Packed3x3ConvMatrix& Bp,
+ const PackedDepthWiseConvMatrix& Bp,
const float* C_multiplier,
int32_t C_zero_point,
uint8_t* C,
@@ -3350,7 +1517,7 @@ static void depthwise_3x3_per_channel_quantization_pad_1_(
int32_t A_zero_point,
const uint8_t* A,
const int32_t* B_zero_point,
- const Packed3x3ConvMatrix& Bp,
+ const PackedDepthWiseConvMatrix& Bp,
const float* C_multiplier,
int32_t C_zero_point,
uint8_t* C,
@@ -3420,7 +1587,7 @@ void depthwise_3x3_per_channel_quantization_pad_1(
int32_t A_zero_point,
const uint8_t* A,
const int32_t* B_zero_point,
- const Packed3x3ConvMatrix& Bp,
+ const PackedDepthWiseConvMatrix& Bp,
const float* C_multiplier,
int32_t C_zero_point,
uint8_t* C,
@@ -3430,6 +1597,12 @@ void depthwise_3x3_per_channel_quantization_pad_1(
const float* act_times_w_scale,
int thread_id,
int num_threads) {
+ if (Bp.GetKernelProduct() != 3 * 3) {
+ string msg =
+ "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " +
+ to_string(3 * 3) + " but has " + to_string(Bp.GetKernelProduct());
+ throw logic_error(msg);
+ }
if (stride_h == 0 || stride_w == 0 || num_threads == 0) {
assert(0 && "stride_h == 0 || stride_w == 0 || num_threads == 0");
return;
@@ -3665,248 +1838,6 @@ void depthwise_3x3_per_channel_quantization_pad_1(
}
}
-// Dispatch A_SYMMETRIC
-template <bool FUSE_RELU, bool HAS_BIAS, typename BIAS_TYPE>
-static void depthwise_3x3x3_per_channel_quantization_pad_1_(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- const int32_t* B_zero_point,
- const Packed3x3x3ConvMatrix& B,
- const float* C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const BIAS_TYPE* bias,
- const float* act_times_w_scale,
- int thread_id,
- int num_threads) {
- int32_t C_int32_temp[(K + 31) / 32 * 32];
- if (A_zero_point == 0 || col_offsets == nullptr) {
- depthwise_3x3x3_per_channel_quantization_pad_1_<
- FUSE_RELU,
- HAS_BIAS,
- true /*A_SYMM*/,
- BIAS_TYPE>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3x3_per_channel_quantization_pad_1_<
- FUSE_RELU,
- HAS_BIAS,
- false /*A_SYMM*/,
- BIAS_TYPE>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- }
-}
-
-// Dispatch HAS_BIAS
-template <bool FUSE_RELU, typename BIAS_TYPE>
-static void depthwise_3x3x3_per_channel_quantization_pad_1_(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- const int32_t* B_zero_point,
- const Packed3x3x3ConvMatrix& B,
- const float* C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const BIAS_TYPE* bias,
- const float* act_times_w_scale,
- int thread_id,
- int num_threads) {
- if (bias) {
- depthwise_3x3x3_per_channel_quantization_pad_1_<
- FUSE_RELU,
- true /* HAS_BIAS */,
- BIAS_TYPE>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3x3_per_channel_quantization_pad_1_<
- FUSE_RELU,
- false /* HAS_BIAS */,
- BIAS_TYPE>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- }
-}
-
-// Dispatch FUSE_RELU
-template <typename BIAS_TYPE>
-void depthwise_3x3x3_per_channel_quantization_pad_1(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- const int32_t* B_zero_point,
- const Packed3x3x3ConvMatrix& B,
- const float* C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const BIAS_TYPE* bias,
- bool fuse_relu,
- const float* act_times_w_scale,
- int thread_id,
- int num_threads) {
- if (stride_t == 0 || stride_h == 0 || stride_w == 0 || num_threads == 0) {
- assert(
- 0 &&
- "stride_t == 0 || stride_h == 0 || stride_w == 0 || num_threads == 0");
- return;
- }
- if (N == 0) {
- // In C2, batch size 0 is allowed, so we should just early return.
- return;
- }
- if (fuse_relu) {
- depthwise_3x3x3_per_channel_quantization_pad_1_<
- true /* FUSE_RELU */,
- BIAS_TYPE>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3x3_per_channel_quantization_pad_1_<
- false /* FUSE_RELU */,
- BIAS_TYPE>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- }
-}
-
// To be removed
void depthwise_3x3_pad_1(
int N,
@@ -3918,7 +1849,7 @@ void depthwise_3x3_pad_1(
int32_t A_zero_point,
const uint8_t* A,
int32_t B_zero_point,
- const Packed3x3ConvMatrix& B,
+ const PackedDepthWiseConvMatrix& B,
float C_multiplier,
int32_t C_zero_point,
uint8_t* C,
@@ -3960,7 +1891,7 @@ void depthwise_3x3_per_channel_quantization_pad_1(
int32_t A_zero_point,
const uint8_t* A,
const int32_t* B_zero_point,
- const Packed3x3ConvMatrix& Bp,
+ const PackedDepthWiseConvMatrix& Bp,
const float* C_multiplier,
int32_t C_zero_point,
uint8_t* C,
@@ -3991,97 +1922,6 @@ void depthwise_3x3_per_channel_quantization_pad_1(
num_threads);
}
-// To be removed
-void depthwise_3x3x3_pad_1(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- int32_t B_zero_point,
- const Packed3x3x3ConvMatrix& B,
- float C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const int32_t* bias,
- bool fuse_relu,
- int thread_id,
- int num_threads) {
- depthwise_3x3x3_pad_1<int32_t>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- fuse_relu,
- 1.0f, // act_scale * weight_scale
- thread_id,
- num_threads);
-}
-
-void depthwise_3x3x3_per_channel_quantization_pad_1(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- const int32_t* B_zero_point,
- const Packed3x3x3ConvMatrix& B,
- const float* C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const int32_t* bias,
- bool fuse_relu,
- int thread_id,
- int num_threads) {
- depthwise_3x3x3_per_channel_quantization_pad_1(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- fuse_relu,
- nullptr, // act_scale * weight_scale
- thread_id,
- num_threads);
-}
-
template void depthwise_3x3_pad_1(
int N,
int H,
@@ -4092,7 +1932,7 @@ template void depthwise_3x3_pad_1(
int32_t A_zero_point,
const uint8_t* A,
int32_t B_zero_point,
- const Packed3x3ConvMatrix& B,
+ const PackedDepthWiseConvMatrix& B,
float C_multiplier,
int32_t C_zero_point,
uint8_t* C,
@@ -4113,7 +1953,7 @@ template void depthwise_3x3_pad_1(
int32_t A_zero_point,
const uint8_t* A,
int32_t B_zero_point,
- const Packed3x3ConvMatrix& B,
+ const PackedDepthWiseConvMatrix& B,
float C_multiplier,
int32_t C_zero_point,
uint8_t* C,
@@ -4134,7 +1974,7 @@ template void depthwise_3x3_per_channel_quantization_pad_1(
int32_t A_zero_point,
const uint8_t* A,
const int32_t* B_zero_point,
- const Packed3x3ConvMatrix& Bp,
+ const PackedDepthWiseConvMatrix& Bp,
const float* C_multiplier,
int32_t C_zero_point,
uint8_t* C,
@@ -4155,99 +1995,7 @@ template void depthwise_3x3_per_channel_quantization_pad_1(
int32_t A_zero_point,
const uint8_t* A,
const int32_t* B_zero_point,
- const Packed3x3ConvMatrix& Bp,
- const float* C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const float* bias,
- bool fuse_relu,
- const float* act_times_w_scale,
- int thread_id,
- int num_threads);
-
-template void depthwise_3x3x3_pad_1(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- int32_t B_zero_point,
- const Packed3x3x3ConvMatrix& B,
- float C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const int32_t* bias,
- bool fuse_relu,
- float act_times_w_scale,
- int thread_id,
- int num_threads);
-
-template void depthwise_3x3x3_pad_1(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- int32_t B_zero_point,
- const Packed3x3x3ConvMatrix& B,
- float C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const float* bias,
- bool fuse_relu,
- float act_times_w_scale,
- int thread_id,
- int num_threads);
-
-template void depthwise_3x3x3_per_channel_quantization_pad_1(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- const int32_t* B_zero_point,
- const Packed3x3x3ConvMatrix& B,
- const float* C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const int32_t* bias,
- bool fuse_relu,
- const float* act_times_w_scale,
- int thread_id,
- int num_threads);
-
-template void depthwise_3x3x3_per_channel_quantization_pad_1(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- const int32_t* B_zero_point,
- const Packed3x3x3ConvMatrix& B,
+ const PackedDepthWiseConvMatrix& Bp,
const float* C_multiplier,
int32_t C_zero_point,
uint8_t* C,
diff --git a/src/PackDepthwiseConvMatrixAvx2.cc b/src/PackDepthwiseConvMatrixAvx2.cc
new file mode 100644
index 0000000..0e17bcd
--- /dev/null
+++ b/src/PackDepthwiseConvMatrixAvx2.cc
@@ -0,0 +1,203 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "fbgemm/FbgemmI8DepthwiseAvx2.h"
+
+#include <immintrin.h>
+
+using namespace std;
+
+namespace fbgemm {
+
+// clang-format off
+static int masks[8][8] = {
+ // NOTE: clang-format wants to use a different formatting but the current
+ // formatting should be easier to read.
+ { 0, 0, 0, 0, 0, 0, 0, 0, },
+ { -1, 0, 0, 0, 0, 0, 0, 0, },
+ { -1, -1, 0, 0, 0, 0, 0, 0, },
+ { -1, -1, -1, 0, 0, 0, 0, 0, },
+ { -1, -1, -1, -1, 0, 0, 0, 0, },
+ { -1, -1, -1, -1, -1, 0, 0, 0, },
+ { -1, -1, -1, -1, -1, -1, 0, 0, },
+ { -1, -1, -1, -1, -1, -1, -1, 0, },
+};
+// clang-format on
+
+PackedDepthWiseConvMatrix::PackedDepthWiseConvMatrix(
+ int K,
+ int kernel_prod,
+ const int8_t* smat)
+ : K_(K), kernel_prod_(kernel_prod) {
+ // Transpose the input matrix to make packing faster.
+ alignas(64) int8_t smat_transposed[K * kernel_prod];
+ for (int i = 0; i < kernel_prod; ++i) {
+ for (int j = 0; j < K; ++j) {
+ smat_transposed[i * K + j] = smat[i + j * kernel_prod];
+ }
+ }
+
+ // Allocate packed arrays
+ int kernel_prod_aligned = (kernel_prod + 1) / 2 * 2;
+ // pmat_ = static_cast<int8_t *>(fbgemmAlignedAlloc(
+ // 64, ((K + 31) / 32) * KERNEL_PROD_ALIGNED * 32 * sizeof(int8_t)));
+ posix_memalign(
+ (void**)&pmat_,
+ 64,
+ ((K + 31) / 32) * kernel_prod_aligned * 32 * sizeof(int8_t));
+
+ // Pack input matrix
+ // The layout is optimized to use vpmaddubsw efficiently (see
+ // madd_epi16x4_packed function).
+ // For a group of 32 channels, we have 10 32B SIMD registers.
+ // Denote ith channel jth filter as (i, j)
+ // 0th SIMD register:
+ // (0, 0), (0, 1), (0, 2), (0, 3), ..., (3, 0), (3, 1), (3, 2), (3, 3)
+ // (16, 0), (16, 1), (16, 2), (16, 3), ..., (19, 0), (19, 1), (19, 2), (19, 3)
+ // 1st SIMD register:
+ // (4, 0), (4, 1), (4, 2), (4, 3), ..., (7, 0), (7, 1), (7, 2), (7, 3)
+ // (20, 0), (20, 1), (20, 2), (20, 3), ..., (23, 0), (23, 1), (23, 2), (23, 3)
+ // 2nd SIMD register:
+ // (8, 0), (8, 1), (8, 2), (8, 3), ..., (11, 0), (11, 1), (11, 2), (11, 3)
+ // (24, 0), (24, 1), (24, 2), (24, 3), ..., (27, 0), (27, 1), (27, 2), (27, 3)
+ // 3rd SIMD register:
+ // (12, 0), (12, 1), (12, 2), (12, 3), ..., (15, 0), (15, 1), (15, 2), (15, 3)
+ // (28, 0), (28, 1), (28, 2), (28, 3), ..., (31, 0), (31, 1), (31, 2), (31, 3)
+ // 4-7th SIMD register: same as the previous 4 registers but for 4-7th filter
+ // coefficients
+ // ...
+ //
+ // REMAINDER
+ // If kernel_prod % 4 == 1 for example when kernel_prod == 9
+ // 8th SIMD register:
+ // (0, 8), zero, ..., (7, 8), zero
+ // (16, 8), zero, ..., (23, 8), zero
+ // 9th SIMD register:
+ // (8, 8), zero, ..., (15, 8), zero
+ // (24, 8), zero, ..., (31, 8), zero
+ // We use madd_epi16_packed for this case
+ //
+ // If kernel_prod % 4 == 2 for example when kernel_prod == 10
+ // 8th SIMD register:
+ // (0, 8), (0, 9), ..., (7, 8), (7, 9)
+ // (16, 8), (16, 9), ..., (23, 8), (23, 9)
+ // 9th SIMD register:
+ // (8, 8), (8, 9), ..., (15, 8), (15, 9)
+ // (24, 8), (24, 9), ..., (31, 8), (31, 9)
+ //
+ // If kernel_prod % 4 == 3 for example when kernel_prod == 11
+ // 8th SIMD register:
+ // (0, 8), (0, 9), (0, 10), zero, ..., (3, 8), (3, 9), (3, 10), zero
+ // (16, 8), (16, 9), (16, 10), zero, ..., (19, 8), (19, 9), (19, 10), zero
+ // 9th SIMD register:
+ // (4, 8), (4, 9), (4, 10), zero, ..., (7, 8), (7, 9), (7, 10), zero
+ // (20, 8), (20, 9), (20, 10), zero, ..., (23, 8), (23, 9), (23, 10), zero
+ // 10th SIMD register:
+ // (8, 8), (8, 9), (8, 10), zero, ..., (11, 8), (11, 9), (11, 10), zero
+ // (24, 8), (24, 9), (24, 10), zero, ..., (27, 8), (27, 9), (27, 10), zero
+ // 11th SIMD register:
+ // (12, 8), (12, 9), (12, 10), zero, ..., (15, 8), (15, 9), (15, 10), zero
+ // (28, 8), (28, 9), (28, 10), zero, ..., (31, 8), (31, 9), (31, 10), zero
+ for (int k1 = 0; k1 < K; k1 += 32) {
+ __m256i b_v[kernel_prod];
+ int remainder = K - k1;
+ if (remainder < 32) {
+ __m256i mask_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(masks[remainder / 4]));
+ for (int i = 0; i < kernel_prod; ++i) {
+ b_v[i] = _mm256_maskload_epi32(
+ reinterpret_cast<const int*>(smat_transposed + i * K + k1), mask_v);
+ }
+ } else {
+ for (int i = 0; i < kernel_prod; ++i) {
+ b_v[i] = _mm256_lddqu_si256(
+ reinterpret_cast<const __m256i*>(smat_transposed + i * K + k1));
+ }
+ }
+
+ // Interleave 2 SIMD registers
+ __m256i b_interleaved_epi16[kernel_prod_aligned];
+ __m256i zero_v = _mm256_setzero_si256();
+ for (int i = 0; i < kernel_prod_aligned / 2; ++i) {
+ if (2 * i + 1 >= kernel_prod) {
+ b_interleaved_epi16[2 * i] = _mm256_unpacklo_epi8(b_v[2 * i], zero_v);
+ b_interleaved_epi16[2 * i + 1] =
+ _mm256_unpackhi_epi8(b_v[2 * i], zero_v);
+ } else {
+ b_interleaved_epi16[2 * i] =
+ _mm256_unpacklo_epi8(b_v[2 * i], b_v[2 * i + 1]);
+ b_interleaved_epi16[2 * i + 1] =
+ _mm256_unpackhi_epi8(b_v[2 * i], b_v[2 * i + 1]);
+ }
+ }
+
+ // Interleave 4 SIMD registers
+ __m256i b_interleaved_epi32[kernel_prod_aligned];
+ for (int i = 0; i < kernel_prod_aligned / 4; ++i) {
+ b_interleaved_epi32[4 * i] = _mm256_unpacklo_epi16(
+ b_interleaved_epi16[4 * i], b_interleaved_epi16[4 * i + 2]);
+ b_interleaved_epi32[4 * i + 1] = _mm256_unpackhi_epi16(
+ b_interleaved_epi16[4 * i], b_interleaved_epi16[4 * i + 2]);
+ b_interleaved_epi32[4 * i + 2] = _mm256_unpacklo_epi16(
+ b_interleaved_epi16[4 * i + 1], b_interleaved_epi16[4 * i + 3]);
+ b_interleaved_epi32[4 * i + 3] = _mm256_unpackhi_epi16(
+ b_interleaved_epi16[4 * i + 1], b_interleaved_epi16[4 * i + 3]);
+ }
+ for (int i = kernel_prod_aligned / 4 * 4; i < kernel_prod_aligned; ++i) {
+ b_interleaved_epi32[i] = b_interleaved_epi16[i];
+ }
+
+ for (int i = 0; i < kernel_prod_aligned; ++i) {
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(
+ &pmat_[((k1 / 32) * kernel_prod_aligned + i) * 32]),
+ b_interleaved_epi32[i]);
+ }
+ }
+}
+
+int PackedDepthWiseConvMatrix::addr(int r, int c) {
+ int kernel_prod_aligned = (kernel_prod_ + 1) / 2 * 2;
+ if (c >= kernel_prod_ / 4 * 4 &&
+ (kernel_prod_ % 4 == 1 || kernel_prod_ % 4 == 2)) {
+ int kBlock = r / 32;
+ int reg_idx = (r % 16) / 8 + c / 4 * 4;
+
+ int blk_idx = kBlock * kernel_prod_aligned + reg_idx;
+
+ int r_ = r % 8;
+ int c_ = c % 4;
+
+ int in_blk_idx = (r % 32) / 16 * 16 + 2 * r_ + c_;
+ return blk_idx * 32 + in_blk_idx;
+
+ } else {
+ int kBlock = r / 32;
+ int reg_idx = (r % 16) / 4 + c / 4 * 4;
+
+ int blk_idx = kBlock * kernel_prod_aligned + reg_idx;
+
+ int r_ = r % 4;
+ int c_ = c % 4;
+
+ int in_blk_idx = (r % 32) / 16 * 16 + 4 * r_ + c_;
+ return blk_idx * 32 + in_blk_idx;
+ }
+}
+
+void PackedDepthWiseConvMatrix::unpack(int8_t* unpacked_data) {
+ for (int r = 0; r < K_; ++r) {
+ for (int c = 0; c < kernel_prod_; ++c) {
+ unpacked_data[r * kernel_prod_ + c] = pmat_[addr(r, c)];
+ }
+ }
+}
+
+PackedDepthWiseConvMatrix::~PackedDepthWiseConvMatrix() {
+ free(pmat_);
+}
+
+} // namespace fbgemm
diff --git a/src/PackWeightsForConv.cc b/src/PackWeightsForConv.cc
index 44f210e..192fb00 100644
--- a/src/PackWeightsForConv.cc
+++ b/src/PackWeightsForConv.cc
@@ -23,35 +23,17 @@ PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv(
// FbgemmConv.cc
switch (ConvFastPath<SPATIAL_DIM, accT>(conv_p)) {
case optimized_conv_t::depthwise: {
- if (SPATIAL_DIM == 3) {
- W_im2col_packed_ = nullptr;
- W_dw_2D_packed_ = nullptr;
- W_dw_3D_packed_ =
- std::make_shared<Packed3x3x3ConvMatrix>(conv_p.G, sdata);
- W_gconv_packed_ = nullptr;
- } else {
- W_im2col_packed_ = nullptr;
- W_dw_2D_packed_ =
- std::make_shared<Packed3x3ConvMatrix>(conv_p.G, sdata);
- W_dw_3D_packed_ = nullptr;
- W_gconv_packed_ = nullptr;
- }
+ W_dw_packed_ = std::make_shared<PackedDepthWiseConvMatrix>(
+ conv_p.G, SPATIAL_DIM == 3 ? 3 * 3 * 3 : 3 * 3, sdata);
break;
}
case optimized_conv_t::groupwise: {
- W_im2col_packed_ = nullptr;
- W_dw_2D_packed_ = nullptr;
- W_dw_3D_packed_ = nullptr;
W_gconv_packed_ =
std::make_shared<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>>(
matrix_op_t::Transpose, conv_p, sdata, nullptr);
break;
}
case optimized_conv_t::pointwise: {
- W_im2col_packed_ = nullptr;
- W_dw_2D_packed_ = nullptr;
- W_dw_3D_packed_ = nullptr;
- W_gconv_packed_ = nullptr;
int NDim = conv_p.OC / conv_p.G;
int KDim = conv_p.K[0] * conv_p.K[1] * conv_p.IC;
W_pointwise_packed_ = std::make_shared<PackBMatrix<T, accT>>(
@@ -77,9 +59,6 @@ PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv(
nullptr,
conv_p.G,
blocking_params);
- W_dw_2D_packed_ = nullptr;
- W_dw_3D_packed_ = nullptr;
- W_gconv_packed_ = nullptr;
break;
}
} // switch
@@ -87,10 +66,8 @@ PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv(
template <int SPATIAL_DIM, typename T, typename accT>
void PackWeightsForConv<SPATIAL_DIM, T, accT>::unpack(T* origin_buf) {
- if (W_dw_2D_packed_) {
- W_dw_2D_packed_->unpack(origin_buf);
- } else if (W_dw_3D_packed_) {
- W_dw_3D_packed_->unpack(origin_buf);
+ if (W_dw_packed_) {
+ W_dw_packed_->unpack(origin_buf);
} else if (W_gconv_packed_) {
W_gconv_packed_->unpack(origin_buf);
} else if (W_im2col_packed_) {
@@ -139,7 +116,7 @@ std::string PackWeightsForConv<SPATIAL_DIM, T, accT>::mismatchingParams(
};
auto combineInt = [&combineStr](std::string id, int int1, int int2) {
- return combineStr(id, std::to_string(int1), std::to_string(int2));
+ return combineStr(id, std::to_string(int1), std::to_string(int2));
};
if (conv_param_.IC != test_conv_p.IC) {
diff --git a/test/I8DepthwiseTest.cc b/test/I8DepthwiseTest.cc
index 8492acb..9de6943 100644
--- a/test/I8DepthwiseTest.cc
+++ b/test/I8DepthwiseTest.cc
@@ -193,7 +193,7 @@ TEST_P(FBGemmDepthWiseTest, Test3x3) {
K);
}
- Packed3x3ConvMatrix Bp(K, B.data());
+ PackedDepthWiseConvMatrix Bp(K, 3 * 3, B.data());
depthwise_3x3_pad_1(
N,
@@ -330,7 +330,7 @@ TEST_P(FBGemmDepthWiseTest, Test3x3x3) {
K);
}
- Packed3x3x3ConvMatrix Bp(K, B.data());
+ PackedDepthWiseConvMatrix Bp(K, 3 * 3 * 3, B.data());
depthwise_3x3x3_pad_1(
N,
@@ -464,7 +464,7 @@ TEST(FBGemmDepthWiseTest, Test3x3PerChannelQuantization) {
K);
}
- Packed3x3ConvMatrix Bp(K, B.data());
+ PackedDepthWiseConvMatrix Bp(K, 3 * 3, B.data());
depthwise_3x3_per_channel_quantization_pad_1(
N,
@@ -596,7 +596,7 @@ TEST(FBGemmDepthWiseTest, Test3x3x3PerChannelQuantization) {
K);
}
- Packed3x3x3ConvMatrix Bp(K, B.data());
+ PackedDepthWiseConvMatrix Bp(K, 3 * 3 * 3, B.data());
depthwise_3x3x3_per_channel_quantization_pad_1(
N,
@@ -652,36 +652,8 @@ TEST_P(FBGemmDepthWisePackUnpackTest, TestPackUnpack) {
aligned_vector<int8_t> BUnpacked(K * kernel_prod);
- if (kernel_prod == 1) {
- Packed1ConvMatrix BPacked(K, B.data());
- BPacked.unpack(BUnpacked.data());
- } else if (kernel_prod == 2) {
- Packed2ConvMatrix BPacked(K, B.data());
- BPacked.unpack(BUnpacked.data());
- } else if (kernel_prod == 3) {
- Packed3ConvMatrix BPacked(K, B.data());
- BPacked.unpack(BUnpacked.data());
- } else if (kernel_prod == 4) {
- Packed4ConvMatrix BPacked(K, B.data());
- BPacked.unpack(BUnpacked.data());
- } else if (kernel_prod == 5) {
- Packed5ConvMatrix BPacked(K, B.data());
- BPacked.unpack(BUnpacked.data());
- } else if (kernel_prod == 9) {
- Packed3x3ConvMatrix BPacked(K, B.data());
- BPacked.unpack(BUnpacked.data());
- } else if (kernel_prod == 10) {
- Packed10ConvMatrix BPacked(K, B.data());
- BPacked.unpack(BUnpacked.data());
- } else if (kernel_prod == 11) {
- Packed11ConvMatrix BPacked(K, B.data());
- BPacked.unpack(BUnpacked.data());
- } else if (kernel_prod == 27) {
- Packed3x3x3ConvMatrix BPacked(K, B.data());
- BPacked.unpack(BUnpacked.data());
- } else {
- ASSERT_TRUE(false);
- }
+ PackedDepthWiseConvMatrix BPacked(K, kernel_prod, B.data());
+ BPacked.unpack(BUnpacked.data());
ASSERT_EQ(B, BUnpacked)
<< "Original and unpacked data elements are not the same";
diff --git a/test/UniConvTest.cc b/test/UniConvTest.cc
index 83674ce..cead3a6 100644
--- a/test/UniConvTest.cc
+++ b/test/UniConvTest.cc
@@ -147,23 +147,19 @@ TEST_P(uniConvTest, packingTest) {
case optimized_conv_t::depthwise: {
ASSERT_EQ(packedB_2D.getPackedWForIm2col(), nullptr)
<< "im2col packed matrix should be null";
- ASSERT_EQ(packedB_2D.getPackedWFor3DDW(), nullptr)
- << "3D depthwise packed matrix should be null";
ASSERT_EQ(packedB_2D.getPackedWForGroupwise(), nullptr)
<< "groupwise packed matrix should be null";
ASSERT_EQ(packedB_2D.getPackedWForPointwise(), nullptr)
<< "pointwise packed matrix should be null";
- ASSERT_NE(packedB_2D.getPackedWFor2DDW(), nullptr)
- << "2D depthwise packed matrix is null";
+ ASSERT_NE(packedB_2D.getPackedWForDepthwise(), nullptr)
+ << "depthwise packed matrix is null";
break;
}
case optimized_conv_t::groupwise: {
ASSERT_EQ(packedB_2D.getPackedWForIm2col(), nullptr)
<< "im2col packed matrix should be null";
- ASSERT_EQ(packedB_2D.getPackedWFor2DDW(), nullptr)
- << "2D depthwise packed matrix is null";
- ASSERT_EQ(packedB_2D.getPackedWFor3DDW(), nullptr)
- << "3D depthwise packed matrix should be null";
+ ASSERT_EQ(packedB_2D.getPackedWForDepthwise(), nullptr)
+ << "depthwise packed matrix should be null";
ASSERT_EQ(packedB_2D.getPackedWForPointwise(), nullptr)
<< "pointwise packed matrix should be null";
ASSERT_NE(packedB_2D.getPackedWForGroupwise(), nullptr)
@@ -173,10 +169,8 @@ TEST_P(uniConvTest, packingTest) {
case optimized_conv_t::pointwise: {
ASSERT_EQ(packedB_2D.getPackedWForIm2col(), nullptr)
<< "im2col packed matrix should be null";
- ASSERT_EQ(packedB_2D.getPackedWFor2DDW(), nullptr)
- << "2D depthwise packed matrix is null";
- ASSERT_EQ(packedB_2D.getPackedWFor3DDW(), nullptr)
- << "3D depthwise packed matrix should be null";
+ ASSERT_EQ(packedB_2D.getPackedWForDepthwise(), nullptr)
+ << "depthwise packed matrix should null";
ASSERT_EQ(packedB_2D.getPackedWForGroupwise(), nullptr)
<< "Groupwise packed matrix should be null";
ASSERT_NE(packedB_2D.getPackedWForPointwise(), nullptr)
@@ -184,10 +178,8 @@ TEST_P(uniConvTest, packingTest) {
break;
}
case optimized_conv_t::im2col: {
- ASSERT_EQ(packedB_2D.getPackedWFor2DDW(), nullptr)
- << "2D depthwise packed matrix is null";
- ASSERT_EQ(packedB_2D.getPackedWFor3DDW(), nullptr)
- << "3D depthwise packed matrix should be null";
+ ASSERT_EQ(packedB_2D.getPackedWForDepthwise(), nullptr)
+ << "depthwise packed matrix should be null";
ASSERT_EQ(packedB_2D.getPackedWForGroupwise(), nullptr)
<< "groupwise packed matrix should be null";
ASSERT_EQ(packedB_2D.getPackedWForPointwise(), nullptr)
@@ -215,16 +207,14 @@ TEST_P(uniConvTest, packingTest) {
switch (ConvFastPath<3, int32_t>(conv_p_3d)) {
case optimized_conv_t::depthwise: {
- ASSERT_EQ(packedB_3D.getPackedWFor2DDW(), nullptr)
- << "2D depthwise packed matrix is null";
ASSERT_EQ(packedB_3D.getPackedWForIm2col(), nullptr)
<< "im2col packed matrix should be null";
ASSERT_EQ(packedB_3D.getPackedWForGroupwise(), nullptr)
<< "groupwise packed matrix should be null";
ASSERT_EQ(packedB_3D.getPackedWForPointwise(), nullptr)
<< "pointwise packed matrix should be null";
- ASSERT_NE(packedB_3D.getPackedWFor3DDW(), nullptr)
- << "3D depthwise packed matrix should be null";
+ ASSERT_NE(packedB_3D.getPackedWForDepthwise(), nullptr)
+ << "depthwise packed matrix is null";
break;
}
case optimized_conv_t::groupwise: {
@@ -232,10 +222,8 @@ TEST_P(uniConvTest, packingTest) {
break;
}
case optimized_conv_t::pointwise: {
- ASSERT_EQ(packedB_3D.getPackedWFor2DDW(), nullptr)
- << "2D depthwise packed matrix is null";
- ASSERT_EQ(packedB_3D.getPackedWFor3DDW(), nullptr)
- << "3D depthwise packed matrix should be null";
+ ASSERT_EQ(packedB_3D.getPackedWForDepthwise(), nullptr)
+ << "depthwise packed matrix should be null";
ASSERT_EQ(packedB_3D.getPackedWForGroupwise(), nullptr)
<< "groupwise packed matrix should be null";
ASSERT_EQ(packedB_3D.getPackedWForIm2col(), nullptr)
@@ -245,10 +233,8 @@ TEST_P(uniConvTest, packingTest) {
break;
}
case optimized_conv_t::im2col: {
- ASSERT_EQ(packedB_3D.getPackedWFor2DDW(), nullptr)
- << "2D depthwise packed matrix is null";
- ASSERT_EQ(packedB_3D.getPackedWFor3DDW(), nullptr)
- << "3D depthwise packed matrix should be null";
+ ASSERT_EQ(packedB_3D.getPackedWForDepthwise(), nullptr)
+ << "depthwise packed matrix should be null";
ASSERT_EQ(packedB_3D.getPackedWForGroupwise(), nullptr)
<< "groupwise packed matrix should be null";
ASSERT_EQ(packedB_3D.getPackedWForPointwise(), nullptr)