From c6e86067e41a363af718dae7f8d7494068aad868 Mon Sep 17 00:00:00 2001 From: Jongsoo Park Date: Wed, 3 Apr 2019 07:59:57 -0700 Subject: optimize dw conv for symmetric quant (#73) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/73 Skip computing row_offset if B uses symmetric quantization. Skip adding col_offset if A uses symmetric quantization. Reviewed By: jianyuh Differential Revision: D14055973 fbshipit-source-id: 91da8f0755b2f90175e94a893b5a3ad6342c506d --- src/FbgemmI8DepthwiseAvx2.cc | 685 ++++++++++++++++++++++++++++++++++++------- src/FbgemmI8DepthwiseAvx2.h | 45 ++- test/I8DepthwiseTest.cc | 42 ++- 3 files changed, 660 insertions(+), 112 deletions(-) diff --git a/src/FbgemmI8DepthwiseAvx2.cc b/src/FbgemmI8DepthwiseAvx2.cc index 017c4c8..2620e43 100644 --- a/src/FbgemmI8DepthwiseAvx2.cc +++ b/src/FbgemmI8DepthwiseAvx2.cc @@ -525,7 +525,12 @@ static inline __attribute__((always_inline)) void inner_prod_3x3_packed_( // Almost same as ReQuantizeOutput in OutputProcessing-inh.h but different // row_offsets for each row because of depth-wise convolution -template +template < + bool FUSE_RELU, + bool HAS_BIAS, + bool PER_CHANNEL_QUANTIZATION, + bool A_SYMMETRIC, + bool B_SYMMETRIC> static inline __attribute__((always_inline)) void requantize_( int32_t A_zero_point, const float* C_multiplier, @@ -544,6 +549,9 @@ static inline __attribute__((always_inline)) void requantize_( __m256i min_v = _mm256_set1_epi8(static_cast(0)); __m256i max_v = _mm256_set1_epi8(static_cast(255)); + if (A_SYMMETRIC) { + assert(A_zero_point == 0 || col_offsets == nullptr); + } __m256i A_zero_point_v = _mm256_set1_epi32(A_zero_point); __m256i C_zero_point_epi16_v = _mm256_set1_epi16(C_zero_point); __m256i C_zero_point_epi8_v = _mm256_set1_epi8(C_zero_point); @@ -563,36 +571,59 @@ static inline __attribute__((always_inline)) void requantize_( __m256i w_v = _mm256_loadu_si256( reinterpret_cast(C_int32 + j + 3 * VLEN)); - __m256i col_off_v = _mm256_mullo_epi32( - A_zero_point_v, - _mm256_loadu_si256(reinterpret_cast(col_offsets + j))); - __m256i row_offset_v = - _mm256_loadu_si256(reinterpret_cast(row_offsets + j)); - x_v = _mm256_sub_epi32(_mm256_sub_epi32(x_v, col_off_v), row_offset_v); - - col_off_v = _mm256_mullo_epi32( - A_zero_point_v, - _mm256_loadu_si256( - reinterpret_cast(col_offsets + j + VLEN))); - row_offset_v = _mm256_loadu_si256( - reinterpret_cast(row_offsets + j + VLEN)); - y_v = _mm256_sub_epi32(_mm256_sub_epi32(y_v, col_off_v), row_offset_v); - - col_off_v = _mm256_mullo_epi32( - A_zero_point_v, - _mm256_loadu_si256( - reinterpret_cast(col_offsets + j + 2 * VLEN))); - row_offset_v = _mm256_loadu_si256( - reinterpret_cast(row_offsets + j + 2 * VLEN)); - z_v = _mm256_sub_epi32(_mm256_sub_epi32(z_v, col_off_v), row_offset_v); - - col_off_v = _mm256_mullo_epi32( - A_zero_point_v, - _mm256_loadu_si256( - reinterpret_cast(col_offsets + j + 3 * VLEN))); - row_offset_v = _mm256_loadu_si256( - reinterpret_cast(row_offsets + j + 3 * VLEN)); - w_v = _mm256_sub_epi32(_mm256_sub_epi32(w_v, col_off_v), row_offset_v); + __m256i row_offset_v; + if (!B_SYMMETRIC) { + row_offset_v = + _mm256_loadu_si256(reinterpret_cast(row_offsets + j)); + x_v = _mm256_sub_epi32(x_v, row_offset_v); + } + __m256i col_off_v; + if (!A_SYMMETRIC) { + col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256( + reinterpret_cast(col_offsets + j))); + x_v = _mm256_sub_epi32(x_v, col_off_v); + } + + if (!B_SYMMETRIC) { + row_offset_v = _mm256_loadu_si256( + reinterpret_cast(row_offsets + j + VLEN)); + y_v = _mm256_sub_epi32(y_v, row_offset_v); + } + if (!A_SYMMETRIC) { + col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256( + reinterpret_cast(col_offsets + j + VLEN))); + y_v = _mm256_sub_epi32(y_v, col_off_v); + } + + if (!B_SYMMETRIC) { + row_offset_v = _mm256_loadu_si256( + reinterpret_cast(row_offsets + j + 2 * VLEN)); + z_v = _mm256_sub_epi32(z_v, row_offset_v); + } + if (!A_SYMMETRIC) { + col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256( + reinterpret_cast(col_offsets + j + 2 * VLEN))); + z_v = _mm256_sub_epi32(z_v, col_off_v); + } + + if (!B_SYMMETRIC) { + row_offset_v = _mm256_loadu_si256( + reinterpret_cast(row_offsets + j + 3 * VLEN)); + w_v = _mm256_sub_epi32(w_v, row_offset_v); + } + if (!A_SYMMETRIC) { + col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256( + reinterpret_cast(col_offsets + j + 3 * VLEN))); + w_v = _mm256_sub_epi32(w_v, col_off_v); + } if (HAS_BIAS) { // static if x_v = _mm256_add_epi32( @@ -653,12 +684,18 @@ static inline __attribute__((always_inline)) void requantize_( __m256i x_v = _mm256_loadu_si256(reinterpret_cast(C_int32 + j)); - __m256i col_off_v = _mm256_mullo_epi32( - A_zero_point_v, - _mm256_loadu_si256(reinterpret_cast(col_offsets + j))); - __m256i row_offset_v = - _mm256_loadu_si256(reinterpret_cast(row_offsets + j)); - x_v = _mm256_sub_epi32(_mm256_sub_epi32(x_v, col_off_v), row_offset_v); + if (!B_SYMMETRIC) { + __m256i row_offset_v = + _mm256_loadu_si256(reinterpret_cast(row_offsets + j)); + x_v = _mm256_sub_epi32(x_v, row_offset_v); + } + if (!A_SYMMETRIC) { + __m256i col_off_v = _mm256_mullo_epi32( + A_zero_point_v, + _mm256_loadu_si256( + reinterpret_cast(col_offsets + j))); + x_v = _mm256_sub_epi32(x_v, col_off_v); + } if (HAS_BIAS) { // static if x_v = _mm256_add_epi32( @@ -687,7 +724,13 @@ static inline __attribute__((always_inline)) void requantize_( } // j loop vectorized for (; j < n; ++j) { - int32_t raw = C_int32[j] - A_zero_point * col_offsets[j] - row_offsets[j]; + int32_t raw = C_int32[j]; + if (!B_SYMMETRIC) { + raw -= row_offsets[j]; + } + if (!A_SYMMETRIC) { + raw -= A_zero_point * col_offsets[j]; + } if (HAS_BIAS) { // static if raw += bias[j]; } @@ -1074,7 +1117,7 @@ static inline __attribute__((always_inline)) void inner_prod_3x3x3_packed_( } } -template +template static inline __attribute__((always_inline)) void depthwise_3x3_kernel_( int H, int W, @@ -1102,7 +1145,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_kernel_( int k; for (k = 0; k < K / 32 * 32; k += 32) { - inner_prod_3x3_packed_( + inner_prod_3x3_packed_( H, W, K, @@ -1114,11 +1157,11 @@ static inline __attribute__((always_inline)) void depthwise_3x3_kernel_( &B_zero_point, C_int32 + k, 0, - &row_offsets[k]); + B_SYMMETRIC ? nullptr : &row_offsets[k]); } int remainder = K - k; if (remainder) { - inner_prod_3x3_packed_( + inner_prod_3x3_packed_( H, W, K, @@ -1130,10 +1173,15 @@ static inline __attribute__((always_inline)) void depthwise_3x3_kernel_( &B_zero_point, C_int32 + k, remainder, - &row_offsets[k]); + B_SYMMETRIC ? nullptr : &row_offsets[k]); } - requantize_( + requantize_< + FUSE_RELU, + HAS_BIAS, + false, /*PER_CHAN_QUANT*/ + A_SYMMETRIC, + B_SYMMETRIC>( A_zero_point, &C_multiplier, C_zero_point, @@ -1145,7 +1193,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_kernel_( bias); } -template +template static inline __attribute__((always_inline)) void depthwise_3x3x3_kernel_( int T, int H, @@ -1178,7 +1226,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_kernel_( int k; for (k = 0; k < K / 32 * 32; k += 32) { - inner_prod_3x3x3_packed_( + inner_prod_3x3x3_packed_( T, H, W, @@ -1192,11 +1240,11 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_kernel_( &B_zero_point, C_int32 + k, 0, - &row_offsets[k]); + B_SYMMETRIC ? nullptr : &row_offsets[k]); } int remainder = K - k; if (remainder) { - inner_prod_3x3x3_packed_( + inner_prod_3x3x3_packed_( T, H, W, @@ -1210,10 +1258,15 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_kernel_( &B_zero_point, C_int32 + k, remainder, - &row_offsets[k]); + B_SYMMETRIC ? nullptr : &row_offsets[k]); } - requantize_( + requantize_< + FUSE_RELU, + HAS_BIAS, + false, /*PER_CHAN_QUANT*/ + A_SYMMETRIC, + B_SYMMETRIC>( A_zero_point, &C_multiplier, C_zero_point, @@ -1225,7 +1278,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_kernel_( bias); } -template +template static inline __attribute__((always_inline)) void depthwise_3x3_per_channel_quantization_kernel_( int H, @@ -1255,8 +1308,8 @@ depthwise_3x3_per_channel_quantization_kernel_( int k; for (k = 0; k < K / 32 * 32; k += 32) { inner_prod_3x3_packed_< - true /*SUM_A*/, - false /*remainder*/, + true, /*SUM_A*/ + false, /*remainder*/ true /*per-channel*/>( H, W, @@ -1274,8 +1327,8 @@ depthwise_3x3_per_channel_quantization_kernel_( int remainder = K - k; if (remainder) { inner_prod_3x3_packed_< - true /*SUM_A*/, - true /*remainder*/, + true, /*SUM_A*/ + true, /*remainder*/ true /*per-channel*/>( H, W, @@ -1291,7 +1344,12 @@ depthwise_3x3_per_channel_quantization_kernel_( &row_offsets[k]); } - requantize_( + requantize_< + FUSE_RELU, + HAS_BIAS, + true, /*PER_CHAN_QUANT*/ + A_SYMMETRIC, + false /*B_SYMM*/>( A_zero_point, C_multiplier, C_zero_point, @@ -1303,7 +1361,7 @@ depthwise_3x3_per_channel_quantization_kernel_( bias); } -template +template static inline __attribute__((always_inline)) void depthwise_3x3x3_per_channel_quantization_kernel_( int T, @@ -1338,8 +1396,8 @@ depthwise_3x3x3_per_channel_quantization_kernel_( int k; for (k = 0; k < K / 32 * 32; k += 32) { inner_prod_3x3x3_packed_< - true /*SUM_A*/, - false /*remainder*/, + true, /*SUM_A*/ + false, /*remainder*/ true /*per-channel*/>( T, H, @@ -1359,8 +1417,8 @@ depthwise_3x3x3_per_channel_quantization_kernel_( int remainder = K - k; if (remainder) { inner_prod_3x3x3_packed_< - true /*SUM_A*/, - true /*remainder*/, + true, /*SUM_A*/ + true, /*remainder*/ true /*per-channel*/>( T, H, @@ -1377,7 +1435,12 @@ depthwise_3x3x3_per_channel_quantization_kernel_( remainder, &row_offsets[k]); } - requantize_( + requantize_< + FUSE_RELU, + HAS_BIAS, + true, /*PER_CHAN_QUANT*/ + A_SYMMETRIC, + false /*B_SYMM*/>( A_zero_point, C_multiplier, C_zero_point, @@ -1401,7 +1464,7 @@ static pair closest_factors_(int n) { // This implemntation should be general enough to handle not just 3x3 but other // filter shapes by parameterizing with R and S but restricting it to just 3x3 // for now. -template +template static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( int N, int H, @@ -1477,7 +1540,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( if (h_begin == 0) { if (w_begin == 0) { - depthwise_3x3_kernel_( + depthwise_3x3_kernel_( H, W, K, @@ -1499,7 +1562,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( } for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { - depthwise_3x3_kernel_( + depthwise_3x3_kernel_( H, W, K, @@ -1522,7 +1585,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( if (w_end == W_OUT) { w = W_OUT - 1; - depthwise_3x3_kernel_( + depthwise_3x3_kernel_( H, W, K, @@ -1547,7 +1610,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( for (h = std::max(1, h_begin); h < std::min(H - 1, h_end); ++h) { if (w_begin == 0) { w = 0; - depthwise_3x3_kernel_( + depthwise_3x3_kernel_( H, W, K, @@ -1569,7 +1632,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( } for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { - depthwise_3x3_kernel_( + depthwise_3x3_kernel_( H, W, K, @@ -1592,7 +1655,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( if (w_end == W_OUT) { w = W_OUT - 1; - depthwise_3x3_kernel_( + depthwise_3x3_kernel_( H, W, K, @@ -1618,7 +1681,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( h = H_OUT - 1; w = 0; if (w_begin == 0) { - depthwise_3x3_kernel_( + depthwise_3x3_kernel_( H, W, K, @@ -1640,7 +1703,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( } for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { - depthwise_3x3_kernel_( + depthwise_3x3_kernel_( H, W, K, @@ -1663,7 +1726,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( if (w_end == W_OUT) { w = W_OUT - 1; - depthwise_3x3_kernel_( + depthwise_3x3_kernel_( H, W, K, @@ -1687,7 +1750,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( } // for each n }; -template +template static inline __attribute__((always_inline)) void depthwise_3x3x3_pad_1_( int N, int T, @@ -1765,7 +1828,11 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_pad_1_( for (int t = t_begin; t < t_end; ++t) { for (int h = h_begin; h < h_end; ++h) { for (int w = 0; w < W_OUT; ++w) { - depthwise_3x3x3_kernel_( + depthwise_3x3x3_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC>( T, H, W, @@ -1793,7 +1860,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_pad_1_( } // for each n }; -template +template static inline __attribute__((always_inline)) void depthwise_3x3_per_channel_quantization_pad_1_( int N, @@ -1870,7 +1937,10 @@ depthwise_3x3_per_channel_quantization_pad_1_( if (h_begin == 0) { if (w_begin == 0) { - depthwise_3x3_per_channel_quantization_kernel_( + depthwise_3x3_per_channel_quantization_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC>( H, W, K, @@ -1892,7 +1962,10 @@ depthwise_3x3_per_channel_quantization_pad_1_( } for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { - depthwise_3x3_per_channel_quantization_kernel_( + depthwise_3x3_per_channel_quantization_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC>( H, W, K, @@ -1915,7 +1988,10 @@ depthwise_3x3_per_channel_quantization_pad_1_( if (w_end == W_OUT) { w = W_OUT - 1; - depthwise_3x3_per_channel_quantization_kernel_( + depthwise_3x3_per_channel_quantization_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC>( H, W, K, @@ -1940,7 +2016,10 @@ depthwise_3x3_per_channel_quantization_pad_1_( for (h = std::max(1, h_begin); h < std::min(H - 1, h_end); ++h) { if (w_begin == 0) { w = 0; - depthwise_3x3_per_channel_quantization_kernel_( + depthwise_3x3_per_channel_quantization_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC>( H, W, K, @@ -1962,7 +2041,10 @@ depthwise_3x3_per_channel_quantization_pad_1_( } for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { - depthwise_3x3_per_channel_quantization_kernel_( + depthwise_3x3_per_channel_quantization_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC>( H, W, K, @@ -1985,7 +2067,10 @@ depthwise_3x3_per_channel_quantization_pad_1_( if (w_end == W_OUT) { w = W_OUT - 1; - depthwise_3x3_per_channel_quantization_kernel_( + depthwise_3x3_per_channel_quantization_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC>( H, W, K, @@ -2011,7 +2096,10 @@ depthwise_3x3_per_channel_quantization_pad_1_( h = H_OUT - 1; w = 0; if (w_begin == 0) { - depthwise_3x3_per_channel_quantization_kernel_( + depthwise_3x3_per_channel_quantization_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC>( H, W, K, @@ -2033,7 +2121,10 @@ depthwise_3x3_per_channel_quantization_pad_1_( } for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { - depthwise_3x3_per_channel_quantization_kernel_( + depthwise_3x3_per_channel_quantization_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC>( H, W, K, @@ -2056,7 +2147,10 @@ depthwise_3x3_per_channel_quantization_pad_1_( if (w_end == W_OUT) { w = W_OUT - 1; - depthwise_3x3_per_channel_quantization_kernel_( + depthwise_3x3_per_channel_quantization_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC>( H, W, K, @@ -2080,7 +2174,7 @@ depthwise_3x3_per_channel_quantization_pad_1_( } // for each n }; -template +template static inline __attribute__((always_inline)) void depthwise_3x3x3_per_channel_quantization_pad_1_( int N, @@ -2159,7 +2253,10 @@ depthwise_3x3x3_per_channel_quantization_pad_1_( for (int t = t_begin; t < t_end; ++t) { for (int h = h_begin; h < h_end; ++h) { for (int w = 0; w < W_OUT; ++w) { - depthwise_3x3x3_per_channel_quantization_kernel_( + depthwise_3x3x3_per_channel_quantization_kernel_< + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC>( T, H, W, @@ -2187,6 +2284,130 @@ depthwise_3x3x3_per_channel_quantization_pad_1_( } // for each n }; +// Dispatch A_SYMMETRIC and B_SYMMETRIC +template +static void depthwise_3x3_pad_1_( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + int32_t B_zero_point, + const Packed3x3ConvMatrix& B, + float C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias, + int thread_id, + int num_threads) { + int32_t C_int32_temp[(K + 31) / 32 * 32]; + if (A_zero_point == 0 || col_offsets == nullptr) { + if (B_zero_point == 0) { + depthwise_3x3_pad_1_< + FUSE_RELU, + HAS_BIAS, + true /*A_symmetric*/, + true /*B_symmetric*/>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } else { + depthwise_3x3_pad_1_< + FUSE_RELU, + HAS_BIAS, + true /*A_symmetric*/, + false /*B_symmetric*/>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } + } else { + if (B_zero_point == 0) { + depthwise_3x3_pad_1_< + FUSE_RELU, + HAS_BIAS, + false /*A_symmetric*/, + true /*B_symmetric*/>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } else { + depthwise_3x3_pad_1_< + FUSE_RELU, + HAS_BIAS, + false /*A_symmetric*/, + false /*B_symmetric*/>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } + } +} + // Dispatch HAS_BIAS template static void depthwise_3x3_pad_1_( @@ -2207,7 +2428,6 @@ static void depthwise_3x3_pad_1_( const int32_t* bias, int thread_id, int num_threads) { - int32_t C_int32_temp[(K + 31) / 32 * 32]; if (bias) { depthwise_3x3_pad_1_( N, @@ -2222,7 +2442,6 @@ static void depthwise_3x3_pad_1_( B, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, @@ -2242,7 +2461,6 @@ static void depthwise_3x3_pad_1_( B, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, @@ -2469,6 +2687,140 @@ void depthwise_3x3_pad_1( } } +// Dispatch A_SYMMETRIC and B_SYMMETRIC +template +static void depthwise_3x3x3_pad_1_( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + int32_t B_zero_point, + const Packed3x3x3ConvMatrix& B, + float C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias, + int thread_id, + int num_threads) { + int32_t C_int32_temp[(K + 31) / 32 * 32]; + if (A_zero_point == 0 || col_offsets == nullptr) { + if (B_zero_point == 0) { + depthwise_3x3x3_pad_1_< + FUSE_RELU, + HAS_BIAS, + true /*A_symmetric*/, + true /*B_symmetric*/>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } else { + depthwise_3x3x3_pad_1_< + FUSE_RELU, + HAS_BIAS, + true /*A_symmetric*/, + false /*B_symmetric*/>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } + } else { + if (B_zero_point == 0) { + depthwise_3x3x3_pad_1_< + FUSE_RELU, + HAS_BIAS, + false /*A_symmetric*/, + true /*B_symmetric*/>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } else { + depthwise_3x3x3_pad_1_< + FUSE_RELU, + HAS_BIAS, + false /*A_symmetric*/, + false /*B_symmetric*/>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } + } +} + // Dispatch HAS_BIAS template static void depthwise_3x3x3_pad_1_( @@ -2491,7 +2843,6 @@ static void depthwise_3x3x3_pad_1_( const int32_t* bias, int thread_id, int num_threads) { - int32_t C_int32_temp[(K + 31) / 32 * 32]; if (bias) { depthwise_3x3x3_pad_1_( N, @@ -2508,7 +2859,6 @@ static void depthwise_3x3x3_pad_1_( B, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, @@ -2530,7 +2880,6 @@ static void depthwise_3x3x3_pad_1_( B, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, @@ -2606,6 +2955,76 @@ void depthwise_3x3x3_pad_1( } } +// Dispatch A_SYMMETRIC +template +static void depthwise_3x3_per_channel_quantization_pad_1_( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const Packed3x3ConvMatrix& Bp, + const float* C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias, + int thread_id, + int num_threads) { + int32_t C_int32_temp[(K + 31) / 32 * 32]; + if (A_zero_point == 0 || col_offsets == nullptr) { + depthwise_3x3_per_channel_quantization_pad_1_< + FUSE_RELU, + HAS_BIAS, + true /*A_SYMM*/>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } else { + depthwise_3x3_per_channel_quantization_pad_1_< + FUSE_RELU, + HAS_BIAS, + false /*A_SYMM*/>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } +} + // Dispatch HAS_BIAS template static void depthwise_3x3_per_channel_quantization_pad_1_( @@ -2626,7 +3045,6 @@ static void depthwise_3x3_per_channel_quantization_pad_1_( const int32_t* bias, int thread_id, int num_threads) { - int32_t C_int32_temp[(K + 31) / 32 * 32]; if (bias) { depthwise_3x3_per_channel_quantization_pad_1_< FUSE_RELU, @@ -2643,7 +3061,6 @@ static void depthwise_3x3_per_channel_quantization_pad_1_( Bp, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, @@ -2665,7 +3082,6 @@ static void depthwise_3x3_per_channel_quantization_pad_1_( Bp, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, @@ -2891,6 +3307,82 @@ void depthwise_3x3_per_channel_quantization_pad_1( } } +// Dispatch A_SYMMETRIC +template +static void depthwise_3x3x3_per_channel_quantization_pad_1_( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const Packed3x3x3ConvMatrix& B, + const float* C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias, + int thread_id, + int num_threads) { + int32_t C_int32_temp[(K + 31) / 32 * 32]; + if (A_zero_point == 0 || col_offsets == nullptr) { + depthwise_3x3x3_per_channel_quantization_pad_1_< + FUSE_RELU, + HAS_BIAS, + true /*A_SYMM*/>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } else { + depthwise_3x3x3_per_channel_quantization_pad_1_< + FUSE_RELU, + HAS_BIAS, + false /*A_SYMM*/>( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } +} + // Dispatch HAS_BIAS template static void depthwise_3x3x3_per_channel_quantization_pad_1_( @@ -2913,7 +3405,6 @@ static void depthwise_3x3x3_per_channel_quantization_pad_1_( const int32_t* bias, int thread_id, int num_threads) { - int32_t C_int32_temp[(K + 31) / 32 * 32]; if (bias) { depthwise_3x3x3_per_channel_quantization_pad_1_< FUSE_RELU, @@ -2932,7 +3423,6 @@ static void depthwise_3x3x3_per_channel_quantization_pad_1_( B, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, @@ -2956,7 +3446,6 @@ static void depthwise_3x3x3_per_channel_quantization_pad_1_( B, C_multiplier, C_zero_point, - C_int32_temp, C, col_offsets, bias, diff --git a/src/FbgemmI8DepthwiseAvx2.h b/src/FbgemmI8DepthwiseAvx2.h index 9fd16ae..069ff77 100644 --- a/src/FbgemmI8DepthwiseAvx2.h +++ b/src/FbgemmI8DepthwiseAvx2.h @@ -34,10 +34,29 @@ using Packed3x3x3ConvMatrix = PackedDepthWiseConvMatrix<3 * 3 * 3>; /** * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8 - * * @params A The input image in NHWK layout * @params Bp The pre-packed filter */ +FBGEMM_API void depthwise_3x3_pad_1( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, + const Packed3x3ConvMatrix& Bp, + std::int32_t* C, + int thread_id = 0, + int num_threads = 1); + +/** + * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8 + * This version is fused with requantization. + * + * @col_offsets nullptr if col_offsets are folded into bias + */ FBGEMM_API void depthwise_3x3_pad_1( int N, int H, @@ -61,6 +80,8 @@ FBGEMM_API void depthwise_3x3_pad_1( /** * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8 * This version is fused with requantization and uses per-channel quantization. + * + * @col_offsets nullptr if col_offsets are folded into bias */ FBGEMM_API void depthwise_3x3_per_channel_quantization_pad_1( int N, @@ -82,6 +103,25 @@ FBGEMM_API void depthwise_3x3_per_channel_quantization_pad_1( int thread_id = 0, int num_threads = 1); +FBGEMM_API void depthwise_3x3x3_pad_1( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, + const Packed3x3x3ConvMatrix& Bp, + std::int32_t* C, + int thread_id = 0, + int num_threads = 1); + +/** + * @col_offsets nullptr if col_offsets are folded into bias + */ FBGEMM_API void depthwise_3x3x3_pad_1( int N, int T, @@ -104,6 +144,9 @@ FBGEMM_API void depthwise_3x3x3_pad_1( int thread_id = 0, int num_threads = 1); +/** + * @col_offsets nullptr if col_offsets are folded into bias + */ FBGEMM_API void depthwise_3x3x3_per_channel_quantization_pad_1( int N, int T, diff --git a/test/I8DepthwiseTest.cc b/test/I8DepthwiseTest.cc index d985807..7843aae 100644 --- a/test/I8DepthwiseTest.cc +++ b/test/I8DepthwiseTest.cc @@ -68,7 +68,20 @@ static vector> shapes = { { 1, 8, 4, 4, 1, }, }; -TEST(FBGemmDepthWiseTest, Test3x3) { +namespace { +class FBGemmDepthWiseTest + : public testing::TestWithParam> {}; +} // namespace + +INSTANTIATE_TEST_CASE_P( + InstantiationName, + FBGemmDepthWiseTest, + ::testing::Combine(::testing::Bool(), ::testing::Bool())); + +TEST_P(FBGemmDepthWiseTest, Test3x3) { + bool a_symmetric, b_symmetric; + tie(a_symmetric, b_symmetric) = GetParam(); + for (auto shape : shapes) { int N = shape[0]; int K = shape[1]; @@ -86,10 +99,10 @@ TEST(FBGemmDepthWiseTest, Test3x3) { aligned_vector C_ref(N * H_OUT * W_OUT * K), C(C_ref.size()); randFill(A, 0, 86); - int32_t A_zero_point = 43; + int32_t A_zero_point = a_symmetric ? 0 : 43; randFill(B, -16, 16); - int32_t B_zero_point = 5; + int32_t B_zero_point = b_symmetric ? 0 : 5; depthwise_3x3_pad_1_ref( N, @@ -148,7 +161,7 @@ TEST(FBGemmDepthWiseTest, Test3x3) { C_multiplier, C_zero_point, C_uint8.data(), - col_offsets.data(), + a_symmetric ? nullptr : col_offsets.data(), bias.data(), false, /* fuse_relu */ 0, @@ -172,8 +185,15 @@ TEST(FBGemmDepthWiseTest, Test3x3) { } // for each shape } // Test3x3 -TEST(FBGemmDepthWiseTest, Test3x3x3) { - for (auto shape : shapes_3d) { +TEST_P(FBGemmDepthWiseTest, Test3x3x3) { + bool a_symmetric, b_symmetric; + tie(a_symmetric, b_symmetric) = GetParam(); + + // 3x3x3 tests take a long time so for a symmetric quantization, we only + // test with 2 shapes. + for (auto shape : a_symmetric || b_symmetric + ? vector>(shapes_3d.cbegin(), shapes_3d.cbegin() + 2) + : shapes_3d) { int N = shape[0]; int K = shape[1]; int T = shape[2]; @@ -195,7 +215,7 @@ TEST(FBGemmDepthWiseTest, Test3x3x3) { C(C_ref.size()); randFill(A, 0, 86); - int32_t A_zero_point = 43; + int32_t A_zero_point = a_symmetric ? 0 : 43; randFill(B, -16, 16); int32_t B_zero_point = 5; @@ -277,8 +297,8 @@ TEST(FBGemmDepthWiseTest, Test3x3x3) { for (int k = 0; k < K; ++k) { int32_t expected = C_uint8_ref [(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k]; - int32_t actual = - C_uint8[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k]; + int32_t actual = C_uint8 + [(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k]; EXPECT_EQ(expected, actual) << "Depthwise 3x3 results differ at (" << n << ", " << t << ", " << h << ", " << w << ", " << k << ")."; @@ -343,8 +363,6 @@ TEST(FBGemmDepthWiseTest, Test3x3PerChannelQuantization) { int32_t minimum = *min_element(C_ref_k_begin, C_ref_k_end); int32_t maximum = *max_element(C_ref_k_begin, C_ref_k_end); C_multiplier[k] = 255. / (maximum - minimum); - cerr << "k " << k << " minimum " << minimum << " maximum " << maximum - << " multiplier " << C_multiplier[k] << endl; } int32_t C_zero_point = 5; @@ -470,8 +488,6 @@ TEST(FBGemmDepthWiseTest, Test3x3x3PerChannelQuantization) { int32_t minimum = *min_element(C_ref_k_begin, C_ref_k_end); int32_t maximum = *max_element(C_ref_k_begin, C_ref_k_end); C_multiplier[k] = 255. / (maximum - minimum); - cerr << "k " << k << " minimum " << minimum << " maximum " << maximum - << " multiplier " << C_multiplier[k] << endl; } int32_t C_zero_point = 5; -- cgit v1.2.3