diff options
author | Jongsoo Park <jongsoo@fb.com> | 2019-04-03 17:59:57 +0300 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-04-03 18:02:56 +0300 |
commit | c6e86067e41a363af718dae7f8d7494068aad868 (patch) | |
tree | 29aafe2a702a9a5da25eaa1a7c17d97cc531de89 /src | |
parent | f12ec122be12b0647ada3ff2c374cca57aa4ae95 (diff) |
optimize dw conv for symmetric quant (#73)
Summary:
Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/73
Skip computing row_offset if B uses symmetric quantization. Skip adding col_offset if A uses symmetric quantization.
Reviewed By: jianyuh
Differential Revision: D14055973
fbshipit-source-id: 91da8f0755b2f90175e94a893b5a3ad6342c506d
Diffstat (limited to 'src')
-rw-r--r-- | src/FbgemmI8DepthwiseAvx2.cc | 685 | ||||
-rw-r--r-- | src/FbgemmI8DepthwiseAvx2.h | 45 |
2 files changed, 631 insertions, 99 deletions
diff --git a/src/FbgemmI8DepthwiseAvx2.cc b/src/FbgemmI8DepthwiseAvx2.cc index 017c4c8..2620e43 100644 --- a/src/FbgemmI8DepthwiseAvx2.cc +++ b/src/FbgemmI8DepthwiseAvx2.cc @@ -525,7 +525,12 @@ static inline __attribute__((always_inline)) void inner_prod_3x3_packed_( // Almost same as ReQuantizeOutput in OutputProcessing-inh.h but different // row_offsets for each row because of depth-wise convolution -template <bool FUSE_RELU, bool HAS_BIAS, bool PER_CHANNEL_QUANTIZATION> +template < + bool FUSE_RELU, + bool HAS_BIAS, + bool PER_CHANNEL_QUANTIZATION, + bool A_SYMMETRIC, + bool B_SYMMETRIC> static inline __attribute__((always_inline)) void requantize_( int32_t A_zero_point, const float* C_multiplier, @@ -544,6 +549,9 @@ static inline __attribute__((always_inline)) void requantize_( __m256i min_v = _mm256_set1_epi8(static_cast<uint8_t>(0)); __m256i max_v = _mm256_set1_epi8(static_cast<uint8_t>(255)); + if (A_SYMMETRIC) { + assert(A_zero_point == 0 || col_offsets == nullptr); + } __m256i A_zero_point_v = _mm256_set1_epi32(A_zero_point); __m256i C_zero_point_epi16_v = _mm256_set1_epi16(C_zero_point); __m256i C_zero_point_epi8_v = _mm256_set1_epi8(C_zero_point); @@ -563,36 +571,59 @@ static inline __attribute__((always_inline)) void requantize_( __m256i w_v = _mm256_loadu_si256( reinterpret_cast<const __m256i*>(C_int32 + j + 3 * VLEN)); - __m256i col_off_v = _mm256_mullo_epi32( - A_zero_point_v, - _mm256_loadu_si256(reinterpret_cast<const __m256i*>(col_offsets + j))); - __m256i row_offset_v = - _mm256_loadu_si256(reinterpret_cast<const __m256i*>(row_offsets + j)); - x_v = _mm256_sub_epi32(_mm256_sub_epi32(x_v, col_off_v), row_offset_v); - - col_off_v = _mm256_mullo_epi32( - A_zero_point_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(col_offsets + j + VLEN))); - row_offset_v = _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(row_offsets + j + VLEN)); - y_v = _mm256_sub_epi32(_mm256_sub_epi32(y_v, col_off_v), row_offset_v); - - col_off_v = _mm256_mullo_epi32( - A_zero_point_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(col_offsets + j + 2 * VLEN))); - row_offset_v = _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(row_offsets + j + 2 * VLEN)); - z_v = _mm256_sub_epi32(_mm256_sub_epi32(z_v, col_off_v), row_offset_v); - - col_off_v = _mm256_mullo_epi32( - A_zero_point_v, - _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(col_offsets + j + 3 * VLEN))); - row_offset_v = _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(row_offsets + j + 3 * VLEN)); - w_v = _mm256_sub_epi32(_mm256_sub_epi32(w_v, col_off_v), row_offset_v); + __m256i row_offset_v; + if (!B_SYMMETRIC) { + row_offset_v = + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(row_offsets + j)); + x_v = _mm256_sub_epi32(x_v, row_offset_v); + } + __m256i col_off_v; + if (!A_SYMMETRIC) { + col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(col_offsets + j))); + x_v = _mm256_sub_epi32(x_v, col_off_v); + } + + if (!B_SYMMETRIC) { + row_offset_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(row_offsets + j + VLEN)); + y_v = _mm256_sub_epi32(y_v, row_offset_v); + } + if (!A_SYMMETRIC) { + col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(col_offsets + j + VLEN))); + y_v = _mm256_sub_epi32(y_v, col_off_v); + } + + if (!B_SYMMETRIC) { + row_offset_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(row_offsets + j + 2 * VLEN)); + z_v = _mm256_sub_epi32(z_v, row_offset_v); + } + if (!A_SYMMETRIC) { + col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(col_offsets + j + 2 * VLEN))); + z_v = _mm256_sub_epi32(z_v, col_off_v); + } + + if (!B_SYMMETRIC) { + row_offset_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(row_offsets + j + 3 * VLEN)); + w_v = _mm256_sub_epi32(w_v, row_offset_v); + } + if (!A_SYMMETRIC) { + col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(col_offsets + j + 3 * VLEN))); + w_v = _mm256_sub_epi32(w_v, col_off_v); + } if (HAS_BIAS) { // static if x_v = _mm256_add_epi32( @@ -653,12 +684,18 @@ static inline __attribute__((always_inline)) void requantize_( __m256i x_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j)); - __m256i col_off_v = _mm256_mullo_epi32( - A_zero_point_v, - _mm256_loadu_si256(reinterpret_cast<const __m256i*>(col_offsets + j))); - __m256i row_offset_v = - _mm256_loadu_si256(reinterpret_cast<const __m256i*>(row_offsets + j)); - x_v = _mm256_sub_epi32(_mm256_sub_epi32(x_v, col_off_v), row_offset_v); + if (!B_SYMMETRIC) { + __m256i row_offset_v = + _mm256_loadu_si256(reinterpret_cast<const __m256i*>(row_offsets + j)); + x_v = _mm256_sub_epi32(x_v, row_offset_v); + } + if (!A_SYMMETRIC) { + __m256i col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(col_offsets + j))); + x_v = _mm256_sub_epi32(x_v, col_off_v); + } if (HAS_BIAS) { // static if x_v = _mm256_add_epi32( @@ -687,7 +724,13 @@ static inline __attribute__((always_inline)) void requantize_( } // j loop vectorized for (; j < n; ++j) { - int32_t raw = C_int32[j] - A_zero_point * col_offsets[j] - row_offsets[j]; + int32_t raw = C_int32[j]; + if (!B_SYMMETRIC) { + raw -= row_offsets[j]; + } + if (!A_SYMMETRIC) { + raw -= A_zero_point * col_offsets[j]; + } if (HAS_BIAS) { // static if raw += bias[j]; } @@ -1074,7 +1117,7 @@ static inline __attribute__((always_inline)) void inner_prod_3x3x3_packed_( } } -template <bool FUSE_RELU, bool HAS_BIAS> +template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, bool B_SYMMETRIC> static inline __attribute__((always_inline)) void depthwise_3x3_kernel_( int H, int W, @@ -1102,7 +1145,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_kernel_( int k; for (k = 0; k < K / 32 * 32; k += 32) { - inner_prod_3x3_packed_<true /*SUM_A*/>( + inner_prod_3x3_packed_<!B_SYMMETRIC /*SUM_A*/>( H, W, K, @@ -1114,11 +1157,11 @@ static inline __attribute__((always_inline)) void depthwise_3x3_kernel_( &B_zero_point, C_int32 + k, 0, - &row_offsets[k]); + B_SYMMETRIC ? nullptr : &row_offsets[k]); } int remainder = K - k; if (remainder) { - inner_prod_3x3_packed_<true /*SUM_A*/, true /*REMAINDER*/>( + inner_prod_3x3_packed_<!B_SYMMETRIC, true>( H, W, K, @@ -1130,10 +1173,15 @@ static inline __attribute__((always_inline)) void depthwise_3x3_kernel_( &B_zero_point, C_int32 + k, remainder, - &row_offsets[k]); + B_SYMMETRIC ? nullptr : &row_offsets[k]); } - requantize_<FUSE_RELU, HAS_BIAS, false /*PER_CHAN_QUANT*/>( + requantize_< + FUSE_RELU, + HAS_BIAS, + false, /*PER_CHAN_QUANT*/ + A_SYMMETRIC, + B_SYMMETRIC>( A_zero_point, &C_multiplier, C_zero_point, @@ -1145,7 +1193,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_kernel_( bias); } -template <bool FUSE_RELU, bool HAS_BIAS> +template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, bool B_SYMMETRIC> static inline __attribute__((always_inline)) void depthwise_3x3x3_kernel_( int T, int H, @@ -1178,7 +1226,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_kernel_( int k; for (k = 0; k < K / 32 * 32; k += 32) { - inner_prod_3x3x3_packed_<true /*SUM_A*/>( + inner_prod_3x3x3_packed_<!B_SYMMETRIC /*SUM_A*/>( T, H, W, @@ -1192,11 +1240,11 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_kernel_( &B_zero_point, C_int32 + k, 0, - &row_offsets[k]); + B_SYMMETRIC ? nullptr : &row_offsets[k]); } int remainder = K - k; if (remainder) { - inner_prod_3x3x3_packed_<true /*SUM_A*/, true /*REMAINDER*/>( + inner_prod_3x3x3_packed_<!B_SYMMETRIC /*SUM_A*/, true>( T, H, W, @@ -1210,10 +1258,15 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_kernel_( &B_zero_point, C_int32 + k, remainder, - &row_offsets[k]); + B_SYMMETRIC ? nullptr : &row_offsets[k]); } - requantize_<FUSE_RELU, HAS_BIAS, false /*PER_CHAN_QUANT*/>( + requantize_< + FUSE_RELU, + HAS_BIAS, + false, /*PER_CHAN_QUANT*/ + A_SYMMETRIC, + B_SYMMETRIC>( A_zero_point, &C_multiplier, C_zero_point, @@ -1225,7 +1278,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_kernel_( bias); } -template <bool FUSE_RELU, bool HAS_BIAS> +template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC> static inline __attribute__((always_inline)) void depthwise_3x3_per_channel_quantization_kernel_( int H, @@ -1255,8 +1308,8 @@ depthwise_3x3_per_channel_quantization_kernel_( int k; for (k = 0; k < K / 32 * 32; k += 32) { inner_prod_3x3_packed_< - true /*SUM_A*/, - false /*remainder*/, + true, /*SUM_A*/ + false, /*remainder*/ true /*per-channel*/>( H, W, @@ -1274,8 +1327,8 @@ depthwise_3x3_per_channel_quantization_kernel_( int remainder = K - k; if (remainder) { inner_prod_3x3_packed_< - true /*SUM_A*/, - true /*remainder*/, + true, /*SUM_A*/ + true, /*remainder*/ true /*per-channel*/>( H, W, @@ -1291,7 +1344,12 @@ depthwise_3x3_per_channel_quantization_kernel_( &row_offsets[k]); } - requantize_<FUSE_RELU, HAS_BIAS, true /*PER_CHAN_QUANT*/>( + requantize_< + FUSE_RELU, + HAS_BIAS, + true, /*PER_CHAN_QUANT*/ + A_SYMMETRIC, + false /*B_SYMM*/>( A_zero_point, C_multiplier, C_zero_point, @@ -1303,7 +1361,7 @@ depthwise_3x3_per_channel_quantization_kernel_( bias); } -template <bool FUSE_RELU, bool HAS_BIAS> +template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC> static inline __attribute__((always_inline)) void depthwise_3x3x3_per_channel_quantization_kernel_( int T, @@ -1338,8 +1396,8 @@ depthwise_3x3x3_per_channel_quantization_kernel_( int k; for (k = 0; k < K / 32 * 32; k += 32) { inner_prod_3x3x3_packed_< - true /*SUM_A*/, - false /*remainder*/, + true, /*SUM_A*/ + false, /*remainder*/ true /*per-channel*/>( T, H, @@ -1359,8 +1417,8 @@ depthwise_3x3x3_per_channel_quantization_kernel_( int remainder = K - k; if (remainder) { inner_prod_3x3x3_packed_< - true /*SUM_A*/, - true /*remainder*/, + true, /*SUM_A*/ + true, /*remainder*/ true /*per-channel*/>( T, H, @@ -1377,7 +1435,12 @@ depthwise_3x3x3_per_channel_quantization_kernel_( remainder, &row_offsets[k]); } - requantize_<FUSE_RELU, HAS_BIAS, true /*PER_CHAN_QUANT*/>( + requantize_< + FUSE_RELU, + HAS_BIAS, + true, /*PER_CHAN_QUANT*/ + A_SYMMETRIC, + false /*B_SYMM*/>( A_zero_point, C_multiplier, C_zero_point, @@ -1401,7 +1464,7 @@ static pair<int, int> closest_factors_(int n) { // This implemntation should be general enough to handle not just 3x3 but other // filter shapes by parameterizing with R and S but restricting it to just 3x3 // for now. -template <bool FUSE_RELU, bool HAS_BIAS> +template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, bool B_SYMMETRIC> static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( int N, int H, @@ -1477,7 +1540,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( if (h_begin == 0) { if (w_begin == 0) { - depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS>( + depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>( H, W, K, @@ -1499,7 +1562,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( } for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { - depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS>( + depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>( H, W, K, @@ -1522,7 +1585,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( if (w_end == W_OUT) { w = W_OUT - 1; - depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS>( + depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>( H, W, K, @@ -1547,7 +1610,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( for (h = std::max(1, h_begin); h < std::min(H - 1, h_end); ++h) { if (w_begin == 0) { w = 0; - depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS>( + depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>( H, W, K, @@ -1569,7 +1632,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( } for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { - depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS>( + depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>( H, W, K, @@ -1592,7 +1655,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( if (w_end == W_OUT) { w = W_OUT - 1; - depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS>( + depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>( H, W, K, @@ -1618,7 +1681,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( h = H_OUT - 1; w = 0; if (w_begin == 0) { - depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS>( + depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>( H, W, K, @@ -1640,7 +1703,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( } for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { - depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS>( + depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>( H, W, K, @@ -1663,7 +1726,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( if (w_end == W_OUT) { w = W_OUT - 1; - depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS>( + depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>( H, W, K, @@ -1687,7 +1750,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( } // for each n }; -template <bool FUSE_RELU, bool HAS_BIAS> +template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, bool B_SYMMETRIC> static inline __attribute__((always_inline)) void depthwise_3x3x3_pad_1_( int N, int T, @@ -1765,7 +1828,11 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_pad_1_( for (int t = t_begin; t < t_end; ++t) { for (int h = h_begin; h < h_end; ++h) { for (int w = 0; w < W_OUT; ++w) { - depthwise_3x3x3_kernel_<FUSE_RELU, HAS_BIAS>( + depthwise_3x3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC>( T, H, W, @@ -1793,7 +1860,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_pad_1_( } // for each n }; -template <bool FUSE_RELU, bool HAS_BIAS> +template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC> static inline __attribute__((always_inline)) void depthwise_3x3_per_channel_quantization_pad_1_( int N, @@ -1870,7 +1937,10 @@ depthwise_3x3_per_channel_quantization_pad_1_( if (h_begin == 0) { if (w_begin == 0) { - depthwise_3x3_per_channel_quantization_kernel_<FUSE_RELU, HAS_BIAS>( + depthwise_3x3_per_channel_quantization_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC>( H, W, K, @@ -1892,7 +1962,10 @@ depthwise_3x3_per_channel_quantization_pad_1_( } for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { - depthwise_3x3_per_channel_quantization_kernel_<FUSE_RELU, HAS_BIAS>( + depthwise_3x3_per_channel_quantization_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC>( H, W, K, @@ -1915,7 +1988,10 @@ depthwise_3x3_per_channel_quantization_pad_1_( if (w_end == W_OUT) { w = W_OUT - 1; - depthwise_3x3_per_channel_quantization_kernel_<FUSE_RELU, HAS_BIAS>( + depthwise_3x3_per_channel_quantization_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC>( H, W, K, @@ -1940,7 +2016,10 @@ depthwise_3x3_per_channel_quantization_pad_1_( for (h = std::max(1, h_begin); h < std::min(H - 1, h_end); ++h) { if (w_begin == 0) { w = 0; - depthwise_3x3_per_channel_quantization_kernel_<FUSE_RELU, HAS_BIAS>( + depthwise_3x3_per_channel_quantization_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC>( H, W, K, @@ -1962,7 +2041,10 @@ depthwise_3x3_per_channel_quantization_pad_1_( } for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { - depthwise_3x3_per_channel_quantization_kernel_<FUSE_RELU, HAS_BIAS>( + depthwise_3x3_per_channel_quantization_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC>( H, W, K, @@ -1985,7 +2067,10 @@ depthwise_3x3_per_channel_quantization_pad_1_( if (w_end == W_OUT) { w = W_OUT - 1; - depthwise_3x3_per_channel_quantization_kernel_<FUSE_RELU, HAS_BIAS>( + depthwise_3x3_per_channel_quantization_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC>( H, W, K, @@ -2011,7 +2096,10 @@ depthwise_3x3_per_channel_quantization_pad_1_( h = H_OUT - 1; w = 0; if (w_begin == 0) { - depthwise_3x3_per_channel_quantization_kernel_<FUSE_RELU, HAS_BIAS>( + depthwise_3x3_per_channel_quantization_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC>( H, W, K, @@ -2033,7 +2121,10 @@ depthwise_3x3_per_channel_quantization_pad_1_( } for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { - depthwise_3x3_per_channel_quantization_kernel_<FUSE_RELU, HAS_BIAS>( + depthwise_3x3_per_channel_quantization_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC>( H, W, K, @@ -2056,7 +2147,10 @@ depthwise_3x3_per_channel_quantization_pad_1_( if (w_end == W_OUT) { w = W_OUT - 1; - depthwise_3x3_per_channel_quantization_kernel_<FUSE_RELU, HAS_BIAS>( + depthwise_3x3_per_channel_quantization_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC>( H, W, K, @@ -2080,7 +2174,7 @@ depthwise_3x3_per_channel_quantization_pad_1_( } // for each n }; -template <bool FUSE_RELU, bool HAS_BIAS> +template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC> static inline __attribute__((always_inline)) void depthwise_3x3x3_per_channel_quantization_pad_1_( int N, @@ -2159,7 +2253,10 @@ depthwise_3x3x3_per_channel_quantization_pad_1_( for (int t = t_begin; t < t_end; ++t) { for (int h = h_begin; h < h_end; ++h) { for (int w = 0; w < W_OUT; ++w) { - depthwise_3x3x3_per_channel_quantization_kernel_<FUSE_RELU, HAS_BIAS>( + depthwise_3x3x3_per_channel_quantization_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC>( T, H, W, @@ -2187,6 +2284,130 @@ depthwise_3x3x3_per_channel_quantization_pad_1_( } // for each n }; +// Dispatch A_SYMMETRIC and B_SYMMETRIC +template <bool FUSE_RELU, bool HAS_BIAS> +static void depthwise_3x3_pad_1_( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + int32_t B_zero_point, + const Packed3x3ConvMatrix& B, + float C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias, + int thread_id, + int num_threads) { + int32_t C_int32_temp[(K + 31) / 32 * 32]; + if (A_zero_point == 0 || col_offsets == nullptr) { + if (B_zero_point == 0) { + depthwise_3x3_pad_1_< + FUSE_RELU, + HAS_BIAS, + true /*A_symmetric*/, + true /*B_symmetric*/>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } else { + depthwise_3x3_pad_1_< + FUSE_RELU, + HAS_BIAS, + true /*A_symmetric*/, + false /*B_symmetric*/>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } + } else { + if (B_zero_point == 0) { + depthwise_3x3_pad_1_< + FUSE_RELU, + HAS_BIAS, + false /*A_symmetric*/, + true /*B_symmetric*/>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } else { + depthwise_3x3_pad_1_< + FUSE_RELU, + HAS_BIAS, + false /*A_symmetric*/, + false /*B_symmetric*/>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } + } +} + // Dispatch HAS_BIAS template <bool FUSE_RELU> static void depthwise_3x3_pad_1_( @@ -2207,7 +2428,6 @@ static void depthwise_3x3_pad_1_( const int32_t* bias, int thread_id, int num_threads) { - int32_t C_int32_temp[(K + 31) / 32 * 32]; if (bias) { depthwise_3x3_pad_1_<FUSE_RELU, true /*HAS_BIAS*/>( N, @@ -2222,7 +2442,6 @@ static void depthwise_3x3_pad_1_( B, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, @@ -2242,7 +2461,6 @@ static void depthwise_3x3_pad_1_( B, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, @@ -2469,6 +2687,140 @@ void depthwise_3x3_pad_1( } } +// Dispatch A_SYMMETRIC and B_SYMMETRIC +template <bool FUSE_RELU, bool HAS_BIAS> +static void depthwise_3x3x3_pad_1_( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + int32_t B_zero_point, + const Packed3x3x3ConvMatrix& B, + float C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias, + int thread_id, + int num_threads) { + int32_t C_int32_temp[(K + 31) / 32 * 32]; + if (A_zero_point == 0 || col_offsets == nullptr) { + if (B_zero_point == 0) { + depthwise_3x3x3_pad_1_< + FUSE_RELU, + HAS_BIAS, + true /*A_symmetric*/, + true /*B_symmetric*/>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } else { + depthwise_3x3x3_pad_1_< + FUSE_RELU, + HAS_BIAS, + true /*A_symmetric*/, + false /*B_symmetric*/>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } + } else { + if (B_zero_point == 0) { + depthwise_3x3x3_pad_1_< + FUSE_RELU, + HAS_BIAS, + false /*A_symmetric*/, + true /*B_symmetric*/>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } else { + depthwise_3x3x3_pad_1_< + FUSE_RELU, + HAS_BIAS, + false /*A_symmetric*/, + false /*B_symmetric*/>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } + } +} + // Dispatch HAS_BIAS template <bool FUSE_RELU> static void depthwise_3x3x3_pad_1_( @@ -2491,7 +2843,6 @@ static void depthwise_3x3x3_pad_1_( const int32_t* bias, int thread_id, int num_threads) { - int32_t C_int32_temp[(K + 31) / 32 * 32]; if (bias) { depthwise_3x3x3_pad_1_<FUSE_RELU, true /*HAS_BIAS*/>( N, @@ -2508,7 +2859,6 @@ static void depthwise_3x3x3_pad_1_( B, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, @@ -2530,7 +2880,6 @@ static void depthwise_3x3x3_pad_1_( B, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, @@ -2606,6 +2955,76 @@ void depthwise_3x3x3_pad_1( } } +// Dispatch A_SYMMETRIC +template <bool FUSE_RELU, bool HAS_BIAS> +static void depthwise_3x3_per_channel_quantization_pad_1_( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const Packed3x3ConvMatrix& Bp, + const float* C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias, + int thread_id, + int num_threads) { + int32_t C_int32_temp[(K + 31) / 32 * 32]; + if (A_zero_point == 0 || col_offsets == nullptr) { + depthwise_3x3_per_channel_quantization_pad_1_< + FUSE_RELU, + HAS_BIAS, + true /*A_SYMM*/>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } else { + depthwise_3x3_per_channel_quantization_pad_1_< + FUSE_RELU, + HAS_BIAS, + false /*A_SYMM*/>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } +} + // Dispatch HAS_BIAS template <bool FUSE_RELU> static void depthwise_3x3_per_channel_quantization_pad_1_( @@ -2626,7 +3045,6 @@ static void depthwise_3x3_per_channel_quantization_pad_1_( const int32_t* bias, int thread_id, int num_threads) { - int32_t C_int32_temp[(K + 31) / 32 * 32]; if (bias) { depthwise_3x3_per_channel_quantization_pad_1_< FUSE_RELU, @@ -2643,7 +3061,6 @@ static void depthwise_3x3_per_channel_quantization_pad_1_( Bp, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, @@ -2665,7 +3082,6 @@ static void depthwise_3x3_per_channel_quantization_pad_1_( Bp, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, @@ -2891,6 +3307,82 @@ void depthwise_3x3_per_channel_quantization_pad_1( } } +// Dispatch A_SYMMETRIC +template <bool FUSE_RELU, bool HAS_BIAS> +static void depthwise_3x3x3_per_channel_quantization_pad_1_( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const Packed3x3x3ConvMatrix& B, + const float* C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias, + int thread_id, + int num_threads) { + int32_t C_int32_temp[(K + 31) / 32 * 32]; + if (A_zero_point == 0 || col_offsets == nullptr) { + depthwise_3x3x3_per_channel_quantization_pad_1_< + FUSE_RELU, + HAS_BIAS, + true /*A_SYMM*/>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } else { + depthwise_3x3x3_per_channel_quantization_pad_1_< + FUSE_RELU, + HAS_BIAS, + false /*A_SYMM*/>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } +} + // Dispatch HAS_BIAS template <bool FUSE_RELU> static void depthwise_3x3x3_per_channel_quantization_pad_1_( @@ -2913,7 +3405,6 @@ static void depthwise_3x3x3_per_channel_quantization_pad_1_( const int32_t* bias, int thread_id, int num_threads) { - int32_t C_int32_temp[(K + 31) / 32 * 32]; if (bias) { depthwise_3x3x3_per_channel_quantization_pad_1_< FUSE_RELU, @@ -2932,7 +3423,6 @@ static void depthwise_3x3x3_per_channel_quantization_pad_1_( B, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, @@ -2956,7 +3446,6 @@ static void depthwise_3x3x3_per_channel_quantization_pad_1_( B, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, diff --git a/src/FbgemmI8DepthwiseAvx2.h b/src/FbgemmI8DepthwiseAvx2.h index 9fd16ae..069ff77 100644 --- a/src/FbgemmI8DepthwiseAvx2.h +++ b/src/FbgemmI8DepthwiseAvx2.h @@ -34,7 +34,6 @@ using Packed3x3x3ConvMatrix = PackedDepthWiseConvMatrix<3 * 3 * 3>; /** * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8 - * * @params A The input image in NHWK layout * @params Bp The pre-packed filter */ @@ -47,6 +46,26 @@ FBGEMM_API void depthwise_3x3_pad_1( int stride_w, std::int32_t A_zero_point, const std::uint8_t* A, + const Packed3x3ConvMatrix& Bp, + std::int32_t* C, + int thread_id = 0, + int num_threads = 1); + +/** + * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8 + * This version is fused with requantization. + * + * @col_offsets nullptr if col_offsets are folded into bias + */ +FBGEMM_API void depthwise_3x3_pad_1( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, std::int32_t B_zero_point, const Packed3x3ConvMatrix& Bp, float C_multiplier, @@ -61,6 +80,8 @@ FBGEMM_API void depthwise_3x3_pad_1( /** * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8 * This version is fused with requantization and uses per-channel quantization. + * + * @col_offsets nullptr if col_offsets are folded into bias */ FBGEMM_API void depthwise_3x3_per_channel_quantization_pad_1( int N, @@ -93,6 +114,25 @@ FBGEMM_API void depthwise_3x3x3_pad_1( int stride_w, std::int32_t A_zero_point, const std::uint8_t* A, + const Packed3x3x3ConvMatrix& Bp, + std::int32_t* C, + int thread_id = 0, + int num_threads = 1); + +/** + * @col_offsets nullptr if col_offsets are folded into bias + */ +FBGEMM_API void depthwise_3x3x3_pad_1( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, std::int32_t B_zero_point, const Packed3x3x3ConvMatrix& Bp, float C_multiplier, @@ -104,6 +144,9 @@ FBGEMM_API void depthwise_3x3x3_pad_1( int thread_id = 0, int num_threads = 1); +/** + * @col_offsets nullptr if col_offsets are folded into bias + */ FBGEMM_API void depthwise_3x3x3_per_channel_quantization_pad_1( int N, int T, |