Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJongsoo Park <jongsoo@fb.com>2019-02-14 07:35:32 +0300
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>2019-02-14 07:44:53 +0300
commit05ce78e3a5735217cb9154a2c1572dc956ffe6fc (patch)
tree6d2486304b84ef15887385ade7ea16b7b62a571e /src
parent7813a2f2233fa48199b18aa8c03bb439b1fe9ff5 (diff)
clean up depthwise conv interface (#72)
Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/72 depthwise conv without requantization is not really useful and was generating more template parameter options Reviewed By: jianyuh Differential Revision: D14021514 fbshipit-source-id: 61f646373fcd902fdb2854a96d003a548f29f8eb
Diffstat (limited to 'src')
-rw-r--r--src/FbgemmI8DepthwiseAvx2.cc789
-rw-r--r--src/FbgemmI8DepthwiseAvx2.h35
2 files changed, 346 insertions, 478 deletions
diff --git a/src/FbgemmI8DepthwiseAvx2.cc b/src/FbgemmI8DepthwiseAvx2.cc
index 555b93a..017c4c8 100644
--- a/src/FbgemmI8DepthwiseAvx2.cc
+++ b/src/FbgemmI8DepthwiseAvx2.cc
@@ -701,52 +701,6 @@ static inline __attribute__((always_inline)) void requantize_(
}
}
-template <bool FUSE_RELU, bool HAS_BIAS>
-static inline __attribute__((always_inline)) void requantize_(
- int32_t A_zero_point,
- float C_multiplier,
- int32_t C_zero_point,
- const int32_t* C_int32,
- uint8_t* C_uint8,
- int n,
- const int32_t* row_offsets,
- const int32_t* col_offsets,
- const int32_t* bias) {
- requantize_<FUSE_RELU, HAS_BIAS, false /* PER_CHANNEL_QUANTIZATION */>(
- A_zero_point,
- &C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8,
- n,
- row_offsets,
- col_offsets,
- bias);
-}
-
-template <bool FUSE_RELU, bool HAS_BIAS>
-static inline __attribute__((always_inline)) void requantize_per_channel_(
- int32_t A_zero_point,
- const float* C_multiplier,
- int32_t C_zero_point,
- const int32_t* C_int32,
- uint8_t* C_uint8,
- int n,
- const int32_t* row_offsets,
- const int32_t* col_offsets,
- const int32_t* bias) {
- requantize_<FUSE_RELU, HAS_BIAS, true /* PER_CHANNEL_QUANTIZATION */>(
- A_zero_point,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8,
- n,
- row_offsets,
- col_offsets,
- bias);
-}
-
template <bool REMAINDER>
static inline __attribute__((always_inline)) __m256i load_a(
const uint8_t* A,
@@ -1120,7 +1074,7 @@ static inline __attribute__((always_inline)) void inner_prod_3x3x3_packed_(
}
}
-template <bool SUM_A, bool FUSE_RELU>
+template <bool FUSE_RELU, bool HAS_BIAS>
static inline __attribute__((always_inline)) void depthwise_3x3_kernel_(
int H,
int W,
@@ -1148,7 +1102,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_kernel_(
int k;
for (k = 0; k < K / 32 * 32; k += 32) {
- inner_prod_3x3_packed_<SUM_A>(
+ inner_prod_3x3_packed_<true /*SUM_A*/>(
H,
W,
K,
@@ -1164,7 +1118,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_kernel_(
}
int remainder = K - k;
if (remainder) {
- inner_prod_3x3_packed_<SUM_A, true>(
+ inner_prod_3x3_packed_<true /*SUM_A*/, true /*REMAINDER*/>(
H,
W,
K,
@@ -1178,21 +1132,20 @@ static inline __attribute__((always_inline)) void depthwise_3x3_kernel_(
remainder,
&row_offsets[k]);
}
- if (SUM_A) {
- requantize_<FUSE_RELU, true>(
- A_zero_point,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8 + (h * W_OUT + w) * K,
- K,
- row_offsets,
- col_offsets,
- bias);
- }
+
+ requantize_<FUSE_RELU, HAS_BIAS, false /*PER_CHAN_QUANT*/>(
+ A_zero_point,
+ &C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8 + (h * W_OUT + w) * K,
+ K,
+ row_offsets,
+ col_offsets,
+ bias);
}
-template <bool SUM_A, bool FUSE_RELU>
+template <bool FUSE_RELU, bool HAS_BIAS>
static inline __attribute__((always_inline)) void depthwise_3x3x3_kernel_(
int T,
int H,
@@ -1225,7 +1178,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_kernel_(
int k;
for (k = 0; k < K / 32 * 32; k += 32) {
- inner_prod_3x3x3_packed_<SUM_A>(
+ inner_prod_3x3x3_packed_<true /*SUM_A*/>(
T,
H,
W,
@@ -1243,7 +1196,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_kernel_(
}
int remainder = K - k;
if (remainder) {
- inner_prod_3x3x3_packed_<SUM_A, true>(
+ inner_prod_3x3x3_packed_<true /*SUM_A*/, true /*REMAINDER*/>(
T,
H,
W,
@@ -1259,21 +1212,20 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_kernel_(
remainder,
&row_offsets[k]);
}
- if (SUM_A) {
- requantize_<FUSE_RELU, true>(
- A_zero_point,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K,
- K,
- row_offsets,
- col_offsets,
- bias);
- }
+
+ requantize_<FUSE_RELU, HAS_BIAS, false /*PER_CHAN_QUANT*/>(
+ A_zero_point,
+ &C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K,
+ K,
+ row_offsets,
+ col_offsets,
+ bias);
}
-template <bool SUM_A, bool FUSE_RELU>
+template <bool FUSE_RELU, bool HAS_BIAS>
static inline __attribute__((always_inline)) void
depthwise_3x3_per_channel_quantization_kernel_(
int H,
@@ -1302,7 +1254,10 @@ depthwise_3x3_per_channel_quantization_kernel_(
int k;
for (k = 0; k < K / 32 * 32; k += 32) {
- inner_prod_3x3_packed_<SUM_A, false /*remainder*/, true /*per-channel*/>(
+ inner_prod_3x3_packed_<
+ true /*SUM_A*/,
+ false /*remainder*/,
+ true /*per-channel*/>(
H,
W,
K,
@@ -1318,7 +1273,10 @@ depthwise_3x3_per_channel_quantization_kernel_(
}
int remainder = K - k;
if (remainder) {
- inner_prod_3x3_packed_<SUM_A, true /*remainder*/, true /*per-channel*/>(
+ inner_prod_3x3_packed_<
+ true /*SUM_A*/,
+ true /*remainder*/,
+ true /*per-channel*/>(
H,
W,
K,
@@ -1332,21 +1290,20 @@ depthwise_3x3_per_channel_quantization_kernel_(
remainder,
&row_offsets[k]);
}
- if (SUM_A) {
- requantize_per_channel_<FUSE_RELU, true /*HAS_BIAS*/>(
- A_zero_point,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8 + (h * W_OUT + w) * K,
- K,
- row_offsets,
- col_offsets,
- bias);
- }
+
+ requantize_<FUSE_RELU, HAS_BIAS, true /*PER_CHAN_QUANT*/>(
+ A_zero_point,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8 + (h * W_OUT + w) * K,
+ K,
+ row_offsets,
+ col_offsets,
+ bias);
}
-template <bool SUM_A, bool FUSE_RELU>
+template <bool FUSE_RELU, bool HAS_BIAS>
static inline __attribute__((always_inline)) void
depthwise_3x3x3_per_channel_quantization_kernel_(
int T,
@@ -1380,7 +1337,10 @@ depthwise_3x3x3_per_channel_quantization_kernel_(
int k;
for (k = 0; k < K / 32 * 32; k += 32) {
- inner_prod_3x3x3_packed_<SUM_A, false /*remainder*/, true /*per-channel*/>(
+ inner_prod_3x3x3_packed_<
+ true /*SUM_A*/,
+ false /*remainder*/,
+ true /*per-channel*/>(
T,
H,
W,
@@ -1398,7 +1358,10 @@ depthwise_3x3x3_per_channel_quantization_kernel_(
}
int remainder = K - k;
if (remainder) {
- inner_prod_3x3x3_packed_<SUM_A, true /*remainder*/, true /*per-channel*/>(
+ inner_prod_3x3x3_packed_<
+ true /*SUM_A*/,
+ true /*remainder*/,
+ true /*per-channel*/>(
T,
H,
W,
@@ -1414,18 +1377,16 @@ depthwise_3x3x3_per_channel_quantization_kernel_(
remainder,
&row_offsets[k]);
}
- if (SUM_A) {
- requantize_per_channel_<FUSE_RELU, true>(
- A_zero_point,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K,
- K,
- row_offsets,
- col_offsets,
- bias);
- }
+ requantize_<FUSE_RELU, HAS_BIAS, true /*PER_CHAN_QUANT*/>(
+ A_zero_point,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K,
+ K,
+ row_offsets,
+ col_offsets,
+ bias);
}
static pair<int, int> closest_factors_(int n) {
@@ -1440,7 +1401,7 @@ static pair<int, int> closest_factors_(int n) {
// This implemntation should be general enough to handle not just 3x3 but other
// filter shapes by parameterizing with R and S but restricting it to just 3x3
// for now.
-template <bool FUSE_RESCALE = true, bool FUSE_RELU = false>
+template <bool FUSE_RELU, bool HAS_BIAS>
static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_(
int N,
int H,
@@ -1468,7 +1429,6 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_(
const int8_t* Bp = B.PackedMat();
int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64)));
- int32_t* C_temp;
int n_begin, n_end;
int h_begin, h_end, w_begin, w_end;
@@ -1517,9 +1477,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_(
if (h_begin == 0) {
if (w_begin == 0) {
- C_temp = FUSE_RESCALE ? C_int32
- : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
- depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS>(
H,
W,
K,
@@ -1533,7 +1491,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_(
Bp,
C_multiplier,
C_zero_point,
- C_temp,
+ C_int32,
C_uint8_base,
row_offsets,
col_offsets,
@@ -1541,9 +1499,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_(
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
- C_temp = FUSE_RESCALE ? C_int32
- : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
- depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS>(
H,
W,
K,
@@ -1557,7 +1513,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_(
Bp,
C_multiplier,
C_zero_point,
- C_temp,
+ C_int32,
C_uint8_base,
row_offsets,
col_offsets,
@@ -1566,9 +1522,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_(
if (w_end == W_OUT) {
w = W_OUT - 1;
- C_temp = FUSE_RESCALE ? C_int32
- : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
- depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS>(
H,
W,
K,
@@ -1582,7 +1536,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_(
Bp,
C_multiplier,
C_zero_point,
- C_temp,
+ C_int32,
C_uint8_base,
row_offsets,
col_offsets,
@@ -1593,9 +1547,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_(
for (h = std::max(1, h_begin); h < std::min(H - 1, h_end); ++h) {
if (w_begin == 0) {
w = 0;
- C_temp = FUSE_RESCALE ? C_int32
- : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
- depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS>(
H,
W,
K,
@@ -1609,7 +1561,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_(
Bp,
C_multiplier,
C_zero_point,
- C_temp,
+ C_int32,
C_uint8_base,
row_offsets,
col_offsets,
@@ -1617,9 +1569,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_(
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
- C_temp = FUSE_RESCALE ? C_int32
- : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
- depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS>(
H,
W,
K,
@@ -1633,7 +1583,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_(
Bp,
C_multiplier,
C_zero_point,
- C_temp,
+ C_int32,
C_uint8_base,
row_offsets,
col_offsets,
@@ -1642,9 +1592,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_(
if (w_end == W_OUT) {
w = W_OUT - 1;
- C_temp = FUSE_RESCALE ? C_int32
- : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
- depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS>(
H,
W,
K,
@@ -1658,7 +1606,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_(
Bp,
C_multiplier,
C_zero_point,
- C_temp,
+ C_int32,
C_uint8_base,
row_offsets,
col_offsets,
@@ -1670,9 +1618,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_(
h = H_OUT - 1;
w = 0;
if (w_begin == 0) {
- C_temp = FUSE_RESCALE ? C_int32
- : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
- depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS>(
H,
W,
K,
@@ -1686,7 +1632,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_(
Bp,
C_multiplier,
C_zero_point,
- C_temp,
+ C_int32,
C_uint8_base,
row_offsets,
col_offsets,
@@ -1694,9 +1640,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_(
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
- C_temp = FUSE_RESCALE ? C_int32
- : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
- depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS>(
H,
W,
K,
@@ -1710,7 +1654,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_(
Bp,
C_multiplier,
C_zero_point,
- C_temp,
+ C_int32,
C_uint8_base,
row_offsets,
col_offsets,
@@ -1719,9 +1663,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_(
if (w_end == W_OUT) {
w = W_OUT - 1;
- C_temp = FUSE_RESCALE ? C_int32
- : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
- depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS>(
H,
W,
K,
@@ -1735,7 +1677,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_(
Bp,
C_multiplier,
C_zero_point,
- C_temp,
+ C_int32,
C_uint8_base,
row_offsets,
col_offsets,
@@ -1745,7 +1687,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_(
} // for each n
};
-template <bool FUSE_RESCALE = true, bool FUSE_RELU = false>
+template <bool FUSE_RELU, bool HAS_BIAS>
static inline __attribute__((always_inline)) void depthwise_3x3x3_pad_1_(
int N,
int T,
@@ -1777,7 +1719,6 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_pad_1_(
const int8_t* Bp = B.PackedMat();
int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64)));
- int32_t* C_temp;
int n_begin, n_end;
int t_begin, t_end, h_begin, h_end;
@@ -1824,10 +1765,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_pad_1_(
for (int t = t_begin; t < t_end; ++t) {
for (int h = h_begin; h < h_end; ++h) {
for (int w = 0; w < W_OUT; ++w) {
- C_temp = FUSE_RESCALE
- ? C_int32
- : C_int32 + (((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K;
- depthwise_3x3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ depthwise_3x3x3_kernel_<FUSE_RELU, HAS_BIAS>(
T,
H,
W,
@@ -1844,7 +1782,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_pad_1_(
Bp,
C_multiplier,
C_zero_point,
- C_temp,
+ C_int32,
C_uint8_base,
row_offsets,
col_offsets,
@@ -1855,7 +1793,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_pad_1_(
} // for each n
};
-template <bool FUSE_RESCALE = true, bool FUSE_RELU = false>
+template <bool FUSE_RELU, bool HAS_BIAS>
static inline __attribute__((always_inline)) void
depthwise_3x3_per_channel_quantization_pad_1_(
int N,
@@ -1884,7 +1822,6 @@ depthwise_3x3_per_channel_quantization_pad_1_(
const int8_t* Bp = B.PackedMat();
int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64)));
- int32_t* C_temp;
int n_begin, n_end;
int h_begin, h_end, w_begin, w_end;
@@ -1933,9 +1870,7 @@ depthwise_3x3_per_channel_quantization_pad_1_(
if (h_begin == 0) {
if (w_begin == 0) {
- C_temp = FUSE_RESCALE ? C_int32
- : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
- depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RELU, HAS_BIAS>(
H,
W,
K,
@@ -1949,7 +1884,7 @@ depthwise_3x3_per_channel_quantization_pad_1_(
Bp,
C_multiplier,
C_zero_point,
- C_temp,
+ C_int32,
C_uint8_base,
row_offsets,
col_offsets,
@@ -1957,9 +1892,7 @@ depthwise_3x3_per_channel_quantization_pad_1_(
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
- C_temp = FUSE_RESCALE ? C_int32
- : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
- depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RELU, HAS_BIAS>(
H,
W,
K,
@@ -1973,7 +1906,7 @@ depthwise_3x3_per_channel_quantization_pad_1_(
Bp,
C_multiplier,
C_zero_point,
- C_temp,
+ C_int32,
C_uint8_base,
row_offsets,
col_offsets,
@@ -1982,9 +1915,7 @@ depthwise_3x3_per_channel_quantization_pad_1_(
if (w_end == W_OUT) {
w = W_OUT - 1;
- C_temp = FUSE_RESCALE ? C_int32
- : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
- depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RELU, HAS_BIAS>(
H,
W,
K,
@@ -1998,7 +1929,7 @@ depthwise_3x3_per_channel_quantization_pad_1_(
Bp,
C_multiplier,
C_zero_point,
- C_temp,
+ C_int32,
C_uint8_base,
row_offsets,
col_offsets,
@@ -2009,9 +1940,7 @@ depthwise_3x3_per_channel_quantization_pad_1_(
for (h = std::max(1, h_begin); h < std::min(H - 1, h_end); ++h) {
if (w_begin == 0) {
w = 0;
- C_temp = FUSE_RESCALE ? C_int32
- : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
- depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RELU, HAS_BIAS>(
H,
W,
K,
@@ -2025,7 +1954,7 @@ depthwise_3x3_per_channel_quantization_pad_1_(
Bp,
C_multiplier,
C_zero_point,
- C_temp,
+ C_int32,
C_uint8_base,
row_offsets,
col_offsets,
@@ -2033,9 +1962,7 @@ depthwise_3x3_per_channel_quantization_pad_1_(
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
- C_temp = FUSE_RESCALE ? C_int32
- : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
- depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RELU, HAS_BIAS>(
H,
W,
K,
@@ -2049,7 +1976,7 @@ depthwise_3x3_per_channel_quantization_pad_1_(
Bp,
C_multiplier,
C_zero_point,
- C_temp,
+ C_int32,
C_uint8_base,
row_offsets,
col_offsets,
@@ -2058,9 +1985,7 @@ depthwise_3x3_per_channel_quantization_pad_1_(
if (w_end == W_OUT) {
w = W_OUT - 1;
- C_temp = FUSE_RESCALE ? C_int32
- : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
- depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RELU, HAS_BIAS>(
H,
W,
K,
@@ -2074,7 +1999,7 @@ depthwise_3x3_per_channel_quantization_pad_1_(
Bp,
C_multiplier,
C_zero_point,
- C_temp,
+ C_int32,
C_uint8_base,
row_offsets,
col_offsets,
@@ -2086,9 +2011,7 @@ depthwise_3x3_per_channel_quantization_pad_1_(
h = H_OUT - 1;
w = 0;
if (w_begin == 0) {
- C_temp = FUSE_RESCALE ? C_int32
- : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
- depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RELU, HAS_BIAS>(
H,
W,
K,
@@ -2102,7 +2025,7 @@ depthwise_3x3_per_channel_quantization_pad_1_(
Bp,
C_multiplier,
C_zero_point,
- C_temp,
+ C_int32,
C_uint8_base,
row_offsets,
col_offsets,
@@ -2110,9 +2033,7 @@ depthwise_3x3_per_channel_quantization_pad_1_(
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
- C_temp = FUSE_RESCALE ? C_int32
- : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
- depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RELU, HAS_BIAS>(
H,
W,
K,
@@ -2126,7 +2047,7 @@ depthwise_3x3_per_channel_quantization_pad_1_(
Bp,
C_multiplier,
C_zero_point,
- C_temp,
+ C_int32,
C_uint8_base,
row_offsets,
col_offsets,
@@ -2135,9 +2056,7 @@ depthwise_3x3_per_channel_quantization_pad_1_(
if (w_end == W_OUT) {
w = W_OUT - 1;
- C_temp = FUSE_RESCALE ? C_int32
- : C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
- depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE, FUSE_RELU>(
+ depthwise_3x3_per_channel_quantization_kernel_<FUSE_RELU, HAS_BIAS>(
H,
W,
K,
@@ -2151,7 +2070,7 @@ depthwise_3x3_per_channel_quantization_pad_1_(
Bp,
C_multiplier,
C_zero_point,
- C_temp,
+ C_int32,
C_uint8_base,
row_offsets,
col_offsets,
@@ -2161,7 +2080,7 @@ depthwise_3x3_per_channel_quantization_pad_1_(
} // for each n
};
-template <bool FUSE_RESCALE = true, bool FUSE_RELU = false>
+template <bool FUSE_RELU, bool HAS_BIAS>
static inline __attribute__((always_inline)) void
depthwise_3x3x3_per_channel_quantization_pad_1_(
int N,
@@ -2194,7 +2113,6 @@ depthwise_3x3x3_per_channel_quantization_pad_1_(
const int8_t* Bp = B.PackedMat();
int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64)));
- int32_t* C_temp;
int n_begin, n_end;
int t_begin, t_end, h_begin, h_end;
@@ -2241,12 +2159,7 @@ depthwise_3x3x3_per_channel_quantization_pad_1_(
for (int t = t_begin; t < t_end; ++t) {
for (int h = h_begin; h < h_end; ++h) {
for (int w = 0; w < W_OUT; ++w) {
- C_temp = FUSE_RESCALE
- ? C_int32
- : C_int32 + (((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K;
- depthwise_3x3x3_per_channel_quantization_kernel_<
- FUSE_RESCALE,
- FUSE_RELU>(
+ depthwise_3x3x3_per_channel_quantization_kernel_<FUSE_RELU, HAS_BIAS>(
T,
H,
W,
@@ -2263,7 +2176,7 @@ depthwise_3x3x3_per_channel_quantization_pad_1_(
Bp,
C_multiplier,
C_zero_point,
- C_temp,
+ C_int32,
C_uint8_base,
row_offsets,
col_offsets,
@@ -2274,8 +2187,9 @@ depthwise_3x3x3_per_channel_quantization_pad_1_(
} // for each n
};
-// assumption: W > 3 and H > 3
-void depthwise_3x3_pad_1(
+// Dispatch HAS_BIAS
+template <bool FUSE_RELU>
+static void depthwise_3x3_pad_1_(
int N,
int H,
int W,
@@ -2284,72 +2198,18 @@ void depthwise_3x3_pad_1(
int stride_w,
int32_t A_zero_point,
const uint8_t* A,
+ int32_t B_zero_point,
const Packed3x3ConvMatrix& B,
- int32_t* C,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias,
int thread_id,
int num_threads) {
- if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
- depthwise_3x3_pad_1_<false>(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- 0,
- B,
- 0.0f,
- 0,
- C,
- nullptr,
- nullptr,
- nullptr,
- thread_id,
- num_threads);
- } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
- depthwise_3x3_pad_1_<false>(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- 0,
- B,
- 0.0f,
- 0,
- C,
- nullptr,
- nullptr,
- nullptr,
- thread_id,
- num_threads);
- } else if (1 == stride_h && 1 == stride_w) {
- depthwise_3x3_pad_1_<false>(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- 0,
- B,
- 0.0f,
- 0,
- C,
- nullptr,
- nullptr,
- nullptr,
- thread_id,
- num_threads);
- } else if (2 == stride_h && 2 == stride_w) {
- depthwise_3x3_pad_1_<false>(
+ int32_t C_int32_temp[(K + 31) / 32 * 32];
+ if (bias) {
+ depthwise_3x3_pad_1_<FUSE_RELU, true /*HAS_BIAS*/>(
N,
H,
W,
@@ -2358,18 +2218,18 @@ void depthwise_3x3_pad_1(
stride_w,
A_zero_point,
A,
- 0,
+ B_zero_point,
B,
- 0.0f,
- 0,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
C,
- nullptr,
- nullptr,
- nullptr,
+ col_offsets,
+ bias,
thread_id,
num_threads);
} else {
- depthwise_3x3_pad_1_<false>(
+ depthwise_3x3_pad_1_<FUSE_RELU, false /*HAS_BIAS*/>(
N,
H,
W,
@@ -2378,19 +2238,21 @@ void depthwise_3x3_pad_1(
stride_w,
A_zero_point,
A,
- 0,
+ B_zero_point,
B,
- 0.0f,
- 0,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
C,
- nullptr,
- nullptr,
- nullptr,
+ col_offsets,
+ bias,
thread_id,
num_threads);
}
}
+// Dispatch input shape and FUSE_RELU
+// assumption: W > 3 and H > 3
void depthwise_3x3_pad_1(
int N,
int H,
@@ -2410,10 +2272,9 @@ void depthwise_3x3_pad_1(
bool fuse_relu,
int thread_id,
int num_threads) {
- int32_t C_int32_temp[(K + 31) / 32 * 32];
if (fuse_relu) {
if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
- depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<true /* FUSE_RELU */>(
N,
H,
W,
@@ -2426,14 +2287,13 @@ void depthwise_3x3_pad_1(
B,
C_multiplier,
C_zero_point,
- C_int32_temp,
C,
col_offsets,
bias,
thread_id,
num_threads);
} else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
- depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<true /* FUSE_RELU */>(
N,
H,
W,
@@ -2446,14 +2306,13 @@ void depthwise_3x3_pad_1(
B,
C_multiplier,
C_zero_point,
- C_int32_temp,
C,
col_offsets,
bias,
thread_id,
num_threads);
} else if (1 == stride_h && 1 == stride_w) {
- depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<true /* FUSE_RELU */>(
N,
H,
W,
@@ -2466,14 +2325,13 @@ void depthwise_3x3_pad_1(
B,
C_multiplier,
C_zero_point,
- C_int32_temp,
C,
col_offsets,
bias,
thread_id,
num_threads);
} else if (2 == stride_h && 2 == stride_w) {
- depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<true /* FUSE_RELU */>(
N,
H,
W,
@@ -2486,14 +2344,13 @@ void depthwise_3x3_pad_1(
B,
C_multiplier,
C_zero_point,
- C_int32_temp,
C,
col_offsets,
bias,
thread_id,
num_threads);
} else {
- depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<true /* FUSE_RELU */>(
N,
H,
W,
@@ -2506,7 +2363,6 @@ void depthwise_3x3_pad_1(
B,
C_multiplier,
C_zero_point,
- C_int32_temp,
C,
col_offsets,
bias,
@@ -2515,7 +2371,7 @@ void depthwise_3x3_pad_1(
}
} else {
if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
- depthwise_3x3_pad_1_(
+ depthwise_3x3_pad_1_<false /* FUSE_RELU */>(
N,
H,
W,
@@ -2528,14 +2384,13 @@ void depthwise_3x3_pad_1(
B,
C_multiplier,
C_zero_point,
- C_int32_temp,
C,
col_offsets,
bias,
thread_id,
num_threads);
} else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
- depthwise_3x3_pad_1_(
+ depthwise_3x3_pad_1_<false /* FUSE_RELU */>(
N,
H,
W,
@@ -2548,14 +2403,13 @@ void depthwise_3x3_pad_1(
B,
C_multiplier,
C_zero_point,
- C_int32_temp,
C,
col_offsets,
bias,
thread_id,
num_threads);
} else if (1 == stride_h && 1 == stride_w) {
- depthwise_3x3_pad_1_(
+ depthwise_3x3_pad_1_<false /* FUSE_RELU */>(
N,
H,
W,
@@ -2568,14 +2422,13 @@ void depthwise_3x3_pad_1(
B,
C_multiplier,
C_zero_point,
- C_int32_temp,
C,
col_offsets,
bias,
thread_id,
num_threads);
} else if (2 == stride_h && 2 == stride_w) {
- depthwise_3x3_pad_1_(
+ depthwise_3x3_pad_1_<false /* FUSE_RELU */>(
N,
H,
W,
@@ -2588,14 +2441,13 @@ void depthwise_3x3_pad_1(
B,
C_multiplier,
C_zero_point,
- C_int32_temp,
C,
col_offsets,
bias,
thread_id,
num_threads);
} else {
- depthwise_3x3_pad_1_(
+ depthwise_3x3_pad_1_<false /* FUSE_RELU */>(
N,
H,
W,
@@ -2608,7 +2460,6 @@ void depthwise_3x3_pad_1(
B,
C_multiplier,
C_zero_point,
- C_int32_temp,
C,
col_offsets,
bias,
@@ -2618,44 +2469,8 @@ void depthwise_3x3_pad_1(
}
}
-void depthwise_3x3x3_pad_1(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- const Packed3x3x3ConvMatrix& B,
- int32_t* C,
- int thread_id,
- int num_threads) {
- depthwise_3x3x3_pad_1_<false /* FUSE_RESCALE */>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- 0,
- B,
- 0.0f,
- 0,
- C,
- nullptr,
- nullptr,
- nullptr,
- thread_id,
- num_threads);
-}
-
+// Dispatch HAS_BIAS
+template <bool FUSE_RELU>
static void depthwise_3x3x3_pad_1_(
int N,
int T,
@@ -2677,30 +2492,55 @@ static void depthwise_3x3x3_pad_1_(
int thread_id,
int num_threads) {
int32_t C_int32_temp[(K + 31) / 32 * 32];
- depthwise_3x3x3_pad_1_<true /* FUSE_RESCALE */, false /* FUSE_RELU */>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
+ if (bias) {
+ depthwise_3x3x3_pad_1_<FUSE_RELU, true /*HAS_BIAS*/>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_3x3x3_pad_1_<FUSE_RELU, false /*HAS_BIAS*/>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
+ }
}
-static void depthwise_3x3x3_pad_1_relu_fused_(
+// Dispatch FUSE_RELU
+void depthwise_3x3x3_pad_1(
int N,
int T,
int H,
@@ -2718,91 +2558,114 @@ static void depthwise_3x3x3_pad_1_relu_fused_(
uint8_t* C,
const int32_t* col_offsets,
const int32_t* bias,
+ bool fuse_relu,
int thread_id,
int num_threads) {
- int32_t C_int32_temp[(K + 31) / 32 * 32];
- depthwise_3x3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
+ if (fuse_relu) {
+ depthwise_3x3x3_pad_1_<true /*FUSE_RELU*/>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_3x3x3_pad_1_<false /*FUSE_RELU*/>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
+ }
}
-void depthwise_3x3x3_pad_1(
+// Dispatch HAS_BIAS
+template <bool FUSE_RELU>
+static void depthwise_3x3_per_channel_quantization_pad_1_(
int N,
- int T,
int H,
int W,
int K,
- int stride_t,
int stride_h,
int stride_w,
int32_t A_zero_point,
const uint8_t* A,
- int32_t B_zero_point,
- const Packed3x3x3ConvMatrix& B,
- float C_multiplier,
+ const int32_t* B_zero_point,
+ const Packed3x3ConvMatrix& Bp,
+ const float* C_multiplier,
int32_t C_zero_point,
uint8_t* C,
const int32_t* col_offsets,
const int32_t* bias,
- bool fuse_relu,
int thread_id,
int num_threads) {
- // If we inline the following two functions, I see stack overflow.
- if (fuse_relu) {
- depthwise_3x3x3_pad_1_relu_fused_(
+ int32_t C_int32_temp[(K + 31) / 32 * 32];
+ if (bias) {
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ FUSE_RELU,
+ true /* HAS_BIAS */>(
N,
- T,
H,
W,
K,
- stride_t,
stride_h,
stride_w,
A_zero_point,
A,
B_zero_point,
- B,
+ Bp,
C_multiplier,
C_zero_point,
+ C_int32_temp,
C,
col_offsets,
bias,
thread_id,
num_threads);
} else {
- depthwise_3x3x3_pad_1_(
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ FUSE_RELU,
+ false /* HAS_BIAS */>(
N,
- T,
H,
W,
K,
- stride_t,
stride_h,
stride_w,
A_zero_point,
A,
B_zero_point,
- B,
+ Bp,
C_multiplier,
C_zero_point,
+ C_int32_temp,
C,
col_offsets,
bias,
@@ -2811,6 +2674,7 @@ void depthwise_3x3x3_pad_1(
}
}
+// Dispatch input shape and FUSE_RELU
void depthwise_3x3_per_channel_quantization_pad_1(
int N,
int H,
@@ -2830,12 +2694,9 @@ void depthwise_3x3_per_channel_quantization_pad_1(
bool fuse_relu,
int thread_id,
int num_threads) {
- int32_t C_int32_temp[(K + 31) / 32 * 32];
if (fuse_relu) {
if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<
- true /* FUSE_RESCALE */,
- true /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>(
N,
H,
W,
@@ -2848,16 +2709,13 @@ void depthwise_3x3_per_channel_quantization_pad_1(
Bp,
C_multiplier,
C_zero_point,
- C_int32_temp,
C,
col_offsets,
bias,
thread_id,
num_threads);
} else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<
- true /* FUSE_RESCALE */,
- true /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>(
N,
H,
W,
@@ -2870,16 +2728,13 @@ void depthwise_3x3_per_channel_quantization_pad_1(
Bp,
C_multiplier,
C_zero_point,
- C_int32_temp,
C,
col_offsets,
bias,
thread_id,
num_threads);
} else if (1 == stride_h && 1 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<
- true /* FUSE_RESCALE */,
- true /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>(
N,
H,
W,
@@ -2892,16 +2747,13 @@ void depthwise_3x3_per_channel_quantization_pad_1(
Bp,
C_multiplier,
C_zero_point,
- C_int32_temp,
C,
col_offsets,
bias,
thread_id,
num_threads);
} else if (2 == stride_h && 2 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<
- true /* FUSE_RESCALE */,
- true /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>(
N,
H,
W,
@@ -2914,16 +2766,13 @@ void depthwise_3x3_per_channel_quantization_pad_1(
Bp,
C_multiplier,
C_zero_point,
- C_int32_temp,
C,
col_offsets,
bias,
thread_id,
num_threads);
} else {
- depthwise_3x3_per_channel_quantization_pad_1_<
- true /* FUSE_RESCALE */,
- true /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>(
N,
H,
W,
@@ -2936,7 +2785,6 @@ void depthwise_3x3_per_channel_quantization_pad_1(
Bp,
C_multiplier,
C_zero_point,
- C_int32_temp,
C,
col_offsets,
bias,
@@ -2945,9 +2793,7 @@ void depthwise_3x3_per_channel_quantization_pad_1(
}
} else {
if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<
- true /* FUSE_RESCALE */,
- false /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>(
N,
H,
W,
@@ -2960,16 +2806,13 @@ void depthwise_3x3_per_channel_quantization_pad_1(
Bp,
C_multiplier,
C_zero_point,
- C_int32_temp,
C,
col_offsets,
bias,
thread_id,
num_threads);
} else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<
- true /* FUSE_RESCALE */,
- false /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>(
N,
H,
W,
@@ -2982,16 +2825,13 @@ void depthwise_3x3_per_channel_quantization_pad_1(
Bp,
C_multiplier,
C_zero_point,
- C_int32_temp,
C,
col_offsets,
bias,
thread_id,
num_threads);
} else if (1 == stride_h && 1 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<
- true /* FUSE_RESCALE */,
- false /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>(
N,
H,
W,
@@ -3004,16 +2844,13 @@ void depthwise_3x3_per_channel_quantization_pad_1(
Bp,
C_multiplier,
C_zero_point,
- C_int32_temp,
C,
col_offsets,
bias,
thread_id,
num_threads);
} else if (2 == stride_h && 2 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<
- true /* FUSE_RESCALE */,
- false /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>(
N,
H,
W,
@@ -3026,16 +2863,13 @@ void depthwise_3x3_per_channel_quantization_pad_1(
Bp,
C_multiplier,
C_zero_point,
- C_int32_temp,
C,
col_offsets,
bias,
thread_id,
num_threads);
} else {
- depthwise_3x3_per_channel_quantization_pad_1_<
- true /* FUSE_RESCALE */,
- false /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>(
N,
H,
W,
@@ -3048,7 +2882,6 @@ void depthwise_3x3_per_channel_quantization_pad_1(
Bp,
C_multiplier,
C_zero_point,
- C_int32_temp,
C,
col_offsets,
bias,
@@ -3058,7 +2891,9 @@ void depthwise_3x3_per_channel_quantization_pad_1(
}
}
-void depthwise_3x3x3_per_channel_quantization_pad_1(
+// Dispatch HAS_BIAS
+template <bool FUSE_RELU>
+static void depthwise_3x3x3_per_channel_quantization_pad_1_(
int N,
int T,
int H,
@@ -3076,14 +2911,13 @@ void depthwise_3x3x3_per_channel_quantization_pad_1(
uint8_t* C,
const int32_t* col_offsets,
const int32_t* bias,
- bool fuse_relu,
int thread_id,
int num_threads) {
int32_t C_int32_temp[(K + 31) / 32 * 32];
- if (fuse_relu) {
+ if (bias) {
depthwise_3x3x3_per_channel_quantization_pad_1_<
- true /* FUSE_RESCALE */,
- true /* FUSE_RELU */>(
+ FUSE_RELU,
+ true /* HAS_BIAS */>(
N,
T,
H,
@@ -3106,8 +2940,8 @@ void depthwise_3x3x3_per_channel_quantization_pad_1(
num_threads);
} else {
depthwise_3x3x3_per_channel_quantization_pad_1_<
- true /* FUSE_RESCALE */,
- false /* FUSE_RELU */>(
+ FUSE_RELU,
+ false /* HAS_BIAS */>(
N,
T,
H,
@@ -3131,4 +2965,71 @@ void depthwise_3x3x3_per_channel_quantization_pad_1(
}
}
+// Dispatch FUSE_RELU
+void depthwise_3x3x3_per_channel_quantization_pad_1(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const Packed3x3x3ConvMatrix& B,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ bool fuse_relu,
+ int thread_id,
+ int num_threads) {
+ if (fuse_relu) {
+ depthwise_3x3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_3x3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
+ }
+}
+
} // namespace fbgemm
diff --git a/src/FbgemmI8DepthwiseAvx2.h b/src/FbgemmI8DepthwiseAvx2.h
index 53c6e8a..9fd16ae 100644
--- a/src/FbgemmI8DepthwiseAvx2.h
+++ b/src/FbgemmI8DepthwiseAvx2.h
@@ -34,6 +34,7 @@ using Packed3x3x3ConvMatrix = PackedDepthWiseConvMatrix<3 * 3 * 3>;
/**
* Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8
+ *
* @params A The input image in NHWK layout
* @params Bp The pre-packed filter
*/
@@ -46,24 +47,6 @@ FBGEMM_API void depthwise_3x3_pad_1(
int stride_w,
std::int32_t A_zero_point,
const std::uint8_t* A,
- const Packed3x3ConvMatrix& Bp,
- std::int32_t* C,
- int thread_id = 0,
- int num_threads = 1);
-
-/**
- * Depth-wise 3x3 convolution with pad=1 and stride=1 and K a multiple of 8
- * This version is fused with requantization.
- */
-FBGEMM_API void depthwise_3x3_pad_1(
- int N,
- int H,
- int W,
- int K,
- int stride_h,
- int stride_w,
- std::int32_t A_zero_point,
- const std::uint8_t* A,
std::int32_t B_zero_point,
const Packed3x3ConvMatrix& Bp,
float C_multiplier,
@@ -110,22 +93,6 @@ FBGEMM_API void depthwise_3x3x3_pad_1(
int stride_w,
std::int32_t A_zero_point,
const std::uint8_t* A,
- const Packed3x3x3ConvMatrix& Bp,
- std::int32_t* C,
- int thread_id = 0,
- int num_threads = 1);
-
-FBGEMM_API void depthwise_3x3x3_pad_1(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- std::int32_t A_zero_point,
- const std::uint8_t* A,
std::int32_t B_zero_point,
const Packed3x3x3ConvMatrix& Bp,
float C_multiplier,