Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJongsoo Park <jongsoo@fb.com>2019-01-12 06:33:40 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-01-12 06:35:56 +0300
commit36309fc56728e32b5d78c3be85b48a93f00ed0bf (patch)
tree5172c281585bc1855073a92ed11dbe3096dc4c31
parentf6e4f991351d16738ebc92ee681cb4eac83d3941 (diff)
3x3x3 depthwise convolution with per channel quantization (#15775)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15775 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/55 fbgemm didn't have per-channel quantization for 3x3x3 depth-wise convolution Reviewed By: jianyuh Differential Revision: D13587438 fbshipit-source-id: 91c36fae7a0e8386e3bc49808e18918b01681dd1
-rw-r--r--bench/Depthwise3DBenchmark.cc2
-rw-r--r--bench/DepthwiseBenchmark.cc1
-rw-r--r--src/FbgemmI8DepthwiseAvx2.cc551
-rw-r--r--src/FbgemmI8DepthwiseAvx2.h27
-rw-r--r--src/RefImplementations.cc89
-rw-r--r--src/RefImplementations.h19
-rw-r--r--test/I8DepthwiseTest.cc138
7 files changed, 744 insertions, 83 deletions
diff --git a/bench/Depthwise3DBenchmark.cc b/bench/Depthwise3DBenchmark.cc
index 63efc9d..a9b4e4b 100644
--- a/bench/Depthwise3DBenchmark.cc
+++ b/bench/Depthwise3DBenchmark.cc
@@ -215,7 +215,7 @@ int main() {
C_uint8.data(),
col_offsets.data(),
bias.data(),
- false /* fuse_relu */,
+ false, /* fuse_relu */
tid,
num_threads);
}
diff --git a/bench/DepthwiseBenchmark.cc b/bench/DepthwiseBenchmark.cc
index 13be6f6..0fe1b5c 100644
--- a/bench/DepthwiseBenchmark.cc
+++ b/bench/DepthwiseBenchmark.cc
@@ -292,6 +292,7 @@ int main() {
C_uint8.data(),
col_offsets.data(),
bias.data(),
+ false, /* fuse_relu */
tid,
num_threads);
}
diff --git a/src/FbgemmI8DepthwiseAvx2.cc b/src/FbgemmI8DepthwiseAvx2.cc
index 9dcd7e1..555b93a 100644
--- a/src/FbgemmI8DepthwiseAvx2.cc
+++ b/src/FbgemmI8DepthwiseAvx2.cc
@@ -1273,7 +1273,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_kernel_(
}
}
-template <bool SUM_A>
+template <bool SUM_A, bool FUSE_RELU>
static inline __attribute__((always_inline)) void
depthwise_3x3_per_channel_quantization_kernel_(
int H,
@@ -1333,7 +1333,7 @@ depthwise_3x3_per_channel_quantization_kernel_(
&row_offsets[k]);
}
if (SUM_A) {
- requantize_per_channel_<false, true>(
+ requantize_per_channel_<FUSE_RELU, true /*HAS_BIAS*/>(
A_zero_point,
C_multiplier,
C_zero_point,
@@ -1346,6 +1346,88 @@ depthwise_3x3_per_channel_quantization_kernel_(
}
}
+template <bool SUM_A, bool FUSE_RELU>
+static inline __attribute__((always_inline)) void
+depthwise_3x3x3_per_channel_quantization_kernel_(
+ int T,
+ int H,
+ int W,
+ int K,
+ int t,
+ int h,
+ int w,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const int8_t* Bp,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ int32_t* C_int32,
+ uint8_t* C_uint8,
+ int32_t* row_offsets,
+ const int32_t* col_offsets,
+ const int32_t* bias) {
+ constexpr int R = 3, S = 3;
+ constexpr int PAD_P = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
+ int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+ int t_in = -PAD_P + t * stride_t;
+ int h_in = -PAD_T + h * stride_h;
+ int w_in = -PAD_L + w * stride_w;
+
+ int k;
+ for (k = 0; k < K / 32 * 32; k += 32) {
+ inner_prod_3x3x3_packed_<SUM_A, false /*remainder*/, true /*per-channel*/>(
+ T,
+ H,
+ W,
+ K,
+ t_in,
+ h_in,
+ w_in,
+ A + ((t_in * H + h_in) * W + w_in) * K + k,
+ A_zero_point,
+ Bp + k * 28,
+ B_zero_point + k,
+ C_int32 + k,
+ 0,
+ &row_offsets[k]);
+ }
+ int remainder = K - k;
+ if (remainder) {
+ inner_prod_3x3x3_packed_<SUM_A, true /*remainder*/, true /*per-channel*/>(
+ T,
+ H,
+ W,
+ K,
+ t_in,
+ h_in,
+ w_in,
+ A + ((t_in * H + h_in) * W + w_in) * K + k,
+ A_zero_point,
+ Bp + k * 28,
+ B_zero_point + k,
+ C_int32 + k,
+ remainder,
+ &row_offsets[k]);
+ }
+ if (SUM_A) {
+ requantize_per_channel_<FUSE_RELU, true>(
+ A_zero_point,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K,
+ K,
+ row_offsets,
+ col_offsets,
+ bias);
+ }
+}
+
static pair<int, int> closest_factors_(int n) {
int a = (int)std::sqrt(n);
while (n % a != 0) {
@@ -1773,7 +1855,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_pad_1_(
} // for each n
};
-template <bool FUSE_RESCALE = true>
+template <bool FUSE_RESCALE = true, bool FUSE_RELU = false>
static inline __attribute__((always_inline)) void
depthwise_3x3_per_channel_quantization_pad_1_(
int N,
@@ -1853,7 +1935,7 @@ depthwise_3x3_per_channel_quantization_pad_1_(
if (w_begin == 0) {
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
- depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>(
H,
W,
K,
@@ -1877,7 +1959,7 @@ depthwise_3x3_per_channel_quantization_pad_1_(
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
- depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>(
H,
W,
K,
@@ -1902,7 +1984,7 @@ depthwise_3x3_per_channel_quantization_pad_1_(
w = W_OUT - 1;
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
- depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>(
H,
W,
K,
@@ -1929,7 +2011,7 @@ depthwise_3x3_per_channel_quantization_pad_1_(
w = 0;
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
- depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>(
H,
W,
K,
@@ -1953,7 +2035,7 @@ depthwise_3x3_per_channel_quantization_pad_1_(
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
- depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>(
H,
W,
K,
@@ -1978,7 +2060,7 @@ depthwise_3x3_per_channel_quantization_pad_1_(
w = W_OUT - 1;
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
- depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>(
H,
W,
K,
@@ -2006,7 +2088,7 @@ depthwise_3x3_per_channel_quantization_pad_1_(
if (w_begin == 0) {
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
- depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>(
H,
W,
K,
@@ -2030,7 +2112,7 @@ depthwise_3x3_per_channel_quantization_pad_1_(
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
- depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>(
H,
W,
K,
@@ -2055,7 +2137,7 @@ depthwise_3x3_per_channel_quantization_pad_1_(
w = W_OUT - 1;
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
- depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>(
H,
W,
K,
@@ -2079,6 +2161,119 @@ depthwise_3x3_per_channel_quantization_pad_1_(
} // for each n
};
+template <bool FUSE_RESCALE = true, bool FUSE_RELU = false>
+static inline __attribute__((always_inline)) void
+depthwise_3x3x3_per_channel_quantization_pad_1_(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const Packed3x3x3ConvMatrix& B,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ int32_t* C_int32,
+ uint8_t* C_uint8,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ int thread_id,
+ int num_threads) {
+ assert(K % 8 == 0);
+ constexpr int K_T = 3, K_H = 3, K_W = 3;
+ constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
+ PAD_R = 1;
+ int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
+ int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
+ const int8_t* Bp = B.PackedMat();
+
+ int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64)));
+ int32_t* C_temp;
+
+ int n_begin, n_end;
+ int t_begin, t_end, h_begin, h_end;
+ if (N >= num_threads) {
+ int n_per_thread = (N + num_threads - 1) / num_threads;
+ n_begin = std::min(thread_id * n_per_thread, N);
+ n_end = std::min(n_begin + n_per_thread, N);
+ t_begin = 0;
+ t_end = T_OUT;
+ h_begin = 0;
+ h_end = H_OUT;
+ } else {
+ int nthreads_per_n = num_threads / N;
+ n_begin = std::min(thread_id / nthreads_per_n, N);
+ n_end = std::min(n_begin + 1, N);
+
+ int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads);
+ int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads);
+ int nthreads_of_n = tid_of_n_end - tid_of_n_begin;
+ int tid_within_n = thread_id - tid_of_n_begin;
+ assert(tid_within_n >= 0);
+ assert(tid_within_n < nthreads_of_n);
+
+ // n is processed by num_threads_t * num_threads_h 2D grid of threads
+ int num_threads_t, num_threads_h;
+ // num_threads_w <= num_threads_h
+ tie(num_threads_t, num_threads_h) = closest_factors_(nthreads_of_n);
+ int tid_t = tid_within_n / num_threads_h;
+ int tid_h = tid_within_n % num_threads_h;
+
+ int t_per_thread = (T_OUT + num_threads_t - 1) / num_threads_t;
+ t_begin = std::min(tid_t * t_per_thread, T_OUT);
+ t_end = std::min(t_begin + t_per_thread, T_OUT);
+
+ int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h;
+ h_begin = std::min(tid_h * h_per_thread, H_OUT);
+ h_end = std::min(h_begin + h_per_thread, H_OUT);
+ }
+
+ for (int n = n_begin; n < n_end; ++n) {
+ const uint8_t* A_base = A + n * T * H * W * K;
+ uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * K;
+
+ for (int t = t_begin; t < t_end; ++t) {
+ for (int h = h_begin; h < h_end; ++h) {
+ for (int w = 0; w < W_OUT; ++w) {
+ C_temp = FUSE_RESCALE
+ ? C_int32
+ : C_int32 + (((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K;
+ depthwise_3x3x3_per_channel_quantization_kernel_<
+ FUSE_RESCALE,
+ FUSE_RELU>(
+ T,
+ H,
+ W,
+ K,
+ t,
+ h,
+ w,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
+ } // w
+ } // h
+ } // t
+ } // for each n
+};
+
// assumption: W > 3 and H > 3
void depthwise_3x3_pad_1(
int N,
@@ -2212,9 +2407,9 @@ void depthwise_3x3_pad_1(
uint8_t* C,
const int32_t* col_offsets,
const int32_t* bias,
+ bool fuse_relu,
int thread_id,
- int num_threads,
- bool fuse_relu) {
+ int num_threads) {
int32_t C_int32_temp[(K + 31) / 32 * 32];
if (fuse_relu) {
if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
@@ -2632,81 +2827,275 @@ void depthwise_3x3_per_channel_quantization_pad_1(
uint8_t* C,
const int32_t* col_offsets,
const int32_t* bias,
+ bool fuse_relu,
int thread_id,
int num_threads) {
int32_t C_int32_temp[(K + 31) / 32 * 32];
- if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- } else if (1 == stride_h && 1 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- } else if (2 == stride_h && 2 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_(
+ if (fuse_relu) {
+ if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ true /* FUSE_RESCALE */,
+ true /* FUSE_RELU */>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
+ } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ true /* FUSE_RESCALE */,
+ true /* FUSE_RELU */>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
+ } else if (1 == stride_h && 1 == stride_w) {
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ true /* FUSE_RESCALE */,
+ true /* FUSE_RELU */>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
+ } else if (2 == stride_h && 2 == stride_w) {
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ true /* FUSE_RESCALE */,
+ true /* FUSE_RELU */>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ true /* FUSE_RESCALE */,
+ true /* FUSE_RELU */>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
+ }
+ } else {
+ if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ true /* FUSE_RESCALE */,
+ false /* FUSE_RELU */>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
+ } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ true /* FUSE_RESCALE */,
+ false /* FUSE_RELU */>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
+ } else if (1 == stride_h && 1 == stride_w) {
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ true /* FUSE_RESCALE */,
+ false /* FUSE_RELU */>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
+ } else if (2 == stride_h && 2 == stride_w) {
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ true /* FUSE_RESCALE */,
+ false /* FUSE_RELU */>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ true /* FUSE_RESCALE */,
+ false /* FUSE_RELU */>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
+ }
+ }
+}
+
+void depthwise_3x3x3_per_channel_quantization_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const Packed3x3x3ConvMatrix& B,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ bool fuse_relu,
+ int thread_id,
+ int num_threads) {
+ int32_t C_int32_temp[(K + 31) / 32 * 32];
+ if (fuse_relu) {
+ depthwise_3x3x3_per_channel_quantization_pad_1_<
+ true /* FUSE_RESCALE */,
+ true /* FUSE_RELU */>(
N,
+ T,
H,
W,
K,
+ stride_t,
stride_h,
stride_w,
A_zero_point,
A,
B_zero_point,
- Bp,
+ B,
C_multiplier,
C_zero_point,
C_int32_temp,
@@ -2716,17 +3105,21 @@ void depthwise_3x3_per_channel_quantization_pad_1(
thread_id,
num_threads);
} else {
- depthwise_3x3_per_channel_quantization_pad_1_(
+ depthwise_3x3x3_per_channel_quantization_pad_1_<
+ true /* FUSE_RESCALE */,
+ false /* FUSE_RELU */>(
N,
+ T,
H,
W,
K,
+ stride_t,
stride_h,
stride_w,
A_zero_point,
A,
B_zero_point,
- Bp,
+ B,
C_multiplier,
C_zero_point,
C_int32_temp,
diff --git a/src/FbgemmI8DepthwiseAvx2.h b/src/FbgemmI8DepthwiseAvx2.h
index f0c21a6..53c6e8a 100644
--- a/src/FbgemmI8DepthwiseAvx2.h
+++ b/src/FbgemmI8DepthwiseAvx2.h
@@ -71,9 +71,9 @@ FBGEMM_API void depthwise_3x3_pad_1(
std::uint8_t* C,
const std::int32_t* col_offsets,
const std::int32_t* bias,
+ bool fuse_relu = false,
int thread_id = 0,
- int num_threads = 1,
- bool fuse_relu = false);
+ int num_threads = 1);
/**
* Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8
@@ -95,6 +95,7 @@ FBGEMM_API void depthwise_3x3_per_channel_quantization_pad_1(
std::uint8_t* C,
const std::int32_t* col_offsets,
const std::int32_t* bias,
+ bool fuse_relu = false,
int thread_id = 0,
int num_threads = 1);
@@ -136,4 +137,26 @@ FBGEMM_API void depthwise_3x3x3_pad_1(
int thread_id = 0,
int num_threads = 1);
+FBGEMM_API void depthwise_3x3x3_per_channel_quantization_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ std::int32_t A_zero_point,
+ const std::uint8_t* A,
+ const std::int32_t* B_zero_point,
+ const Packed3x3x3ConvMatrix& Bp,
+ const float* C_multiplier,
+ std::int32_t C_zero_point,
+ std::uint8_t* C,
+ const std::int32_t* col_offsets,
+ const std::int32_t* bias,
+ bool fuse_relu = false,
+ int thread_id = 0,
+ int num_threads = 1);
+
} // namespace fbgemm
diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc
index 6aebc3d..5168a15 100644
--- a/src/RefImplementations.cc
+++ b/src/RefImplementations.cc
@@ -797,4 +797,93 @@ void depthwise_3x3x3_pad_1_ref(
}
};
+void depthwise_3x3x3_per_channel_quantization_pad_1_ref(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const int8_t* B,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias) {
+ constexpr int K_T = 3, K_H = 3, K_W = 3;
+ constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
+ PAD_R = 1;
+ int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
+ int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
+
+ vector<int32_t> C_int32(N * T_OUT * H_OUT * W_OUT * K);
+ depthwise_3x3x3_pad_1_ref(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B,
+ C_int32.data());
+
+ vector<int32_t> row_offsets(N * T_OUT * H_OUT * W_OUT * K);
+ for (int n = 0; n < N; ++n) {
+ for (int t = 0; t < T_OUT; ++t) {
+ for (int h = 0; h < H_OUT; ++h) {
+ for (int w = 0; w < W_OUT; ++w) {
+ for (int k = 0; k < K; ++k) {
+ int sum = 0;
+ for (int k_t = 0; k_t < K_T; ++k_t) {
+ int t_in = -PAD_P + t * stride_t + k_t;
+ for (int k_h = 0; k_h < K_H; ++k_h) {
+ int h_in = -PAD_T + h * stride_h + k_h;
+ for (int k_w = 0; k_w < K_W; ++k_w) {
+ int w_in = -PAD_L + w * stride_w + k_w;
+ int a = t_in < 0 || t_in >= T || h_in < 0 || h_in >= H ||
+ w_in < 0 || w_in >= W
+ ? A_zero_point
+ : A[(((n * T + t_in) * H + h_in) * W + w_in) * K + k];
+ sum += a;
+ }
+ }
+ }
+ row_offsets[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k] =
+ sum;
+ }
+ } // w
+ } // h
+ } // t
+ } // for each n
+
+ for (int i = 0; i < N * T_OUT * H_OUT * W_OUT; ++i) {
+ for (int k = 0; k < K; ++k) {
+ requantize_u8acc32_ref(
+ 1,
+ 1,
+ 1,
+ C_int32.data() + i * K + k,
+ C + i * K + k,
+ &C_multiplier[k],
+ C_zero_point,
+ A_zero_point,
+ &B_zero_point[k],
+ &row_offsets[i * K + k],
+ col_offsets + k,
+ bias ? bias + k : nullptr,
+ 1);
+ }
+ }
+};
+
} // namespace fbgemm
diff --git a/src/RefImplementations.h b/src/RefImplementations.h
index e2f52d5..fce68e6 100644
--- a/src/RefImplementations.h
+++ b/src/RefImplementations.h
@@ -298,4 +298,23 @@ FBGEMM_API void depthwise_3x3x3_pad_1_ref(
const std::int32_t* col_offsets,
const std::int32_t* bias);
+FBGEMM_API void depthwise_3x3x3_per_channel_quantization_pad_1_ref(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ std::int32_t A_zero_point,
+ const std::uint8_t* A,
+ const std::int32_t* B_zero_point,
+ const std::int8_t* B,
+ const float* C_multiplier,
+ std::int32_t C_zero_point,
+ std::uint8_t* C,
+ const std::int32_t* col_offsets,
+ const std::int32_t* bias);
+
} // namespace fbgemm
diff --git a/test/I8DepthwiseTest.cc b/test/I8DepthwiseTest.cc
index 7242b7a..ab44f0f 100644
--- a/test/I8DepthwiseTest.cc
+++ b/test/I8DepthwiseTest.cc
@@ -168,6 +168,7 @@ TEST(FBGemmDepthWiseTest, Test3x3) {
C_uint8.data(),
col_offsets.data(),
bias.data(),
+ false, /* fuse_relu */
0,
1);
@@ -317,7 +318,7 @@ TEST(FBGemmDepthWiseTest, Test3x3x3) {
C_uint8.data(),
col_offsets.data(),
bias.data(),
- false /* fuse_relu */,
+ false, /* fuse_relu */
0,
1);
@@ -441,6 +442,7 @@ TEST(FBGemmDepthWiseTest, Test3x3PerChannelQuantization) {
C_uint8.data(),
col_offsets.data(),
bias.data(),
+ false, /* fuse_relu */
0,
1);
@@ -462,4 +464,138 @@ TEST(FBGemmDepthWiseTest, Test3x3PerChannelQuantization) {
} // for each shape
} // Test3x3PerChannelQuantization
+TEST(FBGemmDepthWiseTest, Test3x3x3PerChannelQuantization) {
+ for (auto shape : shapes_3d) {
+ int N = shape[0];
+ int K = shape[1];
+ int T = shape[2];
+ int H = shape[3];
+ int W = shape[4];
+ int stride_t = shape[5];
+ int stride_h = stride_t;
+ int stride_w = stride_t;
+ constexpr int K_T = 3, K_H = 3, K_W = 3;
+ constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
+ PAD_R = 1;
+ int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
+ int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
+
+ aligned_vector<uint8_t> A(N * T * H * W * K);
+ aligned_vector<int8_t> B(K * K_T * K_H * K_W);
+ int32_t C_num_rows = N * T_OUT * H_OUT * W_OUT;
+ aligned_vector<int32_t> C_ref(C_num_rows * K), C(C_ref.size());
+
+ randFill<uint8_t>(A, 0, 86);
+ int32_t A_zero_point = 43;
+
+ // Each row of G has a different range to really test per-channel
+ // quantization.
+ vector<int32_t> B_zero_point(K);
+ for (auto k = 0; k < K; ++k) {
+ aligned_vector<int8_t> Bk(K_T * K_H * K_W);
+ randFill<int8_t>(Bk, -16 + k, 16 + k);
+ copy(Bk.begin(), Bk.end(), B.begin() + k * K_T * K_H * K_W);
+
+ B_zero_point[k] = 5 + k;
+ }
+
+ depthwise_3x3x3_pad_1_ref(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A.data(),
+ B.data(),
+ C_ref.data());
+
+ aligned_vector<int32_t> C_ref_transpose(C_ref);
+ transpose_matrix(C_ref.data(), C_num_rows, K);
+ vector<float> C_multiplier(K);
+ for (auto k = 0; k < K; ++k) {
+ auto C_ref_k_begin = C_ref_transpose.begin() + k * C_num_rows;
+ auto C_ref_k_end = C_ref_k_begin + C_num_rows;
+ int32_t minimum = *min_element(C_ref_k_begin, C_ref_k_end);
+ int32_t maximum = *max_element(C_ref_k_begin, C_ref_k_end);
+ C_multiplier[k] = 255. / (maximum - minimum);
+ cerr << "k " << k << " minimum " << minimum << " maximum " << maximum
+ << " multiplier " << C_multiplier[k] << endl;
+ }
+ int32_t C_zero_point = 5;
+
+ aligned_vector<int32_t> col_offsets(K);
+ aligned_vector<int32_t> bias(K);
+ randFill(col_offsets, -100, 100);
+ randFill(bias, -40, 40);
+
+ aligned_vector<uint8_t> C_uint8_ref(C_ref.size()), C_uint8(C_ref.size());
+ depthwise_3x3x3_per_channel_quantization_pad_1_ref(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A.data(),
+ B_zero_point.data(),
+ B.data(),
+ C_multiplier.data(),
+ C_zero_point,
+ C_uint8_ref.data(),
+ col_offsets.data(),
+ bias.data());
+
+ Packed3x3x3ConvMatrix Bp(K, B.data());
+
+ depthwise_3x3x3_per_channel_quantization_pad_1(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A.data(),
+ B_zero_point.data(),
+ Bp,
+ C_multiplier.data(),
+ C_zero_point,
+ C_uint8.data(),
+ col_offsets.data(),
+ bias.data(),
+ false, /* fuse_relu */
+ 0,
+ 1);
+
+ // correctness check
+ for (int n = 0; n < N; ++n) {
+ for (int t = 0; t < T_OUT; ++t) {
+ for (int h = 0; h < H_OUT; ++h) {
+ for (int w = 0; w < W_OUT; ++w) {
+ for (int k = 0; k < K; ++k) {
+ int32_t expected = C_uint8_ref
+ [(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k];
+ int32_t actual =
+ C_uint8[(((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K + k];
+ ASSERT_EQ(expected, actual)
+ << "Depthwise 3x3 results differ at (" << n << ", " << t
+ << ", " << h << ", " << w << ", " << k << ").";
+ }
+ } // w
+ } // h
+ } // t
+ } // n
+ } // for each shape
+} // Test3x3PerChannelQuantization
+
} // namespace fbgemm