Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/FbgemmI8DepthwiseAvx2.cc')
-rw-r--r--src/FbgemmI8DepthwiseAvx2.cc2329
1 files changed, 398 insertions, 1931 deletions
diff --git a/src/FbgemmI8DepthwiseAvx2.cc b/src/FbgemmI8DepthwiseAvx2.cc
index f96d1d2..994f206 100644
--- a/src/FbgemmI8DepthwiseAvx2.cc
+++ b/src/FbgemmI8DepthwiseAvx2.cc
@@ -7,523 +7,15 @@
#include "fbgemm/FbgemmI8DepthwiseAvx2.h"
#include "fbgemm/Utils.h"
-#include <algorithm> // for min and max
-#include <cassert>
-#include <cmath> // for lrintf and sqrt
+#include <string>
#include <tuple> // for tie
-#include <immintrin.h>
+#include "FbgemmI8DepthwiseAvx2-inl.h"
using namespace std;
namespace fbgemm {
-static int masks[8][8] = {
- // NOTE: clang-format wants to use a different formatting but the current
- // formatting should be easier to read.
- { 0, 0, 0, 0, 0, 0, 0, 0, },
- { -1, 0, 0, 0, 0, 0, 0, 0, },
- { -1, -1, 0, 0, 0, 0, 0, 0, },
- { -1, -1, -1, 0, 0, 0, 0, 0, },
- { -1, -1, -1, -1, 0, 0, 0, 0, },
- { -1, -1, -1, -1, -1, 0, 0, 0, },
- { -1, -1, -1, -1, -1, -1, 0, 0, },
- { -1, -1, -1, -1, -1, -1, -1, 0, },
-};
-
-template <int KERNEL_PROD>
-PackedDepthWiseConvMatrix<KERNEL_PROD>::PackedDepthWiseConvMatrix(
- int K,
- const int8_t* smat)
- : K_(K) {
- // Transpose the input matrix to make packing faster.
- int8_t* smat_transposed = static_cast<int8_t *>(ALIGNED_MALLOC(
- K * KERNEL_PROD * sizeof(int8_t), 64));
- for (int i = 0; i < KERNEL_PROD; ++i) {
- for (int j = 0; j < K; ++j) {
- smat_transposed[i * K + j] = smat[i + j * KERNEL_PROD];
- }
- }
-
- // Allocate packed arrays
- constexpr int KERNEL_PROD_ALIGNED = (KERNEL_PROD + 1) / 2 * 2;
-#ifdef _MSC_VER
- pmat_ = static_cast<int8_t *>(_aligned_malloc(
- ((K + 31) / 32) * KERNEL_PROD_ALIGNED * 32 * sizeof(int8_t), 64));
-#else
- posix_memalign(
- (void**)&pmat_,
- 64,
- ((K + 31) / 32) * KERNEL_PROD_ALIGNED * 32 * sizeof(int8_t));
-#endif
-
- // Pack input matrix
- // The layout is optimized to use vpmaddubsw efficiently (see
- // madd_epi16x4_packed function).
- // For a group of 32 channels, we have 10 32B SIMD registers.
- // Denote ith channel jth filter as (i, j)
- // 0th SIMD register:
- // (0, 0), (0, 1), (0, 2), (0, 3), ..., (3, 0), (3, 1), (3, 2), (3, 3)
- // (16, 0), (16, 1), (16, 2), (16, 3), ..., (19, 0), (19, 1), (19, 2), (19, 3)
- // 1st SIMD register:
- // (4, 0), (4, 1), (4, 2), (4, 3), ..., (7, 0), (7, 1), (7, 2), (7, 3)
- // (20, 0), (20, 1), (20, 2), (20, 3), ..., (23, 0), (23, 1), (23, 2), (23, 3)
- // 2nd SIMD register:
- // (8, 0), (8, 1), (8, 2), (8, 3), ..., (11, 0), (11, 1), (11, 2), (11, 3)
- // (24, 0), (24, 1), (24, 2), (24, 3), ..., (27, 0), (27, 1), (27, 2), (27, 3)
- // 3rd SIMD register:
- // (12, 0), (12, 1), (12, 2), (12, 3), ..., (15, 0), (15, 1), (15, 2), (15, 3)
- // (28, 0), (28, 1), (28, 2), (28, 3), ..., (31, 0), (31, 1), (31, 2), (31, 3)
- // 4-7th SIMD register: same as the previous 4 registers but for 4-7th filter
- // coefficients
- // ...
- //
- // REMAINDER
- // If KERNEL_PROD % 4 == 1 for example when KERNEL_PROD == 9
- // 8th SIMD register:
- // (0, 8), zero, ..., (7, 8), zero
- // (16, 8), zero, ..., (23, 8), zero
- // 9th SIMD register:
- // (8, 8), zero, ..., (15, 8), zero
- // (24, 8), zero, ..., (31, 8), zero
- // We use madd_epi16_packed for this case
- //
- // If KERNEL_PROD % 4 == 2 for example when KERNEL_PROD == 10
- // 8th SIMD register:
- // (0, 8), (0, 9), ..., (7, 8), (7, 9)
- // (16, 8), (16, 9), ..., (23, 8), (23, 9)
- // 9th SIMD register:
- // (8, 8), (8, 9), ..., (15, 8), (15, 9)
- // (24, 8), (24, 9), ..., (31, 8), (31, 9)
- //
- // If KERNEL_PROD % 4 == 3 for example when KERNEL_PROD == 11
- // 8th SIMD register:
- // (0, 8), (0, 9), (0, 10), zero, ..., (3, 8), (3, 9), (3, 10), zero
- // (16, 8), (16, 9), (16, 10), zero, ..., (19, 8), (19, 9), (19, 10), zero
- // 9th SIMD register:
- // (4, 8), (4, 9), (4, 10), zero, ..., (7, 8), (7, 9), (7, 10), zero
- // (20, 8), (20, 9), (20, 10), zero, ..., (23, 8), (23, 9), (23, 10), zero
- // 10th SIMD register:
- // (8, 8), (8, 9), (8, 10), zero, ..., (11, 8), (11, 9), (11, 10), zero
- // (24, 8), (24, 9), (24, 10), zero, ..., (27, 8), (27, 9), (27, 10), zero
- // 11th SIMD register:
- // (12, 8), (12, 9), (12, 10), zero, ..., (15, 8), (15, 9), (15, 10), zero
- // (28, 8), (28, 9), (28, 10), zero, ..., (31, 8), (31, 9), (31, 10), zero
- for (int k1 = 0; k1 < K; k1 += 32) {
- __m256i b_v[KERNEL_PROD];
- int remainder = K - k1;
- if (remainder < 32) {
- __m256i mask_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(masks[remainder / 4]));
- for (int i = 0; i < KERNEL_PROD; ++i) {
- b_v[i] = _mm256_maskload_epi32(
- reinterpret_cast<const int*>(smat_transposed + i * K + k1), mask_v);
- }
- } else {
- for (int i = 0; i < KERNEL_PROD; ++i) {
- b_v[i] = _mm256_lddqu_si256(
- reinterpret_cast<const __m256i*>(smat_transposed + i * K + k1));
- }
- }
-
- // Interleave 2 SIMD registers
- __m256i b_interleaved_epi16[KERNEL_PROD_ALIGNED];
- __m256i zero_v = _mm256_setzero_si256();
- for (int i = 0; i < KERNEL_PROD_ALIGNED / 2; ++i) {
- if (2 * i + 1 >= KERNEL_PROD) {
- b_interleaved_epi16[2 * i] = _mm256_unpacklo_epi8(b_v[2 * i], zero_v);
- b_interleaved_epi16[2 * i + 1] =
- _mm256_unpackhi_epi8(b_v[2 * i], zero_v);
- } else {
- b_interleaved_epi16[2 * i] =
- _mm256_unpacklo_epi8(b_v[2 * i], b_v[2 * i + 1]);
- b_interleaved_epi16[2 * i + 1] =
- _mm256_unpackhi_epi8(b_v[2 * i], b_v[2 * i + 1]);
- }
- }
-
- // Interleave 4 SIMD registers
- __m256i b_interleaved_epi32[KERNEL_PROD_ALIGNED];
- for (int i = 0; i < KERNEL_PROD_ALIGNED / 4; ++i) {
- b_interleaved_epi32[4 * i] = _mm256_unpacklo_epi16(
- b_interleaved_epi16[4 * i], b_interleaved_epi16[4 * i + 2]);
- b_interleaved_epi32[4 * i + 1] = _mm256_unpackhi_epi16(
- b_interleaved_epi16[4 * i], b_interleaved_epi16[4 * i + 2]);
- b_interleaved_epi32[4 * i + 2] = _mm256_unpacklo_epi16(
- b_interleaved_epi16[4 * i + 1], b_interleaved_epi16[4 * i + 3]);
- b_interleaved_epi32[4 * i + 3] = _mm256_unpackhi_epi16(
- b_interleaved_epi16[4 * i + 1], b_interleaved_epi16[4 * i + 3]);
- }
- for (int i = KERNEL_PROD_ALIGNED / 4 * 4; i < KERNEL_PROD_ALIGNED; ++i) {
- b_interleaved_epi32[i] = b_interleaved_epi16[i];
- }
-
- for (int i = 0; i < KERNEL_PROD_ALIGNED; ++i) {
- _mm256_storeu_si256(
- reinterpret_cast<__m256i*>(
- &pmat_[((k1 / 32) * KERNEL_PROD_ALIGNED + i) * 32]),
- b_interleaved_epi32[i]);
- }
- }
-
- FREE(smat_transposed);
-}
-
-template <int KERNEL_PROD>
-PackedDepthWiseConvMatrix<KERNEL_PROD>::~PackedDepthWiseConvMatrix() {
-#ifdef _MSC_VER
- _aligned_free(pmat_);
-#else
- free(pmat_);
-#endif
-}
-
-template class PackedDepthWiseConvMatrix<3 * 3>;
-template class PackedDepthWiseConvMatrix<3 * 3 * 3>;
-
-// c = a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3
-// A is in uint8_t
-// B is in int8_t and pre-interleaved
-// C is in int32_t and 4 registers have results in the following layout:
-// c0_v: c[0:4], c[16:20]
-// c1_v: c[4:8], c[20:24]
-// c2_v: c[8:12], c[24:28]
-// c3_v: c[12:16], c[28:32]
-template <bool SUM_A = false>
-static inline ALWAYS_INLINE void madd_epi16x4_packed(
- __m256i a0_v,
- __m256i a1_v,
- __m256i a2_v,
- __m256i a3_v,
- const __m256i* b,
- __m256i* c0_v,
- __m256i* c1_v,
- __m256i* c2_v,
- __m256i* c3_v,
- __m256i* a_sum = nullptr) {
- __m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v);
- __m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v);
- __m256i a23_lo_v = _mm256_unpacklo_epi8(a2_v, a3_v);
- __m256i a23_hi_v = _mm256_unpackhi_epi8(a2_v, a3_v);
-
- if (SUM_A) {
- __m256i one_epi8_v = _mm256_set1_epi8(1);
- a_sum[0] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a01_lo_v, one_epi8_v), a_sum[0]);
- a_sum[1] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a01_hi_v, one_epi8_v), a_sum[1]);
- a_sum[0] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a23_lo_v, one_epi8_v), a_sum[0]);
- a_sum[1] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a23_hi_v, one_epi8_v), a_sum[1]);
- }
-
- __m256i a0_interleaved_v = _mm256_unpacklo_epi16(a01_lo_v, a23_lo_v);
- __m256i a1_interleaved_v = _mm256_unpackhi_epi16(a01_lo_v, a23_lo_v);
- __m256i a2_interleaved_v = _mm256_unpacklo_epi16(a01_hi_v, a23_hi_v);
- __m256i a3_interleaved_v = _mm256_unpackhi_epi16(a01_hi_v, a23_hi_v);
-
- __m256i b0_v = _mm256_load_si256(b + 0);
- __m256i b1_v = _mm256_load_si256(b + 1);
- __m256i b2_v = _mm256_load_si256(b + 2);
- __m256i b3_v = _mm256_load_si256(b + 3);
-
- __m256i ab0 = _mm256_maddubs_epi16(a0_interleaved_v, b0_v);
- __m256i ab1 = _mm256_maddubs_epi16(a1_interleaved_v, b1_v);
- __m256i ab2 = _mm256_maddubs_epi16(a2_interleaved_v, b2_v);
- __m256i ab3 = _mm256_maddubs_epi16(a3_interleaved_v, b3_v);
-
- __m256i one_v = _mm256_set1_epi16(1);
- *c0_v = _mm256_madd_epi16(ab0, one_v);
- *c1_v = _mm256_madd_epi16(ab1, one_v);
- *c2_v = _mm256_madd_epi16(ab2, one_v);
- *c3_v = _mm256_madd_epi16(ab3, one_v);
-}
-
-// c = a0 * b0 + a1 * b1 + a2 * b2
-// A is in uint8_t
-// B is in int8_t and pre-interleaved
-// C is in int32_t and 4 registers have results in the following layout:
-// c0_v: c[0:4], c[16:20]
-// c1_v: c[4:8], c[20:24]
-// c2_v: c[8:12], c[24:28]
-// c3_v: c[12:16], c[28:32]
-template <bool SUM_A = false>
-static inline ALWAYS_INLINE void madd_epi16x3_packed(
- __m256i a0_v,
- __m256i a1_v,
- __m256i a2_v,
- const __m256i* b,
- __m256i* c0_v,
- __m256i* c1_v,
- __m256i* c2_v,
- __m256i* c3_v,
- __m256i* a_sum = nullptr) {
- __m256i zero_v = _mm256_setzero_si256();
-
- __m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v);
- __m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v);
- __m256i a23_lo_v = _mm256_unpacklo_epi8(a2_v, zero_v);
- __m256i a23_hi_v = _mm256_unpackhi_epi8(a2_v, zero_v);
-
- if (SUM_A) {
- __m256i one_epi8_v = _mm256_set1_epi8(1);
- a_sum[0] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a01_lo_v, one_epi8_v), a_sum[0]);
- a_sum[1] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a01_hi_v, one_epi8_v), a_sum[1]);
- a_sum[0] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a23_lo_v, one_epi8_v), a_sum[0]);
- a_sum[1] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a23_hi_v, one_epi8_v), a_sum[1]);
- }
-
- __m256i a0_interleaved_v = _mm256_unpacklo_epi16(a01_lo_v, a23_lo_v);
- __m256i a1_interleaved_v = _mm256_unpackhi_epi16(a01_lo_v, a23_lo_v);
- __m256i a2_interleaved_v = _mm256_unpacklo_epi16(a01_hi_v, a23_hi_v);
- __m256i a3_interleaved_v = _mm256_unpackhi_epi16(a01_hi_v, a23_hi_v);
-
- __m256i b0_v = _mm256_load_si256(b + 0);
- __m256i b1_v = _mm256_load_si256(b + 1);
- __m256i b2_v = _mm256_load_si256(b + 2);
- __m256i b3_v = _mm256_load_si256(b + 3);
-
- __m256i ab0 = _mm256_maddubs_epi16(a0_interleaved_v, b0_v);
- __m256i ab1 = _mm256_maddubs_epi16(a1_interleaved_v, b1_v);
- __m256i ab2 = _mm256_maddubs_epi16(a2_interleaved_v, b2_v);
- __m256i ab3 = _mm256_maddubs_epi16(a3_interleaved_v, b3_v);
-
- __m256i one_v = _mm256_set1_epi16(1);
- *c0_v = _mm256_madd_epi16(ab0, one_v);
- *c1_v = _mm256_madd_epi16(ab1, one_v);
- *c2_v = _mm256_madd_epi16(ab2, one_v);
- *c3_v = _mm256_madd_epi16(ab3, one_v);
-}
-
-// c = a0 * b0 + a1 * b1
-// A is in uint8_t
-// B is in int8_t and pre-interleaved
-// C is in int32_t and 4 registers have results in the following layout:
-// c0_v: c[0:4], c[4:8]
-// c1_v: c[8:12], c[12:16]
-// c2_v: c[16:20], c[20:24]
-// c3_v: c[24:28], c[28:32]
-template <bool SUM_A = false>
-static inline ALWAYS_INLINE void madd_epi16x2_packed(
- __m256i a0_v,
- __m256i a1_v,
- const __m256i* b,
- __m256i* c0_v,
- __m256i* c1_v,
- __m256i* c2_v,
- __m256i* c3_v,
- __m256i* a_sum = nullptr) {
- __m256i a_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v);
- __m256i a_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v);
-
- if (SUM_A) {
- __m256i one_epi8_v = _mm256_set1_epi8(1);
- a_sum[0] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a_lo_v, one_epi8_v), a_sum[0]);
- a_sum[1] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a_hi_v, one_epi8_v), a_sum[1]);
- }
-
- __m256i b0_v = _mm256_load_si256(b + 0);
- __m256i b1_v = _mm256_load_si256(b + 1);
-
- __m256i ab_lo_v = _mm256_maddubs_epi16(a_lo_v, b0_v);
- __m256i ab_hi_v = _mm256_maddubs_epi16(a_hi_v, b1_v);
-
- *c0_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_lo_v));
- *c1_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_hi_v));
- *c2_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_lo_v, 1));
- *c3_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_hi_v, 1));
-}
-
-// c = a0 * b0
-// A is in uint8_t
-// B is in int8_t and pre-interleaved
-// C is in int32_t and 4 registers have results in the following layout:
-// c0_v: c[0:4], c[4:8]
-// c1_v: c[8:12], c[12:16]
-// c2_v: c[16:20], c[20:24]
-// c3_v: c[24:28], c[28:32]
-template <bool SUM_A = false>
-static inline ALWAYS_INLINE void madd_epi16_packed(
- __m256i a_v,
- const __m256i* b,
- __m256i* c0_v,
- __m256i* c1_v,
- __m256i* c2_v,
- __m256i* c3_v,
- __m256i* a_sum = nullptr) {
- __m256i zero_v = _mm256_setzero_si256();
-
- __m256i a_lo_v = _mm256_unpacklo_epi8(a_v, zero_v);
- __m256i a_hi_v = _mm256_unpackhi_epi8(a_v, zero_v);
-
- if (SUM_A) {
- __m256i one_epi8_v = _mm256_set1_epi8(1);
- a_sum[0] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a_lo_v, one_epi8_v), a_sum[0]);
- a_sum[1] =
- _mm256_adds_epi16(_mm256_maddubs_epi16(a_hi_v, one_epi8_v), a_sum[1]);
- }
-
- __m256i b0_v = _mm256_load_si256(b + 0);
- __m256i b1_v = _mm256_load_si256(b + 1);
-
- __m256i ab_lo_v = _mm256_maddubs_epi16(a_lo_v, b0_v);
- __m256i ab_hi_v = _mm256_maddubs_epi16(a_hi_v, b1_v);
-
- *c0_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_lo_v));
- *c1_v = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(ab_hi_v));
- *c2_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_lo_v, 1));
- *c3_v = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(ab_hi_v, 1));
-}
-
-// K is the number of accumulations we're doing
-template <int K, bool SUM_A = false, bool REMAINDER = false, bool ACC = false>
-static inline ALWAYS_INLINE void inner_prod_packed_(
- const __m256i* a_v,
- const __m256i* Bp,
- int32_t* C,
- int remainder,
- __m256i* a_sum = nullptr) {
- __m256i c[4], c_temp[4];
- __m256i a_sum_temp[2] = {0, 0};
-
- int k = 0;
- if (K >= 4) {
- madd_epi16x4_packed<SUM_A>(
- a_v[0],
- a_v[1],
- a_v[2],
- a_v[3],
- Bp,
- &c[0],
- &c[1],
- &c[2],
- &c[3],
- a_sum_temp);
-
- for (k = 4; k < K / 4 * 4; k += 4) {
- madd_epi16x4_packed<SUM_A>(
- a_v[k + 0],
- a_v[k + 1],
- a_v[k + 2],
- a_v[k + 3],
- Bp + k,
- &c_temp[0],
- &c_temp[1],
- &c_temp[2],
- &c_temp[3],
- a_sum_temp);
-
- c[0] = _mm256_add_epi32(c[0], c_temp[0]);
- c[1] = _mm256_add_epi32(c[1], c_temp[1]);
- c[2] = _mm256_add_epi32(c[2], c_temp[2]);
- c[3] = _mm256_add_epi32(c[3], c_temp[3]);
- }
- } else {
- c[0] = _mm256_setzero_si256();
- c[1] = _mm256_setzero_si256();
- c[2] = _mm256_setzero_si256();
- c[3] = _mm256_setzero_si256();
- }
-
- if (K - k == 3) {
- madd_epi16x3_packed<SUM_A>(
- a_v[k],
- a_v[k + 1],
- a_v[k + 2],
- Bp + k,
- &c_temp[0],
- &c_temp[1],
- &c_temp[2],
- &c_temp[3],
- a_sum_temp);
-
- c[0] = _mm256_add_epi32(c[0], c_temp[0]);
- c[1] = _mm256_add_epi32(c[1], c_temp[1]);
- c[2] = _mm256_add_epi32(c[2], c_temp[2]);
- c[3] = _mm256_add_epi32(c[3], c_temp[3]);
- }
-
- c_temp[0] = _mm256_permute2f128_si256(c[0], c[1], 0x20);
- c_temp[1] = _mm256_permute2f128_si256(c[2], c[3], 0x20);
- c_temp[2] = _mm256_permute2f128_si256(c[0], c[1], 0x31);
- c_temp[3] = _mm256_permute2f128_si256(c[2], c[3], 0x31);
-
- if (K - k == 0 || K - k == 3) {
- c[0] = c_temp[0];
- c[1] = c_temp[1];
- c[2] = c_temp[2];
- c[3] = c_temp[3];
- } else {
- if (K - k == 1) {
- madd_epi16_packed<SUM_A>(
- a_v[k], Bp + k, &c[0], &c[1], &c[2], &c[3], a_sum_temp);
- } else if (K - k == 2) {
- madd_epi16x2_packed<SUM_A>(
- a_v[k], a_v[k + 1], Bp + k, &c[0], &c[1], &c[2], &c[3], a_sum_temp);
- }
-
- c[0] = _mm256_add_epi32(c[0], c_temp[0]);
- c[1] = _mm256_add_epi32(c[1], c_temp[1]);
- c[2] = _mm256_add_epi32(c[2], c_temp[2]);
- c[3] = _mm256_add_epi32(c[3], c_temp[3]);
- }
-
- if (REMAINDER) {
- for (int r = 0; r < remainder / 8; ++r) {
- if (ACC) {
- _mm256_storeu_si256(
- reinterpret_cast<__m256i*>(C + r * 8),
- _mm256_add_epi32(
- _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + r * 8)),
- c[r]));
- } else {
- _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + r * 8), c[r]);
- }
- }
- } else {
- if (ACC) {
- _mm256_storeu_si256(
- reinterpret_cast<__m256i*>(C),
- _mm256_add_epi32(
- _mm256_loadu_si256(reinterpret_cast<__m256i*>(C)), c[0]));
- _mm256_storeu_si256(
- reinterpret_cast<__m256i*>(C + 8),
- _mm256_add_epi32(
- _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 8)), c[1]));
- _mm256_storeu_si256(
- reinterpret_cast<__m256i*>(C + 16),
- _mm256_add_epi32(
- _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 16)), c[2]));
- _mm256_storeu_si256(
- reinterpret_cast<__m256i*>(C + 24),
- _mm256_add_epi32(
- _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 24)), c[3]));
- } else {
- _mm256_storeu_si256(reinterpret_cast<__m256i*>(C), c[0]);
- _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 8), c[1]);
- _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 16), c[2]);
- _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 24), c[3]);
- }
- }
-
- if (SUM_A) {
- a_sum[0] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(a_sum_temp[0]));
- a_sum[1] = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(a_sum_temp[1]));
- a_sum[2] =
- _mm256_cvtepi16_epi32(_mm256_extracti128_si256(a_sum_temp[0], 1));
- a_sum[3] =
- _mm256_cvtepi16_epi32(_mm256_extracti128_si256(a_sum_temp[1], 1));
- }
-}
-
template <bool SUM_A = false, bool REMAINDER = false>
static inline ALWAYS_INLINE void inner_prod_3x3_packed_(
const __m256i* a_v,
@@ -534,238 +26,6 @@ static inline ALWAYS_INLINE void inner_prod_3x3_packed_(
return inner_prod_packed_<9, SUM_A, REMAINDER>(a_v, Bp, C, remainder, a_sum);
}
-// Almost same as ReQuantizeOutput in OutputProcessing-inh.h but different
-// row_offsets for each row because of depth-wise convolution
-template <
- bool FUSE_RELU,
- bool HAS_BIAS,
- bool PER_CHANNEL_QUANTIZATION,
- bool A_SYMMETRIC,
- bool B_SYMMETRIC>
-static inline ALWAYS_INLINE void requantize_(
- int32_t A_zero_point,
- const float* C_multiplier,
- int32_t C_zero_point,
- const int32_t* C_int32,
- uint8_t* C_uint8,
- int n,
- const int32_t* row_offsets,
- const int32_t* col_offsets,
- const int32_t* bias) {
- __m256 multiplier_v = _mm256_setzero_ps();
- if (!PER_CHANNEL_QUANTIZATION) {
- multiplier_v = _mm256_set1_ps(*C_multiplier);
- }
-
- __m256i min_v = _mm256_set1_epi8(static_cast<uint8_t>(0));
- __m256i max_v = _mm256_set1_epi8(static_cast<uint8_t>(255));
-
- if (A_SYMMETRIC) {
- assert(A_zero_point == 0 || col_offsets == nullptr);
- }
- __m256i A_zero_point_v = _mm256_set1_epi32(A_zero_point);
- __m256i C_zero_point_epi16_v = _mm256_set1_epi16(C_zero_point);
- __m256i C_zero_point_epi8_v = _mm256_set1_epi8(C_zero_point);
-
- __m256i permute_mask_v =
- _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
-
- constexpr int VLEN = 8;
- int j = 0;
- for (; j < n / (VLEN * 4) * (VLEN * 4); j += (VLEN * 4)) {
- __m256i x_v =
- _mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j));
- __m256i y_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(C_int32 + j + VLEN));
- __m256i z_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(C_int32 + j + 2 * VLEN));
- __m256i w_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(C_int32 + j + 3 * VLEN));
-
- __m256i row_offset_v;
- if (!B_SYMMETRIC) {
- row_offset_v =
- _mm256_loadu_si256(reinterpret_cast<const __m256i*>(row_offsets + j));
- x_v = _mm256_sub_epi32(x_v, row_offset_v);
- }
- __m256i col_off_v;
- if (!A_SYMMETRIC) {
- col_off_v = _mm256_mullo_epi32(
- A_zero_point_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(col_offsets + j)));
- x_v = _mm256_sub_epi32(x_v, col_off_v);
- }
-
- if (!B_SYMMETRIC) {
- row_offset_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(row_offsets + j + VLEN));
- y_v = _mm256_sub_epi32(y_v, row_offset_v);
- }
- if (!A_SYMMETRIC) {
- col_off_v = _mm256_mullo_epi32(
- A_zero_point_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(col_offsets + j + VLEN)));
- y_v = _mm256_sub_epi32(y_v, col_off_v);
- }
-
- if (!B_SYMMETRIC) {
- row_offset_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(row_offsets + j + 2 * VLEN));
- z_v = _mm256_sub_epi32(z_v, row_offset_v);
- }
- if (!A_SYMMETRIC) {
- col_off_v = _mm256_mullo_epi32(
- A_zero_point_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(col_offsets + j + 2 * VLEN)));
- z_v = _mm256_sub_epi32(z_v, col_off_v);
- }
-
- if (!B_SYMMETRIC) {
- row_offset_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(row_offsets + j + 3 * VLEN));
- w_v = _mm256_sub_epi32(w_v, row_offset_v);
- }
- if (!A_SYMMETRIC) {
- col_off_v = _mm256_mullo_epi32(
- A_zero_point_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(col_offsets + j + 3 * VLEN)));
- w_v = _mm256_sub_epi32(w_v, col_off_v);
- }
-
- if (HAS_BIAS) { // static if
- x_v = _mm256_add_epi32(
- x_v, _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias + j)));
- y_v = _mm256_add_epi32(
- y_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(bias + j + VLEN)));
- z_v = _mm256_add_epi32(
- z_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(bias + j + 2 * VLEN)));
- w_v = _mm256_add_epi32(
- w_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(bias + j + 3 * VLEN)));
- }
-
- if (PER_CHANNEL_QUANTIZATION) {
- multiplier_v = _mm256_loadu_ps(C_multiplier + j);
- }
- __m256 x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
- if (PER_CHANNEL_QUANTIZATION) {
- multiplier_v = _mm256_loadu_ps(C_multiplier + j + VLEN);
- }
- __m256 y_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(y_v), multiplier_v);
- if (PER_CHANNEL_QUANTIZATION) {
- multiplier_v = _mm256_loadu_ps(C_multiplier + j + 2 * VLEN);
- }
- __m256 z_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(z_v), multiplier_v);
- if (PER_CHANNEL_QUANTIZATION) {
- multiplier_v = _mm256_loadu_ps(C_multiplier + j + 3 * VLEN);
- }
- __m256 w_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(w_v), multiplier_v);
-
- __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
- __m256i y_rounded_v = _mm256_cvtps_epi32(y_scaled_v);
- __m256i z_rounded_v = _mm256_cvtps_epi32(z_scaled_v);
- __m256i w_rounded_v = _mm256_cvtps_epi32(w_scaled_v);
-
- __m256i xy_packed_v = _mm256_adds_epi16(
- _mm256_packs_epi32(x_rounded_v, y_rounded_v), C_zero_point_epi16_v);
- __m256i zw_packed_v = _mm256_adds_epi16(
- _mm256_packs_epi32(z_rounded_v, w_rounded_v), C_zero_point_epi16_v);
- __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v);
- __m256i xyzw_clamped_v = _mm256_max_epu8(
- FUSE_RELU ? C_zero_point_epi8_v : min_v,
- _mm256_min_epu8(xyzw_packed_v, max_v));
-
- xyzw_clamped_v =
- _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
-
- _mm256_storeu_si256(
- reinterpret_cast<__m256i*>(C_uint8 + j), xyzw_clamped_v);
- } // j loop vectorized and unrolled 4x
-
- for (; j < n / VLEN * VLEN; j += VLEN) {
- __m256i x_v =
- _mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j));
-
- if (!B_SYMMETRIC) {
- __m256i row_offset_v =
- _mm256_loadu_si256(reinterpret_cast<const __m256i*>(row_offsets + j));
- x_v = _mm256_sub_epi32(x_v, row_offset_v);
- }
- if (!A_SYMMETRIC) {
- __m256i col_off_v = _mm256_mullo_epi32(
- A_zero_point_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(col_offsets + j)));
- x_v = _mm256_sub_epi32(x_v, col_off_v);
- }
-
- if (HAS_BIAS) { // static if
- x_v = _mm256_add_epi32(
- x_v, _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias + j)));
- }
-
- if (PER_CHANNEL_QUANTIZATION) {
- multiplier_v = _mm256_loadu_ps(C_multiplier + j);
- }
- __m256 x_scaled_v = _mm256_mul_ps(_mm256_cvtepi32_ps(x_v), multiplier_v);
- __m256i x_rounded_v = _mm256_cvtps_epi32(x_scaled_v);
-
- __m256i x_packed_v = _mm256_adds_epi16(
- _mm256_packs_epi32(x_rounded_v, _mm256_setzero_si256()),
- C_zero_point_epi16_v);
- x_packed_v = _mm256_packus_epi16(x_packed_v, _mm256_setzero_si256());
- __m256i x_clamped_v = _mm256_max_epu8(
- FUSE_RELU ? C_zero_point_epi8_v : min_v,
- _mm256_min_epu8(x_packed_v, max_v));
-
- x_clamped_v = _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v);
-
- _mm_storel_epi64(
- reinterpret_cast<__m128i*>(C_uint8 + j),
- _mm256_castsi256_si128(x_clamped_v));
- } // j loop vectorized
-
- for (; j < n; ++j) {
- int32_t raw = C_int32[j];
- if (!B_SYMMETRIC) {
- raw -= row_offsets[j];
- }
- if (!A_SYMMETRIC) {
- raw -= A_zero_point * col_offsets[j];
- }
- if (HAS_BIAS) { // static if
- raw += bias[j];
- }
-
- float ab = raw * C_multiplier[PER_CHANNEL_QUANTIZATION ? j : 0];
- long rounded = lrintf(ab) + C_zero_point;
-
- C_uint8[j] = std::max(
- FUSE_RELU ? static_cast<long>(C_zero_point) : 0l,
- std::min(255l, rounded));
- }
-}
-
-template <bool REMAINDER>
-static inline ALWAYS_INLINE __m256i load_a(
- const uint8_t* A,
- __m256i mask_v) {
- if (REMAINDER) {
- return _mm256_maskload_epi32(reinterpret_cast<const int*>(A), mask_v);
- } else {
- return _mm256_lddqu_si256(reinterpret_cast<const __m256i*>(A));
- }
-}
-
template <
bool SUM_A,
bool REMAINDER = false,
@@ -878,257 +138,11 @@ static inline ALWAYS_INLINE void inner_prod_3x3_packed_(
}
template <
- bool SUM_A,
- bool REMAINDER = false,
- bool PER_CHANNEL_QUANTIZATION = false>
-static inline ALWAYS_INLINE void inner_prod_3x3x3_packed_(
- int T,
- int H,
- int W,
- int K,
- int t_in,
- int h_in,
- int w_in,
- const uint8_t* A,
- int32_t A_zero_point,
- const int8_t* Bp,
- const int32_t* B_zero_point,
- int32_t* C,
- int remainder,
- int32_t* row_offsets) {
- __m256i A_zero_point_v = _mm256_set1_epi8(static_cast<uint8_t>(A_zero_point));
- __m256i mask_v = _mm256_setzero_si256();
- if (REMAINDER) {
- mask_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(masks[remainder / 4]));
- }
-
- // The code below can be written as a simple R*S loop but the compiler
- // doesn't unroll so we're manually unrolling it.
- // constexpr int R = 3, S = 3;
- // array<__m256i, R * S> a_v;
- // for (int r = 0; r < R; ++r) {
- // for (int s = 0; s < S; ++s) {
- // if (h_in + r >= 0 && h_in + r < H && w_in + s >= 0 && w_in + s < W) {
- // if (REMAINDER) {
- // a_v[r * S + s] =
- // _mm256_maskload_epi32((const int *)(A + (r * W + s) * K),
- // mask_v);
- // } else {
- // a_v[r * S + s] =
- // _mm256_lddqu_si256((const __m256i *)(A + (r * W + s) * K));
- // }
- // } else {
- // a_v[r * S + s] = A_zero_point_v;
- // }
- // }
- // }
- __m256i a_v[8];
- a_v[0] = A_zero_point_v;
- a_v[1] = A_zero_point_v;
- a_v[2] = A_zero_point_v;
- a_v[3] = A_zero_point_v;
- a_v[4] = A_zero_point_v;
- a_v[5] = A_zero_point_v;
- a_v[6] = A_zero_point_v;
- a_v[7] = A_zero_point_v;
-
- if (t_in >= 0 && t_in < T) {
- if (h_in >= 0 && h_in < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[0] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[1] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[2] = load_a<REMAINDER>(A + ((0 * H + 0) * W + 2) * K, mask_v);
- }
- }
-
- if (h_in + 1 >= 0 && h_in + 1 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[3] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[4] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[5] = load_a<REMAINDER>(A + ((0 * H + 1) * W + 2) * K, mask_v);
- }
- }
-
- if (h_in + 2 >= 0 && h_in + 2 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[6] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[7] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 1) * K, mask_v);
- }
- }
- }
-
- __m256i a_sum[4];
- inner_prod_packed_<8, SUM_A, REMAINDER>(
- a_v, reinterpret_cast<const __m256i*>(Bp), C, remainder, a_sum);
-
- a_v[0] = A_zero_point_v;
- a_v[1] = A_zero_point_v;
- a_v[2] = A_zero_point_v;
- a_v[3] = A_zero_point_v;
- a_v[4] = A_zero_point_v;
- a_v[5] = A_zero_point_v;
- a_v[6] = A_zero_point_v;
- a_v[7] = A_zero_point_v;
-
- if (t_in >= 0 && t_in < T) {
- if (h_in + 2 >= 0 && h_in + 2 < H) {
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[0] = load_a<REMAINDER>(A + ((0 * H + 2) * W + 2) * K, mask_v);
- }
- }
- }
-
- if (t_in + 1 >= 0 && t_in + 1 < T) {
- if (h_in >= 0 && h_in < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[1] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[2] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[3] = load_a<REMAINDER>(A + ((1 * H + 0) * W + 2) * K, mask_v);
- }
- }
-
- if (h_in + 1 >= 0 && h_in + 1 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[4] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[5] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[6] = load_a<REMAINDER>(A + ((1 * H + 1) * W + 2) * K, mask_v);
- }
- }
-
- if (h_in + 2 >= 0 && h_in + 2 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[7] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 0) * K, mask_v);
- }
- }
- }
-
- __m256i a_sum_temp[4];
- inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>(
- a_v, reinterpret_cast<const __m256i*>(Bp) + 8, C, remainder, a_sum_temp);
- if (SUM_A) {
- a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]);
- a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]);
- a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]);
- a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]);
- }
-
- a_v[0] = A_zero_point_v;
- a_v[1] = A_zero_point_v;
- a_v[2] = A_zero_point_v;
- a_v[3] = A_zero_point_v;
- a_v[4] = A_zero_point_v;
- a_v[5] = A_zero_point_v;
- a_v[6] = A_zero_point_v;
- a_v[7] = A_zero_point_v;
-
- if (t_in + 1 >= 0 && t_in + 1 < T) {
- if (h_in + 2 >= 0 && h_in + 2 < H) {
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[0] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[1] = load_a<REMAINDER>(A + ((1 * H + 2) * W + 2) * K, mask_v);
- }
- }
- }
-
- if (t_in + 2 >= 0 && t_in + 2 < T) {
- if (h_in >= 0 && h_in < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[2] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[3] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[4] = load_a<REMAINDER>(A + ((2 * H + 0) * W + 2) * K, mask_v);
- }
- }
-
- if (h_in + 1 >= 0 && h_in + 1 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[5] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[6] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[7] = load_a<REMAINDER>(A + ((2 * H + 1) * W + 2) * K, mask_v);
- }
- }
- }
-
- inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>(
- a_v, reinterpret_cast<const __m256i*>(Bp) + 16, C, remainder, a_sum_temp);
- if (SUM_A) {
- a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]);
- a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]);
- a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]);
- a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]);
- }
-
- a_v[0] = A_zero_point_v;
- a_v[1] = A_zero_point_v;
- a_v[2] = A_zero_point_v;
-
- if (t_in + 2 >= 0 && t_in + 2 < T) {
- if (h_in + 2 >= 0 && h_in + 2 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[0] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[1] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[2] = load_a<REMAINDER>(A + ((2 * H + 2) * W + 2) * K, mask_v);
- }
- }
- }
-
- inner_prod_packed_<3, SUM_A, REMAINDER, true /* acc */>(
- a_v, reinterpret_cast<const __m256i*>(Bp) + 24, C, remainder, a_sum_temp);
-
- if (SUM_A) {
- a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]);
- a_sum[1] = _mm256_add_epi32(a_sum[1], a_sum_temp[1]);
- a_sum[2] = _mm256_add_epi32(a_sum[2], a_sum_temp[2]);
- a_sum[3] = _mm256_add_epi32(a_sum[3], a_sum_temp[3]);
-
- __m256i B_zero_point_v;
- for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) {
- if (PER_CHANNEL_QUANTIZATION) {
- B_zero_point_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(B_zero_point + i * 8));
- } else {
- B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]);
- }
- _mm256_store_si256(
- reinterpret_cast<__m256i*>(&row_offsets[i * 8]),
- _mm256_mullo_epi32(a_sum[i], B_zero_point_v));
- }
- }
-}
-
-template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, bool B_SYMMETRIC>
+ bool FUSE_RELU,
+ bool HAS_BIAS,
+ bool A_SYMMETRIC,
+ bool B_SYMMETRIC,
+ typename BIAS_TYPE>
static inline ALWAYS_INLINE void depthwise_3x3_kernel_(
int H,
int W,
@@ -1147,7 +161,8 @@ static inline ALWAYS_INLINE void depthwise_3x3_kernel_(
uint8_t* C_uint8,
int32_t* row_offsets,
const int32_t* col_offsets,
- const int32_t* bias) {
+ const BIAS_TYPE* bias,
+ float act_times_w_scale) {
constexpr int S = 3;
constexpr int PAD_T = 1, PAD_L = 1, PAD_R = 1;
int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
@@ -1192,7 +207,8 @@ static inline ALWAYS_INLINE void depthwise_3x3_kernel_(
HAS_BIAS,
false, /*PER_CHAN_QUANT*/
A_SYMMETRIC,
- B_SYMMETRIC>(
+ B_SYMMETRIC,
+ BIAS_TYPE>(
A_zero_point,
&C_multiplier,
C_zero_point,
@@ -1201,95 +217,11 @@ static inline ALWAYS_INLINE void depthwise_3x3_kernel_(
K,
row_offsets,
col_offsets,
- bias);
-}
-
-template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, bool B_SYMMETRIC>
-static inline ALWAYS_INLINE void depthwise_3x3x3_kernel_(
- int T,
- int H,
- int W,
- int K,
- int t,
- int h,
- int w,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- int32_t B_zero_point,
- const int8_t* Bp,
- float C_multiplier,
- int32_t C_zero_point,
- int32_t* C_int32,
- uint8_t* C_uint8,
- int32_t* row_offsets,
- const int32_t* col_offsets,
- const int32_t* bias) {
- constexpr int R = 3, S = 3;
- constexpr int PAD_P = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
- int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
- int t_in = -PAD_P + t * stride_t;
- int h_in = -PAD_T + h * stride_h;
- int w_in = -PAD_L + w * stride_w;
-
- int k;
- for (k = 0; k < K / 32 * 32; k += 32) {
- inner_prod_3x3x3_packed_<!B_SYMMETRIC /*SUM_A*/>(
- T,
- H,
- W,
- K,
- t_in,
- h_in,
- w_in,
- A + ((t_in * H + h_in) * W + w_in) * K + k,
- A_zero_point,
- Bp + k * 28,
- &B_zero_point,
- C_int32 + k,
- 0,
- B_SYMMETRIC ? nullptr : &row_offsets[k]);
- }
- int remainder = K - k;
- if (remainder) {
- inner_prod_3x3x3_packed_<!B_SYMMETRIC /*SUM_A*/, true>(
- T,
- H,
- W,
- K,
- t_in,
- h_in,
- w_in,
- A + ((t_in * H + h_in) * W + w_in) * K + k,
- A_zero_point,
- Bp + k * 28,
- &B_zero_point,
- C_int32 + k,
- remainder,
- B_SYMMETRIC ? nullptr : &row_offsets[k]);
- }
-
- requantize_<
- FUSE_RELU,
- HAS_BIAS,
- false, /*PER_CHAN_QUANT*/
- A_SYMMETRIC,
- B_SYMMETRIC>(
- A_zero_point,
- &C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K,
- K,
- row_offsets,
- col_offsets,
- bias);
+ bias,
+ &act_times_w_scale);
}
-template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC>
+template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, typename BIAS_TYPE>
static inline ALWAYS_INLINE void
depthwise_3x3_per_channel_quantization_kernel_(
int H,
@@ -1309,7 +241,8 @@ depthwise_3x3_per_channel_quantization_kernel_(
uint8_t* C_uint8,
int32_t* row_offsets,
const int32_t* col_offsets,
- const int32_t* bias) {
+ const BIAS_TYPE* bias,
+ const float* act_times_w_scale) {
constexpr int S = 3;
constexpr int PAD_T = 1, PAD_L = 1, PAD_R = 1;
int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
@@ -1360,7 +293,8 @@ depthwise_3x3_per_channel_quantization_kernel_(
HAS_BIAS,
true, /*PER_CHAN_QUANT*/
A_SYMMETRIC,
- false /*B_SYMM*/>(
+ false, /*B_SYMM*/
+ BIAS_TYPE>(
A_zero_point,
C_multiplier,
C_zero_point,
@@ -1369,113 +303,20 @@ depthwise_3x3_per_channel_quantization_kernel_(
K,
row_offsets,
col_offsets,
- bias);
-}
-
-template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC>
-static inline ALWAYS_INLINE void
-depthwise_3x3x3_per_channel_quantization_kernel_(
- int T,
- int H,
- int W,
- int K,
- int t,
- int h,
- int w,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- const int32_t* B_zero_point,
- const int8_t* Bp,
- const float* C_multiplier,
- int32_t C_zero_point,
- int32_t* C_int32,
- uint8_t* C_uint8,
- int32_t* row_offsets,
- const int32_t* col_offsets,
- const int32_t* bias) {
- constexpr int R = 3, S = 3;
- constexpr int PAD_P = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
- int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
- int t_in = -PAD_P + t * stride_t;
- int h_in = -PAD_T + h * stride_h;
- int w_in = -PAD_L + w * stride_w;
-
- int k;
- for (k = 0; k < K / 32 * 32; k += 32) {
- inner_prod_3x3x3_packed_<
- true, /*SUM_A*/
- false, /*remainder*/
- true /*per-channel*/>(
- T,
- H,
- W,
- K,
- t_in,
- h_in,
- w_in,
- A + ((t_in * H + h_in) * W + w_in) * K + k,
- A_zero_point,
- Bp + k * 28,
- B_zero_point + k,
- C_int32 + k,
- 0,
- &row_offsets[k]);
- }
- int remainder = K - k;
- if (remainder) {
- inner_prod_3x3x3_packed_<
- true, /*SUM_A*/
- true, /*remainder*/
- true /*per-channel*/>(
- T,
- H,
- W,
- K,
- t_in,
- h_in,
- w_in,
- A + ((t_in * H + h_in) * W + w_in) * K + k,
- A_zero_point,
- Bp + k * 28,
- B_zero_point + k,
- C_int32 + k,
- remainder,
- &row_offsets[k]);
- }
- requantize_<
- FUSE_RELU,
- HAS_BIAS,
- true, /*PER_CHAN_QUANT*/
- A_SYMMETRIC,
- false /*B_SYMM*/>(
- A_zero_point,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K,
- K,
- row_offsets,
- col_offsets,
- bias);
-}
-
-static pair<int, int> closest_factors_(int n) {
- int a = (int)std::sqrt(n);
- while (n % a != 0) {
- a--;
- }
- return {a, n / a}; // a <= n / a
+ bias,
+ act_times_w_scale);
}
// TODO: short-circuit when B_zero_point is 0 or A_zero_point is 0
// This implemntation should be general enough to handle not just 3x3 but other
// filter shapes by parameterizing with R and S but restricting it to just 3x3
// for now.
-template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, bool B_SYMMETRIC>
+template <
+ bool FUSE_RELU,
+ bool HAS_BIAS,
+ bool A_SYMMETRIC,
+ bool B_SYMMETRIC,
+ typename BIAS_TYPE>
static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
int N,
int H,
@@ -1486,13 +327,14 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
int32_t A_zero_point,
const uint8_t* A,
int32_t B_zero_point,
- const Packed3x3ConvMatrix& B,
+ const PackedDepthWiseConvMatrix& B,
float C_multiplier,
int32_t C_zero_point,
int32_t* C_int32,
uint8_t* C_uint8,
const int32_t* col_offsets,
- const int32_t* bias,
+ const BIAS_TYPE* bias,
+ float act_times_w_scale,
int thread_id,
int num_threads) {
assert(K % 8 == 0);
@@ -1551,7 +393,12 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
if (h_begin == 0) {
if (w_begin == 0) {
- depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>(
+ depthwise_3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -1569,11 +416,17 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
- depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>(
+ depthwise_3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -1591,12 +444,18 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
if (w_end == W_OUT) {
w = W_OUT - 1;
- depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>(
+ depthwise_3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -1614,14 +473,20 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
}
for (h = std::max(1, h_begin); h < std::min(H - 1, h_end); ++h) {
if (w_begin == 0) {
w = 0;
- depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>(
+ depthwise_3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -1639,11 +504,17 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
- depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>(
+ depthwise_3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -1661,12 +532,18 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
if (w_end == W_OUT) {
w = W_OUT - 1;
- depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>(
+ depthwise_3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -1684,7 +561,8 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
}
@@ -1692,7 +570,12 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
h = H_OUT - 1;
w = 0;
if (w_begin == 0) {
- depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>(
+ depthwise_3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -1710,11 +593,17 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
- depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>(
+ depthwise_3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -1732,12 +621,18 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
if (w_end == W_OUT) {
w = W_OUT - 1;
- depthwise_3x3_kernel_<FUSE_RELU, HAS_BIAS, A_SYMMETRIC, B_SYMMETRIC>(
+ depthwise_3x3_kernel_<
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -1755,126 +650,15 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
}
} // for each n
FREE(row_offsets);
};
-template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, bool B_SYMMETRIC>
-static inline ALWAYS_INLINE void depthwise_3x3x3_pad_1_(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- int32_t B_zero_point,
- const Packed3x3x3ConvMatrix& B,
- float C_multiplier,
- int32_t C_zero_point,
- int32_t* C_int32,
- uint8_t* C_uint8,
- const int32_t* col_offsets,
- const int32_t* bias,
- int thread_id,
- int num_threads) {
- assert(K % 8 == 0);
- constexpr int K_T = 3, K_H = 3, K_W = 3;
- constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
- PAD_R = 1;
- int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
- int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
- const int8_t* Bp = B.PackedMat();
-
- int32_t* row_offsets = static_cast<int32_t*>(ALIGNED_MALLOC(((K + 31) / 32 * 32)*sizeof(int32_t), 64)); // __attribute__((aligned(64)));
-
- int n_begin, n_end;
- int t_begin, t_end, h_begin, h_end;
- if (N >= num_threads) {
- int n_per_thread = (N + num_threads - 1) / num_threads;
- n_begin = std::min(thread_id * n_per_thread, N);
- n_end = std::min(n_begin + n_per_thread, N);
- t_begin = 0;
- t_end = T_OUT;
- h_begin = 0;
- h_end = H_OUT;
- } else {
- int nthreads_per_n = num_threads / N;
- n_begin = std::min(thread_id / nthreads_per_n, N);
- n_end = std::min(n_begin + 1, N);
-
- int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads);
- int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads);
- int nthreads_of_n = tid_of_n_end - tid_of_n_begin;
- int tid_within_n = thread_id - tid_of_n_begin;
- assert(tid_within_n >= 0);
- assert(tid_within_n < nthreads_of_n);
-
- // n is processed by num_threads_t * num_threads_h 2D grid of threads
- int num_threads_t, num_threads_h;
- // num_threads_w <= num_threads_h
- tie(num_threads_t, num_threads_h) = closest_factors_(nthreads_of_n);
- int tid_t = tid_within_n / num_threads_h;
- int tid_h = tid_within_n % num_threads_h;
-
- int t_per_thread = (T_OUT + num_threads_t - 1) / num_threads_t;
- t_begin = std::min(tid_t * t_per_thread, T_OUT);
- t_end = std::min(t_begin + t_per_thread, T_OUT);
-
- int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h;
- h_begin = std::min(tid_h * h_per_thread, H_OUT);
- h_end = std::min(h_begin + h_per_thread, H_OUT);
- }
-
- for (int n = n_begin; n < n_end; ++n) {
- const uint8_t* A_base = A + n * T * H * W * K;
- uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * K;
-
- for (int t = t_begin; t < t_end; ++t) {
- for (int h = h_begin; h < h_end; ++h) {
- for (int w = 0; w < W_OUT; ++w) {
- depthwise_3x3x3_kernel_<
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC,
- B_SYMMETRIC>(
- T,
- H,
- W,
- K,
- t,
- h,
- w,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias);
- } // w
- } // h
- } // t
- } // for each n
-
- FREE(row_offsets);
-};
-
-template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC>
+template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, typename BIAS_TYPE>
static inline ALWAYS_INLINE void
depthwise_3x3_per_channel_quantization_pad_1_(
int N,
@@ -1886,13 +670,14 @@ depthwise_3x3_per_channel_quantization_pad_1_(
int32_t A_zero_point,
const uint8_t* A,
const int32_t* B_zero_point,
- const Packed3x3ConvMatrix& B,
+ const PackedDepthWiseConvMatrix& B,
const float* C_multiplier,
int32_t C_zero_point,
int32_t* C_int32,
uint8_t* C_uint8,
const int32_t* col_offsets,
- const int32_t* bias,
+ const BIAS_TYPE* bias,
+ const float* act_times_w_scale,
int thread_id,
int num_threads) {
assert(K % 8 == 0);
@@ -1954,7 +739,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
depthwise_3x3_per_channel_quantization_kernel_<
FUSE_RELU,
HAS_BIAS,
- A_SYMMETRIC>(
+ A_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -1972,14 +758,16 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
depthwise_3x3_per_channel_quantization_kernel_<
FUSE_RELU,
HAS_BIAS,
- A_SYMMETRIC>(
+ A_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -1997,7 +785,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
if (w_end == W_OUT) {
@@ -2005,7 +794,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
depthwise_3x3_per_channel_quantization_kernel_<
FUSE_RELU,
HAS_BIAS,
- A_SYMMETRIC>(
+ A_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -2023,7 +813,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
}
@@ -2033,7 +824,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
depthwise_3x3_per_channel_quantization_kernel_<
FUSE_RELU,
HAS_BIAS,
- A_SYMMETRIC>(
+ A_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -2051,14 +843,16 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
depthwise_3x3_per_channel_quantization_kernel_<
FUSE_RELU,
HAS_BIAS,
- A_SYMMETRIC>(
+ A_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -2076,7 +870,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
if (w_end == W_OUT) {
@@ -2084,7 +879,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
depthwise_3x3_per_channel_quantization_kernel_<
FUSE_RELU,
HAS_BIAS,
- A_SYMMETRIC>(
+ A_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -2102,7 +898,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
}
@@ -2113,7 +910,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
depthwise_3x3_per_channel_quantization_kernel_<
FUSE_RELU,
HAS_BIAS,
- A_SYMMETRIC>(
+ A_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -2131,14 +929,16 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
depthwise_3x3_per_channel_quantization_kernel_<
FUSE_RELU,
HAS_BIAS,
- A_SYMMETRIC>(
+ A_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -2156,7 +956,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
if (w_end == W_OUT) {
@@ -2164,7 +965,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
depthwise_3x3_per_channel_quantization_kernel_<
FUSE_RELU,
HAS_BIAS,
- A_SYMMETRIC>(
+ A_SYMMETRIC,
+ BIAS_TYPE>(
H,
W,
K,
@@ -2182,128 +984,15 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_uint8_base,
row_offsets,
col_offsets,
- bias);
+ bias,
+ act_times_w_scale);
}
}
} // for each n
-
- FREE(row_offsets);
-};
-
-template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC>
-static inline ALWAYS_INLINE void
-depthwise_3x3x3_per_channel_quantization_pad_1_(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- const int32_t* B_zero_point,
- const Packed3x3x3ConvMatrix& B,
- const float* C_multiplier,
- int32_t C_zero_point,
- int32_t* C_int32,
- uint8_t* C_uint8,
- const int32_t* col_offsets,
- const int32_t* bias,
- int thread_id,
- int num_threads) {
- assert(K % 8 == 0);
- constexpr int K_T = 3, K_H = 3, K_W = 3;
- constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
- PAD_R = 1;
- int T_OUT = (T + PAD_P + PAD_N - K_T) / stride_t + 1;
- int H_OUT = (H + PAD_T + PAD_B - K_H) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
- const int8_t* Bp = B.PackedMat();
-
- int32_t* row_offsets = static_cast<int32_t*>(ALIGNED_MALLOC(((K + 31) / 32 * 32)*sizeof(int32_t), 64)); // __attribute__((aligned(64)));
-
- int n_begin, n_end;
- int t_begin, t_end, h_begin, h_end;
- if (N >= num_threads) {
- int n_per_thread = (N + num_threads - 1) / num_threads;
- n_begin = std::min(thread_id * n_per_thread, N);
- n_end = std::min(n_begin + n_per_thread, N);
- t_begin = 0;
- t_end = T_OUT;
- h_begin = 0;
- h_end = H_OUT;
- } else {
- int nthreads_per_n = num_threads / N;
- n_begin = std::min(thread_id / nthreads_per_n, N);
- n_end = std::min(n_begin + 1, N);
-
- int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads);
- int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads);
- int nthreads_of_n = tid_of_n_end - tid_of_n_begin;
- int tid_within_n = thread_id - tid_of_n_begin;
- assert(tid_within_n >= 0);
- assert(tid_within_n < nthreads_of_n);
-
- // n is processed by num_threads_t * num_threads_h 2D grid of threads
- int num_threads_t, num_threads_h;
- // num_threads_w <= num_threads_h
- tie(num_threads_t, num_threads_h) = closest_factors_(nthreads_of_n);
- int tid_t = tid_within_n / num_threads_h;
- int tid_h = tid_within_n % num_threads_h;
-
- int t_per_thread = (T_OUT + num_threads_t - 1) / num_threads_t;
- t_begin = std::min(tid_t * t_per_thread, T_OUT);
- t_end = std::min(t_begin + t_per_thread, T_OUT);
-
- int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h;
- h_begin = std::min(tid_h * h_per_thread, H_OUT);
- h_end = std::min(h_begin + h_per_thread, H_OUT);
- }
-
- for (int n = n_begin; n < n_end; ++n) {
- const uint8_t* A_base = A + n * T * H * W * K;
- uint8_t* C_uint8_base = C_uint8 + n * T_OUT * H_OUT * W_OUT * K;
-
- for (int t = t_begin; t < t_end; ++t) {
- for (int h = h_begin; h < h_end; ++h) {
- for (int w = 0; w < W_OUT; ++w) {
- depthwise_3x3x3_per_channel_quantization_kernel_<
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC>(
- T,
- H,
- W,
- K,
- t,
- h,
- w,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias);
- } // w
- } // h
- } // t
- } // for each n
-
- FREE(row_offsets);
};
// Dispatch A_SYMMETRIC and B_SYMMETRIC
-template <bool FUSE_RELU, bool HAS_BIAS>
+template <bool FUSE_RELU, bool HAS_BIAS, typename BIAS_TYPE>
static void depthwise_3x3_pad_1_(
int N,
int H,
@@ -2314,12 +1003,13 @@ static void depthwise_3x3_pad_1_(
int32_t A_zero_point,
const uint8_t* A,
int32_t B_zero_point,
- const Packed3x3ConvMatrix& B,
+ const PackedDepthWiseConvMatrix& B,
float C_multiplier,
int32_t C_zero_point,
uint8_t* C,
const int32_t* col_offsets,
- const int32_t* bias,
+ const BIAS_TYPE* bias,
+ float act_times_w_scale,
int thread_id,
int num_threads) {
int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32];
@@ -2329,7 +1019,8 @@ static void depthwise_3x3_pad_1_(
FUSE_RELU,
HAS_BIAS,
true /*A_symmetric*/,
- true /*B_symmetric*/>(
+ true /*B_symmetric*/,
+ BIAS_TYPE>(
N,
H,
W,
@@ -2346,6 +1037,7 @@ static void depthwise_3x3_pad_1_(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else {
@@ -2353,7 +1045,8 @@ static void depthwise_3x3_pad_1_(
FUSE_RELU,
HAS_BIAS,
true /*A_symmetric*/,
- false /*B_symmetric*/>(
+ false /*B_symmetric*/,
+ BIAS_TYPE>(
N,
H,
W,
@@ -2370,6 +1063,7 @@ static void depthwise_3x3_pad_1_(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
}
@@ -2379,7 +1073,8 @@ static void depthwise_3x3_pad_1_(
FUSE_RELU,
HAS_BIAS,
false /*A_symmetric*/,
- true /*B_symmetric*/>(
+ true /*B_symmetric*/,
+ BIAS_TYPE>(
N,
H,
W,
@@ -2396,6 +1091,7 @@ static void depthwise_3x3_pad_1_(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else {
@@ -2403,7 +1099,8 @@ static void depthwise_3x3_pad_1_(
FUSE_RELU,
HAS_BIAS,
false /*A_symmetric*/,
- false /*B_symmetric*/>(
+ false /*B_symmetric*/,
+ BIAS_TYPE>(
N,
H,
W,
@@ -2420,6 +1117,7 @@ static void depthwise_3x3_pad_1_(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
}
@@ -2428,7 +1126,7 @@ static void depthwise_3x3_pad_1_(
}
// Dispatch HAS_BIAS
-template <bool FUSE_RELU>
+template <bool FUSE_RELU, typename BIAS_TYPE>
static void depthwise_3x3_pad_1_(
int N,
int H,
@@ -2439,16 +1137,17 @@ static void depthwise_3x3_pad_1_(
int32_t A_zero_point,
const uint8_t* A,
int32_t B_zero_point,
- const Packed3x3ConvMatrix& B,
+ const PackedDepthWiseConvMatrix& B,
float C_multiplier,
int32_t C_zero_point,
uint8_t* C,
const int32_t* col_offsets,
- const int32_t* bias,
+ const BIAS_TYPE* bias,
+ float act_times_w_scale,
int thread_id,
int num_threads) {
if (bias) {
- depthwise_3x3_pad_1_<FUSE_RELU, true /*HAS_BIAS*/>(
+ depthwise_3x3_pad_1_<FUSE_RELU, true /*HAS_BIAS*/, BIAS_TYPE>(
N,
H,
W,
@@ -2464,10 +1163,11 @@ static void depthwise_3x3_pad_1_(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else {
- depthwise_3x3_pad_1_<FUSE_RELU, false /*HAS_BIAS*/>(
+ depthwise_3x3_pad_1_<FUSE_RELU, false /*HAS_BIAS*/, BIAS_TYPE>(
N,
H,
W,
@@ -2483,6 +1183,7 @@ static void depthwise_3x3_pad_1_(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
}
@@ -2490,6 +1191,7 @@ static void depthwise_3x3_pad_1_(
// Dispatch input shape and FUSE_RELU
// assumption: W > 3 and H > 3
+template <typename BIAS_TYPE>
void depthwise_3x3_pad_1(
int N,
int H,
@@ -2500,18 +1202,33 @@ void depthwise_3x3_pad_1(
int32_t A_zero_point,
const uint8_t* A,
int32_t B_zero_point,
- const Packed3x3ConvMatrix& B,
+ const PackedDepthWiseConvMatrix& B,
float C_multiplier,
int32_t C_zero_point,
uint8_t* C,
const int32_t* col_offsets,
- const int32_t* bias,
+ const BIAS_TYPE* bias,
bool fuse_relu,
+ float act_times_w_scale,
int thread_id,
int num_threads) {
+ if (B.GetKernelProduct() != 3 * 3) {
+ string msg =
+ "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " +
+ to_string(3 * 3) + " but has " + to_string(B.GetKernelProduct());
+ throw logic_error(msg);
+ }
+ if (stride_h == 0 || stride_w == 0 || num_threads == 0) {
+ assert(0 && "stride_h == 0 || stride_w == 0 || num_threads == 0");
+ return;
+ }
+ if (N == 0) {
+ // In C2, batch size 0 is allowed, so we should just early return.
+ return;
+ }
if (fuse_relu) {
if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
- depthwise_3x3_pad_1_<true /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>(
N,
H,
W,
@@ -2527,10 +1244,11 @@ void depthwise_3x3_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
- depthwise_3x3_pad_1_<true /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>(
N,
H,
W,
@@ -2546,10 +1264,11 @@ void depthwise_3x3_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (1 == stride_h && 1 == stride_w) {
- depthwise_3x3_pad_1_<true /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>(
N,
H,
W,
@@ -2565,10 +1284,11 @@ void depthwise_3x3_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (2 == stride_h && 2 == stride_w) {
- depthwise_3x3_pad_1_<true /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>(
N,
H,
W,
@@ -2584,10 +1304,11 @@ void depthwise_3x3_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else {
- depthwise_3x3_pad_1_<true /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>(
N,
H,
W,
@@ -2603,12 +1324,13 @@ void depthwise_3x3_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
}
} else {
if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
- depthwise_3x3_pad_1_<false /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>(
N,
H,
W,
@@ -2624,10 +1346,11 @@ void depthwise_3x3_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
- depthwise_3x3_pad_1_<false /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>(
N,
H,
W,
@@ -2643,10 +1366,11 @@ void depthwise_3x3_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (1 == stride_h && 1 == stride_w) {
- depthwise_3x3_pad_1_<false /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>(
N,
H,
W,
@@ -2662,10 +1386,11 @@ void depthwise_3x3_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (2 == stride_h && 2 == stride_w) {
- depthwise_3x3_pad_1_<false /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>(
N,
H,
W,
@@ -2681,10 +1406,11 @@ void depthwise_3x3_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else {
- depthwise_3x3_pad_1_<false /* FUSE_RELU */>(
+ depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>(
N,
H,
W,
@@ -2700,283 +1426,15 @@ void depthwise_3x3_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
}
}
}
-// Dispatch A_SYMMETRIC and B_SYMMETRIC
-template <bool FUSE_RELU, bool HAS_BIAS>
-static void depthwise_3x3x3_pad_1_(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- int32_t B_zero_point,
- const Packed3x3x3ConvMatrix& B,
- float C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const int32_t* bias,
- int thread_id,
- int num_threads) {
- int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32];
- if (A_zero_point == 0 || col_offsets == nullptr) {
- if (B_zero_point == 0) {
- depthwise_3x3x3_pad_1_<
- FUSE_RELU,
- HAS_BIAS,
- true /*A_symmetric*/,
- true /*B_symmetric*/>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3x3_pad_1_<
- FUSE_RELU,
- HAS_BIAS,
- true /*A_symmetric*/,
- false /*B_symmetric*/>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- }
- } else {
- if (B_zero_point == 0) {
- depthwise_3x3x3_pad_1_<
- FUSE_RELU,
- HAS_BIAS,
- false /*A_symmetric*/,
- true /*B_symmetric*/>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3x3_pad_1_<
- FUSE_RELU,
- HAS_BIAS,
- false /*A_symmetric*/,
- false /*B_symmetric*/>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- }
- }
- delete[] C_int32_temp;
-}
-
-// Dispatch HAS_BIAS
-template <bool FUSE_RELU>
-static void depthwise_3x3x3_pad_1_(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- int32_t B_zero_point,
- const Packed3x3x3ConvMatrix& B,
- float C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const int32_t* bias,
- int thread_id,
- int num_threads) {
- if (bias) {
- depthwise_3x3x3_pad_1_<FUSE_RELU, true /*HAS_BIAS*/>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3x3_pad_1_<FUSE_RELU, false /*HAS_BIAS*/>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- }
-}
-
-// Dispatch FUSE_RELU
-void depthwise_3x3x3_pad_1(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- int32_t B_zero_point,
- const Packed3x3x3ConvMatrix& B,
- float C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const int32_t* bias,
- bool fuse_relu,
- int thread_id,
- int num_threads) {
- if (fuse_relu) {
- depthwise_3x3x3_pad_1_<true /*FUSE_RELU*/>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3x3_pad_1_<false /*FUSE_RELU*/>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- }
-}
-
// Dispatch A_SYMMETRIC
-template <bool FUSE_RELU, bool HAS_BIAS>
+template <bool FUSE_RELU, bool HAS_BIAS, typename BIAS_TYPE>
static void depthwise_3x3_per_channel_quantization_pad_1_(
int N,
int H,
@@ -2987,12 +1445,13 @@ static void depthwise_3x3_per_channel_quantization_pad_1_(
int32_t A_zero_point,
const uint8_t* A,
const int32_t* B_zero_point,
- const Packed3x3ConvMatrix& Bp,
+ const PackedDepthWiseConvMatrix& Bp,
const float* C_multiplier,
int32_t C_zero_point,
uint8_t* C,
const int32_t* col_offsets,
- const int32_t* bias,
+ const BIAS_TYPE* bias,
+ const float* act_times_w_scale,
int thread_id,
int num_threads) {
int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32];
@@ -3000,7 +1459,8 @@ static void depthwise_3x3_per_channel_quantization_pad_1_(
depthwise_3x3_per_channel_quantization_pad_1_<
FUSE_RELU,
HAS_BIAS,
- true /*A_SYMM*/>(
+ true /*A_SYMM*/,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3017,13 +1477,15 @@ static void depthwise_3x3_per_channel_quantization_pad_1_(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else {
depthwise_3x3_per_channel_quantization_pad_1_<
FUSE_RELU,
HAS_BIAS,
- false /*A_SYMM*/>(
+ false /*A_SYMM*/,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3040,6 +1502,7 @@ static void depthwise_3x3_per_channel_quantization_pad_1_(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
}
@@ -3047,7 +1510,7 @@ static void depthwise_3x3_per_channel_quantization_pad_1_(
}
// Dispatch HAS_BIAS
-template <bool FUSE_RELU>
+template <bool FUSE_RELU, typename BIAS_TYPE>
static void depthwise_3x3_per_channel_quantization_pad_1_(
int N,
int H,
@@ -3058,18 +1521,20 @@ static void depthwise_3x3_per_channel_quantization_pad_1_(
int32_t A_zero_point,
const uint8_t* A,
const int32_t* B_zero_point,
- const Packed3x3ConvMatrix& Bp,
+ const PackedDepthWiseConvMatrix& Bp,
const float* C_multiplier,
int32_t C_zero_point,
uint8_t* C,
const int32_t* col_offsets,
- const int32_t* bias,
+ const BIAS_TYPE* bias,
+ const float* act_times_w_scale,
int thread_id,
int num_threads) {
if (bias) {
depthwise_3x3_per_channel_quantization_pad_1_<
FUSE_RELU,
- true /* HAS_BIAS */>(
+ true /* HAS_BIAS */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3085,12 +1550,14 @@ static void depthwise_3x3_per_channel_quantization_pad_1_(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else {
depthwise_3x3_per_channel_quantization_pad_1_<
FUSE_RELU,
- false /* HAS_BIAS */>(
+ false /* HAS_BIAS */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3106,12 +1573,14 @@ static void depthwise_3x3_per_channel_quantization_pad_1_(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
}
}
// Dispatch input shape and FUSE_RELU
+template <typename BIAS_TYPE>
void depthwise_3x3_per_channel_quantization_pad_1(
int N,
int H,
@@ -3122,18 +1591,35 @@ void depthwise_3x3_per_channel_quantization_pad_1(
int32_t A_zero_point,
const uint8_t* A,
const int32_t* B_zero_point,
- const Packed3x3ConvMatrix& Bp,
+ const PackedDepthWiseConvMatrix& Bp,
const float* C_multiplier,
int32_t C_zero_point,
uint8_t* C,
const int32_t* col_offsets,
- const int32_t* bias,
+ const BIAS_TYPE* bias,
bool fuse_relu,
+ const float* act_times_w_scale,
int thread_id,
int num_threads) {
+ if (Bp.GetKernelProduct() != 3 * 3) {
+ string msg =
+ "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " +
+ to_string(3 * 3) + " but has " + to_string(Bp.GetKernelProduct());
+ throw logic_error(msg);
+ }
+ if (stride_h == 0 || stride_w == 0 || num_threads == 0) {
+ assert(0 && "stride_h == 0 || stride_w == 0 || num_threads == 0");
+ return;
+ }
+ if (N == 0) {
+ // In C2, batch size 0 is allowed, so we should just early return.
+ return;
+ }
if (fuse_relu) {
if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ true /* FUSE_RELU */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3149,10 +1635,13 @@ void depthwise_3x3_per_channel_quantization_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ true /* FUSE_RELU */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3168,10 +1657,13 @@ void depthwise_3x3_per_channel_quantization_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (1 == stride_h && 1 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ true /* FUSE_RELU */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3187,10 +1679,13 @@ void depthwise_3x3_per_channel_quantization_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (2 == stride_h && 2 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ true /* FUSE_RELU */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3206,10 +1701,13 @@ void depthwise_3x3_per_channel_quantization_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else {
- depthwise_3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ true /* FUSE_RELU */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3225,12 +1723,15 @@ void depthwise_3x3_per_channel_quantization_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
}
} else {
if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ false /* FUSE_RELU */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3246,10 +1747,13 @@ void depthwise_3x3_per_channel_quantization_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ false /* FUSE_RELU */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3265,10 +1769,13 @@ void depthwise_3x3_per_channel_quantization_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (1 == stride_h && 1 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ false /* FUSE_RELU */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3284,10 +1791,13 @@ void depthwise_3x3_per_channel_quantization_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else if (2 == stride_h && 2 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ false /* FUSE_RELU */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3303,10 +1813,13 @@ void depthwise_3x3_per_channel_quantization_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
} else {
- depthwise_3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>(
+ depthwise_3x3_per_channel_quantization_pad_1_<
+ false /* FUSE_RELU */,
+ BIAS_TYPE>(
N,
H,
W,
@@ -3322,225 +1835,179 @@ void depthwise_3x3_per_channel_quantization_pad_1(
C,
col_offsets,
bias,
+ act_times_w_scale,
thread_id,
num_threads);
}
}
}
-// Dispatch A_SYMMETRIC
-template <bool FUSE_RELU, bool HAS_BIAS>
-static void depthwise_3x3x3_per_channel_quantization_pad_1_(
+// To be removed
+void depthwise_3x3_pad_1(
int N,
- int T,
int H,
int W,
int K,
- int stride_t,
int stride_h,
int stride_w,
int32_t A_zero_point,
const uint8_t* A,
- const int32_t* B_zero_point,
- const Packed3x3x3ConvMatrix& B,
- const float* C_multiplier,
+ int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
int32_t C_zero_point,
uint8_t* C,
const int32_t* col_offsets,
const int32_t* bias,
+ bool fuse_relu,
int thread_id,
int num_threads) {
- int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32];
- if (A_zero_point == 0 || col_offsets == nullptr) {
- depthwise_3x3x3_per_channel_quantization_pad_1_<
- FUSE_RELU,
- HAS_BIAS,
- true /*A_SYMM*/>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3x3_per_channel_quantization_pad_1_<
- FUSE_RELU,
- HAS_BIAS,
- false /*A_SYMM*/>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- }
- delete[] C_int32_temp;
+ depthwise_3x3_pad_1<std::int32_t>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ fuse_relu,
+ 1.0f,
+ thread_id,
+ num_threads);
}
-// Dispatch HAS_BIAS
-template <bool FUSE_RELU>
-static void depthwise_3x3x3_per_channel_quantization_pad_1_(
+// To be removed
+void depthwise_3x3_per_channel_quantization_pad_1(
int N,
- int T,
int H,
int W,
int K,
- int stride_t,
int stride_h,
int stride_w,
int32_t A_zero_point,
const uint8_t* A,
const int32_t* B_zero_point,
- const Packed3x3x3ConvMatrix& B,
+ const PackedDepthWiseConvMatrix& Bp,
const float* C_multiplier,
int32_t C_zero_point,
uint8_t* C,
const int32_t* col_offsets,
const int32_t* bias,
+ bool fuse_relu,
int thread_id,
int num_threads) {
- if (bias) {
- depthwise_3x3x3_per_channel_quantization_pad_1_<
- FUSE_RELU,
- true /* HAS_BIAS */>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3x3_per_channel_quantization_pad_1_<
- FUSE_RELU,
- false /* HAS_BIAS */>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- }
+ depthwise_3x3_per_channel_quantization_pad_1<std::int32_t>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ fuse_relu,
+ nullptr,
+ thread_id,
+ num_threads);
}
-// Dispatch FUSE_RELU
-void depthwise_3x3x3_per_channel_quantization_pad_1(
+template void depthwise_3x3_pad_1(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ bool fuse_relu,
+ float act_times_w_scale,
+ int thread_id,
+ int num_threads);
+
+template void depthwise_3x3_pad_1(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const float* bias,
+ bool fuse_relu,
+ float act_times_w_scale,
+ int thread_id,
+ int num_threads);
+
+template void depthwise_3x3_per_channel_quantization_pad_1(
int N,
- int T,
int H,
int W,
int K,
- int stride_t,
int stride_h,
int stride_w,
int32_t A_zero_point,
const uint8_t* A,
const int32_t* B_zero_point,
- const Packed3x3x3ConvMatrix& B,
+ const PackedDepthWiseConvMatrix& Bp,
const float* C_multiplier,
int32_t C_zero_point,
uint8_t* C,
const int32_t* col_offsets,
const int32_t* bias,
bool fuse_relu,
+ const float* act_times_w_scale,
int thread_id,
- int num_threads) {
- if (fuse_relu) {
- depthwise_3x3x3_per_channel_quantization_pad_1_<true /* FUSE_RELU */>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3x3_per_channel_quantization_pad_1_<false /* FUSE_RELU */>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- thread_id,
- num_threads);
- }
-}
+ int num_threads);
+
+template void depthwise_3x3_per_channel_quantization_pad_1(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& Bp,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const float* bias,
+ bool fuse_relu,
+ const float* act_times_w_scale,
+ int thread_id,
+ int num_threads);
} // namespace fbgemm