Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/FbgemmI8Depthwise.cc')
-rw-r--r--src/FbgemmI8Depthwise.cc1893
1 files changed, 1355 insertions, 538 deletions
diff --git a/src/FbgemmI8Depthwise.cc b/src/FbgemmI8Depthwise.cc
index 54e2272..551e98e 100644
--- a/src/FbgemmI8Depthwise.cc
+++ b/src/FbgemmI8Depthwise.cc
@@ -18,8 +18,7 @@
using namespace std;
-namespace fbgemm2
-{
+namespace fbgemm2 {
static array<array<int, 8>, 8> masks = {{
{ 0, 0, 0, 0, 0, 0, 0, 0, },
@@ -34,7 +33,8 @@ static array<array<int, 8>, 8> masks = {{
template <int KERNEL_PROD>
PackedDepthWiseConvMatrix<KERNEL_PROD>::PackedDepthWiseConvMatrix(
- int K, const int8_t *smat)
+ int K,
+ const int8_t* smat)
: K_(K) {
// Transpose the input matrix to make packing faster.
vector<int8_t> smat_transposed(K * KERNEL_PROD);
@@ -46,8 +46,12 @@ PackedDepthWiseConvMatrix<KERNEL_PROD>::PackedDepthWiseConvMatrix(
// Allocate packed arrays
constexpr int KERNEL_PROD_ALIGNED = (KERNEL_PROD + 1) / 2 * 2;
- pmat_ = static_cast<int8_t *>(aligned_alloc(
- 64, ((K + 31) / 32) * KERNEL_PROD_ALIGNED * 32 * sizeof(int8_t)));
+ // pmat_ = static_cast<int8_t *>(fbgemmAlignedAlloc(
+ // 64, ((K + 31) / 32) * KERNEL_PROD_ALIGNED * 32 * sizeof(int8_t)));
+ posix_memalign(
+ (void**)&pmat_,
+ 64,
+ ((K + 31) / 32) * KERNEL_PROD_ALIGNED * 32 * sizeof(int8_t));
// Pack input matrix
// The layout is optimized to use vpmaddubsw efficiently (see
@@ -106,15 +110,15 @@ PackedDepthWiseConvMatrix<KERNEL_PROD>::PackedDepthWiseConvMatrix(
int remainder = K - k1;
if (remainder < 32) {
__m256i mask_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i *>(masks[remainder / 4].data()));
+ reinterpret_cast<const __m256i*>(masks[remainder / 4].data()));
for (int i = 0; i < KERNEL_PROD; ++i) {
b_v[i] = _mm256_maskload_epi32(
- reinterpret_cast<const int *>(smat_transposed.data() + i * K + k1),
+ reinterpret_cast<const int*>(smat_transposed.data() + i * K + k1),
mask_v);
}
} else {
for (int i = 0; i < KERNEL_PROD; ++i) {
- b_v[i] = _mm256_lddqu_si256(reinterpret_cast<const __m256i *>(
+ b_v[i] = _mm256_lddqu_si256(reinterpret_cast<const __m256i*>(
smat_transposed.data() + i * K + k1));
}
}
@@ -153,7 +157,7 @@ PackedDepthWiseConvMatrix<KERNEL_PROD>::PackedDepthWiseConvMatrix(
for (int i = 0; i < KERNEL_PROD_ALIGNED; ++i) {
_mm256_storeu_si256(
- reinterpret_cast<__m256i *>(
+ reinterpret_cast<__m256i*>(
&pmat_[((k1 / 32) * KERNEL_PROD_ALIGNED + i) * 32]),
b_interleaved_epi32[i]);
}
@@ -161,8 +165,7 @@ PackedDepthWiseConvMatrix<KERNEL_PROD>::PackedDepthWiseConvMatrix(
}
template <int KERNEL_PROD>
-PackedDepthWiseConvMatrix<KERNEL_PROD>::~PackedDepthWiseConvMatrix()
-{
+PackedDepthWiseConvMatrix<KERNEL_PROD>::~PackedDepthWiseConvMatrix() {
free(pmat_);
}
@@ -178,11 +181,16 @@ template class PackedDepthWiseConvMatrix<3 * 3 * 3>;
// c2_v: c[8:12], c[24:28]
// c3_v: c[12:16], c[28:32]
template <bool SUM_A = false>
-static inline __attribute__((always_inline))
-void madd_epi16x4_packed(
- __m256i a0_v, __m256i a1_v, __m256i a2_v, __m256i a3_v,
+static inline __attribute__((always_inline)) void madd_epi16x4_packed(
+ __m256i a0_v,
+ __m256i a1_v,
+ __m256i a2_v,
+ __m256i a3_v,
const __m256i* b,
- __m256i* c0_v, __m256i* c1_v, __m256i* c2_v, __m256i* c3_v,
+ __m256i* c0_v,
+ __m256i* c1_v,
+ __m256i* c2_v,
+ __m256i* c3_v,
__m256i* a_sum = nullptr) {
__m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v);
__m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v);
@@ -232,11 +240,15 @@ void madd_epi16x4_packed(
// c2_v: c[8:12], c[24:28]
// c3_v: c[12:16], c[28:32]
template <bool SUM_A = false>
-static inline __attribute__((always_inline))
-void madd_epi16x3_packed(
- __m256i a0_v, __m256i a1_v, __m256i a2_v,
+static inline __attribute__((always_inline)) void madd_epi16x3_packed(
+ __m256i a0_v,
+ __m256i a1_v,
+ __m256i a2_v,
const __m256i* b,
- __m256i* c0_v, __m256i* c1_v, __m256i* c2_v, __m256i* c3_v,
+ __m256i* c0_v,
+ __m256i* c1_v,
+ __m256i* c2_v,
+ __m256i* c3_v,
__m256i* a_sum = nullptr) {
__m256i zero_v = _mm256_setzero_si256();
@@ -288,10 +300,15 @@ void madd_epi16x3_packed(
// c2_v: c[16:20], c[20:24]
// c3_v: c[24:28], c[28:32]
template <bool SUM_A = false>
-static inline __attribute__((always_inline)) void
-madd_epi16x2_packed(__m256i a0_v, __m256i a1_v, const __m256i *b, __m256i *c0_v,
- __m256i *c1_v, __m256i *c2_v, __m256i *c3_v,
- __m256i *a_sum = nullptr) {
+static inline __attribute__((always_inline)) void madd_epi16x2_packed(
+ __m256i a0_v,
+ __m256i a1_v,
+ const __m256i* b,
+ __m256i* c0_v,
+ __m256i* c1_v,
+ __m256i* c2_v,
+ __m256i* c3_v,
+ __m256i* a_sum = nullptr) {
__m256i a_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v);
__m256i a_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v);
@@ -324,9 +341,14 @@ madd_epi16x2_packed(__m256i a0_v, __m256i a1_v, const __m256i *b, __m256i *c0_v,
// c2_v: c[16:20], c[20:24]
// c3_v: c[24:28], c[28:32]
template <bool SUM_A = false>
-static inline __attribute__((always_inline)) void
-madd_epi16_packed(__m256i a_v, const __m256i *b, __m256i *c0_v, __m256i *c1_v,
- __m256i *c2_v, __m256i *c3_v, __m256i *a_sum = nullptr) {
+static inline __attribute__((always_inline)) void madd_epi16_packed(
+ __m256i a_v,
+ const __m256i* b,
+ __m256i* c0_v,
+ __m256i* c1_v,
+ __m256i* c2_v,
+ __m256i* c3_v,
+ __m256i* a_sum = nullptr) {
__m256i zero_v = _mm256_setzero_si256();
__m256i a_lo_v = _mm256_unpacklo_epi8(a_v, zero_v);
@@ -354,21 +376,41 @@ madd_epi16_packed(__m256i a_v, const __m256i *b, __m256i *c0_v, __m256i *c1_v,
// K is the number of accumulations we're doing
template <int K, bool SUM_A = false, bool REMAINDER = false, bool ACC = false>
-static inline __attribute__((always_inline)) void
-inner_prod_packed_(const __m256i *a_v, const __m256i *Bp, int32_t *C,
- int remainder, __m256i *a_sum = nullptr) {
+static inline __attribute__((always_inline)) void inner_prod_packed_(
+ const __m256i* a_v,
+ const __m256i* Bp,
+ int32_t* C,
+ int remainder,
+ __m256i* a_sum = nullptr) {
array<__m256i, 4> c, c_temp;
array<__m256i, 2> a_sum_temp{};
int k = 0;
if (K >= 4) {
- madd_epi16x4_packed<SUM_A>(a_v[0], a_v[1], a_v[2], a_v[3], Bp,
- &c[0], &c[1], &c[2], &c[3], a_sum_temp.data());
+ madd_epi16x4_packed<SUM_A>(
+ a_v[0],
+ a_v[1],
+ a_v[2],
+ a_v[3],
+ Bp,
+ &c[0],
+ &c[1],
+ &c[2],
+ &c[3],
+ a_sum_temp.data());
for (k = 4; k < K / 4 * 4; k += 4) {
- madd_epi16x4_packed<SUM_A>(a_v[k + 0], a_v[k + 1], a_v[k + 2], a_v[k + 3],
- Bp + k, &c_temp[0], &c_temp[1], &c_temp[2],
- &c_temp[3], a_sum_temp.data());
+ madd_epi16x4_packed<SUM_A>(
+ a_v[k + 0],
+ a_v[k + 1],
+ a_v[k + 2],
+ a_v[k + 3],
+ Bp + k,
+ &c_temp[0],
+ &c_temp[1],
+ &c_temp[2],
+ &c_temp[3],
+ a_sum_temp.data());
c[0] = _mm256_add_epi32(c[0], c_temp[0]);
c[1] = _mm256_add_epi32(c[1], c_temp[1]);
@@ -383,9 +425,16 @@ inner_prod_packed_(const __m256i *a_v, const __m256i *Bp, int32_t *C,
}
if (K - k == 3) {
- madd_epi16x3_packed<SUM_A>(a_v[k], a_v[k + 1], a_v[k + 2], Bp + k,
- &c_temp[0], &c_temp[1], &c_temp[2], &c_temp[3],
- a_sum_temp.data());
+ madd_epi16x3_packed<SUM_A>(
+ a_v[k],
+ a_v[k + 1],
+ a_v[k + 2],
+ Bp + k,
+ &c_temp[0],
+ &c_temp[1],
+ &c_temp[2],
+ &c_temp[3],
+ a_sum_temp.data());
c[0] = _mm256_add_epi32(c[0], c_temp[0]);
c[1] = _mm256_add_epi32(c[1], c_temp[1]);
@@ -405,11 +454,18 @@ inner_prod_packed_(const __m256i *a_v, const __m256i *Bp, int32_t *C,
c[3] = c_temp[3];
} else {
if (K - k == 1) {
- madd_epi16_packed<SUM_A>(a_v[k], Bp + k, &c[0], &c[1], &c[2], &c[3],
- a_sum_temp.data());
+ madd_epi16_packed<SUM_A>(
+ a_v[k], Bp + k, &c[0], &c[1], &c[2], &c[3], a_sum_temp.data());
} else if (K - k == 2) {
- madd_epi16x2_packed<SUM_A>(a_v[k], a_v[k + 1], Bp + k, &c[0], &c[1],
- &c[2], &c[3], a_sum_temp.data());
+ madd_epi16x2_packed<SUM_A>(
+ a_v[k],
+ a_v[k + 1],
+ Bp + k,
+ &c[0],
+ &c[1],
+ &c[2],
+ &c[3],
+ a_sum_temp.data());
}
c[0] = _mm256_add_epi32(c[0], c_temp[0]);
@@ -422,37 +478,37 @@ inner_prod_packed_(const __m256i *a_v, const __m256i *Bp, int32_t *C,
for (int r = 0; r < remainder / 8; ++r) {
if (ACC) {
_mm256_storeu_si256(
- reinterpret_cast<__m256i *>(C + r * 8),
+ reinterpret_cast<__m256i*>(C + r * 8),
_mm256_add_epi32(
- _mm256_loadu_si256(reinterpret_cast<__m256i *>(C + r * 8)),
+ _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + r * 8)),
c[r]));
} else {
- _mm256_storeu_si256(reinterpret_cast<__m256i *>(C + r * 8), c[r]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + r * 8), c[r]);
}
}
} else {
if (ACC) {
_mm256_storeu_si256(
- reinterpret_cast<__m256i *>(C),
- _mm256_add_epi32(_mm256_loadu_si256(reinterpret_cast<__m256i *>(C)),
- c[0]));
+ reinterpret_cast<__m256i*>(C),
+ _mm256_add_epi32(
+ _mm256_loadu_si256(reinterpret_cast<__m256i*>(C)), c[0]));
_mm256_storeu_si256(
- reinterpret_cast<__m256i *>(C + 8),
+ reinterpret_cast<__m256i*>(C + 8),
_mm256_add_epi32(
- _mm256_loadu_si256(reinterpret_cast<__m256i *>(C + 8)), c[1]));
+ _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 8)), c[1]));
_mm256_storeu_si256(
- reinterpret_cast<__m256i *>(C + 16),
+ reinterpret_cast<__m256i*>(C + 16),
_mm256_add_epi32(
- _mm256_loadu_si256(reinterpret_cast<__m256i *>(C + 16)), c[2]));
+ _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 16)), c[2]));
_mm256_storeu_si256(
- reinterpret_cast<__m256i *>(C + 24),
+ reinterpret_cast<__m256i*>(C + 24),
_mm256_add_epi32(
- _mm256_loadu_si256(reinterpret_cast<__m256i *>(C + 24)), c[3]));
+ _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 24)), c[3]));
} else {
- _mm256_storeu_si256(reinterpret_cast<__m256i *>(C), c[0]);
- _mm256_storeu_si256(reinterpret_cast<__m256i *>(C + 8), c[1]);
- _mm256_storeu_si256(reinterpret_cast<__m256i *>(C + 16), c[2]);
- _mm256_storeu_si256(reinterpret_cast<__m256i *>(C + 24), c[3]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(C), c[0]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 8), c[1]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 16), c[2]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 24), c[3]);
}
}
@@ -467,14 +523,13 @@ inner_prod_packed_(const __m256i *a_v, const __m256i *Bp, int32_t *C,
}
template <bool SUM_A = false, bool REMAINDER = false>
-static inline __attribute__((always_inline))
-void inner_prod_3x3_packed_(const __m256i* a_v,
- const __m256i* Bp,
- int32_t* C,
- int remainder,
- __m256i* a_sum = nullptr) {
- return inner_prod_packed_<9, SUM_A, REMAINDER>(a_v, Bp, C, remainder,
- a_sum);
+static inline __attribute__((always_inline)) void inner_prod_3x3_packed_(
+ const __m256i* a_v,
+ const __m256i* Bp,
+ int32_t* C,
+ int remainder,
+ __m256i* a_sum = nullptr) {
+ return inner_prod_packed_<9, SUM_A, REMAINDER>(a_v, Bp, C, remainder, a_sum);
}
// Almost same as ReQuantizeOutput in OutputProcessing-inh.h but different
@@ -507,7 +562,7 @@ static inline __attribute__((always_inline)) void requantize_(
constexpr int VLEN = 8;
int j = 0;
- for ( ; j < n / (VLEN * 4) * (VLEN * 4); j += (VLEN * 4)) {
+ for (; j < n / (VLEN * 4) * (VLEN * 4); j += (VLEN * 4)) {
__m256i x_v =
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j));
__m256i y_v = _mm256_loadu_si256(
@@ -519,10 +574,9 @@ static inline __attribute__((always_inline)) void requantize_(
__m256i col_off_v = _mm256_mullo_epi32(
A_zero_point_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(col_offsets + j)));
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(col_offsets + j)));
__m256i row_offset_v =
- _mm256_loadu_si256(reinterpret_cast<const __m256i *>(row_offsets + j));
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(row_offsets + j));
x_v = _mm256_sub_epi32(_mm256_sub_epi32(x_v, col_off_v), row_offset_v);
col_off_v = _mm256_mullo_epi32(
@@ -535,23 +589,23 @@ static inline __attribute__((always_inline)) void requantize_(
col_off_v = _mm256_mullo_epi32(
A_zero_point_v,
- _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
- col_offsets + j + 2 * VLEN)));
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(col_offsets + j + 2 * VLEN)));
row_offset_v = _mm256_loadu_si256(
reinterpret_cast<const __m256i*>(row_offsets + j + 2 * VLEN));
z_v = _mm256_sub_epi32(_mm256_sub_epi32(z_v, col_off_v), row_offset_v);
col_off_v = _mm256_mullo_epi32(
A_zero_point_v,
- _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
- col_offsets + j + 3 * VLEN)));
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(col_offsets + j + 3 * VLEN)));
row_offset_v = _mm256_loadu_si256(
reinterpret_cast<const __m256i*>(row_offsets + j + 3 * VLEN));
w_v = _mm256_sub_epi32(_mm256_sub_epi32(w_v, col_off_v), row_offset_v);
if (HAS_BIAS) { // static if
x_v = _mm256_add_epi32(
- x_v, _mm256_loadu_si256(reinterpret_cast<const __m256i *>(bias + j)));
+ x_v, _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias + j)));
y_v = _mm256_add_epi32(
y_v,
_mm256_loadu_si256(
@@ -604,21 +658,20 @@ static inline __attribute__((always_inline)) void requantize_(
reinterpret_cast<__m256i*>(C_uint8 + j), xyzw_clamped_v);
} // j loop vectorized and unrolled 4x
- for ( ; j < n / VLEN * VLEN; j += VLEN) {
+ for (; j < n / VLEN * VLEN; j += VLEN) {
__m256i x_v =
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j));
__m256i col_off_v = _mm256_mullo_epi32(
A_zero_point_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(col_offsets + j)));
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(col_offsets + j)));
__m256i row_offset_v =
- _mm256_loadu_si256(reinterpret_cast<const __m256i *>(row_offsets + j));
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(row_offsets + j));
x_v = _mm256_sub_epi32(_mm256_sub_epi32(x_v, col_off_v), row_offset_v);
if (HAS_BIAS) { // static if
x_v = _mm256_add_epi32(
- x_v, _mm256_loadu_si256(reinterpret_cast<const __m256i *>(bias + j)));
+ x_v, _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias + j)));
}
if (PER_CHANNEL_QUANTIZATION) {
@@ -635,15 +688,14 @@ static inline __attribute__((always_inline)) void requantize_(
FUSE_RELU ? C_zero_point_epi8_v : min_v,
_mm256_min_epu8(x_packed_v, max_v));
- x_clamped_v =
- _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v);
+ x_clamped_v = _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v);
_mm_storel_epi64(
reinterpret_cast<__m128i*>(C_uint8 + j),
_mm256_castsi256_si128(x_clamped_v));
} // j loop vectorized
- for ( ; j < n; ++j) {
+ for (; j < n; ++j) {
int32_t raw = C_int32[j] - A_zero_point * col_offsets[j] - row_offsets[j];
if (HAS_BIAS) { // static if
raw += bias[j];
@@ -659,11 +711,16 @@ static inline __attribute__((always_inline)) void requantize_(
}
template <bool FUSE_RELU, bool HAS_BIAS>
-static inline __attribute__((always_inline)) void
-requantize_(int32_t A_zero_point, float C_multiplier,
- int32_t C_zero_point, const int32_t *C_int32, uint8_t *C_uint8,
- int n, const int32_t *row_offsets, const int32_t *col_offsets,
- const int32_t *bias) {
+static inline __attribute__((always_inline)) void requantize_(
+ int32_t A_zero_point,
+ float C_multiplier,
+ int32_t C_zero_point,
+ const int32_t* C_int32,
+ uint8_t* C_uint8,
+ int n,
+ const int32_t* row_offsets,
+ const int32_t* col_offsets,
+ const int32_t* bias) {
requantize_<FUSE_RELU, HAS_BIAS, false /* PER_CHANNEL_QUANTIZATION */>(
A_zero_point,
&C_multiplier,
@@ -677,11 +734,16 @@ requantize_(int32_t A_zero_point, float C_multiplier,
}
template <bool FUSE_RELU, bool HAS_BIAS>
-static inline __attribute__((always_inline)) void
-requantize_per_channel_(int32_t A_zero_point, const float *C_multiplier,
- int32_t C_zero_point, const int32_t *C_int32,
- uint8_t *C_uint8, int n, const int32_t *row_offsets,
- const int32_t *col_offsets, const int32_t *bias) {
+static inline __attribute__((always_inline)) void requantize_per_channel_(
+ int32_t A_zero_point,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ const int32_t* C_int32,
+ uint8_t* C_uint8,
+ int n,
+ const int32_t* row_offsets,
+ const int32_t* col_offsets,
+ const int32_t* bias) {
requantize_<FUSE_RELU, HAS_BIAS, true /* PER_CHANNEL_QUANTIZATION */>(
A_zero_point,
C_multiplier,
@@ -695,27 +757,38 @@ requantize_per_channel_(int32_t A_zero_point, const float *C_multiplier,
}
template <bool REMAINDER>
-static inline __attribute__((always_inline)) __m256i
-load_a(const uint8_t* A, __m256i mask_v) {
+static inline __attribute__((always_inline)) __m256i load_a(
+ const uint8_t* A,
+ __m256i mask_v) {
if (REMAINDER) {
- return _mm256_maskload_epi32(reinterpret_cast<const int *>(A), mask_v);
+ return _mm256_maskload_epi32(reinterpret_cast<const int*>(A), mask_v);
} else {
- return _mm256_lddqu_si256(reinterpret_cast<const __m256i *>(A));
+ return _mm256_lddqu_si256(reinterpret_cast<const __m256i*>(A));
}
}
-template <bool SUM_A, bool REMAINDER = false,
- bool PER_CHANNEL_QUANTIZATION = false>
-static inline __attribute__((always_inline)) void
-inner_prod_3x3_packed_(int H, int W, int K, int h_in, int w_in,
- const uint8_t *A, int32_t A_zero_point, const int8_t *Bp,
- const int32_t *B_zero_point, int32_t *C, int remainder,
- int32_t *row_offsets) {
+template <
+ bool SUM_A,
+ bool REMAINDER = false,
+ bool PER_CHANNEL_QUANTIZATION = false>
+static inline __attribute__((always_inline)) void inner_prod_3x3_packed_(
+ int H,
+ int W,
+ int K,
+ int h_in,
+ int w_in,
+ const uint8_t* A,
+ int32_t A_zero_point,
+ const int8_t* Bp,
+ const int32_t* B_zero_point,
+ int32_t* C,
+ int remainder,
+ int32_t* row_offsets) {
__m256i A_zero_point_v = _mm256_set1_epi8(static_cast<uint8_t>(A_zero_point));
__m256i mask_v = _mm256_setzero_si256();
if (REMAINDER) {
mask_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i *>(masks[remainder / 4].data()));
+ reinterpret_cast<const __m256i*>(masks[remainder / 4].data()));
}
// The code below can be written as a simple R*S loop but the compiler
@@ -739,15 +812,15 @@ inner_prod_3x3_packed_(int H, int W, int K, int h_in, int w_in,
// }
// }
array<__m256i, 9> a_v = {
- A_zero_point_v,
- A_zero_point_v,
- A_zero_point_v,
- A_zero_point_v,
- A_zero_point_v,
- A_zero_point_v,
- A_zero_point_v,
- A_zero_point_v,
- A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
};
if (h_in >= 0 && h_in < H) {
@@ -788,35 +861,51 @@ inner_prod_3x3_packed_(int H, int W, int K, int h_in, int w_in,
array<__m256i, 4> a_sum;
inner_prod_3x3_packed_<SUM_A, REMAINDER>(
- a_v.data(), reinterpret_cast<const __m256i *>(Bp), C, remainder,
+ a_v.data(),
+ reinterpret_cast<const __m256i*>(Bp),
+ C,
+ remainder,
a_sum.data());
if (SUM_A) {
__m256i B_zero_point_v;
for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) {
if (PER_CHANNEL_QUANTIZATION) {
B_zero_point_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i *>(B_zero_point + i * 8));
+ reinterpret_cast<const __m256i*>(B_zero_point + i * 8));
} else {
B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]);
}
- _mm256_store_si256(reinterpret_cast<__m256i *>(&row_offsets[i * 8]),
- _mm256_mullo_epi32(a_sum[i], B_zero_point_v));
+ _mm256_store_si256(
+ reinterpret_cast<__m256i*>(&row_offsets[i * 8]),
+ _mm256_mullo_epi32(a_sum[i], B_zero_point_v));
}
}
}
-template <bool SUM_A, bool REMAINDER = false,
- bool PER_CHANNEL_QUANTIZATION = false>
-static inline __attribute__((always_inline)) void
-inner_prod_3x3x3_packed_(int T, int H, int W, int K, int t_in, int h_in,
- int w_in, const uint8_t *A, int32_t A_zero_point,
- const int8_t *Bp, const int32_t *B_zero_point,
- int32_t *C, int remainder, int32_t *row_offsets) {
+template <
+ bool SUM_A,
+ bool REMAINDER = false,
+ bool PER_CHANNEL_QUANTIZATION = false>
+static inline __attribute__((always_inline)) void inner_prod_3x3x3_packed_(
+ int T,
+ int H,
+ int W,
+ int K,
+ int t_in,
+ int h_in,
+ int w_in,
+ const uint8_t* A,
+ int32_t A_zero_point,
+ const int8_t* Bp,
+ const int32_t* B_zero_point,
+ int32_t* C,
+ int remainder,
+ int32_t* row_offsets) {
__m256i A_zero_point_v = _mm256_set1_epi8(static_cast<uint8_t>(A_zero_point));
__m256i mask_v = _mm256_setzero_si256();
if (REMAINDER) {
mask_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i *>(masks[remainder / 4].data()));
+ reinterpret_cast<const __m256i*>(masks[remainder / 4].data()));
}
// The code below can be written as a simple R*S loop but the compiler
@@ -885,9 +974,12 @@ inner_prod_3x3x3_packed_(int T, int H, int W, int K, int t_in, int h_in,
}
array<__m256i, 4> a_sum;
- inner_prod_packed_<8, SUM_A, REMAINDER>(a_v.data(),
- reinterpret_cast<const __m256i *>(Bp),
- C, remainder, a_sum.data());
+ inner_prod_packed_<8, SUM_A, REMAINDER>(
+ a_v.data(),
+ reinterpret_cast<const __m256i*>(Bp),
+ C,
+ remainder,
+ a_sum.data());
a_v[0] = A_zero_point_v;
a_v[1] = A_zero_point_v;
@@ -940,7 +1032,10 @@ inner_prod_3x3x3_packed_(int T, int H, int W, int K, int t_in, int h_in,
array<__m256i, 4> a_sum_temp;
inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>(
- a_v.data(), reinterpret_cast<const __m256i *>(Bp) + 8, C, remainder,
+ a_v.data(),
+ reinterpret_cast<const __m256i*>(Bp) + 8,
+ C,
+ remainder,
a_sum_temp.data());
if (SUM_A) {
a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]);
@@ -996,7 +1091,10 @@ inner_prod_3x3x3_packed_(int T, int H, int W, int K, int t_in, int h_in,
}
inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>(
- a_v.data(), reinterpret_cast<const __m256i *>(Bp) + 16, C, remainder,
+ a_v.data(),
+ reinterpret_cast<const __m256i*>(Bp) + 16,
+ C,
+ remainder,
a_sum_temp.data());
if (SUM_A) {
a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]);
@@ -1024,7 +1122,10 @@ inner_prod_3x3x3_packed_(int T, int H, int W, int K, int t_in, int h_in,
}
inner_prod_packed_<3, SUM_A, REMAINDER, true /* acc */>(
- a_v.data(), reinterpret_cast<const __m256i *>(Bp) + 24, C, remainder,
+ a_v.data(),
+ reinterpret_cast<const __m256i*>(Bp) + 24,
+ C,
+ remainder,
a_sum_temp.data());
if (SUM_A) {
@@ -1037,29 +1138,37 @@ inner_prod_3x3x3_packed_(int T, int H, int W, int K, int t_in, int h_in,
for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) {
if (PER_CHANNEL_QUANTIZATION) {
B_zero_point_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i *>(B_zero_point + i * 8));
+ reinterpret_cast<const __m256i*>(B_zero_point + i * 8));
} else {
B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]);
}
- _mm256_store_si256(reinterpret_cast<__m256i *>(&row_offsets[i * 8]),
- _mm256_mullo_epi32(a_sum[i], B_zero_point_v));
+ _mm256_store_si256(
+ reinterpret_cast<__m256i*>(&row_offsets[i * 8]),
+ _mm256_mullo_epi32(a_sum[i], B_zero_point_v));
}
}
}
template <bool SUM_A, bool FUSE_RELU>
-static inline __attribute__((always_inline))
-void depthwise_3x3_kernel_(int H, int W, int K, int h, int w,
- int stride_h, int stride_w,
- int32_t A_zero_point, const uint8_t* A,
- int32_t B_zero_point, const int8_t* Bp,
- float C_multiplier,
- int32_t C_zero_point,
- int32_t* C_int32, uint8_t* C_uint8,
- int32_t* row_offsets,
- const int32_t *col_offsets,
- const int32_t *bias)
-{
+static inline __attribute__((always_inline)) void depthwise_3x3_kernel_(
+ int H,
+ int W,
+ int K,
+ int h,
+ int w,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const int8_t* Bp,
+ float C_multiplier,
+ int32_t C_zero_point,
+ int32_t* C_int32,
+ uint8_t* C_uint8,
+ int32_t* row_offsets,
+ const int32_t* col_offsets,
+ const int32_t* bias) {
constexpr int S = 3;
constexpr int PAD_T = 1, PAD_L = 1, PAD_R = 1;
int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
@@ -1069,43 +1178,72 @@ void depthwise_3x3_kernel_(int H, int W, int K, int h, int w,
int k;
for (k = 0; k < K / 32 * 32; k += 32) {
inner_prod_3x3_packed_<SUM_A>(
- H, W, K, h_in, w_in,
- A + (h_in * W + w_in) * K + k, A_zero_point,
- Bp + k * 10, &B_zero_point,
- C_int32 + k, 0, &row_offsets[k]);
+ H,
+ W,
+ K,
+ h_in,
+ w_in,
+ A + (h_in * W + w_in) * K + k,
+ A_zero_point,
+ Bp + k * 10,
+ &B_zero_point,
+ C_int32 + k,
+ 0,
+ &row_offsets[k]);
}
int remainder = K - k;
if (remainder) {
inner_prod_3x3_packed_<SUM_A, true>(
- H, W, K, h_in, w_in,
- A + (h_in * W + w_in) * K + k, A_zero_point,
- Bp + k * 10, &B_zero_point,
- C_int32 + k, remainder, &row_offsets[k]);
+ H,
+ W,
+ K,
+ h_in,
+ w_in,
+ A + (h_in * W + w_in) * K + k,
+ A_zero_point,
+ Bp + k * 10,
+ &B_zero_point,
+ C_int32 + k,
+ remainder,
+ &row_offsets[k]);
}
if (SUM_A) {
- requantize_<FUSE_RELU, true>
- (
- A_zero_point, C_multiplier, C_zero_point,
- C_int32, C_uint8 + (h * W_OUT + w) * K, K,
+ requantize_<FUSE_RELU, true>(
+ A_zero_point,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8 + (h * W_OUT + w) * K,
+ K,
row_offsets,
- col_offsets, bias
- );
+ col_offsets,
+ bias);
}
}
template <bool SUM_A, bool FUSE_RELU>
-static inline __attribute__((always_inline))
-void depthwise_3x3x3_kernel_(int T, int H, int W, int K, int t, int h, int w,
- int stride_t, int stride_h, int stride_w,
- int32_t A_zero_point, const uint8_t* A,
- int32_t B_zero_point, const int8_t* Bp,
- float C_multiplier,
- int32_t C_zero_point,
- int32_t* C_int32, uint8_t* C_uint8,
- int32_t* row_offsets,
- const int32_t *col_offsets,
- const int32_t *bias)
-{
+static inline __attribute__((always_inline)) void depthwise_3x3x3_kernel_(
+ int T,
+ int H,
+ int W,
+ int K,
+ int t,
+ int h,
+ int w,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const int8_t* Bp,
+ float C_multiplier,
+ int32_t C_zero_point,
+ int32_t* C_int32,
+ uint8_t* C_uint8,
+ int32_t* row_offsets,
+ const int32_t* col_offsets,
+ const int32_t* bias) {
constexpr int R = 3, S = 3;
constexpr int PAD_P = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
@@ -1117,39 +1255,74 @@ void depthwise_3x3x3_kernel_(int T, int H, int W, int K, int t, int h, int w,
int k;
for (k = 0; k < K / 32 * 32; k += 32) {
inner_prod_3x3x3_packed_<SUM_A>(
- T, H, W, K, t_in, h_in, w_in,
- A + ((t_in * H + h_in) * W + w_in) * K + k, A_zero_point,
- Bp + k * 28, &B_zero_point,
- C_int32 + k, 0, &row_offsets[k]);
+ T,
+ H,
+ W,
+ K,
+ t_in,
+ h_in,
+ w_in,
+ A + ((t_in * H + h_in) * W + w_in) * K + k,
+ A_zero_point,
+ Bp + k * 28,
+ &B_zero_point,
+ C_int32 + k,
+ 0,
+ &row_offsets[k]);
}
int remainder = K - k;
if (remainder) {
inner_prod_3x3x3_packed_<SUM_A, true>(
- T, H, W, K, t_in, h_in, w_in,
- A + ((t_in * H + h_in) * W + w_in) * K + k, A_zero_point,
- Bp + k * 28, &B_zero_point,
- C_int32 + k, remainder, &row_offsets[k]);
+ T,
+ H,
+ W,
+ K,
+ t_in,
+ h_in,
+ w_in,
+ A + ((t_in * H + h_in) * W + w_in) * K + k,
+ A_zero_point,
+ Bp + k * 28,
+ &B_zero_point,
+ C_int32 + k,
+ remainder,
+ &row_offsets[k]);
}
if (SUM_A) {
- requantize_<FUSE_RELU, true>
- (
- A_zero_point, C_multiplier, C_zero_point,
- C_int32, C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K, K,
+ requantize_<FUSE_RELU, true>(
+ A_zero_point,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K,
+ K,
row_offsets,
- col_offsets, bias
- );
+ col_offsets,
+ bias);
}
}
template <bool SUM_A>
static inline __attribute__((always_inline)) void
depthwise_3x3_per_channel_quantization_kernel_(
- int H, int W, int K, int h, int w, int stride_h, int stride_w,
- int32_t A_zero_point, const uint8_t *A,
- const int32_t *B_zero_point, const int8_t *Bp,
- const float *C_multiplier, int32_t C_zero_point,
- int32_t *C_int32, uint8_t *C_uint8,
- int32_t *row_offsets, const int32_t *col_offsets, const int32_t *bias) {
+ int H,
+ int W,
+ int K,
+ int h,
+ int w,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const int8_t* Bp,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ int32_t* C_int32,
+ uint8_t* C_uint8,
+ int32_t* row_offsets,
+ const int32_t* col_offsets,
+ const int32_t* bias) {
constexpr int S = 3;
constexpr int PAD_T = 1, PAD_L = 1, PAD_R = 1;
int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
@@ -1158,28 +1331,47 @@ depthwise_3x3_per_channel_quantization_kernel_(
int k;
for (k = 0; k < K / 32 * 32; k += 32) {
- inner_prod_3x3_packed_<SUM_A, false/*remainder*/, true/*per-channel*/>(
- H, W, K, h_in, w_in,
- A + (h_in * W + w_in) * K + k, A_zero_point,
- Bp + k * 10, B_zero_point + k,
- C_int32 + k, 0, &row_offsets[k]);
+ inner_prod_3x3_packed_<SUM_A, false /*remainder*/, true /*per-channel*/>(
+ H,
+ W,
+ K,
+ h_in,
+ w_in,
+ A + (h_in * W + w_in) * K + k,
+ A_zero_point,
+ Bp + k * 10,
+ B_zero_point + k,
+ C_int32 + k,
+ 0,
+ &row_offsets[k]);
}
int remainder = K - k;
if (remainder) {
- inner_prod_3x3_packed_<SUM_A, true/*remainder*/, true/*per-channel*/>(
- H, W, K, h_in, w_in,
- A + (h_in * W + w_in) * K + k, A_zero_point,
- Bp + k * 10, B_zero_point + k,
- C_int32 + k, remainder, &row_offsets[k]);
+ inner_prod_3x3_packed_<SUM_A, true /*remainder*/, true /*per-channel*/>(
+ H,
+ W,
+ K,
+ h_in,
+ w_in,
+ A + (h_in * W + w_in) * K + k,
+ A_zero_point,
+ Bp + k * 10,
+ B_zero_point + k,
+ C_int32 + k,
+ remainder,
+ &row_offsets[k]);
}
if (SUM_A) {
- requantize_per_channel_<false, true>
- (
- A_zero_point, C_multiplier, C_zero_point,
- C_int32, C_uint8 + (h * W_OUT + w) * K, K,
+ requantize_per_channel_<false, true>(
+ A_zero_point,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8 + (h * W_OUT + w) * K,
+ K,
row_offsets,
- col_offsets, bias
- );
+ col_offsets,
+ bias);
}
}
@@ -1188,7 +1380,7 @@ static pair<int, int> closest_factors_(int n) {
while (n % a != 0) {
a--;
}
- return { a, n / a }; // a <= n / a
+ return {a, n / a}; // a <= n / a
}
// TODO: short-circuit when B_zero_point is 0 or A_zero_point is 0
@@ -1196,16 +1388,25 @@ static pair<int, int> closest_factors_(int n) {
// filter shapes by parameterizing with R and S but restricting it to just 3x3
// for now.
template <bool FUSE_RESCALE = true, bool FUSE_RELU = false>
-static inline __attribute__((always_inline))
-void depthwise_3x3_pad_1_(int N, int H, int W, int K,
- int stride_h, int stride_w,
- int32_t A_zero_point, const uint8_t *A,
- int32_t B_zero_point, const Packed3x3ConvMatrix &B,
- float C_multiplier,
- int32_t C_zero_point,
- int32_t* C_int32, uint8_t* C_uint8,
- const int32_t *col_offsets, const int32_t *bias,
- int thread_id, int num_threads) {
+static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const Packed3x3ConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ int32_t* C_int32,
+ uint8_t* C_uint8,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ int thread_id,
+ int num_threads) {
assert(K % 8 == 0);
constexpr int R = 3, S = 3;
constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
@@ -1213,8 +1414,8 @@ void depthwise_3x3_pad_1_(int N, int H, int W, int K,
int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
const int8_t* Bp = B.PackedMat();
- int32_t row_offsets[(K + 31) / 32 * 32] __attribute__ ((aligned (64)));
- int32_t *C_temp;
+ int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64)));
+ int32_t* C_temp;
int n_begin, n_end;
int h_begin, h_end, w_begin, w_end;
@@ -1266,22 +1467,48 @@ void depthwise_3x3_pad_1_(int N, int H, int W, int K,
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
if (w_end == W_OUT) {
@@ -1289,11 +1516,24 @@ void depthwise_3x3_pad_1_(int N, int H, int W, int K,
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
}
@@ -1303,22 +1543,48 @@ void depthwise_3x3_pad_1_(int N, int H, int W, int K,
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
if (w_end == W_OUT) {
@@ -1326,11 +1592,24 @@ void depthwise_3x3_pad_1_(int N, int H, int W, int K,
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
}
@@ -1341,22 +1620,48 @@ void depthwise_3x3_pad_1_(int N, int H, int W, int K,
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
if (w_end == W_OUT) {
@@ -1364,28 +1669,51 @@ void depthwise_3x3_pad_1_(int N, int H, int W, int K,
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
}
} // for each n
};
template <bool FUSE_RESCALE = true, bool FUSE_RELU = false>
-static inline __attribute__((always_inline))
-void depthwise_3x3x3_pad_1_(int N, int T, int H, int W, int K,
- int stride_t, int stride_h, int stride_w,
- int32_t A_zero_point, const uint8_t *A,
- int32_t B_zero_point,
- const Packed3x3x3ConvMatrix &B,
- float C_multiplier,
- int32_t C_zero_point,
- int32_t* C_int32, uint8_t* C_uint8,
- const int32_t *col_offsets, const int32_t *bias,
- int thread_id, int num_threads) {
+static inline __attribute__((always_inline)) void depthwise_3x3x3_pad_1_(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const Packed3x3x3ConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ int32_t* C_int32,
+ uint8_t* C_uint8,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ int thread_id,
+ int num_threads) {
assert(K % 8 == 0);
constexpr int K_T = 3, K_H = 3, K_W = 3;
constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
@@ -1395,8 +1723,8 @@ void depthwise_3x3x3_pad_1_(int N, int T, int H, int W, int K,
int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
const int8_t* Bp = B.PackedMat();
- int32_t row_offsets[(K + 31) / 32 * 32] __attribute__ ((aligned (64)));
- int32_t *C_temp;
+ int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64)));
+ int32_t* C_temp;
int n_begin, n_end;
int t_begin, t_end, h_begin, h_end;
@@ -1443,14 +1771,30 @@ void depthwise_3x3x3_pad_1_(int N, int T, int H, int W, int K,
for (int t = t_begin; t < t_end; ++t) {
for (int h = h_begin; h < h_end; ++h) {
for (int w = 0; w < W_OUT; ++w) {
- C_temp =
- FUSE_RESCALE
- ? C_int32
- : C_int32 + (((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K;
+ C_temp = FUSE_RESCALE
+ ? C_int32
+ : C_int32 + (((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
- T, H, W, K, t, h, w, stride_t, stride_h, stride_w, A_zero_point,
- A_base, B_zero_point, Bp, C_multiplier,
- C_zero_point, C_temp, C_uint8_base, row_offsets, col_offsets,
+ T,
+ H,
+ W,
+ K,
+ t,
+ h,
+ w,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
bias);
} // w
} // h
@@ -1461,11 +1805,23 @@ void depthwise_3x3x3_pad_1_(int N, int T, int H, int W, int K,
template <bool FUSE_RESCALE = true>
static inline __attribute__((always_inline)) void
depthwise_3x3_per_channel_quantization_pad_1_(
- int N, int H, int W, int K, int stride_h, int stride_w,
- int32_t A_zero_point, const uint8_t *A, const int32_t *B_zero_point,
- const Packed3x3ConvMatrix &B, const float *C_multiplier,
- int32_t C_zero_point, int32_t *C_int32, uint8_t *C_uint8,
- const int32_t *col_offsets, const int32_t *bias, int thread_id,
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const Packed3x3ConvMatrix& B,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ int32_t* C_int32,
+ uint8_t* C_uint8,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ int thread_id,
int num_threads) {
assert(K % 8 == 0);
constexpr int R = 3, S = 3;
@@ -1474,8 +1830,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
const int8_t* Bp = B.PackedMat();
- int32_t row_offsets[(K + 31) / 32 * 32] __attribute__ ((aligned (64)));
- int32_t *C_temp;
+ int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64)));
+ int32_t* C_temp;
int n_begin, n_end;
int h_begin, h_end, w_begin, w_end;
@@ -1527,22 +1883,48 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
if (w_end == W_OUT) {
@@ -1550,11 +1932,24 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
}
@@ -1564,22 +1959,48 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
if (w_end == W_OUT) {
@@ -1587,11 +2008,24 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
}
@@ -1602,22 +2036,48 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
if (w_end == W_OUT) {
@@ -1625,11 +2085,24 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
}
} // for each n
@@ -1643,163 +2116,338 @@ void depthwise_3x3_pad_1(
int K,
int stride_h,
int stride_w,
- int32_t A_zero_point, const uint8_t* A,
+ int32_t A_zero_point,
+ const uint8_t* A,
const Packed3x3ConvMatrix& B,
int32_t* C,
int thread_id,
int num_threads) {
if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
depthwise_3x3_pad_1_<false>(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- 0, B,
- 0.0f, 0, C, nullptr,
- nullptr, nullptr,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ 0,
+ B,
+ 0.0f,
+ 0,
+ C,
+ nullptr,
+ nullptr,
+ nullptr,
+ thread_id,
+ num_threads);
} else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
depthwise_3x3_pad_1_<false>(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- 0, B,
- 0.0f, 0, C, nullptr,
- nullptr, nullptr,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ 0,
+ B,
+ 0.0f,
+ 0,
+ C,
+ nullptr,
+ nullptr,
+ nullptr,
+ thread_id,
+ num_threads);
} else if (1 == stride_h && 1 == stride_w) {
depthwise_3x3_pad_1_<false>(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- 0, B,
- 0.0f, 0, C, nullptr,
- nullptr, nullptr,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ 0,
+ B,
+ 0.0f,
+ 0,
+ C,
+ nullptr,
+ nullptr,
+ nullptr,
+ thread_id,
+ num_threads);
} else if (2 == stride_h && 2 == stride_w) {
depthwise_3x3_pad_1_<false>(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- 0, B,
- 0.0f, 0, C, nullptr,
- nullptr, nullptr,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ 0,
+ B,
+ 0.0f,
+ 0,
+ C,
+ nullptr,
+ nullptr,
+ nullptr,
+ thread_id,
+ num_threads);
} else {
depthwise_3x3_pad_1_<false>(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- 0, B,
- 0.0f, 0, C, nullptr,
- nullptr, nullptr,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ 0,
+ B,
+ 0.0f,
+ 0,
+ C,
+ nullptr,
+ nullptr,
+ nullptr,
+ thread_id,
+ num_threads);
}
}
void depthwise_3x3_pad_1(
- int N, int H, int W, int K,
- int stride_h, int stride_w,
- int32_t A_zero_point, const uint8_t* A,
- int32_t B_zero_point, const Packed3x3ConvMatrix& B,
- float C_multiplier, int32_t C_zero_point, uint8_t* C,
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const Packed3x3ConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
const int32_t* col_offsets,
const int32_t* bias,
- int thread_id, int num_threads,
+ int thread_id,
+ int num_threads,
bool fuse_relu) {
int32_t C_int32_temp[(K + 31) / 32 * 32];
if (fuse_relu) {
if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- B_zero_point, B,
- C_multiplier, C_zero_point, C_int32_temp, C,
- col_offsets, bias,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
} else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- B_zero_point, B,
- C_multiplier, C_zero_point, C_int32_temp, C,
- col_offsets, bias,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
} else if (1 == stride_h && 1 == stride_w) {
depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- B_zero_point, B,
- C_multiplier, C_zero_point, C_int32_temp, C,
- col_offsets, bias,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
} else if (2 == stride_h && 2 == stride_w) {
depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- B_zero_point, B,
- C_multiplier, C_zero_point, C_int32_temp, C,
- col_offsets, bias,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
} else {
depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- B_zero_point, B,
- C_multiplier, C_zero_point, C_int32_temp, C,
- col_offsets, bias,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
}
} else {
if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
depthwise_3x3_pad_1_(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- B_zero_point, B,
- C_multiplier, C_zero_point, C_int32_temp, C,
- col_offsets, bias,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
} else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
depthwise_3x3_pad_1_(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- B_zero_point, B,
- C_multiplier, C_zero_point, C_int32_temp, C,
- col_offsets, bias,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
} else if (1 == stride_h && 1 == stride_w) {
depthwise_3x3_pad_1_(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- B_zero_point, B,
- C_multiplier, C_zero_point, C_int32_temp, C,
- col_offsets, bias,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
} else if (2 == stride_h && 2 == stride_w) {
depthwise_3x3_pad_1_(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- B_zero_point, B,
- C_multiplier, C_zero_point, C_int32_temp, C,
- col_offsets, bias,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
} else {
depthwise_3x3_pad_1_(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- B_zero_point, B,
- C_multiplier, C_zero_point, C_int32_temp, C,
- col_offsets, bias,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
}
}
}
@@ -1813,140 +2461,309 @@ void depthwise_3x3x3_pad_1(
int stride_t,
int stride_h,
int stride_w,
- int32_t A_zero_point, const uint8_t* A,
+ int32_t A_zero_point,
+ const uint8_t* A,
const Packed3x3x3ConvMatrix& B,
int32_t* C,
int thread_id,
int num_threads) {
depthwise_3x3x3_pad_1_<false /* FUSE_RESCALE */>(
- N, T, H, W, K,
- stride_t, stride_h, stride_w,
- A_zero_point, A,
- 0, B,
- 0.0f, 0, C, nullptr,
- nullptr, nullptr,
- thread_id, num_threads);
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ 0,
+ B,
+ 0.0f,
+ 0,
+ C,
+ nullptr,
+ nullptr,
+ nullptr,
+ thread_id,
+ num_threads);
}
static void depthwise_3x3x3_pad_1_(
- int N, int T, int H, int W, int K,
- int stride_t, int stride_h, int stride_w,
- int32_t A_zero_point, const uint8_t* A,
- int32_t B_zero_point, const Packed3x3x3ConvMatrix& B,
- float C_multiplier, int32_t C_zero_point, uint8_t* C,
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const Packed3x3x3ConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
const int32_t* col_offsets,
const int32_t* bias,
- int thread_id, int num_threads) {
+ int thread_id,
+ int num_threads) {
int32_t C_int32_temp[(K + 31) / 32 * 32];
depthwise_3x3x3_pad_1_<true /* FUSE_RESCALE */, false /* FUSE_RELU */>(
- N, T, H, W, K,
- stride_t, stride_h, stride_w,
- A_zero_point, A,
- B_zero_point, B,
- C_multiplier, C_zero_point, C_int32_temp, C,
- col_offsets, bias,
- thread_id, num_threads);
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
}
static void depthwise_3x3x3_pad_1_relu_fused_(
- int N, int T, int H, int W, int K,
- int stride_t, int stride_h, int stride_w,
- int32_t A_zero_point, const uint8_t* A,
- int32_t B_zero_point, const Packed3x3x3ConvMatrix& B,
- float C_multiplier, int32_t C_zero_point, uint8_t* C,
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const Packed3x3x3ConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
const int32_t* col_offsets,
const int32_t* bias,
- int thread_id, int num_threads) {
+ int thread_id,
+ int num_threads) {
int32_t C_int32_temp[(K + 31) / 32 * 32];
depthwise_3x3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>(
- N, T, H, W, K,
- stride_t, stride_h, stride_w,
- A_zero_point, A,
- B_zero_point, B,
- C_multiplier, C_zero_point, C_int32_temp, C,
- col_offsets, bias,
- thread_id, num_threads);
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
}
void depthwise_3x3x3_pad_1(
- int N, int T, int H, int W, int K,
- int stride_t, int stride_h, int stride_w,
- int32_t A_zero_point, const uint8_t* A,
- int32_t B_zero_point, const Packed3x3x3ConvMatrix& B,
- float C_multiplier, int32_t C_zero_point, uint8_t* C,
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const Packed3x3x3ConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
const int32_t* col_offsets,
const int32_t* bias,
bool fuse_relu,
- int thread_id, int num_threads) {
+ int thread_id,
+ int num_threads) {
// If we inline the following two functions, I see stack overflow.
if (fuse_relu) {
depthwise_3x3x3_pad_1_relu_fused_(
- N, T, H, W, K, stride_t, stride_h, stride_w, A_zero_point, A,
- B_zero_point, B, C_multiplier, C_zero_point, C,
- col_offsets, bias, thread_id, num_threads);
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
} else {
- depthwise_3x3x3_pad_1_(N, T, H, W, K, stride_t, stride_h, stride_w,
- A_zero_point, A, B_zero_point, B, C_multiplier,
- C_zero_point, C, col_offsets, bias,
- thread_id, num_threads);
+ depthwise_3x3x3_pad_1_(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
}
}
void depthwise_3x3_per_channel_quantization_pad_1(
- int N, int H, int W, int K,
- int stride_h, int stride_w,
- int32_t A_zero_point, const uint8_t* A,
- const int32_t *B_zero_point, const Packed3x3ConvMatrix& Bp,
- const float *C_multiplier, int32_t C_zero_point, uint8_t* C,
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const Packed3x3ConvMatrix& Bp,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
const int32_t* col_offsets,
const int32_t* bias,
- int thread_id, int num_threads) {
+ int thread_id,
+ int num_threads) {
int32_t C_int32_temp[(K + 31) / 32 * 32];
if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
depthwise_3x3_per_channel_quantization_pad_1_(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_int32_temp, C,
- col_offsets, bias,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
} else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
depthwise_3x3_per_channel_quantization_pad_1_(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_int32_temp, C,
- col_offsets, bias,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
} else if (1 == stride_h && 1 == stride_w) {
depthwise_3x3_per_channel_quantization_pad_1_(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_int32_temp, C,
- col_offsets, bias,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
} else if (2 == stride_h && 2 == stride_w) {
depthwise_3x3_per_channel_quantization_pad_1_(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_int32_temp, C,
- col_offsets, bias,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
} else {
depthwise_3x3_per_channel_quantization_pad_1_(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_int32_temp, C,
- col_offsets, bias,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
}
}