diff options
author | Jongsoo Park <jongsoo@fb.com> | 2019-01-12 06:33:40 +0300 |
---|---|---|
committer | Facebook Github Bot <facebook-github-bot@users.noreply.github.com> | 2019-01-12 06:35:56 +0300 |
commit | 36309fc56728e32b5d78c3be85b48a93f00ed0bf (patch) | |
tree | 5172c281585bc1855073a92ed11dbe3096dc4c31 | |
parent | f6e4f991351d16738ebc92ee681cb4eac83d3941 (diff) |
3x3x3 depthwise convolution with per channel quantization (#15775)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15775
Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/55
fbgemm didn't have per-channel quantization for 3x3x3 depth-wise convolution
Reviewed By: jianyuh
Differential Revision: D13587438
fbshipit-source-id: 91c36fae7a0e8386e3bc49808e18918b01681dd1
-rw-r--r-- | bench/Depthwise3DBenchmark.cc | 2 | ||||
-rw-r--r-- | bench/DepthwiseBenchmark.cc | 1 | ||||
-rw-r--r-- | src/FbgemmI8DepthwiseAvx2.cc | 551 | ||||
-rw-r--r-- | src/FbgemmI8DepthwiseAvx2.h | 27 | ||||
-rw-r--r-- | src/RefImplementations.cc | 89 | ||||
-rw-r--r-- | src/RefImplementations.h | 19 | ||||
-rw-r--r-- | test/I8DepthwiseTest.cc | 138 |
7 files changed, 744 insertions, 83 deletions
diff --git a/bench/Depthwise3DBenchmark.cc b/bench/Depthwise3DBenchmark.cc index 63efc9d..a9b4e4b 100644 --- a/bench/Depthwise3DBenchmark.cc +++ b/bench/Depthwise3DBenchmark.cc @@ -215,7 +215,7 @@ int main() { C_uint8.data(), col_offsets.data(), bias.data(), - false /* fuse_relu */, + false, /* fuse_relu */ tid, num_threads); } diff --git a/bench/DepthwiseBenchmark.cc b/bench/DepthwiseBenchmark.cc index 13be6f6..0fe1b5c 100644 --- a/bench/DepthwiseBenchmark.cc +++ b/bench/DepthwiseBenchmark.cc @@ -292,6 +292,7 @@ int main() { C_uint8.data(), col_offsets.data(), bias.data(), + false, /* fuse_relu */ tid, num_threads); } diff --git a/src/FbgemmI8DepthwiseAvx2.cc b/src/FbgemmI8DepthwiseAvx2.cc index 9dcd7e1..555b93a 100644 --- a/src/FbgemmI8DepthwiseAvx2.cc +++ b/src/FbgemmI8DepthwiseAvx2.cc @@ -1273,7 +1273,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_kernel_( } } -template <bool SUM_A> +template <bool SUM_A, bool FUSE_RELU> static inline __attribute__((always_inline)) void depthwise_3x3_per_channel_quantization_kernel_( int H, @@ -1333,7 +1333,7 @@ depthwise_3x3_per_channel_quantization_kernel_( &row_offsets[k]); } if (SUM_A) { - requantize_per_channel_<false, true>( + requantize_per_channel_<FUSE_RELU, true /*HAS_BIAS*/>( A_zero_point, C_multiplier, C_zero_point, @@ -1346,6 +1346,88 @@ depthwise_3x3_per_channel_quantization_kernel_( } } +template <bool SUM_A, bool FUSE_RELU> +static inline __attribute__((always_inline)) void +depthwise_3x3x3_per_channel_quantization_kernel_( + int T, + int H, + int W, + int K, + int t, + int h, + int w, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const int8_t* Bp, + const float* C_multiplier, + int32_t C_zero_point, + int32_t* C_int32, + uint8_t* C_uint8, + int32_t* row_offsets, + const int32_t* col_offsets, + const int32_t* bias) { + constexpr int R = 3, S = 3; + constexpr int PAD_P = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; + int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; + int t_in = -PAD_P + t * stride_t; + int h_in = -PAD_T + h * stride_h; + int w_in = -PAD_L + w * stride_w; + + int k; + for (k = 0; k < K / 32 * 32; k += 32) { + inner_prod_3x3x3_packed_<SUM_A, false /*remainder*/, true /*per-channel*/>( + T, + H, + W, + K, + t_in, + h_in, + w_in, + A + ((t_in * H + h_in) * W + w_in) * K + k, + A_zero_point, + Bp + k * 28, + B_zero_point + k, + C_int32 + k, + 0, + &row_offsets[k]); + } + int remainder = K - k; + if (remainder) { + inner_prod_3x3x3_packed_<SUM_A, true /*remainder*/, true /*per-channel*/>( + T, + H, + W, + K, + t_in, + h_in, + w_in, + A + ((t_in * H + h_in) * W + w_in) * K + k, + A_zero_point, + Bp + k * 28, + B_zero_point + k, + C_int32 + k, + remainder, + &row_offsets[k]); + } + if (SUM_A) { + requantize_per_channel_<FUSE_RELU, true>( + A_zero_point, + C_multiplier, + C_zero_point, + C_int32, + C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K, + K, + row_offsets, + col_offsets, + bias); + } +} + static pair<int, int> closest_factors_(int n) { int a = (int)std::sqrt(n); while (n % a != 0) { @@ -1773,7 +1855,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_pad_1_( } // for each n }; -template <bool FUSE_RESCALE = true> +template <bool FUSE_RESCALE = true, bool FUSE_RELU = false> static inline __attribute__((always_inline)) void depthwise_3x3_per_channel_quantization_pad_1_( int N, @@ -1853,7 +1935,7 @@ depthwise_3x3_per_channel_quantization_pad_1_( if (w_begin == 0) { C_temp = FUSE_RESCALE ? C_int32 : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; - depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>( + depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>( H, W, K, @@ -1877,7 +1959,7 @@ depthwise_3x3_per_channel_quantization_pad_1_( for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { C_temp = FUSE_RESCALE ? C_int32 : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; - depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>( + depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>( H, W, K, @@ -1902,7 +1984,7 @@ depthwise_3x3_per_channel_quantization_pad_1_( w = W_OUT - 1; C_temp = FUSE_RESCALE ? C_int32 : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; - depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>( + depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>( H, W, K, @@ -1929,7 +2011,7 @@ depthwise_3x3_per_channel_quantization_pad_1_( w = 0; C_temp = FUSE_RESCALE ? C_int32 : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; - depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>( + depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>( H, W, K, @@ -1953,7 +2035,7 @@ depthwise_3x3_per_channel_quantization_pad_1_( for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { C_temp = FUSE_RESCALE ? C_int32 : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; - depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>( + depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>( H, W, K, @@ -1978,7 +2060,7 @@ depthwise_3x3_per_channel_quantization_pad_1_( w = W_OUT - 1; C_temp = FUSE_RESCALE ? C_int32 : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; - depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>( + depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>( H, W, K, @@ -2006,7 +2088,7 @@ depthwise_3x3_per_channel_quantization_pad_1_( if (w_begin == 0) { C_temp = FUSE_RESCALE ? C_int32 : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; - depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>( + depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>( H, W, K, @@ -2030,7 +2112,7 @@ depthwise_3x3_per_channel_quantization_pad_1_( for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { C_temp = FUSE_RESCALE ? C_int32 : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; - depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>( + depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>( H, W, K, @@ -2055,7 +2137,7 @@ depthwise_3x3_per_channel_quantization_pad_1_( w = W_OUT - 1; C_temp = FUSE_RESCALE ? C_int32 : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K; - depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>( + depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>( H, W, K, @@ -2079,6 +2161,119 @@ depthwise_3x3_per_channel_quantization_pad_1_( } // for each n }; +template <bool FUSE_RESCALE = true, bool FUSE_RELU = false> +static inline __attribute__((always_inline)) void +depthwise_3x3x3_per_channel_quantization_pad_1_( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const Packed3x3x3ConvMatrix& B, + const float* C_multiplier, + int32_t C_zero_point, + int32_t* C_int32, + uint8_t* C_uint8, + const int32_t* col_offsets, + const int32_t* bias, + int thread_id, + int num_threads) { + assert(K % 8 == 0); + constexpr int K_T = 3, K_H = 3, K_W = 3; + constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, + PAD_R = 1; + int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1; + int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; + const int8_t* Bp = B.PackedMat(); + + int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64))); + int32_t* C_temp; + + int n_begin, n_end; + int t_begin, t_end, h_begin, h_end; + if (N >= num_threads) { + int n_per_thread = (N + num_threads - 1) / num_threads; + n_begin = std::min(thread_id * n_per_thread, N); + n_end = std::min(n_begin + n_per_thread, N); + t_begin = 0; + t_end = T_OUT; + h_begin = 0; + h_end = H_OUT; + } else { + int nthreads_per_n = num_threads / N; + n_begin = std::min(thread_id / nthreads_per_n, N); + n_end = std::min(n_begin + 1, N); + + int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads); + int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads); + int nthreads_of_n = tid_of_n_end - tid_of_n_begin; + int tid_within_n = thread_id - tid_of_n_begin; + assert(tid_within_n >= 0); + assert(tid_within_n < nthreads_of_n); + + // n is processed by num_threads_t * num_threads_h 2D grid of threads + int num_threads_t, num_threads_h; + // num_threads_w <= num_threads_h + tie(num_threads_t, num_threads_h) = closest_factors_(nthreads_of_n); + int tid_t = tid_within_n / num_threads_h; + int tid_h = tid_within_n % num_threads_h; + + int t_per_thread = (T_OUT + num_threads_t - 1) / num_threads_t; + t_begin = std::min(tid_t * t_per_thread, T_OUT); + t_end = std::min(t_begin + t_per_thread, T_OUT); + + int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h; + h_begin = std::min(tid_h * h_per_thread, H_OUT); + h_end = std::min(h_begin + h_per_thread, H_OUT); + } + + for (int n = n_begin; n < n_end; ++n) { + const uint8_t* A_base = A + n * T * H * W * K; + uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * K; + + for (int t = t_begin; t < t_end; ++t) { + for (int h = h_begin; h < h_end; ++h) { + for (int w = 0; w < W_OUT; ++w) { + C_temp = FUSE_RESCALE + ? C_int32 + : C_int32 + (((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K; + depthwise_3x3x3_per_channel_quantization_kernel_< + FUSE_RESCALE, + FUSE_RELU>( + T, + H, + W, + K, + t, + h, + w, + stride_t, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_temp, + C_uint8_base, + row_offsets, + col_offsets, + bias); + } // w + } // h + } // t + } // for each n +}; + // assumption: W > 3 and H > 3 void depthwise_3x3_pad_1( int N, @@ -2212,9 +2407,9 @@ void depthwise_3x3_pad_1( uint8_t* C, const int32_t* col_offsets, const int32_t* bias, + bool fuse_relu, int thread_id, - int num_threads, - bool fuse_relu) { + int num_threads) { int32_t C_int32_temp[(K + 31) / 32 * 32]; if (fuse_relu) { if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) { @@ -2632,81 +2827,275 @@ void depthwise_3x3_per_channel_quantization_pad_1( uint8_t* C, const int32_t* col_offsets, const int32_t* bias, + bool fuse_relu, int thread_id, int num_threads) { int32_t C_int32_temp[(K + 31) / 32 * 32]; - if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32_temp, - C, - col_offsets, - bias, - thread_id, - num_threads); - } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32_temp, - C, - col_offsets, - bias, - thread_id, - num_threads); - } else if (1 == stride_h && 1 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32_temp, - C, - col_offsets, - bias, - thread_id, - num_threads); - } else if (2 == stride_h && 2 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_( + if (fuse_relu) { + if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) { + depthwise_3x3_per_channel_quantization_pad_1_< + true /* FUSE_RESCALE */, + true /* FUSE_RELU */>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) { + depthwise_3x3_per_channel_quantization_pad_1_< + true /* FUSE_RESCALE */, + true /* FUSE_RELU */>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } else if (1 == stride_h && 1 == stride_w) { + depthwise_3x3_per_channel_quantization_pad_1_< + true /* FUSE_RESCALE */, + true /* FUSE_RELU */>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } else if (2 == stride_h && 2 == stride_w) { + depthwise_3x3_per_channel_quantization_pad_1_< + true /* FUSE_RESCALE */, + true /* FUSE_RELU */>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } else { + depthwise_3x3_per_channel_quantization_pad_1_< + true /* FUSE_RESCALE */, + true /* FUSE_RELU */>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } + } else { + if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) { + depthwise_3x3_per_channel_quantization_pad_1_< + true /* FUSE_RESCALE */, + false /* FUSE_RELU */>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) { + depthwise_3x3_per_channel_quantization_pad_1_< + true /* FUSE_RESCALE */, + false /* FUSE_RELU */>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } else if (1 == stride_h && 1 == stride_w) { + depthwise_3x3_per_channel_quantization_pad_1_< + true /* FUSE_RESCALE */, + false /* FUSE_RELU */>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } else if (2 == stride_h && 2 == stride_w) { + depthwise_3x3_per_channel_quantization_pad_1_< + true /* FUSE_RESCALE */, + false /* FUSE_RELU */>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } else { + depthwise_3x3_per_channel_quantization_pad_1_< + true /* FUSE_RESCALE */, + false /* FUSE_RELU */>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + thread_id, + num_threads); + } + } +} + +void depthwise_3x3x3_per_channel_quantization_pad_1( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const Packed3x3x3ConvMatrix& B, + const float* C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias, + bool fuse_relu, + int thread_id, + int num_threads) { + int32_t C_int32_temp[(K + 31) / 32 * 32]; + if (fuse_relu) { + depthwise_3x3x3_per_channel_quantization_pad_1_< + true /* FUSE_RESCALE */, + true /* FUSE_RELU */>( N, + T, H, W, K, + stride_t, stride_h, stride_w, A_zero_point, A, B_zero_point, - Bp, + B, C_multiplier, C_zero_point, C_int32_temp, @@ -2716,17 +3105,21 @@ void depthwise_3x3_per_channel_quantization_pad_1( thread_id, num_threads); } else { - depthwise_3x3_per_channel_quantization_pad_1_( + depthwise_3x3x3_per_channel_quantization_pad_1_< + true /* FUSE_RESCALE */, + false /* FUSE_RELU */>( N, + T, H, W, K, + stride_t, stride_h, stride_w, A_zero_point, A, B_zero_point, - Bp, + B, C_multiplier, C_zero_point, C_int32_temp, diff --git a/src/FbgemmI8DepthwiseAvx2.h b/src/FbgemmI8DepthwiseAvx2.h index f0c21a6..53c6e8a 100644 --- a/src/FbgemmI8DepthwiseAvx2.h +++ b/src/FbgemmI8DepthwiseAvx2.h @@ -71,9 +71,9 @@ FBGEMM_API void depthwise_3x3_pad_1( std::uint8_t* C, const std::int32_t* col_offsets, const std::int32_t* bias, + bool fuse_relu = false, int thread_id = 0, - int num_threads = 1, - bool fuse_relu = false); + int num_threads = 1); /** * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8 @@ -95,6 +95,7 @@ FBGEMM_API void depthwise_3x3_per_channel_quantization_pad_1( std::uint8_t* C, const std::int32_t* col_offsets, const std::int32_t* bias, + bool fuse_relu = false, int thread_id = 0, int num_threads = 1); @@ -136,4 +137,26 @@ FBGEMM_API void depthwise_3x3x3_pad_1( int thread_id = 0, int num_threads = 1); +FBGEMM_API void depthwise_3x3x3_per_channel_quantization_pad_1( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, + const std::int32_t* B_zero_point, + const Packed3x3x3ConvMatrix& Bp, + const float* C_multiplier, + std::int32_t C_zero_point, + std::uint8_t* C, + const std::int32_t* col_offsets, + const std::int32_t* bias, + bool fuse_relu = false, + int thread_id = 0, + int num_threads = 1); + } // namespace fbgemm diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc index 6aebc3d..5168a15 100644 --- a/src/RefImplementations.cc +++ b/src/RefImplementations.cc @@ -797,4 +797,93 @@ void depthwise_3x3x3_pad_1_ref( } }; +void depthwise_3x3x3_per_channel_quantization_pad_1_ref( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const int8_t* B, + const float* C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias) { + constexpr int K_T = 3, K_H = 3, K_W = 3; + constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, + PAD_R = 1; + int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1; + int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; + + vector<int32_t> C_int32(N * T_OUT * H_OUT * W_OUT * K); + depthwise_3x3x3_pad_1_ref( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A, + B, + C_int32.data()); + + vector<int32_t> row_offsets(N * T_OUT * H_OUT * W_OUT * K); + for (int n = 0; n < N; ++n) { + for (int t = 0; t < T_OUT; ++t) { + for (int h = 0; h < H_OUT; ++h) { + for (int w = 0; w < W_OUT; ++w) { + for (int k = 0; k < K; ++k) { + int sum = 0; + for (int k_t = 0; k_t < K_T; ++k_t) { + int t_in = -PAD_P + t * stride_t + k_t; + for (int k_h = 0; k_h < K_H; ++k_h) { + int h_in = -PAD_T + h * stride_h + k_h; + for (int k_w = 0; k_w < K_W; ++k_w) { + int w_in = -PAD_L + w * stride_w + k_w; + int a = t_in < 0 || t_in >= T || h_in < 0 || h_in >= H || + w_in < 0 || w_in >= W + ? A_zero_point + : A[(((n * T + t_in) * H + h_in) * W + w_in) * K + k]; + sum += a; + } + } + } + row_offsets[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k] = + sum; + } + } // w + } // h + } // t + } // for each n + + for (int i = 0; i < N * T_OUT * H_OUT * W_OUT; ++i) { + for (int k = 0; k < K; ++k) { + requantize_u8acc32_ref( + 1, + 1, + 1, + C_int32.data() + i * K + k, + C + i * K + k, + &C_multiplier[k], + C_zero_point, + A_zero_point, + &B_zero_point[k], + &row_offsets[i * K + k], + col_offsets + k, + bias ? bias + k : nullptr, + 1); + } + } +}; + } // namespace fbgemm diff --git a/src/RefImplementations.h b/src/RefImplementations.h index e2f52d5..fce68e6 100644 --- a/src/RefImplementations.h +++ b/src/RefImplementations.h @@ -298,4 +298,23 @@ FBGEMM_API void depthwise_3x3x3_pad_1_ref( const std::int32_t* col_offsets, const std::int32_t* bias); +FBGEMM_API void depthwise_3x3x3_per_channel_quantization_pad_1_ref( + int N, + int T, + int H, + int W, + int K, + int stride_t, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, + const std::int32_t* B_zero_point, + const std::int8_t* B, + const float* C_multiplier, + std::int32_t C_zero_point, + std::uint8_t* C, + const std::int32_t* col_offsets, + const std::int32_t* bias); + } // namespace fbgemm diff --git a/test/I8DepthwiseTest.cc b/test/I8DepthwiseTest.cc index 7242b7a..ab44f0f 100644 --- a/test/I8DepthwiseTest.cc +++ b/test/I8DepthwiseTest.cc @@ -168,6 +168,7 @@ TEST(FBGemmDepthWiseTest, Test3x3) { C_uint8.data(), col_offsets.data(), bias.data(), + false, /* fuse_relu */ 0, 1); @@ -317,7 +318,7 @@ TEST(FBGemmDepthWiseTest, Test3x3x3) { C_uint8.data(), col_offsets.data(), bias.data(), - false /* fuse_relu */, + false, /* fuse_relu */ 0, 1); @@ -441,6 +442,7 @@ TEST(FBGemmDepthWiseTest, Test3x3PerChannelQuantization) { C_uint8.data(), col_offsets.data(), bias.data(), + false, /* fuse_relu */ 0, 1); @@ -462,4 +464,138 @@ TEST(FBGemmDepthWiseTest, Test3x3PerChannelQuantization) { } // for each shape } // Test3x3PerChannelQuantization +TEST(FBGemmDepthWiseTest, Test3x3x3PerChannelQuantization) { + for (auto shape : shapes_3d) { + int N = shape[0]; + int K = shape[1]; + int T = shape[2]; + int H = shape[3]; + int W = shape[4]; + int stride_t = shape[5]; + int stride_h = stride_t; + int stride_w = stride_t; + constexpr int K_T = 3, K_H = 3, K_W = 3; + constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, + PAD_R = 1; + int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1; + int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; + + aligned_vector<uint8_t> A(N * T * H * W * K); + aligned_vector<int8_t> B(K * K_T * K_H * K_W); + int32_t C_num_rows = N * T_OUT * H_OUT * W_OUT; + aligned_vector<int32_t> C_ref(C_num_rows * K), C(C_ref.size()); + + randFill<uint8_t>(A, 0, 86); + int32_t A_zero_point = 43; + + // Each row of G has a different range to really test per-channel + // quantization. + vector<int32_t> B_zero_point(K); + for (auto k = 0; k < K; ++k) { + aligned_vector<int8_t> Bk(K_T * K_H * K_W); + randFill<int8_t>(Bk, -16 + k, 16 + k); + copy(Bk.begin(), Bk.end(), B.begin() + k * K_T * K_H * K_W); + + B_zero_point[k] = 5 + k; + } + + depthwise_3x3x3_pad_1_ref( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A.data(), + B.data(), + C_ref.data()); + + aligned_vector<int32_t> C_ref_transpose(C_ref); + transpose_matrix(C_ref.data(), C_num_rows, K); + vector<float> C_multiplier(K); + for (auto k = 0; k < K; ++k) { + auto C_ref_k_begin = C_ref_transpose.begin() + k * C_num_rows; + auto C_ref_k_end = C_ref_k_begin + C_num_rows; + int32_t minimum = *min_element(C_ref_k_begin, C_ref_k_end); + int32_t maximum = *max_element(C_ref_k_begin, C_ref_k_end); + C_multiplier[k] = 255. / (maximum - minimum); + cerr << "k " << k << " minimum " << minimum << " maximum " << maximum + << " multiplier " << C_multiplier[k] << endl; + } + int32_t C_zero_point = 5; + + aligned_vector<int32_t> col_offsets(K); + aligned_vector<int32_t> bias(K); + randFill(col_offsets, -100, 100); + randFill(bias, -40, 40); + + aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size()); + depthwise_3x3x3_per_channel_quantization_pad_1_ref( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A.data(), + B_zero_point.data(), + B.data(), + C_multiplier.data(), + C_zero_point, + C_uint8_ref.data(), + col_offsets.data(), + bias.data()); + + Packed3x3x3ConvMatrix Bp(K, B.data()); + + depthwise_3x3x3_per_channel_quantization_pad_1( + N, + T, + H, + W, + K, + stride_t, + stride_h, + stride_w, + A_zero_point, + A.data(), + B_zero_point.data(), + Bp, + C_multiplier.data(), + C_zero_point, + C_uint8.data(), + col_offsets.data(), + bias.data(), + false, /* fuse_relu */ + 0, + 1); + + // correctness check + for (int n = 0; n < N; ++n) { + for (int t = 0; t < T_OUT; ++t) { + for (int h = 0; h < H_OUT; ++h) { + for (int w = 0; w < W_OUT; ++w) { + for (int k = 0; k < K; ++k) { + int32_t expected = C_uint8_ref + [(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k]; + int32_t actual = + C_uint8[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k]; + ASSERT_EQ(expected, actual) + << "Depthwise 3x3 results differ at (" << n << ", " << t + << ", " << h << ", " << w << ", " << k << ")."; + } + } // w + } // h + } // t + } // n + } // for each shape +} // Test3x3PerChannelQuantization + } // namespace fbgemm |