diff options
author | Jongsoo Park <jongsoo@fb.com> | 2019-02-14 07:35:32 +0300 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-02-14 07:44:53 +0300 |
commit | 05ce78e3a5735217cb9154a2c1572dc956ffe6fc (patch) | |
tree | 6d2486304b84ef15887385ade7ea16b7b62a571e /src | |
parent | 7813a2f2233fa48199b18aa8c03bb439b1fe9ff5 (diff) |
clean up depthwise conv interface (#72)
Summary:
Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/72
depthwise conv without requantization is not really useful and was generating more template parameter options
Reviewed By: jianyuh
Differential Revision: D14021514
fbshipit-source-id: 61f646373fcd902fdb2854a96d003a548f29f8eb
Diffstat (limited to 'src')
-rw-r--r-- | src/FbgemmI8DepthwiseAvx2.cc | 789 | ||||
-rw-r--r-- | src/FbgemmI8DepthwiseAvx2.h | 35 |
2 files changed, 346 insertions, 478 deletions
diff --git a/src/FbgemmI8DepthwiseAvx2.cc b/src/FbgemmI8DepthwiseAvx2.cc index 555b93a..017c4c8 100644 --- a/src/FbgemmI8DepthwiseAvx2.cc +++ b/src/FbgemmI8DepthwiseAvx2.cc @@ -701,52 +701,6 @@ static inline __attribute__((always_inline)) void requantize_( } } -template <bool FUSE_RELU, bool HAS_BIAS> -static inline __attribute__((always_inline)) void requantize_( - int32_t A_zero_point, - float C_multiplier, - int32_t C_zero_point, - const int32_t* C_int32, - uint8_t* C_uint8, - int n, - const int32_t* row_offsets, - const int32_t* col_offsets, - const int32_t* bias) { - requantize_<FUSE_RELU, HAS_BIAS, false /* PER_CHANNEL_QUANTIZATION */>( - A_zero_point, - &C_multiplier, - C_zero_point, - C_int32, - C_uint8, - n, - row_offsets, - col_offsets, - bias); -} - -template <bool FUSE_RELU, bool HAS_BIAS> -static inline __attribute__((always_inline)) void requantize_per_channel_( - int32_t A_zero_point, - const float* C_multiplier, - int32_t C_zero_point, - const int32_t* C_int32, - uint8_t* C_uint8, - int n, - const int32_t* row_offsets, - const int32_t* col_offsets, - const int32_t* bias) { - requantize_<FUSE_RELU, HAS_BIAS, true /* PER_CHANNEL_QUANTIZATION */>( - A_zero_point, - C_multiplier, - C_zero_point, - C_int32, - C_uint8, - n, - row_offsets, - col_offsets, - bias); -} - template <bool REMAINDER> static inline __attribute__((always_inline)) __m256i load_a( const uint8_t* A, @@ -1120,7 +1074,7 @@ static inline __attribute__((always_inline)) void inner_prod_3x3x3_packed_( } } -template <bool SUM_A, bool FUSE_RELU> +template <bool FUSE_RELU, bool HAS_BIAS> static inline __attribute__((always_inline)) void depthwise_3x3_kernel_( int H, int W, @@ -1148,7 +1102,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_kernel_( int k; for (k = 0; k < K / 32 * 32; k += 32) { - inner_prod_3x3_packed_<SUM_A>( + inner_prod_3x3_packed_<true /*SUM_A*/>( H, W, K, @@ -1164,7 +1118,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_kernel_( } int remainder = K - k; if (remainder) { - inner_prod_3x3_packed_<SUM_A, true>( + inner_prod_3x3_packed_<true /*SUM_A*/, true /*REMAINDER*/>( H, W, K, @@ -1178,21 +1132,20 @@ static inline __attribute__((always_inline)) void depthwise_3x3_kernel_( remainder, &row_offsets[k]); } - if (SUM_A) { - requantize_<FUSE_RELU, true>( - A_zero_point, - C_multiplier, - C_zero_point, - C_int32, - C_uint8 + (h * W_OUT + w) * K, - K, - row_offsets, - col_offsets, - bias); - } + + requantize_<FUSE_RELU, HAS_BIAS, false /*PER_CHAN_QUANT*/>( + A_zero_point, + &C_multiplier, + C_zero_point, + C_int32, + C_uint8 + (h * W_OUT + w) * K, + K, + row_offsets, + col_offsets, + bias); } -template <bool SUM_A, bool FUSE_RELU> +template <bool FUSE_RELU, bool HAS_BIAS> static inline __attribute__((always_inline)) void depthwise_3x3x3_kernel_( int T, int H, @@ -1225,7 +1178,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_kernel_( int k; for (k = 0; k < K / 32 * 32; k += 32) { - inner_prod_3x3x3_packed_<SUM_A>( + inner_prod_3x3x3_packed_<true /*SUM_A*/>( T, H, W, @@ -1243,7 +1196,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_kernel_( } int remainder = K - k; if (remainder) { - inner_prod_3x3x3_packed_<SUM_A, true>( + inner_prod_3x3x3_packed_<true /*SUM_A*/, true /*REMAINDER*/>( T, H, W, @@ -1259,21 +1212,20 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_kernel_( remainder, &row_offsets[k]); } - if (SUM_A) { - requantize_<FUSE_RELU, true>( - A_zero_point, - C_multiplier, - C_zero_point, - C_int32, - C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K, - K, - row_offsets, - col_offsets, - bias); - } + + requantize_<FUSE_RELU, HAS_BIAS, false /*PER_CHAN_QUANT*/>( + A_zero_point, + &C_multiplier, + C_zero_point, + C_int32, + C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K, + K, + row_offsets, + col_offsets, + bias); } -template <bool SUM_A, bool FUSE_RELU> +template <bool FUSE_RELU, bool HAS_BIAS> static inline __attribute__((always_inline)) void depthwise_3x3_per_channel_quantization_kernel_( int H, @@ -1302,7 +1254,10 @@ depthwise_3x3_per_channel_quantization_kernel_( int k; for (k = 0; k < K / 32 * 32; k += 32) { - inner_prod_3x3_packed_<SUM_A, false /*remainder*/, true /*per-channel*/>( + inner_prod_3x3_packed_< + true /*SUM_A*/, + false /*remainder*/, + true /*per-channel*/>( H, W, K, @@ -1318,7 +1273,10 @@ depthwise_3x3_per_channel_quantization_kernel_( } int remainder = K - k; if (remainder) { - inner_prod_3x3_packed_<SUM_A, true /*remainder*/, true /*per-channel*/>( + inner_prod_3x3_packed_< + true /*SUM_A*/, + true /*remainder*/, + true /*per-channel*/>( H, W, K, @@ -1332,21 +1290,20 @@ depthwise_3x3_per_channel_quantization_kernel_( remainder, &row_offsets[k]); } - if (SUM_A) { - requantize_per_channel_<FUSE_RELU, true /*HAS_BIAS*/>( - A_zero_point, - C_multiplier, - C_zero_point, - C_int32, - C_uint8 + (h * W_OUT + w) * K, - K, - row_offsets, - col_offsets, - bias); - } + + requantize_<FUSE_RELU, HAS_BIAS, true /*PER_CHAN_QUANT*/>( + A_zero_point, + C_multiplier, + C_zero_point, + C_int32, + C_uint8 + (h * W_OUT + w) * K, + K, + row_offsets, + col_offsets, + bias); } -template <bool SUM_A, bool FUSE_RELU> +template <bool FUSE_RELU, bool HAS_BIAS> static inline __attribute__((always_inline)) void depthwise_3x3x3_per_channel_quantization_kernel_( int T, @@ -1380,7 +1337,10 @@ depthwise_3x3x3_per_channel_quantization_kernel_( int k; for (k = 0; k < K / 32 * 32; k += 32) { - inner_prod_3x3x3_packed_<SUM_A, false /*remainder*/, true /*per-channel*/>( + inner_prod_3x3x3_packed_< + true /*SUM_A*/, + false /*remainder*/, + true /*per-channel*/>( T, H, W, @@ -1398,7 +1358,10 @@ depthwise_3x3x3_per_channel_quantization_kernel_( } int remainder = K - k; if (remainder) { - inner_prod_3x3x3_packed_<SUM_A, true /*remainder*/, true /*per-channel*/>( + inner_prod_3x3x3_packed_< + true /*SUM_A*/, + true /*remainder*/, + true /*per-channel*/>( T, H, W, @@ -1414,18 +1377,16 @@ depthwise_3x3x3_per_channel_quantization_kernel_( remainder, &row_offsets[k]); } - if (SUM_A) { - requantize_per_channel_<FUSE_RELU, true>( - A_zero_point, - C_multiplier, - C_zero_point, - C_int32, - C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K, - K, - row_offsets, - col_offsets, - bias); - } + requantize_<FUSE_RELU, HAS_BIAS, true /*PER_CHAN_QUANT*/>( + A_zero_point, + C_multiplier, + C_zero_point, + C_int32, + C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K, + K, + row_offsets, + col_offsets, + bias); } static pair<int, int> closest_factors_(int n) { @@ -1440,7 +1401,7 @@ static pair<int, int> closest_factors_(int n) { // This implemntation should be general enough to handle not just 3x3 but other // filter shapes by parameterizing with R and S but restricting it to just 3x3 // for now. -template <bool FUSE_RESCALE = true, bool FUSE_RELU = false> +template <bool FUSE_RELU, bool HAS_BIAS> static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( int N, int H, @@ -1468,7 +1429,6 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( const int8_t* Bp = B.PackedMat(); int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64))); - int32_t* C_temp; int n_begin, n_end; int h_begin, h_end, w_begin, w_end; @@ -1517,9 +1477,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( if (h_begin == 0) { if (w_begin == 0) { - C_temp = FUSE_RESCALE ? C_int32 - : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; - depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>( + depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS>( H, W, K, @@ -1533,7 +1491,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( Bp, C_multiplier, C_zero_point, - C_temp, + C_int32, C_uint8_base, row_offsets, col_offsets, @@ -1541,9 +1499,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( } for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { - C_temp = FUSE_RESCALE ? C_int32 - : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; - depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>( + depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS>( H, W, K, @@ -1557,7 +1513,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( Bp, C_multiplier, C_zero_point, - C_temp, + C_int32, C_uint8_base, row_offsets, col_offsets, @@ -1566,9 +1522,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( if (w_end == W_OUT) { w = W_OUT - 1; - C_temp = FUSE_RESCALE ? C_int32 - : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; - depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>( + depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS>( H, W, K, @@ -1582,7 +1536,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( Bp, C_multiplier, C_zero_point, - C_temp, + C_int32, C_uint8_base, row_offsets, col_offsets, @@ -1593,9 +1547,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( for (h = std::max(1, h_begin); h < std::min(H - 1, h_end); ++h) { if (w_begin == 0) { w = 0; - C_temp = FUSE_RESCALE ? C_int32 - : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; - depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>( + depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS>( H, W, K, @@ -1609,7 +1561,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( Bp, C_multiplier, C_zero_point, - C_temp, + C_int32, C_uint8_base, row_offsets, col_offsets, @@ -1617,9 +1569,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( } for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { - C_temp = FUSE_RESCALE ? C_int32 - : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; - depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>( + depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS>( H, W, K, @@ -1633,7 +1583,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( Bp, C_multiplier, C_zero_point, - C_temp, + C_int32, C_uint8_base, row_offsets, col_offsets, @@ -1642,9 +1592,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( if (w_end == W_OUT) { w = W_OUT - 1; - C_temp = FUSE_RESCALE ? C_int32 - : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; - depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>( + depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS>( H, W, K, @@ -1658,7 +1606,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( Bp, C_multiplier, C_zero_point, - C_temp, + C_int32, C_uint8_base, row_offsets, col_offsets, @@ -1670,9 +1618,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( h = H_OUT - 1; w = 0; if (w_begin == 0) { - C_temp = FUSE_RESCALE ? C_int32 - : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; - depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>( + depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS>( H, W, K, @@ -1686,7 +1632,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( Bp, C_multiplier, C_zero_point, - C_temp, + C_int32, C_uint8_base, row_offsets, col_offsets, @@ -1694,9 +1640,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( } for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { - C_temp = FUSE_RESCALE ? C_int32 - : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; - depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>( + depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS>( H, W, K, @@ -1710,7 +1654,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( Bp, C_multiplier, C_zero_point, - C_temp, + C_int32, C_uint8_base, row_offsets, col_offsets, @@ -1719,9 +1663,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( if (w_end == W_OUT) { w = W_OUT - 1; - C_temp = FUSE_RESCALE ? C_int32 - : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; - depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>( + depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS>( H, W, K, @@ -1735,7 +1677,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( Bp, C_multiplier, C_zero_point, - C_temp, + C_int32, C_uint8_base, row_offsets, col_offsets, @@ -1745,7 +1687,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( } // for each n }; -template <bool FUSE_RESCALE = true, bool FUSE_RELU = false> +template <bool FUSE_RELU, bool HAS_BIAS> static inline __attribute__((always_inline)) void depthwise_3x3x3_pad_1_( int N, int T, @@ -1777,7 +1719,6 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_pad_1_( const int8_t* Bp = B.PackedMat(); int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64))); - int32_t* C_temp; int n_begin, n_end; int t_begin, t_end, h_begin, h_end; @@ -1824,10 +1765,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_pad_1_( for (int t = t_begin; t < t_end; ++t) { for (int h = h_begin; h < h_end; ++h) { for (int w = 0; w < W_OUT; ++w) { - C_temp = FUSE_RESCALE - ? C_int32 - : C_int32 + (((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K; - depthwise_3x3x3_kernel_<FUSE_RESCALE, FUSE_RELU>( + depthwise_3x3x3_kernel_<FUSE_RELU, HAS_BIAS>( T, H, W, @@ -1844,7 +1782,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_pad_1_( Bp, C_multiplier, C_zero_point, - C_temp, + C_int32, C_uint8_base, row_offsets, col_offsets, @@ -1855,7 +1793,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_pad_1_( } // for each n }; -template <bool FUSE_RESCALE = true, bool FUSE_RELU = false> +template <bool FUSE_RELU, bool HAS_BIAS> static inline __attribute__((always_inline)) void depthwise_3x3_per_channel_quantization_pad_1_( int N, @@ -1884,7 +1822,6 @@ depthwise_3x3_per_channel_quantization_pad_1_( const int8_t* Bp = B.PackedMat(); int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64))); - int32_t* C_temp; int n_begin, n_end; int h_begin, h_end, w_begin, w_end; @@ -1933,9 +1870,7 @@ depthwise_3x3_per_channel_quantization_pad_1_( if (h_begin == 0) { if (w_begin == 0) { - C_temp = FUSE_RESCALE ? C_int32 - : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; - depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>( + depthwise_3x3_per_channel_quantization_kernel_<FUSE_RELU, HAS_BIAS>( H, W, K, @@ -1949,7 +1884,7 @@ depthwise_3x3_per_channel_quantization_pad_1_( Bp, C_multiplier, C_zero_point, - C_temp, + C_int32, C_uint8_base, row_offsets, col_offsets, @@ -1957,9 +1892,7 @@ depthwise_3x3_per_channel_quantization_pad_1_( } for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { - C_temp = FUSE_RESCALE ? C_int32 - : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; - depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>( + depthwise_3x3_per_channel_quantization_kernel_<FUSE_RELU, HAS_BIAS>( H, W, K, @@ -1973,7 +1906,7 @@ depthwise_3x3_per_channel_quantization_pad_1_( Bp, C_multiplier, C_zero_point, - C_temp, + C_int32, C_uint8_base, row_offsets, col_offsets, @@ -1982,9 +1915,7 @@ depthwise_3x3_per_channel_quantization_pad_1_( if (w_end == W_OUT) { w = W_OUT - 1; - C_temp = FUSE_RESCALE ? C_int32 - : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; - depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>( + depthwise_3x3_per_channel_quantization_kernel_<FUSE_RELU, HAS_BIAS>( H, W, K, @@ -1998,7 +1929,7 @@ depthwise_3x3_per_channel_quantization_pad_1_( Bp, C_multiplier, C_zero_point, - C_temp, + C_int32, C_uint8_base, row_offsets, col_offsets, @@ -2009,9 +1940,7 @@ depthwise_3x3_per_channel_quantization_pad_1_( for (h = std::max(1, h_begin); h < std::min(H - 1, h_end); ++h) { if (w_begin == 0) { w = 0; - C_temp = FUSE_RESCALE ? C_int32 - : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; - depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>( + depthwise_3x3_per_channel_quantization_kernel_<FUSE_RELU, HAS_BIAS>( H, W, K, @@ -2025,7 +1954,7 @@ depthwise_3x3_per_channel_quantization_pad_1_( Bp, C_multiplier, C_zero_point, - C_temp, + C_int32, C_uint8_base, row_offsets, col_offsets, @@ -2033,9 +1962,7 @@ depthwise_3x3_per_channel_quantization_pad_1_( } for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { - C_temp = FUSE_RESCALE ? C_int32 - : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; - depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>( + depthwise_3x3_per_channel_quantization_kernel_<FUSE_RELU, HAS_BIAS>( H, W, K, @@ -2049,7 +1976,7 @@ depthwise_3x3_per_channel_quantization_pad_1_( Bp, C_multiplier, C_zero_point, - C_temp, + C_int32, C_uint8_base, row_offsets, col_offsets, @@ -2058,9 +1985,7 @@ depthwise_3x3_per_channel_quantization_pad_1_( if (w_end == W_OUT) { w = W_OUT - 1; - C_temp = FUSE_RESCALE ? C_int32 - : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; - depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>( + depthwise_3x3_per_channel_quantization_kernel_<FUSE_RELU, HAS_BIAS>( H, W, K, @@ -2074,7 +1999,7 @@ depthwise_3x3_per_channel_quantization_pad_1_( Bp, C_multiplier, C_zero_point, - C_temp, + C_int32, C_uint8_base, row_offsets, col_offsets, @@ -2086,9 +2011,7 @@ depthwise_3x3_per_channel_quantization_pad_1_( h = H_OUT - 1; w = 0; if (w_begin == 0) { - C_temp = FUSE_RESCALE ? C_int32 - : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; - depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>( + depthwise_3x3_per_channel_quantization_kernel_<FUSE_RELU, HAS_BIAS>( H, W, K, @@ -2102,7 +2025,7 @@ depthwise_3x3_per_channel_quantization_pad_1_( Bp, C_multiplier, C_zero_point, - C_temp, + C_int32, C_uint8_base, row_offsets, col_offsets, @@ -2110,9 +2033,7 @@ depthwise_3x3_per_channel_quantization_pad_1_( } for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { - C_temp = FUSE_RESCALE ? C_int32 - : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; - depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>( + depthwise_3x3_per_channel_quantization_kernel_<FUSE_RELU, HAS_BIAS>( H, W, K, @@ -2126,7 +2047,7 @@ depthwise_3x3_per_channel_quantization_pad_1_( Bp, C_multiplier, C_zero_point, - C_temp, + C_int32, C_uint8_base, row_offsets, col_offsets, @@ -2135,9 +2056,7 @@ depthwise_3x3_per_channel_quantization_pad_1_( if (w_end == W_OUT) { w = W_OUT - 1; - C_temp = FUSE_RESCALE ? C_int32 - : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; - depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>( + depthwise_3x3_per_channel_quantization_kernel_<FUSE_RELU, HAS_BIAS>( H, W, K, @@ -2151,7 +2070,7 @@ depthwise_3x3_per_channel_quantization_pad_1_( Bp, C_multiplier, C_zero_point, - C_temp, + C_int32, C_uint8_base, row_offsets, col_offsets, @@ -2161,7 +2080,7 @@ depthwise_3x3_per_channel_quantization_pad_1_( } // for each n }; -template <bool FUSE_RESCALE = true, bool FUSE_RELU = false> +template <bool FUSE_RELU, bool HAS_BIAS> static inline __attribute__((always_inline)) void depthwise_3x3x3_per_channel_quantization_pad_1_( int N, @@ -2194,7 +2113,6 @@ depthwise_3x3x3_per_channel_quantization_pad_1_( const int8_t* Bp = B.PackedMat(); int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64))); - int32_t* C_temp; int n_begin, n_end; int t_begin, t_end, h_begin, h_end; @@ -2241,12 +2159,7 @@ depthwise_3x3x3_per_channel_quantization_pad_1_( for (int t = t_begin; t < t_end; ++t) { for (int h = h_begin; h < h_end; ++h) { for (int w = 0; w < W_OUT; ++w) { - C_temp = FUSE_RESCALE - ? C_int32 - : C_int32 + (((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K; - depthwise_3x3x3_per_channel_quantization_kernel_< - FUSE_RESCALE, - FUSE_RELU>( + depthwise_3x3x3_per_channel_quantization_kernel_<FUSE_RELU, HAS_BIAS>( T, H, W, @@ -2263,7 +2176,7 @@ depthwise_3x3x3_per_channel_quantization_pad_1_( Bp, C_multiplier, C_zero_point, - C_temp, + C_int32, C_uint8_base, row_offsets, col_offsets, @@ -2274,8 +2187,9 @@ depthwise_3x3x3_per_channel_quantization_pad_1_( } // for each n }; -// assumption: W > 3 and H > 3 -void depthwise_3x3_pad_1( +// Dispatch HAS_BIAS +template <bool FUSE_RELU> +static void depthwise_3x3_pad_1_( int N, int H, int W, @@ -2284,72 +2198,18 @@ void depthwise_3x3_pad_1( int stride_w, int32_t A_zero_point, const uint8_t* A, + int32_t B_zero_point, const Packed3x3ConvMatrix& B, - int32_t* C, + float C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias, int thread_id, int num_threads) { - if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) { - depthwise_3x3_pad_1_<false>( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - 0, - B, - 0.0f, - 0, - C, - nullptr, - nullptr, - nullptr, - thread_id, - num_threads); - } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) { - depthwise_3x3_pad_1_<false>( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - 0, - B, - 0.0f, - 0, - C, - nullptr, - nullptr, - nullptr, - thread_id, - num_threads); - } else if (1 == stride_h && 1 == stride_w) { - depthwise_3x3_pad_1_<false>( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - 0, - B, - 0.0f, - 0, - C, - nullptr, - nullptr, - nullptr, - thread_id, - num_threads); - } else if (2 == stride_h && 2 == stride_w) { - depthwise_3x3_pad_1_<false>( + int32_t C_int32_temp[(K + 31) / 32 * 32]; + if (bias) { + depthwise_3x3_pad_1_<FUSE_RELU, true /*HAS_BIAS*/>( N, H, W, @@ -2358,18 +2218,18 @@ void depthwise_3x3_pad_1( stride_w, A_zero_point, A, - 0, + B_zero_point, B, - 0.0f, - 0, + C_multiplier, + C_zero_point, + C_int32_temp, C, - nullptr, - nullptr, - nullptr, + col_offsets, + bias, thread_id, num_threads); } else { - depthwise_3x3_pad_1_<false>( + depthwise_3x3_pad_1_<FUSE_RELU, false /*HAS_BIAS*/>( N, H, W, @@ -2378,19 +2238,21 @@ void depthwise_3x3_pad_1( stride_w, A_zero_point, A, - 0, + B_zero_point, B, - 0.0f, - 0, + C_multiplier, + C_zero_point, + C_int32_temp, C, - nullptr, - nullptr, - nullptr, + col_offsets, + bias, thread_id, num_threads); } } +// Dispatch input shape and FUSE_RELU +// assumption: W > 3 and H > 3 void depthwise_3x3_pad_1( int N, int H, @@ -2410,10 +2272,9 @@ void depthwise_3x3_pad_1( bool fuse_relu, int thread_id, int num_threads) { - int32_t C_int32_temp[(K + 31) / 32 * 32]; if (fuse_relu) { if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) { - depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>( + depthwise_3x3_pad_1_<true /* FUSE_RELU */>( N, H, W, @@ -2426,14 +2287,13 @@ void depthwise_3x3_pad_1( B, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, thread_id, num_threads); } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) { - depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>( + depthwise_3x3_pad_1_<true /* FUSE_RELU */>( N, H, W, @@ -2446,14 +2306,13 @@ void depthwise_3x3_pad_1( B, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, thread_id, num_threads); } else if (1 == stride_h && 1 == stride_w) { - depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>( + depthwise_3x3_pad_1_<true /* FUSE_RELU */>( N, H, W, @@ -2466,14 +2325,13 @@ void depthwise_3x3_pad_1( B, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, thread_id, num_threads); } else if (2 == stride_h && 2 == stride_w) { - depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>( + depthwise_3x3_pad_1_<true /* FUSE_RELU */>( N, H, W, @@ -2486,14 +2344,13 @@ void depthwise_3x3_pad_1( B, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, thread_id, num_threads); } else { - depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>( + depthwise_3x3_pad_1_<true /* FUSE_RELU */>( N, H, W, @@ -2506,7 +2363,6 @@ void depthwise_3x3_pad_1( B, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, @@ -2515,7 +2371,7 @@ void depthwise_3x3_pad_1( } } else { if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) { - depthwise_3x3_pad_1_( + depthwise_3x3_pad_1_<false /* FUSE_RELU */>( N, H, W, @@ -2528,14 +2384,13 @@ void depthwise_3x3_pad_1( B, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, thread_id, num_threads); } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) { - depthwise_3x3_pad_1_( + depthwise_3x3_pad_1_<false /* FUSE_RELU */>( N, H, W, @@ -2548,14 +2403,13 @@ void depthwise_3x3_pad_1( B, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, thread_id, num_threads); } else if (1 == stride_h && 1 == stride_w) { - depthwise_3x3_pad_1_( + depthwise_3x3_pad_1_<false /* FUSE_RELU */>( N, H, W, @@ -2568,14 +2422,13 @@ void depthwise_3x3_pad_1( B, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, thread_id, num_threads); } else if (2 == stride_h && 2 == stride_w) { - depthwise_3x3_pad_1_( + depthwise_3x3_pad_1_<false /* FUSE_RELU */>( N, H, W, @@ -2588,14 +2441,13 @@ void depthwise_3x3_pad_1( B, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, thread_id, num_threads); } else { - depthwise_3x3_pad_1_( + depthwise_3x3_pad_1_<false /* FUSE_RELU */>( N, H, W, @@ -2608,7 +2460,6 @@ void depthwise_3x3_pad_1( B, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, @@ -2618,44 +2469,8 @@ void depthwise_3x3_pad_1( } } -void depthwise_3x3x3_pad_1( - int N, - int T, - int H, - int W, - int K, - int stride_t, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - const Packed3x3x3ConvMatrix& B, - int32_t* C, - int thread_id, - int num_threads) { - depthwise_3x3x3_pad_1_<false /* FUSE_RESCALE */>( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - 0, - B, - 0.0f, - 0, - C, - nullptr, - nullptr, - nullptr, - thread_id, - num_threads); -} - +// Dispatch HAS_BIAS +template <bool FUSE_RELU> static void depthwise_3x3x3_pad_1_( int N, int T, @@ -2677,30 +2492,55 @@ static void depthwise_3x3x3_pad_1_( int thread_id, int num_threads) { int32_t C_int32_temp[(K + 31) / 32 * 32]; - depthwise_3x3x3_pad_1_<true /* FUSE_RESCALE */, false /* FUSE_RELU */>( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C_int32_temp, - C, - col_offsets, - bias, - thread_id, - num_threads); + if (bias) { + depthwise_3x3x3_pad_1_<FUSE_RELU, true /*HAS_BIAS*/>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } else { + depthwise_3x3x3_pad_1_<FUSE_RELU, false /*HAS_BIAS*/>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } } -static void depthwise_3x3x3_pad_1_relu_fused_( +// Dispatch FUSE_RELU +void depthwise_3x3x3_pad_1( int N, int T, int H, @@ -2718,91 +2558,114 @@ static void depthwise_3x3x3_pad_1_relu_fused_( uint8_t* C, const int32_t* col_offsets, const int32_t* bias, + bool fuse_relu, int thread_id, int num_threads) { - int32_t C_int32_temp[(K + 31) / 32 * 32]; - depthwise_3x3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C_int32_temp, - C, - col_offsets, - bias, - thread_id, - num_threads); + if (fuse_relu) { + depthwise_3x3x3_pad_1_<true /*FUSE_RELU*/>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + thread_id, + num_threads); + } else { + depthwise_3x3x3_pad_1_<false /*FUSE_RELU*/>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + thread_id, + num_threads); + } } -void depthwise_3x3x3_pad_1( +// Dispatch HAS_BIAS +template <bool FUSE_RELU> +static void depthwise_3x3_per_channel_quantization_pad_1_( int N, - int T, int H, int W, int K, - int stride_t, int stride_h, int stride_w, int32_t A_zero_point, const uint8_t* A, - int32_t B_zero_point, - const Packed3x3x3ConvMatrix& B, - float C_multiplier, + const int32_t* B_zero_point, + const Packed3x3ConvMatrix& Bp, + const float* C_multiplier, int32_t C_zero_point, uint8_t* C, const int32_t* col_offsets, const int32_t* bias, - bool fuse_relu, int thread_id, int num_threads) { - // If we inline the following two functions, I see stack overflow. - if (fuse_relu) { - depthwise_3x3x3_pad_1_relu_fused_( + int32_t C_int32_temp[(K + 31) / 32 * 32]; + if (bias) { + depthwise_3x3_per_channel_quantization_pad_1_< + FUSE_RELU, + true /* HAS_BIAS */>( N, - T, H, W, K, - stride_t, stride_h, stride_w, A_zero_point, A, B_zero_point, - B, + Bp, C_multiplier, C_zero_point, + C_int32_temp, C, col_offsets, bias, thread_id, num_threads); } else { - depthwise_3x3x3_pad_1_( + depthwise_3x3_per_channel_quantization_pad_1_< + FUSE_RELU, + false /* HAS_BIAS */>( N, - T, H, W, K, - stride_t, stride_h, stride_w, A_zero_point, A, B_zero_point, - B, + Bp, C_multiplier, C_zero_point, + C_int32_temp, C, col_offsets, bias, @@ -2811,6 +2674,7 @@ void depthwise_3x3x3_pad_1( } } +// Dispatch input shape and FUSE_RELU void depthwise_3x3_per_channel_quantization_pad_1( int N, int H, @@ -2830,12 +2694,9 @@ void depthwise_3x3_per_channel_quantization_pad_1( bool fuse_relu, int thread_id, int num_threads) { - int32_t C_int32_temp[(K + 31) / 32 * 32]; if (fuse_relu) { if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_< - true /* FUSE_RESCALE */, - true /* FUSE_RELU */>( + depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>( N, H, W, @@ -2848,16 +2709,13 @@ void depthwise_3x3_per_channel_quantization_pad_1( Bp, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, thread_id, num_threads); } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_< - true /* FUSE_RESCALE */, - true /* FUSE_RELU */>( + depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>( N, H, W, @@ -2870,16 +2728,13 @@ void depthwise_3x3_per_channel_quantization_pad_1( Bp, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, thread_id, num_threads); } else if (1 == stride_h && 1 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_< - true /* FUSE_RESCALE */, - true /* FUSE_RELU */>( + depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>( N, H, W, @@ -2892,16 +2747,13 @@ void depthwise_3x3_per_channel_quantization_pad_1( Bp, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, thread_id, num_threads); } else if (2 == stride_h && 2 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_< - true /* FUSE_RESCALE */, - true /* FUSE_RELU */>( + depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>( N, H, W, @@ -2914,16 +2766,13 @@ void depthwise_3x3_per_channel_quantization_pad_1( Bp, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, thread_id, num_threads); } else { - depthwise_3x3_per_channel_quantization_pad_1_< - true /* FUSE_RESCALE */, - true /* FUSE_RELU */>( + depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>( N, H, W, @@ -2936,7 +2785,6 @@ void depthwise_3x3_per_channel_quantization_pad_1( Bp, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, @@ -2945,9 +2793,7 @@ void depthwise_3x3_per_channel_quantization_pad_1( } } else { if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_< - true /* FUSE_RESCALE */, - false /* FUSE_RELU */>( + depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>( N, H, W, @@ -2960,16 +2806,13 @@ void depthwise_3x3_per_channel_quantization_pad_1( Bp, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, thread_id, num_threads); } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_< - true /* FUSE_RESCALE */, - false /* FUSE_RELU */>( + depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>( N, H, W, @@ -2982,16 +2825,13 @@ void depthwise_3x3_per_channel_quantization_pad_1( Bp, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, thread_id, num_threads); } else if (1 == stride_h && 1 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_< - true /* FUSE_RESCALE */, - false /* FUSE_RELU */>( + depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>( N, H, W, @@ -3004,16 +2844,13 @@ void depthwise_3x3_per_channel_quantization_pad_1( Bp, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, thread_id, num_threads); } else if (2 == stride_h && 2 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_< - true /* FUSE_RESCALE */, - false /* FUSE_RELU */>( + depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>( N, H, W, @@ -3026,16 +2863,13 @@ void depthwise_3x3_per_channel_quantization_pad_1( Bp, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, thread_id, num_threads); } else { - depthwise_3x3_per_channel_quantization_pad_1_< - true /* FUSE_RESCALE */, - false /* FUSE_RELU */>( + depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>( N, H, W, @@ -3048,7 +2882,6 @@ void depthwise_3x3_per_channel_quantization_pad_1( Bp, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, @@ -3058,7 +2891,9 @@ void depthwise_3x3_per_channel_quantization_pad_1( } } -void depthwise_3x3x3_per_channel_quantization_pad_1( +// Dispatch HAS_BIAS +template <bool FUSE_RELU> +static void depthwise_3x3x3_per_channel_quantization_pad_1_( int N, int T, int H, @@ -3076,14 +2911,13 @@ void depthwise_3x3x3_per_channel_quantization_pad_1( uint8_t* C, const int32_t* col_offsets, const int32_t* bias, - bool fuse_relu, int thread_id, int num_threads) { int32_t C_int32_temp[(K + 31) / 32 * 32]; - if (fuse_relu) { + if (bias) { depthwise_3x3x3_per_channel_quantization_pad_1_< - true /* FUSE_RESCALE */, - true /* FUSE_RELU */>( + FUSE_RELU, + true /* HAS_BIAS */>( N, T, H, @@ -3106,8 +2940,8 @@ void depthwise_3x3x3_per_channel_quantization_pad_1( num_threads); } else { depthwise_3x3x3_per_channel_quantization_pad_1_< - true /* FUSE_RESCALE */, - false /* FUSE_RELU */>( + FUSE_RELU, + false /* HAS_BIAS */>( N, T, H, @@ -3131,4 +2965,71 @@ void depthwise_3x3x3_per_channel_quantization_pad_1( } } +// Dispatch FUSE_RELU +void depthwise_3x3x3_per_channel_quantization_pad_1( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const Packed3x3x3ConvMatrix& B, + const float* C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias, + bool fuse_relu, + int thread_id, + int num_threads) { + if (fuse_relu) { + depthwise_3x3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + thread_id, + num_threads); + } else { + depthwise_3x3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + thread_id, + num_threads); + } +} + } // namespace fbgemm diff --git a/src/FbgemmI8DepthwiseAvx2.h b/src/FbgemmI8DepthwiseAvx2.h index 53c6e8a..9fd16ae 100644 --- a/src/FbgemmI8DepthwiseAvx2.h +++ b/src/FbgemmI8DepthwiseAvx2.h @@ -34,6 +34,7 @@ using Packed3x3x3ConvMatrix = PackedDepthWiseConvMatrix<3 * 3 * 3>; /** * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8 + * * @params A The input image in NHWK layout * @params Bp The pre-packed filter */ @@ -46,24 +47,6 @@ FBGEMM_API void depthwise_3x3_pad_1( int stride_w, std::int32_t A_zero_point, const std::uint8_t* A, - const Packed3x3ConvMatrix& Bp, - std::int32_t* C, - int thread_id = 0, - int num_threads = 1); - -/** - * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8 - * This version is fused with requantization. - */ -FBGEMM_API void depthwise_3x3_pad_1( - int N, - int H, - int W, - int K, - int stride_h, - int stride_w, - std::int32_t A_zero_point, - const std::uint8_t* A, std::int32_t B_zero_point, const Packed3x3ConvMatrix& Bp, float C_multiplier, @@ -110,22 +93,6 @@ FBGEMM_API void depthwise_3x3x3_pad_1( int stride_w, std::int32_t A_zero_point, const std::uint8_t* A, - const Packed3x3x3ConvMatrix& Bp, - std::int32_t* C, - int thread_id = 0, - int num_threads = 1); - -FBGEMM_API void depthwise_3x3x3_pad_1( - int N, - int T, - int H, - int W, - int K, - int stride_t, - int stride_h, - int stride_w, - std::int32_t A_zero_point, - const std::uint8_t* A, std::int32_t B_zero_point, const Packed3x3x3ConvMatrix& Bp, float C_multiplier, |