Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJongsoo Park <jongsoo@fb.com>2020-04-07 07:01:53 +0300
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>2020-04-07 07:04:04 +0300
commit1f42be50b7f53f000b381ece2712ad361c7bf556 (patch)
treea92ceda2086e2d6aa93193e7cfebb472a7756c8c
parent35e486b706abd4d575ae9b7aaf090002aa78551e (diff)
JIT depth-wise conv (#338)
Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/338 Depth-wise convolution was causing a big code size due to its extensive use of template specializations. This diff instead uses JIT'ing to reduce the code size and performance gains. TODO: we may want to land D20860370 before to reduce JIT'ing overhead but D20860370 has a dependency to C++14. Reviewed By: dskhudia Differential Revision: D20858973 fbshipit-source-id: f7f35153fbf2cb96b4a31a82854d669cf164033f
-rw-r--r--CMakeLists.txt1
-rw-r--r--src/FbgemmI8Depthwise2DAvx2-inl.h938
-rw-r--r--src/FbgemmI8Depthwise3DAvx2.cc1026
-rw-r--r--src/FbgemmI8DepthwiseAvx2-inl.h352
-rw-r--r--src/GenerateI8Depthwise.cc506
-rw-r--r--src/GenerateI8Depthwise.h41
6 files changed, 1229 insertions, 1635 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 9abd2c9..cc891fb 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -38,6 +38,7 @@ set(FBGEMM_GENERIC_SRCS src/EmbeddingSpMDM.cc
src/FbgemmFloat16Convert.cc
src/FbgemmI64.cc
src/FbgemmI8Spmdm.cc
+ src/GenerateI8Depthwise.cc
src/GenerateKernelU8S8S32ACC16.cc
src/GenerateKernelU8S8S32ACC16Avx512.cc
src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc
diff --git a/src/FbgemmI8Depthwise2DAvx2-inl.h b/src/FbgemmI8Depthwise2DAvx2-inl.h
index 7488a3c..f4f1a42 100644
--- a/src/FbgemmI8Depthwise2DAvx2-inl.h
+++ b/src/FbgemmI8Depthwise2DAvx2-inl.h
@@ -6,356 +6,21 @@
*/
#pragma once
-#include "fbgemm/UtilsAvx2.h"
#include "fbgemm/Utils.h"
+#include "fbgemm/UtilsAvx2.h"
#include "src/FbgemmI8DepthwiseAvx2-inl.h"
+#include "src/GenerateI8Depthwise.h"
#include "src/MaskAvx2.h"
namespace fbgemm {
-template <int S = 3, bool SUM_A = false, bool REMAINDER = false>
-static ALWAYS_INLINE void inner_prod_2d_packed_(
- const __m256i* a_v,
- const __m256i* Bp,
- std::int32_t* C,
- int remainder,
- __m256i* a_sum = nullptr) {
- return inner_prod_packed_<S * S, SUM_A, REMAINDER>(
- a_v, Bp, C, remainder, a_sum);
-}
-
-template <
- bool SUM_A,
- bool REMAINDER = false,
- bool PER_CHANNEL_QUANTIZATION = false>
-static ALWAYS_INLINE void inner_prod_3x3_packed_(
- int H,
- int W,
- int K,
- int h_in,
- int w_in,
- const std::uint8_t* A,
- std::int32_t A_zero_point,
- const std::int8_t* Bp,
- const std::int32_t* B_zero_point,
- std::int32_t* C,
- int remainder,
- std::int32_t* row_offsets) {
- __m256i A_zero_point_v =
- _mm256_set1_epi8(static_cast<std::uint8_t>(A_zero_point));
- __m256i mask_v = _mm256_setzero_si256();
- if (REMAINDER) {
- mask_v = _mm256_load_si256(reinterpret_cast<const __m256i*>(
- internal::avx2_ps_or_epi32_masks[remainder / 4]));
- }
-
- // The code below can be written as a simple R*S loop but the compiler
- // doesn't unroll so we're manually unrolling it.
- // constexpr int R = 3, S = 3;
- // array<__m256i, R * S> a_v;
- // for (int r = 0; r < R; ++r) {
- // for (int s = 0; s < S; ++s) {
- // if (h_in + r >= 0 && h_in + r < H && w_in + s >= 0 && w_in + s < W) {
- // if (REMAINDER) {
- // a_v[r * S + s] =
- // _mm256_maskload_epi32((const int *)(A + (r * W + s) * K),
- // mask_v);
- // } else {
- // a_v[r * S + s] =
- // _mm256_lddqu_si256((const __m256i *)(A + (r * W + s) * K));
- // }
- // } else {
- // a_v[r * S + s] = A_zero_point_v;
- // }
- // }
- // }
- __m256i a_v[9] = {
- A_zero_point_v,
- A_zero_point_v,
- A_zero_point_v,
- A_zero_point_v,
- A_zero_point_v,
- A_zero_point_v,
- A_zero_point_v,
- A_zero_point_v,
- A_zero_point_v,
- };
-
- if (h_in >= 0 && h_in < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[0] = load_a<REMAINDER>(A + (0 * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[1] = load_a<REMAINDER>(A + (0 * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[2] = load_a<REMAINDER>(A + (0 * W + 2) * K, mask_v);
- }
- }
-
- if (h_in + 1 >= 0 && h_in + 1 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[3] = load_a<REMAINDER>(A + (1 * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[4] = load_a<REMAINDER>(A + (1 * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[5] = load_a<REMAINDER>(A + (1 * W + 2) * K, mask_v);
- }
- }
-
- if (h_in + 2 >= 0 && h_in + 2 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[6] = load_a<REMAINDER>(A + (2 * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[7] = load_a<REMAINDER>(A + (2 * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[8] = load_a<REMAINDER>(A + (2 * W + 2) * K, mask_v);
- }
- }
-
- __m256i a_sum[4];
- inner_prod_2d_packed_<3, SUM_A, REMAINDER>(
- a_v, reinterpret_cast<const __m256i*>(Bp), C, remainder, a_sum);
- if (SUM_A) {
- __m256i B_zero_point_v;
- for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) {
- if (PER_CHANNEL_QUANTIZATION) {
- B_zero_point_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(B_zero_point + i * 8));
- } else {
- B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]);
- }
- _mm256_store_si256(
- reinterpret_cast<__m256i*>(&row_offsets[i * 8]),
- _mm256_mullo_epi32(a_sum[i], B_zero_point_v));
- }
- }
-}
-
-template <
- bool SUM_A,
- bool REMAINDER = false,
- bool PER_CHANNEL_QUANTIZATION = false>
-static ALWAYS_INLINE void inner_prod_5x5_packed_(
- int H,
- int W,
- int K,
- int h_in,
- int w_in,
- const std::uint8_t* A,
- std::int32_t A_zero_point,
- const std::int8_t* Bp,
- const std::int32_t* B_zero_point,
- std::int32_t* C,
- int remainder,
- std::int32_t* row_offsets) {
- __m256i A_zero_point_v =
- _mm256_set1_epi8(static_cast<std::uint8_t>(A_zero_point));
- __m256i mask_v = _mm256_setzero_si256();
- if (REMAINDER) {
- mask_v = _mm256_load_si256(reinterpret_cast<const __m256i*>(
- internal::avx2_ps_or_epi32_masks[remainder / 4]));
- }
-
- // The code below can be written as a simple R*S loop but the compiler
- // doesn't unroll so we're manually unrolling it.
- // constexpr int R = 5, S = 5;
- // array<__m256i, R * S> a_v;
- // for (int r = 0; r < R; ++r) {
- // for (int s = 0; s < S; ++s) {
- // if (h_in + r >= 0 && h_in + r < H && w_in + s >= 0 && w_in + s < W) {
- // if (REMAINDER) {
- // a_v[r * S + s] =
- // _mm256_maskload_epi32((const int *)(A + (r * W + s) * K),
- // mask_v);
- // } else {
- // a_v[r * S + s] =
- // _mm256_lddqu_si256((const __m256i *)(A + (r * W + s) * K));
- // }
- // } else {
- // a_v[r * S + s] = A_zero_point_v;
- // }
- // }
- // }
- __m256i a_v[25] = {
- A_zero_point_v, A_zero_point_v, A_zero_point_v, A_zero_point_v,
- A_zero_point_v, A_zero_point_v, A_zero_point_v, A_zero_point_v,
- A_zero_point_v, A_zero_point_v, A_zero_point_v, A_zero_point_v,
- A_zero_point_v, A_zero_point_v, A_zero_point_v, A_zero_point_v,
- A_zero_point_v, A_zero_point_v, A_zero_point_v, A_zero_point_v,
- A_zero_point_v, A_zero_point_v, A_zero_point_v, A_zero_point_v,
- A_zero_point_v,
- };
-
- if (h_in >= 0 && h_in < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[0] = load_a<REMAINDER>(A + (0 * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[1] = load_a<REMAINDER>(A + (0 * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[2] = load_a<REMAINDER>(A + (0 * W + 2) * K, mask_v);
- }
- if (w_in + 3 >= 0 && w_in + 3 < W) {
- a_v[3] = load_a<REMAINDER>(A + (0 * W + 3) * K, mask_v);
- }
- if (w_in + 4 >= 0 && w_in + 4 < W) {
- a_v[4] = load_a<REMAINDER>(A + (0 * W + 4) * K, mask_v);
- }
- }
-
- if (h_in + 1 >= 0 && h_in + 1 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[5] = load_a<REMAINDER>(A + (1 * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[6] = load_a<REMAINDER>(A + (1 * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[7] = load_a<REMAINDER>(A + (1 * W + 2) * K, mask_v);
- }
- if (w_in + 3 >= 0 && w_in + 3 < W) {
- a_v[8] = load_a<REMAINDER>(A + (1 * W + 3) * K, mask_v);
- }
- if (w_in + 4 >= 0 && w_in + 4 < W) {
- a_v[9] = load_a<REMAINDER>(A + (1 * W + 4) * K, mask_v);
- }
- }
-
- if (h_in + 2 >= 0 && h_in + 2 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[10] = load_a<REMAINDER>(A + (2 * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[11] = load_a<REMAINDER>(A + (2 * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[12] = load_a<REMAINDER>(A + (2 * W + 2) * K, mask_v);
- }
- if (w_in + 3 >= 0 && w_in + 3 < W) {
- a_v[13] = load_a<REMAINDER>(A + (2 * W + 3) * K, mask_v);
- }
- if (w_in + 4 >= 0 && w_in + 4 < W) {
- a_v[14] = load_a<REMAINDER>(A + (2 * W + 4) * K, mask_v);
- }
- }
-
- if (h_in + 3 >= 0 && h_in + 3 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[15] = load_a<REMAINDER>(A + (3 * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[16] = load_a<REMAINDER>(A + (3 * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[17] = load_a<REMAINDER>(A + (3 * W + 2) * K, mask_v);
- }
- if (w_in + 3 >= 0 && w_in + 3 < W) {
- a_v[18] = load_a<REMAINDER>(A + (3 * W + 3) * K, mask_v);
- }
- if (w_in + 4 >= 0 && w_in + 4 < W) {
- a_v[19] = load_a<REMAINDER>(A + (3 * W + 4) * K, mask_v);
- }
- }
-
- if (h_in + 4 >= 0 && h_in + 4 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[20] = load_a<REMAINDER>(A + (4 * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[21] = load_a<REMAINDER>(A + (4 * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[22] = load_a<REMAINDER>(A + (4 * W + 2) * K, mask_v);
- }
- if (w_in + 3 >= 0 && w_in + 3 < W) {
- a_v[23] = load_a<REMAINDER>(A + (4 * W + 3) * K, mask_v);
- }
- if (w_in + 4 >= 0 && w_in + 4 < W) {
- a_v[24] = load_a<REMAINDER>(A + (4 * W + 4) * K, mask_v);
- }
- }
-
- __m256i a_sum[4];
- inner_prod_2d_packed_<5, SUM_A, REMAINDER>(
- a_v, reinterpret_cast<const __m256i*>(Bp), C, remainder, a_sum);
- if (SUM_A) {
- __m256i B_zero_point_v;
- for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) {
- if (PER_CHANNEL_QUANTIZATION) {
- B_zero_point_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(B_zero_point + i * 8));
- } else {
- B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]);
- }
- _mm256_store_si256(
- reinterpret_cast<__m256i*>(&row_offsets[i * 8]),
- _mm256_mullo_epi32(a_sum[i], B_zero_point_v));
- }
- }
-}
-
-template <
- int S,
- bool SUM_A,
- bool REMAINDER = false,
- bool PER_CHANNEL_QUANTIZATION = false>
-static ALWAYS_INLINE void inner_prod_2d_packed_(
- int H,
- int W,
- int K,
- int h_in,
- int w_in,
- const std::uint8_t* A,
- std::int32_t A_zero_point,
- const std::int8_t* Bp,
- const std::int32_t* B_zero_point,
- std::int32_t* C,
- int remainder,
- std::int32_t* row_offsets) {
- if (S == 3) {
- inner_prod_3x3_packed_<SUM_A, REMAINDER, PER_CHANNEL_QUANTIZATION>(
- H,
- W,
- K,
- h_in,
- w_in,
- A,
- A_zero_point,
- Bp,
- B_zero_point,
- C,
- remainder,
- row_offsets);
- } else {
- assert(S == 5);
- inner_prod_5x5_packed_<SUM_A, REMAINDER, PER_CHANNEL_QUANTIZATION>(
- H,
- W,
- K,
- h_in,
- w_in,
- A,
- A_zero_point,
- Bp,
- B_zero_point,
- C,
- remainder,
- row_offsets);
- }
-}
-
template <
int S,
bool FUSE_RELU,
bool HAS_BIAS,
bool A_SYMMETRIC,
bool B_SYMMETRIC,
+ bool PER_CHANNEL_QUANTIZAITON,
typename BIAS_TYPE>
static ALWAYS_INLINE void depthwise_2d_kernel_(
int H,
@@ -367,16 +32,17 @@ static ALWAYS_INLINE void depthwise_2d_kernel_(
int stride_w,
std::int32_t A_zero_point,
const std::uint8_t* A,
- std::int32_t B_zero_point,
+ const std::int32_t* B_zero_point,
const std::int8_t* Bp,
- float C_multiplier,
+ const float* C_multiplier,
std::int32_t C_zero_point,
std::int32_t* C_int32,
std::uint8_t* C_uint8,
std::int32_t* row_offsets,
const std::int32_t* col_offsets,
const BIAS_TYPE* bias,
- float act_times_w_scale) {
+ const float* act_times_w_scale,
+ GenI8Depthwise::jit_kernel_signature* pregenerated_kernel = nullptr) {
constexpr int PAD_T = (S - 1) / 2, PAD_L = (S - 1) / 2, PAD_R = (S - 1) / 2;
int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
int h_in = -PAD_T + h * stride_h;
@@ -384,138 +50,44 @@ static ALWAYS_INLINE void depthwise_2d_kernel_(
constexpr int KERNEL_PROD_ALIGNED = (S * S + 1) / 2 * 2;
- int k;
- for (k = 0; k < K / 32 * 32; k += 32) {
- inner_prod_2d_packed_<S, !B_SYMMETRIC /*SUM_A*/>(
- H,
- W,
- K,
- h_in,
- w_in,
- A + (h_in * W + w_in) * K + k,
- A_zero_point,
- Bp + k * KERNEL_PROD_ALIGNED,
- &B_zero_point,
- C_int32 + k,
- 0,
- B_SYMMETRIC ? nullptr : &row_offsets[k]);
- }
- int remainder = K - k;
- if (remainder) {
- inner_prod_2d_packed_<S, !B_SYMMETRIC, true>(
- H,
- W,
- K,
- h_in,
- w_in,
- A + (h_in * W + w_in) * K + k,
- A_zero_point,
- Bp + k * KERNEL_PROD_ALIGNED,
- &B_zero_point,
- C_int32 + k,
- remainder,
- B_SYMMETRIC ? nullptr : &row_offsets[k]);
+ int remainder = K % 32;
+ if (remainder == 0) {
+ remainder = 32;
}
- requantize_<
- FUSE_RELU,
- HAS_BIAS,
- false, /*PER_CHAN_QUANT*/
- A_SYMMETRIC,
- B_SYMMETRIC,
- BIAS_TYPE>(
- A_zero_point,
- &C_multiplier,
- C_zero_point,
+ GenI8Depthwise::jit_kernel_signature kernel = pregenerated_kernel
+ ? *pregenerated_kernel
+ : GenI8Depthwise().getOrCreate(
+ /*D=*/2,
+ S,
+ /*compute_a_sum=*/!B_SYMMETRIC,
+ PER_CHANNEL_QUANTIZAITON,
+ remainder,
+ 0,
+ 0,
+ /*top_skip=*/std::max(-h_in, 0),
+ /*bottom_skip=*/std::max(h_in + S - H, 0),
+ /*left_skip=*/std::max(-w_in, 0),
+ /*right_skip=*/std::max(w_in + S - W, 0));
+
+ kernel(
+ A + (h_in * W + w_in) * K,
+ Bp,
C_int32,
- C_uint8 + (h * W_OUT + w) * K,
+ B_SYMMETRIC ? nullptr : row_offsets,
+ H,
+ W,
K,
- row_offsets,
- col_offsets,
- bias,
- &act_times_w_scale);
-}
-
-template <
- int S,
- bool FUSE_RELU,
- bool HAS_BIAS,
- bool A_SYMMETRIC,
- typename BIAS_TYPE>
-static ALWAYS_INLINE void depthwise_2d_per_channel_quantization_kernel_(
- int H,
- int W,
- int K,
- int h,
- int w,
- int stride_h,
- int stride_w,
- std::int32_t A_zero_point,
- const std::uint8_t* A,
- const std::int32_t* B_zero_point,
- const std::int8_t* Bp,
- const float* C_multiplier,
- std::int32_t C_zero_point,
- std::int32_t* C_int32,
- std::uint8_t* C_uint8,
- std::int32_t* row_offsets,
- const std::int32_t* col_offsets,
- const BIAS_TYPE* bias,
- const float* act_times_w_scale) {
- constexpr int PAD_T = (S - 1) / 2, PAD_L = (S - 1) / 2, PAD_R = (S - 1) / 2;
- int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
- int h_in = -PAD_T + h * stride_h;
- int w_in = -PAD_L + w * stride_w;
-
- constexpr int KERNEL_PROD_ALIGNED = (S * S + 1) / 2 * 2;
-
- int k;
- for (k = 0; k < K / 32 * 32; k += 32) {
- inner_prod_2d_packed_<
- S,
- true, /*SUM_A*/
- false, /*remainder*/
- true /*per-channel*/>(
- H,
- W,
- K,
- h_in,
- w_in,
- A + (h_in * W + w_in) * K + k,
- A_zero_point,
- Bp + k * KERNEL_PROD_ALIGNED,
- B_zero_point + k,
- C_int32 + k,
- 0,
- &row_offsets[k]);
- }
- int remainder = K - k;
- if (remainder) {
- inner_prod_2d_packed_<
- S,
- true, /*SUM_A*/
- true, /*remainder*/
- true /*per-channel*/>(
- H,
- W,
- K,
- h_in,
- w_in,
- A + (h_in * W + w_in) * K + k,
- A_zero_point,
- Bp + k * KERNEL_PROD_ALIGNED,
- B_zero_point + k,
- C_int32 + k,
- remainder,
- &row_offsets[k]);
- }
+ internal::avx2_ps_or_epi32_combined_mask,
+ A_zero_point,
+ B_zero_point);
requantize_<
FUSE_RELU,
HAS_BIAS,
- true, /*PER_CHAN_QUANT*/
+ PER_CHANNEL_QUANTIZAITON,
A_SYMMETRIC,
- false, /*B_SYMM*/
+ B_SYMMETRIC,
BIAS_TYPE>(
A_zero_point,
C_multiplier,
@@ -539,7 +111,8 @@ template <
bool HAS_BIAS,
bool A_SYMMETRIC,
bool B_SYMMETRIC,
- typename BIAS_TYPE>
+ typename BIAS_TYPE,
+ bool PER_CHANNEL_QUANTIZATION>
static ALWAYS_INLINE void depthwise_2d_(
int N,
int H,
@@ -549,15 +122,15 @@ static ALWAYS_INLINE void depthwise_2d_(
int stride_w,
std::int32_t A_zero_point,
const std::uint8_t* A,
- std::int32_t B_zero_point,
+ const std::int32_t* B_zero_point,
const PackedDepthWiseConvMatrix& B,
- float C_multiplier,
+ const float* C_multiplier,
std::int32_t C_zero_point,
std::int32_t* C_int32,
std::uint8_t* C_uint8,
const std::int32_t* col_offsets,
const BIAS_TYPE* bias,
- float act_times_w_scale,
+ const float* act_times_w_scale,
int thread_id,
int num_threads) {
assert(K % 8 == 0);
@@ -586,6 +159,8 @@ static ALWAYS_INLINE void depthwise_2d_(
fbgemmPartition1D(
th_info.n_thread_id, th_info.n_num_threads, W_OUT, w_begin, w_end);
+ GenI8Depthwise::jit_kernel_signature middle_kernel;
+
for (int n = n_begin; n < n_end; ++n) {
const std::uint8_t* A_base = A + n * H * W * K;
std::uint8_t* C_uint8_base = C_uint8 + n * H_OUT * W_OUT * K;
@@ -593,14 +168,15 @@ static ALWAYS_INLINE void depthwise_2d_(
int h = 0;
int w = 0;
- for (h = h_begin; h < std::max(PAD_T, h_begin); ++h) {
- for (w = w_begin; w < std::max(PAD_L, w_begin); ++w) {
+ for (h = h_begin; h < PAD_T; ++h) {
+ for (w = w_begin; w < PAD_L; ++w) {
depthwise_2d_kernel_<
S,
FUSE_RELU,
HAS_BIAS,
A_SYMMETRIC,
B_SYMMETRIC,
+ PER_CHANNEL_QUANTIZATION,
BIAS_TYPE>(
H,
W,
@@ -623,13 +199,14 @@ static ALWAYS_INLINE void depthwise_2d_(
act_times_w_scale);
}
- for (; w < std::min(W_OUT - PAD_R, w_end); ++w) {
+ for (; w < std::min(W_OUT - PAD_R - stride_w + 1, w_end); ++w) {
depthwise_2d_kernel_<
S,
FUSE_RELU,
HAS_BIAS,
A_SYMMETRIC,
B_SYMMETRIC,
+ PER_CHANNEL_QUANTIZATION,
BIAS_TYPE>(
H,
W,
@@ -659,6 +236,7 @@ static ALWAYS_INLINE void depthwise_2d_(
HAS_BIAS,
A_SYMMETRIC,
B_SYMMETRIC,
+ PER_CHANNEL_QUANTIZATION,
BIAS_TYPE>(
H,
W,
@@ -682,14 +260,27 @@ static ALWAYS_INLINE void depthwise_2d_(
}
}
- for (; h < std::min(H - PAD_B, h_end); ++h) {
- for (w = w_begin; w < std::max(PAD_L, w_begin); ++w) {
+ // h <= H_OUT - PAD_B - stride_h
+ // h <= (H + PAD_T + PAD_B - S) / stride_h + 1 - PAD_B - stride_h
+ // h_in <= -PAD_T +
+ // ((H + PAD_T + PAD_B - S) / stride_h + 1 - PAD_B - stride_h) * stride_h
+ // Case 1) For stride_h == 1,
+ // h_in <= -PAD_T + H + PAD_T + PAD_B - S + 1 - PAD_B - 1
+ // h_in + S - H <= 0
+ // Case 2) For stride_h == 2,
+ // h_in <= -PAD_L +
+ // H + PAD_T + PAD_B - S + 1 + (1 - PAD_B - stride_h) * stride_h
+ // h_in + S - H <= PAD_B * (1 - stride_h) + 1 + (1 - stride_h) * stride_h
+ // <= -PAD_B + 1 - stride_h <= 0
+ for (; h < std::min(H_OUT - PAD_B - stride_h + 1, h_end); ++h) {
+ for (w = w_begin; w < PAD_L; ++w) {
depthwise_2d_kernel_<
S,
FUSE_RELU,
HAS_BIAS,
A_SYMMETRIC,
B_SYMMETRIC,
+ PER_CHANNEL_QUANTIZATION,
BIAS_TYPE>(
H,
W,
@@ -712,13 +303,32 @@ static ALWAYS_INLINE void depthwise_2d_(
act_times_w_scale);
}
- for (; w < std::min(W_OUT - PAD_R, w_end); ++w) {
+ for (; w < std::min(W_OUT - PAD_R - stride_w + 1, w_end); ++w) {
+ if (n == n_begin && w == std::max(PAD_L, w_begin)) {
+ int remainder = K % 32;
+ if (remainder == 0) {
+ remainder = 32;
+ }
+ middle_kernel = GenI8Depthwise().getOrCreate(
+ /*D=*/2,
+ S,
+ /*compute_a_sum=*/!B_SYMMETRIC,
+ PER_CHANNEL_QUANTIZATION,
+ remainder,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0);
+ }
depthwise_2d_kernel_<
S,
FUSE_RELU,
HAS_BIAS,
A_SYMMETRIC,
B_SYMMETRIC,
+ PER_CHANNEL_QUANTIZATION,
BIAS_TYPE>(
H,
W,
@@ -738,7 +348,8 @@ static ALWAYS_INLINE void depthwise_2d_(
row_offsets,
col_offsets,
bias,
- act_times_w_scale);
+ act_times_w_scale,
+ &middle_kernel);
}
for (; w < w_end; ++w) {
@@ -748,6 +359,7 @@ static ALWAYS_INLINE void depthwise_2d_(
HAS_BIAS,
A_SYMMETRIC,
B_SYMMETRIC,
+ PER_CHANNEL_QUANTIZATION,
BIAS_TYPE>(
H,
W,
@@ -772,13 +384,14 @@ static ALWAYS_INLINE void depthwise_2d_(
}
for (; h < h_end; ++h) {
- for (w = w_begin; w < std::max(PAD_L, w_begin); ++w) {
+ for (w = w_begin; w < PAD_L; ++w) {
depthwise_2d_kernel_<
S,
FUSE_RELU,
HAS_BIAS,
A_SYMMETRIC,
B_SYMMETRIC,
+ PER_CHANNEL_QUANTIZATION,
BIAS_TYPE>(
H,
W,
@@ -801,13 +414,14 @@ static ALWAYS_INLINE void depthwise_2d_(
act_times_w_scale);
}
- for (; w < std::min(W_OUT - PAD_R, w_end); ++w) {
+ for (; w < std::min(W_OUT - PAD_R - stride_w + 1, w_end); ++w) {
depthwise_2d_kernel_<
S,
FUSE_RELU,
HAS_BIAS,
A_SYMMETRIC,
B_SYMMETRIC,
+ PER_CHANNEL_QUANTIZATION,
BIAS_TYPE>(
H,
W,
@@ -837,327 +451,7 @@ static ALWAYS_INLINE void depthwise_2d_(
HAS_BIAS,
A_SYMMETRIC,
B_SYMMETRIC,
- BIAS_TYPE>(
- H,
- W,
- K,
- h,
- w,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias,
- act_times_w_scale);
- }
- }
- } // for each n
-
- fbgemmAlignedFree(row_offsets);
-};
-
-template <
- int S,
- bool FUSE_RELU,
- bool HAS_BIAS,
- bool A_SYMMETRIC,
- typename BIAS_TYPE>
-static ALWAYS_INLINE void depthwise_2d_per_channel_quantization_(
- int N,
- int H,
- int W,
- int K,
- int stride_h,
- int stride_w,
- std::int32_t A_zero_point,
- const std::uint8_t* A,
- const std::int32_t* B_zero_point,
- const PackedDepthWiseConvMatrix& B,
- const float* C_multiplier,
- std::int32_t C_zero_point,
- std::int32_t* C_int32,
- std::uint8_t* C_uint8,
- const std::int32_t* col_offsets,
- const BIAS_TYPE* bias,
- const float* act_times_w_scale,
- int thread_id,
- int num_threads) {
- assert(K % 8 == 0);
- constexpr int R = S;
- constexpr int PAD_T = (R - 1) / 2, PAD_B = PAD_T, PAD_L = (S - 1) / 2,
- PAD_R = PAD_L;
- int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
- const std::int8_t* Bp = B.PackedMat();
-
- int32_t* row_offsets = static_cast<int32_t*>(
- fbgemmAlignedAlloc(64, (K + 31) / 32 * 32 * sizeof(int32_t)));
-
- int n_begin, n_end, h_begin, h_end, w_begin, w_end;
- // Reuse the 3-dim partition scheme for parallelization in matrix
- // multiplication.
- thread_type_t th_info =
- fbgemmGetThreadPartition(N, H_OUT, W_OUT, thread_id, num_threads);
- // Calculate the begin and end index along the batch (N) dimension
- fbgemmPartition1D(
- th_info.g_thread_id, th_info.g_num_threads, N, n_begin, n_end);
- // Calculate the begin and end index along the H dimension
- fbgemmPartition1D(
- th_info.m_thread_id, th_info.m_num_threads, H_OUT, h_begin, h_end);
- // Calculate the begin and end index along the W dimension
- fbgemmPartition1D(
- th_info.n_thread_id, th_info.n_num_threads, W_OUT, w_begin, w_end);
-
- for (int n = n_begin; n < n_end; ++n) {
- const std::uint8_t* A_base = A + n * H * W * K;
- std::uint8_t* C_uint8_base = C_uint8 + n * H_OUT * W_OUT * K;
-
- int h = 0;
- int w = 0;
-
- for (h = h_begin; h < std::max(PAD_T, h_begin); ++h) {
- for (w = w_begin; w < std::max(PAD_L, w_begin); ++w) {
- depthwise_2d_per_channel_quantization_kernel_<
- S,
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC,
- BIAS_TYPE>(
- H,
- W,
- K,
- h,
- w,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias,
- act_times_w_scale);
- }
-
- for (; w < std::min(W_OUT - PAD_R, w_end); ++w) {
- depthwise_2d_per_channel_quantization_kernel_<
- S,
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC,
- BIAS_TYPE>(
- H,
- W,
- K,
- h,
- w,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias,
- act_times_w_scale);
- }
-
- for (; w < w_end; ++w) {
- depthwise_2d_per_channel_quantization_kernel_<
- S,
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC,
- BIAS_TYPE>(
- H,
- W,
- K,
- h,
- w,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias,
- act_times_w_scale);
- }
- }
-
- for (; h < std::min(H - PAD_B, h_end); ++h) {
- for (w = w_begin; w < std::max(PAD_L, w_begin); ++w) {
- depthwise_2d_per_channel_quantization_kernel_<
- S,
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC,
- BIAS_TYPE>(
- H,
- W,
- K,
- h,
- w,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias,
- act_times_w_scale);
- }
-
- for (; w < std::min(W_OUT - PAD_R, w_end); ++w) {
- depthwise_2d_per_channel_quantization_kernel_<
- S,
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC,
- BIAS_TYPE>(
- H,
- W,
- K,
- h,
- w,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias,
- act_times_w_scale);
- }
-
- for (; w < w_end; ++w) {
- depthwise_2d_per_channel_quantization_kernel_<
- S,
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC,
- BIAS_TYPE>(
- H,
- W,
- K,
- h,
- w,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias,
- act_times_w_scale);
- }
- }
-
- for (; h < h_end; ++h) {
- for (w = w_begin; w < std::max(PAD_L, w_begin); ++w) {
- depthwise_2d_per_channel_quantization_kernel_<
- S,
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC,
- BIAS_TYPE>(
- H,
- W,
- K,
- h,
- w,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias,
- act_times_w_scale);
- }
-
- for (; w < std::min(W_OUT - PAD_R, w_end); ++w) {
- depthwise_2d_per_channel_quantization_kernel_<
- S,
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC,
- BIAS_TYPE>(
- H,
- W,
- K,
- h,
- w,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias,
- act_times_w_scale);
- }
-
- for (; w < w_end; ++w) {
- depthwise_2d_per_channel_quantization_kernel_<
- S,
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC,
+ PER_CHANNEL_QUANTIZATION,
BIAS_TYPE>(
H,
W,
@@ -1216,7 +510,8 @@ static void depthwise_2d_(
HAS_BIAS,
true /*A_symmetric*/,
true /*B_symmetric*/,
- BIAS_TYPE>(
+ BIAS_TYPE,
+ false /*PER_CHANNEL_QUANTIZAITON*/>(
N,
H,
W,
@@ -1225,15 +520,15 @@ static void depthwise_2d_(
stride_w,
A_zero_point,
A,
- B_zero_point,
+ &B_zero_point,
B,
- C_multiplier,
+ &C_multiplier,
C_zero_point,
C_int32_temp,
C,
col_offsets,
bias,
- act_times_w_scale,
+ &act_times_w_scale,
thread_id,
num_threads);
} else {
@@ -1243,7 +538,8 @@ static void depthwise_2d_(
HAS_BIAS,
true /*A_symmetric*/,
false /*B_symmetric*/,
- BIAS_TYPE>(
+ BIAS_TYPE,
+ false /*PER_CHANNEL_QUANTIZAITON*/>(
N,
H,
W,
@@ -1252,15 +548,15 @@ static void depthwise_2d_(
stride_w,
A_zero_point,
A,
- B_zero_point,
+ &B_zero_point,
B,
- C_multiplier,
+ &C_multiplier,
C_zero_point,
C_int32_temp,
C,
col_offsets,
bias,
- act_times_w_scale,
+ &act_times_w_scale,
thread_id,
num_threads);
}
@@ -1272,7 +568,8 @@ static void depthwise_2d_(
HAS_BIAS,
false /*A_symmetric*/,
true /*B_symmetric*/,
- BIAS_TYPE>(
+ BIAS_TYPE,
+ false /*PER_CHANNEL_QUANTIZAITON*/>(
N,
H,
W,
@@ -1281,15 +578,15 @@ static void depthwise_2d_(
stride_w,
A_zero_point,
A,
- B_zero_point,
+ &B_zero_point,
B,
- C_multiplier,
+ &C_multiplier,
C_zero_point,
C_int32_temp,
C,
col_offsets,
bias,
- act_times_w_scale,
+ &act_times_w_scale,
thread_id,
num_threads);
} else {
@@ -1299,7 +596,8 @@ static void depthwise_2d_(
HAS_BIAS,
false /*A_symmetric*/,
false /*B_symmetric*/,
- BIAS_TYPE>(
+ BIAS_TYPE,
+ false /*PER_CHANNEL_QUANTIZAITON*/>(
N,
H,
W,
@@ -1308,15 +606,15 @@ static void depthwise_2d_(
stride_w,
A_zero_point,
A,
- B_zero_point,
+ &B_zero_point,
B,
- C_multiplier,
+ &C_multiplier,
C_zero_point,
C_int32_temp,
C,
col_offsets,
bias,
- act_times_w_scale,
+ &act_times_w_scale,
thread_id,
num_threads);
}
@@ -1412,12 +710,14 @@ static void depthwise_2d_per_channel_quantization_(
int32_t* C_int32_temp = static_cast<int32_t*>(
fbgemmAlignedAlloc(64, (K + 31) / 32 * 32 * sizeof(int32_t)));
if (A_zero_point == 0 || col_offsets == nullptr) {
- depthwise_2d_per_channel_quantization_<
+ depthwise_2d_<
S,
FUSE_RELU,
HAS_BIAS,
true /*A_SYMM*/,
- BIAS_TYPE>(
+ false /*B_SYMM*/,
+ BIAS_TYPE,
+ true /*PER_CHANNEL_QUANTIZAITON*/>(
N,
H,
W,
@@ -1438,12 +738,14 @@ static void depthwise_2d_per_channel_quantization_(
thread_id,
num_threads);
} else {
- depthwise_2d_per_channel_quantization_<
+ depthwise_2d_<
S,
FUSE_RELU,
HAS_BIAS,
false /*A_SYMM*/,
- BIAS_TYPE>(
+ false /*B_SYMM*/,
+ BIAS_TYPE,
+ true /*PER_CHANNEL_QUANTIZAITON*/>(
N,
H,
W,
diff --git a/src/FbgemmI8Depthwise3DAvx2.cc b/src/FbgemmI8Depthwise3DAvx2.cc
index f70f915..8993e2b 100644
--- a/src/FbgemmI8Depthwise3DAvx2.cc
+++ b/src/FbgemmI8Depthwise3DAvx2.cc
@@ -11,6 +11,7 @@
#include <string>
#include "./FbgemmI8DepthwiseAvx2-inl.h"
+#include "./GenerateI8Depthwise.h"
#include "./MaskAvx2.h"
#include "fbgemm/Utils.h"
#include "fbgemm/UtilsAvx2.h"
@@ -20,261 +21,11 @@ using namespace std;
namespace fbgemm {
template <
- bool SUM_A,
- bool REMAINDER = false,
- bool PER_CHANNEL_QUANTIZATION = false>
-static ALWAYS_INLINE void inner_prod_3x3x3_packed_(
- int T,
- int H,
- int W,
- int K,
- int t_in,
- int h_in,
- int w_in,
- const uint8_t* A,
- int32_t A_zero_point,
- const int8_t* Bp,
- const int32_t* B_zero_point,
- int32_t* C,
- int remainder,
- int32_t* row_offsets) {
- __m256i A_zero_point_v = _mm256_set1_epi8(static_cast<uint8_t>(A_zero_point));
- __m256i mask_v = _mm256_setzero_si256();
- if (REMAINDER) {
- mask_v = _mm256_load_si256(reinterpret_cast<const __m256i*>(
- internal::avx2_ps_or_epi32_masks[remainder / 4]));
- }
-
- // The code below can be written as a simple R*S loop but the compiler
- // doesn't unroll so we're manually unrolling it.
- // constexpr int R = 3, S = 3;
- // array<__m256i, R * S> a_v;
- // for (int r = 0; r < R; ++r) {
- // for (int s = 0; s < S; ++s) {
- // if (h_in + r >= 0 && h_in + r < H && w_in + s >= 0 && w_in + s < W) {
- // if (REMAINDER) {
- // a_v[r * S + s] =
- // _mm256_maskload_epi32((const int *)(A + (r * W + s) * K),
- // mask_v);
- // } else {
- // a_v[r * S + s] =
- // _mm256_lddqu_si256((const __m256i *)(A + (r * W + s) * K));
- // }
- // } else {
- // a_v[r * S + s] = A_zero_point_v;
- // }
- // }
- // }
- __m256i a_v[8];
- a_v[0] = A_zero_point_v;
- a_v[1] = A_zero_point_v;
- a_v[2] = A_zero_point_v;
- a_v[3] = A_zero_point_v;
- a_v[4] = A_zero_point_v;
- a_v[5] = A_zero_point_v;
- a_v[6] = A_zero_point_v;
- a_v[7] = A_zero_point_v;
-
- if (t_in >= 0 && t_in < T) {
- if (h_in >= 0 && h_in < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[0] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[1] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[2] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 2) * K, mask_v);
- }
- }
-
- if (h_in + 1 >= 0 && h_in + 1 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[3] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[4] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[5] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 2) * K, mask_v);
- }
- }
-
- if (h_in + 2 >= 0 && h_in + 2 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[6] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[7] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 1) * K, mask_v);
- }
- }
- }
-
- __m256i a_sum[4];
- inner_prod_packed_<8, SUM_A, REMAINDER>(
- a_v, reinterpret_cast<const __m256i*>(Bp), C, remainder, a_sum);
-
- a_v[0] = A_zero_point_v;
- a_v[1] = A_zero_point_v;
- a_v[2] = A_zero_point_v;
- a_v[3] = A_zero_point_v;
- a_v[4] = A_zero_point_v;
- a_v[5] = A_zero_point_v;
- a_v[6] = A_zero_point_v;
- a_v[7] = A_zero_point_v;
-
- if (t_in >= 0 && t_in < T) {
- if (h_in + 2 >= 0 && h_in + 2 < H) {
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[0] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 2) * K, mask_v);
- }
- }
- }
-
- if (t_in + 1 >= 0 && t_in + 1 < T) {
- if (h_in >= 0 && h_in < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[1] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[2] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[3] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 2) * K, mask_v);
- }
- }
-
- if (h_in + 1 >= 0 && h_in + 1 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[4] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[5] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[6] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 2) * K, mask_v);
- }
- }
-
- if (h_in + 2 >= 0 && h_in + 2 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[7] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 0) * K, mask_v);
- }
- }
- }
-
- __m256i a_sum_temp[4];
- inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>(
- a_v, reinterpret_cast<const __m256i*>(Bp) + 8, C, remainder, a_sum_temp);
- if (SUM_A) {
- a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]);
- a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]);
- a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]);
- a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]);
- }
-
- a_v[0] = A_zero_point_v;
- a_v[1] = A_zero_point_v;
- a_v[2] = A_zero_point_v;
- a_v[3] = A_zero_point_v;
- a_v[4] = A_zero_point_v;
- a_v[5] = A_zero_point_v;
- a_v[6] = A_zero_point_v;
- a_v[7] = A_zero_point_v;
-
- if (t_in + 1 >= 0 && t_in + 1 < T) {
- if (h_in + 2 >= 0 && h_in + 2 < H) {
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[0] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[1] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 2) * K, mask_v);
- }
- }
- }
-
- if (t_in + 2 >= 0 && t_in + 2 < T) {
- if (h_in >= 0 && h_in < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[2] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[3] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[4] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 2) * K, mask_v);
- }
- }
-
- if (h_in + 1 >= 0 && h_in + 1 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[5] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[6] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[7] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 2) * K, mask_v);
- }
- }
- }
-
- inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>(
- a_v, reinterpret_cast<const __m256i*>(Bp) + 16, C, remainder, a_sum_temp);
- if (SUM_A) {
- a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]);
- a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]);
- a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]);
- a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]);
- }
-
- a_v[0] = A_zero_point_v;
- a_v[1] = A_zero_point_v;
- a_v[2] = A_zero_point_v;
-
- if (t_in + 2 >= 0 && t_in + 2 < T) {
- if (h_in + 2 >= 0 && h_in + 2 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[0] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[1] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[2] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 2) * K, mask_v);
- }
- }
- }
-
- inner_prod_packed_<3, SUM_A, REMAINDER, true /* acc */>(
- a_v, reinterpret_cast<const __m256i*>(Bp) + 24, C, remainder, a_sum_temp);
-
- if (SUM_A) {
- a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]);
- a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]);
- a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]);
- a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]);
-
- __m256i B_zero_point_v;
- for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) {
- if (PER_CHANNEL_QUANTIZATION) {
- B_zero_point_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(B_zero_point + i * 8));
- } else {
- B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]);
- }
- _mm256_store_si256(
- reinterpret_cast<__m256i*>(&row_offsets[i * 8]),
- _mm256_mullo_epi32(a_sum[i], B_zero_point_v));
- }
- }
-}
-
-template <
bool FUSE_RELU,
bool HAS_BIAS,
bool A_SYMMETRIC,
bool B_SYMMETRIC,
+ bool PER_CHANNEL_QUANTIZATION,
typename BIAS_TYPE>
static ALWAYS_INLINE void depthwise_3x3x3_kernel_(
int T,
@@ -289,16 +40,17 @@ static ALWAYS_INLINE void depthwise_3x3x3_kernel_(
int stride_w,
int32_t A_zero_point,
const uint8_t* A,
- int32_t B_zero_point,
+ const int32_t* B_zero_point,
const int8_t* Bp,
- float C_multiplier,
+ const float* C_multiplier,
int32_t C_zero_point,
int32_t* C_int32,
uint8_t* C_uint8,
int32_t* row_offsets,
const int32_t* col_offsets,
const BIAS_TYPE* bias,
- float act_times_w_scale) {
+ const float* act_times_w_scale,
+ GenI8Depthwise::jit_kernel_signature* pregenerated_kernel = nullptr) {
constexpr int R = 3, S = 3;
constexpr int PAD_P = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
@@ -307,141 +59,43 @@ static ALWAYS_INLINE void depthwise_3x3x3_kernel_(
int h_in = -PAD_T + h * stride_h;
int w_in = -PAD_L + w * stride_w;
- int k;
- for (k = 0; k < K / 32 * 32; k += 32) {
- inner_prod_3x3x3_packed_<!B_SYMMETRIC /*SUM_A*/>(
- T,
- H,
- W,
- K,
- t_in,
- h_in,
- w_in,
- A + ((t_in * H + h_in) * W + w_in) * K + k,
- A_zero_point,
- Bp + k * 28,
- &B_zero_point,
- C_int32 + k,
- 0,
- B_SYMMETRIC ? nullptr : &row_offsets[k]);
- }
- int remainder = K - k;
- if (remainder) {
- inner_prod_3x3x3_packed_<!B_SYMMETRIC /*SUM_A*/, true>(
- T,
- H,
- W,
- K,
- t_in,
- h_in,
- w_in,
- A + ((t_in * H + h_in) * W + w_in) * K + k,
- A_zero_point,
- Bp + k * 28,
- &B_zero_point,
- C_int32 + k,
- remainder,
- B_SYMMETRIC ? nullptr : &row_offsets[k]);
+ int remainder = K % 32;
+ if (remainder == 0) {
+ remainder = 32;
}
- requantize_<
- FUSE_RELU,
- HAS_BIAS,
- false, /*PER_CHAN_QUANT*/
- A_SYMMETRIC,
- B_SYMMETRIC>(
- A_zero_point,
- &C_multiplier,
- C_zero_point,
+ GenI8Depthwise::jit_kernel_signature kernel = pregenerated_kernel
+ ? *pregenerated_kernel
+ : GenI8Depthwise().getOrCreate(
+ /*D=*/3,
+ /*S=*/3,
+ /*compute_a_sum=*/!B_SYMMETRIC,
+ PER_CHANNEL_QUANTIZATION,
+ remainder,
+ /*prev_skip=*/std::max(-t_in, 0),
+ /*next_skip=*/std::max(t_in + 3 - T, 0),
+ /*top_skip=*/std::max(-h_in, 0),
+ /*bottom_skip=*/std::max(h_in + 3 - H, 0),
+ /*left_skip=*/std::max(-w_in, 0),
+ /*right_skip=*/std::max(w_in + 3 - W, 0));
+ kernel(
+ A + ((t_in * H + h_in) * W + w_in) * K,
+ Bp,
C_int32,
- C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K,
+ B_SYMMETRIC ? nullptr : row_offsets,
+ H,
+ W,
K,
- row_offsets,
- col_offsets,
- bias,
- &act_times_w_scale);
-}
-
-template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, typename BIAS_TYPE>
-static ALWAYS_INLINE void depthwise_3x3x3_per_channel_quantization_kernel_(
- int T,
- int H,
- int W,
- int K,
- int t,
- int h,
- int w,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- const int32_t* B_zero_point,
- const int8_t* Bp,
- const float* C_multiplier,
- int32_t C_zero_point,
- int32_t* C_int32,
- uint8_t* C_uint8,
- int32_t* row_offsets,
- const int32_t* col_offsets,
- const BIAS_TYPE* bias,
- const float* act_times_w_scale) {
- constexpr int R = 3, S = 3;
- constexpr int PAD_P = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
- int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
- int t_in = -PAD_P + t * stride_t;
- int h_in = -PAD_T + h * stride_h;
- int w_in = -PAD_L + w * stride_w;
+ internal::avx2_ps_or_epi32_combined_mask,
+ A_zero_point,
+ B_zero_point);
- int k;
- for (k = 0; k < K / 32 * 32; k += 32) {
- inner_prod_3x3x3_packed_<
- true, /*SUM_A*/
- false, /*remainder*/
- true /*per-channel*/>(
- T,
- H,
- W,
- K,
- t_in,
- h_in,
- w_in,
- A + ((t_in * H + h_in) * W + w_in) * K + k,
- A_zero_point,
- Bp + k * 28,
- B_zero_point + k,
- C_int32 + k,
- 0,
- &row_offsets[k]);
- }
- int remainder = K - k;
- if (remainder) {
- inner_prod_3x3x3_packed_<
- true, /*SUM_A*/
- true, /*remainder*/
- true /*per-channel*/>(
- T,
- H,
- W,
- K,
- t_in,
- h_in,
- w_in,
- A + ((t_in * H + h_in) * W + w_in) * K + k,
- A_zero_point,
- Bp + k * 28,
- B_zero_point + k,
- C_int32 + k,
- remainder,
- &row_offsets[k]);
- }
requantize_<
FUSE_RELU,
HAS_BIAS,
- true, /*PER_CHAN_QUANT*/
+ PER_CHANNEL_QUANTIZATION,
A_SYMMETRIC,
- false /*B_SYMM*/>(
+ B_SYMMETRIC>(
A_zero_point,
C_multiplier,
C_zero_point,
@@ -459,7 +113,8 @@ template <
bool HAS_BIAS,
bool A_SYMMETRIC,
bool B_SYMMETRIC,
- typename BIAS_TYPE>
+ typename BIAS_TYPE,
+ bool PER_CHANNEL_QUANTIZATION>
static ALWAYS_INLINE void depthwise_3x3x3_pad_1_(
int N,
int T,
@@ -471,15 +126,15 @@ static ALWAYS_INLINE void depthwise_3x3x3_pad_1_(
int stride_w,
int32_t A_zero_point,
const uint8_t* A,
- int32_t B_zero_point,
+ const int32_t* B_zero_point,
const PackedDepthWiseConvMatrix& B,
- float C_multiplier,
+ const float* C_multiplier,
int32_t C_zero_point,
int32_t* C_int32,
uint8_t* C_uint8,
const int32_t* col_offsets,
const BIAS_TYPE* bias,
- float act_times_w_scale,
+ const float* act_times_w_scale,
int thread_id,
int num_threads) {
assert(K % 8 == 0);
@@ -509,18 +164,173 @@ static ALWAYS_INLINE void depthwise_3x3x3_pad_1_(
fbgemmPartition1D(
th_info.n_thread_id, th_info.n_num_threads, H_OUT, h_begin, h_end);
+ GenI8Depthwise::jit_kernel_signature middle_kernel;
+
for (int n = n_begin; n < n_end; ++n) {
const uint8_t* A_base = A + n * T * H * W * K;
uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * K;
- for (int t = t_begin; t < t_end; ++t) {
- for (int h = h_begin; h < h_end; ++h) {
+ int t;
+ for (t = t_begin; t < PAD_P; ++t) {
+ int h;
+ for (h = h_begin; h < PAD_T; ++h) {
for (int w = 0; w < W_OUT; ++w) {
depthwise_3x3x3_kernel_<
FUSE_RELU,
HAS_BIAS,
A_SYMMETRIC,
- B_SYMMETRIC>(
+ B_SYMMETRIC,
+ PER_CHANNEL_QUANTIZATION>(
+ T,
+ H,
+ W,
+ K,
+ t,
+ h,
+ w,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+ } // w
+ } // h
+
+ for (; h < std::min(H_OUT - PAD_B - stride_h + 1, h_end); ++h) {
+ int w;
+ for (w = 0; w < PAD_L; ++w) {
+ depthwise_3x3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ PER_CHANNEL_QUANTIZATION>(
+ T,
+ H,
+ W,
+ K,
+ t,
+ h,
+ w,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+ } // w
+
+ GenI8Depthwise::jit_kernel_signature kernel;
+ for (; w < W_OUT - PAD_R - stride_w + 1; ++w) {
+ if (w == PAD_L) {
+ int remainder = K % 32;
+ if (remainder == 0) {
+ remainder = 32;
+ }
+ int t_in = -PAD_P + t * stride_t;
+ kernel = GenI8Depthwise().getOrCreate(
+ /*D=*/3,
+ /*F=*/3,
+ /*compute_a_sum=*/!B_SYMMETRIC,
+ /*per_chnnale_quantization=*/PER_CHANNEL_QUANTIZATION,
+ remainder,
+ /*prev_skip=*/std::max(-t_in, 0),
+ /*next_skip=*/std::max(t_in + 3 - T, 0),
+ 0,
+ 0,
+ 0,
+ 0);
+ }
+ depthwise_3x3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ PER_CHANNEL_QUANTIZATION>(
+ T,
+ H,
+ W,
+ K,
+ t,
+ h,
+ w,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ &kernel);
+ } // w
+
+ for (; w < W_OUT; ++w) {
+ depthwise_3x3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ PER_CHANNEL_QUANTIZATION>(
+ T,
+ H,
+ W,
+ K,
+ t,
+ h,
+ w,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+ } // w
+ } // h
+
+ for (; h < h_end; ++h) {
+ for (int w = 0; w < W_OUT; ++w) {
+ depthwise_3x3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ PER_CHANNEL_QUANTIZATION>(
T,
H,
W,
@@ -546,72 +356,165 @@ static ALWAYS_INLINE void depthwise_3x3x3_pad_1_(
} // w
} // h
} // t
- } // for each n
- fbgemmAlignedFree(row_offsets);
-};
-template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, typename BIAS_TYPE>
-static ALWAYS_INLINE void depthwise_3x3x3_per_channel_quantization_pad_1_(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- const int32_t* B_zero_point,
- const PackedDepthWiseConvMatrix& B,
- const float* C_multiplier,
- int32_t C_zero_point,
- int32_t* C_int32,
- uint8_t* C_uint8,
- const int32_t* col_offsets,
- const BIAS_TYPE* bias,
- const float* act_times_w_scale,
- int thread_id,
- int num_threads) {
- assert(K % 8 == 0);
- constexpr int K_T = 3, K_H = 3, K_W = 3;
- constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
- PAD_R = 1;
- int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
- int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
- const int8_t* Bp = B.PackedMat();
+ for (; t < std::min(T_OUT - PAD_N - stride_t + 1, t_end); ++t) {
+ int h;
+ for (h = h_begin; h < PAD_T; ++h) {
+ for (int w = 0; w < W_OUT; ++w) {
+ depthwise_3x3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ PER_CHANNEL_QUANTIZATION>(
+ T,
+ H,
+ W,
+ K,
+ t,
+ h,
+ w,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+ } // w
+ } // h
- int32_t* row_offsets = static_cast<int32_t*>(
- fbgemmAlignedAlloc(64, (K + 31) / 32 * 32 * sizeof(int32_t)));
+ for (; h < std::min(H_OUT - PAD_B - stride_h + 1, h_end); ++h) {
+ int w;
+ for (w = 0; w < PAD_L; ++w) {
+ depthwise_3x3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ PER_CHANNEL_QUANTIZATION>(
+ T,
+ H,
+ W,
+ K,
+ t,
+ h,
+ w,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+ } // w
- int n_begin, n_end, t_begin, t_end, h_begin, h_end;
- // Reuse the 3-dim partition scheme for parallelization in matrix
- // multiplication.
- thread_type_t th_info =
- fbgemmGetThreadPartition(N, T_OUT, H_OUT, thread_id, num_threads);
- // Calculate the begin and end index along the batch (N) dimension
- fbgemmPartition1D(
- th_info.g_thread_id, th_info.g_num_threads, N, n_begin, n_end);
- // Calculate the begin and end index along the T dimension
- fbgemmPartition1D(
- th_info.m_thread_id, th_info.m_num_threads, T_OUT, t_begin, t_end);
- // Calculate the begin and end index along the H dimension
- fbgemmPartition1D(
- th_info.n_thread_id, th_info.n_num_threads, H_OUT, h_begin, h_end);
+ for (; w < W_OUT - PAD_R - stride_w + 1; ++w) {
+ if (n == n_begin && w == PAD_L) {
+ int remainder = K % 32;
+ if (remainder == 0) {
+ remainder = 32;
+ }
+ middle_kernel = GenI8Depthwise().getOrCreate(
+ /*D=*/3,
+ /*F=*/3,
+ /*compute_a_sum=*/!B_SYMMETRIC,
+ /*per_chnnale_quantization=*/PER_CHANNEL_QUANTIZATION,
+ remainder,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0);
+ }
+ depthwise_3x3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ PER_CHANNEL_QUANTIZATION>(
+ T,
+ H,
+ W,
+ K,
+ t,
+ h,
+ w,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ &middle_kernel);
+ }
- for (int n = n_begin; n < n_end; ++n) {
- const uint8_t* A_base = A + n * T * H * W * K;
- uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * K;
+ for (; w < W_OUT; ++w) {
+ depthwise_3x3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ PER_CHANNEL_QUANTIZATION>(
+ T,
+ H,
+ W,
+ K,
+ t,
+ h,
+ w,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+ }
+ } // h
- for (int t = t_begin; t < t_end; ++t) {
- for (int h = h_begin; h < h_end; ++h) {
+ for (; h < h_end; ++h) {
for (int w = 0; w < W_OUT; ++w) {
- depthwise_3x3x3_per_channel_quantization_kernel_<
+ depthwise_3x3x3_kernel_<
FUSE_RELU,
HAS_BIAS,
A_SYMMETRIC,
- BIAS_TYPE>(
+ B_SYMMETRIC,
+ PER_CHANNEL_QUANTIZATION>(
T,
H,
W,
@@ -637,8 +540,193 @@ static ALWAYS_INLINE void depthwise_3x3x3_per_channel_quantization_pad_1_(
} // w
} // h
} // t
- } // for each n
+ for (; t < t_end; ++t) {
+ int h;
+ for (h = h_begin; h < PAD_T; ++h) {
+ for (int w = 0; w < W_OUT; ++w) {
+ depthwise_3x3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ PER_CHANNEL_QUANTIZATION>(
+ T,
+ H,
+ W,
+ K,
+ t,
+ h,
+ w,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+ } // w
+ } // h
+
+ for (; h < std::min(H_OUT - PAD_B - stride_h + 1, h_end); ++h) {
+ int w;
+ for (w = 0; w < PAD_L; ++w) {
+ depthwise_3x3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ PER_CHANNEL_QUANTIZATION>(
+ T,
+ H,
+ W,
+ K,
+ t,
+ h,
+ w,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+ } // w
+
+ GenI8Depthwise::jit_kernel_signature kernel;
+ for (; w < W_OUT - PAD_R - stride_w + 1; ++w) {
+ if (w == PAD_L) {
+ int remainder = K % 32;
+ if (remainder == 0) {
+ remainder = 32;
+ }
+ int t_in = -PAD_P + t * stride_t;
+ kernel = GenI8Depthwise().getOrCreate(
+ /*D=*/3,
+ /*F=*/3,
+ /*compute_a_sum=*/!B_SYMMETRIC,
+ /*per_chnnale_quantization=*/PER_CHANNEL_QUANTIZATION,
+ remainder,
+ /*prev_skip=*/std::max(-t_in, 0),
+ /*next_skip=*/std::max(t_in + 3 - T, 0),
+ 0,
+ 0,
+ 0,
+ 0);
+ }
+ depthwise_3x3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ PER_CHANNEL_QUANTIZATION>(
+ T,
+ H,
+ W,
+ K,
+ t,
+ h,
+ w,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ &kernel);
+ } // w
+
+ for (; w < W_OUT; ++w) {
+ depthwise_3x3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ PER_CHANNEL_QUANTIZATION>(
+ T,
+ H,
+ W,
+ K,
+ t,
+ h,
+ w,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+ } // w
+ } // h
+
+ for (; h < h_end; ++h) {
+ for (int w = 0; w < W_OUT; ++w) {
+ depthwise_3x3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ PER_CHANNEL_QUANTIZATION>(
+ T,
+ H,
+ W,
+ K,
+ t,
+ h,
+ w,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+ } // w
+ } // h
+ } // t
+ } // for each n
fbgemmAlignedFree(row_offsets);
};
@@ -674,7 +762,8 @@ static void depthwise_3x3x3_pad_1_(
HAS_BIAS,
true /*A_symmetric*/,
true /*B_symmetric*/,
- BIAS_TYPE>(
+ BIAS_TYPE,
+ false /*PER_CHANNEL_QUANTIZATION*/>(
N,
T,
H,
@@ -685,15 +774,15 @@ static void depthwise_3x3x3_pad_1_(
stride_w,
A_zero_point,
A,
- B_zero_point,
+ &B_zero_point,
B,
- C_multiplier,
+ &C_multiplier,
C_zero_point,
C_int32_temp,
C,
col_offsets,
bias,
- act_times_w_scale,
+ &act_times_w_scale,
thread_id,
num_threads);
} else {
@@ -702,7 +791,8 @@ static void depthwise_3x3x3_pad_1_(
HAS_BIAS,
true /*A_symmetric*/,
false /*B_symmetric*/,
- BIAS_TYPE>(
+ BIAS_TYPE,
+ false /*PER_CHANNEL_QUANTIZATION*/>(
N,
T,
H,
@@ -713,15 +803,15 @@ static void depthwise_3x3x3_pad_1_(
stride_w,
A_zero_point,
A,
- B_zero_point,
+ &B_zero_point,
B,
- C_multiplier,
+ &C_multiplier,
C_zero_point,
C_int32_temp,
C,
col_offsets,
bias,
- act_times_w_scale,
+ &act_times_w_scale,
thread_id,
num_threads);
}
@@ -732,7 +822,8 @@ static void depthwise_3x3x3_pad_1_(
HAS_BIAS,
false /*A_symmetric*/,
true /*B_symmetric*/,
- BIAS_TYPE>(
+ BIAS_TYPE,
+ false /*PER_CHANNEL_QUANTIZATION*/>(
N,
T,
H,
@@ -743,15 +834,15 @@ static void depthwise_3x3x3_pad_1_(
stride_w,
A_zero_point,
A,
- B_zero_point,
+ &B_zero_point,
B,
- C_multiplier,
+ &C_multiplier,
C_zero_point,
C_int32_temp,
C,
col_offsets,
bias,
- act_times_w_scale,
+ &act_times_w_scale,
thread_id,
num_threads);
} else {
@@ -760,7 +851,8 @@ static void depthwise_3x3x3_pad_1_(
HAS_BIAS,
false /*A_symmetric*/,
false /*B_symmetric*/,
- BIAS_TYPE>(
+ BIAS_TYPE,
+ false /*PER_CHANNEL_QUANTIZATION*/>(
N,
T,
H,
@@ -771,15 +863,15 @@ static void depthwise_3x3x3_pad_1_(
stride_w,
A_zero_point,
A,
- B_zero_point,
+ &B_zero_point,
B,
- C_multiplier,
+ &C_multiplier,
C_zero_point,
C_int32_temp,
C,
col_offsets,
bias,
- act_times_w_scale,
+ &act_times_w_scale,
thread_id,
num_threads);
}
@@ -970,11 +1062,13 @@ static void depthwise_3x3x3_per_channel_quantization_pad_1_(
int32_t* C_int32_temp = static_cast<int32_t*>(
fbgemmAlignedAlloc(64, (K + 31) / 32 * 32 * sizeof(int32_t)));
if (A_zero_point == 0 || col_offsets == nullptr) {
- depthwise_3x3x3_per_channel_quantization_pad_1_<
+ depthwise_3x3x3_pad_1_<
FUSE_RELU,
HAS_BIAS,
true /*A_SYMM*/,
- BIAS_TYPE>(
+ false /*B_SYMM*/,
+ BIAS_TYPE,
+ true /*PER_CHANNEL_QUANTIZATION*/>(
N,
T,
H,
@@ -997,11 +1091,13 @@ static void depthwise_3x3x3_per_channel_quantization_pad_1_(
thread_id,
num_threads);
} else {
- depthwise_3x3x3_per_channel_quantization_pad_1_<
+ depthwise_3x3x3_pad_1_<
FUSE_RELU,
HAS_BIAS,
false /*A_SYMM*/,
- BIAS_TYPE>(
+ false /*B_SYMM*/,
+ BIAS_TYPE,
+ true /*PER_CHANNEL_QUANTIZATION*/>(
N,
T,
H,
diff --git a/src/FbgemmI8DepthwiseAvx2-inl.h b/src/FbgemmI8DepthwiseAvx2-inl.h
index 18dda3b..11714c8 100644
--- a/src/FbgemmI8DepthwiseAvx2-inl.h
+++ b/src/FbgemmI8DepthwiseAvx2-inl.h
@@ -16,349 +16,6 @@
namespace fbgemm {
-// c = a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3
-// A is in uint8_t
-// B is in int8_t and pre-interleaved
-// C is in int32_t and 4 registers have results in the following layout:
-// c0_v: c[0:4], c[16:20]
-// c1_v: c[4:8], c[20:24]
-// c2_v: c[8:12], c[24:28]
-// c3_v: c[12:16], c[28:32]
-template <bool SUM_A = false>
-static ALWAYS_INLINE void madd_epi16x4_packed(
- __m256i a0_v,
- __m256i a1_v,
- __m256i a2_v,
- __m256i a3_v,
- const __m256i* b,
- __m256i* c0_v,
- __m256i* c1_v,
- __m256i* c2_v,
- __m256i* c3_v,
- __m256i* a_sum = nullptr) {
- __m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v);
- __m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v);
- __m256i a23_lo_v = _mm256_unpacklo_epi8(a2_v, a3_v);
- __m256i a23_hi_v = _mm256_unpackhi_epi8(a2_v, a3_v);
-
- if (SUM_A) {
- __m256i one_epi8_v = _mm256_set1_epi8(1);
- a_sum[0] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a01_lo_v, one_epi8_v), a_sum[0]);
- a_sum[1] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a01_hi_v, one_epi8_v), a_sum[1]);
- a_sum[0] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a23_lo_v, one_epi8_v), a_sum[0]);
- a_sum[1] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a23_hi_v, one_epi8_v), a_sum[1]);
- }
-
- __m256i a0_interleaved_v = _mm256_unpacklo_epi16(a01_lo_v, a23_lo_v);
- __m256i a1_interleaved_v = _mm256_unpackhi_epi16(a01_lo_v, a23_lo_v);
- __m256i a2_interleaved_v = _mm256_unpacklo_epi16(a01_hi_v, a23_hi_v);
- __m256i a3_interleaved_v = _mm256_unpackhi_epi16(a01_hi_v, a23_hi_v);
-
- __m256i b0_v = _mm256_load_si256(b + 0);
- __m256i b1_v = _mm256_load_si256(b + 1);
- __m256i b2_v = _mm256_load_si256(b + 2);
- __m256i b3_v = _mm256_load_si256(b + 3);
-
- __m256i ab0 = _mm256_maddubs_epi16(a0_interleaved_v, b0_v);
- __m256i ab1 = _mm256_maddubs_epi16(a1_interleaved_v, b1_v);
- __m256i ab2 = _mm256_maddubs_epi16(a2_interleaved_v, b2_v);
- __m256i ab3 = _mm256_maddubs_epi16(a3_interleaved_v, b3_v);
-
- __m256i one_v = _mm256_set1_epi16(1);
- *c0_v = _mm256_madd_epi16(ab0, one_v);
- *c1_v = _mm256_madd_epi16(ab1, one_v);
- *c2_v = _mm256_madd_epi16(ab2, one_v);
- *c3_v = _mm256_madd_epi16(ab3, one_v);
-}
-
-// c = a0 * b0 + a1 * b1 + a2 * b2
-// A is in uint8_t
-// B is in int8_t and pre-interleaved
-// C is in int32_t and 4 registers have results in the following layout:
-// c0_v: c[0:4], c[16:20]
-// c1_v: c[4:8], c[20:24]
-// c2_v: c[8:12], c[24:28]
-// c3_v: c[12:16], c[28:32]
-template <bool SUM_A = false>
-static ALWAYS_INLINE void madd_epi16x3_packed(
- __m256i a0_v,
- __m256i a1_v,
- __m256i a2_v,
- const __m256i* b,
- __m256i* c0_v,
- __m256i* c1_v,
- __m256i* c2_v,
- __m256i* c3_v,
- __m256i* a_sum = nullptr) {
- __m256i zero_v = _mm256_setzero_si256();
-
- __m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v);
- __m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v);
- __m256i a23_lo_v = _mm256_unpacklo_epi8(a2_v, zero_v);
- __m256i a23_hi_v = _mm256_unpackhi_epi8(a2_v, zero_v);
-
- if (SUM_A) {
- __m256i one_epi8_v = _mm256_set1_epi8(1);
- a_sum[0] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a01_lo_v, one_epi8_v), a_sum[0]);
- a_sum[1] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a01_hi_v, one_epi8_v), a_sum[1]);
- a_sum[0] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a23_lo_v, one_epi8_v), a_sum[0]);
- a_sum[1] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a23_hi_v, one_epi8_v), a_sum[1]);
- }
-
- __m256i a0_interleaved_v = _mm256_unpacklo_epi16(a01_lo_v, a23_lo_v);
- __m256i a1_interleaved_v = _mm256_unpackhi_epi16(a01_lo_v, a23_lo_v);
- __m256i a2_interleaved_v = _mm256_unpacklo_epi16(a01_hi_v, a23_hi_v);
- __m256i a3_interleaved_v = _mm256_unpackhi_epi16(a01_hi_v, a23_hi_v);
-
- __m256i b0_v = _mm256_load_si256(b + 0);
- __m256i b1_v = _mm256_load_si256(b + 1);
- __m256i b2_v = _mm256_load_si256(b + 2);
- __m256i b3_v = _mm256_load_si256(b + 3);
-
- __m256i ab0 = _mm256_maddubs_epi16(a0_interleaved_v, b0_v);
- __m256i ab1 = _mm256_maddubs_epi16(a1_interleaved_v, b1_v);
- __m256i ab2 = _mm256_maddubs_epi16(a2_interleaved_v, b2_v);
- __m256i ab3 = _mm256_maddubs_epi16(a3_interleaved_v, b3_v);
-
- __m256i one_v = _mm256_set1_epi16(1);
- *c0_v = _mm256_madd_epi16(ab0, one_v);
- *c1_v = _mm256_madd_epi16(ab1, one_v);
- *c2_v = _mm256_madd_epi16(ab2, one_v);
- *c3_v = _mm256_madd_epi16(ab3, one_v);
-}
-
-// c = a0 * b0 + a1 * b1
-// A is in uint8_t
-// B is in int8_t and pre-interleaved
-// C is in int32_t and 4 registers have results in the following layout:
-// c0_v: c[0:4], c[4:8]
-// c1_v: c[8:12], c[12:16]
-// c2_v: c[16:20], c[20:24]
-// c3_v: c[24:28], c[28:32]
-template <bool SUM_A = false>
-static ALWAYS_INLINE void madd_epi16x2_packed(
- __m256i a0_v,
- __m256i a1_v,
- const __m256i* b,
- __m256i* c0_v,
- __m256i* c1_v,
- __m256i* c2_v,
- __m256i* c3_v,
- __m256i* a_sum = nullptr) {
- __m256i a_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v);
- __m256i a_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v);
-
- if (SUM_A) {
- __m256i one_epi8_v = _mm256_set1_epi8(1);
- a_sum[0] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a_lo_v, one_epi8_v), a_sum[0]);
- a_sum[1] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a_hi_v, one_epi8_v), a_sum[1]);
- }
-
- __m256i b0_v = _mm256_load_si256(b + 0);
- __m256i b1_v = _mm256_load_si256(b + 1);
-
- __m256i ab_lo_v = _mm256_maddubs_epi16(a_lo_v, b0_v);
- __m256i ab_hi_v = _mm256_maddubs_epi16(a_hi_v, b1_v);
-
- *c0_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_lo_v));
- *c1_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_hi_v));
- *c2_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_lo_v, 1));
- *c3_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_hi_v, 1));
-}
-
-// c = a0 * b0
-// A is in uint8_t
-// B is in int8_t and pre-interleaved
-// C is in int32_t and 4 registers have results in the following layout:
-// c0_v: c[0:4], c[4:8]
-// c1_v: c[8:12], c[12:16]
-// c2_v: c[16:20], c[20:24]
-// c3_v: c[24:28], c[28:32]
-template <bool SUM_A = false>
-static ALWAYS_INLINE void madd_epi16_packed(
- __m256i a_v,
- const __m256i* b,
- __m256i* c0_v,
- __m256i* c1_v,
- __m256i* c2_v,
- __m256i* c3_v,
- __m256i* a_sum = nullptr) {
- __m256i zero_v = _mm256_setzero_si256();
-
- __m256i a_lo_v = _mm256_unpacklo_epi8(a_v, zero_v);
- __m256i a_hi_v = _mm256_unpackhi_epi8(a_v, zero_v);
-
- if (SUM_A) {
- __m256i one_epi8_v = _mm256_set1_epi8(1);
- a_sum[0] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a_lo_v, one_epi8_v), a_sum[0]);
- a_sum[1] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a_hi_v, one_epi8_v), a_sum[1]);
- }
-
- __m256i b0_v = _mm256_load_si256(b + 0);
- __m256i b1_v = _mm256_load_si256(b + 1);
-
- __m256i ab_lo_v = _mm256_maddubs_epi16(a_lo_v, b0_v);
- __m256i ab_hi_v = _mm256_maddubs_epi16(a_hi_v, b1_v);
-
- *c0_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_lo_v));
- *c1_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_hi_v));
- *c2_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_lo_v, 1));
- *c3_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_hi_v, 1));
-}
-
-// K is the number of accumulations we're doing
-template <int K, bool SUM_A = false, bool REMAINDER = false, bool ACC = false>
-static ALWAYS_INLINE void inner_prod_packed_(
- const __m256i* a_v,
- const __m256i* Bp,
- std::int32_t* C,
- int remainder,
- __m256i* a_sum = nullptr) {
- __m256i c[4], c_temp[4];
- __m256i a_sum_temp[2] = {0, 0};
-
- int k = 0;
- if (K >= 4) {
- madd_epi16x4_packed<SUM_A>(
- a_v[0],
- a_v[1],
- a_v[2],
- a_v[3],
- Bp,
- &c[0],
- &c[1],
- &c[2],
- &c[3],
- a_sum_temp);
-
- for (k = 4; k < K / 4 * 4; k += 4) {
- madd_epi16x4_packed<SUM_A>(
- a_v[k + 0],
- a_v[k + 1],
- a_v[k + 2],
- a_v[k + 3],
- Bp + k,
- &c_temp[0],
- &c_temp[1],
- &c_temp[2],
- &c_temp[3],
- a_sum_temp);
-
- c[0] = _mm256_add_epi32(c[0], c_temp[0]);
- c[1] = _mm256_add_epi32(c[1], c_temp[1]);
- c[2] = _mm256_add_epi32(c[2], c_temp[2]);
- c[3] = _mm256_add_epi32(c[3], c_temp[3]);
- }
- } else {
- c[0] = _mm256_setzero_si256();
- c[1] = _mm256_setzero_si256();
- c[2] = _mm256_setzero_si256();
- c[3] = _mm256_setzero_si256();
- }
-
- if (K - k == 3) {
- madd_epi16x3_packed<SUM_A>(
- a_v[k],
- a_v[k + 1],
- a_v[k + 2],
- Bp + k,
- &c_temp[0],
- &c_temp[1],
- &c_temp[2],
- &c_temp[3],
- a_sum_temp);
-
- c[0] = _mm256_add_epi32(c[0], c_temp[0]);
- c[1] = _mm256_add_epi32(c[1], c_temp[1]);
- c[2] = _mm256_add_epi32(c[2], c_temp[2]);
- c[3] = _mm256_add_epi32(c[3], c_temp[3]);
- }
-
- c_temp[0] = _mm256_permute2f128_si256(c[0], c[1], 0x20);
- c_temp[1] = _mm256_permute2f128_si256(c[2], c[3], 0x20);
- c_temp[2] = _mm256_permute2f128_si256(c[0], c[1], 0x31);
- c_temp[3] = _mm256_permute2f128_si256(c[2], c[3], 0x31);
-
- if (K - k == 0 || K - k == 3) {
- c[0] = c_temp[0];
- c[1] = c_temp[1];
- c[2] = c_temp[2];
- c[3] = c_temp[3];
- } else {
- if (K - k == 1) {
- madd_epi16_packed<SUM_A>(
- a_v[k], Bp + k, &c[0], &c[1], &c[2], &c[3], a_sum_temp);
- } else if (K - k == 2) {
- madd_epi16x2_packed<SUM_A>(
- a_v[k], a_v[k + 1], Bp + k, &c[0], &c[1], &c[2], &c[3], a_sum_temp);
- }
-
- c[0] = _mm256_add_epi32(c[0], c_temp[0]);
- c[1] = _mm256_add_epi32(c[1], c_temp[1]);
- c[2] = _mm256_add_epi32(c[2], c_temp[2]);
- c[3] = _mm256_add_epi32(c[3], c_temp[3]);
- }
-
- if (REMAINDER) {
- for (int r = 0; r < remainder / 8; ++r) {
- if (ACC) {
- _mm256_storeu_si256(
- reinterpret_cast<__m256i*>(C + r * 8),
- _mm256_add_epi32(
- _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + r * 8)),
- c[r]));
- } else {
- _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + r * 8), c[r]);
- }
- }
- } else {
- if (ACC) {
- _mm256_storeu_si256(
- reinterpret_cast<__m256i*>(C),
- _mm256_add_epi32(
- _mm256_loadu_si256(reinterpret_cast<__m256i*>(C)), c[0]));
- _mm256_storeu_si256(
- reinterpret_cast<__m256i*>(C + 8),
- _mm256_add_epi32(
- _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 8)), c[1]));
- _mm256_storeu_si256(
- reinterpret_cast<__m256i*>(C + 16),
- _mm256_add_epi32(
- _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 16)), c[2]));
- _mm256_storeu_si256(
- reinterpret_cast<__m256i*>(C + 24),
- _mm256_add_epi32(
- _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 24)), c[3]));
- } else {
- _mm256_storeu_si256(reinterpret_cast<__m256i*>(C), c[0]);
- _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 8), c[1]);
- _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 16), c[2]);
- _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 24), c[3]);
- }
- }
-
- if (SUM_A) {
- a_sum[0] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(a_sum_temp[0]));
- a_sum[1] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(a_sum_temp[1]));
- a_sum[2] =
- _mm256_cvtepi16_epi32(_mm256_extracti128_si256(a_sum_temp[0], 1));
- a_sum[3] =
- _mm256_cvtepi16_epi32(_mm256_extracti128_si256(a_sum_temp[1], 1));
- }
-}
-
// Almost same as ReQuantizeOutput in OutputProcessing-inh.h but different
// row_offsets for each row because of depth-wise convolution
template <
@@ -672,15 +329,6 @@ static ALWAYS_INLINE void requantize_(
}
}
-template <bool REMAINDER>
-static ALWAYS_INLINE __m256i load_a(const std::uint8_t* A, __m256i mask_v) {
- if (REMAINDER) {
- return _mm256_maskload_epi32(reinterpret_cast<const int*>(A), mask_v);
- } else {
- return _mm256_lddqu_si256(reinterpret_cast<const __m256i*>(A));
- }
-}
-
static inline std::pair<int, int> closest_factors_(int n) {
int a = static_cast<int>(std::sqrt(n));
while (n % a != 0) {
diff --git a/src/GenerateI8Depthwise.cc b/src/GenerateI8Depthwise.cc
new file mode 100644
index 0000000..4b9eb7e
--- /dev/null
+++ b/src/GenerateI8Depthwise.cc
@@ -0,0 +1,506 @@
+#include "./GenerateI8Depthwise.h"
+
+#include <asmjit/asmjit.h>
+#include <iostream>
+
+#include "./CodeCache.h"
+#include "./CodeGenHelpers.h"
+#include "fbgemm/Utils.h"
+
+namespace fbgemm {
+
+namespace {
+asmjit::JitRuntime& runtime() {
+ static asmjit::JitRuntime rt; //< JIT Runtime for asmjit,
+ // depents on other static
+ // variables. Required to prevent
+ // initialization order fiasco
+ return rt;
+}
+
+// Controll access to runtime;
+std::mutex rtMutex_;
+
+// The hash depends on D, F, compute_a_sum, per_channel_quantization, remainder,
+// prev_skip, next_skip, top_skip, bottom_skip, left_skip, and right_skip.
+CodeCache<
+ std::tuple<int, int, bool, bool, int, int, int, int, int, int, int>,
+ GenI8Depthwise::jit_kernel_signature>
+ codeCache_;
+} // namespace
+
+namespace x86 = asmjit::x86;
+
+// c = a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3
+// A is in uint8_t
+// B is in int8_t and pre-interleaved
+// C is in int32_t and 4 registers have results in the following layout:
+// c0_v: c[0:4], c[16:20]
+// c1_v: c[4:8], c[20:24]
+// c2_v: c[8:12], c[24:28]
+// c3_v: c[12:16], c[28:32]
+static void genMaddEpi16xNPacked(
+ x86::Emitter* e,
+ x86::Ymm a[4],
+ x86::Gp b,
+ x86::Ymm c[4],
+ x86::Ymm* a_sum,
+ int n,
+ int remainder,
+ bool accumulation,
+ x86::Ymm one_epi8,
+ x86::Ymm one_epi16,
+ x86::Ymm zero) {
+ // Interleave inputs. Reuse a[1] and a[3] to save registers
+ x86::Ymm a01_lo(0), a01_hi(1), a23_lo(a[1].id()), a23_hi(a[3].id());
+ e->vpunpcklbw(a01_lo, a[0], n == 1 ? zero : a[1]);
+ if (remainder >= 8) {
+ e->vpunpckhbw(a01_hi, a[0], n == 1 ? zero : a[1]);
+ }
+ if (n > 2) {
+ e->vpunpcklbw(a23_lo, a[2], n == 3 ? zero : a[3]);
+ if (remainder >= 8) {
+ e->vpunpckhbw(a23_hi, a[2], n == 3 ? zero : a[3]);
+ }
+ }
+
+ // Compute row_wise sum of A for row_offsets
+ if (a_sum) {
+ if (accumulation) {
+ e->vpmaddubsw(a[0], a01_lo, one_epi8);
+ e->vpaddsw(a_sum[0], a[0], a_sum[0]);
+
+ if (remainder >= 8) {
+ e->vpmaddubsw(a[2], a01_hi, one_epi8);
+ e->vpaddsw(a_sum[1], a[2], a_sum[1]);
+ }
+ } else {
+ e->vpmaddubsw(a_sum[0], a01_lo, one_epi8);
+ if (remainder >= 8) {
+ e->vpmaddubsw(a_sum[1], a01_hi, one_epi8);
+ }
+ }
+
+ if (n > 2) {
+ e->vpmaddubsw(a[0], a23_lo, one_epi8);
+ e->vpaddsw(a_sum[0], a[0], a_sum[0]);
+
+ if (remainder >= 8) {
+ e->vpmaddubsw(a[2], a23_hi, one_epi8);
+ e->vpaddsw(a_sum[1], a[2], a_sum[1]);
+ }
+ }
+ }
+
+ if (n > 2) {
+ // Reusing a
+ e->vpunpcklwd(a[0], a01_lo, a23_lo);
+ e->vpunpckhwd(a[1], a01_lo, a23_lo);
+ if (remainder >= 16) {
+ e->vpunpcklwd(a[2], a01_hi, a23_hi);
+ e->vpunpckhwd(a[3], a01_hi, a23_hi);
+ }
+
+ e->vpmaddubsw(a[0], a[0], x86::ymmword_ptr(b));
+ e->vpmaddubsw(a[1], a[1], x86::ymmword_ptr(b, 32));
+ if (remainder >= 16) {
+ e->vpmaddubsw(a[2], a[2], x86::ymmword_ptr(b, 64));
+ e->vpmaddubsw(a[3], a[3], x86::ymmword_ptr(b, 96));
+ }
+
+ if (accumulation) {
+ e->vpmaddwd(a[0], a[0], one_epi16);
+ e->vpaddd(c[0], c[0], a[0]);
+ e->vpmaddwd(a[1], a[1], one_epi16);
+ e->vpaddd(c[1], c[1], a[1]);
+
+ if (remainder >= 16) {
+ e->vpmaddwd(a[2], a[2], one_epi16);
+ e->vpaddd(c[2], c[2], a[2]);
+ e->vpmaddwd(a[3], a[3], one_epi16);
+ e->vpaddd(c[3], c[3], a[3]);
+ }
+ } else {
+ e->vpmaddwd(c[0], a[0], one_epi16);
+ e->vpmaddwd(c[1], a[1], one_epi16);
+
+ if (remainder >= 16) {
+ e->vpmaddwd(c[2], a[2], one_epi16);
+ e->vpmaddwd(c[3], a[3], one_epi16);
+ }
+ }
+ } else {
+ // Reusing a
+ e->vpmaddubsw(a[0], a01_lo, x86::ymmword_ptr(b));
+ e->vpmaddubsw(a[1], a01_hi, x86::ymmword_ptr(b, 32));
+
+ if (accumulation) {
+ e->vpmovsxwd(a[2], x86::Xmm(a[0].id()));
+ e->vpaddd(c[0], c[0], a[2]);
+ e->vpmovsxwd(a[3], x86::Xmm(a[1].id()));
+ e->vpaddd(c[1], c[1], a[3]);
+
+ if (remainder >= 16) {
+ e->vextracti128(x86::Xmm(a[0].id()), a[0], asmjit::Imm(1));
+ e->vpmovsxwd(a[0], x86::Xmm(a[0].id()));
+ e->vpaddd(c[2], c[2], a[0]);
+ e->vextracti128(x86::Xmm(a[1].id()), a[1], asmjit::Imm(1));
+ e->vpmovsxwd(a[1], x86::Xmm(a[1].id()));
+ e->vpaddd(c[3], c[3], a[1]);
+ }
+ } else {
+ e->vpmovsxwd(c[0], x86::Xmm(a[0].id()));
+ e->vpmovsxwd(c[1], x86::Xmm(a[1].id()));
+
+ if (remainder >= 16) {
+ e->vextracti128(x86::Xmm(a[0].id()), a[0], asmjit::Imm(1));
+ e->vpmovsxwd(c[2], x86::Xmm(a[0].id()));
+ e->vextracti128(x86::Xmm(a[1].id()), a[1], asmjit::Imm(1));
+ e->vpmovsxwd(c[3], x86::Xmm(a[1].id()));
+ }
+ }
+ }
+}
+
+GenI8Depthwise::jit_kernel_signature GenI8Depthwise::getOrCreate(
+ int D,
+ int S,
+ bool compute_a_sum,
+ bool per_channel_quantization,
+ int remainder,
+ int prev_skip,
+ int next_skip,
+ int top_skip,
+ int bottom_skip,
+ int left_skip,
+ int right_skip) {
+ std::tuple<int, int, bool, bool, int, int, int, int, int, int, int>
+ kernelSig = std::make_tuple(
+ D,
+ S,
+ compute_a_sum,
+ per_channel_quantization,
+ remainder,
+ prev_skip,
+ next_skip,
+ top_skip,
+ bottom_skip,
+ left_skip,
+ right_skip);
+
+ return codeCache_.getOrCreate(kernelSig, [&]() -> jit_kernel_signature {
+ asmjit::CodeHolder code;
+ code.init(runtime().codeInfo());
+ x86::Assembler assembler(&code);
+ x86::Emitter* e = assembler.as<x86::Emitter>();
+
+ x86::Gp a_addr = e->zdi();
+ x86::Gp b_addr = e->zsi();
+ x86::Gp c_addr = e->zdx();
+ x86::Gp a_sum_addr = e->zcx();
+ x86::Gp h = e->gpz(8);
+ x86::Gp w = e->gpz(9);
+ x86::Gp c_in = e->gpz(10);
+ x86::Gp mask_addr = e->gpz(11);
+ x86::Gp a_zero_point = e->gpz(12);
+ x86::Gp b_zero_point_addr = e->gpz(13);
+ x86::Gp ic_loop_count = e->gpz(14);
+ x86::Gp a_addr_save = e->gpz(15);
+
+ asmjit::FuncDetail func;
+ func.init(asmjit::FuncSignatureT<
+ void,
+ const std::uint8_t*,
+ const std::int8_t*,
+ std::int32_t*,
+ std::int32_t*,
+ int,
+ int,
+ int,
+ const int*,
+ int,
+ const std::int32_t*>(asmjit::CallConv::kIdHost));
+
+ asmjit::FuncFrame frame;
+ frame.init(func);
+
+ frame.setDirtyRegs(
+ x86::Reg::kGroupVec,
+ asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+ frame.setDirtyRegs(
+ x86::Reg::kGroupGp,
+ asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
+
+ asmjit::FuncArgsAssignment args(&func);
+ args.assignAll(
+ a_addr,
+ b_addr,
+ c_addr,
+ a_sum_addr,
+ h,
+ w,
+ c_in,
+ mask_addr,
+ a_zero_point,
+ b_zero_point_addr);
+
+ args.updateFuncFrame(frame);
+ frame.finalize();
+
+ e->emitProlog(frame);
+ e->emitArgsAssignment(frame, args);
+
+ // Assign vector registers
+ x86::Ymm a[4];
+ x86::Ymm c[4];
+ x86::Ymm a_sum[2];
+
+ int vreg_id = 2; // reserve 2 for temp vreg
+ for (int i = 0; i < 4; ++i, ++vreg_id) {
+ a[i] = x86::Ymm(vreg_id);
+ }
+ for (int i = 0; i < 4; ++i, ++vreg_id) {
+ c[i] = x86::Ymm(vreg_id);
+ }
+ if (compute_a_sum) {
+ a_sum[0] = x86::Ymm(vreg_id);
+ ++vreg_id;
+ a_sum[1] = x86::Ymm(vreg_id);
+ ++vreg_id;
+ }
+ x86::Ymm mask_vreg(vreg_id);
+ constexpr int vlen = simd_info<inst_set_t::avx2>::WIDTH_32BIT_ELEMS;
+ if (remainder != simd_info<inst_set_t::avx2>::WIDTH_BYTES) {
+ ++vreg_id;
+ e->vmovups(
+ mask_vreg,
+ x86::ymmword_ptr(
+ mask_addr, (vlen - remainder / 4) % vlen * sizeof(int32_t)));
+ }
+ x86::Ymm one_epi8(vreg_id);
+ if (compute_a_sum) {
+ ++vreg_id;
+ gen8BitVectorOne(e, one_epi8);
+ }
+
+ int K = 1;
+ for (int i = 0; i < D; ++i) {
+ K *= S;
+ }
+ x86::Ymm one_epi16(vreg_id);
+ if (K > 2) {
+ ++vreg_id;
+ gen16BitVectorOne(e, one_epi16);
+ }
+
+ bool has_pad = prev_skip || next_skip || top_skip || bottom_skip ||
+ left_skip || right_skip;
+ bool need_zero = K % 4 == 3 || K % 4 == 1;
+ // When out of registers, zero and A_zero_point_vreg need to share.
+ bool recompute_zero = vreg_id == 15 && need_zero;
+
+ x86::Ymm a_zero_point_vreg(vreg_id);
+ if (!recompute_zero && has_pad) {
+ e->movq(x86::Xmm(a_zero_point_vreg.id()), a_zero_point);
+ e->vpbroadcastb(a_zero_point_vreg, x86::Xmm(a_zero_point_vreg.id()));
+ }
+ if (vreg_id < 15) {
+ ++vreg_id;
+ }
+ x86::Ymm zero(vreg_id);
+ if (need_zero && (!recompute_zero || !has_pad)) {
+ e->vxorps(zero, zero, zero);
+ }
+
+ // Assign scalar registers
+ e->imul(w, c_in);
+ e->imul(h, w);
+ if (D >= 3) {
+ e->mov(a_addr_save, w);
+ e->imul(a_addr_save, S);
+ e->sub(h, a_addr_save);
+ }
+ e->mov(a_addr_save, c_in);
+ e->imul(a_addr_save, S);
+ e->sub(w, a_addr_save);
+
+ e->mov(ic_loop_count, c_in);
+ e->add(ic_loop_count, asmjit::Imm(31));
+ e->sar(ic_loop_count, asmjit::Imm(5));
+
+ e->mov(a_addr_save, a_addr);
+ asmjit::Label ic_loop_begin = e->newLabel(), ic_loop_end = e->newLabel();
+
+ // main_loop == false: the last vector iteration across input channels
+ for (bool main_loop : {true, false}) {
+ if (main_loop) {
+ e->bind(ic_loop_begin);
+ e->dec(ic_loop_count);
+ e->jle(ic_loop_end);
+ }
+
+ if (recompute_zero && has_pad) {
+ e->movq(x86::Xmm(a_zero_point_vreg.id()), a_zero_point);
+ e->vpbroadcastb(a_zero_point_vreg, x86::Xmm(a_zero_point_vreg.id()));
+ }
+
+ int i = 0;
+ // Iterate across the reduction (filter) dimension
+ for (int f_t = 0; f_t < ((D == 2) ? 1 : S); ++f_t) {
+ for (int f_h = 0; f_h < S; ++f_h) {
+ for (int f_w = 0; f_w < S; ++f_w, ++i) {
+ bool pad = false;
+ if (D > 2) {
+ if (f_t < prev_skip || f_t >= S - next_skip) {
+ pad = true;
+ }
+ }
+ if (f_h < top_skip || f_h >= S - bottom_skip || f_w < left_skip ||
+ f_w >= S - right_skip) {
+ pad = true;
+ }
+
+ // Load A
+ if (pad) {
+ e->vmovups(a[i % 4], a_zero_point_vreg);
+ } else {
+ if (!main_loop && remainder != 32) {
+ e->vmaskmovps(a[i % 4], mask_vreg, x86::ymmword_ptr(a_addr));
+ } else {
+ e->vmovups(a[i % 4], x86::ymmword_ptr(a_addr));
+ }
+ }
+
+ // Compute when we have 4 inputs or this is the last iteration
+ if (i % 4 == 3 || i == K - 1) {
+ if (i == K - 1 && (i / 4 * 4 == K - 3 || i / 4 * 4 == K - 1)) {
+ if (recompute_zero && has_pad) {
+ e->vxorps(zero, zero, zero);
+ }
+ }
+
+ genMaddEpi16xNPacked(
+ e,
+ a,
+ b_addr,
+ c,
+ compute_a_sum ? a_sum : nullptr,
+ /*n=*/std::min(K - i / 4 * 4, 4),
+ main_loop ? 32 : remainder,
+ /*accumulation=*/i / 4 > 0,
+ one_epi8,
+ one_epi16,
+ zero);
+
+ if (i != K - 1) {
+ e->add(b_addr, asmjit::Imm(32 * 4));
+ } else if (main_loop) {
+ e->add(b_addr, asmjit::Imm(32 * (K - i / 4 * 4 + 1) / 2 * 2));
+ }
+
+ if (K - i / 4 * 4 >= 3 && K - i / 4 * 4 <= 6) {
+ for (int r = 0; r < (main_loop ? 4 : remainder / 8); ++r) {
+ e->vperm2f128(
+ a[r],
+ c[r % 2 * 2],
+ c[r % 2 * 2 + 1],
+ asmjit::Imm(r < 2 ? 0x20 : 0x31));
+ }
+ for (int r = 0; r < (main_loop ? 4 : remainder / 8); ++r) {
+ e->vmovaps(c[r], a[r]);
+ }
+ }
+ }
+ if (i != K - 1) {
+ e->add(a_addr, c_in);
+ }
+ }
+ if (i != K - 1) {
+ e->add(a_addr, w);
+ }
+ }
+ if (D >= 3 && i != K - 1) {
+ e->add(a_addr, h);
+ }
+ }
+
+ for (int r = 0; r < (main_loop ? 4 : remainder / 8); ++r) {
+ e->vmovups(x86::ymmword_ptr(c_addr, r * 32), c[r]);
+ }
+
+ if (compute_a_sum) {
+ if (per_channel_quantization) {
+ e->vmovups(c[0], x86::ymmword_ptr(b_zero_point_addr));
+ } else {
+ e->vpbroadcastd(c[0], x86::dword_ptr(b_zero_point_addr));
+ }
+ e->vpmovsxwd(a[0], x86::Xmm(a_sum[0].id()));
+ e->vpmulld(a[0], a[0], c[0]);
+ e->vmovups(x86::ymmword_ptr(a_sum_addr), a[0]);
+
+ if (main_loop || remainder >= 8) {
+ if (per_channel_quantization) {
+ e->vmovups(c[0], x86::ymmword_ptr(b_zero_point_addr, 32));
+ }
+ e->vpmovsxwd(a[1], x86::Xmm(a_sum[1].id()));
+ e->vpmulld(a[1], a[1], c[0]);
+ e->vmovups(x86::ymmword_ptr(a_sum_addr, 32), a[1]);
+ }
+
+ if (main_loop || remainder >= 16) {
+ if (per_channel_quantization) {
+ e->vmovups(c[0], x86::ymmword_ptr(b_zero_point_addr, 64));
+ }
+ e->vextracti128(x86::Xmm(a_sum[0].id()), a_sum[0], asmjit::Imm(1));
+ e->vpmovsxwd(a_sum[0], x86::Xmm(a_sum[0].id()));
+ e->vpmulld(a_sum[0], a_sum[0], c[0]);
+ e->vmovups(x86::ymmword_ptr(a_sum_addr, 64), a_sum[0]);
+ }
+
+ if (main_loop || remainder >= 24) {
+ if (per_channel_quantization) {
+ e->vmovups(c[0], x86::ymmword_ptr(b_zero_point_addr, 96));
+ }
+ e->vextracti128(x86::Xmm(a_sum[1].id()), a_sum[1], asmjit::Imm(1));
+ e->vpmovsxwd(a_sum[1], x86::Xmm(a_sum[1].id()));
+ e->vpmulld(a_sum[1], a_sum[1], c[0]);
+ e->vmovups(x86::ymmword_ptr(a_sum_addr, 96), a_sum[1]);
+ }
+
+ if (main_loop) {
+ if (per_channel_quantization) {
+ e->add(b_zero_point_addr, asmjit::Imm(128));
+ }
+ e->add(a_sum_addr, asmjit::Imm(128));
+ }
+ }
+
+ if (main_loop) {
+ e->add(c_addr, asmjit::Imm(128));
+ e->add(a_addr_save, asmjit::Imm(32));
+ e->mov(a_addr, a_addr_save);
+ e->jmp(ic_loop_begin);
+
+ e->bind(ic_loop_end);
+ }
+ }
+
+ e->emitEpilog(frame);
+
+ jit_kernel_signature fn;
+ asmjit::Error err;
+ {
+ std::unique_lock<std::mutex> lock(rtMutex_);
+ err = runtime().add(&fn, &code);
+ }
+ if (err) {
+ std::cout << "Error: in fn add" << std::endl;
+ return nullptr;
+ }
+
+ return fn;
+ });
+}
+
+} // namespace fbgemm
diff --git a/src/GenerateI8Depthwise.h b/src/GenerateI8Depthwise.h
new file mode 100644
index 0000000..4e5d2ee
--- /dev/null
+++ b/src/GenerateI8Depthwise.h
@@ -0,0 +1,41 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+
+#include <cstdint>
+
+namespace fbgemm {
+
+class GenI8Depthwise {
+ public:
+ using jit_kernel_signature = void (*)(
+ const std::uint8_t* a,
+ const std::int8_t* b,
+ std::int32_t* c,
+ std::int32_t* a_sum, // row_wise sum of A
+ int h,
+ int w,
+ int c_in, // the number of input channels
+ const int* mask,
+ int A_zero_point,
+ const int32_t* B_zero_point);
+
+ jit_kernel_signature getOrCreate(
+ int D, // dimension
+ int F, // filter size per dimension
+ bool compute_a_sum,
+ bool per_channel_quantization,
+ int remainder, // the number of channels in the remainder loop
+ int prev_skip,
+ int next_skip,
+ int top_skip,
+ int bottom_skip,
+ int left_skip,
+ int right_skip);
+};
+
+} // namespace fbgemm