diff options
author | Jongsoo Park <jongsoo@fb.com> | 2020-04-07 07:01:53 +0300 |
---|---|---|
committer | Facebook GitHub Bot <facebook-github-bot@users.noreply.github.com> | 2020-04-07 07:04:04 +0300 |
commit | 1f42be50b7f53f000b381ece2712ad361c7bf556 (patch) | |
tree | a92ceda2086e2d6aa93193e7cfebb472a7756c8c | |
parent | 35e486b706abd4d575ae9b7aaf090002aa78551e (diff) |
JIT depth-wise conv (#338)
Summary:
Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/338
Depth-wise convolution was causing a big code size due to its extensive use of template specializations. This diff instead uses JIT'ing to reduce the code size and performance gains.
TODO: we may want to land D20860370 before to reduce JIT'ing overhead but D20860370 has a dependency to C++14.
Reviewed By: dskhudia
Differential Revision: D20858973
fbshipit-source-id: f7f35153fbf2cb96b4a31a82854d669cf164033f
-rw-r--r-- | CMakeLists.txt | 1 | ||||
-rw-r--r-- | src/FbgemmI8Depthwise2DAvx2-inl.h | 938 | ||||
-rw-r--r-- | src/FbgemmI8Depthwise3DAvx2.cc | 1026 | ||||
-rw-r--r-- | src/FbgemmI8DepthwiseAvx2-inl.h | 352 | ||||
-rw-r--r-- | src/GenerateI8Depthwise.cc | 506 | ||||
-rw-r--r-- | src/GenerateI8Depthwise.h | 41 |
6 files changed, 1229 insertions, 1635 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 9abd2c9..cc891fb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,6 +38,7 @@ set(FBGEMM_GENERIC_SRCS src/EmbeddingSpMDM.cc src/FbgemmFloat16Convert.cc src/FbgemmI64.cc src/FbgemmI8Spmdm.cc + src/GenerateI8Depthwise.cc src/GenerateKernelU8S8S32ACC16.cc src/GenerateKernelU8S8S32ACC16Avx512.cc src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc diff --git a/src/FbgemmI8Depthwise2DAvx2-inl.h b/src/FbgemmI8Depthwise2DAvx2-inl.h index 7488a3c..f4f1a42 100644 --- a/src/FbgemmI8Depthwise2DAvx2-inl.h +++ b/src/FbgemmI8Depthwise2DAvx2-inl.h @@ -6,356 +6,21 @@ */ #pragma once -#include "fbgemm/UtilsAvx2.h" #include "fbgemm/Utils.h" +#include "fbgemm/UtilsAvx2.h" #include "src/FbgemmI8DepthwiseAvx2-inl.h" +#include "src/GenerateI8Depthwise.h" #include "src/MaskAvx2.h" namespace fbgemm { -template <int S = 3, bool SUM_A = false, bool REMAINDER = false> -static ALWAYS_INLINE void inner_prod_2d_packed_( - const __m256i* a_v, - const __m256i* Bp, - std::int32_t* C, - int remainder, - __m256i* a_sum = nullptr) { - return inner_prod_packed_<S * S, SUM_A, REMAINDER>( - a_v, Bp, C, remainder, a_sum); -} - -template < - bool SUM_A, - bool REMAINDER = false, - bool PER_CHANNEL_QUANTIZATION = false> -static ALWAYS_INLINE void inner_prod_3x3_packed_( - int H, - int W, - int K, - int h_in, - int w_in, - const std::uint8_t* A, - std::int32_t A_zero_point, - const std::int8_t* Bp, - const std::int32_t* B_zero_point, - std::int32_t* C, - int remainder, - std::int32_t* row_offsets) { - __m256i A_zero_point_v = - _mm256_set1_epi8(static_cast<std::uint8_t>(A_zero_point)); - __m256i mask_v = _mm256_setzero_si256(); - if (REMAINDER) { - mask_v = _mm256_load_si256(reinterpret_cast<const __m256i*>( - internal::avx2_ps_or_epi32_masks[remainder / 4])); - } - - // The code below can be written as a simple R*S loop but the compiler - // doesn't unroll so we're manually unrolling it. - // constexpr int R = 3, S = 3; - // array<__m256i, R * S> a_v; - // for (int r = 0; r < R; ++r) { - // for (int s = 0; s < S; ++s) { - // if (h_in + r >= 0 && h_in + r < H && w_in + s >= 0 && w_in + s < W) { - // if (REMAINDER) { - // a_v[r * S + s] = - // _mm256_maskload_epi32((const int *)(A + (r * W + s) * K), - // mask_v); - // } else { - // a_v[r * S + s] = - // _mm256_lddqu_si256((const __m256i *)(A + (r * W + s) * K)); - // } - // } else { - // a_v[r * S + s] = A_zero_point_v; - // } - // } - // } - __m256i a_v[9] = { - A_zero_point_v, - A_zero_point_v, - A_zero_point_v, - A_zero_point_v, - A_zero_point_v, - A_zero_point_v, - A_zero_point_v, - A_zero_point_v, - A_zero_point_v, - }; - - if (h_in >= 0 && h_in < H) { - if (w_in >= 0 && w_in < W) { - a_v[0] = load_a<REMAINDER>(A + (0 * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[1] = load_a<REMAINDER>(A + (0 * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[2] = load_a<REMAINDER>(A + (0 * W + 2) * K, mask_v); - } - } - - if (h_in + 1 >= 0 && h_in + 1 < H) { - if (w_in >= 0 && w_in < W) { - a_v[3] = load_a<REMAINDER>(A + (1 * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[4] = load_a<REMAINDER>(A + (1 * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[5] = load_a<REMAINDER>(A + (1 * W + 2) * K, mask_v); - } - } - - if (h_in + 2 >= 0 && h_in + 2 < H) { - if (w_in >= 0 && w_in < W) { - a_v[6] = load_a<REMAINDER>(A + (2 * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[7] = load_a<REMAINDER>(A + (2 * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[8] = load_a<REMAINDER>(A + (2 * W + 2) * K, mask_v); - } - } - - __m256i a_sum[4]; - inner_prod_2d_packed_<3, SUM_A, REMAINDER>( - a_v, reinterpret_cast<const __m256i*>(Bp), C, remainder, a_sum); - if (SUM_A) { - __m256i B_zero_point_v; - for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) { - if (PER_CHANNEL_QUANTIZATION) { - B_zero_point_v = _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(B_zero_point + i * 8)); - } else { - B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]); - } - _mm256_store_si256( - reinterpret_cast<__m256i*>(&row_offsets[i * 8]), - _mm256_mullo_epi32(a_sum[i], B_zero_point_v)); - } - } -} - -template < - bool SUM_A, - bool REMAINDER = false, - bool PER_CHANNEL_QUANTIZATION = false> -static ALWAYS_INLINE void inner_prod_5x5_packed_( - int H, - int W, - int K, - int h_in, - int w_in, - const std::uint8_t* A, - std::int32_t A_zero_point, - const std::int8_t* Bp, - const std::int32_t* B_zero_point, - std::int32_t* C, - int remainder, - std::int32_t* row_offsets) { - __m256i A_zero_point_v = - _mm256_set1_epi8(static_cast<std::uint8_t>(A_zero_point)); - __m256i mask_v = _mm256_setzero_si256(); - if (REMAINDER) { - mask_v = _mm256_load_si256(reinterpret_cast<const __m256i*>( - internal::avx2_ps_or_epi32_masks[remainder / 4])); - } - - // The code below can be written as a simple R*S loop but the compiler - // doesn't unroll so we're manually unrolling it. - // constexpr int R = 5, S = 5; - // array<__m256i, R * S> a_v; - // for (int r = 0; r < R; ++r) { - // for (int s = 0; s < S; ++s) { - // if (h_in + r >= 0 && h_in + r < H && w_in + s >= 0 && w_in + s < W) { - // if (REMAINDER) { - // a_v[r * S + s] = - // _mm256_maskload_epi32((const int *)(A + (r * W + s) * K), - // mask_v); - // } else { - // a_v[r * S + s] = - // _mm256_lddqu_si256((const __m256i *)(A + (r * W + s) * K)); - // } - // } else { - // a_v[r * S + s] = A_zero_point_v; - // } - // } - // } - __m256i a_v[25] = { - A_zero_point_v, A_zero_point_v, A_zero_point_v, A_zero_point_v, - A_zero_point_v, A_zero_point_v, A_zero_point_v, A_zero_point_v, - A_zero_point_v, A_zero_point_v, A_zero_point_v, A_zero_point_v, - A_zero_point_v, A_zero_point_v, A_zero_point_v, A_zero_point_v, - A_zero_point_v, A_zero_point_v, A_zero_point_v, A_zero_point_v, - A_zero_point_v, A_zero_point_v, A_zero_point_v, A_zero_point_v, - A_zero_point_v, - }; - - if (h_in >= 0 && h_in < H) { - if (w_in >= 0 && w_in < W) { - a_v[0] = load_a<REMAINDER>(A + (0 * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[1] = load_a<REMAINDER>(A + (0 * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[2] = load_a<REMAINDER>(A + (0 * W + 2) * K, mask_v); - } - if (w_in + 3 >= 0 && w_in + 3 < W) { - a_v[3] = load_a<REMAINDER>(A + (0 * W + 3) * K, mask_v); - } - if (w_in + 4 >= 0 && w_in + 4 < W) { - a_v[4] = load_a<REMAINDER>(A + (0 * W + 4) * K, mask_v); - } - } - - if (h_in + 1 >= 0 && h_in + 1 < H) { - if (w_in >= 0 && w_in < W) { - a_v[5] = load_a<REMAINDER>(A + (1 * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[6] = load_a<REMAINDER>(A + (1 * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[7] = load_a<REMAINDER>(A + (1 * W + 2) * K, mask_v); - } - if (w_in + 3 >= 0 && w_in + 3 < W) { - a_v[8] = load_a<REMAINDER>(A + (1 * W + 3) * K, mask_v); - } - if (w_in + 4 >= 0 && w_in + 4 < W) { - a_v[9] = load_a<REMAINDER>(A + (1 * W + 4) * K, mask_v); - } - } - - if (h_in + 2 >= 0 && h_in + 2 < H) { - if (w_in >= 0 && w_in < W) { - a_v[10] = load_a<REMAINDER>(A + (2 * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[11] = load_a<REMAINDER>(A + (2 * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[12] = load_a<REMAINDER>(A + (2 * W + 2) * K, mask_v); - } - if (w_in + 3 >= 0 && w_in + 3 < W) { - a_v[13] = load_a<REMAINDER>(A + (2 * W + 3) * K, mask_v); - } - if (w_in + 4 >= 0 && w_in + 4 < W) { - a_v[14] = load_a<REMAINDER>(A + (2 * W + 4) * K, mask_v); - } - } - - if (h_in + 3 >= 0 && h_in + 3 < H) { - if (w_in >= 0 && w_in < W) { - a_v[15] = load_a<REMAINDER>(A + (3 * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[16] = load_a<REMAINDER>(A + (3 * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[17] = load_a<REMAINDER>(A + (3 * W + 2) * K, mask_v); - } - if (w_in + 3 >= 0 && w_in + 3 < W) { - a_v[18] = load_a<REMAINDER>(A + (3 * W + 3) * K, mask_v); - } - if (w_in + 4 >= 0 && w_in + 4 < W) { - a_v[19] = load_a<REMAINDER>(A + (3 * W + 4) * K, mask_v); - } - } - - if (h_in + 4 >= 0 && h_in + 4 < H) { - if (w_in >= 0 && w_in < W) { - a_v[20] = load_a<REMAINDER>(A + (4 * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[21] = load_a<REMAINDER>(A + (4 * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[22] = load_a<REMAINDER>(A + (4 * W + 2) * K, mask_v); - } - if (w_in + 3 >= 0 && w_in + 3 < W) { - a_v[23] = load_a<REMAINDER>(A + (4 * W + 3) * K, mask_v); - } - if (w_in + 4 >= 0 && w_in + 4 < W) { - a_v[24] = load_a<REMAINDER>(A + (4 * W + 4) * K, mask_v); - } - } - - __m256i a_sum[4]; - inner_prod_2d_packed_<5, SUM_A, REMAINDER>( - a_v, reinterpret_cast<const __m256i*>(Bp), C, remainder, a_sum); - if (SUM_A) { - __m256i B_zero_point_v; - for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) { - if (PER_CHANNEL_QUANTIZATION) { - B_zero_point_v = _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(B_zero_point + i * 8)); - } else { - B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]); - } - _mm256_store_si256( - reinterpret_cast<__m256i*>(&row_offsets[i * 8]), - _mm256_mullo_epi32(a_sum[i], B_zero_point_v)); - } - } -} - -template < - int S, - bool SUM_A, - bool REMAINDER = false, - bool PER_CHANNEL_QUANTIZATION = false> -static ALWAYS_INLINE void inner_prod_2d_packed_( - int H, - int W, - int K, - int h_in, - int w_in, - const std::uint8_t* A, - std::int32_t A_zero_point, - const std::int8_t* Bp, - const std::int32_t* B_zero_point, - std::int32_t* C, - int remainder, - std::int32_t* row_offsets) { - if (S == 3) { - inner_prod_3x3_packed_<SUM_A, REMAINDER, PER_CHANNEL_QUANTIZATION>( - H, - W, - K, - h_in, - w_in, - A, - A_zero_point, - Bp, - B_zero_point, - C, - remainder, - row_offsets); - } else { - assert(S == 5); - inner_prod_5x5_packed_<SUM_A, REMAINDER, PER_CHANNEL_QUANTIZATION>( - H, - W, - K, - h_in, - w_in, - A, - A_zero_point, - Bp, - B_zero_point, - C, - remainder, - row_offsets); - } -} - template < int S, bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, bool B_SYMMETRIC, + bool PER_CHANNEL_QUANTIZAITON, typename BIAS_TYPE> static ALWAYS_INLINE void depthwise_2d_kernel_( int H, @@ -367,16 +32,17 @@ static ALWAYS_INLINE void depthwise_2d_kernel_( int stride_w, std::int32_t A_zero_point, const std::uint8_t* A, - std::int32_t B_zero_point, + const std::int32_t* B_zero_point, const std::int8_t* Bp, - float C_multiplier, + const float* C_multiplier, std::int32_t C_zero_point, std::int32_t* C_int32, std::uint8_t* C_uint8, std::int32_t* row_offsets, const std::int32_t* col_offsets, const BIAS_TYPE* bias, - float act_times_w_scale) { + const float* act_times_w_scale, + GenI8Depthwise::jit_kernel_signature* pregenerated_kernel = nullptr) { constexpr int PAD_T = (S - 1) / 2, PAD_L = (S - 1) / 2, PAD_R = (S - 1) / 2; int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; int h_in = -PAD_T + h * stride_h; @@ -384,138 +50,44 @@ static ALWAYS_INLINE void depthwise_2d_kernel_( constexpr int KERNEL_PROD_ALIGNED = (S * S + 1) / 2 * 2; - int k; - for (k = 0; k < K / 32 * 32; k += 32) { - inner_prod_2d_packed_<S, !B_SYMMETRIC /*SUM_A*/>( - H, - W, - K, - h_in, - w_in, - A + (h_in * W + w_in) * K + k, - A_zero_point, - Bp + k * KERNEL_PROD_ALIGNED, - &B_zero_point, - C_int32 + k, - 0, - B_SYMMETRIC ? nullptr : &row_offsets[k]); - } - int remainder = K - k; - if (remainder) { - inner_prod_2d_packed_<S, !B_SYMMETRIC, true>( - H, - W, - K, - h_in, - w_in, - A + (h_in * W + w_in) * K + k, - A_zero_point, - Bp + k * KERNEL_PROD_ALIGNED, - &B_zero_point, - C_int32 + k, - remainder, - B_SYMMETRIC ? nullptr : &row_offsets[k]); + int remainder = K % 32; + if (remainder == 0) { + remainder = 32; } - requantize_< - FUSE_RELU, - HAS_BIAS, - false, /*PER_CHAN_QUANT*/ - A_SYMMETRIC, - B_SYMMETRIC, - BIAS_TYPE>( - A_zero_point, - &C_multiplier, - C_zero_point, + GenI8Depthwise::jit_kernel_signature kernel = pregenerated_kernel + ? *pregenerated_kernel + : GenI8Depthwise().getOrCreate( + /*D=*/2, + S, + /*compute_a_sum=*/!B_SYMMETRIC, + PER_CHANNEL_QUANTIZAITON, + remainder, + 0, + 0, + /*top_skip=*/std::max(-h_in, 0), + /*bottom_skip=*/std::max(h_in + S - H, 0), + /*left_skip=*/std::max(-w_in, 0), + /*right_skip=*/std::max(w_in + S - W, 0)); + + kernel( + A + (h_in * W + w_in) * K, + Bp, C_int32, - C_uint8 + (h * W_OUT + w) * K, + B_SYMMETRIC ? nullptr : row_offsets, + H, + W, K, - row_offsets, - col_offsets, - bias, - &act_times_w_scale); -} - -template < - int S, - bool FUSE_RELU, - bool HAS_BIAS, - bool A_SYMMETRIC, - typename BIAS_TYPE> -static ALWAYS_INLINE void depthwise_2d_per_channel_quantization_kernel_( - int H, - int W, - int K, - int h, - int w, - int stride_h, - int stride_w, - std::int32_t A_zero_point, - const std::uint8_t* A, - const std::int32_t* B_zero_point, - const std::int8_t* Bp, - const float* C_multiplier, - std::int32_t C_zero_point, - std::int32_t* C_int32, - std::uint8_t* C_uint8, - std::int32_t* row_offsets, - const std::int32_t* col_offsets, - const BIAS_TYPE* bias, - const float* act_times_w_scale) { - constexpr int PAD_T = (S - 1) / 2, PAD_L = (S - 1) / 2, PAD_R = (S - 1) / 2; - int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; - int h_in = -PAD_T + h * stride_h; - int w_in = -PAD_L + w * stride_w; - - constexpr int KERNEL_PROD_ALIGNED = (S * S + 1) / 2 * 2; - - int k; - for (k = 0; k < K / 32 * 32; k += 32) { - inner_prod_2d_packed_< - S, - true, /*SUM_A*/ - false, /*remainder*/ - true /*per-channel*/>( - H, - W, - K, - h_in, - w_in, - A + (h_in * W + w_in) * K + k, - A_zero_point, - Bp + k * KERNEL_PROD_ALIGNED, - B_zero_point + k, - C_int32 + k, - 0, - &row_offsets[k]); - } - int remainder = K - k; - if (remainder) { - inner_prod_2d_packed_< - S, - true, /*SUM_A*/ - true, /*remainder*/ - true /*per-channel*/>( - H, - W, - K, - h_in, - w_in, - A + (h_in * W + w_in) * K + k, - A_zero_point, - Bp + k * KERNEL_PROD_ALIGNED, - B_zero_point + k, - C_int32 + k, - remainder, - &row_offsets[k]); - } + internal::avx2_ps_or_epi32_combined_mask, + A_zero_point, + B_zero_point); requantize_< FUSE_RELU, HAS_BIAS, - true, /*PER_CHAN_QUANT*/ + PER_CHANNEL_QUANTIZAITON, A_SYMMETRIC, - false, /*B_SYMM*/ + B_SYMMETRIC, BIAS_TYPE>( A_zero_point, C_multiplier, @@ -539,7 +111,8 @@ template < bool HAS_BIAS, bool A_SYMMETRIC, bool B_SYMMETRIC, - typename BIAS_TYPE> + typename BIAS_TYPE, + bool PER_CHANNEL_QUANTIZATION> static ALWAYS_INLINE void depthwise_2d_( int N, int H, @@ -549,15 +122,15 @@ static ALWAYS_INLINE void depthwise_2d_( int stride_w, std::int32_t A_zero_point, const std::uint8_t* A, - std::int32_t B_zero_point, + const std::int32_t* B_zero_point, const PackedDepthWiseConvMatrix& B, - float C_multiplier, + const float* C_multiplier, std::int32_t C_zero_point, std::int32_t* C_int32, std::uint8_t* C_uint8, const std::int32_t* col_offsets, const BIAS_TYPE* bias, - float act_times_w_scale, + const float* act_times_w_scale, int thread_id, int num_threads) { assert(K % 8 == 0); @@ -586,6 +159,8 @@ static ALWAYS_INLINE void depthwise_2d_( fbgemmPartition1D( th_info.n_thread_id, th_info.n_num_threads, W_OUT, w_begin, w_end); + GenI8Depthwise::jit_kernel_signature middle_kernel; + for (int n = n_begin; n < n_end; ++n) { const std::uint8_t* A_base = A + n * H * W * K; std::uint8_t* C_uint8_base = C_uint8 + n * H_OUT * W_OUT * K; @@ -593,14 +168,15 @@ static ALWAYS_INLINE void depthwise_2d_( int h = 0; int w = 0; - for (h = h_begin; h < std::max(PAD_T, h_begin); ++h) { - for (w = w_begin; w < std::max(PAD_L, w_begin); ++w) { + for (h = h_begin; h < PAD_T; ++h) { + for (w = w_begin; w < PAD_L; ++w) { depthwise_2d_kernel_< S, FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC, + PER_CHANNEL_QUANTIZATION, BIAS_TYPE>( H, W, @@ -623,13 +199,14 @@ static ALWAYS_INLINE void depthwise_2d_( act_times_w_scale); } - for (; w < std::min(W_OUT - PAD_R, w_end); ++w) { + for (; w < std::min(W_OUT - PAD_R - stride_w + 1, w_end); ++w) { depthwise_2d_kernel_< S, FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC, + PER_CHANNEL_QUANTIZATION, BIAS_TYPE>( H, W, @@ -659,6 +236,7 @@ static ALWAYS_INLINE void depthwise_2d_( HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC, + PER_CHANNEL_QUANTIZATION, BIAS_TYPE>( H, W, @@ -682,14 +260,27 @@ static ALWAYS_INLINE void depthwise_2d_( } } - for (; h < std::min(H - PAD_B, h_end); ++h) { - for (w = w_begin; w < std::max(PAD_L, w_begin); ++w) { + // h <= H_OUT - PAD_B - stride_h + // h <= (H + PAD_T + PAD_B - S) / stride_h + 1 - PAD_B - stride_h + // h_in <= -PAD_T + + // ((H + PAD_T + PAD_B - S) / stride_h + 1 - PAD_B - stride_h) * stride_h + // Case 1) For stride_h == 1, + // h_in <= -PAD_T + H + PAD_T + PAD_B - S + 1 - PAD_B - 1 + // h_in + S - H <= 0 + // Case 2) For stride_h == 2, + // h_in <= -PAD_L + + // H + PAD_T + PAD_B - S + 1 + (1 - PAD_B - stride_h) * stride_h + // h_in + S - H <= PAD_B * (1 - stride_h) + 1 + (1 - stride_h) * stride_h + // <= -PAD_B + 1 - stride_h <= 0 + for (; h < std::min(H_OUT - PAD_B - stride_h + 1, h_end); ++h) { + for (w = w_begin; w < PAD_L; ++w) { depthwise_2d_kernel_< S, FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC, + PER_CHANNEL_QUANTIZATION, BIAS_TYPE>( H, W, @@ -712,13 +303,32 @@ static ALWAYS_INLINE void depthwise_2d_( act_times_w_scale); } - for (; w < std::min(W_OUT - PAD_R, w_end); ++w) { + for (; w < std::min(W_OUT - PAD_R - stride_w + 1, w_end); ++w) { + if (n == n_begin && w == std::max(PAD_L, w_begin)) { + int remainder = K % 32; + if (remainder == 0) { + remainder = 32; + } + middle_kernel = GenI8Depthwise().getOrCreate( + /*D=*/2, + S, + /*compute_a_sum=*/!B_SYMMETRIC, + PER_CHANNEL_QUANTIZATION, + remainder, + 0, + 0, + 0, + 0, + 0, + 0); + } depthwise_2d_kernel_< S, FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC, + PER_CHANNEL_QUANTIZATION, BIAS_TYPE>( H, W, @@ -738,7 +348,8 @@ static ALWAYS_INLINE void depthwise_2d_( row_offsets, col_offsets, bias, - act_times_w_scale); + act_times_w_scale, + &middle_kernel); } for (; w < w_end; ++w) { @@ -748,6 +359,7 @@ static ALWAYS_INLINE void depthwise_2d_( HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC, + PER_CHANNEL_QUANTIZATION, BIAS_TYPE>( H, W, @@ -772,13 +384,14 @@ static ALWAYS_INLINE void depthwise_2d_( } for (; h < h_end; ++h) { - for (w = w_begin; w < std::max(PAD_L, w_begin); ++w) { + for (w = w_begin; w < PAD_L; ++w) { depthwise_2d_kernel_< S, FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC, + PER_CHANNEL_QUANTIZATION, BIAS_TYPE>( H, W, @@ -801,13 +414,14 @@ static ALWAYS_INLINE void depthwise_2d_( act_times_w_scale); } - for (; w < std::min(W_OUT - PAD_R, w_end); ++w) { + for (; w < std::min(W_OUT - PAD_R - stride_w + 1, w_end); ++w) { depthwise_2d_kernel_< S, FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC, + PER_CHANNEL_QUANTIZATION, BIAS_TYPE>( H, W, @@ -837,327 +451,7 @@ static ALWAYS_INLINE void depthwise_2d_( HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC, - BIAS_TYPE>( - H, - W, - K, - h, - w, - stride_h, - stride_w, - A_zero_point, - A_base, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32, - C_uint8_base, - row_offsets, - col_offsets, - bias, - act_times_w_scale); - } - } - } // for each n - - fbgemmAlignedFree(row_offsets); -}; - -template < - int S, - bool FUSE_RELU, - bool HAS_BIAS, - bool A_SYMMETRIC, - typename BIAS_TYPE> -static ALWAYS_INLINE void depthwise_2d_per_channel_quantization_( - int N, - int H, - int W, - int K, - int stride_h, - int stride_w, - std::int32_t A_zero_point, - const std::uint8_t* A, - const std::int32_t* B_zero_point, - const PackedDepthWiseConvMatrix& B, - const float* C_multiplier, - std::int32_t C_zero_point, - std::int32_t* C_int32, - std::uint8_t* C_uint8, - const std::int32_t* col_offsets, - const BIAS_TYPE* bias, - const float* act_times_w_scale, - int thread_id, - int num_threads) { - assert(K % 8 == 0); - constexpr int R = S; - constexpr int PAD_T = (R - 1) / 2, PAD_B = PAD_T, PAD_L = (S - 1) / 2, - PAD_R = PAD_L; - int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; - const std::int8_t* Bp = B.PackedMat(); - - int32_t* row_offsets = static_cast<int32_t*>( - fbgemmAlignedAlloc(64, (K + 31) / 32 * 32 * sizeof(int32_t))); - - int n_begin, n_end, h_begin, h_end, w_begin, w_end; - // Reuse the 3-dim partition scheme for parallelization in matrix - // multiplication. - thread_type_t th_info = - fbgemmGetThreadPartition(N, H_OUT, W_OUT, thread_id, num_threads); - // Calculate the begin and end index along the batch (N) dimension - fbgemmPartition1D( - th_info.g_thread_id, th_info.g_num_threads, N, n_begin, n_end); - // Calculate the begin and end index along the H dimension - fbgemmPartition1D( - th_info.m_thread_id, th_info.m_num_threads, H_OUT, h_begin, h_end); - // Calculate the begin and end index along the W dimension - fbgemmPartition1D( - th_info.n_thread_id, th_info.n_num_threads, W_OUT, w_begin, w_end); - - for (int n = n_begin; n < n_end; ++n) { - const std::uint8_t* A_base = A + n * H * W * K; - std::uint8_t* C_uint8_base = C_uint8 + n * H_OUT * W_OUT * K; - - int h = 0; - int w = 0; - - for (h = h_begin; h < std::max(PAD_T, h_begin); ++h) { - for (w = w_begin; w < std::max(PAD_L, w_begin); ++w) { - depthwise_2d_per_channel_quantization_kernel_< - S, - FUSE_RELU, - HAS_BIAS, - A_SYMMETRIC, - BIAS_TYPE>( - H, - W, - K, - h, - w, - stride_h, - stride_w, - A_zero_point, - A_base, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32, - C_uint8_base, - row_offsets, - col_offsets, - bias, - act_times_w_scale); - } - - for (; w < std::min(W_OUT - PAD_R, w_end); ++w) { - depthwise_2d_per_channel_quantization_kernel_< - S, - FUSE_RELU, - HAS_BIAS, - A_SYMMETRIC, - BIAS_TYPE>( - H, - W, - K, - h, - w, - stride_h, - stride_w, - A_zero_point, - A_base, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32, - C_uint8_base, - row_offsets, - col_offsets, - bias, - act_times_w_scale); - } - - for (; w < w_end; ++w) { - depthwise_2d_per_channel_quantization_kernel_< - S, - FUSE_RELU, - HAS_BIAS, - A_SYMMETRIC, - BIAS_TYPE>( - H, - W, - K, - h, - w, - stride_h, - stride_w, - A_zero_point, - A_base, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32, - C_uint8_base, - row_offsets, - col_offsets, - bias, - act_times_w_scale); - } - } - - for (; h < std::min(H - PAD_B, h_end); ++h) { - for (w = w_begin; w < std::max(PAD_L, w_begin); ++w) { - depthwise_2d_per_channel_quantization_kernel_< - S, - FUSE_RELU, - HAS_BIAS, - A_SYMMETRIC, - BIAS_TYPE>( - H, - W, - K, - h, - w, - stride_h, - stride_w, - A_zero_point, - A_base, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32, - C_uint8_base, - row_offsets, - col_offsets, - bias, - act_times_w_scale); - } - - for (; w < std::min(W_OUT - PAD_R, w_end); ++w) { - depthwise_2d_per_channel_quantization_kernel_< - S, - FUSE_RELU, - HAS_BIAS, - A_SYMMETRIC, - BIAS_TYPE>( - H, - W, - K, - h, - w, - stride_h, - stride_w, - A_zero_point, - A_base, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32, - C_uint8_base, - row_offsets, - col_offsets, - bias, - act_times_w_scale); - } - - for (; w < w_end; ++w) { - depthwise_2d_per_channel_quantization_kernel_< - S, - FUSE_RELU, - HAS_BIAS, - A_SYMMETRIC, - BIAS_TYPE>( - H, - W, - K, - h, - w, - stride_h, - stride_w, - A_zero_point, - A_base, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32, - C_uint8_base, - row_offsets, - col_offsets, - bias, - act_times_w_scale); - } - } - - for (; h < h_end; ++h) { - for (w = w_begin; w < std::max(PAD_L, w_begin); ++w) { - depthwise_2d_per_channel_quantization_kernel_< - S, - FUSE_RELU, - HAS_BIAS, - A_SYMMETRIC, - BIAS_TYPE>( - H, - W, - K, - h, - w, - stride_h, - stride_w, - A_zero_point, - A_base, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32, - C_uint8_base, - row_offsets, - col_offsets, - bias, - act_times_w_scale); - } - - for (; w < std::min(W_OUT - PAD_R, w_end); ++w) { - depthwise_2d_per_channel_quantization_kernel_< - S, - FUSE_RELU, - HAS_BIAS, - A_SYMMETRIC, - BIAS_TYPE>( - H, - W, - K, - h, - w, - stride_h, - stride_w, - A_zero_point, - A_base, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32, - C_uint8_base, - row_offsets, - col_offsets, - bias, - act_times_w_scale); - } - - for (; w < w_end; ++w) { - depthwise_2d_per_channel_quantization_kernel_< - S, - FUSE_RELU, - HAS_BIAS, - A_SYMMETRIC, + PER_CHANNEL_QUANTIZATION, BIAS_TYPE>( H, W, @@ -1216,7 +510,8 @@ static void depthwise_2d_( HAS_BIAS, true /*A_symmetric*/, true /*B_symmetric*/, - BIAS_TYPE>( + BIAS_TYPE, + false /*PER_CHANNEL_QUANTIZAITON*/>( N, H, W, @@ -1225,15 +520,15 @@ static void depthwise_2d_( stride_w, A_zero_point, A, - B_zero_point, + &B_zero_point, B, - C_multiplier, + &C_multiplier, C_zero_point, C_int32_temp, C, col_offsets, bias, - act_times_w_scale, + &act_times_w_scale, thread_id, num_threads); } else { @@ -1243,7 +538,8 @@ static void depthwise_2d_( HAS_BIAS, true /*A_symmetric*/, false /*B_symmetric*/, - BIAS_TYPE>( + BIAS_TYPE, + false /*PER_CHANNEL_QUANTIZAITON*/>( N, H, W, @@ -1252,15 +548,15 @@ static void depthwise_2d_( stride_w, A_zero_point, A, - B_zero_point, + &B_zero_point, B, - C_multiplier, + &C_multiplier, C_zero_point, C_int32_temp, C, col_offsets, bias, - act_times_w_scale, + &act_times_w_scale, thread_id, num_threads); } @@ -1272,7 +568,8 @@ static void depthwise_2d_( HAS_BIAS, false /*A_symmetric*/, true /*B_symmetric*/, - BIAS_TYPE>( + BIAS_TYPE, + false /*PER_CHANNEL_QUANTIZAITON*/>( N, H, W, @@ -1281,15 +578,15 @@ static void depthwise_2d_( stride_w, A_zero_point, A, - B_zero_point, + &B_zero_point, B, - C_multiplier, + &C_multiplier, C_zero_point, C_int32_temp, C, col_offsets, bias, - act_times_w_scale, + &act_times_w_scale, thread_id, num_threads); } else { @@ -1299,7 +596,8 @@ static void depthwise_2d_( HAS_BIAS, false /*A_symmetric*/, false /*B_symmetric*/, - BIAS_TYPE>( + BIAS_TYPE, + false /*PER_CHANNEL_QUANTIZAITON*/>( N, H, W, @@ -1308,15 +606,15 @@ static void depthwise_2d_( stride_w, A_zero_point, A, - B_zero_point, + &B_zero_point, B, - C_multiplier, + &C_multiplier, C_zero_point, C_int32_temp, C, col_offsets, bias, - act_times_w_scale, + &act_times_w_scale, thread_id, num_threads); } @@ -1412,12 +710,14 @@ static void depthwise_2d_per_channel_quantization_( int32_t* C_int32_temp = static_cast<int32_t*>( fbgemmAlignedAlloc(64, (K + 31) / 32 * 32 * sizeof(int32_t))); if (A_zero_point == 0 || col_offsets == nullptr) { - depthwise_2d_per_channel_quantization_< + depthwise_2d_< S, FUSE_RELU, HAS_BIAS, true /*A_SYMM*/, - BIAS_TYPE>( + false /*B_SYMM*/, + BIAS_TYPE, + true /*PER_CHANNEL_QUANTIZAITON*/>( N, H, W, @@ -1438,12 +738,14 @@ static void depthwise_2d_per_channel_quantization_( thread_id, num_threads); } else { - depthwise_2d_per_channel_quantization_< + depthwise_2d_< S, FUSE_RELU, HAS_BIAS, false /*A_SYMM*/, - BIAS_TYPE>( + false /*B_SYMM*/, + BIAS_TYPE, + true /*PER_CHANNEL_QUANTIZAITON*/>( N, H, W, diff --git a/src/FbgemmI8Depthwise3DAvx2.cc b/src/FbgemmI8Depthwise3DAvx2.cc index f70f915..8993e2b 100644 --- a/src/FbgemmI8Depthwise3DAvx2.cc +++ b/src/FbgemmI8Depthwise3DAvx2.cc @@ -11,6 +11,7 @@ #include <string> #include "./FbgemmI8DepthwiseAvx2-inl.h" +#include "./GenerateI8Depthwise.h" #include "./MaskAvx2.h" #include "fbgemm/Utils.h" #include "fbgemm/UtilsAvx2.h" @@ -20,261 +21,11 @@ using namespace std; namespace fbgemm { template < - bool SUM_A, - bool REMAINDER = false, - bool PER_CHANNEL_QUANTIZATION = false> -static ALWAYS_INLINE void inner_prod_3x3x3_packed_( - int T, - int H, - int W, - int K, - int t_in, - int h_in, - int w_in, - const uint8_t* A, - int32_t A_zero_point, - const int8_t* Bp, - const int32_t* B_zero_point, - int32_t* C, - int remainder, - int32_t* row_offsets) { - __m256i A_zero_point_v = _mm256_set1_epi8(static_cast<uint8_t>(A_zero_point)); - __m256i mask_v = _mm256_setzero_si256(); - if (REMAINDER) { - mask_v = _mm256_load_si256(reinterpret_cast<const __m256i*>( - internal::avx2_ps_or_epi32_masks[remainder / 4])); - } - - // The code below can be written as a simple R*S loop but the compiler - // doesn't unroll so we're manually unrolling it. - // constexpr int R = 3, S = 3; - // array<__m256i, R * S> a_v; - // for (int r = 0; r < R; ++r) { - // for (int s = 0; s < S; ++s) { - // if (h_in + r >= 0 && h_in + r < H && w_in + s >= 0 && w_in + s < W) { - // if (REMAINDER) { - // a_v[r * S + s] = - // _mm256_maskload_epi32((const int *)(A + (r * W + s) * K), - // mask_v); - // } else { - // a_v[r * S + s] = - // _mm256_lddqu_si256((const __m256i *)(A + (r * W + s) * K)); - // } - // } else { - // a_v[r * S + s] = A_zero_point_v; - // } - // } - // } - __m256i a_v[8]; - a_v[0] = A_zero_point_v; - a_v[1] = A_zero_point_v; - a_v[2] = A_zero_point_v; - a_v[3] = A_zero_point_v; - a_v[4] = A_zero_point_v; - a_v[5] = A_zero_point_v; - a_v[6] = A_zero_point_v; - a_v[7] = A_zero_point_v; - - if (t_in >= 0 && t_in < T) { - if (h_in >= 0 && h_in < H) { - if (w_in >= 0 && w_in < W) { - a_v[0] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[1] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[2] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 2) * K, mask_v); - } - } - - if (h_in + 1 >= 0 && h_in + 1 < H) { - if (w_in >= 0 && w_in < W) { - a_v[3] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[4] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[5] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 2) * K, mask_v); - } - } - - if (h_in + 2 >= 0 && h_in + 2 < H) { - if (w_in >= 0 && w_in < W) { - a_v[6] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[7] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 1) * K, mask_v); - } - } - } - - __m256i a_sum[4]; - inner_prod_packed_<8, SUM_A, REMAINDER>( - a_v, reinterpret_cast<const __m256i*>(Bp), C, remainder, a_sum); - - a_v[0] = A_zero_point_v; - a_v[1] = A_zero_point_v; - a_v[2] = A_zero_point_v; - a_v[3] = A_zero_point_v; - a_v[4] = A_zero_point_v; - a_v[5] = A_zero_point_v; - a_v[6] = A_zero_point_v; - a_v[7] = A_zero_point_v; - - if (t_in >= 0 && t_in < T) { - if (h_in + 2 >= 0 && h_in + 2 < H) { - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[0] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 2) * K, mask_v); - } - } - } - - if (t_in + 1 >= 0 && t_in + 1 < T) { - if (h_in >= 0 && h_in < H) { - if (w_in >= 0 && w_in < W) { - a_v[1] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[2] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[3] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 2) * K, mask_v); - } - } - - if (h_in + 1 >= 0 && h_in + 1 < H) { - if (w_in >= 0 && w_in < W) { - a_v[4] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[5] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[6] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 2) * K, mask_v); - } - } - - if (h_in + 2 >= 0 && h_in + 2 < H) { - if (w_in >= 0 && w_in < W) { - a_v[7] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 0) * K, mask_v); - } - } - } - - __m256i a_sum_temp[4]; - inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>( - a_v, reinterpret_cast<const __m256i*>(Bp) + 8, C, remainder, a_sum_temp); - if (SUM_A) { - a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]); - a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]); - a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]); - a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]); - } - - a_v[0] = A_zero_point_v; - a_v[1] = A_zero_point_v; - a_v[2] = A_zero_point_v; - a_v[3] = A_zero_point_v; - a_v[4] = A_zero_point_v; - a_v[5] = A_zero_point_v; - a_v[6] = A_zero_point_v; - a_v[7] = A_zero_point_v; - - if (t_in + 1 >= 0 && t_in + 1 < T) { - if (h_in + 2 >= 0 && h_in + 2 < H) { - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[0] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[1] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 2) * K, mask_v); - } - } - } - - if (t_in + 2 >= 0 && t_in + 2 < T) { - if (h_in >= 0 && h_in < H) { - if (w_in >= 0 && w_in < W) { - a_v[2] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[3] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[4] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 2) * K, mask_v); - } - } - - if (h_in + 1 >= 0 && h_in + 1 < H) { - if (w_in >= 0 && w_in < W) { - a_v[5] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[6] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[7] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 2) * K, mask_v); - } - } - } - - inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>( - a_v, reinterpret_cast<const __m256i*>(Bp) + 16, C, remainder, a_sum_temp); - if (SUM_A) { - a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]); - a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]); - a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]); - a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]); - } - - a_v[0] = A_zero_point_v; - a_v[1] = A_zero_point_v; - a_v[2] = A_zero_point_v; - - if (t_in + 2 >= 0 && t_in + 2 < T) { - if (h_in + 2 >= 0 && h_in + 2 < H) { - if (w_in >= 0 && w_in < W) { - a_v[0] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[1] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[2] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 2) * K, mask_v); - } - } - } - - inner_prod_packed_<3, SUM_A, REMAINDER, true /* acc */>( - a_v, reinterpret_cast<const __m256i*>(Bp) + 24, C, remainder, a_sum_temp); - - if (SUM_A) { - a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]); - a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]); - a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]); - a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]); - - __m256i B_zero_point_v; - for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) { - if (PER_CHANNEL_QUANTIZATION) { - B_zero_point_v = _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(B_zero_point + i * 8)); - } else { - B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]); - } - _mm256_store_si256( - reinterpret_cast<__m256i*>(&row_offsets[i * 8]), - _mm256_mullo_epi32(a_sum[i], B_zero_point_v)); - } - } -} - -template < bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, bool B_SYMMETRIC, + bool PER_CHANNEL_QUANTIZATION, typename BIAS_TYPE> static ALWAYS_INLINE void depthwise_3x3x3_kernel_( int T, @@ -289,16 +40,17 @@ static ALWAYS_INLINE void depthwise_3x3x3_kernel_( int stride_w, int32_t A_zero_point, const uint8_t* A, - int32_t B_zero_point, + const int32_t* B_zero_point, const int8_t* Bp, - float C_multiplier, + const float* C_multiplier, int32_t C_zero_point, int32_t* C_int32, uint8_t* C_uint8, int32_t* row_offsets, const int32_t* col_offsets, const BIAS_TYPE* bias, - float act_times_w_scale) { + const float* act_times_w_scale, + GenI8Depthwise::jit_kernel_signature* pregenerated_kernel = nullptr) { constexpr int R = 3, S = 3; constexpr int PAD_P = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; @@ -307,141 +59,43 @@ static ALWAYS_INLINE void depthwise_3x3x3_kernel_( int h_in = -PAD_T + h * stride_h; int w_in = -PAD_L + w * stride_w; - int k; - for (k = 0; k < K / 32 * 32; k += 32) { - inner_prod_3x3x3_packed_<!B_SYMMETRIC /*SUM_A*/>( - T, - H, - W, - K, - t_in, - h_in, - w_in, - A + ((t_in * H + h_in) * W + w_in) * K + k, - A_zero_point, - Bp + k * 28, - &B_zero_point, - C_int32 + k, - 0, - B_SYMMETRIC ? nullptr : &row_offsets[k]); - } - int remainder = K - k; - if (remainder) { - inner_prod_3x3x3_packed_<!B_SYMMETRIC /*SUM_A*/, true>( - T, - H, - W, - K, - t_in, - h_in, - w_in, - A + ((t_in * H + h_in) * W + w_in) * K + k, - A_zero_point, - Bp + k * 28, - &B_zero_point, - C_int32 + k, - remainder, - B_SYMMETRIC ? nullptr : &row_offsets[k]); + int remainder = K % 32; + if (remainder == 0) { + remainder = 32; } - requantize_< - FUSE_RELU, - HAS_BIAS, - false, /*PER_CHAN_QUANT*/ - A_SYMMETRIC, - B_SYMMETRIC>( - A_zero_point, - &C_multiplier, - C_zero_point, + GenI8Depthwise::jit_kernel_signature kernel = pregenerated_kernel + ? *pregenerated_kernel + : GenI8Depthwise().getOrCreate( + /*D=*/3, + /*S=*/3, + /*compute_a_sum=*/!B_SYMMETRIC, + PER_CHANNEL_QUANTIZATION, + remainder, + /*prev_skip=*/std::max(-t_in, 0), + /*next_skip=*/std::max(t_in + 3 - T, 0), + /*top_skip=*/std::max(-h_in, 0), + /*bottom_skip=*/std::max(h_in + 3 - H, 0), + /*left_skip=*/std::max(-w_in, 0), + /*right_skip=*/std::max(w_in + 3 - W, 0)); + kernel( + A + ((t_in * H + h_in) * W + w_in) * K, + Bp, C_int32, - C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K, + B_SYMMETRIC ? nullptr : row_offsets, + H, + W, K, - row_offsets, - col_offsets, - bias, - &act_times_w_scale); -} - -template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, typename BIAS_TYPE> -static ALWAYS_INLINE void depthwise_3x3x3_per_channel_quantization_kernel_( - int T, - int H, - int W, - int K, - int t, - int h, - int w, - int stride_t, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - const int32_t* B_zero_point, - const int8_t* Bp, - const float* C_multiplier, - int32_t C_zero_point, - int32_t* C_int32, - uint8_t* C_uint8, - int32_t* row_offsets, - const int32_t* col_offsets, - const BIAS_TYPE* bias, - const float* act_times_w_scale) { - constexpr int R = 3, S = 3; - constexpr int PAD_P = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; - int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; - int t_in = -PAD_P + t * stride_t; - int h_in = -PAD_T + h * stride_h; - int w_in = -PAD_L + w * stride_w; + internal::avx2_ps_or_epi32_combined_mask, + A_zero_point, + B_zero_point); - int k; - for (k = 0; k < K / 32 * 32; k += 32) { - inner_prod_3x3x3_packed_< - true, /*SUM_A*/ - false, /*remainder*/ - true /*per-channel*/>( - T, - H, - W, - K, - t_in, - h_in, - w_in, - A + ((t_in * H + h_in) * W + w_in) * K + k, - A_zero_point, - Bp + k * 28, - B_zero_point + k, - C_int32 + k, - 0, - &row_offsets[k]); - } - int remainder = K - k; - if (remainder) { - inner_prod_3x3x3_packed_< - true, /*SUM_A*/ - true, /*remainder*/ - true /*per-channel*/>( - T, - H, - W, - K, - t_in, - h_in, - w_in, - A + ((t_in * H + h_in) * W + w_in) * K + k, - A_zero_point, - Bp + k * 28, - B_zero_point + k, - C_int32 + k, - remainder, - &row_offsets[k]); - } requantize_< FUSE_RELU, HAS_BIAS, - true, /*PER_CHAN_QUANT*/ + PER_CHANNEL_QUANTIZATION, A_SYMMETRIC, - false /*B_SYMM*/>( + B_SYMMETRIC>( A_zero_point, C_multiplier, C_zero_point, @@ -459,7 +113,8 @@ template < bool HAS_BIAS, bool A_SYMMETRIC, bool B_SYMMETRIC, - typename BIAS_TYPE> + typename BIAS_TYPE, + bool PER_CHANNEL_QUANTIZATION> static ALWAYS_INLINE void depthwise_3x3x3_pad_1_( int N, int T, @@ -471,15 +126,15 @@ static ALWAYS_INLINE void depthwise_3x3x3_pad_1_( int stride_w, int32_t A_zero_point, const uint8_t* A, - int32_t B_zero_point, + const int32_t* B_zero_point, const PackedDepthWiseConvMatrix& B, - float C_multiplier, + const float* C_multiplier, int32_t C_zero_point, int32_t* C_int32, uint8_t* C_uint8, const int32_t* col_offsets, const BIAS_TYPE* bias, - float act_times_w_scale, + const float* act_times_w_scale, int thread_id, int num_threads) { assert(K % 8 == 0); @@ -509,18 +164,173 @@ static ALWAYS_INLINE void depthwise_3x3x3_pad_1_( fbgemmPartition1D( th_info.n_thread_id, th_info.n_num_threads, H_OUT, h_begin, h_end); + GenI8Depthwise::jit_kernel_signature middle_kernel; + for (int n = n_begin; n < n_end; ++n) { const uint8_t* A_base = A + n * T * H * W * K; uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * K; - for (int t = t_begin; t < t_end; ++t) { - for (int h = h_begin; h < h_end; ++h) { + int t; + for (t = t_begin; t < PAD_P; ++t) { + int h; + for (h = h_begin; h < PAD_T; ++h) { for (int w = 0; w < W_OUT; ++w) { depthwise_3x3x3_kernel_< FUSE_RELU, HAS_BIAS, A_SYMMETRIC, - B_SYMMETRIC>( + B_SYMMETRIC, + PER_CHANNEL_QUANTIZATION>( + T, + H, + W, + K, + t, + h, + w, + stride_t, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale); + } // w + } // h + + for (; h < std::min(H_OUT - PAD_B - stride_h + 1, h_end); ++h) { + int w; + for (w = 0; w < PAD_L; ++w) { + depthwise_3x3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + PER_CHANNEL_QUANTIZATION>( + T, + H, + W, + K, + t, + h, + w, + stride_t, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale); + } // w + + GenI8Depthwise::jit_kernel_signature kernel; + for (; w < W_OUT - PAD_R - stride_w + 1; ++w) { + if (w == PAD_L) { + int remainder = K % 32; + if (remainder == 0) { + remainder = 32; + } + int t_in = -PAD_P + t * stride_t; + kernel = GenI8Depthwise().getOrCreate( + /*D=*/3, + /*F=*/3, + /*compute_a_sum=*/!B_SYMMETRIC, + /*per_chnnale_quantization=*/PER_CHANNEL_QUANTIZATION, + remainder, + /*prev_skip=*/std::max(-t_in, 0), + /*next_skip=*/std::max(t_in + 3 - T, 0), + 0, + 0, + 0, + 0); + } + depthwise_3x3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + PER_CHANNEL_QUANTIZATION>( + T, + H, + W, + K, + t, + h, + w, + stride_t, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale, + &kernel); + } // w + + for (; w < W_OUT; ++w) { + depthwise_3x3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + PER_CHANNEL_QUANTIZATION>( + T, + H, + W, + K, + t, + h, + w, + stride_t, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale); + } // w + } // h + + for (; h < h_end; ++h) { + for (int w = 0; w < W_OUT; ++w) { + depthwise_3x3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + PER_CHANNEL_QUANTIZATION>( T, H, W, @@ -546,72 +356,165 @@ static ALWAYS_INLINE void depthwise_3x3x3_pad_1_( } // w } // h } // t - } // for each n - fbgemmAlignedFree(row_offsets); -}; -template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, typename BIAS_TYPE> -static ALWAYS_INLINE void depthwise_3x3x3_per_channel_quantization_pad_1_( - int N, - int T, - int H, - int W, - int K, - int stride_t, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - const int32_t* B_zero_point, - const PackedDepthWiseConvMatrix& B, - const float* C_multiplier, - int32_t C_zero_point, - int32_t* C_int32, - uint8_t* C_uint8, - const int32_t* col_offsets, - const BIAS_TYPE* bias, - const float* act_times_w_scale, - int thread_id, - int num_threads) { - assert(K % 8 == 0); - constexpr int K_T = 3, K_H = 3, K_W = 3; - constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, - PAD_R = 1; - int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1; - int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; - const int8_t* Bp = B.PackedMat(); + for (; t < std::min(T_OUT - PAD_N - stride_t + 1, t_end); ++t) { + int h; + for (h = h_begin; h < PAD_T; ++h) { + for (int w = 0; w < W_OUT; ++w) { + depthwise_3x3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + PER_CHANNEL_QUANTIZATION>( + T, + H, + W, + K, + t, + h, + w, + stride_t, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale); + } // w + } // h - int32_t* row_offsets = static_cast<int32_t*>( - fbgemmAlignedAlloc(64, (K + 31) / 32 * 32 * sizeof(int32_t))); + for (; h < std::min(H_OUT - PAD_B - stride_h + 1, h_end); ++h) { + int w; + for (w = 0; w < PAD_L; ++w) { + depthwise_3x3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + PER_CHANNEL_QUANTIZATION>( + T, + H, + W, + K, + t, + h, + w, + stride_t, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale); + } // w - int n_begin, n_end, t_begin, t_end, h_begin, h_end; - // Reuse the 3-dim partition scheme for parallelization in matrix - // multiplication. - thread_type_t th_info = - fbgemmGetThreadPartition(N, T_OUT, H_OUT, thread_id, num_threads); - // Calculate the begin and end index along the batch (N) dimension - fbgemmPartition1D( - th_info.g_thread_id, th_info.g_num_threads, N, n_begin, n_end); - // Calculate the begin and end index along the T dimension - fbgemmPartition1D( - th_info.m_thread_id, th_info.m_num_threads, T_OUT, t_begin, t_end); - // Calculate the begin and end index along the H dimension - fbgemmPartition1D( - th_info.n_thread_id, th_info.n_num_threads, H_OUT, h_begin, h_end); + for (; w < W_OUT - PAD_R - stride_w + 1; ++w) { + if (n == n_begin && w == PAD_L) { + int remainder = K % 32; + if (remainder == 0) { + remainder = 32; + } + middle_kernel = GenI8Depthwise().getOrCreate( + /*D=*/3, + /*F=*/3, + /*compute_a_sum=*/!B_SYMMETRIC, + /*per_chnnale_quantization=*/PER_CHANNEL_QUANTIZATION, + remainder, + 0, + 0, + 0, + 0, + 0, + 0); + } + depthwise_3x3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + PER_CHANNEL_QUANTIZATION>( + T, + H, + W, + K, + t, + h, + w, + stride_t, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale, + &middle_kernel); + } - for (int n = n_begin; n < n_end; ++n) { - const uint8_t* A_base = A + n * T * H * W * K; - uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * K; + for (; w < W_OUT; ++w) { + depthwise_3x3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + PER_CHANNEL_QUANTIZATION>( + T, + H, + W, + K, + t, + h, + w, + stride_t, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale); + } + } // h - for (int t = t_begin; t < t_end; ++t) { - for (int h = h_begin; h < h_end; ++h) { + for (; h < h_end; ++h) { for (int w = 0; w < W_OUT; ++w) { - depthwise_3x3x3_per_channel_quantization_kernel_< + depthwise_3x3x3_kernel_< FUSE_RELU, HAS_BIAS, A_SYMMETRIC, - BIAS_TYPE>( + B_SYMMETRIC, + PER_CHANNEL_QUANTIZATION>( T, H, W, @@ -637,8 +540,193 @@ static ALWAYS_INLINE void depthwise_3x3x3_per_channel_quantization_pad_1_( } // w } // h } // t - } // for each n + for (; t < t_end; ++t) { + int h; + for (h = h_begin; h < PAD_T; ++h) { + for (int w = 0; w < W_OUT; ++w) { + depthwise_3x3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + PER_CHANNEL_QUANTIZATION>( + T, + H, + W, + K, + t, + h, + w, + stride_t, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale); + } // w + } // h + + for (; h < std::min(H_OUT - PAD_B - stride_h + 1, h_end); ++h) { + int w; + for (w = 0; w < PAD_L; ++w) { + depthwise_3x3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + PER_CHANNEL_QUANTIZATION>( + T, + H, + W, + K, + t, + h, + w, + stride_t, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale); + } // w + + GenI8Depthwise::jit_kernel_signature kernel; + for (; w < W_OUT - PAD_R - stride_w + 1; ++w) { + if (w == PAD_L) { + int remainder = K % 32; + if (remainder == 0) { + remainder = 32; + } + int t_in = -PAD_P + t * stride_t; + kernel = GenI8Depthwise().getOrCreate( + /*D=*/3, + /*F=*/3, + /*compute_a_sum=*/!B_SYMMETRIC, + /*per_chnnale_quantization=*/PER_CHANNEL_QUANTIZATION, + remainder, + /*prev_skip=*/std::max(-t_in, 0), + /*next_skip=*/std::max(t_in + 3 - T, 0), + 0, + 0, + 0, + 0); + } + depthwise_3x3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + PER_CHANNEL_QUANTIZATION>( + T, + H, + W, + K, + t, + h, + w, + stride_t, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale, + &kernel); + } // w + + for (; w < W_OUT; ++w) { + depthwise_3x3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + PER_CHANNEL_QUANTIZATION>( + T, + H, + W, + K, + t, + h, + w, + stride_t, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale); + } // w + } // h + + for (; h < h_end; ++h) { + for (int w = 0; w < W_OUT; ++w) { + depthwise_3x3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + PER_CHANNEL_QUANTIZATION>( + T, + H, + W, + K, + t, + h, + w, + stride_t, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale); + } // w + } // h + } // t + } // for each n fbgemmAlignedFree(row_offsets); }; @@ -674,7 +762,8 @@ static void depthwise_3x3x3_pad_1_( HAS_BIAS, true /*A_symmetric*/, true /*B_symmetric*/, - BIAS_TYPE>( + BIAS_TYPE, + false /*PER_CHANNEL_QUANTIZATION*/>( N, T, H, @@ -685,15 +774,15 @@ static void depthwise_3x3x3_pad_1_( stride_w, A_zero_point, A, - B_zero_point, + &B_zero_point, B, - C_multiplier, + &C_multiplier, C_zero_point, C_int32_temp, C, col_offsets, bias, - act_times_w_scale, + &act_times_w_scale, thread_id, num_threads); } else { @@ -702,7 +791,8 @@ static void depthwise_3x3x3_pad_1_( HAS_BIAS, true /*A_symmetric*/, false /*B_symmetric*/, - BIAS_TYPE>( + BIAS_TYPE, + false /*PER_CHANNEL_QUANTIZATION*/>( N, T, H, @@ -713,15 +803,15 @@ static void depthwise_3x3x3_pad_1_( stride_w, A_zero_point, A, - B_zero_point, + &B_zero_point, B, - C_multiplier, + &C_multiplier, C_zero_point, C_int32_temp, C, col_offsets, bias, - act_times_w_scale, + &act_times_w_scale, thread_id, num_threads); } @@ -732,7 +822,8 @@ static void depthwise_3x3x3_pad_1_( HAS_BIAS, false /*A_symmetric*/, true /*B_symmetric*/, - BIAS_TYPE>( + BIAS_TYPE, + false /*PER_CHANNEL_QUANTIZATION*/>( N, T, H, @@ -743,15 +834,15 @@ static void depthwise_3x3x3_pad_1_( stride_w, A_zero_point, A, - B_zero_point, + &B_zero_point, B, - C_multiplier, + &C_multiplier, C_zero_point, C_int32_temp, C, col_offsets, bias, - act_times_w_scale, + &act_times_w_scale, thread_id, num_threads); } else { @@ -760,7 +851,8 @@ static void depthwise_3x3x3_pad_1_( HAS_BIAS, false /*A_symmetric*/, false /*B_symmetric*/, - BIAS_TYPE>( + BIAS_TYPE, + false /*PER_CHANNEL_QUANTIZATION*/>( N, T, H, @@ -771,15 +863,15 @@ static void depthwise_3x3x3_pad_1_( stride_w, A_zero_point, A, - B_zero_point, + &B_zero_point, B, - C_multiplier, + &C_multiplier, C_zero_point, C_int32_temp, C, col_offsets, bias, - act_times_w_scale, + &act_times_w_scale, thread_id, num_threads); } @@ -970,11 +1062,13 @@ static void depthwise_3x3x3_per_channel_quantization_pad_1_( int32_t* C_int32_temp = static_cast<int32_t*>( fbgemmAlignedAlloc(64, (K + 31) / 32 * 32 * sizeof(int32_t))); if (A_zero_point == 0 || col_offsets == nullptr) { - depthwise_3x3x3_per_channel_quantization_pad_1_< + depthwise_3x3x3_pad_1_< FUSE_RELU, HAS_BIAS, true /*A_SYMM*/, - BIAS_TYPE>( + false /*B_SYMM*/, + BIAS_TYPE, + true /*PER_CHANNEL_QUANTIZATION*/>( N, T, H, @@ -997,11 +1091,13 @@ static void depthwise_3x3x3_per_channel_quantization_pad_1_( thread_id, num_threads); } else { - depthwise_3x3x3_per_channel_quantization_pad_1_< + depthwise_3x3x3_pad_1_< FUSE_RELU, HAS_BIAS, false /*A_SYMM*/, - BIAS_TYPE>( + false /*B_SYMM*/, + BIAS_TYPE, + true /*PER_CHANNEL_QUANTIZATION*/>( N, T, H, diff --git a/src/FbgemmI8DepthwiseAvx2-inl.h b/src/FbgemmI8DepthwiseAvx2-inl.h index 18dda3b..11714c8 100644 --- a/src/FbgemmI8DepthwiseAvx2-inl.h +++ b/src/FbgemmI8DepthwiseAvx2-inl.h @@ -16,349 +16,6 @@ namespace fbgemm { -// c = a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3 -// A is in uint8_t -// B is in int8_t and pre-interleaved -// C is in int32_t and 4 registers have results in the following layout: -// c0_v: c[0:4], c[16:20] -// c1_v: c[4:8], c[20:24] -// c2_v: c[8:12], c[24:28] -// c3_v: c[12:16], c[28:32] -template <bool SUM_A = false> -static ALWAYS_INLINE void madd_epi16x4_packed( - __m256i a0_v, - __m256i a1_v, - __m256i a2_v, - __m256i a3_v, - const __m256i* b, - __m256i* c0_v, - __m256i* c1_v, - __m256i* c2_v, - __m256i* c3_v, - __m256i* a_sum = nullptr) { - __m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v); - __m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v); - __m256i a23_lo_v = _mm256_unpacklo_epi8(a2_v, a3_v); - __m256i a23_hi_v = _mm256_unpackhi_epi8(a2_v, a3_v); - - if (SUM_A) { - __m256i one_epi8_v = _mm256_set1_epi8(1); - a_sum[0] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a01_lo_v, one_epi8_v), a_sum[0]); - a_sum[1] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a01_hi_v, one_epi8_v), a_sum[1]); - a_sum[0] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a23_lo_v, one_epi8_v), a_sum[0]); - a_sum[1] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a23_hi_v, one_epi8_v), a_sum[1]); - } - - __m256i a0_interleaved_v = _mm256_unpacklo_epi16(a01_lo_v, a23_lo_v); - __m256i a1_interleaved_v = _mm256_unpackhi_epi16(a01_lo_v, a23_lo_v); - __m256i a2_interleaved_v = _mm256_unpacklo_epi16(a01_hi_v, a23_hi_v); - __m256i a3_interleaved_v = _mm256_unpackhi_epi16(a01_hi_v, a23_hi_v); - - __m256i b0_v = _mm256_load_si256(b + 0); - __m256i b1_v = _mm256_load_si256(b + 1); - __m256i b2_v = _mm256_load_si256(b + 2); - __m256i b3_v = _mm256_load_si256(b + 3); - - __m256i ab0 = _mm256_maddubs_epi16(a0_interleaved_v, b0_v); - __m256i ab1 = _mm256_maddubs_epi16(a1_interleaved_v, b1_v); - __m256i ab2 = _mm256_maddubs_epi16(a2_interleaved_v, b2_v); - __m256i ab3 = _mm256_maddubs_epi16(a3_interleaved_v, b3_v); - - __m256i one_v = _mm256_set1_epi16(1); - *c0_v = _mm256_madd_epi16(ab0, one_v); - *c1_v = _mm256_madd_epi16(ab1, one_v); - *c2_v = _mm256_madd_epi16(ab2, one_v); - *c3_v = _mm256_madd_epi16(ab3, one_v); -} - -// c = a0 * b0 + a1 * b1 + a2 * b2 -// A is in uint8_t -// B is in int8_t and pre-interleaved -// C is in int32_t and 4 registers have results in the following layout: -// c0_v: c[0:4], c[16:20] -// c1_v: c[4:8], c[20:24] -// c2_v: c[8:12], c[24:28] -// c3_v: c[12:16], c[28:32] -template <bool SUM_A = false> -static ALWAYS_INLINE void madd_epi16x3_packed( - __m256i a0_v, - __m256i a1_v, - __m256i a2_v, - const __m256i* b, - __m256i* c0_v, - __m256i* c1_v, - __m256i* c2_v, - __m256i* c3_v, - __m256i* a_sum = nullptr) { - __m256i zero_v = _mm256_setzero_si256(); - - __m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v); - __m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v); - __m256i a23_lo_v = _mm256_unpacklo_epi8(a2_v, zero_v); - __m256i a23_hi_v = _mm256_unpackhi_epi8(a2_v, zero_v); - - if (SUM_A) { - __m256i one_epi8_v = _mm256_set1_epi8(1); - a_sum[0] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a01_lo_v, one_epi8_v), a_sum[0]); - a_sum[1] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a01_hi_v, one_epi8_v), a_sum[1]); - a_sum[0] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a23_lo_v, one_epi8_v), a_sum[0]); - a_sum[1] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a23_hi_v, one_epi8_v), a_sum[1]); - } - - __m256i a0_interleaved_v = _mm256_unpacklo_epi16(a01_lo_v, a23_lo_v); - __m256i a1_interleaved_v = _mm256_unpackhi_epi16(a01_lo_v, a23_lo_v); - __m256i a2_interleaved_v = _mm256_unpacklo_epi16(a01_hi_v, a23_hi_v); - __m256i a3_interleaved_v = _mm256_unpackhi_epi16(a01_hi_v, a23_hi_v); - - __m256i b0_v = _mm256_load_si256(b + 0); - __m256i b1_v = _mm256_load_si256(b + 1); - __m256i b2_v = _mm256_load_si256(b + 2); - __m256i b3_v = _mm256_load_si256(b + 3); - - __m256i ab0 = _mm256_maddubs_epi16(a0_interleaved_v, b0_v); - __m256i ab1 = _mm256_maddubs_epi16(a1_interleaved_v, b1_v); - __m256i ab2 = _mm256_maddubs_epi16(a2_interleaved_v, b2_v); - __m256i ab3 = _mm256_maddubs_epi16(a3_interleaved_v, b3_v); - - __m256i one_v = _mm256_set1_epi16(1); - *c0_v = _mm256_madd_epi16(ab0, one_v); - *c1_v = _mm256_madd_epi16(ab1, one_v); - *c2_v = _mm256_madd_epi16(ab2, one_v); - *c3_v = _mm256_madd_epi16(ab3, one_v); -} - -// c = a0 * b0 + a1 * b1 -// A is in uint8_t -// B is in int8_t and pre-interleaved -// C is in int32_t and 4 registers have results in the following layout: -// c0_v: c[0:4], c[4:8] -// c1_v: c[8:12], c[12:16] -// c2_v: c[16:20], c[20:24] -// c3_v: c[24:28], c[28:32] -template <bool SUM_A = false> -static ALWAYS_INLINE void madd_epi16x2_packed( - __m256i a0_v, - __m256i a1_v, - const __m256i* b, - __m256i* c0_v, - __m256i* c1_v, - __m256i* c2_v, - __m256i* c3_v, - __m256i* a_sum = nullptr) { - __m256i a_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v); - __m256i a_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v); - - if (SUM_A) { - __m256i one_epi8_v = _mm256_set1_epi8(1); - a_sum[0] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a_lo_v, one_epi8_v), a_sum[0]); - a_sum[1] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a_hi_v, one_epi8_v), a_sum[1]); - } - - __m256i b0_v = _mm256_load_si256(b + 0); - __m256i b1_v = _mm256_load_si256(b + 1); - - __m256i ab_lo_v = _mm256_maddubs_epi16(a_lo_v, b0_v); - __m256i ab_hi_v = _mm256_maddubs_epi16(a_hi_v, b1_v); - - *c0_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_lo_v)); - *c1_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_hi_v)); - *c2_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_lo_v, 1)); - *c3_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_hi_v, 1)); -} - -// c = a0 * b0 -// A is in uint8_t -// B is in int8_t and pre-interleaved -// C is in int32_t and 4 registers have results in the following layout: -// c0_v: c[0:4], c[4:8] -// c1_v: c[8:12], c[12:16] -// c2_v: c[16:20], c[20:24] -// c3_v: c[24:28], c[28:32] -template <bool SUM_A = false> -static ALWAYS_INLINE void madd_epi16_packed( - __m256i a_v, - const __m256i* b, - __m256i* c0_v, - __m256i* c1_v, - __m256i* c2_v, - __m256i* c3_v, - __m256i* a_sum = nullptr) { - __m256i zero_v = _mm256_setzero_si256(); - - __m256i a_lo_v = _mm256_unpacklo_epi8(a_v, zero_v); - __m256i a_hi_v = _mm256_unpackhi_epi8(a_v, zero_v); - - if (SUM_A) { - __m256i one_epi8_v = _mm256_set1_epi8(1); - a_sum[0] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a_lo_v, one_epi8_v), a_sum[0]); - a_sum[1] = - _mm256_adds_epi16(_mm256_maddubs_epi16(a_hi_v, one_epi8_v), a_sum[1]); - } - - __m256i b0_v = _mm256_load_si256(b + 0); - __m256i b1_v = _mm256_load_si256(b + 1); - - __m256i ab_lo_v = _mm256_maddubs_epi16(a_lo_v, b0_v); - __m256i ab_hi_v = _mm256_maddubs_epi16(a_hi_v, b1_v); - - *c0_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_lo_v)); - *c1_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_hi_v)); - *c2_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_lo_v, 1)); - *c3_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_hi_v, 1)); -} - -// K is the number of accumulations we're doing -template <int K, bool SUM_A = false, bool REMAINDER = false, bool ACC = false> -static ALWAYS_INLINE void inner_prod_packed_( - const __m256i* a_v, - const __m256i* Bp, - std::int32_t* C, - int remainder, - __m256i* a_sum = nullptr) { - __m256i c[4], c_temp[4]; - __m256i a_sum_temp[2] = {0, 0}; - - int k = 0; - if (K >= 4) { - madd_epi16x4_packed<SUM_A>( - a_v[0], - a_v[1], - a_v[2], - a_v[3], - Bp, - &c[0], - &c[1], - &c[2], - &c[3], - a_sum_temp); - - for (k = 4; k < K / 4 * 4; k += 4) { - madd_epi16x4_packed<SUM_A>( - a_v[k + 0], - a_v[k + 1], - a_v[k + 2], - a_v[k + 3], - Bp + k, - &c_temp[0], - &c_temp[1], - &c_temp[2], - &c_temp[3], - a_sum_temp); - - c[0] = _mm256_add_epi32(c[0], c_temp[0]); - c[1] = _mm256_add_epi32(c[1], c_temp[1]); - c[2] = _mm256_add_epi32(c[2], c_temp[2]); - c[3] = _mm256_add_epi32(c[3], c_temp[3]); - } - } else { - c[0] = _mm256_setzero_si256(); - c[1] = _mm256_setzero_si256(); - c[2] = _mm256_setzero_si256(); - c[3] = _mm256_setzero_si256(); - } - - if (K - k == 3) { - madd_epi16x3_packed<SUM_A>( - a_v[k], - a_v[k + 1], - a_v[k + 2], - Bp + k, - &c_temp[0], - &c_temp[1], - &c_temp[2], - &c_temp[3], - a_sum_temp); - - c[0] = _mm256_add_epi32(c[0], c_temp[0]); - c[1] = _mm256_add_epi32(c[1], c_temp[1]); - c[2] = _mm256_add_epi32(c[2], c_temp[2]); - c[3] = _mm256_add_epi32(c[3], c_temp[3]); - } - - c_temp[0] = _mm256_permute2f128_si256(c[0], c[1], 0x20); - c_temp[1] = _mm256_permute2f128_si256(c[2], c[3], 0x20); - c_temp[2] = _mm256_permute2f128_si256(c[0], c[1], 0x31); - c_temp[3] = _mm256_permute2f128_si256(c[2], c[3], 0x31); - - if (K - k == 0 || K - k == 3) { - c[0] = c_temp[0]; - c[1] = c_temp[1]; - c[2] = c_temp[2]; - c[3] = c_temp[3]; - } else { - if (K - k == 1) { - madd_epi16_packed<SUM_A>( - a_v[k], Bp + k, &c[0], &c[1], &c[2], &c[3], a_sum_temp); - } else if (K - k == 2) { - madd_epi16x2_packed<SUM_A>( - a_v[k], a_v[k + 1], Bp + k, &c[0], &c[1], &c[2], &c[3], a_sum_temp); - } - - c[0] = _mm256_add_epi32(c[0], c_temp[0]); - c[1] = _mm256_add_epi32(c[1], c_temp[1]); - c[2] = _mm256_add_epi32(c[2], c_temp[2]); - c[3] = _mm256_add_epi32(c[3], c_temp[3]); - } - - if (REMAINDER) { - for (int r = 0; r < remainder / 8; ++r) { - if (ACC) { - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(C + r * 8), - _mm256_add_epi32( - _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + r * 8)), - c[r])); - } else { - _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + r * 8), c[r]); - } - } - } else { - if (ACC) { - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(C), - _mm256_add_epi32( - _mm256_loadu_si256(reinterpret_cast<__m256i*>(C)), c[0])); - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(C + 8), - _mm256_add_epi32( - _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 8)), c[1])); - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(C + 16), - _mm256_add_epi32( - _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 16)), c[2])); - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(C + 24), - _mm256_add_epi32( - _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 24)), c[3])); - } else { - _mm256_storeu_si256(reinterpret_cast<__m256i*>(C), c[0]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 8), c[1]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 16), c[2]); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 24), c[3]); - } - } - - if (SUM_A) { - a_sum[0] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(a_sum_temp[0])); - a_sum[1] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(a_sum_temp[1])); - a_sum[2] = - _mm256_cvtepi16_epi32(_mm256_extracti128_si256(a_sum_temp[0], 1)); - a_sum[3] = - _mm256_cvtepi16_epi32(_mm256_extracti128_si256(a_sum_temp[1], 1)); - } -} - // Almost same as ReQuantizeOutput in OutputProcessing-inh.h but different // row_offsets for each row because of depth-wise convolution template < @@ -672,15 +329,6 @@ static ALWAYS_INLINE void requantize_( } } -template <bool REMAINDER> -static ALWAYS_INLINE __m256i load_a(const std::uint8_t* A, __m256i mask_v) { - if (REMAINDER) { - return _mm256_maskload_epi32(reinterpret_cast<const int*>(A), mask_v); - } else { - return _mm256_lddqu_si256(reinterpret_cast<const __m256i*>(A)); - } -} - static inline std::pair<int, int> closest_factors_(int n) { int a = static_cast<int>(std::sqrt(n)); while (n % a != 0) { diff --git a/src/GenerateI8Depthwise.cc b/src/GenerateI8Depthwise.cc new file mode 100644 index 0000000..4b9eb7e --- /dev/null +++ b/src/GenerateI8Depthwise.cc @@ -0,0 +1,506 @@ +#include "./GenerateI8Depthwise.h" + +#include <asmjit/asmjit.h> +#include <iostream> + +#include "./CodeCache.h" +#include "./CodeGenHelpers.h" +#include "fbgemm/Utils.h" + +namespace fbgemm { + +namespace { +asmjit::JitRuntime& runtime() { + static asmjit::JitRuntime rt; //< JIT Runtime for asmjit, + // depents on other static + // variables. Required to prevent + // initialization order fiasco + return rt; +} + +// Controll access to runtime; +std::mutex rtMutex_; + +// The hash depends on D, F, compute_a_sum, per_channel_quantization, remainder, +// prev_skip, next_skip, top_skip, bottom_skip, left_skip, and right_skip. +CodeCache< + std::tuple<int, int, bool, bool, int, int, int, int, int, int, int>, + GenI8Depthwise::jit_kernel_signature> + codeCache_; +} // namespace + +namespace x86 = asmjit::x86; + +// c = a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3 +// A is in uint8_t +// B is in int8_t and pre-interleaved +// C is in int32_t and 4 registers have results in the following layout: +// c0_v: c[0:4], c[16:20] +// c1_v: c[4:8], c[20:24] +// c2_v: c[8:12], c[24:28] +// c3_v: c[12:16], c[28:32] +static void genMaddEpi16xNPacked( + x86::Emitter* e, + x86::Ymm a[4], + x86::Gp b, + x86::Ymm c[4], + x86::Ymm* a_sum, + int n, + int remainder, + bool accumulation, + x86::Ymm one_epi8, + x86::Ymm one_epi16, + x86::Ymm zero) { + // Interleave inputs. Reuse a[1] and a[3] to save registers + x86::Ymm a01_lo(0), a01_hi(1), a23_lo(a[1].id()), a23_hi(a[3].id()); + e->vpunpcklbw(a01_lo, a[0], n == 1 ? zero : a[1]); + if (remainder >= 8) { + e->vpunpckhbw(a01_hi, a[0], n == 1 ? zero : a[1]); + } + if (n > 2) { + e->vpunpcklbw(a23_lo, a[2], n == 3 ? zero : a[3]); + if (remainder >= 8) { + e->vpunpckhbw(a23_hi, a[2], n == 3 ? zero : a[3]); + } + } + + // Compute row_wise sum of A for row_offsets + if (a_sum) { + if (accumulation) { + e->vpmaddubsw(a[0], a01_lo, one_epi8); + e->vpaddsw(a_sum[0], a[0], a_sum[0]); + + if (remainder >= 8) { + e->vpmaddubsw(a[2], a01_hi, one_epi8); + e->vpaddsw(a_sum[1], a[2], a_sum[1]); + } + } else { + e->vpmaddubsw(a_sum[0], a01_lo, one_epi8); + if (remainder >= 8) { + e->vpmaddubsw(a_sum[1], a01_hi, one_epi8); + } + } + + if (n > 2) { + e->vpmaddubsw(a[0], a23_lo, one_epi8); + e->vpaddsw(a_sum[0], a[0], a_sum[0]); + + if (remainder >= 8) { + e->vpmaddubsw(a[2], a23_hi, one_epi8); + e->vpaddsw(a_sum[1], a[2], a_sum[1]); + } + } + } + + if (n > 2) { + // Reusing a + e->vpunpcklwd(a[0], a01_lo, a23_lo); + e->vpunpckhwd(a[1], a01_lo, a23_lo); + if (remainder >= 16) { + e->vpunpcklwd(a[2], a01_hi, a23_hi); + e->vpunpckhwd(a[3], a01_hi, a23_hi); + } + + e->vpmaddubsw(a[0], a[0], x86::ymmword_ptr(b)); + e->vpmaddubsw(a[1], a[1], x86::ymmword_ptr(b, 32)); + if (remainder >= 16) { + e->vpmaddubsw(a[2], a[2], x86::ymmword_ptr(b, 64)); + e->vpmaddubsw(a[3], a[3], x86::ymmword_ptr(b, 96)); + } + + if (accumulation) { + e->vpmaddwd(a[0], a[0], one_epi16); + e->vpaddd(c[0], c[0], a[0]); + e->vpmaddwd(a[1], a[1], one_epi16); + e->vpaddd(c[1], c[1], a[1]); + + if (remainder >= 16) { + e->vpmaddwd(a[2], a[2], one_epi16); + e->vpaddd(c[2], c[2], a[2]); + e->vpmaddwd(a[3], a[3], one_epi16); + e->vpaddd(c[3], c[3], a[3]); + } + } else { + e->vpmaddwd(c[0], a[0], one_epi16); + e->vpmaddwd(c[1], a[1], one_epi16); + + if (remainder >= 16) { + e->vpmaddwd(c[2], a[2], one_epi16); + e->vpmaddwd(c[3], a[3], one_epi16); + } + } + } else { + // Reusing a + e->vpmaddubsw(a[0], a01_lo, x86::ymmword_ptr(b)); + e->vpmaddubsw(a[1], a01_hi, x86::ymmword_ptr(b, 32)); + + if (accumulation) { + e->vpmovsxwd(a[2], x86::Xmm(a[0].id())); + e->vpaddd(c[0], c[0], a[2]); + e->vpmovsxwd(a[3], x86::Xmm(a[1].id())); + e->vpaddd(c[1], c[1], a[3]); + + if (remainder >= 16) { + e->vextracti128(x86::Xmm(a[0].id()), a[0], asmjit::Imm(1)); + e->vpmovsxwd(a[0], x86::Xmm(a[0].id())); + e->vpaddd(c[2], c[2], a[0]); + e->vextracti128(x86::Xmm(a[1].id()), a[1], asmjit::Imm(1)); + e->vpmovsxwd(a[1], x86::Xmm(a[1].id())); + e->vpaddd(c[3], c[3], a[1]); + } + } else { + e->vpmovsxwd(c[0], x86::Xmm(a[0].id())); + e->vpmovsxwd(c[1], x86::Xmm(a[1].id())); + + if (remainder >= 16) { + e->vextracti128(x86::Xmm(a[0].id()), a[0], asmjit::Imm(1)); + e->vpmovsxwd(c[2], x86::Xmm(a[0].id())); + e->vextracti128(x86::Xmm(a[1].id()), a[1], asmjit::Imm(1)); + e->vpmovsxwd(c[3], x86::Xmm(a[1].id())); + } + } + } +} + +GenI8Depthwise::jit_kernel_signature GenI8Depthwise::getOrCreate( + int D, + int S, + bool compute_a_sum, + bool per_channel_quantization, + int remainder, + int prev_skip, + int next_skip, + int top_skip, + int bottom_skip, + int left_skip, + int right_skip) { + std::tuple<int, int, bool, bool, int, int, int, int, int, int, int> + kernelSig = std::make_tuple( + D, + S, + compute_a_sum, + per_channel_quantization, + remainder, + prev_skip, + next_skip, + top_skip, + bottom_skip, + left_skip, + right_skip); + + return codeCache_.getOrCreate(kernelSig, [&]() -> jit_kernel_signature { + asmjit::CodeHolder code; + code.init(runtime().codeInfo()); + x86::Assembler assembler(&code); + x86::Emitter* e = assembler.as<x86::Emitter>(); + + x86::Gp a_addr = e->zdi(); + x86::Gp b_addr = e->zsi(); + x86::Gp c_addr = e->zdx(); + x86::Gp a_sum_addr = e->zcx(); + x86::Gp h = e->gpz(8); + x86::Gp w = e->gpz(9); + x86::Gp c_in = e->gpz(10); + x86::Gp mask_addr = e->gpz(11); + x86::Gp a_zero_point = e->gpz(12); + x86::Gp b_zero_point_addr = e->gpz(13); + x86::Gp ic_loop_count = e->gpz(14); + x86::Gp a_addr_save = e->gpz(15); + + asmjit::FuncDetail func; + func.init(asmjit::FuncSignatureT< + void, + const std::uint8_t*, + const std::int8_t*, + std::int32_t*, + std::int32_t*, + int, + int, + int, + const int*, + int, + const std::int32_t*>(asmjit::CallConv::kIdHost)); + + asmjit::FuncFrame frame; + frame.init(func); + + frame.setDirtyRegs( + x86::Reg::kGroupVec, + asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + frame.setDirtyRegs( + x86::Reg::kGroupGp, + asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + + asmjit::FuncArgsAssignment args(&func); + args.assignAll( + a_addr, + b_addr, + c_addr, + a_sum_addr, + h, + w, + c_in, + mask_addr, + a_zero_point, + b_zero_point_addr); + + args.updateFuncFrame(frame); + frame.finalize(); + + e->emitProlog(frame); + e->emitArgsAssignment(frame, args); + + // Assign vector registers + x86::Ymm a[4]; + x86::Ymm c[4]; + x86::Ymm a_sum[2]; + + int vreg_id = 2; // reserve 2 for temp vreg + for (int i = 0; i < 4; ++i, ++vreg_id) { + a[i] = x86::Ymm(vreg_id); + } + for (int i = 0; i < 4; ++i, ++vreg_id) { + c[i] = x86::Ymm(vreg_id); + } + if (compute_a_sum) { + a_sum[0] = x86::Ymm(vreg_id); + ++vreg_id; + a_sum[1] = x86::Ymm(vreg_id); + ++vreg_id; + } + x86::Ymm mask_vreg(vreg_id); + constexpr int vlen = simd_info<inst_set_t::avx2>::WIDTH_32BIT_ELEMS; + if (remainder != simd_info<inst_set_t::avx2>::WIDTH_BYTES) { + ++vreg_id; + e->vmovups( + mask_vreg, + x86::ymmword_ptr( + mask_addr, (vlen - remainder / 4) % vlen * sizeof(int32_t))); + } + x86::Ymm one_epi8(vreg_id); + if (compute_a_sum) { + ++vreg_id; + gen8BitVectorOne(e, one_epi8); + } + + int K = 1; + for (int i = 0; i < D; ++i) { + K *= S; + } + x86::Ymm one_epi16(vreg_id); + if (K > 2) { + ++vreg_id; + gen16BitVectorOne(e, one_epi16); + } + + bool has_pad = prev_skip || next_skip || top_skip || bottom_skip || + left_skip || right_skip; + bool need_zero = K % 4 == 3 || K % 4 == 1; + // When out of registers, zero and A_zero_point_vreg need to share. + bool recompute_zero = vreg_id == 15 && need_zero; + + x86::Ymm a_zero_point_vreg(vreg_id); + if (!recompute_zero && has_pad) { + e->movq(x86::Xmm(a_zero_point_vreg.id()), a_zero_point); + e->vpbroadcastb(a_zero_point_vreg, x86::Xmm(a_zero_point_vreg.id())); + } + if (vreg_id < 15) { + ++vreg_id; + } + x86::Ymm zero(vreg_id); + if (need_zero && (!recompute_zero || !has_pad)) { + e->vxorps(zero, zero, zero); + } + + // Assign scalar registers + e->imul(w, c_in); + e->imul(h, w); + if (D >= 3) { + e->mov(a_addr_save, w); + e->imul(a_addr_save, S); + e->sub(h, a_addr_save); + } + e->mov(a_addr_save, c_in); + e->imul(a_addr_save, S); + e->sub(w, a_addr_save); + + e->mov(ic_loop_count, c_in); + e->add(ic_loop_count, asmjit::Imm(31)); + e->sar(ic_loop_count, asmjit::Imm(5)); + + e->mov(a_addr_save, a_addr); + asmjit::Label ic_loop_begin = e->newLabel(), ic_loop_end = e->newLabel(); + + // main_loop == false: the last vector iteration across input channels + for (bool main_loop : {true, false}) { + if (main_loop) { + e->bind(ic_loop_begin); + e->dec(ic_loop_count); + e->jle(ic_loop_end); + } + + if (recompute_zero && has_pad) { + e->movq(x86::Xmm(a_zero_point_vreg.id()), a_zero_point); + e->vpbroadcastb(a_zero_point_vreg, x86::Xmm(a_zero_point_vreg.id())); + } + + int i = 0; + // Iterate across the reduction (filter) dimension + for (int f_t = 0; f_t < ((D == 2) ? 1 : S); ++f_t) { + for (int f_h = 0; f_h < S; ++f_h) { + for (int f_w = 0; f_w < S; ++f_w, ++i) { + bool pad = false; + if (D > 2) { + if (f_t < prev_skip || f_t >= S - next_skip) { + pad = true; + } + } + if (f_h < top_skip || f_h >= S - bottom_skip || f_w < left_skip || + f_w >= S - right_skip) { + pad = true; + } + + // Load A + if (pad) { + e->vmovups(a[i % 4], a_zero_point_vreg); + } else { + if (!main_loop && remainder != 32) { + e->vmaskmovps(a[i % 4], mask_vreg, x86::ymmword_ptr(a_addr)); + } else { + e->vmovups(a[i % 4], x86::ymmword_ptr(a_addr)); + } + } + + // Compute when we have 4 inputs or this is the last iteration + if (i % 4 == 3 || i == K - 1) { + if (i == K - 1 && (i / 4 * 4 == K - 3 || i / 4 * 4 == K - 1)) { + if (recompute_zero && has_pad) { + e->vxorps(zero, zero, zero); + } + } + + genMaddEpi16xNPacked( + e, + a, + b_addr, + c, + compute_a_sum ? a_sum : nullptr, + /*n=*/std::min(K - i / 4 * 4, 4), + main_loop ? 32 : remainder, + /*accumulation=*/i / 4 > 0, + one_epi8, + one_epi16, + zero); + + if (i != K - 1) { + e->add(b_addr, asmjit::Imm(32 * 4)); + } else if (main_loop) { + e->add(b_addr, asmjit::Imm(32 * (K - i / 4 * 4 + 1) / 2 * 2)); + } + + if (K - i / 4 * 4 >= 3 && K - i / 4 * 4 <= 6) { + for (int r = 0; r < (main_loop ? 4 : remainder / 8); ++r) { + e->vperm2f128( + a[r], + c[r % 2 * 2], + c[r % 2 * 2 + 1], + asmjit::Imm(r < 2 ? 0x20 : 0x31)); + } + for (int r = 0; r < (main_loop ? 4 : remainder / 8); ++r) { + e->vmovaps(c[r], a[r]); + } + } + } + if (i != K - 1) { + e->add(a_addr, c_in); + } + } + if (i != K - 1) { + e->add(a_addr, w); + } + } + if (D >= 3 && i != K - 1) { + e->add(a_addr, h); + } + } + + for (int r = 0; r < (main_loop ? 4 : remainder / 8); ++r) { + e->vmovups(x86::ymmword_ptr(c_addr, r * 32), c[r]); + } + + if (compute_a_sum) { + if (per_channel_quantization) { + e->vmovups(c[0], x86::ymmword_ptr(b_zero_point_addr)); + } else { + e->vpbroadcastd(c[0], x86::dword_ptr(b_zero_point_addr)); + } + e->vpmovsxwd(a[0], x86::Xmm(a_sum[0].id())); + e->vpmulld(a[0], a[0], c[0]); + e->vmovups(x86::ymmword_ptr(a_sum_addr), a[0]); + + if (main_loop || remainder >= 8) { + if (per_channel_quantization) { + e->vmovups(c[0], x86::ymmword_ptr(b_zero_point_addr, 32)); + } + e->vpmovsxwd(a[1], x86::Xmm(a_sum[1].id())); + e->vpmulld(a[1], a[1], c[0]); + e->vmovups(x86::ymmword_ptr(a_sum_addr, 32), a[1]); + } + + if (main_loop || remainder >= 16) { + if (per_channel_quantization) { + e->vmovups(c[0], x86::ymmword_ptr(b_zero_point_addr, 64)); + } + e->vextracti128(x86::Xmm(a_sum[0].id()), a_sum[0], asmjit::Imm(1)); + e->vpmovsxwd(a_sum[0], x86::Xmm(a_sum[0].id())); + e->vpmulld(a_sum[0], a_sum[0], c[0]); + e->vmovups(x86::ymmword_ptr(a_sum_addr, 64), a_sum[0]); + } + + if (main_loop || remainder >= 24) { + if (per_channel_quantization) { + e->vmovups(c[0], x86::ymmword_ptr(b_zero_point_addr, 96)); + } + e->vextracti128(x86::Xmm(a_sum[1].id()), a_sum[1], asmjit::Imm(1)); + e->vpmovsxwd(a_sum[1], x86::Xmm(a_sum[1].id())); + e->vpmulld(a_sum[1], a_sum[1], c[0]); + e->vmovups(x86::ymmword_ptr(a_sum_addr, 96), a_sum[1]); + } + + if (main_loop) { + if (per_channel_quantization) { + e->add(b_zero_point_addr, asmjit::Imm(128)); + } + e->add(a_sum_addr, asmjit::Imm(128)); + } + } + + if (main_loop) { + e->add(c_addr, asmjit::Imm(128)); + e->add(a_addr_save, asmjit::Imm(32)); + e->mov(a_addr, a_addr_save); + e->jmp(ic_loop_begin); + + e->bind(ic_loop_end); + } + } + + e->emitEpilog(frame); + + jit_kernel_signature fn; + asmjit::Error err; + { + std::unique_lock<std::mutex> lock(rtMutex_); + err = runtime().add(&fn, &code); + } + if (err) { + std::cout << "Error: in fn add" << std::endl; + return nullptr; + } + + return fn; + }); +} + +} // namespace fbgemm diff --git a/src/GenerateI8Depthwise.h b/src/GenerateI8Depthwise.h new file mode 100644 index 0000000..4e5d2ee --- /dev/null +++ b/src/GenerateI8Depthwise.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include <cstdint> + +namespace fbgemm { + +class GenI8Depthwise { + public: + using jit_kernel_signature = void (*)( + const std::uint8_t* a, + const std::int8_t* b, + std::int32_t* c, + std::int32_t* a_sum, // row_wise sum of A + int h, + int w, + int c_in, // the number of input channels + const int* mask, + int A_zero_point, + const int32_t* B_zero_point); + + jit_kernel_signature getOrCreate( + int D, // dimension + int F, // filter size per dimension + bool compute_a_sum, + bool per_channel_quantization, + int remainder, // the number of channels in the remainder loop + int prev_skip, + int next_skip, + int top_skip, + int bottom_skip, + int left_skip, + int right_skip); +}; + +} // namespace fbgemm |