Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authordskhudia <dskhudia@fb.com>2018-11-04 19:22:37 +0300
committerdskhudia <dskhudia@fb.com>2018-11-04 19:22:37 +0300
commit690dbc29d9b0cb373fa0303b7c30c20b527e9605 (patch)
tree56d9b3ebc1a7b5ff394e5dc9e08db9e44285e6f4
parent505eb847185c9255526813dd39edadcd4e61d8e0 (diff)
Syncing with internal version. Fixes for Mac/clang build. Other minor fixes
-rw-r--r--CMakeLists.txt1
-rw-r--r--include/fbgemm/Fbgemm.h29
-rw-r--r--include/fbgemm/FbgemmFP16.h129
-rw-r--r--src/FbgemmFP16.cc10
-rw-r--r--src/FbgemmI8Depthwise.cc1893
-rw-r--r--src/GenerateKernelU8S8S32ACC16.cc26
-rw-r--r--src/GenerateKernelU8S8S32ACC16_avx512.cc28
-rw-r--r--src/GenerateKernelU8S8S32ACC32.cc27
-rw-r--r--src/GenerateKernelU8S8S32ACC32_avx512.cc29
-rw-r--r--src/PackAMatrix.cc6
-rw-r--r--src/PackAWithIm2Col.cc5
-rw-r--r--src/PackBMatrix.cc4
-rw-r--r--src/PackWithQuantRowOffset.cc46
-rw-r--r--src/PackWithRowOffset.cc33
-rw-r--r--src/RefImplementations.cc3
15 files changed, 1590 insertions, 679 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 8a477d6..ad8ffd9 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -93,6 +93,7 @@ if(NOT TARGET asmjit)
#build asmjit
set(ASMJIT_STATIC ON)
add_subdirectory("${ASMJIT_SRC_DIR}" "${FBGEMM_BINARY_DIR}/asmjit")
+ set_property(TARGET asmjit PROPERTY POSITION_INDEPENDENT_CODE ON)
endif()
if(NOT TARGET cpuinfo)
diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h
index 988e24b..2f9ddc7 100644
--- a/include/fbgemm/Fbgemm.h
+++ b/include/fbgemm/Fbgemm.h
@@ -236,7 +236,7 @@ class PackMatrix {
return last_bcol_ != blockColSize();
}
- ~PackMatrix() {
+ virtual ~PackMatrix() {
if (bufAllocatedHere_) {
free(buf_);
}
@@ -286,7 +286,7 @@ class PackMatrix {
* accumulation type is int32.
*/
template <typename T, typename accT = std::int32_t>
-class PackAMatrix : public PackMatrix<PackAMatrix<T, accT>, T, accT> {
+class PackAMatrix final : public PackMatrix<PackAMatrix<T, accT>, T, accT> {
public:
using This = PackAMatrix<T, accT>;
using BaseType = PackMatrix<This, T, accT>;
@@ -306,7 +306,7 @@ class PackAMatrix : public PackMatrix<PackAMatrix<T, accT>, T, accT> {
std::int32_t ld,
inpType* pmat = nullptr,
std::int32_t groups = 1,
- accT zero_pt = 0);
+ std::int32_t zero_pt = 0);
/**
* Activation matrices are not constant so cannot amortize the cost of
@@ -361,7 +361,7 @@ class PackAMatrix : public PackMatrix<PackAMatrix<T, accT>, T, accT> {
* type is int32.
*/
template <typename T, typename accT = std::int32_t>
-class PackBMatrix : public PackMatrix<PackBMatrix<T, accT>, T, accT> {
+class PackBMatrix final : public PackMatrix<PackBMatrix<T, accT>, T, accT> {
public:
using This = PackBMatrix<T, accT>;
using BaseType = PackMatrix<This, T, accT>;
@@ -381,7 +381,7 @@ class PackBMatrix : public PackMatrix<PackBMatrix<T, accT>, T, accT> {
std::int32_t ld,
inpType* pmat = nullptr,
std::int32_t groups = 1,
- accT zero_pt = 0);
+ std::int32_t zero_pt = 0);
/**
* Weight matrices are usually constant so worth pre-packing.
@@ -439,7 +439,8 @@ class PackBMatrix : public PackMatrix<PackBMatrix<T, accT>, T, accT> {
* quantized.
*/
template <typename T, typename accT = std::int32_t>
-class PackAWithIm2Col : public PackMatrix<PackAWithIm2Col<T, accT>, T, accT> {
+class PackAWithIm2Col final
+ : public PackMatrix<PackAWithIm2Col<T, accT>, T, accT> {
public:
using This = PackAWithIm2Col<T, accT>;
using BaseType = PackMatrix<This, T, accT>;
@@ -499,7 +500,7 @@ class PackAWithIm2Col : public PackMatrix<PackAWithIm2Col<T, accT>, T, accT> {
* The source matrix is already quantized.
*/
template <typename T, typename accT = std::int32_t>
-class PackAWithRowOffset
+class PackAWithRowOffset final
: public PackMatrix<PackAWithRowOffset<T, accT>, T, accT> {
public:
using This = PackAWithRowOffset<T, accT>;
@@ -572,7 +573,7 @@ class PackAWithRowOffset
* The source matrix is in fp32 and quantized during packing.
*/
template <typename T, typename accT = std::int32_t>
-class PackAWithQuantRowOffset
+class PackAWithQuantRowOffset final
: public PackMatrix<PackAWithQuantRowOffset<T, accT>, T, accT> {
public:
using This = PackAWithQuantRowOffset<T, accT>;
@@ -935,7 +936,6 @@ void fbgemmPacked(
/**
* @brief Perform depthwise separable convolution
*/
-
template <
typename packingAMatrix,
typename packingBMatrix,
@@ -949,4 +949,15 @@ void convDepthwiseSeparable(
outT* out,
const processOutputType& output);
+/**
+ * @brief Allocate __size bytes of uninitialized storage whose alignment is
+ * specified by __align.
+ */
+static void* fbgemmAlignedAlloc(size_t __align, size_t __size) {
+ void* aligned_mem;
+ if (posix_memalign(&aligned_mem, __align, __size))
+ return 0;
+ return aligned_mem;
+}
+
} // namespace fbgemm2
diff --git a/include/fbgemm/FbgemmFP16.h b/include/fbgemm/FbgemmFP16.h
index 55718d4..0428d93 100644
--- a/include/fbgemm/FbgemmFP16.h
+++ b/include/fbgemm/FbgemmFP16.h
@@ -10,6 +10,7 @@
// upgraded to match with new fbgemm interface.
#include <cassert>
+#include <cstdlib>
#include <memory>
#include <vector>
@@ -22,7 +23,7 @@ namespace fbgemm2 {
/// row-major format into
/// internal packed blocked-row major format
class PackedGemmMatrixFP16 {
-public:
+ public:
// takes smat input mamtrix in row-major format;
// and packs it into gemm-friendly blocked format;
// allocate space and sets up all the internal variables;
@@ -32,30 +33,31 @@ public:
// before flushing into fp32
// the smaller the brow_, the higher overhead
// of flushing is
- PackedGemmMatrixFP16(const matrix_op_t trans, const int nrow,
- const int ncol, const float alpha,
- const float *smat,
- const int brow = 512)
+ PackedGemmMatrixFP16(
+ const matrix_op_t trans,
+ const int nrow,
+ const int ncol,
+ const float alpha,
+ const float* smat,
+ const int brow = 512)
: nrow_(nrow), ncol_(ncol), brow_(brow) {
-
bcol_ = 8 * 1; // hardwired
// set up internal packing parameters
nbrow_ = ((numRows() % blockRowSize()) == 0)
- ? (numRows() / blockRowSize())
- : ((numRows() + blockRowSize()) / blockRowSize());
+ ? (numRows() / blockRowSize())
+ : ((numRows() + blockRowSize()) / blockRowSize());
last_brow_ = ((nrow % blockRowSize()) == 0) ? blockRowSize()
- : (nrow % blockRowSize());
+ : (nrow % blockRowSize());
nbcol_ = ((numCols() % blockColSize()) == 0)
- ? (numCols() / blockColSize())
- : ((numCols() + blockColSize()) / blockColSize());
+ ? (numCols() / blockColSize())
+ : ((numCols() + blockColSize()) / blockColSize());
if (numCols() != blockColSize() * nbcol_) {
#ifdef VLOG
- VLOG(0)
- << "Packer warning: ncol(" << numCols()
- << ") is not a multiple of internal block size (" << blockColSize()
- << ")";
+ VLOG(0) << "Packer warning: ncol(" << numCols()
+ << ") is not a multiple of internal block size ("
+ << blockColSize() << ")";
VLOG(0)
<< "lefover is currently done via MKL: hence overhead will inccur";
#endif
@@ -64,7 +66,9 @@ public:
// allocate and initialize packed memory
const int padding = 1024; // required by sw pipelined kernels
size_ = (blockRowSize() * nbrow_) * (blockColSize() * nbcol_);
- pmat_ = (float16 *)aligned_alloc(64, matSize() * sizeof(float16) + padding);
+ // pmat_ = (float16 *)aligned_alloc(64, matSize() * sizeof(float16) +
+ // padding);
+ posix_memalign((void**)&pmat_, 64, matSize() * sizeof(float16) + padding);
for (auto i = 0; i < matSize(); i++) {
pmat_[i] = tconv(0.f, pmat_[i]);
}
@@ -77,7 +81,7 @@ public:
free(pmat_);
}
-// protected:
+ // protected:
// blocked row-major format address arithmetic
uint64_t addr(const int r_, const int c_) const {
uint64_t r = (uint64_t)r_;
@@ -87,10 +91,9 @@ public:
brow_offset =
(block_row_id * nbcol_) * (blockRowSize() * blockColSize());
uint64_t block_col_id = c / blockColSize(),
- bcol_offset =
- block_col_id * ((block_row_id != nbrow_ - 1)
- ? (blockRowSize() * blockColSize())
- : (last_brow_ * blockColSize()));
+ bcol_offset = block_col_id *
+ ((block_row_id != nbrow_ - 1) ? (blockRowSize() * blockColSize())
+ : (last_brow_ * blockColSize()));
uint64_t block_offset = brow_offset + bcol_offset;
uint64_t inblock_offset =
r % blockRowSize() * blockColSize() + c % blockColSize();
@@ -100,22 +103,22 @@ public:
return index;
}
- void packFromSrc(const matrix_op_t trans, const float alpha,
- const float *smat) {
+ void
+ packFromSrc(const matrix_op_t trans, const float alpha, const float* smat) {
bool tr = (trans == matrix_op_t::Transpose);
// pack
for (int i = 0; i < numRows(); i++) {
for (int j = 0; j < numCols(); j++) {
pmat_[addr(i, j)] = tconv(
- alpha * (
- (tr == false)
- ? smat[i * numCols() + j] : smat[i + numRows() * j]),
+ alpha *
+ ((tr == false) ? smat[i * numCols() + j]
+ : smat[i + numRows() * j]),
pmat_[addr(i, j)]);
}
}
}
- const float16 &operator()(const int r, const int c) const {
+ const float16& operator()(const int r, const int c) const {
uint64_t a = addr(r, c);
assert(r < numRows());
assert(c < numCols());
@@ -123,38 +126,60 @@ public:
return pmat_[a];
}
- int matSize() const { return size_; }
- int numRows() const { return nrow_; }
- int numCols() const { return ncol_; }
- inline int blockRowSize() const { return brow_; }
- inline int blockColSize() const { return bcol_; }
+ int matSize() const {
+ return size_;
+ }
+ int numRows() const {
+ return nrow_;
+ }
+ int numCols() const {
+ return ncol_;
+ }
+ inline int blockRowSize() const {
+ return brow_;
+ }
+ inline int blockColSize() const {
+ return bcol_;
+ }
int nrow_, ncol_;
int brow_, last_brow_, bcol_;
int nbrow_, nbcol_;
uint64_t size_;
- float16 *pmat_;
-
- friend void cblas_gemm_compute(const matrix_op_t transa, const int m,
- const float *A,
- const PackedGemmMatrixFP16 &Bp,
- const float beta, float *C);
- friend void cblas_gemm_compute(const matrix_op_t transa, const int m,
- const float *A,
- const PackedGemmMatrixFP16 &Bp,
- const float beta, float *C);
+ float16* pmat_;
+
+ friend void cblas_gemm_compute(
+ const matrix_op_t transa,
+ const int m,
+ const float* A,
+ const PackedGemmMatrixFP16& Bp,
+ const float beta,
+ float* C);
+ friend void cblas_gemm_compute(
+ const matrix_op_t transa,
+ const int m,
+ const float* A,
+ const PackedGemmMatrixFP16& Bp,
+ const float beta,
+ float* C);
};
/**
* restrictions: transa == CblasNoTrans
*/
-extern void cblas_gemm_compute(const matrix_op_t transa, const int m,
- const float *A,
- const PackedGemmMatrixFP16 &Bp,
- const float beta, float *C);
-extern void cblas_gemm_compute(const matrix_op_t transa, const int m,
- const float *A,
- const PackedGemmMatrixFP16 &Bp,
- const float beta, float *C);
-
-}; // namespace fbgemm
+extern void cblas_gemm_compute(
+ const matrix_op_t transa,
+ const int m,
+ const float* A,
+ const PackedGemmMatrixFP16& Bp,
+ const float beta,
+ float* C);
+extern void cblas_gemm_compute(
+ const matrix_op_t transa,
+ const int m,
+ const float* A,
+ const PackedGemmMatrixFP16& Bp,
+ const float beta,
+ float* C);
+
+}; // namespace fbgemm2
diff --git a/src/FbgemmFP16.cc b/src/FbgemmFP16.cc
index 7bbfa54..eff173f 100644
--- a/src/FbgemmFP16.cc
+++ b/src/FbgemmFP16.cc
@@ -7,6 +7,8 @@
#include "fbgemm/FbgemmFP16.h"
#include <cpuinfo.h>
+#include <array>
+#include <utility>
#include "FbgemmFP16UKernels.h"
@@ -44,7 +46,7 @@ struct KernelInfo {
// autotuned kernel splits for various cases m = 1:mb_max
// may need re-autotuning for new uarch
- static constexpr array<array<pair<int, int>, 2>, 121 > partition = {
+ static constexpr array<array<array<int, 2>, 2>, 121 > partition = {
{
{{ { 0, 0 }, { 0, 0 } } },
{{ { 1, 1 }, { 0, 0 } } },
@@ -171,7 +173,7 @@ struct KernelInfo {
};
};
constexpr array<KernelInfo::knl_ptr, 15> KernelInfo::kernel;
-constexpr array<array<pair<int, int>, 2>, 121 > KernelInfo::partition;
+constexpr array<array<array<int, 2>, 2>, 121 > KernelInfo::partition;
// autotuned kernel splits for various cases m = 1:mb_max
void
@@ -220,8 +222,8 @@ cblas_gemm_compute(const matrix_op_t transa, const int m, const float *A,
auto m1 = 0;
for (auto c = 0; c < 2; c++) {
- auto kernel_nrows = KernelInfo::partition[mb][c].first;
- auto nkernel_nrows = KernelInfo::partition[mb][c].second;
+ auto kernel_nrows = KernelInfo::partition[mb][c][0];
+ auto nkernel_nrows = KernelInfo::partition[mb][c][1];
auto m_start = m1, m_end = m1 + kernel_nrows * nkernel_nrows;
for (auto m2 = m_start; m2 < m_end; m2 += kernel_nrows) {
diff --git a/src/FbgemmI8Depthwise.cc b/src/FbgemmI8Depthwise.cc
index 54e2272..551e98e 100644
--- a/src/FbgemmI8Depthwise.cc
+++ b/src/FbgemmI8Depthwise.cc
@@ -18,8 +18,7 @@
using namespace std;
-namespace fbgemm2
-{
+namespace fbgemm2 {
static array<array<int, 8>, 8> masks = {{
{ 0, 0, 0, 0, 0, 0, 0, 0, },
@@ -34,7 +33,8 @@ static array<array<int, 8>, 8> masks = {{
template <int KERNEL_PROD>
PackedDepthWiseConvMatrix<KERNEL_PROD>::PackedDepthWiseConvMatrix(
- int K, const int8_t *smat)
+ int K,
+ const int8_t* smat)
: K_(K) {
// Transpose the input matrix to make packing faster.
vector<int8_t> smat_transposed(K * KERNEL_PROD);
@@ -46,8 +46,12 @@ PackedDepthWiseConvMatrix<KERNEL_PROD>::PackedDepthWiseConvMatrix(
// Allocate packed arrays
constexpr int KERNEL_PROD_ALIGNED = (KERNEL_PROD + 1) / 2 * 2;
- pmat_ = static_cast<int8_t *>(aligned_alloc(
- 64, ((K + 31) / 32) * KERNEL_PROD_ALIGNED * 32 * sizeof(int8_t)));
+ // pmat_ = static_cast<int8_t *>(fbgemmAlignedAlloc(
+ // 64, ((K + 31) / 32) * KERNEL_PROD_ALIGNED * 32 * sizeof(int8_t)));
+ posix_memalign(
+ (void**)&pmat_,
+ 64,
+ ((K + 31) / 32) * KERNEL_PROD_ALIGNED * 32 * sizeof(int8_t));
// Pack input matrix
// The layout is optimized to use vpmaddubsw efficiently (see
@@ -106,15 +110,15 @@ PackedDepthWiseConvMatrix<KERNEL_PROD>::PackedDepthWiseConvMatrix(
int remainder = K - k1;
if (remainder < 32) {
__m256i mask_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i *>(masks[remainder / 4].data()));
+ reinterpret_cast<const __m256i*>(masks[remainder / 4].data()));
for (int i = 0; i < KERNEL_PROD; ++i) {
b_v[i] = _mm256_maskload_epi32(
- reinterpret_cast<const int *>(smat_transposed.data() + i * K + k1),
+ reinterpret_cast<const int*>(smat_transposed.data() + i * K + k1),
mask_v);
}
} else {
for (int i = 0; i < KERNEL_PROD; ++i) {
- b_v[i] = _mm256_lddqu_si256(reinterpret_cast<const __m256i *>(
+ b_v[i] = _mm256_lddqu_si256(reinterpret_cast<const __m256i*>(
smat_transposed.data() + i * K + k1));
}
}
@@ -153,7 +157,7 @@ PackedDepthWiseConvMatrix<KERNEL_PROD>::PackedDepthWiseConvMatrix(
for (int i = 0; i < KERNEL_PROD_ALIGNED; ++i) {
_mm256_storeu_si256(
- reinterpret_cast<__m256i *>(
+ reinterpret_cast<__m256i*>(
&pmat_[((k1 / 32) * KERNEL_PROD_ALIGNED + i) * 32]),
b_interleaved_epi32[i]);
}
@@ -161,8 +165,7 @@ PackedDepthWiseConvMatrix<KERNEL_PROD>::PackedDepthWiseConvMatrix(
}
template <int KERNEL_PROD>
-PackedDepthWiseConvMatrix<KERNEL_PROD>::~PackedDepthWiseConvMatrix()
-{
+PackedDepthWiseConvMatrix<KERNEL_PROD>::~PackedDepthWiseConvMatrix() {
free(pmat_);
}
@@ -178,11 +181,16 @@ template class PackedDepthWiseConvMatrix<3 * 3 * 3>;
// c2_v: c[8:12], c[24:28]
// c3_v: c[12:16], c[28:32]
template <bool SUM_A = false>
-static inline __attribute__((always_inline))
-void madd_epi16x4_packed(
- __m256i a0_v, __m256i a1_v, __m256i a2_v, __m256i a3_v,
+static inline __attribute__((always_inline)) void madd_epi16x4_packed(
+ __m256i a0_v,
+ __m256i a1_v,
+ __m256i a2_v,
+ __m256i a3_v,
const __m256i* b,
- __m256i* c0_v, __m256i* c1_v, __m256i* c2_v, __m256i* c3_v,
+ __m256i* c0_v,
+ __m256i* c1_v,
+ __m256i* c2_v,
+ __m256i* c3_v,
__m256i* a_sum = nullptr) {
__m256i a01_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v);
__m256i a01_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v);
@@ -232,11 +240,15 @@ void madd_epi16x4_packed(
// c2_v: c[8:12], c[24:28]
// c3_v: c[12:16], c[28:32]
template <bool SUM_A = false>
-static inline __attribute__((always_inline))
-void madd_epi16x3_packed(
- __m256i a0_v, __m256i a1_v, __m256i a2_v,
+static inline __attribute__((always_inline)) void madd_epi16x3_packed(
+ __m256i a0_v,
+ __m256i a1_v,
+ __m256i a2_v,
const __m256i* b,
- __m256i* c0_v, __m256i* c1_v, __m256i* c2_v, __m256i* c3_v,
+ __m256i* c0_v,
+ __m256i* c1_v,
+ __m256i* c2_v,
+ __m256i* c3_v,
__m256i* a_sum = nullptr) {
__m256i zero_v = _mm256_setzero_si256();
@@ -288,10 +300,15 @@ void madd_epi16x3_packed(
// c2_v: c[16:20], c[20:24]
// c3_v: c[24:28], c[28:32]
template <bool SUM_A = false>
-static inline __attribute__((always_inline)) void
-madd_epi16x2_packed(__m256i a0_v, __m256i a1_v, const __m256i *b, __m256i *c0_v,
- __m256i *c1_v, __m256i *c2_v, __m256i *c3_v,
- __m256i *a_sum = nullptr) {
+static inline __attribute__((always_inline)) void madd_epi16x2_packed(
+ __m256i a0_v,
+ __m256i a1_v,
+ const __m256i* b,
+ __m256i* c0_v,
+ __m256i* c1_v,
+ __m256i* c2_v,
+ __m256i* c3_v,
+ __m256i* a_sum = nullptr) {
__m256i a_lo_v = _mm256_unpacklo_epi8(a0_v, a1_v);
__m256i a_hi_v = _mm256_unpackhi_epi8(a0_v, a1_v);
@@ -324,9 +341,14 @@ madd_epi16x2_packed(__m256i a0_v, __m256i a1_v, const __m256i *b, __m256i *c0_v,
// c2_v: c[16:20], c[20:24]
// c3_v: c[24:28], c[28:32]
template <bool SUM_A = false>
-static inline __attribute__((always_inline)) void
-madd_epi16_packed(__m256i a_v, const __m256i *b, __m256i *c0_v, __m256i *c1_v,
- __m256i *c2_v, __m256i *c3_v, __m256i *a_sum = nullptr) {
+static inline __attribute__((always_inline)) void madd_epi16_packed(
+ __m256i a_v,
+ const __m256i* b,
+ __m256i* c0_v,
+ __m256i* c1_v,
+ __m256i* c2_v,
+ __m256i* c3_v,
+ __m256i* a_sum = nullptr) {
__m256i zero_v = _mm256_setzero_si256();
__m256i a_lo_v = _mm256_unpacklo_epi8(a_v, zero_v);
@@ -354,21 +376,41 @@ madd_epi16_packed(__m256i a_v, const __m256i *b, __m256i *c0_v, __m256i *c1_v,
// K is the number of accumulations we're doing
template <int K, bool SUM_A = false, bool REMAINDER = false, bool ACC = false>
-static inline __attribute__((always_inline)) void
-inner_prod_packed_(const __m256i *a_v, const __m256i *Bp, int32_t *C,
- int remainder, __m256i *a_sum = nullptr) {
+static inline __attribute__((always_inline)) void inner_prod_packed_(
+ const __m256i* a_v,
+ const __m256i* Bp,
+ int32_t* C,
+ int remainder,
+ __m256i* a_sum = nullptr) {
array<__m256i, 4> c, c_temp;
array<__m256i, 2> a_sum_temp{};
int k = 0;
if (K >= 4) {
- madd_epi16x4_packed<SUM_A>(a_v[0], a_v[1], a_v[2], a_v[3], Bp,
- &c[0], &c[1], &c[2], &c[3], a_sum_temp.data());
+ madd_epi16x4_packed<SUM_A>(
+ a_v[0],
+ a_v[1],
+ a_v[2],
+ a_v[3],
+ Bp,
+ &c[0],
+ &c[1],
+ &c[2],
+ &c[3],
+ a_sum_temp.data());
for (k = 4; k < K / 4 * 4; k += 4) {
- madd_epi16x4_packed<SUM_A>(a_v[k + 0], a_v[k + 1], a_v[k + 2], a_v[k + 3],
- Bp + k, &c_temp[0], &c_temp[1], &c_temp[2],
- &c_temp[3], a_sum_temp.data());
+ madd_epi16x4_packed<SUM_A>(
+ a_v[k + 0],
+ a_v[k + 1],
+ a_v[k + 2],
+ a_v[k + 3],
+ Bp + k,
+ &c_temp[0],
+ &c_temp[1],
+ &c_temp[2],
+ &c_temp[3],
+ a_sum_temp.data());
c[0] = _mm256_add_epi32(c[0], c_temp[0]);
c[1] = _mm256_add_epi32(c[1], c_temp[1]);
@@ -383,9 +425,16 @@ inner_prod_packed_(const __m256i *a_v, const __m256i *Bp, int32_t *C,
}
if (K - k == 3) {
- madd_epi16x3_packed<SUM_A>(a_v[k], a_v[k + 1], a_v[k + 2], Bp + k,
- &c_temp[0], &c_temp[1], &c_temp[2], &c_temp[3],
- a_sum_temp.data());
+ madd_epi16x3_packed<SUM_A>(
+ a_v[k],
+ a_v[k + 1],
+ a_v[k + 2],
+ Bp + k,
+ &c_temp[0],
+ &c_temp[1],
+ &c_temp[2],
+ &c_temp[3],
+ a_sum_temp.data());
c[0] = _mm256_add_epi32(c[0], c_temp[0]);
c[1] = _mm256_add_epi32(c[1], c_temp[1]);
@@ -405,11 +454,18 @@ inner_prod_packed_(const __m256i *a_v, const __m256i *Bp, int32_t *C,
c[3] = c_temp[3];
} else {
if (K - k == 1) {
- madd_epi16_packed<SUM_A>(a_v[k], Bp + k, &c[0], &c[1], &c[2], &c[3],
- a_sum_temp.data());
+ madd_epi16_packed<SUM_A>(
+ a_v[k], Bp + k, &c[0], &c[1], &c[2], &c[3], a_sum_temp.data());
} else if (K - k == 2) {
- madd_epi16x2_packed<SUM_A>(a_v[k], a_v[k + 1], Bp + k, &c[0], &c[1],
- &c[2], &c[3], a_sum_temp.data());
+ madd_epi16x2_packed<SUM_A>(
+ a_v[k],
+ a_v[k + 1],
+ Bp + k,
+ &c[0],
+ &c[1],
+ &c[2],
+ &c[3],
+ a_sum_temp.data());
}
c[0] = _mm256_add_epi32(c[0], c_temp[0]);
@@ -422,37 +478,37 @@ inner_prod_packed_(const __m256i *a_v, const __m256i *Bp, int32_t *C,
for (int r = 0; r < remainder / 8; ++r) {
if (ACC) {
_mm256_storeu_si256(
- reinterpret_cast<__m256i *>(C + r * 8),
+ reinterpret_cast<__m256i*>(C + r * 8),
_mm256_add_epi32(
- _mm256_loadu_si256(reinterpret_cast<__m256i *>(C + r * 8)),
+ _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + r * 8)),
c[r]));
} else {
- _mm256_storeu_si256(reinterpret_cast<__m256i *>(C + r * 8), c[r]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + r * 8), c[r]);
}
}
} else {
if (ACC) {
_mm256_storeu_si256(
- reinterpret_cast<__m256i *>(C),
- _mm256_add_epi32(_mm256_loadu_si256(reinterpret_cast<__m256i *>(C)),
- c[0]));
+ reinterpret_cast<__m256i*>(C),
+ _mm256_add_epi32(
+ _mm256_loadu_si256(reinterpret_cast<__m256i*>(C)), c[0]));
_mm256_storeu_si256(
- reinterpret_cast<__m256i *>(C + 8),
+ reinterpret_cast<__m256i*>(C + 8),
_mm256_add_epi32(
- _mm256_loadu_si256(reinterpret_cast<__m256i *>(C + 8)), c[1]));
+ _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 8)), c[1]));
_mm256_storeu_si256(
- reinterpret_cast<__m256i *>(C + 16),
+ reinterpret_cast<__m256i*>(C + 16),
_mm256_add_epi32(
- _mm256_loadu_si256(reinterpret_cast<__m256i *>(C + 16)), c[2]));
+ _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 16)), c[2]));
_mm256_storeu_si256(
- reinterpret_cast<__m256i *>(C + 24),
+ reinterpret_cast<__m256i*>(C + 24),
_mm256_add_epi32(
- _mm256_loadu_si256(reinterpret_cast<__m256i *>(C + 24)), c[3]));
+ _mm256_loadu_si256(reinterpret_cast<__m256i*>(C + 24)), c[3]));
} else {
- _mm256_storeu_si256(reinterpret_cast<__m256i *>(C), c[0]);
- _mm256_storeu_si256(reinterpret_cast<__m256i *>(C + 8), c[1]);
- _mm256_storeu_si256(reinterpret_cast<__m256i *>(C + 16), c[2]);
- _mm256_storeu_si256(reinterpret_cast<__m256i *>(C + 24), c[3]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(C), c[0]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 8), c[1]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 16), c[2]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + 24), c[3]);
}
}
@@ -467,14 +523,13 @@ inner_prod_packed_(const __m256i *a_v, const __m256i *Bp, int32_t *C,
}
template <bool SUM_A = false, bool REMAINDER = false>
-static inline __attribute__((always_inline))
-void inner_prod_3x3_packed_(const __m256i* a_v,
- const __m256i* Bp,
- int32_t* C,
- int remainder,
- __m256i* a_sum = nullptr) {
- return inner_prod_packed_<9, SUM_A, REMAINDER>(a_v, Bp, C, remainder,
- a_sum);
+static inline __attribute__((always_inline)) void inner_prod_3x3_packed_(
+ const __m256i* a_v,
+ const __m256i* Bp,
+ int32_t* C,
+ int remainder,
+ __m256i* a_sum = nullptr) {
+ return inner_prod_packed_<9, SUM_A, REMAINDER>(a_v, Bp, C, remainder, a_sum);
}
// Almost same as ReQuantizeOutput in OutputProcessing-inh.h but different
@@ -507,7 +562,7 @@ static inline __attribute__((always_inline)) void requantize_(
constexpr int VLEN = 8;
int j = 0;
- for ( ; j < n / (VLEN * 4) * (VLEN * 4); j += (VLEN * 4)) {
+ for (; j < n / (VLEN * 4) * (VLEN * 4); j += (VLEN * 4)) {
__m256i x_v =
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j));
__m256i y_v = _mm256_loadu_si256(
@@ -519,10 +574,9 @@ static inline __attribute__((always_inline)) void requantize_(
__m256i col_off_v = _mm256_mullo_epi32(
A_zero_point_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(col_offsets + j)));
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(col_offsets + j)));
__m256i row_offset_v =
- _mm256_loadu_si256(reinterpret_cast<const __m256i *>(row_offsets + j));
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(row_offsets + j));
x_v = _mm256_sub_epi32(_mm256_sub_epi32(x_v, col_off_v), row_offset_v);
col_off_v = _mm256_mullo_epi32(
@@ -535,23 +589,23 @@ static inline __attribute__((always_inline)) void requantize_(
col_off_v = _mm256_mullo_epi32(
A_zero_point_v,
- _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
- col_offsets + j + 2 * VLEN)));
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(col_offsets + j + 2 * VLEN)));
row_offset_v = _mm256_loadu_si256(
reinterpret_cast<const __m256i*>(row_offsets + j + 2 * VLEN));
z_v = _mm256_sub_epi32(_mm256_sub_epi32(z_v, col_off_v), row_offset_v);
col_off_v = _mm256_mullo_epi32(
A_zero_point_v,
- _mm256_loadu_si256(reinterpret_cast<const __m256i*>(
- col_offsets + j + 3 * VLEN)));
+ _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(col_offsets + j + 3 * VLEN)));
row_offset_v = _mm256_loadu_si256(
reinterpret_cast<const __m256i*>(row_offsets + j + 3 * VLEN));
w_v = _mm256_sub_epi32(_mm256_sub_epi32(w_v, col_off_v), row_offset_v);
if (HAS_BIAS) { // static if
x_v = _mm256_add_epi32(
- x_v, _mm256_loadu_si256(reinterpret_cast<const __m256i *>(bias + j)));
+ x_v, _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias + j)));
y_v = _mm256_add_epi32(
y_v,
_mm256_loadu_si256(
@@ -604,21 +658,20 @@ static inline __attribute__((always_inline)) void requantize_(
reinterpret_cast<__m256i*>(C_uint8 + j), xyzw_clamped_v);
} // j loop vectorized and unrolled 4x
- for ( ; j < n / VLEN * VLEN; j += VLEN) {
+ for (; j < n / VLEN * VLEN; j += VLEN) {
__m256i x_v =
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(C_int32 + j));
__m256i col_off_v = _mm256_mullo_epi32(
A_zero_point_v,
- _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(col_offsets + j)));
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(col_offsets + j)));
__m256i row_offset_v =
- _mm256_loadu_si256(reinterpret_cast<const __m256i *>(row_offsets + j));
+ _mm256_loadu_si256(reinterpret_cast<const __m256i*>(row_offsets + j));
x_v = _mm256_sub_epi32(_mm256_sub_epi32(x_v, col_off_v), row_offset_v);
if (HAS_BIAS) { // static if
x_v = _mm256_add_epi32(
- x_v, _mm256_loadu_si256(reinterpret_cast<const __m256i *>(bias + j)));
+ x_v, _mm256_loadu_si256(reinterpret_cast<const __m256i*>(bias + j)));
}
if (PER_CHANNEL_QUANTIZATION) {
@@ -635,15 +688,14 @@ static inline __attribute__((always_inline)) void requantize_(
FUSE_RELU ? C_zero_point_epi8_v : min_v,
_mm256_min_epu8(x_packed_v, max_v));
- x_clamped_v =
- _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v);
+ x_clamped_v = _mm256_permutevar8x32_epi32(x_clamped_v, permute_mask_v);
_mm_storel_epi64(
reinterpret_cast<__m128i*>(C_uint8 + j),
_mm256_castsi256_si128(x_clamped_v));
} // j loop vectorized
- for ( ; j < n; ++j) {
+ for (; j < n; ++j) {
int32_t raw = C_int32[j] - A_zero_point * col_offsets[j] - row_offsets[j];
if (HAS_BIAS) { // static if
raw += bias[j];
@@ -659,11 +711,16 @@ static inline __attribute__((always_inline)) void requantize_(
}
template <bool FUSE_RELU, bool HAS_BIAS>
-static inline __attribute__((always_inline)) void
-requantize_(int32_t A_zero_point, float C_multiplier,
- int32_t C_zero_point, const int32_t *C_int32, uint8_t *C_uint8,
- int n, const int32_t *row_offsets, const int32_t *col_offsets,
- const int32_t *bias) {
+static inline __attribute__((always_inline)) void requantize_(
+ int32_t A_zero_point,
+ float C_multiplier,
+ int32_t C_zero_point,
+ const int32_t* C_int32,
+ uint8_t* C_uint8,
+ int n,
+ const int32_t* row_offsets,
+ const int32_t* col_offsets,
+ const int32_t* bias) {
requantize_<FUSE_RELU, HAS_BIAS, false /* PER_CHANNEL_QUANTIZATION */>(
A_zero_point,
&C_multiplier,
@@ -677,11 +734,16 @@ requantize_(int32_t A_zero_point, float C_multiplier,
}
template <bool FUSE_RELU, bool HAS_BIAS>
-static inline __attribute__((always_inline)) void
-requantize_per_channel_(int32_t A_zero_point, const float *C_multiplier,
- int32_t C_zero_point, const int32_t *C_int32,
- uint8_t *C_uint8, int n, const int32_t *row_offsets,
- const int32_t *col_offsets, const int32_t *bias) {
+static inline __attribute__((always_inline)) void requantize_per_channel_(
+ int32_t A_zero_point,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ const int32_t* C_int32,
+ uint8_t* C_uint8,
+ int n,
+ const int32_t* row_offsets,
+ const int32_t* col_offsets,
+ const int32_t* bias) {
requantize_<FUSE_RELU, HAS_BIAS, true /* PER_CHANNEL_QUANTIZATION */>(
A_zero_point,
C_multiplier,
@@ -695,27 +757,38 @@ requantize_per_channel_(int32_t A_zero_point, const float *C_multiplier,
}
template <bool REMAINDER>
-static inline __attribute__((always_inline)) __m256i
-load_a(const uint8_t* A, __m256i mask_v) {
+static inline __attribute__((always_inline)) __m256i load_a(
+ const uint8_t* A,
+ __m256i mask_v) {
if (REMAINDER) {
- return _mm256_maskload_epi32(reinterpret_cast<const int *>(A), mask_v);
+ return _mm256_maskload_epi32(reinterpret_cast<const int*>(A), mask_v);
} else {
- return _mm256_lddqu_si256(reinterpret_cast<const __m256i *>(A));
+ return _mm256_lddqu_si256(reinterpret_cast<const __m256i*>(A));
}
}
-template <bool SUM_A, bool REMAINDER = false,
- bool PER_CHANNEL_QUANTIZATION = false>
-static inline __attribute__((always_inline)) void
-inner_prod_3x3_packed_(int H, int W, int K, int h_in, int w_in,
- const uint8_t *A, int32_t A_zero_point, const int8_t *Bp,
- const int32_t *B_zero_point, int32_t *C, int remainder,
- int32_t *row_offsets) {
+template <
+ bool SUM_A,
+ bool REMAINDER = false,
+ bool PER_CHANNEL_QUANTIZATION = false>
+static inline __attribute__((always_inline)) void inner_prod_3x3_packed_(
+ int H,
+ int W,
+ int K,
+ int h_in,
+ int w_in,
+ const uint8_t* A,
+ int32_t A_zero_point,
+ const int8_t* Bp,
+ const int32_t* B_zero_point,
+ int32_t* C,
+ int remainder,
+ int32_t* row_offsets) {
__m256i A_zero_point_v = _mm256_set1_epi8(static_cast<uint8_t>(A_zero_point));
__m256i mask_v = _mm256_setzero_si256();
if (REMAINDER) {
mask_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i *>(masks[remainder / 4].data()));
+ reinterpret_cast<const __m256i*>(masks[remainder / 4].data()));
}
// The code below can be written as a simple R*S loop but the compiler
@@ -739,15 +812,15 @@ inner_prod_3x3_packed_(int H, int W, int K, int h_in, int w_in,
// }
// }
array<__m256i, 9> a_v = {
- A_zero_point_v,
- A_zero_point_v,
- A_zero_point_v,
- A_zero_point_v,
- A_zero_point_v,
- A_zero_point_v,
- A_zero_point_v,
- A_zero_point_v,
- A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
};
if (h_in >= 0 && h_in < H) {
@@ -788,35 +861,51 @@ inner_prod_3x3_packed_(int H, int W, int K, int h_in, int w_in,
array<__m256i, 4> a_sum;
inner_prod_3x3_packed_<SUM_A, REMAINDER>(
- a_v.data(), reinterpret_cast<const __m256i *>(Bp), C, remainder,
+ a_v.data(),
+ reinterpret_cast<const __m256i*>(Bp),
+ C,
+ remainder,
a_sum.data());
if (SUM_A) {
__m256i B_zero_point_v;
for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) {
if (PER_CHANNEL_QUANTIZATION) {
B_zero_point_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i *>(B_zero_point + i * 8));
+ reinterpret_cast<const __m256i*>(B_zero_point + i * 8));
} else {
B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]);
}
- _mm256_store_si256(reinterpret_cast<__m256i *>(&row_offsets[i * 8]),
- _mm256_mullo_epi32(a_sum[i], B_zero_point_v));
+ _mm256_store_si256(
+ reinterpret_cast<__m256i*>(&row_offsets[i * 8]),
+ _mm256_mullo_epi32(a_sum[i], B_zero_point_v));
}
}
}
-template <bool SUM_A, bool REMAINDER = false,
- bool PER_CHANNEL_QUANTIZATION = false>
-static inline __attribute__((always_inline)) void
-inner_prod_3x3x3_packed_(int T, int H, int W, int K, int t_in, int h_in,
- int w_in, const uint8_t *A, int32_t A_zero_point,
- const int8_t *Bp, const int32_t *B_zero_point,
- int32_t *C, int remainder, int32_t *row_offsets) {
+template <
+ bool SUM_A,
+ bool REMAINDER = false,
+ bool PER_CHANNEL_QUANTIZATION = false>
+static inline __attribute__((always_inline)) void inner_prod_3x3x3_packed_(
+ int T,
+ int H,
+ int W,
+ int K,
+ int t_in,
+ int h_in,
+ int w_in,
+ const uint8_t* A,
+ int32_t A_zero_point,
+ const int8_t* Bp,
+ const int32_t* B_zero_point,
+ int32_t* C,
+ int remainder,
+ int32_t* row_offsets) {
__m256i A_zero_point_v = _mm256_set1_epi8(static_cast<uint8_t>(A_zero_point));
__m256i mask_v = _mm256_setzero_si256();
if (REMAINDER) {
mask_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i *>(masks[remainder / 4].data()));
+ reinterpret_cast<const __m256i*>(masks[remainder / 4].data()));
}
// The code below can be written as a simple R*S loop but the compiler
@@ -885,9 +974,12 @@ inner_prod_3x3x3_packed_(int T, int H, int W, int K, int t_in, int h_in,
}
array<__m256i, 4> a_sum;
- inner_prod_packed_<8, SUM_A, REMAINDER>(a_v.data(),
- reinterpret_cast<const __m256i *>(Bp),
- C, remainder, a_sum.data());
+ inner_prod_packed_<8, SUM_A, REMAINDER>(
+ a_v.data(),
+ reinterpret_cast<const __m256i*>(Bp),
+ C,
+ remainder,
+ a_sum.data());
a_v[0] = A_zero_point_v;
a_v[1] = A_zero_point_v;
@@ -940,7 +1032,10 @@ inner_prod_3x3x3_packed_(int T, int H, int W, int K, int t_in, int h_in,
array<__m256i, 4> a_sum_temp;
inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>(
- a_v.data(), reinterpret_cast<const __m256i *>(Bp) + 8, C, remainder,
+ a_v.data(),
+ reinterpret_cast<const __m256i*>(Bp) + 8,
+ C,
+ remainder,
a_sum_temp.data());
if (SUM_A) {
a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]);
@@ -996,7 +1091,10 @@ inner_prod_3x3x3_packed_(int T, int H, int W, int K, int t_in, int h_in,
}
inner_prod_packed_<8, SUM_A, REMAINDER, true /* acc */>(
- a_v.data(), reinterpret_cast<const __m256i *>(Bp) + 16, C, remainder,
+ a_v.data(),
+ reinterpret_cast<const __m256i*>(Bp) + 16,
+ C,
+ remainder,
a_sum_temp.data());
if (SUM_A) {
a_sum[0] = _mm256_add_epi32(a_sum[0], a_sum_temp[0]);
@@ -1024,7 +1122,10 @@ inner_prod_3x3x3_packed_(int T, int H, int W, int K, int t_in, int h_in,
}
inner_prod_packed_<3, SUM_A, REMAINDER, true /* acc */>(
- a_v.data(), reinterpret_cast<const __m256i *>(Bp) + 24, C, remainder,
+ a_v.data(),
+ reinterpret_cast<const __m256i*>(Bp) + 24,
+ C,
+ remainder,
a_sum_temp.data());
if (SUM_A) {
@@ -1037,29 +1138,37 @@ inner_prod_3x3x3_packed_(int T, int H, int W, int K, int t_in, int h_in,
for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) {
if (PER_CHANNEL_QUANTIZATION) {
B_zero_point_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i *>(B_zero_point + i * 8));
+ reinterpret_cast<const __m256i*>(B_zero_point + i * 8));
} else {
B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]);
}
- _mm256_store_si256(reinterpret_cast<__m256i *>(&row_offsets[i * 8]),
- _mm256_mullo_epi32(a_sum[i], B_zero_point_v));
+ _mm256_store_si256(
+ reinterpret_cast<__m256i*>(&row_offsets[i * 8]),
+ _mm256_mullo_epi32(a_sum[i], B_zero_point_v));
}
}
}
template <bool SUM_A, bool FUSE_RELU>
-static inline __attribute__((always_inline))
-void depthwise_3x3_kernel_(int H, int W, int K, int h, int w,
- int stride_h, int stride_w,
- int32_t A_zero_point, const uint8_t* A,
- int32_t B_zero_point, const int8_t* Bp,
- float C_multiplier,
- int32_t C_zero_point,
- int32_t* C_int32, uint8_t* C_uint8,
- int32_t* row_offsets,
- const int32_t *col_offsets,
- const int32_t *bias)
-{
+static inline __attribute__((always_inline)) void depthwise_3x3_kernel_(
+ int H,
+ int W,
+ int K,
+ int h,
+ int w,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const int8_t* Bp,
+ float C_multiplier,
+ int32_t C_zero_point,
+ int32_t* C_int32,
+ uint8_t* C_uint8,
+ int32_t* row_offsets,
+ const int32_t* col_offsets,
+ const int32_t* bias) {
constexpr int S = 3;
constexpr int PAD_T = 1, PAD_L = 1, PAD_R = 1;
int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
@@ -1069,43 +1178,72 @@ void depthwise_3x3_kernel_(int H, int W, int K, int h, int w,
int k;
for (k = 0; k < K / 32 * 32; k += 32) {
inner_prod_3x3_packed_<SUM_A>(
- H, W, K, h_in, w_in,
- A + (h_in * W + w_in) * K + k, A_zero_point,
- Bp + k * 10, &B_zero_point,
- C_int32 + k, 0, &row_offsets[k]);
+ H,
+ W,
+ K,
+ h_in,
+ w_in,
+ A + (h_in * W + w_in) * K + k,
+ A_zero_point,
+ Bp + k * 10,
+ &B_zero_point,
+ C_int32 + k,
+ 0,
+ &row_offsets[k]);
}
int remainder = K - k;
if (remainder) {
inner_prod_3x3_packed_<SUM_A, true>(
- H, W, K, h_in, w_in,
- A + (h_in * W + w_in) * K + k, A_zero_point,
- Bp + k * 10, &B_zero_point,
- C_int32 + k, remainder, &row_offsets[k]);
+ H,
+ W,
+ K,
+ h_in,
+ w_in,
+ A + (h_in * W + w_in) * K + k,
+ A_zero_point,
+ Bp + k * 10,
+ &B_zero_point,
+ C_int32 + k,
+ remainder,
+ &row_offsets[k]);
}
if (SUM_A) {
- requantize_<FUSE_RELU, true>
- (
- A_zero_point, C_multiplier, C_zero_point,
- C_int32, C_uint8 + (h * W_OUT + w) * K, K,
+ requantize_<FUSE_RELU, true>(
+ A_zero_point,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8 + (h * W_OUT + w) * K,
+ K,
row_offsets,
- col_offsets, bias
- );
+ col_offsets,
+ bias);
}
}
template <bool SUM_A, bool FUSE_RELU>
-static inline __attribute__((always_inline))
-void depthwise_3x3x3_kernel_(int T, int H, int W, int K, int t, int h, int w,
- int stride_t, int stride_h, int stride_w,
- int32_t A_zero_point, const uint8_t* A,
- int32_t B_zero_point, const int8_t* Bp,
- float C_multiplier,
- int32_t C_zero_point,
- int32_t* C_int32, uint8_t* C_uint8,
- int32_t* row_offsets,
- const int32_t *col_offsets,
- const int32_t *bias)
-{
+static inline __attribute__((always_inline)) void depthwise_3x3x3_kernel_(
+ int T,
+ int H,
+ int W,
+ int K,
+ int t,
+ int h,
+ int w,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const int8_t* Bp,
+ float C_multiplier,
+ int32_t C_zero_point,
+ int32_t* C_int32,
+ uint8_t* C_uint8,
+ int32_t* row_offsets,
+ const int32_t* col_offsets,
+ const int32_t* bias) {
constexpr int R = 3, S = 3;
constexpr int PAD_P = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
@@ -1117,39 +1255,74 @@ void depthwise_3x3x3_kernel_(int T, int H, int W, int K, int t, int h, int w,
int k;
for (k = 0; k < K / 32 * 32; k += 32) {
inner_prod_3x3x3_packed_<SUM_A>(
- T, H, W, K, t_in, h_in, w_in,
- A + ((t_in * H + h_in) * W + w_in) * K + k, A_zero_point,
- Bp + k * 28, &B_zero_point,
- C_int32 + k, 0, &row_offsets[k]);
+ T,
+ H,
+ W,
+ K,
+ t_in,
+ h_in,
+ w_in,
+ A + ((t_in * H + h_in) * W + w_in) * K + k,
+ A_zero_point,
+ Bp + k * 28,
+ &B_zero_point,
+ C_int32 + k,
+ 0,
+ &row_offsets[k]);
}
int remainder = K - k;
if (remainder) {
inner_prod_3x3x3_packed_<SUM_A, true>(
- T, H, W, K, t_in, h_in, w_in,
- A + ((t_in * H + h_in) * W + w_in) * K + k, A_zero_point,
- Bp + k * 28, &B_zero_point,
- C_int32 + k, remainder, &row_offsets[k]);
+ T,
+ H,
+ W,
+ K,
+ t_in,
+ h_in,
+ w_in,
+ A + ((t_in * H + h_in) * W + w_in) * K + k,
+ A_zero_point,
+ Bp + k * 28,
+ &B_zero_point,
+ C_int32 + k,
+ remainder,
+ &row_offsets[k]);
}
if (SUM_A) {
- requantize_<FUSE_RELU, true>
- (
- A_zero_point, C_multiplier, C_zero_point,
- C_int32, C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K, K,
+ requantize_<FUSE_RELU, true>(
+ A_zero_point,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8 + ((t * H_OUT + h) * W_OUT + w) * K,
+ K,
row_offsets,
- col_offsets, bias
- );
+ col_offsets,
+ bias);
}
}
template <bool SUM_A>
static inline __attribute__((always_inline)) void
depthwise_3x3_per_channel_quantization_kernel_(
- int H, int W, int K, int h, int w, int stride_h, int stride_w,
- int32_t A_zero_point, const uint8_t *A,
- const int32_t *B_zero_point, const int8_t *Bp,
- const float *C_multiplier, int32_t C_zero_point,
- int32_t *C_int32, uint8_t *C_uint8,
- int32_t *row_offsets, const int32_t *col_offsets, const int32_t *bias) {
+ int H,
+ int W,
+ int K,
+ int h,
+ int w,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const int8_t* Bp,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ int32_t* C_int32,
+ uint8_t* C_uint8,
+ int32_t* row_offsets,
+ const int32_t* col_offsets,
+ const int32_t* bias) {
constexpr int S = 3;
constexpr int PAD_T = 1, PAD_L = 1, PAD_R = 1;
int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
@@ -1158,28 +1331,47 @@ depthwise_3x3_per_channel_quantization_kernel_(
int k;
for (k = 0; k < K / 32 * 32; k += 32) {
- inner_prod_3x3_packed_<SUM_A, false/*remainder*/, true/*per-channel*/>(
- H, W, K, h_in, w_in,
- A + (h_in * W + w_in) * K + k, A_zero_point,
- Bp + k * 10, B_zero_point + k,
- C_int32 + k, 0, &row_offsets[k]);
+ inner_prod_3x3_packed_<SUM_A, false /*remainder*/, true /*per-channel*/>(
+ H,
+ W,
+ K,
+ h_in,
+ w_in,
+ A + (h_in * W + w_in) * K + k,
+ A_zero_point,
+ Bp + k * 10,
+ B_zero_point + k,
+ C_int32 + k,
+ 0,
+ &row_offsets[k]);
}
int remainder = K - k;
if (remainder) {
- inner_prod_3x3_packed_<SUM_A, true/*remainder*/, true/*per-channel*/>(
- H, W, K, h_in, w_in,
- A + (h_in * W + w_in) * K + k, A_zero_point,
- Bp + k * 10, B_zero_point + k,
- C_int32 + k, remainder, &row_offsets[k]);
+ inner_prod_3x3_packed_<SUM_A, true /*remainder*/, true /*per-channel*/>(
+ H,
+ W,
+ K,
+ h_in,
+ w_in,
+ A + (h_in * W + w_in) * K + k,
+ A_zero_point,
+ Bp + k * 10,
+ B_zero_point + k,
+ C_int32 + k,
+ remainder,
+ &row_offsets[k]);
}
if (SUM_A) {
- requantize_per_channel_<false, true>
- (
- A_zero_point, C_multiplier, C_zero_point,
- C_int32, C_uint8 + (h * W_OUT + w) * K, K,
+ requantize_per_channel_<false, true>(
+ A_zero_point,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8 + (h * W_OUT + w) * K,
+ K,
row_offsets,
- col_offsets, bias
- );
+ col_offsets,
+ bias);
}
}
@@ -1188,7 +1380,7 @@ static pair<int, int> closest_factors_(int n) {
while (n % a != 0) {
a--;
}
- return { a, n / a }; // a <= n / a
+ return {a, n / a}; // a <= n / a
}
// TODO: short-circuit when B_zero_point is 0 or A_zero_point is 0
@@ -1196,16 +1388,25 @@ static pair<int, int> closest_factors_(int n) {
// filter shapes by parameterizing with R and S but restricting it to just 3x3
// for now.
template <bool FUSE_RESCALE = true, bool FUSE_RELU = false>
-static inline __attribute__((always_inline))
-void depthwise_3x3_pad_1_(int N, int H, int W, int K,
- int stride_h, int stride_w,
- int32_t A_zero_point, const uint8_t *A,
- int32_t B_zero_point, const Packed3x3ConvMatrix &B,
- float C_multiplier,
- int32_t C_zero_point,
- int32_t* C_int32, uint8_t* C_uint8,
- const int32_t *col_offsets, const int32_t *bias,
- int thread_id, int num_threads) {
+static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const Packed3x3ConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ int32_t* C_int32,
+ uint8_t* C_uint8,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ int thread_id,
+ int num_threads) {
assert(K % 8 == 0);
constexpr int R = 3, S = 3;
constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
@@ -1213,8 +1414,8 @@ void depthwise_3x3_pad_1_(int N, int H, int W, int K,
int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
const int8_t* Bp = B.PackedMat();
- int32_t row_offsets[(K + 31) / 32 * 32] __attribute__ ((aligned (64)));
- int32_t *C_temp;
+ int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64)));
+ int32_t* C_temp;
int n_begin, n_end;
int h_begin, h_end, w_begin, w_end;
@@ -1266,22 +1467,48 @@ void depthwise_3x3_pad_1_(int N, int H, int W, int K,
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
if (w_end == W_OUT) {
@@ -1289,11 +1516,24 @@ void depthwise_3x3_pad_1_(int N, int H, int W, int K,
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
}
@@ -1303,22 +1543,48 @@ void depthwise_3x3_pad_1_(int N, int H, int W, int K,
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
if (w_end == W_OUT) {
@@ -1326,11 +1592,24 @@ void depthwise_3x3_pad_1_(int N, int H, int W, int K,
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
}
@@ -1341,22 +1620,48 @@ void depthwise_3x3_pad_1_(int N, int H, int W, int K,
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
if (w_end == W_OUT) {
@@ -1364,28 +1669,51 @@ void depthwise_3x3_pad_1_(int N, int H, int W, int K,
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
}
} // for each n
};
template <bool FUSE_RESCALE = true, bool FUSE_RELU = false>
-static inline __attribute__((always_inline))
-void depthwise_3x3x3_pad_1_(int N, int T, int H, int W, int K,
- int stride_t, int stride_h, int stride_w,
- int32_t A_zero_point, const uint8_t *A,
- int32_t B_zero_point,
- const Packed3x3x3ConvMatrix &B,
- float C_multiplier,
- int32_t C_zero_point,
- int32_t* C_int32, uint8_t* C_uint8,
- const int32_t *col_offsets, const int32_t *bias,
- int thread_id, int num_threads) {
+static inline __attribute__((always_inline)) void depthwise_3x3x3_pad_1_(
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const Packed3x3x3ConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ int32_t* C_int32,
+ uint8_t* C_uint8,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ int thread_id,
+ int num_threads) {
assert(K % 8 == 0);
constexpr int K_T = 3, K_H = 3, K_W = 3;
constexpr int PAD_P = 1, PAD_N = 1, PAD_T = 1, PAD_B = 1, PAD_L = 1,
@@ -1395,8 +1723,8 @@ void depthwise_3x3x3_pad_1_(int N, int T, int H, int W, int K,
int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
const int8_t* Bp = B.PackedMat();
- int32_t row_offsets[(K + 31) / 32 * 32] __attribute__ ((aligned (64)));
- int32_t *C_temp;
+ int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64)));
+ int32_t* C_temp;
int n_begin, n_end;
int t_begin, t_end, h_begin, h_end;
@@ -1443,14 +1771,30 @@ void depthwise_3x3x3_pad_1_(int N, int T, int H, int W, int K,
for (int t = t_begin; t < t_end; ++t) {
for (int h = h_begin; h < h_end; ++h) {
for (int w = 0; w < W_OUT; ++w) {
- C_temp =
- FUSE_RESCALE
- ? C_int32
- : C_int32 + (((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K;
+ C_temp = FUSE_RESCALE
+ ? C_int32
+ : C_int32 + (((n * T_OUT + t) * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3x3_kernel_<FUSE_RESCALE, FUSE_RELU>(
- T, H, W, K, t, h, w, stride_t, stride_h, stride_w, A_zero_point,
- A_base, B_zero_point, Bp, C_multiplier,
- C_zero_point, C_temp, C_uint8_base, row_offsets, col_offsets,
+ T,
+ H,
+ W,
+ K,
+ t,
+ h,
+ w,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
bias);
} // w
} // h
@@ -1461,11 +1805,23 @@ void depthwise_3x3x3_pad_1_(int N, int T, int H, int W, int K,
template <bool FUSE_RESCALE = true>
static inline __attribute__((always_inline)) void
depthwise_3x3_per_channel_quantization_pad_1_(
- int N, int H, int W, int K, int stride_h, int stride_w,
- int32_t A_zero_point, const uint8_t *A, const int32_t *B_zero_point,
- const Packed3x3ConvMatrix &B, const float *C_multiplier,
- int32_t C_zero_point, int32_t *C_int32, uint8_t *C_uint8,
- const int32_t *col_offsets, const int32_t *bias, int thread_id,
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const Packed3x3ConvMatrix& B,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ int32_t* C_int32,
+ uint8_t* C_uint8,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ int thread_id,
int num_threads) {
assert(K % 8 == 0);
constexpr int R = 3, S = 3;
@@ -1474,8 +1830,8 @@ depthwise_3x3_per_channel_quantization_pad_1_(
int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
const int8_t* Bp = B.PackedMat();
- int32_t row_offsets[(K + 31) / 32 * 32] __attribute__ ((aligned (64)));
- int32_t *C_temp;
+ int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64)));
+ int32_t* C_temp;
int n_begin, n_end;
int h_begin, h_end, w_begin, w_end;
@@ -1527,22 +1883,48 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
if (w_end == W_OUT) {
@@ -1550,11 +1932,24 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
}
@@ -1564,22 +1959,48 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
if (w_end == W_OUT) {
@@ -1587,11 +2008,24 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
}
@@ -1602,22 +2036,48 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
if (w_end == W_OUT) {
@@ -1625,11 +2085,24 @@ depthwise_3x3_per_channel_quantization_pad_1_(
C_temp = FUSE_RESCALE ? C_int32
: C_int32 + ((n * H_OUT + h) * W_OUT + w) * K;
depthwise_3x3_per_channel_quantization_kernel_<FUSE_RESCALE>(
- H, W, K, h, w, stride_h, stride_w,
- A_zero_point, A_base,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_temp, C_uint8_base,
- row_offsets, col_offsets, bias);
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_temp,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias);
}
}
} // for each n
@@ -1643,163 +2116,338 @@ void depthwise_3x3_pad_1(
int K,
int stride_h,
int stride_w,
- int32_t A_zero_point, const uint8_t* A,
+ int32_t A_zero_point,
+ const uint8_t* A,
const Packed3x3ConvMatrix& B,
int32_t* C,
int thread_id,
int num_threads) {
if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
depthwise_3x3_pad_1_<false>(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- 0, B,
- 0.0f, 0, C, nullptr,
- nullptr, nullptr,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ 0,
+ B,
+ 0.0f,
+ 0,
+ C,
+ nullptr,
+ nullptr,
+ nullptr,
+ thread_id,
+ num_threads);
} else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
depthwise_3x3_pad_1_<false>(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- 0, B,
- 0.0f, 0, C, nullptr,
- nullptr, nullptr,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ 0,
+ B,
+ 0.0f,
+ 0,
+ C,
+ nullptr,
+ nullptr,
+ nullptr,
+ thread_id,
+ num_threads);
} else if (1 == stride_h && 1 == stride_w) {
depthwise_3x3_pad_1_<false>(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- 0, B,
- 0.0f, 0, C, nullptr,
- nullptr, nullptr,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ 0,
+ B,
+ 0.0f,
+ 0,
+ C,
+ nullptr,
+ nullptr,
+ nullptr,
+ thread_id,
+ num_threads);
} else if (2 == stride_h && 2 == stride_w) {
depthwise_3x3_pad_1_<false>(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- 0, B,
- 0.0f, 0, C, nullptr,
- nullptr, nullptr,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ 0,
+ B,
+ 0.0f,
+ 0,
+ C,
+ nullptr,
+ nullptr,
+ nullptr,
+ thread_id,
+ num_threads);
} else {
depthwise_3x3_pad_1_<false>(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- 0, B,
- 0.0f, 0, C, nullptr,
- nullptr, nullptr,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ 0,
+ B,
+ 0.0f,
+ 0,
+ C,
+ nullptr,
+ nullptr,
+ nullptr,
+ thread_id,
+ num_threads);
}
}
void depthwise_3x3_pad_1(
- int N, int H, int W, int K,
- int stride_h, int stride_w,
- int32_t A_zero_point, const uint8_t* A,
- int32_t B_zero_point, const Packed3x3ConvMatrix& B,
- float C_multiplier, int32_t C_zero_point, uint8_t* C,
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const Packed3x3ConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
const int32_t* col_offsets,
const int32_t* bias,
- int thread_id, int num_threads,
+ int thread_id,
+ int num_threads,
bool fuse_relu) {
int32_t C_int32_temp[(K + 31) / 32 * 32];
if (fuse_relu) {
if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- B_zero_point, B,
- C_multiplier, C_zero_point, C_int32_temp, C,
- col_offsets, bias,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
} else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- B_zero_point, B,
- C_multiplier, C_zero_point, C_int32_temp, C,
- col_offsets, bias,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
} else if (1 == stride_h && 1 == stride_w) {
depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- B_zero_point, B,
- C_multiplier, C_zero_point, C_int32_temp, C,
- col_offsets, bias,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
} else if (2 == stride_h && 2 == stride_w) {
depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- B_zero_point, B,
- C_multiplier, C_zero_point, C_int32_temp, C,
- col_offsets, bias,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
} else {
depthwise_3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- B_zero_point, B,
- C_multiplier, C_zero_point, C_int32_temp, C,
- col_offsets, bias,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
}
} else {
if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
depthwise_3x3_pad_1_(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- B_zero_point, B,
- C_multiplier, C_zero_point, C_int32_temp, C,
- col_offsets, bias,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
} else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
depthwise_3x3_pad_1_(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- B_zero_point, B,
- C_multiplier, C_zero_point, C_int32_temp, C,
- col_offsets, bias,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
} else if (1 == stride_h && 1 == stride_w) {
depthwise_3x3_pad_1_(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- B_zero_point, B,
- C_multiplier, C_zero_point, C_int32_temp, C,
- col_offsets, bias,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
} else if (2 == stride_h && 2 == stride_w) {
depthwise_3x3_pad_1_(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- B_zero_point, B,
- C_multiplier, C_zero_point, C_int32_temp, C,
- col_offsets, bias,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
} else {
depthwise_3x3_pad_1_(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- B_zero_point, B,
- C_multiplier, C_zero_point, C_int32_temp, C,
- col_offsets, bias,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
}
}
}
@@ -1813,140 +2461,309 @@ void depthwise_3x3x3_pad_1(
int stride_t,
int stride_h,
int stride_w,
- int32_t A_zero_point, const uint8_t* A,
+ int32_t A_zero_point,
+ const uint8_t* A,
const Packed3x3x3ConvMatrix& B,
int32_t* C,
int thread_id,
int num_threads) {
depthwise_3x3x3_pad_1_<false /* FUSE_RESCALE */>(
- N, T, H, W, K,
- stride_t, stride_h, stride_w,
- A_zero_point, A,
- 0, B,
- 0.0f, 0, C, nullptr,
- nullptr, nullptr,
- thread_id, num_threads);
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ 0,
+ B,
+ 0.0f,
+ 0,
+ C,
+ nullptr,
+ nullptr,
+ nullptr,
+ thread_id,
+ num_threads);
}
static void depthwise_3x3x3_pad_1_(
- int N, int T, int H, int W, int K,
- int stride_t, int stride_h, int stride_w,
- int32_t A_zero_point, const uint8_t* A,
- int32_t B_zero_point, const Packed3x3x3ConvMatrix& B,
- float C_multiplier, int32_t C_zero_point, uint8_t* C,
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const Packed3x3x3ConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
const int32_t* col_offsets,
const int32_t* bias,
- int thread_id, int num_threads) {
+ int thread_id,
+ int num_threads) {
int32_t C_int32_temp[(K + 31) / 32 * 32];
depthwise_3x3x3_pad_1_<true /* FUSE_RESCALE */, false /* FUSE_RELU */>(
- N, T, H, W, K,
- stride_t, stride_h, stride_w,
- A_zero_point, A,
- B_zero_point, B,
- C_multiplier, C_zero_point, C_int32_temp, C,
- col_offsets, bias,
- thread_id, num_threads);
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
}
static void depthwise_3x3x3_pad_1_relu_fused_(
- int N, int T, int H, int W, int K,
- int stride_t, int stride_h, int stride_w,
- int32_t A_zero_point, const uint8_t* A,
- int32_t B_zero_point, const Packed3x3x3ConvMatrix& B,
- float C_multiplier, int32_t C_zero_point, uint8_t* C,
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const Packed3x3x3ConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
const int32_t* col_offsets,
const int32_t* bias,
- int thread_id, int num_threads) {
+ int thread_id,
+ int num_threads) {
int32_t C_int32_temp[(K + 31) / 32 * 32];
depthwise_3x3x3_pad_1_<true /* FUSE_RESCALE */, true /* FUSE_RELU */>(
- N, T, H, W, K,
- stride_t, stride_h, stride_w,
- A_zero_point, A,
- B_zero_point, B,
- C_multiplier, C_zero_point, C_int32_temp, C,
- col_offsets, bias,
- thread_id, num_threads);
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
}
void depthwise_3x3x3_pad_1(
- int N, int T, int H, int W, int K,
- int stride_t, int stride_h, int stride_w,
- int32_t A_zero_point, const uint8_t* A,
- int32_t B_zero_point, const Packed3x3x3ConvMatrix& B,
- float C_multiplier, int32_t C_zero_point, uint8_t* C,
+ int N,
+ int T,
+ int H,
+ int W,
+ int K,
+ int stride_t,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const Packed3x3x3ConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
const int32_t* col_offsets,
const int32_t* bias,
bool fuse_relu,
- int thread_id, int num_threads) {
+ int thread_id,
+ int num_threads) {
// If we inline the following two functions, I see stack overflow.
if (fuse_relu) {
depthwise_3x3x3_pad_1_relu_fused_(
- N, T, H, W, K, stride_t, stride_h, stride_w, A_zero_point, A,
- B_zero_point, B, C_multiplier, C_zero_point, C,
- col_offsets, bias, thread_id, num_threads);
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
} else {
- depthwise_3x3x3_pad_1_(N, T, H, W, K, stride_t, stride_h, stride_w,
- A_zero_point, A, B_zero_point, B, C_multiplier,
- C_zero_point, C, col_offsets, bias,
- thread_id, num_threads);
+ depthwise_3x3x3_pad_1_(
+ N,
+ T,
+ H,
+ W,
+ K,
+ stride_t,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
}
}
void depthwise_3x3_per_channel_quantization_pad_1(
- int N, int H, int W, int K,
- int stride_h, int stride_w,
- int32_t A_zero_point, const uint8_t* A,
- const int32_t *B_zero_point, const Packed3x3ConvMatrix& Bp,
- const float *C_multiplier, int32_t C_zero_point, uint8_t* C,
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const Packed3x3ConvMatrix& Bp,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
const int32_t* col_offsets,
const int32_t* bias,
- int thread_id, int num_threads) {
+ int thread_id,
+ int num_threads) {
int32_t C_int32_temp[(K + 31) / 32 * 32];
if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
depthwise_3x3_per_channel_quantization_pad_1_(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_int32_temp, C,
- col_offsets, bias,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
} else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
depthwise_3x3_per_channel_quantization_pad_1_(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_int32_temp, C,
- col_offsets, bias,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
} else if (1 == stride_h && 1 == stride_w) {
depthwise_3x3_per_channel_quantization_pad_1_(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_int32_temp, C,
- col_offsets, bias,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
} else if (2 == stride_h && 2 == stride_w) {
depthwise_3x3_per_channel_quantization_pad_1_(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_int32_temp, C,
- col_offsets, bias,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
} else {
depthwise_3x3_per_channel_quantization_pad_1_(
- N, H, W, K,
- stride_h, stride_w,
- A_zero_point, A,
- B_zero_point, Bp,
- C_multiplier, C_zero_point, C_int32_temp, C,
- col_offsets, bias,
- thread_id, num_threads);
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ thread_id,
+ num_threads);
}
}
diff --git a/src/GenerateKernelU8S8S32ACC16.cc b/src/GenerateKernelU8S8S32ACC16.cc
index 2ffe3ab..451592a 100644
--- a/src/GenerateKernelU8S8S32ACC16.cc
+++ b/src/GenerateKernelU8S8S32ACC16.cc
@@ -94,7 +94,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
asmjit::X86Ymm extractDest256 = x86::ymm15;
for (int i = 0; i < rowRegs; ++i) {
- a->imul(C_Offset, ldcReg, i * sizeof(int32_t));
+ a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(i * sizeof(int32_t)));
for (int j = 0; j < colRegs; ++j) {
for (int idx = 0; idx < 2; ++idx) {
a->vextracti128(
@@ -214,17 +214,19 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
a->mov(kIdx, 0);
a->bind(Loopk);
// k is incremented by row_interleave
- a->add(kIdx, row_interleave);
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
genComputeBlock<inst_set_t::avx2>(
a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
// update buffer_A address for next k iteration
- a->add(buffer_A, row_interleave * sizeof(uint8_t));
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
// update buffer_B address for next k iteration
- a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t));
- // a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t));
+ a->add(
+ buffer_B, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t)));
+ // a->add(B_pf, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t)));
a->cmp(kIdx, kSize);
a->jl(Loopk);
@@ -234,9 +236,11 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
// increment A for next block
a->sub(buffer_A, kSize);
- a->add(buffer_A, (rowRegs)*kBlock * sizeof(uint8_t));
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
// increment C for next block
- a->imul(C_Offset, ldcReg, rowRegs * sizeof(int32_t));
+ a->imul(
+ C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t)));
a->add(CBase, C_Offset);
// reset B
a->mov(buffer_B, buffer_B_saved);
@@ -258,16 +262,18 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
a->bind(LoopkRem);
// k is incremented by row_interleave
- a->add(kIdx, row_interleave);
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
genComputeBlock<inst_set_t::avx2>(
a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
// update buffer_A address for next k iteration
- a->add(buffer_A, row_interleave * sizeof(uint8_t));
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
// update buffer_B address for next k iteration
- a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t));
+ a->add(
+ buffer_B, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t)));
// a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t));
a->cmp(kIdx, kSize);
diff --git a/src/GenerateKernelU8S8S32ACC16_avx512.cc b/src/GenerateKernelU8S8S32ACC16_avx512.cc
index e613cf1..cab43ed 100644
--- a/src/GenerateKernelU8S8S32ACC16_avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC16_avx512.cc
@@ -94,7 +94,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
asmjit::X86Zmm extractDest512 = x86::zmm31;
for (int i = 0; i < rowRegs; ++i) {
- a->imul(C_Offset, ldcReg, i * sizeof(int32_t));
+ a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(i * sizeof(int32_t)));
for (int j = 0; j < colRegs; ++j) {
for (int idx = 0; idx < 2; ++idx) {
a->vextracti32x8(
@@ -215,17 +215,19 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
a->mov(kIdx, 0);
a->bind(Loopk);
// k is incremented by row_interleave
- a->add(kIdx, row_interleave);
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
genComputeBlock<inst_set_t::avx512>(
a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
// update buffer_A address for next k iteration
- a->add(buffer_A, row_interleave * sizeof(uint8_t));
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
// update buffer_B address for next k iteration
- a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t));
- // a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t));
+ a->add(
+ buffer_B, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t)));
+ // a->add(B_pf, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t)));
a->cmp(kIdx, kSize);
a->jl(Loopk);
@@ -236,9 +238,11 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
// increment A for next block
a->sub(buffer_A, kSize);
- a->add(buffer_A, (rowRegs)*kBlock * sizeof(uint8_t));
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
// increment C for next block
- a->imul(C_Offset, ldcReg, rowRegs * sizeof(int32_t));
+ a->imul(
+ C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs * sizeof(int32_t)));
a->add(CBase, C_Offset);
// reset B
a->mov(buffer_B, buffer_B_saved);
@@ -260,17 +264,19 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx512>(
a->bind(LoopkRem);
// k is incremented by row_interleave
- a->add(kIdx, row_interleave);
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
genComputeBlock<inst_set_t::avx512>(
a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock);
// update buffer_A address for next k iteration
- a->add(buffer_A, row_interleave * sizeof(uint8_t));
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
// update buffer_B address for next k iteration
- a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t));
- // a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t));
+ a->add(
+ buffer_B, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t)));
+ // a->add(B_pf, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t)));
a->cmp(kIdx, kSize);
a->jl(LoopkRem);
diff --git a/src/GenerateKernelU8S8S32ACC32.cc b/src/GenerateKernelU8S8S32ACC32.cc
index dc8c6d3..9529f5d 100644
--- a/src/GenerateKernelU8S8S32ACC32.cc
+++ b/src/GenerateKernelU8S8S32ACC32.cc
@@ -201,7 +201,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>(
// and so on
a->vpcmpeqw(oneReg, oneReg, oneReg);
a->vpsrlw(oneReg, oneReg, 15);
- a->imul(ldcReg, ldcReg, sizeof(int32_t));
+ a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t)));
a->mov(C_Offset, 0);
int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
@@ -226,17 +226,19 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>(
a->bind(Loopk);
// k is incremented by row_interleave
- a->add(kIdx, row_interleave);
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
genComputeBlock<inst_set_t::avx2>(
a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
// update buffer_A address for next k iteration
- a->add(buffer_A, row_interleave * sizeof(uint8_t));
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
// update buffer_B address for next k iteration
- a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t));
- a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t));
+ a->add(
+ buffer_B, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t)));
+ a->add(B_pf, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t)));
// a->add(B_pf, 32*sizeof(float));
@@ -249,10 +251,11 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>(
// increment A for next block
a->sub(buffer_A, kSize);
- a->add(buffer_A, (rowRegs)*kBlock * sizeof(uint8_t));
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
// increment C for next block
- a->imul(C_Offset, ldcReg, rowRegs);
+ a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs));
a->add(CBase, C_Offset);
a->mov(C_Offset, 0);
@@ -275,17 +278,19 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx2>(
a->bind(LoopkRem);
// k is incremented by row_interleave
- a->add(kIdx, row_interleave);
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
genComputeBlock<inst_set_t::avx2>(
a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
// update buffer_A address for next k iteration
- a->add(buffer_A, row_interleave * sizeof(uint8_t));
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
// update buffer_B address for next k iteration
- a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t));
- a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t));
+ a->add(
+ buffer_B, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t)));
+ a->add(B_pf, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t)));
a->cmp(kIdx, kSize);
a->jl(LoopkRem);
diff --git a/src/GenerateKernelU8S8S32ACC32_avx512.cc b/src/GenerateKernelU8S8S32ACC32_avx512.cc
index 5cd5684..251a8b8 100644
--- a/src/GenerateKernelU8S8S32ACC32_avx512.cc
+++ b/src/GenerateKernelU8S8S32ACC32_avx512.cc
@@ -203,7 +203,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
// a->vpcmpeqw(oneReg, oneReg, oneReg);
a->vpternlogd(oneReg, oneReg, oneReg, 0xff);
a->vpsrlw(oneReg, oneReg, 15);
- a->imul(ldcReg, ldcReg, sizeof(int32_t));
+ a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t)));
a->mov(C_Offset, 0);
int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
@@ -228,19 +228,21 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
a->bind(Loopk);
// k is incremented by row_interleave
- a->add(kIdx, row_interleave);
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
genComputeBlock<inst_set_t::avx512>(
a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
// update buffer_A address for next k iteration
- a->add(buffer_A, row_interleave * sizeof(uint8_t));
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
// update buffer_B address for next k iteration
- a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t));
- a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t));
+ a->add(
+ buffer_B, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t)));
+ a->add(B_pf, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t)));
- // a->add(B_pf, 32*sizeof(float));
+ // a->add(B_pf, static_cast<asmjit::Imm>(32*sizeof(float)));
a->cmp(kIdx, kSize);
a->jl(Loopk);
@@ -251,10 +253,11 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
// increment A for next block
a->sub(buffer_A, kSize);
- a->add(buffer_A, (rowRegs)*kBlock * sizeof(uint8_t));
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>((rowRegs)*kBlock * sizeof(uint8_t)));
// increment C for next block
- a->imul(C_Offset, ldcReg, rowRegs);
+ a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(rowRegs));
a->add(CBase, C_Offset);
a->mov(C_Offset, 0);
@@ -277,17 +280,19 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate<inst_set_t::avx512>(
a->bind(LoopkRem);
// k is incremented by row_interleave
- a->add(kIdx, row_interleave);
+ a->add(kIdx, static_cast<asmjit::Imm>(row_interleave));
genComputeBlock<inst_set_t::avx512>(
a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs);
// update buffer_A address for next k iteration
- a->add(buffer_A, row_interleave * sizeof(uint8_t));
+ a->add(
+ buffer_A, static_cast<asmjit::Imm>(row_interleave * sizeof(uint8_t)));
// update buffer_B address for next k iteration
- a->add(buffer_B, VLEN_ * colRegs * sizeof(int8_t));
- a->add(B_pf, VLEN_ * colRegs * sizeof(int8_t));
+ a->add(
+ buffer_B, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t)));
+ a->add(B_pf, static_cast<asmjit::Imm>(VLEN_ * colRegs * sizeof(int8_t)));
a->cmp(kIdx, kSize);
a->jl(LoopkRem);
diff --git a/src/PackAMatrix.cc b/src/PackAMatrix.cc
index 543d99b..8f260ba 100644
--- a/src/PackAMatrix.cc
+++ b/src/PackAMatrix.cc
@@ -21,7 +21,7 @@ PackAMatrix<T, accT>::PackAMatrix(
int32_t ld,
inpType* pmat,
int32_t groups,
- accT zero_pt)
+ std::int32_t zero_pt)
: PackMatrix<PackAMatrix<T, accT>, T, accT>(nRow, nCol, pmat, zero_pt),
trans_(trans),
smat_(smat),
@@ -44,8 +44,8 @@ PackAMatrix<T, accT>::PackAMatrix(
assert(0 && "unsupported architecure");
}
if (!pmat) {
- BaseType::buf_ =
- (T*)aligned_alloc(64, BaseType::brow_ * BaseType::bcol_ * sizeof(T));
+ BaseType::buf_ = (T*)fbgemmAlignedAlloc(
+ 64, BaseType::brow_ * BaseType::bcol_ * sizeof(T));
}
}
diff --git a/src/PackAWithIm2Col.cc b/src/PackAWithIm2Col.cc
index a007685..e067a3e 100644
--- a/src/PackAWithIm2Col.cc
+++ b/src/PackAWithIm2Col.cc
@@ -49,7 +49,8 @@ PackAWithIm2Col<T, accT>::PackAWithIm2Col(
} else {
BaseType::bufAllocatedHere_ = true;
BaseType::buf_ = static_cast<T*>(
- aligned_alloc(64, BaseType::brow_ * BaseType::bcol_ * sizeof(T)));
+ fbgemmAlignedAlloc(64, BaseType::brow_ * BaseType::bcol_ * sizeof(T)));
+ //aligned_alloc(64, BaseType::brow_ * BaseType::bcol_ * sizeof(T)));
}
if (row_offset) {
rowOffsetAllocatedHere = false;
@@ -57,7 +58,7 @@ PackAWithIm2Col<T, accT>::PackAWithIm2Col(
} else {
rowOffsetAllocatedHere = true;
row_offset_ = static_cast<int32_t*>(
- aligned_alloc(64, BaseType::brow_ * sizeof(int32_t)));
+ fbgemmAlignedAlloc(64, BaseType::brow_ * sizeof(int32_t)));
}
}
diff --git a/src/PackBMatrix.cc b/src/PackBMatrix.cc
index 30d94f8..7b9ba41 100644
--- a/src/PackBMatrix.cc
+++ b/src/PackBMatrix.cc
@@ -21,7 +21,7 @@ PackBMatrix<T, accT>::PackBMatrix(
int32_t ld,
inpType* pmat,
int32_t groups,
- accT zero_pt)
+ std::int32_t zero_pt)
: PackMatrix<PackBMatrix<T, accT>, T, accT>(nRow, nCol, pmat, zero_pt),
trans_(trans),
smat_(smat),
@@ -46,7 +46,7 @@ PackBMatrix<T, accT>::PackBMatrix(
BaseType::packedBlock(block);
if (!pmat) {
BaseType::bufAllocatedHere_ = true;
- BaseType::buf_ = (T*)aligned_alloc(
+ BaseType::buf_ = (T*)fbgemmAlignedAlloc(
64,
BaseType::blockRows() * BaseType::brow_ * BaseType::blockCols() *
BaseType::bcol_ * sizeof(T));
diff --git a/src/PackWithQuantRowOffset.cc b/src/PackWithQuantRowOffset.cc
index 74eaade..5f60faa 100644
--- a/src/PackWithQuantRowOffset.cc
+++ b/src/PackWithQuantRowOffset.cc
@@ -60,13 +60,13 @@ PackAWithQuantRowOffset<T, accT>::PackAWithQuantRowOffset(
BaseType::buf_ = pmat;
} else {
BaseType::bufAllocatedHere_ = true;
- BaseType::buf_ =
- (T*)aligned_alloc(64, BaseType::brow_ * BaseType::bcol_ * sizeof(T));
+ BaseType::buf_ = (T*)fbgemmAlignedAlloc(
+ 64, BaseType::brow_ * BaseType::bcol_ * sizeof(T));
}
if (!row_offset_) {
rowOffsetAllocatedHere = true;
row_offset_ = reinterpret_cast<int32_t*>(
- aligned_alloc(64, BaseType::brow_ * sizeof(accT)));
+ fbgemmAlignedAlloc(64, BaseType::brow_ * sizeof(accT)));
}
}
@@ -109,12 +109,40 @@ void PackAWithQuantRowOffset<T, accT>::pack(const block_type_t& block) {
constexpr int VLEN = 8;
__m256 inverse_scale_v = _mm256_set1_ps(1.0f / scale_);
__m256i shuffle_mask_v = _mm256_set_epi8(
- 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
- 0xff, 0xff, 0xff, 0xff, 0x0c, 0x08, 0x04, 0x00,
- 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
- 0xff, 0xff, 0xff, 0xff, 0x0c, 0x08, 0x04, 0x00);
- __m256i permute_mask_v = _mm256_set_epi32(
- 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00);
+ 0xff,
+ 0xff,
+ 0xff,
+ 0xff,
+ 0xff,
+ 0xff,
+ 0xff,
+ 0xff,
+ 0xff,
+ 0xff,
+ 0xff,
+ 0xff,
+ 0x0c,
+ 0x08,
+ 0x04,
+ 0x00,
+ 0xff,
+ 0xff,
+ 0xff,
+ 0xff,
+ 0xff,
+ 0xff,
+ 0xff,
+ 0xff,
+ 0xff,
+ 0xff,
+ 0xff,
+ 0xff,
+ 0x0c,
+ 0x08,
+ 0x04,
+ 0x00);
+ __m256i permute_mask_v =
+ _mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00);
#endif
for (int i = 0; i < block.row_size; ++i) {
diff --git a/src/PackWithRowOffset.cc b/src/PackWithRowOffset.cc
index 8722723..fa1f2b0 100644
--- a/src/PackWithRowOffset.cc
+++ b/src/PackWithRowOffset.cc
@@ -4,12 +4,12 @@
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
+#include <cpuinfo.h>
#include <cassert>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <stdexcept>
-#include <cpuinfo.h>
#include "fbgemm/Fbgemm.h"
namespace fbgemm2 {
@@ -50,20 +50,20 @@ PackAWithRowOffset<T, accT>::PackAWithRowOffset(
row_interleave_B_ =
PackingTraits<T, accT, inst_set_t::avx2>::ROW_INTERLEAVE;
} else {
- //TODO: Have default slower path
+ // TODO: Have default slower path
assert(0 && "unknown architecure");
}
if (pmat) {
BaseType::buf_ = pmat;
} else {
BaseType::bufAllocatedHere_ = true;
- BaseType::buf_ =
- (T*)aligned_alloc(64, BaseType::brow_ * BaseType::bcol_ * sizeof(T));
+ BaseType::buf_ = (T*)fbgemmAlignedAlloc(
+ 64, BaseType::brow_ * BaseType::bcol_ * sizeof(T));
}
if (!row_offset_) {
rowOffsetAllocatedHere = true;
- row_offset_ = static_cast<int32_t*>(aligned_alloc(64,
- BaseType::brow_ * sizeof(int32_t)));
+ row_offset_ = static_cast<int32_t*>(
+ fbgemmAlignedAlloc(64, BaseType::brow_ * sizeof(int32_t)));
}
}
@@ -89,8 +89,8 @@ void PackAWithRowOffset<T, accT>::pack(const block_type_t& block) {
int32_t* row_offset_buf = getRowOffsetBuffer();
if (tr) {
for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
- int32_t row_sum = row_offset_acc ?
- row_offset_buf[i - block.row_start] : 0;
+ int32_t row_sum =
+ row_offset_acc ? row_offset_buf[i - block.row_start] : 0;
for (int j = block.col_start; j < block.col_start + block.col_size; ++j) {
T val = smat_[i + ld_ * j];
row_sum += val;
@@ -101,7 +101,8 @@ void PackAWithRowOffset<T, accT>::pack(const block_type_t& block) {
// zero fill
// Please see the comment in PackAMatrix.cc on zero vs zero_pt fill.
for (int j = block.col_start + block.col_size;
- j < block_p.col_start + block_p.col_size; ++j) {
+ j < block_p.col_start + block_p.col_size;
+ ++j) {
out[(i - block.row_start) * BaseType::blockColSize() +
(j - block.col_start)] = 0;
}
@@ -117,8 +118,8 @@ void PackAWithRowOffset<T, accT>::pack(const block_type_t& block) {
for (int j = block.col_size; j < block_p.col_size; ++j) {
out[buf_idx * BaseType::blockColSize() + j] = 0;
}
- int32_t row_sum = row_offset_acc ?
- row_offset_buf[i - block.row_start] : 0;
+ int32_t row_sum =
+ row_offset_acc ? row_offset_buf[i - block.row_start] : 0;
__m256i sum_v = _mm256_setzero_si256();
__m256i one_epi16_v = _mm256_set1_epi16(1);
__m256i one_epi8_v = _mm256_set1_epi8(1);
@@ -137,8 +138,10 @@ void PackAWithRowOffset<T, accT>::pack(const block_type_t& block) {
++j) {
row_sum += smat_[i * ld_ + j];
}
- alignas(64) std::array<int32_t, 8> temp;
- _mm256_store_si256(reinterpret_cast<__m256i*>(temp.data()), sum_v);
+ // alignas(64) std::array<int32_t, 8> temp;
+ alignas(64) std::int32_t temp[8];
+ //_mm256_store_si256(reinterpret_cast<__m256i*>(temp.data()), sum_v);
+ _mm256_store_si256(reinterpret_cast<__m256i*>(temp), sum_v);
for (int k = 0; k < 8; ++k) {
row_sum += temp[k];
}
@@ -190,13 +193,13 @@ void PackAWithRowOffset<T, accT>::printPackedMatrix(std::string name) {
template <typename T, typename accT>
int PackAWithRowOffset<T, accT>::rowOffsetBufferSize() {
- if(cpuinfo_initialize()){
+ if (cpuinfo_initialize()) {
if (cpuinfo_has_x86_avx512f()) {
return PackingTraits<T, accT, inst_set_t::avx512>::MCB;
} else if (cpuinfo_has_x86_avx2()) {
return PackingTraits<T, accT, inst_set_t::avx2>::MCB;
} else {
- //TODO: Have default slower path
+ // TODO: Have default slower path
assert(0 && "unsupported architecture");
return -1;
}
diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc
index 10e581f..6bf2d65 100644
--- a/src/RefImplementations.cc
+++ b/src/RefImplementations.cc
@@ -9,6 +9,7 @@
#include <cassert>
#include <cmath>
#include <cstring>
+#include <algorithm>
using namespace std;
@@ -45,7 +46,7 @@ void requantize_u8acc32_ref(
out[i * ld + j] = std::max(
fuse_relu ? static_cast<int64_t>(C_zero_point) : 0l,
- std::min(255l, rounded));
+ std::min(static_cast<int64_t>(255l), rounded));
}
}
}