Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorYoung Jin Kim <youki@microsoft.com>2019-12-07 01:52:35 +0300
committerYoung Jin Kim <youki@microsoft.com>2019-12-07 01:52:35 +0300
commit6f7ad8fb91e8ab94d35f847035a1b5ab8e5c5b44 (patch)
treea0c9e0d7b9acc9c8e711b36ea184ef494733c7e0 /src
parent21f93c950b8b27918cd59c8f3139fb41ad1bd2c6 (diff)
parent0d7da7c36f50276b5a550d46508516d139522687 (diff)
Fixing merge erroryouki/fp16avx512
Diffstat (limited to 'src')
-rw-r--r--src/FbgemmConv.cc53
-rw-r--r--src/FbgemmFP16.cc204
-rw-r--r--src/FbgemmFP16UKernelsAvx2.cc621
-rw-r--r--src/FbgemmFP16UKernelsAvx2.h4
-rw-r--r--src/FbgemmFP16UKernelsAvx512.cc2558
-rw-r--r--src/FbgemmFP16UKernelsAvx512.h32
-rw-r--r--src/FbgemmI8Depthwise2DAvx2-inl.h1623
-rw-r--r--src/FbgemmI8Depthwise3DAvx2.cc91
-rw-r--r--src/FbgemmI8Depthwise3x3Avx2.cc618
-rw-r--r--src/FbgemmI8DepthwiseAvx2.cc1892
-rw-r--r--src/FbgemmI8DepthwisePerChannelQuantAvx2.cc154
-rw-r--r--src/GroupwiseConvAcc32Avx2.cc9
-rw-r--r--src/PackWeightsForConv.cc4
-rw-r--r--src/codegen_fp16fp32.cc351
14 files changed, 6194 insertions, 2020 deletions
diff --git a/src/FbgemmConv.cc b/src/FbgemmConv.cc
index de833d2..da3fa88 100644
--- a/src/FbgemmConv.cc
+++ b/src/FbgemmConv.cc
@@ -6,9 +6,9 @@
*/
#include <algorithm>
+#include <functional>
#include <numeric>
#include <vector>
-#include <functional>
#include "fbgemm/Fbgemm.h"
namespace fbgemm {
@@ -24,13 +24,16 @@ bool takeDepthWiseFastPath(const conv_param_t<SPATIAL_DIM>& conv_p) {
conv_p.stride.end(),
[](int i) { return i == 1 || i == 2; }) &&
std::all_of(
- conv_p.K.begin(), conv_p.K.end(), [](int i) { return i == 3; }) &&
+ conv_p.K.begin(),
+ conv_p.K.end(),
+ [&conv_p](int i) { return i == conv_p.K[0]; }) &&
+ (conv_p.K[0] == 3 || (SPATIAL_DIM == 2 && conv_p.K[0] == 5)) &&
std::all_of(
conv_p.dilation.begin(),
conv_p.dilation.end(),
[](int i) { return i == 1; }) &&
- std::all_of(conv_p.pad.begin(), conv_p.pad.end(), [](int i) {
- return i == 1;
+ std::all_of(conv_p.pad.begin(), conv_p.pad.end(), [&conv_p](int i) {
+ return i == (conv_p.K[0] - 1) / 2;
});
}
@@ -151,9 +154,9 @@ int fbgemmConv(
"not supported";
throw std::runtime_error(msg);
}
- } else {
+ } else if (SPATIAL_DIM == 2) {
if (processOutputType::QGRANType == QuantizationGranularity::TENSOR) {
- depthwise_3x3_pad_1(
+ depthwise_2d_same_pad(
conv_p.MB, // mini batch
conv_p.IN_DIM[0], // H
conv_p.IN_DIM[1], // W
@@ -178,7 +181,7 @@ int fbgemmConv(
QuantizationGranularity::OUT_CHANNEL ||
processOutputType::QGRANType == QuantizationGranularity::GROUP) {
// The number of channels == groups for depthwise convolutions
- depthwise_3x3_per_channel_quantization_pad_1(
+ depthwise_2d_per_channel_quantization_same_pad(
conv_p.MB, // mini batch
conv_p.IN_DIM[0], // H
conv_p.IN_DIM[1], // W
@@ -204,6 +207,10 @@ int fbgemmConv(
"not supported";
throw std::runtime_error(msg);
}
+ } else {
+ std::string msg =
+ "[FBGEMM_CONV_ERROR] This spatial dim is not supported";
+ throw std::runtime_error(msg);
}
break;
}
@@ -212,20 +219,24 @@ int fbgemmConv(
// std::cout << "Groupwise fast path" << std::endl;
assert(
SPATIAL_DIM == 2 && "Only 2D groupwise convolutions are supported");
- std::vector<int32_t> row_offset_buf(
- rowOffsetBufferSizeGConv<SPATIAL_DIM>(conv_p));
- outProcess.setRowOffsets(row_offset_buf.data());
- fbgemmGroupwiseConv(
- conv_p,
- activations,
- outProcess.getAZeroPoint(),
- row_offset_buf.data(),
- *(packed_weights.getPackedWForGroupwise()),
- out,
- outBuffer,
- outProcess,
- thread_id,
- num_threads);
+ // thread 0 does all the work.
+ // TODO: Remove this when fbgemmGroupwiseConv supports threads
+ if (thread_id == 0) {
+ std::vector<int32_t> row_offset_buf(
+ rowOffsetBufferSizeGConv<SPATIAL_DIM>(conv_p));
+ outProcess.setRowOffsets(row_offset_buf.data());
+ fbgemmGroupwiseConv(
+ conv_p,
+ activations,
+ outProcess.getAZeroPoint(),
+ row_offset_buf.data(),
+ *(packed_weights.getPackedWForGroupwise()),
+ out,
+ outBuffer,
+ outProcess,
+ 0,
+ 1);
+ }
break;
}
case optimized_conv_t::pointwise: {
diff --git a/src/FbgemmFP16.cc b/src/FbgemmFP16.cc
index b034f2c..e40277d 100644
--- a/src/FbgemmFP16.cc
+++ b/src/FbgemmFP16.cc
@@ -13,6 +13,7 @@
#include <utility>
#include "FbgemmFP16UKernelsAvx2.h"
+#include "FbgemmFP16UKernelsAvx512.h"
using namespace std;
@@ -32,26 +33,50 @@ inline void PackA(int nrow, int ncol, const float* from, int ldim, float* to) {
transpose_simd(nrow, ncol, from, ldim, to, nrow);
}
+// Each kernel does the following computation that multiplies
+// mb x k A sub-matrix with k x b_block_cols*64 B sub-matrix
+// for (int j = 0; j < b_block_cols * 64; j += 64) {
+// for (int kk = 0; kk < k; ++k) {
+// for (int i = 0; i < mb; ++i) {
+// c[i][j:j+64] += a[i][kk] * b[kk][j:j+64]
+// }
+// }
+// }
+
struct KernelInfo {
using knl_ptr = funcptr_fp16;
// optimized kernels to cover all cases
// 2 in ?x2 should be the same as kernel_ncol_blocks.
// Here with kernel_ncol_blocks = 2, we can provide up to 6x2 kernels, due to
// the restrictions of ymm register numbers (16).
- static constexpr knl_ptr kernel[7] = {
- nullptr,
- gemmkernel_1x2_AVX2_fA0fB0fC0,
- gemmkernel_2x2_AVX2_fA0fB0fC0,
- gemmkernel_3x2_AVX2_fA0fB0fC0,
- gemmkernel_4x2_AVX2_fA0fB0fC0,
- gemmkernel_5x2_AVX2_fA0fB0fC0,
- gemmkernel_6x2_AVX2_fA0fB0fC0
- };
+ static constexpr knl_ptr kernel_avx2[] = {nullptr,
+ gemmkernel_1x2_AVX2_fA0fB0fC0,
+ gemmkernel_2x2_AVX2_fA0fB0fC0,
+ gemmkernel_3x2_AVX2_fA0fB0fC0,
+ gemmkernel_4x2_AVX2_fA0fB0fC0,
+ gemmkernel_5x2_AVX2_fA0fB0fC0,
+ gemmkernel_6x2_AVX2_fA0fB0fC0};
+
+ static constexpr knl_ptr kernel_avx512[] = {nullptr,
+ gemmkernel_1x2_AVX512_fA0fB0fC0,
+ gemmkernel_2x2_AVX512_fA0fB0fC0,
+ gemmkernel_3x2_AVX512_fA0fB0fC0,
+ gemmkernel_4x2_AVX512_fA0fB0fC0,
+ gemmkernel_5x2_AVX512_fA0fB0fC0,
+ gemmkernel_6x2_AVX512_fA0fB0fC0,
+ gemmkernel_7x2_AVX512_fA0fB0fC0,
+ gemmkernel_8x2_AVX512_fA0fB0fC0,
+ gemmkernel_9x2_AVX512_fA0fB0fC0,
+ gemmkernel_10x2_AVX512_fA0fB0fC0,
+ gemmkernel_11x2_AVX512_fA0fB0fC0,
+ gemmkernel_12x2_AVX512_fA0fB0fC0,
+ gemmkernel_13x2_AVX512_fA0fB0fC0,
+ gemmkernel_14x2_AVX512_fA0fB0fC0};
// autotuned kernel splits for various cases m = 1:mb_max
// may need re-autotuning for new uarch
// clang-format off
- static constexpr int partition[121][2][2] = {
+ static constexpr int partition_avx2[121][2][2] = {
// NOTE: clang-format wants to use a different formatting but the current
// formatting should be easier to read.
{ { 0, 0 }, { 0, 0 } }, // 0
@@ -176,10 +201,139 @@ struct KernelInfo {
{ { 6, 19 }, { 5, 1 } }, // 119
{ { 6, 20 }, { 0, 0 } }, // 120
};
+ static constexpr partition_avx512[121][2][2] = {
+ // NOTE: clang-format wants to use a different formatting but the current
+ // formatting should be easier to read.
+ {
+ {{ { 0, 0 }, { 0, 0 } } }, // 0
+ {{ { 1, 1 }, { 0, 0 } } }, // 1
+ {{ { 2, 1 }, { 0, 0 } } }, // 2
+ {{ { 3, 1 }, { 0, 0 } } }, // 3
+ {{ { 4, 1 }, { 0, 0 } } }, // 4
+ {{ { 5, 1 }, { 0, 0 } } }, // 5
+ {{ { 6, 1 }, { 0, 0 } } }, // 6
+ {{ { 7, 1 }, { 0, 0 } } }, // 7
+ {{ { 8, 1 }, { 0, 0 } } }, // 8
+ {{ { 9, 1 }, { 0, 0 } } }, // 9
+ {{ { 10, 1 }, { 0, 0 } } }, // 10
+ {{ { 11, 1 }, { 0, 0 } } }, // 11
+ {{ { 12, 1 }, { 0, 0 } } }, // 12
+ {{ { 13, 1 }, { 0, 0 } } }, // 13
+ {{ { 14, 1 }, { 0, 0 } } }, // 14
+ {{ { 8, 1 }, { 7, 1 } } }, // 15
+ {{ { 8, 2 }, { 0, 0 } } }, // 16
+ {{ { 9, 1 }, { 8, 1 } } }, // 17
+ {{ { 9, 2 }, { 0, 0 } } }, // 18
+ {{ { 10, 1 }, { 9, 1 } } }, // 19
+ {{ { 10, 2 }, { 0, 0 } } }, // 20
+ {{ { 11, 1 }, { 10, 1 } } }, // 21
+ {{ { 11, 2 }, { 0, 0 } } }, // 22
+ {{ { 12, 1 }, { 11, 1 } } }, // 23
+ {{ { 12, 2 }, { 0, 0 } } }, // 24
+ {{ { 13, 1 }, { 12, 1 } } }, // 25
+ {{ { 13, 2 }, { 0, 0 } } }, // 26
+ {{ { 14, 1 }, { 13, 1 } } }, // 27
+ {{ { 14, 2 }, { 0, 0 } } }, // 28
+ {{ { 10, 2 }, { 9, 1 } } }, // 29
+ {{ { 10, 3 }, { 0, 0 } } }, // 30
+ {{ { 11, 2 }, { 9, 1 } } }, // 31
+ {{ { 11, 2 }, { 10, 1 } } }, // 32
+ {{ { 11, 3 }, { 0, 0 } } }, // 33
+ {{ { 12, 2 }, { 10, 1 } } }, // 34
+ {{ { 12, 2 }, { 11, 1 } } }, // 35
+ {{ { 12, 3 }, { 0, 0 } } }, // 36
+ {{ { 13, 2 }, { 11, 1 } } }, // 37
+ {{ { 13, 2 }, { 12, 1 } } }, // 38
+ {{ { 13, 3 }, { 0, 0 } } }, // 39
+ {{ { 14, 2 }, { 12, 1 } } }, // 40
+ {{ { 14, 2 }, { 13, 1 } } }, // 41
+ {{ { 14, 3 }, { 0, 0 } } }, // 42
+ {{ { 11, 3 }, { 10, 1 } } }, // 43
+ {{ { 11, 4 }, { 0, 0 } } }, // 44
+ {{ { 12, 3 }, { 9, 1 } } }, // 45
+ {{ { 12, 3 }, { 10, 1 } } }, // 46
+ {{ { 12, 3 }, { 11, 1 } } }, // 47
+ {{ { 12, 4 }, { 0, 0 } } }, // 48
+ {{ { 13, 3 }, { 10, 1 } } }, // 49
+ {{ { 13, 3 }, { 11, 1 } } }, // 50
+ {{ { 13, 3 }, { 12, 1 } } }, // 51
+ {{ { 13, 4 }, { 0, 0 } } }, // 52
+ {{ { 14, 3 }, { 11, 1 } } }, // 53
+ {{ { 14, 3 }, { 12, 1 } } }, // 54
+ {{ { 14, 3 }, { 13, 1 } } }, // 55
+ {{ { 14, 4 }, { 0, 0 } } }, // 56
+ {{ { 12, 4 }, { 9, 1 } } }, // 57
+ {{ { 12, 4 }, { 10, 1 } } }, // 58
+ {{ { 12, 4 }, { 11, 1 } } }, // 59
+ {{ { 12, 5 }, { 0, 0 } } }, // 60
+ {{ { 13, 4 }, { 9, 1 } } }, // 61
+ {{ { 13, 4 }, { 10, 1 } } }, // 62
+ {{ { 13, 4 }, { 11, 1 } } }, // 63
+ {{ { 13, 4 }, { 12, 1 } } }, // 64
+ {{ { 13, 5 }, { 0, 0 } } }, // 65
+ {{ { 14, 4 }, { 10, 1 } } }, // 66
+ {{ { 14, 4 }, { 11, 1 } } }, // 67
+ {{ { 14, 4 }, { 12, 1 } } }, // 68
+ {{ { 14, 4 }, { 13, 1 } } }, // 69
+ {{ { 14, 5 }, { 0, 0 } } }, // 70
+ {{ { 12, 5 }, { 11, 1 } } }, // 71
+ {{ { 12, 6 }, { 0, 0 } } }, // 72
+ {{ { 13, 5 }, { 8, 1 } } }, // 73
+ {{ { 13, 5 }, { 9, 1 } } }, // 74
+ {{ { 13, 5 }, { 10, 1 } } }, // 75
+ {{ { 13, 5 }, { 11, 1 } } }, // 76
+ {{ { 13, 5 }, { 12, 1 } } }, // 77
+ {{ { 13, 6 }, { 0, 0 } } }, // 78
+ {{ { 14, 5 }, { 9, 1 } } }, // 79
+ {{ { 14, 5 }, { 10, 1 } } }, // 80
+ {{ { 14, 5 }, { 11, 1 } } }, // 81
+ {{ { 14, 5 }, { 12, 1 } } }, // 82
+ {{ { 14, 5 }, { 13, 1 } } }, // 83
+ {{ { 14, 6 }, { 0, 0 } } }, // 84
+ {{ { 13, 6 }, { 7, 1 } } }, // 85
+ {{ { 13, 6 }, { 8, 1 } } }, // 86
+ {{ { 13, 6 }, { 9, 1 } } }, // 87
+ {{ { 13, 6 }, { 10, 1 } } }, // 88
+ {{ { 13, 6 }, { 11, 1 } } }, // 89
+ {{ { 13, 6 }, { 12, 1 } } }, // 90
+ {{ { 13, 7 }, { 0, 0 } } }, // 91
+ {{ { 14, 6 }, { 8, 1 } } }, // 92
+ {{ { 14, 6 }, { 9, 1 } } }, // 93
+ {{ { 14, 6 }, { 10, 1 } } }, // 94
+ {{ { 14, 6 }, { 11, 1 } } }, // 95
+ {{ { 14, 6 }, { 12, 1 } } }, // 96
+ {{ { 14, 6 }, { 13, 1 } } }, // 97
+ {{ { 14, 7 }, { 0, 0 } } }, // 98
+ {{ { 13, 7 }, { 8, 1 } } }, // 99
+ {{ { 13, 7 }, { 9, 1 } } }, // 100
+ {{ { 13, 7 }, { 10, 1 } } }, // 101
+ {{ { 13, 7 }, { 11, 1 } } }, // 102
+ {{ { 13, 7 }, { 12, 1 } } }, // 103
+ {{ { 13, 8 }, { 0, 0 } } }, // 104
+ {{ { 14, 7 }, { 7, 1 } } }, // 105
+ {{ { 14, 7 }, { 8, 1 } } }, // 106
+ {{ { 14, 7 }, { 9, 1 } } }, // 107
+ {{ { 14, 7 }, { 10, 1 } } }, // 108
+ {{ { 14, 7 }, { 11, 1 } } }, // 109
+ {{ { 14, 7 }, { 12, 1 } } }, // 110
+ {{ { 14, 7 }, { 13, 1 } } }, // 111
+ {{ { 14, 8 }, { 0, 0 } } }, // 112
+ {{ { 13, 8 }, { 9, 1 } } }, // 113
+ {{ { 13, 8 }, { 10, 1 } } }, // 114
+ {{ { 13, 8 }, { 11, 1 } } }, // 115
+ {{ { 13, 8 }, { 12, 1 } } }, // 116
+ {{ { 13, 9 }, { 0, 0 } } }, // 117
+ {{ { 14, 8 }, { 6, 1 } } }, // 118
+ {{ { 14, 8 }, { 7, 1 } } }, // 119
+ {{ { 14, 8 }, { 8, 1 } } }, // 120
+ }
+ };
// clang-format on
};
-constexpr KernelInfo::knl_ptr KernelInfo::kernel[7];;
-constexpr int KernelInfo::partition[121][2][2];
+constexpr KernelInfo::knl_ptr KernelInfo::kernel_avx2[];
+constexpr KernelInfo::knl_ptr KernelInfo::kernel_avx512[];
+constexpr KernelInfo::partition[121][2][2] KernelInfo::partition_avx2;
+constexpr KernelInfo::partition[121][2][2] KernelInfo::partition_avx512;
// autotuned kernel splits for various cases m = 1:mb_max
void cblas_gemm_compute(
@@ -200,7 +354,8 @@ void cblas_gemm_compute(
// constants
const int n = Bp.numCols(), k = Bp.numRows(), ldc = n;
const int mb_max = 120;
- constexpr int simd_width = 8;
+ bool has_avx512 = fbgemmHasAvx512Support();
+ int simd_width = has_avx512 ? 16 : 8;
int kernel_ncol_blocks = Bp.kernelNumColBlocks();
int kernel_ncols = kernel_ncol_blocks * simd_width;
@@ -210,6 +365,11 @@ void cblas_gemm_compute(
GemmParams gp;
+ const funcptr_fp16* kernels =
+ has_avx512 ? KernelInfo::kernel_avx512 : KernelInfo::kernel_avx2;
+ const array<array<array<int, 2>, 2>, 121>& partition =
+ has_avx512 ? KernelInfo::partition_avx512 : KernelInfo::partition_avx2;
+
int i_begin, i_end;
// fbgemmGetRange(num_threads, thread_id, m, 1, i_begin, i_end);
i_begin = 0;
@@ -236,8 +396,8 @@ void cblas_gemm_compute(
auto m1 = m0;
for (auto c = 0; c < 2; c++) {
- auto kernel_nrows = KernelInfo::partition[mb][c][0];
- auto nkernel_nrows = KernelInfo::partition[mb][c][1];
+ auto kernel_nrows = partition[mb][c][0];
+ auto nkernel_nrows = partition[mb][c][1];
auto m_start = m1, m_end = m1 + kernel_nrows * nkernel_nrows;
for (auto m2 = m_start; m2 < m_end; m2 += kernel_nrows) {
@@ -271,7 +431,7 @@ void cblas_gemm_compute(
gp.C += Bp.blockColSize() * jb_begin;
gp.b_block_cols = jb_end - jb_begin;
if (gp.b_block_cols) {
- KernelInfo::kernel[kernel_nrows](&gp);
+ kernels[kernel_nrows](&gp);
}
} else {
int last_blk_col = nbcol * Bp.blockColSize();
@@ -283,7 +443,7 @@ void cblas_gemm_compute(
gp.C += Bp.blockColSize() * jb_begin;
gp.b_block_cols = jb_end - jb_begin;
if (gp.b_block_cols) {
- KernelInfo::kernel[kernel_nrows](&gp);
+ kernels[kernel_nrows](&gp);
}
}
@@ -297,19 +457,21 @@ void cblas_gemm_compute(
gp.B = &(Bp(k_ind, last_blk_col));
gp.C = &C[m2 * ldc + last_blk_col];
gp.b_block_cols = 1;
- KernelInfo::kernel[kernel_nrows](&gp);
+ kernels[kernel_nrows](&gp);
} else {
// small temporary buffer: the size should be larger than the
// required kernel_nrow x kernel_ncols elements computed in the
// registers.
- float c_tmp[16 * 24] = {0};
- assert((16 * 24) > kernel_nrows * kernel_ncols);
+ float c_tmp[14 * 32] = {0};
+ assert(
+ sizeof(c_tmp) / sizeof(c_tmp[0]) >=
+ kernel_nrows * kernel_ncols);
gp.B = &(Bp(k_ind, last_blk_col));
gp.C = c_tmp;
gp.ldc = kernel_ncols * sizeof(C[0]);
gp.b_block_cols = 1;
- KernelInfo::kernel[kernel_nrows](&gp);
+ kernels[kernel_nrows](&gp);
for (int i = 0; i < kernel_nrows; i++) {
// Todo: use assembly
for (int j = last_blk_col; j < n; j++) {
diff --git a/src/FbgemmFP16UKernelsAvx2.cc b/src/FbgemmFP16UKernelsAvx2.cc
index 5f7492f..204fa03 100644
--- a/src/FbgemmFP16UKernelsAvx2.cc
+++ b/src/FbgemmFP16UKernelsAvx2.cc
@@ -32,6 +32,7 @@ void NOINLINE_ATTR gemmkernel_1x2_AVX2_fA0fB0fC0(GemmParams* gp) {
// b_block_size
uint64_t rsi = *(uint64_t *)((char*)r14 + 64); //"mov rsi, [r14 + 64]\t\n"
// Make copies of A and C
+<<<<<<< HEAD
float* rax = r9; //"mov rax, r9\t\n"
float* rcx = r12; //"mov rcx, r12\t\n"
@@ -75,6 +76,76 @@ void NOINLINE_ATTR gemmkernel_1x2_AVX2_fA0fB0fC0(GemmParams* gp) {
r12 = rcx; //"mov r12, rcx\t\n"
r9 = rax; //"mov r9, rax\t\n"
} //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n"
+=======
+ "mov rax, r9\t\n"
+ "mov rcx, r12\t\n"
+
+ "mov rbx, 0\t\n"
+ "loop_outter%=:\t\n"
+ "mov r14, 0\t\n"
+ "vxorps ymm0,ymm0,ymm0\t\n"
+ "vxorps ymm1,ymm1,ymm1\t\n"
+
+
+ "loop_inner%=:\t\n"
+
+ "vcvtph2ps ymm3,XMMWORD PTR [r10 + 0]\t\n"
+ "vcvtph2ps ymm4,XMMWORD PTR [r10 + 16]\t\n"
+ "vbroadcastss ymm2,DWORD PTR [r9+0]\t\n"
+ "vfmadd231ps ymm0,ymm3,ymm2\t\n"
+ "vfmadd231ps ymm1,ymm4,ymm2\t\n"
+ "add r9,4\t\n"
+ "add r10,32\t\n"
+ "inc r14\t\n"
+ "cmp r14, r8\t\n"
+ "jl loop_inner%=\t\n"
+
+ "L_exit%=:\t\n"
+
+ "cmp rdx, 1\t\n"
+ "je L_accum%=\t\n"
+ // Dump C
+ "vmovups ymmword PTR [r12 + 0], ymm0\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm1\t\n"
+ "add r12, r13\t\n"
+ "jmp L_done%=\t\n"
+
+ "L_accum%=:\t\n"
+ // Dump C with accumulate
+ "vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+ "vfmadd231ps ymm0,ymm15,ymmword PTR [r12 + 0]\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm0\t\n"
+ "vfmadd231ps ymm1,ymm15,ymmword PTR [r12 + 32]\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm1\t\n"
+ "add r12, r13\t\n"
+
+ "L_done%=:\t\n"
+
+ // next outer iteration
+ "add rcx, 64\t\n"
+ "mov r12, rcx\t\n"
+ "mov r9, rax\t\n"
+ "inc rbx\t\n"
+ "cmp rbx, rdi\t\n"
+ "jl loop_outter%=\t\n"
+ :
+ : [gp] "rm"(gp)
+ : "r8",
+ "r9",
+ "r10",
+ "r11",
+ "r15",
+ "r13",
+ "r14",
+ "rax",
+ "rcx",
+ "rdx",
+ "rsi",
+ "rdi",
+ "rbx",
+ "r12",
+ "memory");
+>>>>>>> upstream
}
void NOINLINE_ATTR gemmkernel_2x2_AVX2_fA0fB0fC0(GemmParams* gp) {
@@ -100,6 +171,7 @@ void NOINLINE_ATTR gemmkernel_2x2_AVX2_fA0fB0fC0(GemmParams* gp) {
// b_block_size
uint64_t rsi = *(uint64_t *)((char*)r14 + 64); //"mov rsi, [r14 + 64]\t\n"
// Make copies of A and C
+<<<<<<< HEAD
float* rax = r9; //"mov rax, r9\t\n"
float* rcx = r12; //"mov rcx, r12\t\n"
@@ -158,6 +230,89 @@ void NOINLINE_ATTR gemmkernel_2x2_AVX2_fA0fB0fC0(GemmParams* gp) {
r12 = rcx; //"mov r12, rcx\t\n"
r9 = rax; //"mov r9, rax\t\n"
} //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n"
+=======
+ "mov rax, r9\t\n"
+ "mov rcx, r12\t\n"
+
+ "mov rbx, 0\t\n"
+ "loop_outter%=:\t\n"
+ "mov r14, 0\t\n"
+ "vxorps ymm0,ymm0,ymm0\t\n"
+ "vxorps ymm1,ymm1,ymm1\t\n"
+ "vxorps ymm2,ymm2,ymm2\t\n"
+ "vxorps ymm3,ymm3,ymm3\t\n"
+
+
+ "loop_inner%=:\t\n"
+
+ "vcvtph2ps ymm5,XMMWORD PTR [r10 + 0]\t\n"
+ "vcvtph2ps ymm6,XMMWORD PTR [r10 + 16]\t\n"
+ "vbroadcastss ymm4,DWORD PTR [r9+0]\t\n"
+ "vfmadd231ps ymm0,ymm5,ymm4\t\n"
+ "vfmadd231ps ymm1,ymm6,ymm4\t\n"
+ "vbroadcastss ymm4,DWORD PTR [r9+4]\t\n"
+ "vfmadd231ps ymm2,ymm5,ymm4\t\n"
+ "vfmadd231ps ymm3,ymm6,ymm4\t\n"
+ "add r9,8\t\n"
+ "add r10,32\t\n"
+ "inc r14\t\n"
+ "cmp r14, r8\t\n"
+ "jl loop_inner%=\t\n"
+
+ "L_exit%=:\t\n"
+
+ "cmp rdx, 1\t\n"
+ "je L_accum%=\t\n"
+ // Dump C
+ "vmovups ymmword PTR [r12 + 0], ymm0\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm1\t\n"
+ "add r12, r13\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm2\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm3\t\n"
+ "add r12, r13\t\n"
+ "jmp L_done%=\t\n"
+
+ "L_accum%=:\t\n"
+ // Dump C with accumulate
+ "vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+ "vfmadd231ps ymm0,ymm15,ymmword PTR [r12 + 0]\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm0\t\n"
+ "vfmadd231ps ymm1,ymm15,ymmword PTR [r12 + 32]\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm1\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps ymm2,ymm15,ymmword PTR [r12 + 0]\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm2\t\n"
+ "vfmadd231ps ymm3,ymm15,ymmword PTR [r12 + 32]\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm3\t\n"
+ "add r12, r13\t\n"
+
+ "L_done%=:\t\n"
+
+ // next outer iteration
+ "add rcx, 64\t\n"
+ "mov r12, rcx\t\n"
+ "mov r9, rax\t\n"
+ "inc rbx\t\n"
+ "cmp rbx, rdi\t\n"
+ "jl loop_outter%=\t\n"
+ :
+ : [gp] "rm"(gp)
+ : "r8",
+ "r9",
+ "r10",
+ "r11",
+ "r15",
+ "r13",
+ "r14",
+ "rax",
+ "rcx",
+ "rdx",
+ "rsi",
+ "rdi",
+ "rbx",
+ "r12",
+ "memory");
+>>>>>>> upstream
}
void NOINLINE_ATTR gemmkernel_3x2_AVX2_fA0fB0fC0(GemmParams* gp) {
@@ -183,6 +338,7 @@ void NOINLINE_ATTR gemmkernel_3x2_AVX2_fA0fB0fC0(GemmParams* gp) {
// b_block_size
uint64_t rsi = *(uint64_t *)((char*)r14 + 64); //"mov rsi, [r14 + 64]\t\n"
// Make copies of A and C
+<<<<<<< HEAD
float* rax = r9; //"mov rax, r9\t\n"
float* rcx = r12; //"mov rcx, r12\t\n"
@@ -256,6 +412,102 @@ void NOINLINE_ATTR gemmkernel_3x2_AVX2_fA0fB0fC0(GemmParams* gp) {
r12 = rcx; //"mov r12, rcx\t\n"
r9 = rax; //"mov r9, rax\t\n"
} //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n"
+=======
+ "mov rax, r9\t\n"
+ "mov rcx, r12\t\n"
+
+ "mov rbx, 0\t\n"
+ "loop_outter%=:\t\n"
+ "mov r14, 0\t\n"
+ "vxorps ymm0,ymm0,ymm0\t\n"
+ "vxorps ymm1,ymm1,ymm1\t\n"
+ "vxorps ymm2,ymm2,ymm2\t\n"
+ "vxorps ymm3,ymm3,ymm3\t\n"
+ "vxorps ymm4,ymm4,ymm4\t\n"
+ "vxorps ymm5,ymm5,ymm5\t\n"
+
+
+ "loop_inner%=:\t\n"
+
+ "vcvtph2ps ymm7,XMMWORD PTR [r10 + 0]\t\n"
+ "vcvtph2ps ymm8,XMMWORD PTR [r10 + 16]\t\n"
+ "vbroadcastss ymm6,DWORD PTR [r9+0]\t\n"
+ "vfmadd231ps ymm0,ymm7,ymm6\t\n"
+ "vfmadd231ps ymm1,ymm8,ymm6\t\n"
+ "vbroadcastss ymm6,DWORD PTR [r9+4]\t\n"
+ "vfmadd231ps ymm2,ymm7,ymm6\t\n"
+ "vfmadd231ps ymm3,ymm8,ymm6\t\n"
+ "vbroadcastss ymm6,DWORD PTR [r9+8]\t\n"
+ "vfmadd231ps ymm4,ymm7,ymm6\t\n"
+ "vfmadd231ps ymm5,ymm8,ymm6\t\n"
+ "add r9,12\t\n"
+ "add r10,32\t\n"
+ "inc r14\t\n"
+ "cmp r14, r8\t\n"
+ "jl loop_inner%=\t\n"
+
+ "L_exit%=:\t\n"
+
+ "cmp rdx, 1\t\n"
+ "je L_accum%=\t\n"
+ // Dump C
+ "vmovups ymmword PTR [r12 + 0], ymm0\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm1\t\n"
+ "add r12, r13\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm2\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm3\t\n"
+ "add r12, r13\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm4\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm5\t\n"
+ "add r12, r13\t\n"
+ "jmp L_done%=\t\n"
+
+ "L_accum%=:\t\n"
+ // Dump C with accumulate
+ "vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+ "vfmadd231ps ymm0,ymm15,ymmword PTR [r12 + 0]\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm0\t\n"
+ "vfmadd231ps ymm1,ymm15,ymmword PTR [r12 + 32]\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm1\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps ymm2,ymm15,ymmword PTR [r12 + 0]\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm2\t\n"
+ "vfmadd231ps ymm3,ymm15,ymmword PTR [r12 + 32]\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm3\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps ymm4,ymm15,ymmword PTR [r12 + 0]\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm4\t\n"
+ "vfmadd231ps ymm5,ymm15,ymmword PTR [r12 + 32]\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm5\t\n"
+ "add r12, r13\t\n"
+
+ "L_done%=:\t\n"
+
+ // next outer iteration
+ "add rcx, 64\t\n"
+ "mov r12, rcx\t\n"
+ "mov r9, rax\t\n"
+ "inc rbx\t\n"
+ "cmp rbx, rdi\t\n"
+ "jl loop_outter%=\t\n"
+ :
+ : [gp] "rm"(gp)
+ : "r8",
+ "r9",
+ "r10",
+ "r11",
+ "r15",
+ "r13",
+ "r14",
+ "rax",
+ "rcx",
+ "rdx",
+ "rsi",
+ "rdi",
+ "rbx",
+ "r12",
+ "memory");
+>>>>>>> upstream
}
void NOINLINE_ATTR gemmkernel_4x2_AVX2_fA0fB0fC0(GemmParams* gp) {
@@ -281,6 +533,7 @@ void NOINLINE_ATTR gemmkernel_4x2_AVX2_fA0fB0fC0(GemmParams* gp) {
// b_block_size
uint64_t rsi = *(uint64_t *)((char*)r14 + 64); //"mov rsi, [r14 + 64]\t\n"
// Make copies of A and C
+<<<<<<< HEAD
float* rax = r9; //"mov rax, r9\t\n"
float* rcx = r12; //"mov rcx, r12\t\n"
@@ -369,6 +622,115 @@ void NOINLINE_ATTR gemmkernel_4x2_AVX2_fA0fB0fC0(GemmParams* gp) {
r12 = rcx; //"mov r12, rcx\t\n"
r9 = rax; //"mov r9, rax\t\n"
} //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n"
+=======
+ "mov rax, r9\t\n"
+ "mov rcx, r12\t\n"
+
+ "mov rbx, 0\t\n"
+ "loop_outter%=:\t\n"
+ "mov r14, 0\t\n"
+ "vxorps ymm0,ymm0,ymm0\t\n"
+ "vxorps ymm1,ymm1,ymm1\t\n"
+ "vxorps ymm2,ymm2,ymm2\t\n"
+ "vxorps ymm3,ymm3,ymm3\t\n"
+ "vxorps ymm4,ymm4,ymm4\t\n"
+ "vxorps ymm5,ymm5,ymm5\t\n"
+ "vxorps ymm6,ymm6,ymm6\t\n"
+ "vxorps ymm7,ymm7,ymm7\t\n"
+
+
+ "loop_inner%=:\t\n"
+
+ "vcvtph2ps ymm9,XMMWORD PTR [r10 + 0]\t\n"
+ "vcvtph2ps ymm10,XMMWORD PTR [r10 + 16]\t\n"
+ "vbroadcastss ymm8,DWORD PTR [r9+0]\t\n"
+ "vfmadd231ps ymm0,ymm9,ymm8\t\n"
+ "vfmadd231ps ymm1,ymm10,ymm8\t\n"
+ "vbroadcastss ymm8,DWORD PTR [r9+4]\t\n"
+ "vfmadd231ps ymm2,ymm9,ymm8\t\n"
+ "vfmadd231ps ymm3,ymm10,ymm8\t\n"
+ "vbroadcastss ymm8,DWORD PTR [r9+8]\t\n"
+ "vfmadd231ps ymm4,ymm9,ymm8\t\n"
+ "vfmadd231ps ymm5,ymm10,ymm8\t\n"
+ "vbroadcastss ymm8,DWORD PTR [r9+12]\t\n"
+ "vfmadd231ps ymm6,ymm9,ymm8\t\n"
+ "vfmadd231ps ymm7,ymm10,ymm8\t\n"
+ "add r9,16\t\n"
+ "add r10,32\t\n"
+ "inc r14\t\n"
+ "cmp r14, r8\t\n"
+ "jl loop_inner%=\t\n"
+
+ "L_exit%=:\t\n"
+
+ "cmp rdx, 1\t\n"
+ "je L_accum%=\t\n"
+ // Dump C
+ "vmovups ymmword PTR [r12 + 0], ymm0\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm1\t\n"
+ "add r12, r13\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm2\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm3\t\n"
+ "add r12, r13\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm4\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm5\t\n"
+ "add r12, r13\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm6\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm7\t\n"
+ "add r12, r13\t\n"
+ "jmp L_done%=\t\n"
+
+ "L_accum%=:\t\n"
+ // Dump C with accumulate
+ "vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+ "vfmadd231ps ymm0,ymm15,ymmword PTR [r12 + 0]\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm0\t\n"
+ "vfmadd231ps ymm1,ymm15,ymmword PTR [r12 + 32]\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm1\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps ymm2,ymm15,ymmword PTR [r12 + 0]\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm2\t\n"
+ "vfmadd231ps ymm3,ymm15,ymmword PTR [r12 + 32]\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm3\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps ymm4,ymm15,ymmword PTR [r12 + 0]\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm4\t\n"
+ "vfmadd231ps ymm5,ymm15,ymmword PTR [r12 + 32]\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm5\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps ymm6,ymm15,ymmword PTR [r12 + 0]\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm6\t\n"
+ "vfmadd231ps ymm7,ymm15,ymmword PTR [r12 + 32]\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm7\t\n"
+ "add r12, r13\t\n"
+
+ "L_done%=:\t\n"
+
+ // next outer iteration
+ "add rcx, 64\t\n"
+ "mov r12, rcx\t\n"
+ "mov r9, rax\t\n"
+ "inc rbx\t\n"
+ "cmp rbx, rdi\t\n"
+ "jl loop_outter%=\t\n"
+ :
+ : [gp] "rm"(gp)
+ : "r8",
+ "r9",
+ "r10",
+ "r11",
+ "r15",
+ "r13",
+ "r14",
+ "rax",
+ "rcx",
+ "rdx",
+ "rsi",
+ "rdi",
+ "rbx",
+ "r12",
+ "memory");
+>>>>>>> upstream
}
void NOINLINE_ATTR gemmkernel_5x2_AVX2_fA0fB0fC0(GemmParams* gp) {
@@ -394,6 +756,7 @@ void NOINLINE_ATTR gemmkernel_5x2_AVX2_fA0fB0fC0(GemmParams* gp) {
// b_block_size
uint64_t rsi = *(uint64_t *)((char*)r14 + 64); //"mov rsi, [r14 + 64]\t\n"
// Make copies of A and C
+<<<<<<< HEAD
float* rax = r9; //"mov rax, r9\t\n"
float* rcx = r12; //"mov rcx, r12\t\n"
@@ -497,6 +860,128 @@ void NOINLINE_ATTR gemmkernel_5x2_AVX2_fA0fB0fC0(GemmParams* gp) {
r12 = rcx; //"mov r12, rcx\t\n"
r9 = rax; //"mov r9, rax\t\n"
} //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n"
+=======
+ "mov rax, r9\t\n"
+ "mov rcx, r12\t\n"
+
+ "mov rbx, 0\t\n"
+ "loop_outter%=:\t\n"
+ "mov r14, 0\t\n"
+ "vxorps ymm0,ymm0,ymm0\t\n"
+ "vxorps ymm1,ymm1,ymm1\t\n"
+ "vxorps ymm2,ymm2,ymm2\t\n"
+ "vxorps ymm3,ymm3,ymm3\t\n"
+ "vxorps ymm4,ymm4,ymm4\t\n"
+ "vxorps ymm5,ymm5,ymm5\t\n"
+ "vxorps ymm6,ymm6,ymm6\t\n"
+ "vxorps ymm7,ymm7,ymm7\t\n"
+ "vxorps ymm8,ymm8,ymm8\t\n"
+ "vxorps ymm9,ymm9,ymm9\t\n"
+
+
+ "loop_inner%=:\t\n"
+
+ "vcvtph2ps ymm11,XMMWORD PTR [r10 + 0]\t\n"
+ "vcvtph2ps ymm12,XMMWORD PTR [r10 + 16]\t\n"
+ "vbroadcastss ymm10,DWORD PTR [r9+0]\t\n"
+ "vfmadd231ps ymm0,ymm11,ymm10\t\n"
+ "vfmadd231ps ymm1,ymm12,ymm10\t\n"
+ "vbroadcastss ymm10,DWORD PTR [r9+4]\t\n"
+ "vfmadd231ps ymm2,ymm11,ymm10\t\n"
+ "vfmadd231ps ymm3,ymm12,ymm10\t\n"
+ "vbroadcastss ymm10,DWORD PTR [r9+8]\t\n"
+ "vfmadd231ps ymm4,ymm11,ymm10\t\n"
+ "vfmadd231ps ymm5,ymm12,ymm10\t\n"
+ "vbroadcastss ymm10,DWORD PTR [r9+12]\t\n"
+ "vfmadd231ps ymm6,ymm11,ymm10\t\n"
+ "vfmadd231ps ymm7,ymm12,ymm10\t\n"
+ "vbroadcastss ymm10,DWORD PTR [r9+16]\t\n"
+ "vfmadd231ps ymm8,ymm11,ymm10\t\n"
+ "vfmadd231ps ymm9,ymm12,ymm10\t\n"
+ "add r9,20\t\n"
+ "add r10,32\t\n"
+ "inc r14\t\n"
+ "cmp r14, r8\t\n"
+ "jl loop_inner%=\t\n"
+
+ "L_exit%=:\t\n"
+
+ "cmp rdx, 1\t\n"
+ "je L_accum%=\t\n"
+ // Dump C
+ "vmovups ymmword PTR [r12 + 0], ymm0\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm1\t\n"
+ "add r12, r13\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm2\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm3\t\n"
+ "add r12, r13\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm4\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm5\t\n"
+ "add r12, r13\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm6\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm7\t\n"
+ "add r12, r13\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm8\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm9\t\n"
+ "add r12, r13\t\n"
+ "jmp L_done%=\t\n"
+
+ "L_accum%=:\t\n"
+ // Dump C with accumulate
+ "vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+ "vfmadd231ps ymm0,ymm15,ymmword PTR [r12 + 0]\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm0\t\n"
+ "vfmadd231ps ymm1,ymm15,ymmword PTR [r12 + 32]\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm1\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps ymm2,ymm15,ymmword PTR [r12 + 0]\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm2\t\n"
+ "vfmadd231ps ymm3,ymm15,ymmword PTR [r12 + 32]\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm3\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps ymm4,ymm15,ymmword PTR [r12 + 0]\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm4\t\n"
+ "vfmadd231ps ymm5,ymm15,ymmword PTR [r12 + 32]\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm5\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps ymm6,ymm15,ymmword PTR [r12 + 0]\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm6\t\n"
+ "vfmadd231ps ymm7,ymm15,ymmword PTR [r12 + 32]\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm7\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps ymm8,ymm15,ymmword PTR [r12 + 0]\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm8\t\n"
+ "vfmadd231ps ymm9,ymm15,ymmword PTR [r12 + 32]\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm9\t\n"
+ "add r12, r13\t\n"
+
+ "L_done%=:\t\n"
+
+ // next outer iteration
+ "add rcx, 64\t\n"
+ "mov r12, rcx\t\n"
+ "mov r9, rax\t\n"
+ "inc rbx\t\n"
+ "cmp rbx, rdi\t\n"
+ "jl loop_outter%=\t\n"
+ :
+ : [gp] "rm"(gp)
+ : "r8",
+ "r9",
+ "r10",
+ "r11",
+ "r15",
+ "r13",
+ "r14",
+ "rax",
+ "rcx",
+ "rdx",
+ "rsi",
+ "rdi",
+ "rbx",
+ "r12",
+ "memory");
+>>>>>>> upstream
}
void NOINLINE_ATTR gemmkernel_6x2_AVX2_fA0fB0fC0(GemmParams* gp) {
@@ -522,6 +1007,7 @@ void NOINLINE_ATTR gemmkernel_6x2_AVX2_fA0fB0fC0(GemmParams* gp) {
// b_block_size
uint64_t rsi = *(uint64_t *)((char*)r14 + 64); //"mov rsi, [r14 + 64]\t\n"
// Make copies of A and C
+<<<<<<< HEAD
float* rax = r9; //"mov rax, r9\t\n"
float* rcx = r12; //"mov rcx, r12\t\n"
@@ -640,6 +1126,141 @@ void NOINLINE_ATTR gemmkernel_6x2_AVX2_fA0fB0fC0(GemmParams* gp) {
r12 = rcx; //"mov r12, rcx\t\n"
r9 = rax; //"mov r9, rax\t\n"
} //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n"
+=======
+ "mov rax, r9\t\n"
+ "mov rcx, r12\t\n"
+
+ "mov rbx, 0\t\n"
+ "loop_outter%=:\t\n"
+ "mov r14, 0\t\n"
+ "vxorps ymm0,ymm0,ymm0\t\n"
+ "vxorps ymm1,ymm1,ymm1\t\n"
+ "vxorps ymm2,ymm2,ymm2\t\n"
+ "vxorps ymm3,ymm3,ymm3\t\n"
+ "vxorps ymm4,ymm4,ymm4\t\n"
+ "vxorps ymm5,ymm5,ymm5\t\n"
+ "vxorps ymm6,ymm6,ymm6\t\n"
+ "vxorps ymm7,ymm7,ymm7\t\n"
+ "vxorps ymm8,ymm8,ymm8\t\n"
+ "vxorps ymm9,ymm9,ymm9\t\n"
+ "vxorps ymm10,ymm10,ymm10\t\n"
+ "vxorps ymm11,ymm11,ymm11\t\n"
+
+
+ "loop_inner%=:\t\n"
+
+ "vcvtph2ps ymm13,XMMWORD PTR [r10 + 0]\t\n"
+ "vcvtph2ps ymm14,XMMWORD PTR [r10 + 16]\t\n"
+ "vbroadcastss ymm12,DWORD PTR [r9+0]\t\n"
+ "vfmadd231ps ymm0,ymm13,ymm12\t\n"
+ "vfmadd231ps ymm1,ymm14,ymm12\t\n"
+ "vbroadcastss ymm12,DWORD PTR [r9+4]\t\n"
+ "vfmadd231ps ymm2,ymm13,ymm12\t\n"
+ "vfmadd231ps ymm3,ymm14,ymm12\t\n"
+ "vbroadcastss ymm12,DWORD PTR [r9+8]\t\n"
+ "vfmadd231ps ymm4,ymm13,ymm12\t\n"
+ "vfmadd231ps ymm5,ymm14,ymm12\t\n"
+ "vbroadcastss ymm12,DWORD PTR [r9+12]\t\n"
+ "vfmadd231ps ymm6,ymm13,ymm12\t\n"
+ "vfmadd231ps ymm7,ymm14,ymm12\t\n"
+ "vbroadcastss ymm12,DWORD PTR [r9+16]\t\n"
+ "vfmadd231ps ymm8,ymm13,ymm12\t\n"
+ "vfmadd231ps ymm9,ymm14,ymm12\t\n"
+ "vbroadcastss ymm12,DWORD PTR [r9+20]\t\n"
+ "vfmadd231ps ymm10,ymm13,ymm12\t\n"
+ "vfmadd231ps ymm11,ymm14,ymm12\t\n"
+ "add r9,24\t\n"
+ "add r10,32\t\n"
+ "inc r14\t\n"
+ "cmp r14, r8\t\n"
+ "jl loop_inner%=\t\n"
+
+ "L_exit%=:\t\n"
+
+ "cmp rdx, 1\t\n"
+ "je L_accum%=\t\n"
+ // Dump C
+ "vmovups ymmword PTR [r12 + 0], ymm0\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm1\t\n"
+ "add r12, r13\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm2\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm3\t\n"
+ "add r12, r13\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm4\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm5\t\n"
+ "add r12, r13\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm6\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm7\t\n"
+ "add r12, r13\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm8\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm9\t\n"
+ "add r12, r13\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm10\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm11\t\n"
+ "add r12, r13\t\n"
+ "jmp L_done%=\t\n"
+
+ "L_accum%=:\t\n"
+ // Dump C with accumulate
+ "vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+ "vfmadd231ps ymm0,ymm15,ymmword PTR [r12 + 0]\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm0\t\n"
+ "vfmadd231ps ymm1,ymm15,ymmword PTR [r12 + 32]\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm1\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps ymm2,ymm15,ymmword PTR [r12 + 0]\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm2\t\n"
+ "vfmadd231ps ymm3,ymm15,ymmword PTR [r12 + 32]\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm3\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps ymm4,ymm15,ymmword PTR [r12 + 0]\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm4\t\n"
+ "vfmadd231ps ymm5,ymm15,ymmword PTR [r12 + 32]\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm5\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps ymm6,ymm15,ymmword PTR [r12 + 0]\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm6\t\n"
+ "vfmadd231ps ymm7,ymm15,ymmword PTR [r12 + 32]\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm7\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps ymm8,ymm15,ymmword PTR [r12 + 0]\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm8\t\n"
+ "vfmadd231ps ymm9,ymm15,ymmword PTR [r12 + 32]\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm9\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps ymm10,ymm15,ymmword PTR [r12 + 0]\t\n"
+ "vmovups ymmword PTR [r12 + 0], ymm10\t\n"
+ "vfmadd231ps ymm11,ymm15,ymmword PTR [r12 + 32]\t\n"
+ "vmovups ymmword PTR [r12 + 32], ymm11\t\n"
+ "add r12, r13\t\n"
+
+ "L_done%=:\t\n"
+
+ // next outer iteration
+ "add rcx, 64\t\n"
+ "mov r12, rcx\t\n"
+ "mov r9, rax\t\n"
+ "inc rbx\t\n"
+ "cmp rbx, rdi\t\n"
+ "jl loop_outter%=\t\n"
+ :
+ : [gp] "rm"(gp)
+ : "r8",
+ "r9",
+ "r10",
+ "r11",
+ "r15",
+ "r13",
+ "r14",
+ "rax",
+ "rcx",
+ "rdx",
+ "rsi",
+ "rdi",
+ "rbx",
+ "r12",
+ "memory");
+>>>>>>> upstream
}
diff --git a/src/FbgemmFP16UKernelsAvx2.h b/src/FbgemmFP16UKernelsAvx2.h
index d48a88e..95fea41 100644
--- a/src/FbgemmFP16UKernelsAvx2.h
+++ b/src/FbgemmFP16UKernelsAvx2.h
@@ -4,8 +4,7 @@
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
-#ifndef FBGEMM_UKERNELS
-#define FBGEMM_UKERNELS
+#pragma once
#include <cstdint>
#include "fbgemm/Types.h"
@@ -40,4 +39,3 @@ typedef void (*funcptr_fp16)(GemmParams* gp);
} // namespace fbgemm
-#endif
diff --git a/src/FbgemmFP16UKernelsAvx512.cc b/src/FbgemmFP16UKernelsAvx512.cc
new file mode 100644
index 0000000..a7927c5
--- /dev/null
+++ b/src/FbgemmFP16UKernelsAvx512.cc
@@ -0,0 +1,2558 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "FbgemmFP16UKernelsAvx512.h"
+
+namespace fbgemm {
+
+void __attribute__((noinline)) gemmkernel_1x2_AVX512_fA0fB0fC0(GemmParams* gp) {
+ asm volatile(
+#if !defined(__clang__)
+ "mov r14, %[gp]\t\n"
+#else
+ "mov %[gp], %%r14\t\n"
+ ".intel_syntax noprefix\t\n"
+#endif
+
+ // Copy parameters
+ // k
+ "mov r8, [r14 + 0]\t\n"
+ // A
+ "mov r9, [r14 + 8]\t\n"
+ // B
+ "mov r10, [r14 + 16]\t\n"
+ // beta
+ "mov r15, [r14 + 24]\t\n"
+ // accum
+ "mov rdx, [r14 + 32]\t\n"
+ // C
+ "mov r12, [r14 + 40]\t\n"
+ // ldc
+ "mov r13, [r14 + 48]\t\n"
+ // b_block_cols
+ "mov rdi, [r14 + 56]\t\n"
+ // b_block_size
+ "mov rsi, [r14 + 64]\t\n"
+ // Make copies of A and C
+ "mov rax, r9\t\n"
+ "mov rcx, r12\t\n"
+
+ "mov rbx, 0\t\n"
+ "loop_outter%=:\t\n"
+ "mov r14, 0\t\n"
+ "vxorps zmm0,zmm0,zmm0\t\n"
+ "vxorps zmm1,zmm1,zmm1\t\n"
+
+ "loop_inner%=:\t\n"
+
+ "vcvtph2ps zmm3,YMMWORD PTR [r10 + 0]\t\n"
+ "vcvtph2ps zmm4,YMMWORD PTR [r10 + 32]\t\n"
+ "vbroadcastss zmm2,DWORD PTR [r9+0]\t\n"
+ "vfmadd231ps zmm0,zmm3,zmm2\t\n"
+ "vfmadd231ps zmm1,zmm4,zmm2\t\n"
+ "add r9,4\t\n"
+ "add r10,64\t\n"
+ "inc r14\t\n"
+ "cmp r14, r8\t\n"
+ "jl loop_inner%=\t\n"
+
+ "L_exit%=:\t\n"
+
+ "cmp rdx, 1\t\n"
+ "je L_accum%=\t\n"
+ // Dump C
+ "vmovups zmmword PTR [r12 + 0], zmm0\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm1\t\n"
+ "add r12, r13\t\n"
+ "jmp L_done%=\t\n"
+
+ "L_accum%=:\t\n"
+ // Dump C with accumulate
+ "vbroadcastss zmm31,DWORD PTR [r15]\t\n"
+ "vfmadd231ps zmm0,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm0\t\n"
+ "vfmadd231ps zmm1,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm1\t\n"
+ "add r12, r13\t\n"
+
+ "L_done%=:\t\n"
+
+ // next outer iteration
+ "add rcx, 128\t\n"
+ "mov r12, rcx\t\n"
+ "mov r9, rax\t\n"
+ "inc rbx\t\n"
+ "cmp rbx, rdi\t\n"
+ "jl loop_outter%=\t\n"
+ :
+ : [gp] "rm"(gp)
+ : "r8",
+ "r9",
+ "r10",
+ "r11",
+ "r15",
+ "r13",
+ "r14",
+ "rax",
+ "rcx",
+ "rdx",
+ "rsi",
+ "rdi",
+ "rbx",
+ "r12",
+ "memory");
+}
+void __attribute__((noinline)) gemmkernel_2x2_AVX512_fA0fB0fC0(GemmParams* gp) {
+ asm volatile(
+#if !defined(__clang__)
+ "mov r14, %[gp]\t\n"
+#else
+ "mov %[gp], %%r14\t\n"
+ ".intel_syntax noprefix\t\n"
+#endif
+
+ // Copy parameters
+ // k
+ "mov r8, [r14 + 0]\t\n"
+ // A
+ "mov r9, [r14 + 8]\t\n"
+ // B
+ "mov r10, [r14 + 16]\t\n"
+ // beta
+ "mov r15, [r14 + 24]\t\n"
+ // accum
+ "mov rdx, [r14 + 32]\t\n"
+ // C
+ "mov r12, [r14 + 40]\t\n"
+ // ldc
+ "mov r13, [r14 + 48]\t\n"
+ // b_block_cols
+ "mov rdi, [r14 + 56]\t\n"
+ // b_block_size
+ "mov rsi, [r14 + 64]\t\n"
+ // Make copies of A and C
+ "mov rax, r9\t\n"
+ "mov rcx, r12\t\n"
+
+ "mov rbx, 0\t\n"
+ "loop_outter%=:\t\n"
+ "mov r14, 0\t\n"
+ "vxorps zmm0,zmm0,zmm0\t\n"
+ "vxorps zmm1,zmm1,zmm1\t\n"
+ "vxorps zmm2,zmm2,zmm2\t\n"
+ "vxorps zmm3,zmm3,zmm3\t\n"
+
+ "loop_inner%=:\t\n"
+
+ "vcvtph2ps zmm5,YMMWORD PTR [r10 + 0]\t\n"
+ "vcvtph2ps zmm6,YMMWORD PTR [r10 + 32]\t\n"
+ "vbroadcastss zmm4,DWORD PTR [r9+0]\t\n"
+ "vfmadd231ps zmm0,zmm5,zmm4\t\n"
+ "vfmadd231ps zmm1,zmm6,zmm4\t\n"
+ "vbroadcastss zmm4,DWORD PTR [r9+4]\t\n"
+ "vfmadd231ps zmm2,zmm5,zmm4\t\n"
+ "vfmadd231ps zmm3,zmm6,zmm4\t\n"
+ "add r9,8\t\n"
+ "add r10,64\t\n"
+ "inc r14\t\n"
+ "cmp r14, r8\t\n"
+ "jl loop_inner%=\t\n"
+
+ "L_exit%=:\t\n"
+
+ "cmp rdx, 1\t\n"
+ "je L_accum%=\t\n"
+ // Dump C
+ "vmovups zmmword PTR [r12 + 0], zmm0\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm1\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm2\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm3\t\n"
+ "add r12, r13\t\n"
+ "jmp L_done%=\t\n"
+
+ "L_accum%=:\t\n"
+ // Dump C with accumulate
+ "vbroadcastss zmm31,DWORD PTR [r15]\t\n"
+ "vfmadd231ps zmm0,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm0\t\n"
+ "vfmadd231ps zmm1,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm1\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm2,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm2\t\n"
+ "vfmadd231ps zmm3,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm3\t\n"
+ "add r12, r13\t\n"
+
+ "L_done%=:\t\n"
+
+ // next outer iteration
+ "add rcx, 128\t\n"
+ "mov r12, rcx\t\n"
+ "mov r9, rax\t\n"
+ "inc rbx\t\n"
+ "cmp rbx, rdi\t\n"
+ "jl loop_outter%=\t\n"
+ :
+ : [gp] "rm"(gp)
+ : "r8",
+ "r9",
+ "r10",
+ "r11",
+ "r15",
+ "r13",
+ "r14",
+ "rax",
+ "rcx",
+ "rdx",
+ "rsi",
+ "rdi",
+ "rbx",
+ "r12",
+ "memory");
+}
+void __attribute__((noinline)) gemmkernel_3x2_AVX512_fA0fB0fC0(GemmParams* gp) {
+ asm volatile(
+#if !defined(__clang__)
+ "mov r14, %[gp]\t\n"
+#else
+ "mov %[gp], %%r14\t\n"
+ ".intel_syntax noprefix\t\n"
+#endif
+
+ // Copy parameters
+ // k
+ "mov r8, [r14 + 0]\t\n"
+ // A
+ "mov r9, [r14 + 8]\t\n"
+ // B
+ "mov r10, [r14 + 16]\t\n"
+ // beta
+ "mov r15, [r14 + 24]\t\n"
+ // accum
+ "mov rdx, [r14 + 32]\t\n"
+ // C
+ "mov r12, [r14 + 40]\t\n"
+ // ldc
+ "mov r13, [r14 + 48]\t\n"
+ // b_block_cols
+ "mov rdi, [r14 + 56]\t\n"
+ // b_block_size
+ "mov rsi, [r14 + 64]\t\n"
+ // Make copies of A and C
+ "mov rax, r9\t\n"
+ "mov rcx, r12\t\n"
+
+ "mov rbx, 0\t\n"
+ "loop_outter%=:\t\n"
+ "mov r14, 0\t\n"
+ "vxorps zmm0,zmm0,zmm0\t\n"
+ "vxorps zmm1,zmm1,zmm1\t\n"
+ "vxorps zmm2,zmm2,zmm2\t\n"
+ "vxorps zmm3,zmm3,zmm3\t\n"
+ "vxorps zmm4,zmm4,zmm4\t\n"
+ "vxorps zmm5,zmm5,zmm5\t\n"
+
+ "loop_inner%=:\t\n"
+
+ "vcvtph2ps zmm7,YMMWORD PTR [r10 + 0]\t\n"
+ "vcvtph2ps zmm8,YMMWORD PTR [r10 + 32]\t\n"
+ "vbroadcastss zmm6,DWORD PTR [r9+0]\t\n"
+ "vfmadd231ps zmm0,zmm7,zmm6\t\n"
+ "vfmadd231ps zmm1,zmm8,zmm6\t\n"
+ "vbroadcastss zmm6,DWORD PTR [r9+4]\t\n"
+ "vfmadd231ps zmm2,zmm7,zmm6\t\n"
+ "vfmadd231ps zmm3,zmm8,zmm6\t\n"
+ "vbroadcastss zmm6,DWORD PTR [r9+8]\t\n"
+ "vfmadd231ps zmm4,zmm7,zmm6\t\n"
+ "vfmadd231ps zmm5,zmm8,zmm6\t\n"
+ "add r9,12\t\n"
+ "add r10,64\t\n"
+ "inc r14\t\n"
+ "cmp r14, r8\t\n"
+ "jl loop_inner%=\t\n"
+
+ "L_exit%=:\t\n"
+
+ "cmp rdx, 1\t\n"
+ "je L_accum%=\t\n"
+ // Dump C
+ "vmovups zmmword PTR [r12 + 0], zmm0\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm1\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm2\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm3\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm4\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm5\t\n"
+ "add r12, r13\t\n"
+ "jmp L_done%=\t\n"
+
+ "L_accum%=:\t\n"
+ // Dump C with accumulate
+ "vbroadcastss zmm31,DWORD PTR [r15]\t\n"
+ "vfmadd231ps zmm0,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm0\t\n"
+ "vfmadd231ps zmm1,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm1\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm2,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm2\t\n"
+ "vfmadd231ps zmm3,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm3\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm4,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm4\t\n"
+ "vfmadd231ps zmm5,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm5\t\n"
+ "add r12, r13\t\n"
+
+ "L_done%=:\t\n"
+
+ // next outer iteration
+ "add rcx, 128\t\n"
+ "mov r12, rcx\t\n"
+ "mov r9, rax\t\n"
+ "inc rbx\t\n"
+ "cmp rbx, rdi\t\n"
+ "jl loop_outter%=\t\n"
+ :
+ : [gp] "rm"(gp)
+ : "r8",
+ "r9",
+ "r10",
+ "r11",
+ "r15",
+ "r13",
+ "r14",
+ "rax",
+ "rcx",
+ "rdx",
+ "rsi",
+ "rdi",
+ "rbx",
+ "r12",
+ "memory");
+}
+void __attribute__((noinline)) gemmkernel_4x2_AVX512_fA0fB0fC0(GemmParams* gp) {
+ asm volatile(
+#if !defined(__clang__)
+ "mov r14, %[gp]\t\n"
+#else
+ "mov %[gp], %%r14\t\n"
+ ".intel_syntax noprefix\t\n"
+#endif
+
+ // Copy parameters
+ // k
+ "mov r8, [r14 + 0]\t\n"
+ // A
+ "mov r9, [r14 + 8]\t\n"
+ // B
+ "mov r10, [r14 + 16]\t\n"
+ // beta
+ "mov r15, [r14 + 24]\t\n"
+ // accum
+ "mov rdx, [r14 + 32]\t\n"
+ // C
+ "mov r12, [r14 + 40]\t\n"
+ // ldc
+ "mov r13, [r14 + 48]\t\n"
+ // b_block_cols
+ "mov rdi, [r14 + 56]\t\n"
+ // b_block_size
+ "mov rsi, [r14 + 64]\t\n"
+ // Make copies of A and C
+ "mov rax, r9\t\n"
+ "mov rcx, r12\t\n"
+
+ "mov rbx, 0\t\n"
+ "loop_outter%=:\t\n"
+ "mov r14, 0\t\n"
+ "vxorps zmm0,zmm0,zmm0\t\n"
+ "vxorps zmm1,zmm1,zmm1\t\n"
+ "vxorps zmm2,zmm2,zmm2\t\n"
+ "vxorps zmm3,zmm3,zmm3\t\n"
+ "vxorps zmm4,zmm4,zmm4\t\n"
+ "vxorps zmm5,zmm5,zmm5\t\n"
+ "vxorps zmm6,zmm6,zmm6\t\n"
+ "vxorps zmm7,zmm7,zmm7\t\n"
+
+ "loop_inner%=:\t\n"
+
+ "vcvtph2ps zmm9,YMMWORD PTR [r10 + 0]\t\n"
+ "vcvtph2ps zmm10,YMMWORD PTR [r10 + 32]\t\n"
+ "vbroadcastss zmm8,DWORD PTR [r9+0]\t\n"
+ "vfmadd231ps zmm0,zmm9,zmm8\t\n"
+ "vfmadd231ps zmm1,zmm10,zmm8\t\n"
+ "vbroadcastss zmm8,DWORD PTR [r9+4]\t\n"
+ "vfmadd231ps zmm2,zmm9,zmm8\t\n"
+ "vfmadd231ps zmm3,zmm10,zmm8\t\n"
+ "vbroadcastss zmm8,DWORD PTR [r9+8]\t\n"
+ "vfmadd231ps zmm4,zmm9,zmm8\t\n"
+ "vfmadd231ps zmm5,zmm10,zmm8\t\n"
+ "vbroadcastss zmm8,DWORD PTR [r9+12]\t\n"
+ "vfmadd231ps zmm6,zmm9,zmm8\t\n"
+ "vfmadd231ps zmm7,zmm10,zmm8\t\n"
+ "add r9,16\t\n"
+ "add r10,64\t\n"
+ "inc r14\t\n"
+ "cmp r14, r8\t\n"
+ "jl loop_inner%=\t\n"
+
+ "L_exit%=:\t\n"
+
+ "cmp rdx, 1\t\n"
+ "je L_accum%=\t\n"
+ // Dump C
+ "vmovups zmmword PTR [r12 + 0], zmm0\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm1\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm2\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm3\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm4\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm5\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm6\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm7\t\n"
+ "add r12, r13\t\n"
+ "jmp L_done%=\t\n"
+
+ "L_accum%=:\t\n"
+ // Dump C with accumulate
+ "vbroadcastss zmm31,DWORD PTR [r15]\t\n"
+ "vfmadd231ps zmm0,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm0\t\n"
+ "vfmadd231ps zmm1,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm1\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm2,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm2\t\n"
+ "vfmadd231ps zmm3,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm3\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm4,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm4\t\n"
+ "vfmadd231ps zmm5,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm5\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm6,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm6\t\n"
+ "vfmadd231ps zmm7,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm7\t\n"
+ "add r12, r13\t\n"
+
+ "L_done%=:\t\n"
+
+ // next outer iteration
+ "add rcx, 128\t\n"
+ "mov r12, rcx\t\n"
+ "mov r9, rax\t\n"
+ "inc rbx\t\n"
+ "cmp rbx, rdi\t\n"
+ "jl loop_outter%=\t\n"
+ :
+ : [gp] "rm"(gp)
+ : "r8",
+ "r9",
+ "r10",
+ "r11",
+ "r15",
+ "r13",
+ "r14",
+ "rax",
+ "rcx",
+ "rdx",
+ "rsi",
+ "rdi",
+ "rbx",
+ "r12",
+ "memory");
+}
+void __attribute__((noinline)) gemmkernel_5x2_AVX512_fA0fB0fC0(GemmParams* gp) {
+ asm volatile(
+#if !defined(__clang__)
+ "mov r14, %[gp]\t\n"
+#else
+ "mov %[gp], %%r14\t\n"
+ ".intel_syntax noprefix\t\n"
+#endif
+
+ // Copy parameters
+ // k
+ "mov r8, [r14 + 0]\t\n"
+ // A
+ "mov r9, [r14 + 8]\t\n"
+ // B
+ "mov r10, [r14 + 16]\t\n"
+ // beta
+ "mov r15, [r14 + 24]\t\n"
+ // accum
+ "mov rdx, [r14 + 32]\t\n"
+ // C
+ "mov r12, [r14 + 40]\t\n"
+ // ldc
+ "mov r13, [r14 + 48]\t\n"
+ // b_block_cols
+ "mov rdi, [r14 + 56]\t\n"
+ // b_block_size
+ "mov rsi, [r14 + 64]\t\n"
+ // Make copies of A and C
+ "mov rax, r9\t\n"
+ "mov rcx, r12\t\n"
+
+ "mov rbx, 0\t\n"
+ "loop_outter%=:\t\n"
+ "mov r14, 0\t\n"
+ "vxorps zmm0,zmm0,zmm0\t\n"
+ "vxorps zmm1,zmm1,zmm1\t\n"
+ "vxorps zmm2,zmm2,zmm2\t\n"
+ "vxorps zmm3,zmm3,zmm3\t\n"
+ "vxorps zmm4,zmm4,zmm4\t\n"
+ "vxorps zmm5,zmm5,zmm5\t\n"
+ "vxorps zmm6,zmm6,zmm6\t\n"
+ "vxorps zmm7,zmm7,zmm7\t\n"
+ "vxorps zmm8,zmm8,zmm8\t\n"
+ "vxorps zmm9,zmm9,zmm9\t\n"
+
+ "loop_inner%=:\t\n"
+
+ "vcvtph2ps zmm11,YMMWORD PTR [r10 + 0]\t\n"
+ "vcvtph2ps zmm12,YMMWORD PTR [r10 + 32]\t\n"
+ "vbroadcastss zmm10,DWORD PTR [r9+0]\t\n"
+ "vfmadd231ps zmm0,zmm11,zmm10\t\n"
+ "vfmadd231ps zmm1,zmm12,zmm10\t\n"
+ "vbroadcastss zmm10,DWORD PTR [r9+4]\t\n"
+ "vfmadd231ps zmm2,zmm11,zmm10\t\n"
+ "vfmadd231ps zmm3,zmm12,zmm10\t\n"
+ "vbroadcastss zmm10,DWORD PTR [r9+8]\t\n"
+ "vfmadd231ps zmm4,zmm11,zmm10\t\n"
+ "vfmadd231ps zmm5,zmm12,zmm10\t\n"
+ "vbroadcastss zmm10,DWORD PTR [r9+12]\t\n"
+ "vfmadd231ps zmm6,zmm11,zmm10\t\n"
+ "vfmadd231ps zmm7,zmm12,zmm10\t\n"
+ "vbroadcastss zmm10,DWORD PTR [r9+16]\t\n"
+ "vfmadd231ps zmm8,zmm11,zmm10\t\n"
+ "vfmadd231ps zmm9,zmm12,zmm10\t\n"
+ "add r9,20\t\n"
+ "add r10,64\t\n"
+ "inc r14\t\n"
+ "cmp r14, r8\t\n"
+ "jl loop_inner%=\t\n"
+
+ "L_exit%=:\t\n"
+
+ "cmp rdx, 1\t\n"
+ "je L_accum%=\t\n"
+ // Dump C
+ "vmovups zmmword PTR [r12 + 0], zmm0\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm1\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm2\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm3\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm4\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm5\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm6\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm7\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm8\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm9\t\n"
+ "add r12, r13\t\n"
+ "jmp L_done%=\t\n"
+
+ "L_accum%=:\t\n"
+ // Dump C with accumulate
+ "vbroadcastss zmm31,DWORD PTR [r15]\t\n"
+ "vfmadd231ps zmm0,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm0\t\n"
+ "vfmadd231ps zmm1,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm1\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm2,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm2\t\n"
+ "vfmadd231ps zmm3,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm3\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm4,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm4\t\n"
+ "vfmadd231ps zmm5,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm5\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm6,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm6\t\n"
+ "vfmadd231ps zmm7,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm7\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm8,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm8\t\n"
+ "vfmadd231ps zmm9,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm9\t\n"
+ "add r12, r13\t\n"
+
+ "L_done%=:\t\n"
+
+ // next outer iteration
+ "add rcx, 128\t\n"
+ "mov r12, rcx\t\n"
+ "mov r9, rax\t\n"
+ "inc rbx\t\n"
+ "cmp rbx, rdi\t\n"
+ "jl loop_outter%=\t\n"
+ :
+ : [gp] "rm"(gp)
+ : "r8",
+ "r9",
+ "r10",
+ "r11",
+ "r15",
+ "r13",
+ "r14",
+ "rax",
+ "rcx",
+ "rdx",
+ "rsi",
+ "rdi",
+ "rbx",
+ "r12",
+ "memory");
+}
+void __attribute__((noinline)) gemmkernel_6x2_AVX512_fA0fB0fC0(GemmParams* gp) {
+ asm volatile(
+#if !defined(__clang__)
+ "mov r14, %[gp]\t\n"
+#else
+ "mov %[gp], %%r14\t\n"
+ ".intel_syntax noprefix\t\n"
+#endif
+
+ // Copy parameters
+ // k
+ "mov r8, [r14 + 0]\t\n"
+ // A
+ "mov r9, [r14 + 8]\t\n"
+ // B
+ "mov r10, [r14 + 16]\t\n"
+ // beta
+ "mov r15, [r14 + 24]\t\n"
+ // accum
+ "mov rdx, [r14 + 32]\t\n"
+ // C
+ "mov r12, [r14 + 40]\t\n"
+ // ldc
+ "mov r13, [r14 + 48]\t\n"
+ // b_block_cols
+ "mov rdi, [r14 + 56]\t\n"
+ // b_block_size
+ "mov rsi, [r14 + 64]\t\n"
+ // Make copies of A and C
+ "mov rax, r9\t\n"
+ "mov rcx, r12\t\n"
+
+ "mov rbx, 0\t\n"
+ "loop_outter%=:\t\n"
+ "mov r14, 0\t\n"
+ "vxorps zmm0,zmm0,zmm0\t\n"
+ "vxorps zmm1,zmm1,zmm1\t\n"
+ "vxorps zmm2,zmm2,zmm2\t\n"
+ "vxorps zmm3,zmm3,zmm3\t\n"
+ "vxorps zmm4,zmm4,zmm4\t\n"
+ "vxorps zmm5,zmm5,zmm5\t\n"
+ "vxorps zmm6,zmm6,zmm6\t\n"
+ "vxorps zmm7,zmm7,zmm7\t\n"
+ "vxorps zmm8,zmm8,zmm8\t\n"
+ "vxorps zmm9,zmm9,zmm9\t\n"
+ "vxorps zmm10,zmm10,zmm10\t\n"
+ "vxorps zmm11,zmm11,zmm11\t\n"
+
+ "loop_inner%=:\t\n"
+
+ "vcvtph2ps zmm13,YMMWORD PTR [r10 + 0]\t\n"
+ "vcvtph2ps zmm14,YMMWORD PTR [r10 + 32]\t\n"
+ "vbroadcastss zmm12,DWORD PTR [r9+0]\t\n"
+ "vfmadd231ps zmm0,zmm13,zmm12\t\n"
+ "vfmadd231ps zmm1,zmm14,zmm12\t\n"
+ "vbroadcastss zmm12,DWORD PTR [r9+4]\t\n"
+ "vfmadd231ps zmm2,zmm13,zmm12\t\n"
+ "vfmadd231ps zmm3,zmm14,zmm12\t\n"
+ "vbroadcastss zmm12,DWORD PTR [r9+8]\t\n"
+ "vfmadd231ps zmm4,zmm13,zmm12\t\n"
+ "vfmadd231ps zmm5,zmm14,zmm12\t\n"
+ "vbroadcastss zmm12,DWORD PTR [r9+12]\t\n"
+ "vfmadd231ps zmm6,zmm13,zmm12\t\n"
+ "vfmadd231ps zmm7,zmm14,zmm12\t\n"
+ "vbroadcastss zmm12,DWORD PTR [r9+16]\t\n"
+ "vfmadd231ps zmm8,zmm13,zmm12\t\n"
+ "vfmadd231ps zmm9,zmm14,zmm12\t\n"
+ "vbroadcastss zmm12,DWORD PTR [r9+20]\t\n"
+ "vfmadd231ps zmm10,zmm13,zmm12\t\n"
+ "vfmadd231ps zmm11,zmm14,zmm12\t\n"
+ "add r9,24\t\n"
+ "add r10,64\t\n"
+ "inc r14\t\n"
+ "cmp r14, r8\t\n"
+ "jl loop_inner%=\t\n"
+
+ "L_exit%=:\t\n"
+
+ "cmp rdx, 1\t\n"
+ "je L_accum%=\t\n"
+ // Dump C
+ "vmovups zmmword PTR [r12 + 0], zmm0\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm1\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm2\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm3\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm4\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm5\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm6\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm7\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm8\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm9\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm10\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm11\t\n"
+ "add r12, r13\t\n"
+ "jmp L_done%=\t\n"
+
+ "L_accum%=:\t\n"
+ // Dump C with accumulate
+ "vbroadcastss zmm31,DWORD PTR [r15]\t\n"
+ "vfmadd231ps zmm0,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm0\t\n"
+ "vfmadd231ps zmm1,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm1\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm2,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm2\t\n"
+ "vfmadd231ps zmm3,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm3\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm4,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm4\t\n"
+ "vfmadd231ps zmm5,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm5\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm6,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm6\t\n"
+ "vfmadd231ps zmm7,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm7\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm8,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm8\t\n"
+ "vfmadd231ps zmm9,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm9\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm10,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm10\t\n"
+ "vfmadd231ps zmm11,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm11\t\n"
+ "add r12, r13\t\n"
+
+ "L_done%=:\t\n"
+
+ // next outer iteration
+ "add rcx, 128\t\n"
+ "mov r12, rcx\t\n"
+ "mov r9, rax\t\n"
+ "inc rbx\t\n"
+ "cmp rbx, rdi\t\n"
+ "jl loop_outter%=\t\n"
+ :
+ : [gp] "rm"(gp)
+ : "r8",
+ "r9",
+ "r10",
+ "r11",
+ "r15",
+ "r13",
+ "r14",
+ "rax",
+ "rcx",
+ "rdx",
+ "rsi",
+ "rdi",
+ "rbx",
+ "r12",
+ "memory");
+}
+void __attribute__((noinline)) gemmkernel_7x2_AVX512_fA0fB0fC0(GemmParams* gp) {
+ asm volatile(
+#if !defined(__clang__)
+ "mov r14, %[gp]\t\n"
+#else
+ "mov %[gp], %%r14\t\n"
+ ".intel_syntax noprefix\t\n"
+#endif
+
+ // Copy parameters
+ // k
+ "mov r8, [r14 + 0]\t\n"
+ // A
+ "mov r9, [r14 + 8]\t\n"
+ // B
+ "mov r10, [r14 + 16]\t\n"
+ // beta
+ "mov r15, [r14 + 24]\t\n"
+ // accum
+ "mov rdx, [r14 + 32]\t\n"
+ // C
+ "mov r12, [r14 + 40]\t\n"
+ // ldc
+ "mov r13, [r14 + 48]\t\n"
+ // b_block_cols
+ "mov rdi, [r14 + 56]\t\n"
+ // b_block_size
+ "mov rsi, [r14 + 64]\t\n"
+ // Make copies of A and C
+ "mov rax, r9\t\n"
+ "mov rcx, r12\t\n"
+
+ "mov rbx, 0\t\n"
+ "loop_outter%=:\t\n"
+ "mov r14, 0\t\n"
+ "vxorps zmm0,zmm0,zmm0\t\n"
+ "vxorps zmm1,zmm1,zmm1\t\n"
+ "vxorps zmm2,zmm2,zmm2\t\n"
+ "vxorps zmm3,zmm3,zmm3\t\n"
+ "vxorps zmm4,zmm4,zmm4\t\n"
+ "vxorps zmm5,zmm5,zmm5\t\n"
+ "vxorps zmm6,zmm6,zmm6\t\n"
+ "vxorps zmm7,zmm7,zmm7\t\n"
+ "vxorps zmm8,zmm8,zmm8\t\n"
+ "vxorps zmm9,zmm9,zmm9\t\n"
+ "vxorps zmm10,zmm10,zmm10\t\n"
+ "vxorps zmm11,zmm11,zmm11\t\n"
+ "vxorps zmm12,zmm12,zmm12\t\n"
+ "vxorps zmm13,zmm13,zmm13\t\n"
+
+ "loop_inner%=:\t\n"
+
+ "vcvtph2ps zmm15,YMMWORD PTR [r10 + 0]\t\n"
+ "vcvtph2ps zmm16,YMMWORD PTR [r10 + 32]\t\n"
+ "vbroadcastss zmm14,DWORD PTR [r9+0]\t\n"
+ "vfmadd231ps zmm0,zmm15,zmm14\t\n"
+ "vfmadd231ps zmm1,zmm16,zmm14\t\n"
+ "vbroadcastss zmm14,DWORD PTR [r9+4]\t\n"
+ "vfmadd231ps zmm2,zmm15,zmm14\t\n"
+ "vfmadd231ps zmm3,zmm16,zmm14\t\n"
+ "vbroadcastss zmm14,DWORD PTR [r9+8]\t\n"
+ "vfmadd231ps zmm4,zmm15,zmm14\t\n"
+ "vfmadd231ps zmm5,zmm16,zmm14\t\n"
+ "vbroadcastss zmm14,DWORD PTR [r9+12]\t\n"
+ "vfmadd231ps zmm6,zmm15,zmm14\t\n"
+ "vfmadd231ps zmm7,zmm16,zmm14\t\n"
+ "vbroadcastss zmm14,DWORD PTR [r9+16]\t\n"
+ "vfmadd231ps zmm8,zmm15,zmm14\t\n"
+ "vfmadd231ps zmm9,zmm16,zmm14\t\n"
+ "vbroadcastss zmm14,DWORD PTR [r9+20]\t\n"
+ "vfmadd231ps zmm10,zmm15,zmm14\t\n"
+ "vfmadd231ps zmm11,zmm16,zmm14\t\n"
+ "vbroadcastss zmm14,DWORD PTR [r9+24]\t\n"
+ "vfmadd231ps zmm12,zmm15,zmm14\t\n"
+ "vfmadd231ps zmm13,zmm16,zmm14\t\n"
+ "add r9,28\t\n"
+ "add r10,64\t\n"
+ "inc r14\t\n"
+ "cmp r14, r8\t\n"
+ "jl loop_inner%=\t\n"
+
+ "L_exit%=:\t\n"
+
+ "cmp rdx, 1\t\n"
+ "je L_accum%=\t\n"
+ // Dump C
+ "vmovups zmmword PTR [r12 + 0], zmm0\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm1\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm2\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm3\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm4\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm5\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm6\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm7\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm8\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm9\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm10\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm11\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm12\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm13\t\n"
+ "add r12, r13\t\n"
+ "jmp L_done%=\t\n"
+
+ "L_accum%=:\t\n"
+ // Dump C with accumulate
+ "vbroadcastss zmm31,DWORD PTR [r15]\t\n"
+ "vfmadd231ps zmm0,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm0\t\n"
+ "vfmadd231ps zmm1,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm1\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm2,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm2\t\n"
+ "vfmadd231ps zmm3,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm3\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm4,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm4\t\n"
+ "vfmadd231ps zmm5,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm5\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm6,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm6\t\n"
+ "vfmadd231ps zmm7,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm7\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm8,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm8\t\n"
+ "vfmadd231ps zmm9,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm9\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm10,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm10\t\n"
+ "vfmadd231ps zmm11,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm11\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm12,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm12\t\n"
+ "vfmadd231ps zmm13,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm13\t\n"
+ "add r12, r13\t\n"
+
+ "L_done%=:\t\n"
+
+ // next outer iteration
+ "add rcx, 128\t\n"
+ "mov r12, rcx\t\n"
+ "mov r9, rax\t\n"
+ "inc rbx\t\n"
+ "cmp rbx, rdi\t\n"
+ "jl loop_outter%=\t\n"
+ :
+ : [gp] "rm"(gp)
+ : "r8",
+ "r9",
+ "r10",
+ "r11",
+ "r15",
+ "r13",
+ "r14",
+ "rax",
+ "rcx",
+ "rdx",
+ "rsi",
+ "rdi",
+ "rbx",
+ "r12",
+ "memory");
+}
+void __attribute__((noinline)) gemmkernel_8x2_AVX512_fA0fB0fC0(GemmParams* gp) {
+ asm volatile(
+#if !defined(__clang__)
+ "mov r14, %[gp]\t\n"
+#else
+ "mov %[gp], %%r14\t\n"
+ ".intel_syntax noprefix\t\n"
+#endif
+
+ // Copy parameters
+ // k
+ "mov r8, [r14 + 0]\t\n"
+ // A
+ "mov r9, [r14 + 8]\t\n"
+ // B
+ "mov r10, [r14 + 16]\t\n"
+ // beta
+ "mov r15, [r14 + 24]\t\n"
+ // accum
+ "mov rdx, [r14 + 32]\t\n"
+ // C
+ "mov r12, [r14 + 40]\t\n"
+ // ldc
+ "mov r13, [r14 + 48]\t\n"
+ // b_block_cols
+ "mov rdi, [r14 + 56]\t\n"
+ // b_block_size
+ "mov rsi, [r14 + 64]\t\n"
+ // Make copies of A and C
+ "mov rax, r9\t\n"
+ "mov rcx, r12\t\n"
+
+ "mov rbx, 0\t\n"
+ "loop_outter%=:\t\n"
+ "mov r14, 0\t\n"
+ "vxorps zmm0,zmm0,zmm0\t\n"
+ "vxorps zmm1,zmm1,zmm1\t\n"
+ "vxorps zmm2,zmm2,zmm2\t\n"
+ "vxorps zmm3,zmm3,zmm3\t\n"
+ "vxorps zmm4,zmm4,zmm4\t\n"
+ "vxorps zmm5,zmm5,zmm5\t\n"
+ "vxorps zmm6,zmm6,zmm6\t\n"
+ "vxorps zmm7,zmm7,zmm7\t\n"
+ "vxorps zmm8,zmm8,zmm8\t\n"
+ "vxorps zmm9,zmm9,zmm9\t\n"
+ "vxorps zmm10,zmm10,zmm10\t\n"
+ "vxorps zmm11,zmm11,zmm11\t\n"
+ "vxorps zmm12,zmm12,zmm12\t\n"
+ "vxorps zmm13,zmm13,zmm13\t\n"
+ "vxorps zmm14,zmm14,zmm14\t\n"
+ "vxorps zmm15,zmm15,zmm15\t\n"
+
+ "loop_inner%=:\t\n"
+
+ "vcvtph2ps zmm17,YMMWORD PTR [r10 + 0]\t\n"
+ "vcvtph2ps zmm18,YMMWORD PTR [r10 + 32]\t\n"
+ "vbroadcastss zmm16,DWORD PTR [r9+0]\t\n"
+ "vfmadd231ps zmm0,zmm17,zmm16\t\n"
+ "vfmadd231ps zmm1,zmm18,zmm16\t\n"
+ "vbroadcastss zmm16,DWORD PTR [r9+4]\t\n"
+ "vfmadd231ps zmm2,zmm17,zmm16\t\n"
+ "vfmadd231ps zmm3,zmm18,zmm16\t\n"
+ "vbroadcastss zmm16,DWORD PTR [r9+8]\t\n"
+ "vfmadd231ps zmm4,zmm17,zmm16\t\n"
+ "vfmadd231ps zmm5,zmm18,zmm16\t\n"
+ "vbroadcastss zmm16,DWORD PTR [r9+12]\t\n"
+ "vfmadd231ps zmm6,zmm17,zmm16\t\n"
+ "vfmadd231ps zmm7,zmm18,zmm16\t\n"
+ "vbroadcastss zmm16,DWORD PTR [r9+16]\t\n"
+ "vfmadd231ps zmm8,zmm17,zmm16\t\n"
+ "vfmadd231ps zmm9,zmm18,zmm16\t\n"
+ "vbroadcastss zmm16,DWORD PTR [r9+20]\t\n"
+ "vfmadd231ps zmm10,zmm17,zmm16\t\n"
+ "vfmadd231ps zmm11,zmm18,zmm16\t\n"
+ "vbroadcastss zmm16,DWORD PTR [r9+24]\t\n"
+ "vfmadd231ps zmm12,zmm17,zmm16\t\n"
+ "vfmadd231ps zmm13,zmm18,zmm16\t\n"
+ "vbroadcastss zmm16,DWORD PTR [r9+28]\t\n"
+ "vfmadd231ps zmm14,zmm17,zmm16\t\n"
+ "vfmadd231ps zmm15,zmm18,zmm16\t\n"
+ "add r9,32\t\n"
+ "add r10,64\t\n"
+ "inc r14\t\n"
+ "cmp r14, r8\t\n"
+ "jl loop_inner%=\t\n"
+
+ "L_exit%=:\t\n"
+
+ "cmp rdx, 1\t\n"
+ "je L_accum%=\t\n"
+ // Dump C
+ "vmovups zmmword PTR [r12 + 0], zmm0\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm1\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm2\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm3\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm4\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm5\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm6\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm7\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm8\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm9\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm10\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm11\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm12\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm13\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm14\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm15\t\n"
+ "add r12, r13\t\n"
+ "jmp L_done%=\t\n"
+
+ "L_accum%=:\t\n"
+ // Dump C with accumulate
+ "vbroadcastss zmm31,DWORD PTR [r15]\t\n"
+ "vfmadd231ps zmm0,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm0\t\n"
+ "vfmadd231ps zmm1,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm1\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm2,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm2\t\n"
+ "vfmadd231ps zmm3,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm3\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm4,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm4\t\n"
+ "vfmadd231ps zmm5,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm5\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm6,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm6\t\n"
+ "vfmadd231ps zmm7,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm7\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm8,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm8\t\n"
+ "vfmadd231ps zmm9,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm9\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm10,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm10\t\n"
+ "vfmadd231ps zmm11,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm11\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm12,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm12\t\n"
+ "vfmadd231ps zmm13,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm13\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm14,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm14\t\n"
+ "vfmadd231ps zmm15,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm15\t\n"
+ "add r12, r13\t\n"
+
+ "L_done%=:\t\n"
+
+ // next outer iteration
+ "add rcx, 128\t\n"
+ "mov r12, rcx\t\n"
+ "mov r9, rax\t\n"
+ "inc rbx\t\n"
+ "cmp rbx, rdi\t\n"
+ "jl loop_outter%=\t\n"
+ :
+ : [gp] "rm"(gp)
+ : "r8",
+ "r9",
+ "r10",
+ "r11",
+ "r15",
+ "r13",
+ "r14",
+ "rax",
+ "rcx",
+ "rdx",
+ "rsi",
+ "rdi",
+ "rbx",
+ "r12",
+ "memory");
+}
+void __attribute__((noinline)) gemmkernel_9x2_AVX512_fA0fB0fC0(GemmParams* gp) {
+ asm volatile(
+#if !defined(__clang__)
+ "mov r14, %[gp]\t\n"
+#else
+ "mov %[gp], %%r14\t\n"
+ ".intel_syntax noprefix\t\n"
+#endif
+
+ // Copy parameters
+ // k
+ "mov r8, [r14 + 0]\t\n"
+ // A
+ "mov r9, [r14 + 8]\t\n"
+ // B
+ "mov r10, [r14 + 16]\t\n"
+ // beta
+ "mov r15, [r14 + 24]\t\n"
+ // accum
+ "mov rdx, [r14 + 32]\t\n"
+ // C
+ "mov r12, [r14 + 40]\t\n"
+ // ldc
+ "mov r13, [r14 + 48]\t\n"
+ // b_block_cols
+ "mov rdi, [r14 + 56]\t\n"
+ // b_block_size
+ "mov rsi, [r14 + 64]\t\n"
+ // Make copies of A and C
+ "mov rax, r9\t\n"
+ "mov rcx, r12\t\n"
+
+ "mov rbx, 0\t\n"
+ "loop_outter%=:\t\n"
+ "mov r14, 0\t\n"
+ "vxorps zmm0,zmm0,zmm0\t\n"
+ "vxorps zmm1,zmm1,zmm1\t\n"
+ "vxorps zmm2,zmm2,zmm2\t\n"
+ "vxorps zmm3,zmm3,zmm3\t\n"
+ "vxorps zmm4,zmm4,zmm4\t\n"
+ "vxorps zmm5,zmm5,zmm5\t\n"
+ "vxorps zmm6,zmm6,zmm6\t\n"
+ "vxorps zmm7,zmm7,zmm7\t\n"
+ "vxorps zmm8,zmm8,zmm8\t\n"
+ "vxorps zmm9,zmm9,zmm9\t\n"
+ "vxorps zmm10,zmm10,zmm10\t\n"
+ "vxorps zmm11,zmm11,zmm11\t\n"
+ "vxorps zmm12,zmm12,zmm12\t\n"
+ "vxorps zmm13,zmm13,zmm13\t\n"
+ "vxorps zmm14,zmm14,zmm14\t\n"
+ "vxorps zmm15,zmm15,zmm15\t\n"
+ "vxorps zmm16,zmm16,zmm16\t\n"
+ "vxorps zmm17,zmm17,zmm17\t\n"
+
+ "loop_inner%=:\t\n"
+
+ "vcvtph2ps zmm19,YMMWORD PTR [r10 + 0]\t\n"
+ "vcvtph2ps zmm20,YMMWORD PTR [r10 + 32]\t\n"
+ "vbroadcastss zmm18,DWORD PTR [r9+0]\t\n"
+ "vfmadd231ps zmm0,zmm19,zmm18\t\n"
+ "vfmadd231ps zmm1,zmm20,zmm18\t\n"
+ "vbroadcastss zmm18,DWORD PTR [r9+4]\t\n"
+ "vfmadd231ps zmm2,zmm19,zmm18\t\n"
+ "vfmadd231ps zmm3,zmm20,zmm18\t\n"
+ "vbroadcastss zmm18,DWORD PTR [r9+8]\t\n"
+ "vfmadd231ps zmm4,zmm19,zmm18\t\n"
+ "vfmadd231ps zmm5,zmm20,zmm18\t\n"
+ "vbroadcastss zmm18,DWORD PTR [r9+12]\t\n"
+ "vfmadd231ps zmm6,zmm19,zmm18\t\n"
+ "vfmadd231ps zmm7,zmm20,zmm18\t\n"
+ "vbroadcastss zmm18,DWORD PTR [r9+16]\t\n"
+ "vfmadd231ps zmm8,zmm19,zmm18\t\n"
+ "vfmadd231ps zmm9,zmm20,zmm18\t\n"
+ "vbroadcastss zmm18,DWORD PTR [r9+20]\t\n"
+ "vfmadd231ps zmm10,zmm19,zmm18\t\n"
+ "vfmadd231ps zmm11,zmm20,zmm18\t\n"
+ "vbroadcastss zmm18,DWORD PTR [r9+24]\t\n"
+ "vfmadd231ps zmm12,zmm19,zmm18\t\n"
+ "vfmadd231ps zmm13,zmm20,zmm18\t\n"
+ "vbroadcastss zmm18,DWORD PTR [r9+28]\t\n"
+ "vfmadd231ps zmm14,zmm19,zmm18\t\n"
+ "vfmadd231ps zmm15,zmm20,zmm18\t\n"
+ "vbroadcastss zmm18,DWORD PTR [r9+32]\t\n"
+ "vfmadd231ps zmm16,zmm19,zmm18\t\n"
+ "vfmadd231ps zmm17,zmm20,zmm18\t\n"
+ "add r9,36\t\n"
+ "add r10,64\t\n"
+ "inc r14\t\n"
+ "cmp r14, r8\t\n"
+ "jl loop_inner%=\t\n"
+
+ "L_exit%=:\t\n"
+
+ "cmp rdx, 1\t\n"
+ "je L_accum%=\t\n"
+ // Dump C
+ "vmovups zmmword PTR [r12 + 0], zmm0\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm1\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm2\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm3\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm4\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm5\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm6\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm7\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm8\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm9\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm10\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm11\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm12\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm13\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm14\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm15\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm16\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm17\t\n"
+ "add r12, r13\t\n"
+ "jmp L_done%=\t\n"
+
+ "L_accum%=:\t\n"
+ // Dump C with accumulate
+ "vbroadcastss zmm31,DWORD PTR [r15]\t\n"
+ "vfmadd231ps zmm0,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm0\t\n"
+ "vfmadd231ps zmm1,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm1\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm2,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm2\t\n"
+ "vfmadd231ps zmm3,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm3\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm4,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm4\t\n"
+ "vfmadd231ps zmm5,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm5\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm6,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm6\t\n"
+ "vfmadd231ps zmm7,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm7\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm8,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm8\t\n"
+ "vfmadd231ps zmm9,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm9\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm10,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm10\t\n"
+ "vfmadd231ps zmm11,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm11\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm12,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm12\t\n"
+ "vfmadd231ps zmm13,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm13\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm14,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm14\t\n"
+ "vfmadd231ps zmm15,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm15\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm16,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm16\t\n"
+ "vfmadd231ps zmm17,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm17\t\n"
+ "add r12, r13\t\n"
+
+ "L_done%=:\t\n"
+
+ // next outer iteration
+ "add rcx, 128\t\n"
+ "mov r12, rcx\t\n"
+ "mov r9, rax\t\n"
+ "inc rbx\t\n"
+ "cmp rbx, rdi\t\n"
+ "jl loop_outter%=\t\n"
+ :
+ : [gp] "rm"(gp)
+ : "r8",
+ "r9",
+ "r10",
+ "r11",
+ "r15",
+ "r13",
+ "r14",
+ "rax",
+ "rcx",
+ "rdx",
+ "rsi",
+ "rdi",
+ "rbx",
+ "r12",
+ "memory");
+}
+void __attribute__((noinline))
+gemmkernel_10x2_AVX512_fA0fB0fC0(GemmParams* gp) {
+ asm volatile(
+#if !defined(__clang__)
+ "mov r14, %[gp]\t\n"
+#else
+ "mov %[gp], %%r14\t\n"
+ ".intel_syntax noprefix\t\n"
+#endif
+
+ // Copy parameters
+ // k
+ "mov r8, [r14 + 0]\t\n"
+ // A
+ "mov r9, [r14 + 8]\t\n"
+ // B
+ "mov r10, [r14 + 16]\t\n"
+ // beta
+ "mov r15, [r14 + 24]\t\n"
+ // accum
+ "mov rdx, [r14 + 32]\t\n"
+ // C
+ "mov r12, [r14 + 40]\t\n"
+ // ldc
+ "mov r13, [r14 + 48]\t\n"
+ // b_block_cols
+ "mov rdi, [r14 + 56]\t\n"
+ // b_block_size
+ "mov rsi, [r14 + 64]\t\n"
+ // Make copies of A and C
+ "mov rax, r9\t\n"
+ "mov rcx, r12\t\n"
+
+ "mov rbx, 0\t\n"
+ "loop_outter%=:\t\n"
+ "mov r14, 0\t\n"
+ "vxorps zmm0,zmm0,zmm0\t\n"
+ "vxorps zmm1,zmm1,zmm1\t\n"
+ "vxorps zmm2,zmm2,zmm2\t\n"
+ "vxorps zmm3,zmm3,zmm3\t\n"
+ "vxorps zmm4,zmm4,zmm4\t\n"
+ "vxorps zmm5,zmm5,zmm5\t\n"
+ "vxorps zmm6,zmm6,zmm6\t\n"
+ "vxorps zmm7,zmm7,zmm7\t\n"
+ "vxorps zmm8,zmm8,zmm8\t\n"
+ "vxorps zmm9,zmm9,zmm9\t\n"
+ "vxorps zmm10,zmm10,zmm10\t\n"
+ "vxorps zmm11,zmm11,zmm11\t\n"
+ "vxorps zmm12,zmm12,zmm12\t\n"
+ "vxorps zmm13,zmm13,zmm13\t\n"
+ "vxorps zmm14,zmm14,zmm14\t\n"
+ "vxorps zmm15,zmm15,zmm15\t\n"
+ "vxorps zmm16,zmm16,zmm16\t\n"
+ "vxorps zmm17,zmm17,zmm17\t\n"
+ "vxorps zmm18,zmm18,zmm18\t\n"
+ "vxorps zmm19,zmm19,zmm19\t\n"
+
+ "loop_inner%=:\t\n"
+
+ "vcvtph2ps zmm21,YMMWORD PTR [r10 + 0]\t\n"
+ "vcvtph2ps zmm22,YMMWORD PTR [r10 + 32]\t\n"
+ "vbroadcastss zmm20,DWORD PTR [r9+0]\t\n"
+ "vfmadd231ps zmm0,zmm21,zmm20\t\n"
+ "vfmadd231ps zmm1,zmm22,zmm20\t\n"
+ "vbroadcastss zmm20,DWORD PTR [r9+4]\t\n"
+ "vfmadd231ps zmm2,zmm21,zmm20\t\n"
+ "vfmadd231ps zmm3,zmm22,zmm20\t\n"
+ "vbroadcastss zmm20,DWORD PTR [r9+8]\t\n"
+ "vfmadd231ps zmm4,zmm21,zmm20\t\n"
+ "vfmadd231ps zmm5,zmm22,zmm20\t\n"
+ "vbroadcastss zmm20,DWORD PTR [r9+12]\t\n"
+ "vfmadd231ps zmm6,zmm21,zmm20\t\n"
+ "vfmadd231ps zmm7,zmm22,zmm20\t\n"
+ "vbroadcastss zmm20,DWORD PTR [r9+16]\t\n"
+ "vfmadd231ps zmm8,zmm21,zmm20\t\n"
+ "vfmadd231ps zmm9,zmm22,zmm20\t\n"
+ "vbroadcastss zmm20,DWORD PTR [r9+20]\t\n"
+ "vfmadd231ps zmm10,zmm21,zmm20\t\n"
+ "vfmadd231ps zmm11,zmm22,zmm20\t\n"
+ "vbroadcastss zmm20,DWORD PTR [r9+24]\t\n"
+ "vfmadd231ps zmm12,zmm21,zmm20\t\n"
+ "vfmadd231ps zmm13,zmm22,zmm20\t\n"
+ "vbroadcastss zmm20,DWORD PTR [r9+28]\t\n"
+ "vfmadd231ps zmm14,zmm21,zmm20\t\n"
+ "vfmadd231ps zmm15,zmm22,zmm20\t\n"
+ "vbroadcastss zmm20,DWORD PTR [r9+32]\t\n"
+ "vfmadd231ps zmm16,zmm21,zmm20\t\n"
+ "vfmadd231ps zmm17,zmm22,zmm20\t\n"
+ "vbroadcastss zmm20,DWORD PTR [r9+36]\t\n"
+ "vfmadd231ps zmm18,zmm21,zmm20\t\n"
+ "vfmadd231ps zmm19,zmm22,zmm20\t\n"
+ "add r9,40\t\n"
+ "add r10,64\t\n"
+ "inc r14\t\n"
+ "cmp r14, r8\t\n"
+ "jl loop_inner%=\t\n"
+
+ "L_exit%=:\t\n"
+
+ "cmp rdx, 1\t\n"
+ "je L_accum%=\t\n"
+ // Dump C
+ "vmovups zmmword PTR [r12 + 0], zmm0\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm1\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm2\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm3\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm4\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm5\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm6\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm7\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm8\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm9\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm10\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm11\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm12\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm13\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm14\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm15\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm16\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm17\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm18\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm19\t\n"
+ "add r12, r13\t\n"
+ "jmp L_done%=\t\n"
+
+ "L_accum%=:\t\n"
+ // Dump C with accumulate
+ "vbroadcastss zmm31,DWORD PTR [r15]\t\n"
+ "vfmadd231ps zmm0,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm0\t\n"
+ "vfmadd231ps zmm1,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm1\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm2,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm2\t\n"
+ "vfmadd231ps zmm3,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm3\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm4,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm4\t\n"
+ "vfmadd231ps zmm5,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm5\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm6,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm6\t\n"
+ "vfmadd231ps zmm7,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm7\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm8,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm8\t\n"
+ "vfmadd231ps zmm9,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm9\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm10,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm10\t\n"
+ "vfmadd231ps zmm11,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm11\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm12,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm12\t\n"
+ "vfmadd231ps zmm13,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm13\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm14,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm14\t\n"
+ "vfmadd231ps zmm15,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm15\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm16,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm16\t\n"
+ "vfmadd231ps zmm17,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm17\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm18,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm18\t\n"
+ "vfmadd231ps zmm19,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm19\t\n"
+ "add r12, r13\t\n"
+
+ "L_done%=:\t\n"
+
+ // next outer iteration
+ "add rcx, 128\t\n"
+ "mov r12, rcx\t\n"
+ "mov r9, rax\t\n"
+ "inc rbx\t\n"
+ "cmp rbx, rdi\t\n"
+ "jl loop_outter%=\t\n"
+ :
+ : [gp] "rm"(gp)
+ : "r8",
+ "r9",
+ "r10",
+ "r11",
+ "r15",
+ "r13",
+ "r14",
+ "rax",
+ "rcx",
+ "rdx",
+ "rsi",
+ "rdi",
+ "rbx",
+ "r12",
+ "memory");
+}
+void __attribute__((noinline))
+gemmkernel_11x2_AVX512_fA0fB0fC0(GemmParams* gp) {
+ asm volatile(
+#if !defined(__clang__)
+ "mov r14, %[gp]\t\n"
+#else
+ "mov %[gp], %%r14\t\n"
+ ".intel_syntax noprefix\t\n"
+#endif
+
+ // Copy parameters
+ // k
+ "mov r8, [r14 + 0]\t\n"
+ // A
+ "mov r9, [r14 + 8]\t\n"
+ // B
+ "mov r10, [r14 + 16]\t\n"
+ // beta
+ "mov r15, [r14 + 24]\t\n"
+ // accum
+ "mov rdx, [r14 + 32]\t\n"
+ // C
+ "mov r12, [r14 + 40]\t\n"
+ // ldc
+ "mov r13, [r14 + 48]\t\n"
+ // b_block_cols
+ "mov rdi, [r14 + 56]\t\n"
+ // b_block_size
+ "mov rsi, [r14 + 64]\t\n"
+ // Make copies of A and C
+ "mov rax, r9\t\n"
+ "mov rcx, r12\t\n"
+
+ "mov rbx, 0\t\n"
+ "loop_outter%=:\t\n"
+ "mov r14, 0\t\n"
+ "vxorps zmm0,zmm0,zmm0\t\n"
+ "vxorps zmm1,zmm1,zmm1\t\n"
+ "vxorps zmm2,zmm2,zmm2\t\n"
+ "vxorps zmm3,zmm3,zmm3\t\n"
+ "vxorps zmm4,zmm4,zmm4\t\n"
+ "vxorps zmm5,zmm5,zmm5\t\n"
+ "vxorps zmm6,zmm6,zmm6\t\n"
+ "vxorps zmm7,zmm7,zmm7\t\n"
+ "vxorps zmm8,zmm8,zmm8\t\n"
+ "vxorps zmm9,zmm9,zmm9\t\n"
+ "vxorps zmm10,zmm10,zmm10\t\n"
+ "vxorps zmm11,zmm11,zmm11\t\n"
+ "vxorps zmm12,zmm12,zmm12\t\n"
+ "vxorps zmm13,zmm13,zmm13\t\n"
+ "vxorps zmm14,zmm14,zmm14\t\n"
+ "vxorps zmm15,zmm15,zmm15\t\n"
+ "vxorps zmm16,zmm16,zmm16\t\n"
+ "vxorps zmm17,zmm17,zmm17\t\n"
+ "vxorps zmm18,zmm18,zmm18\t\n"
+ "vxorps zmm19,zmm19,zmm19\t\n"
+ "vxorps zmm20,zmm20,zmm20\t\n"
+ "vxorps zmm21,zmm21,zmm21\t\n"
+
+ "loop_inner%=:\t\n"
+
+ "vcvtph2ps zmm23,YMMWORD PTR [r10 + 0]\t\n"
+ "vcvtph2ps zmm24,YMMWORD PTR [r10 + 32]\t\n"
+ "vbroadcastss zmm22,DWORD PTR [r9+0]\t\n"
+ "vfmadd231ps zmm0,zmm23,zmm22\t\n"
+ "vfmadd231ps zmm1,zmm24,zmm22\t\n"
+ "vbroadcastss zmm22,DWORD PTR [r9+4]\t\n"
+ "vfmadd231ps zmm2,zmm23,zmm22\t\n"
+ "vfmadd231ps zmm3,zmm24,zmm22\t\n"
+ "vbroadcastss zmm22,DWORD PTR [r9+8]\t\n"
+ "vfmadd231ps zmm4,zmm23,zmm22\t\n"
+ "vfmadd231ps zmm5,zmm24,zmm22\t\n"
+ "vbroadcastss zmm22,DWORD PTR [r9+12]\t\n"
+ "vfmadd231ps zmm6,zmm23,zmm22\t\n"
+ "vfmadd231ps zmm7,zmm24,zmm22\t\n"
+ "vbroadcastss zmm22,DWORD PTR [r9+16]\t\n"
+ "vfmadd231ps zmm8,zmm23,zmm22\t\n"
+ "vfmadd231ps zmm9,zmm24,zmm22\t\n"
+ "vbroadcastss zmm22,DWORD PTR [r9+20]\t\n"
+ "vfmadd231ps zmm10,zmm23,zmm22\t\n"
+ "vfmadd231ps zmm11,zmm24,zmm22\t\n"
+ "vbroadcastss zmm22,DWORD PTR [r9+24]\t\n"
+ "vfmadd231ps zmm12,zmm23,zmm22\t\n"
+ "vfmadd231ps zmm13,zmm24,zmm22\t\n"
+ "vbroadcastss zmm22,DWORD PTR [r9+28]\t\n"
+ "vfmadd231ps zmm14,zmm23,zmm22\t\n"
+ "vfmadd231ps zmm15,zmm24,zmm22\t\n"
+ "vbroadcastss zmm22,DWORD PTR [r9+32]\t\n"
+ "vfmadd231ps zmm16,zmm23,zmm22\t\n"
+ "vfmadd231ps zmm17,zmm24,zmm22\t\n"
+ "vbroadcastss zmm22,DWORD PTR [r9+36]\t\n"
+ "vfmadd231ps zmm18,zmm23,zmm22\t\n"
+ "vfmadd231ps zmm19,zmm24,zmm22\t\n"
+ "vbroadcastss zmm22,DWORD PTR [r9+40]\t\n"
+ "vfmadd231ps zmm20,zmm23,zmm22\t\n"
+ "vfmadd231ps zmm21,zmm24,zmm22\t\n"
+ "add r9,44\t\n"
+ "add r10,64\t\n"
+ "inc r14\t\n"
+ "cmp r14, r8\t\n"
+ "jl loop_inner%=\t\n"
+
+ "L_exit%=:\t\n"
+
+ "cmp rdx, 1\t\n"
+ "je L_accum%=\t\n"
+ // Dump C
+ "vmovups zmmword PTR [r12 + 0], zmm0\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm1\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm2\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm3\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm4\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm5\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm6\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm7\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm8\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm9\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm10\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm11\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm12\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm13\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm14\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm15\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm16\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm17\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm18\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm19\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm20\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm21\t\n"
+ "add r12, r13\t\n"
+ "jmp L_done%=\t\n"
+
+ "L_accum%=:\t\n"
+ // Dump C with accumulate
+ "vbroadcastss zmm31,DWORD PTR [r15]\t\n"
+ "vfmadd231ps zmm0,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm0\t\n"
+ "vfmadd231ps zmm1,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm1\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm2,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm2\t\n"
+ "vfmadd231ps zmm3,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm3\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm4,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm4\t\n"
+ "vfmadd231ps zmm5,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm5\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm6,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm6\t\n"
+ "vfmadd231ps zmm7,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm7\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm8,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm8\t\n"
+ "vfmadd231ps zmm9,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm9\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm10,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm10\t\n"
+ "vfmadd231ps zmm11,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm11\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm12,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm12\t\n"
+ "vfmadd231ps zmm13,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm13\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm14,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm14\t\n"
+ "vfmadd231ps zmm15,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm15\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm16,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm16\t\n"
+ "vfmadd231ps zmm17,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm17\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm18,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm18\t\n"
+ "vfmadd231ps zmm19,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm19\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm20,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm20\t\n"
+ "vfmadd231ps zmm21,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm21\t\n"
+ "add r12, r13\t\n"
+
+ "L_done%=:\t\n"
+
+ // next outer iteration
+ "add rcx, 128\t\n"
+ "mov r12, rcx\t\n"
+ "mov r9, rax\t\n"
+ "inc rbx\t\n"
+ "cmp rbx, rdi\t\n"
+ "jl loop_outter%=\t\n"
+ :
+ : [gp] "rm"(gp)
+ : "r8",
+ "r9",
+ "r10",
+ "r11",
+ "r15",
+ "r13",
+ "r14",
+ "rax",
+ "rcx",
+ "rdx",
+ "rsi",
+ "rdi",
+ "rbx",
+ "r12",
+ "memory");
+}
+void __attribute__((noinline))
+gemmkernel_12x2_AVX512_fA0fB0fC0(GemmParams* gp) {
+ asm volatile(
+#if !defined(__clang__)
+ "mov r14, %[gp]\t\n"
+#else
+ "mov %[gp], %%r14\t\n"
+ ".intel_syntax noprefix\t\n"
+#endif
+
+ // Copy parameters
+ // k
+ "mov r8, [r14 + 0]\t\n"
+ // A
+ "mov r9, [r14 + 8]\t\n"
+ // B
+ "mov r10, [r14 + 16]\t\n"
+ // beta
+ "mov r15, [r14 + 24]\t\n"
+ // accum
+ "mov rdx, [r14 + 32]\t\n"
+ // C
+ "mov r12, [r14 + 40]\t\n"
+ // ldc
+ "mov r13, [r14 + 48]\t\n"
+ // b_block_cols
+ "mov rdi, [r14 + 56]\t\n"
+ // b_block_size
+ "mov rsi, [r14 + 64]\t\n"
+ // Make copies of A and C
+ "mov rax, r9\t\n"
+ "mov rcx, r12\t\n"
+
+ "mov rbx, 0\t\n"
+ "loop_outter%=:\t\n"
+ "mov r14, 0\t\n"
+ "vxorps zmm0,zmm0,zmm0\t\n"
+ "vxorps zmm1,zmm1,zmm1\t\n"
+ "vxorps zmm2,zmm2,zmm2\t\n"
+ "vxorps zmm3,zmm3,zmm3\t\n"
+ "vxorps zmm4,zmm4,zmm4\t\n"
+ "vxorps zmm5,zmm5,zmm5\t\n"
+ "vxorps zmm6,zmm6,zmm6\t\n"
+ "vxorps zmm7,zmm7,zmm7\t\n"
+ "vxorps zmm8,zmm8,zmm8\t\n"
+ "vxorps zmm9,zmm9,zmm9\t\n"
+ "vxorps zmm10,zmm10,zmm10\t\n"
+ "vxorps zmm11,zmm11,zmm11\t\n"
+ "vxorps zmm12,zmm12,zmm12\t\n"
+ "vxorps zmm13,zmm13,zmm13\t\n"
+ "vxorps zmm14,zmm14,zmm14\t\n"
+ "vxorps zmm15,zmm15,zmm15\t\n"
+ "vxorps zmm16,zmm16,zmm16\t\n"
+ "vxorps zmm17,zmm17,zmm17\t\n"
+ "vxorps zmm18,zmm18,zmm18\t\n"
+ "vxorps zmm19,zmm19,zmm19\t\n"
+ "vxorps zmm20,zmm20,zmm20\t\n"
+ "vxorps zmm21,zmm21,zmm21\t\n"
+ "vxorps zmm22,zmm22,zmm22\t\n"
+ "vxorps zmm23,zmm23,zmm23\t\n"
+
+ "loop_inner%=:\t\n"
+
+ "vcvtph2ps zmm25,YMMWORD PTR [r10 + 0]\t\n"
+ "vcvtph2ps zmm26,YMMWORD PTR [r10 + 32]\t\n"
+ "vbroadcastss zmm24,DWORD PTR [r9+0]\t\n"
+ "vfmadd231ps zmm0,zmm25,zmm24\t\n"
+ "vfmadd231ps zmm1,zmm26,zmm24\t\n"
+ "vbroadcastss zmm24,DWORD PTR [r9+4]\t\n"
+ "vfmadd231ps zmm2,zmm25,zmm24\t\n"
+ "vfmadd231ps zmm3,zmm26,zmm24\t\n"
+ "vbroadcastss zmm24,DWORD PTR [r9+8]\t\n"
+ "vfmadd231ps zmm4,zmm25,zmm24\t\n"
+ "vfmadd231ps zmm5,zmm26,zmm24\t\n"
+ "vbroadcastss zmm24,DWORD PTR [r9+12]\t\n"
+ "vfmadd231ps zmm6,zmm25,zmm24\t\n"
+ "vfmadd231ps zmm7,zmm26,zmm24\t\n"
+ "vbroadcastss zmm24,DWORD PTR [r9+16]\t\n"
+ "vfmadd231ps zmm8,zmm25,zmm24\t\n"
+ "vfmadd231ps zmm9,zmm26,zmm24\t\n"
+ "vbroadcastss zmm24,DWORD PTR [r9+20]\t\n"
+ "vfmadd231ps zmm10,zmm25,zmm24\t\n"
+ "vfmadd231ps zmm11,zmm26,zmm24\t\n"
+ "vbroadcastss zmm24,DWORD PTR [r9+24]\t\n"
+ "vfmadd231ps zmm12,zmm25,zmm24\t\n"
+ "vfmadd231ps zmm13,zmm26,zmm24\t\n"
+ "vbroadcastss zmm24,DWORD PTR [r9+28]\t\n"
+ "vfmadd231ps zmm14,zmm25,zmm24\t\n"
+ "vfmadd231ps zmm15,zmm26,zmm24\t\n"
+ "vbroadcastss zmm24,DWORD PTR [r9+32]\t\n"
+ "vfmadd231ps zmm16,zmm25,zmm24\t\n"
+ "vfmadd231ps zmm17,zmm26,zmm24\t\n"
+ "vbroadcastss zmm24,DWORD PTR [r9+36]\t\n"
+ "vfmadd231ps zmm18,zmm25,zmm24\t\n"
+ "vfmadd231ps zmm19,zmm26,zmm24\t\n"
+ "vbroadcastss zmm24,DWORD PTR [r9+40]\t\n"
+ "vfmadd231ps zmm20,zmm25,zmm24\t\n"
+ "vfmadd231ps zmm21,zmm26,zmm24\t\n"
+ "vbroadcastss zmm24,DWORD PTR [r9+44]\t\n"
+ "vfmadd231ps zmm22,zmm25,zmm24\t\n"
+ "vfmadd231ps zmm23,zmm26,zmm24\t\n"
+ "add r9,48\t\n"
+ "add r10,64\t\n"
+ "inc r14\t\n"
+ "cmp r14, r8\t\n"
+ "jl loop_inner%=\t\n"
+
+ "L_exit%=:\t\n"
+
+ "cmp rdx, 1\t\n"
+ "je L_accum%=\t\n"
+ // Dump C
+ "vmovups zmmword PTR [r12 + 0], zmm0\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm1\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm2\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm3\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm4\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm5\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm6\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm7\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm8\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm9\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm10\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm11\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm12\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm13\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm14\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm15\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm16\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm17\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm18\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm19\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm20\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm21\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm22\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm23\t\n"
+ "add r12, r13\t\n"
+ "jmp L_done%=\t\n"
+
+ "L_accum%=:\t\n"
+ // Dump C with accumulate
+ "vbroadcastss zmm31,DWORD PTR [r15]\t\n"
+ "vfmadd231ps zmm0,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm0\t\n"
+ "vfmadd231ps zmm1,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm1\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm2,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm2\t\n"
+ "vfmadd231ps zmm3,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm3\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm4,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm4\t\n"
+ "vfmadd231ps zmm5,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm5\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm6,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm6\t\n"
+ "vfmadd231ps zmm7,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm7\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm8,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm8\t\n"
+ "vfmadd231ps zmm9,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm9\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm10,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm10\t\n"
+ "vfmadd231ps zmm11,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm11\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm12,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm12\t\n"
+ "vfmadd231ps zmm13,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm13\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm14,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm14\t\n"
+ "vfmadd231ps zmm15,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm15\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm16,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm16\t\n"
+ "vfmadd231ps zmm17,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm17\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm18,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm18\t\n"
+ "vfmadd231ps zmm19,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm19\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm20,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm20\t\n"
+ "vfmadd231ps zmm21,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm21\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm22,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm22\t\n"
+ "vfmadd231ps zmm23,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm23\t\n"
+ "add r12, r13\t\n"
+
+ "L_done%=:\t\n"
+
+ // next outer iteration
+ "add rcx, 128\t\n"
+ "mov r12, rcx\t\n"
+ "mov r9, rax\t\n"
+ "inc rbx\t\n"
+ "cmp rbx, rdi\t\n"
+ "jl loop_outter%=\t\n"
+ :
+ : [gp] "rm"(gp)
+ : "r8",
+ "r9",
+ "r10",
+ "r11",
+ "r15",
+ "r13",
+ "r14",
+ "rax",
+ "rcx",
+ "rdx",
+ "rsi",
+ "rdi",
+ "rbx",
+ "r12",
+ "memory");
+}
+void __attribute__((noinline))
+gemmkernel_13x2_AVX512_fA0fB0fC0(GemmParams* gp) {
+ asm volatile(
+#if !defined(__clang__)
+ "mov r14, %[gp]\t\n"
+#else
+ "mov %[gp], %%r14\t\n"
+ ".intel_syntax noprefix\t\n"
+#endif
+
+ // Copy parameters
+ // k
+ "mov r8, [r14 + 0]\t\n"
+ // A
+ "mov r9, [r14 + 8]\t\n"
+ // B
+ "mov r10, [r14 + 16]\t\n"
+ // beta
+ "mov r15, [r14 + 24]\t\n"
+ // accum
+ "mov rdx, [r14 + 32]\t\n"
+ // C
+ "mov r12, [r14 + 40]\t\n"
+ // ldc
+ "mov r13, [r14 + 48]\t\n"
+ // b_block_cols
+ "mov rdi, [r14 + 56]\t\n"
+ // b_block_size
+ "mov rsi, [r14 + 64]\t\n"
+ // Make copies of A and C
+ "mov rax, r9\t\n"
+ "mov rcx, r12\t\n"
+
+ "mov rbx, 0\t\n"
+ "loop_outter%=:\t\n"
+ "mov r14, 0\t\n"
+ "vxorps zmm0,zmm0,zmm0\t\n"
+ "vxorps zmm1,zmm1,zmm1\t\n"
+ "vxorps zmm2,zmm2,zmm2\t\n"
+ "vxorps zmm3,zmm3,zmm3\t\n"
+ "vxorps zmm4,zmm4,zmm4\t\n"
+ "vxorps zmm5,zmm5,zmm5\t\n"
+ "vxorps zmm6,zmm6,zmm6\t\n"
+ "vxorps zmm7,zmm7,zmm7\t\n"
+ "vxorps zmm8,zmm8,zmm8\t\n"
+ "vxorps zmm9,zmm9,zmm9\t\n"
+ "vxorps zmm10,zmm10,zmm10\t\n"
+ "vxorps zmm11,zmm11,zmm11\t\n"
+ "vxorps zmm12,zmm12,zmm12\t\n"
+ "vxorps zmm13,zmm13,zmm13\t\n"
+ "vxorps zmm14,zmm14,zmm14\t\n"
+ "vxorps zmm15,zmm15,zmm15\t\n"
+ "vxorps zmm16,zmm16,zmm16\t\n"
+ "vxorps zmm17,zmm17,zmm17\t\n"
+ "vxorps zmm18,zmm18,zmm18\t\n"
+ "vxorps zmm19,zmm19,zmm19\t\n"
+ "vxorps zmm20,zmm20,zmm20\t\n"
+ "vxorps zmm21,zmm21,zmm21\t\n"
+ "vxorps zmm22,zmm22,zmm22\t\n"
+ "vxorps zmm23,zmm23,zmm23\t\n"
+ "vxorps zmm24,zmm24,zmm24\t\n"
+ "vxorps zmm25,zmm25,zmm25\t\n"
+
+ "loop_inner%=:\t\n"
+
+ "vcvtph2ps zmm27,YMMWORD PTR [r10 + 0]\t\n"
+ "vcvtph2ps zmm28,YMMWORD PTR [r10 + 32]\t\n"
+ "vbroadcastss zmm26,DWORD PTR [r9+0]\t\n"
+ "vfmadd231ps zmm0,zmm27,zmm26\t\n"
+ "vfmadd231ps zmm1,zmm28,zmm26\t\n"
+ "vbroadcastss zmm26,DWORD PTR [r9+4]\t\n"
+ "vfmadd231ps zmm2,zmm27,zmm26\t\n"
+ "vfmadd231ps zmm3,zmm28,zmm26\t\n"
+ "vbroadcastss zmm26,DWORD PTR [r9+8]\t\n"
+ "vfmadd231ps zmm4,zmm27,zmm26\t\n"
+ "vfmadd231ps zmm5,zmm28,zmm26\t\n"
+ "vbroadcastss zmm26,DWORD PTR [r9+12]\t\n"
+ "vfmadd231ps zmm6,zmm27,zmm26\t\n"
+ "vfmadd231ps zmm7,zmm28,zmm26\t\n"
+ "vbroadcastss zmm26,DWORD PTR [r9+16]\t\n"
+ "vfmadd231ps zmm8,zmm27,zmm26\t\n"
+ "vfmadd231ps zmm9,zmm28,zmm26\t\n"
+ "vbroadcastss zmm26,DWORD PTR [r9+20]\t\n"
+ "vfmadd231ps zmm10,zmm27,zmm26\t\n"
+ "vfmadd231ps zmm11,zmm28,zmm26\t\n"
+ "vbroadcastss zmm26,DWORD PTR [r9+24]\t\n"
+ "vfmadd231ps zmm12,zmm27,zmm26\t\n"
+ "vfmadd231ps zmm13,zmm28,zmm26\t\n"
+ "vbroadcastss zmm26,DWORD PTR [r9+28]\t\n"
+ "vfmadd231ps zmm14,zmm27,zmm26\t\n"
+ "vfmadd231ps zmm15,zmm28,zmm26\t\n"
+ "vbroadcastss zmm26,DWORD PTR [r9+32]\t\n"
+ "vfmadd231ps zmm16,zmm27,zmm26\t\n"
+ "vfmadd231ps zmm17,zmm28,zmm26\t\n"
+ "vbroadcastss zmm26,DWORD PTR [r9+36]\t\n"
+ "vfmadd231ps zmm18,zmm27,zmm26\t\n"
+ "vfmadd231ps zmm19,zmm28,zmm26\t\n"
+ "vbroadcastss zmm26,DWORD PTR [r9+40]\t\n"
+ "vfmadd231ps zmm20,zmm27,zmm26\t\n"
+ "vfmadd231ps zmm21,zmm28,zmm26\t\n"
+ "vbroadcastss zmm26,DWORD PTR [r9+44]\t\n"
+ "vfmadd231ps zmm22,zmm27,zmm26\t\n"
+ "vfmadd231ps zmm23,zmm28,zmm26\t\n"
+ "vbroadcastss zmm26,DWORD PTR [r9+48]\t\n"
+ "vfmadd231ps zmm24,zmm27,zmm26\t\n"
+ "vfmadd231ps zmm25,zmm28,zmm26\t\n"
+ "add r9,52\t\n"
+ "add r10,64\t\n"
+ "inc r14\t\n"
+ "cmp r14, r8\t\n"
+ "jl loop_inner%=\t\n"
+
+ "L_exit%=:\t\n"
+
+ "cmp rdx, 1\t\n"
+ "je L_accum%=\t\n"
+ // Dump C
+ "vmovups zmmword PTR [r12 + 0], zmm0\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm1\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm2\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm3\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm4\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm5\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm6\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm7\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm8\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm9\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm10\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm11\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm12\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm13\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm14\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm15\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm16\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm17\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm18\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm19\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm20\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm21\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm22\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm23\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm24\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm25\t\n"
+ "add r12, r13\t\n"
+ "jmp L_done%=\t\n"
+
+ "L_accum%=:\t\n"
+ // Dump C with accumulate
+ "vbroadcastss zmm31,DWORD PTR [r15]\t\n"
+ "vfmadd231ps zmm0,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm0\t\n"
+ "vfmadd231ps zmm1,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm1\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm2,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm2\t\n"
+ "vfmadd231ps zmm3,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm3\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm4,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm4\t\n"
+ "vfmadd231ps zmm5,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm5\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm6,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm6\t\n"
+ "vfmadd231ps zmm7,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm7\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm8,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm8\t\n"
+ "vfmadd231ps zmm9,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm9\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm10,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm10\t\n"
+ "vfmadd231ps zmm11,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm11\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm12,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm12\t\n"
+ "vfmadd231ps zmm13,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm13\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm14,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm14\t\n"
+ "vfmadd231ps zmm15,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm15\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm16,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm16\t\n"
+ "vfmadd231ps zmm17,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm17\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm18,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm18\t\n"
+ "vfmadd231ps zmm19,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm19\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm20,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm20\t\n"
+ "vfmadd231ps zmm21,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm21\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm22,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm22\t\n"
+ "vfmadd231ps zmm23,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm23\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm24,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm24\t\n"
+ "vfmadd231ps zmm25,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm25\t\n"
+ "add r12, r13\t\n"
+
+ "L_done%=:\t\n"
+
+ // next outer iteration
+ "add rcx, 128\t\n"
+ "mov r12, rcx\t\n"
+ "mov r9, rax\t\n"
+ "inc rbx\t\n"
+ "cmp rbx, rdi\t\n"
+ "jl loop_outter%=\t\n"
+ :
+ : [gp] "rm"(gp)
+ : "r8",
+ "r9",
+ "r10",
+ "r11",
+ "r15",
+ "r13",
+ "r14",
+ "rax",
+ "rcx",
+ "rdx",
+ "rsi",
+ "rdi",
+ "rbx",
+ "r12",
+ "memory");
+}
+void __attribute__((noinline))
+gemmkernel_14x2_AVX512_fA0fB0fC0(GemmParams* gp) {
+ asm volatile(
+#if !defined(__clang__)
+ "mov r14, %[gp]\t\n"
+#else
+ "mov %[gp], %%r14\t\n"
+ ".intel_syntax noprefix\t\n"
+#endif
+
+ // Copy parameters
+ // k
+ "mov r8, [r14 + 0]\t\n"
+ // A
+ "mov r9, [r14 + 8]\t\n"
+ // B
+ "mov r10, [r14 + 16]\t\n"
+ // beta
+ "mov r15, [r14 + 24]\t\n"
+ // accum
+ "mov rdx, [r14 + 32]\t\n"
+ // C
+ "mov r12, [r14 + 40]\t\n"
+ // ldc
+ "mov r13, [r14 + 48]\t\n"
+ // b_block_cols
+ "mov rdi, [r14 + 56]\t\n"
+ // b_block_size
+ "mov rsi, [r14 + 64]\t\n"
+ // Make copies of A and C
+ "mov rax, r9\t\n"
+ "mov rcx, r12\t\n"
+
+ "mov rbx, 0\t\n"
+ "loop_outter%=:\t\n"
+ "mov r14, 0\t\n"
+ "vxorps zmm0,zmm0,zmm0\t\n"
+ "vxorps zmm1,zmm1,zmm1\t\n"
+ "vxorps zmm2,zmm2,zmm2\t\n"
+ "vxorps zmm3,zmm3,zmm3\t\n"
+ "vxorps zmm4,zmm4,zmm4\t\n"
+ "vxorps zmm5,zmm5,zmm5\t\n"
+ "vxorps zmm6,zmm6,zmm6\t\n"
+ "vxorps zmm7,zmm7,zmm7\t\n"
+ "vxorps zmm8,zmm8,zmm8\t\n"
+ "vxorps zmm9,zmm9,zmm9\t\n"
+ "vxorps zmm10,zmm10,zmm10\t\n"
+ "vxorps zmm11,zmm11,zmm11\t\n"
+ "vxorps zmm12,zmm12,zmm12\t\n"
+ "vxorps zmm13,zmm13,zmm13\t\n"
+ "vxorps zmm14,zmm14,zmm14\t\n"
+ "vxorps zmm15,zmm15,zmm15\t\n"
+ "vxorps zmm16,zmm16,zmm16\t\n"
+ "vxorps zmm17,zmm17,zmm17\t\n"
+ "vxorps zmm18,zmm18,zmm18\t\n"
+ "vxorps zmm19,zmm19,zmm19\t\n"
+ "vxorps zmm20,zmm20,zmm20\t\n"
+ "vxorps zmm21,zmm21,zmm21\t\n"
+ "vxorps zmm22,zmm22,zmm22\t\n"
+ "vxorps zmm23,zmm23,zmm23\t\n"
+ "vxorps zmm24,zmm24,zmm24\t\n"
+ "vxorps zmm25,zmm25,zmm25\t\n"
+ "vxorps zmm26,zmm26,zmm26\t\n"
+ "vxorps zmm27,zmm27,zmm27\t\n"
+
+ "loop_inner%=:\t\n"
+
+ "vcvtph2ps zmm29,YMMWORD PTR [r10 + 0]\t\n"
+ "vcvtph2ps zmm30,YMMWORD PTR [r10 + 32]\t\n"
+ "vbroadcastss zmm28,DWORD PTR [r9+0]\t\n"
+ "vfmadd231ps zmm0,zmm29,zmm28\t\n"
+ "vfmadd231ps zmm1,zmm30,zmm28\t\n"
+ "vbroadcastss zmm28,DWORD PTR [r9+4]\t\n"
+ "vfmadd231ps zmm2,zmm29,zmm28\t\n"
+ "vfmadd231ps zmm3,zmm30,zmm28\t\n"
+ "vbroadcastss zmm28,DWORD PTR [r9+8]\t\n"
+ "vfmadd231ps zmm4,zmm29,zmm28\t\n"
+ "vfmadd231ps zmm5,zmm30,zmm28\t\n"
+ "vbroadcastss zmm28,DWORD PTR [r9+12]\t\n"
+ "vfmadd231ps zmm6,zmm29,zmm28\t\n"
+ "vfmadd231ps zmm7,zmm30,zmm28\t\n"
+ "vbroadcastss zmm28,DWORD PTR [r9+16]\t\n"
+ "vfmadd231ps zmm8,zmm29,zmm28\t\n"
+ "vfmadd231ps zmm9,zmm30,zmm28\t\n"
+ "vbroadcastss zmm28,DWORD PTR [r9+20]\t\n"
+ "vfmadd231ps zmm10,zmm29,zmm28\t\n"
+ "vfmadd231ps zmm11,zmm30,zmm28\t\n"
+ "vbroadcastss zmm28,DWORD PTR [r9+24]\t\n"
+ "vfmadd231ps zmm12,zmm29,zmm28\t\n"
+ "vfmadd231ps zmm13,zmm30,zmm28\t\n"
+ "vbroadcastss zmm28,DWORD PTR [r9+28]\t\n"
+ "vfmadd231ps zmm14,zmm29,zmm28\t\n"
+ "vfmadd231ps zmm15,zmm30,zmm28\t\n"
+ "vbroadcastss zmm28,DWORD PTR [r9+32]\t\n"
+ "vfmadd231ps zmm16,zmm29,zmm28\t\n"
+ "vfmadd231ps zmm17,zmm30,zmm28\t\n"
+ "vbroadcastss zmm28,DWORD PTR [r9+36]\t\n"
+ "vfmadd231ps zmm18,zmm29,zmm28\t\n"
+ "vfmadd231ps zmm19,zmm30,zmm28\t\n"
+ "vbroadcastss zmm28,DWORD PTR [r9+40]\t\n"
+ "vfmadd231ps zmm20,zmm29,zmm28\t\n"
+ "vfmadd231ps zmm21,zmm30,zmm28\t\n"
+ "vbroadcastss zmm28,DWORD PTR [r9+44]\t\n"
+ "vfmadd231ps zmm22,zmm29,zmm28\t\n"
+ "vfmadd231ps zmm23,zmm30,zmm28\t\n"
+ "vbroadcastss zmm28,DWORD PTR [r9+48]\t\n"
+ "vfmadd231ps zmm24,zmm29,zmm28\t\n"
+ "vfmadd231ps zmm25,zmm30,zmm28\t\n"
+ "vbroadcastss zmm28,DWORD PTR [r9+52]\t\n"
+ "vfmadd231ps zmm26,zmm29,zmm28\t\n"
+ "vfmadd231ps zmm27,zmm30,zmm28\t\n"
+ "add r9,56\t\n"
+ "add r10,64\t\n"
+ "inc r14\t\n"
+ "cmp r14, r8\t\n"
+ "jl loop_inner%=\t\n"
+
+ "L_exit%=:\t\n"
+
+ "cmp rdx, 1\t\n"
+ "je L_accum%=\t\n"
+ // Dump C
+ "vmovups zmmword PTR [r12 + 0], zmm0\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm1\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm2\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm3\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm4\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm5\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm6\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm7\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm8\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm9\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm10\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm11\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm12\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm13\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm14\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm15\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm16\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm17\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm18\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm19\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm20\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm21\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm22\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm23\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm24\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm25\t\n"
+ "add r12, r13\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm26\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm27\t\n"
+ "add r12, r13\t\n"
+ "jmp L_done%=\t\n"
+
+ "L_accum%=:\t\n"
+ // Dump C with accumulate
+ "vbroadcastss zmm31,DWORD PTR [r15]\t\n"
+ "vfmadd231ps zmm0,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm0\t\n"
+ "vfmadd231ps zmm1,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm1\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm2,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm2\t\n"
+ "vfmadd231ps zmm3,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm3\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm4,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm4\t\n"
+ "vfmadd231ps zmm5,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm5\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm6,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm6\t\n"
+ "vfmadd231ps zmm7,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm7\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm8,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm8\t\n"
+ "vfmadd231ps zmm9,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm9\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm10,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm10\t\n"
+ "vfmadd231ps zmm11,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm11\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm12,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm12\t\n"
+ "vfmadd231ps zmm13,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm13\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm14,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm14\t\n"
+ "vfmadd231ps zmm15,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm15\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm16,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm16\t\n"
+ "vfmadd231ps zmm17,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm17\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm18,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm18\t\n"
+ "vfmadd231ps zmm19,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm19\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm20,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm20\t\n"
+ "vfmadd231ps zmm21,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm21\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm22,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm22\t\n"
+ "vfmadd231ps zmm23,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm23\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm24,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm24\t\n"
+ "vfmadd231ps zmm25,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm25\t\n"
+ "add r12, r13\t\n"
+ "vfmadd231ps zmm26,zmm31,zmmword PTR [r12 + 0]\t\n"
+ "vmovups zmmword PTR [r12 + 0], zmm26\t\n"
+ "vfmadd231ps zmm27,zmm31,zmmword PTR [r12 + 64]\t\n"
+ "vmovups zmmword PTR [r12 + 64], zmm27\t\n"
+ "add r12, r13\t\n"
+
+ "L_done%=:\t\n"
+
+ // next outer iteration
+ "add rcx, 128\t\n"
+ "mov r12, rcx\t\n"
+ "mov r9, rax\t\n"
+ "inc rbx\t\n"
+ "cmp rbx, rdi\t\n"
+ "jl loop_outter%=\t\n"
+ :
+ : [gp] "rm"(gp)
+ : "r8",
+ "r9",
+ "r10",
+ "r11",
+ "r15",
+ "r13",
+ "r14",
+ "rax",
+ "rcx",
+ "rdx",
+ "rsi",
+ "rdi",
+ "rbx",
+ "r12",
+ "memory");
+}
+
+} // namespace fbgemm
diff --git a/src/FbgemmFP16UKernelsAvx512.h b/src/FbgemmFP16UKernelsAvx512.h
new file mode 100644
index 0000000..bbb41f5
--- /dev/null
+++ b/src/FbgemmFP16UKernelsAvx512.h
@@ -0,0 +1,32 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+#include <cstdint>
+#include "FbgemmFP16UKernelsAvx2.h"
+#include "fbgemm/Types.h"
+
+namespace fbgemm {
+
+void __attribute__((noinline)) gemmkernel_1x2_AVX512_fA0fB0fC0(GemmParams* gp);
+void __attribute__((noinline)) gemmkernel_2x2_AVX512_fA0fB0fC0(GemmParams* gp);
+void __attribute__((noinline)) gemmkernel_3x2_AVX512_fA0fB0fC0(GemmParams* gp);
+void __attribute__((noinline)) gemmkernel_4x2_AVX512_fA0fB0fC0(GemmParams* gp);
+void __attribute__((noinline)) gemmkernel_5x2_AVX512_fA0fB0fC0(GemmParams* gp);
+void __attribute__((noinline)) gemmkernel_6x2_AVX512_fA0fB0fC0(GemmParams* gp);
+void __attribute__((noinline)) gemmkernel_7x2_AVX512_fA0fB0fC0(GemmParams* gp);
+void __attribute__((noinline)) gemmkernel_8x2_AVX512_fA0fB0fC0(GemmParams* gp);
+void __attribute__((noinline)) gemmkernel_9x2_AVX512_fA0fB0fC0(GemmParams* gp);
+void __attribute__((noinline)) gemmkernel_10x2_AVX512_fA0fB0fC0(GemmParams* gp);
+void __attribute__((noinline)) gemmkernel_11x2_AVX512_fA0fB0fC0(GemmParams* gp);
+void __attribute__((noinline)) gemmkernel_12x2_AVX512_fA0fB0fC0(GemmParams* gp);
+void __attribute__((noinline)) gemmkernel_13x2_AVX512_fA0fB0fC0(GemmParams* gp);
+void __attribute__((noinline)) gemmkernel_14x2_AVX512_fA0fB0fC0(GemmParams* gp);
+typedef void (*funcptr_fp16)(GemmParams* gp);
+;
+
+} // namespace fbgemm
+
diff --git a/src/FbgemmI8Depthwise2DAvx2-inl.h b/src/FbgemmI8Depthwise2DAvx2-inl.h
new file mode 100644
index 0000000..1dda555
--- /dev/null
+++ b/src/FbgemmI8Depthwise2DAvx2-inl.h
@@ -0,0 +1,1623 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#pragma once
+
+#include <tuple> // for tie
+
+#include "src/FbgemmI8DepthwiseAvx2-inl.h"
+
+namespace fbgemm {
+
+template <int S = 3, bool SUM_A = false, bool REMAINDER = false>
+static inline __attribute__((always_inline)) void inner_prod_2d_packed_(
+ const __m256i* a_v,
+ const __m256i* Bp,
+ std::int32_t* C,
+ int remainder,
+ __m256i* a_sum = nullptr) {
+ return inner_prod_packed_<S * S, SUM_A, REMAINDER>(
+ a_v, Bp, C, remainder, a_sum);
+}
+
+template <
+ bool SUM_A,
+ bool REMAINDER = false,
+ bool PER_CHANNEL_QUANTIZATION = false>
+static inline __attribute__((always_inline)) void inner_prod_3x3_packed_(
+ int H,
+ int W,
+ int K,
+ int h_in,
+ int w_in,
+ const std::uint8_t* A,
+ std::int32_t A_zero_point,
+ const std::int8_t* Bp,
+ const std::int32_t* B_zero_point,
+ std::int32_t* C,
+ int remainder,
+ std::int32_t* row_offsets) {
+ __m256i A_zero_point_v =
+ _mm256_set1_epi8(static_cast<std::uint8_t>(A_zero_point));
+ __m256i mask_v = _mm256_setzero_si256();
+ if (REMAINDER) {
+ mask_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(masks[remainder / 4]));
+ }
+
+ // The code below can be written as a simple R*S loop but the compiler
+ // doesn't unroll so we're manually unrolling it.
+ // constexpr int R = 3, S = 3;
+ // array<__m256i, R * S> a_v;
+ // for (int r = 0; r < R; ++r) {
+ // for (int s = 0; s < S; ++s) {
+ // if (h_in + r >= 0 && h_in + r < H && w_in + s >= 0 && w_in + s < W) {
+ // if (REMAINDER) {
+ // a_v[r * S + s] =
+ // _mm256_maskload_epi32((const int *)(A + (r * W + s) * K),
+ // mask_v);
+ // } else {
+ // a_v[r * S + s] =
+ // _mm256_lddqu_si256((const __m256i *)(A + (r * W + s) * K));
+ // }
+ // } else {
+ // a_v[r * S + s] = A_zero_point_v;
+ // }
+ // }
+ // }
+ __m256i a_v[9] = {
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ A_zero_point_v,
+ };
+
+ if (h_in >= 0 && h_in < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[0] = load_a<REMAINDER>(A + (0 * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[1] = load_a<REMAINDER>(A + (0 * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[2] = load_a<REMAINDER>(A + (0 * W + 2) * K, mask_v);
+ }
+ }
+
+ if (h_in + 1 >= 0 && h_in + 1 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[3] = load_a<REMAINDER>(A + (1 * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[4] = load_a<REMAINDER>(A + (1 * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[5] = load_a<REMAINDER>(A + (1 * W + 2) * K, mask_v);
+ }
+ }
+
+ if (h_in + 2 >= 0 && h_in + 2 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[6] = load_a<REMAINDER>(A + (2 * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[7] = load_a<REMAINDER>(A + (2 * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[8] = load_a<REMAINDER>(A + (2 * W + 2) * K, mask_v);
+ }
+ }
+
+ __m256i a_sum[4];
+ inner_prod_2d_packed_<3, SUM_A, REMAINDER>(
+ a_v, reinterpret_cast<const __m256i*>(Bp), C, remainder, a_sum);
+ if (SUM_A) {
+ __m256i B_zero_point_v;
+ for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) {
+ if (PER_CHANNEL_QUANTIZATION) {
+ B_zero_point_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(B_zero_point + i * 8));
+ } else {
+ B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]);
+ }
+ _mm256_store_si256(
+ reinterpret_cast<__m256i*>(&row_offsets[i * 8]),
+ _mm256_mullo_epi32(a_sum[i], B_zero_point_v));
+ }
+ }
+}
+
+template <
+ bool SUM_A,
+ bool REMAINDER = false,
+ bool PER_CHANNEL_QUANTIZATION = false>
+static inline __attribute__((always_inline)) void inner_prod_5x5_packed_(
+ int H,
+ int W,
+ int K,
+ int h_in,
+ int w_in,
+ const std::uint8_t* A,
+ std::int32_t A_zero_point,
+ const std::int8_t* Bp,
+ const std::int32_t* B_zero_point,
+ std::int32_t* C,
+ int remainder,
+ std::int32_t* row_offsets) {
+ __m256i A_zero_point_v =
+ _mm256_set1_epi8(static_cast<std::uint8_t>(A_zero_point));
+ __m256i mask_v = _mm256_setzero_si256();
+ if (REMAINDER) {
+ mask_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(masks[remainder / 4]));
+ }
+
+ // The code below can be written as a simple R*S loop but the compiler
+ // doesn't unroll so we're manually unrolling it.
+ // constexpr int R = 5, S = 5;
+ // array<__m256i, R * S> a_v;
+ // for (int r = 0; r < R; ++r) {
+ // for (int s = 0; s < S; ++s) {
+ // if (h_in + r >= 0 && h_in + r < H && w_in + s >= 0 && w_in + s < W) {
+ // if (REMAINDER) {
+ // a_v[r * S + s] =
+ // _mm256_maskload_epi32((const int *)(A + (r * W + s) * K),
+ // mask_v);
+ // } else {
+ // a_v[r * S + s] =
+ // _mm256_lddqu_si256((const __m256i *)(A + (r * W + s) * K));
+ // }
+ // } else {
+ // a_v[r * S + s] = A_zero_point_v;
+ // }
+ // }
+ // }
+ __m256i a_v[25] = {
+ A_zero_point_v, A_zero_point_v, A_zero_point_v, A_zero_point_v,
+ A_zero_point_v, A_zero_point_v, A_zero_point_v, A_zero_point_v,
+ A_zero_point_v, A_zero_point_v, A_zero_point_v, A_zero_point_v,
+ A_zero_point_v, A_zero_point_v, A_zero_point_v, A_zero_point_v,
+ A_zero_point_v, A_zero_point_v, A_zero_point_v, A_zero_point_v,
+ A_zero_point_v, A_zero_point_v, A_zero_point_v, A_zero_point_v,
+ A_zero_point_v,
+ };
+
+ if (h_in >= 0 && h_in < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[0] = load_a<REMAINDER>(A + (0 * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[1] = load_a<REMAINDER>(A + (0 * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[2] = load_a<REMAINDER>(A + (0 * W + 2) * K, mask_v);
+ }
+ if (w_in + 3 >= 0 && w_in + 3 < W) {
+ a_v[3] = load_a<REMAINDER>(A + (0 * W + 3) * K, mask_v);
+ }
+ if (w_in + 4 >= 0 && w_in + 4 < W) {
+ a_v[4] = load_a<REMAINDER>(A + (0 * W + 4) * K, mask_v);
+ }
+ }
+
+ if (h_in + 1 >= 0 && h_in + 1 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[5] = load_a<REMAINDER>(A + (1 * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[6] = load_a<REMAINDER>(A + (1 * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[7] = load_a<REMAINDER>(A + (1 * W + 2) * K, mask_v);
+ }
+ if (w_in + 3 >= 0 && w_in + 3 < W) {
+ a_v[8] = load_a<REMAINDER>(A + (1 * W + 3) * K, mask_v);
+ }
+ if (w_in + 4 >= 0 && w_in + 4 < W) {
+ a_v[9] = load_a<REMAINDER>(A + (1 * W + 4) * K, mask_v);
+ }
+ }
+
+ if (h_in + 2 >= 0 && h_in + 2 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[10] = load_a<REMAINDER>(A + (2 * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[11] = load_a<REMAINDER>(A + (2 * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[12] = load_a<REMAINDER>(A + (2 * W + 2) * K, mask_v);
+ }
+ if (w_in + 3 >= 0 && w_in + 3 < W) {
+ a_v[13] = load_a<REMAINDER>(A + (2 * W + 3) * K, mask_v);
+ }
+ if (w_in + 4 >= 0 && w_in + 4 < W) {
+ a_v[14] = load_a<REMAINDER>(A + (2 * W + 4) * K, mask_v);
+ }
+ }
+
+ if (h_in + 3 >= 0 && h_in + 3 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[15] = load_a<REMAINDER>(A + (3 * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[16] = load_a<REMAINDER>(A + (3 * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[17] = load_a<REMAINDER>(A + (3 * W + 2) * K, mask_v);
+ }
+ if (w_in + 3 >= 0 && w_in + 3 < W) {
+ a_v[18] = load_a<REMAINDER>(A + (3 * W + 3) * K, mask_v);
+ }
+ if (w_in + 4 >= 0 && w_in + 4 < W) {
+ a_v[19] = load_a<REMAINDER>(A + (3 * W + 4) * K, mask_v);
+ }
+ }
+
+ if (h_in + 4 >= 0 && h_in + 4 < H) {
+ if (w_in >= 0 && w_in < W) {
+ a_v[20] = load_a<REMAINDER>(A + (4 * W + 0) * K, mask_v);
+ }
+ if (w_in + 1 >= 0 && w_in + 1 < W) {
+ a_v[21] = load_a<REMAINDER>(A + (4 * W + 1) * K, mask_v);
+ }
+ if (w_in + 2 >= 0 && w_in + 2 < W) {
+ a_v[22] = load_a<REMAINDER>(A + (4 * W + 2) * K, mask_v);
+ }
+ if (w_in + 3 >= 0 && w_in + 3 < W) {
+ a_v[23] = load_a<REMAINDER>(A + (4 * W + 3) * K, mask_v);
+ }
+ if (w_in + 4 >= 0 && w_in + 4 < W) {
+ a_v[24] = load_a<REMAINDER>(A + (4 * W + 4) * K, mask_v);
+ }
+ }
+
+ __m256i a_sum[4];
+ inner_prod_2d_packed_<5, SUM_A, REMAINDER>(
+ a_v, reinterpret_cast<const __m256i*>(Bp), C, remainder, a_sum);
+ if (SUM_A) {
+ __m256i B_zero_point_v;
+ for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) {
+ if (PER_CHANNEL_QUANTIZATION) {
+ B_zero_point_v = _mm256_loadu_si256(
+ reinterpret_cast<const __m256i*>(B_zero_point + i * 8));
+ } else {
+ B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]);
+ }
+ _mm256_store_si256(
+ reinterpret_cast<__m256i*>(&row_offsets[i * 8]),
+ _mm256_mullo_epi32(a_sum[i], B_zero_point_v));
+ }
+ }
+}
+
+template <
+ int S,
+ bool SUM_A,
+ bool REMAINDER = false,
+ bool PER_CHANNEL_QUANTIZATION = false>
+static inline __attribute__((always_inline)) void inner_prod_2d_packed_(
+ int H,
+ int W,
+ int K,
+ int h_in,
+ int w_in,
+ const std::uint8_t* A,
+ std::int32_t A_zero_point,
+ const std::int8_t* Bp,
+ const std::int32_t* B_zero_point,
+ std::int32_t* C,
+ int remainder,
+ std::int32_t* row_offsets) {
+ if (S == 3) {
+ inner_prod_3x3_packed_<SUM_A, REMAINDER, PER_CHANNEL_QUANTIZATION>(
+ H,
+ W,
+ K,
+ h_in,
+ w_in,
+ A,
+ A_zero_point,
+ Bp,
+ B_zero_point,
+ C,
+ remainder,
+ row_offsets);
+ } else {
+ assert(S == 5);
+ inner_prod_5x5_packed_<SUM_A, REMAINDER, PER_CHANNEL_QUANTIZATION>(
+ H,
+ W,
+ K,
+ h_in,
+ w_in,
+ A,
+ A_zero_point,
+ Bp,
+ B_zero_point,
+ C,
+ remainder,
+ row_offsets);
+ }
+}
+
+template <
+ int S,
+ bool FUSE_RELU,
+ bool HAS_BIAS,
+ bool A_SYMMETRIC,
+ bool B_SYMMETRIC,
+ typename BIAS_TYPE>
+static inline __attribute__((always_inline)) void depthwise_2d_kernel_(
+ int H,
+ int W,
+ int K,
+ int h,
+ int w,
+ int stride_h,
+ int stride_w,
+ std::int32_t A_zero_point,
+ const std::uint8_t* A,
+ std::int32_t B_zero_point,
+ const std::int8_t* Bp,
+ float C_multiplier,
+ std::int32_t C_zero_point,
+ std::int32_t* C_int32,
+ std::uint8_t* C_uint8,
+ std::int32_t* row_offsets,
+ const std::int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ float act_times_w_scale) {
+ constexpr int PAD_T = (S - 1) / 2, PAD_L = (S - 1) / 2, PAD_R = (S - 1) / 2;
+ int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+ int h_in = -PAD_T + h * stride_h;
+ int w_in = -PAD_L + w * stride_w;
+
+ constexpr int KERNEL_PROD_ALIGNED = (S * S + 1) / 2 * 2;
+
+ int k;
+ for (k = 0; k < K / 32 * 32; k += 32) {
+ inner_prod_2d_packed_<S, !B_SYMMETRIC /*SUM_A*/>(
+ H,
+ W,
+ K,
+ h_in,
+ w_in,
+ A + (h_in * W + w_in) * K + k,
+ A_zero_point,
+ Bp + k * KERNEL_PROD_ALIGNED,
+ &B_zero_point,
+ C_int32 + k,
+ 0,
+ B_SYMMETRIC ? nullptr : &row_offsets[k]);
+ }
+ int remainder = K - k;
+ if (remainder) {
+ inner_prod_2d_packed_<S, !B_SYMMETRIC, true>(
+ H,
+ W,
+ K,
+ h_in,
+ w_in,
+ A + (h_in * W + w_in) * K + k,
+ A_zero_point,
+ Bp + k * KERNEL_PROD_ALIGNED,
+ &B_zero_point,
+ C_int32 + k,
+ remainder,
+ B_SYMMETRIC ? nullptr : &row_offsets[k]);
+ }
+
+ requantize_<
+ FUSE_RELU,
+ HAS_BIAS,
+ false, /*PER_CHAN_QUANT*/
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
+ A_zero_point,
+ &C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8 + (h * W_OUT + w) * K,
+ K,
+ row_offsets,
+ col_offsets,
+ bias,
+ &act_times_w_scale);
+}
+
+template <
+ int S,
+ bool FUSE_RELU,
+ bool HAS_BIAS,
+ bool A_SYMMETRIC,
+ typename BIAS_TYPE>
+static inline __attribute__((always_inline)) void
+depthwise_2d_per_channel_quantization_kernel_(
+ int H,
+ int W,
+ int K,
+ int h,
+ int w,
+ int stride_h,
+ int stride_w,
+ std::int32_t A_zero_point,
+ const std::uint8_t* A,
+ const std::int32_t* B_zero_point,
+ const std::int8_t* Bp,
+ const float* C_multiplier,
+ std::int32_t C_zero_point,
+ std::int32_t* C_int32,
+ std::uint8_t* C_uint8,
+ std::int32_t* row_offsets,
+ const std::int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ const float* act_times_w_scale) {
+ constexpr int PAD_T = (S - 1) / 2, PAD_L = (S - 1) / 2, PAD_R = (S - 1) / 2;
+ int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+ int h_in = -PAD_T + h * stride_h;
+ int w_in = -PAD_L + w * stride_w;
+
+ constexpr int KERNEL_PROD_ALIGNED = (S * S + 1) / 2 * 2;
+
+ int k;
+ for (k = 0; k < K / 32 * 32; k += 32) {
+ inner_prod_2d_packed_<
+ S,
+ true, /*SUM_A*/
+ false, /*remainder*/
+ true /*per-channel*/>(
+ H,
+ W,
+ K,
+ h_in,
+ w_in,
+ A + (h_in * W + w_in) * K + k,
+ A_zero_point,
+ Bp + k * KERNEL_PROD_ALIGNED,
+ B_zero_point + k,
+ C_int32 + k,
+ 0,
+ &row_offsets[k]);
+ }
+ int remainder = K - k;
+ if (remainder) {
+ inner_prod_2d_packed_<
+ S,
+ true, /*SUM_A*/
+ true, /*remainder*/
+ true /*per-channel*/>(
+ H,
+ W,
+ K,
+ h_in,
+ w_in,
+ A + (h_in * W + w_in) * K + k,
+ A_zero_point,
+ Bp + k * KERNEL_PROD_ALIGNED,
+ B_zero_point + k,
+ C_int32 + k,
+ remainder,
+ &row_offsets[k]);
+ }
+
+ requantize_<
+ FUSE_RELU,
+ HAS_BIAS,
+ true, /*PER_CHAN_QUANT*/
+ A_SYMMETRIC,
+ false, /*B_SYMM*/
+ BIAS_TYPE>(
+ A_zero_point,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8 + (h * W_OUT + w) * K,
+ K,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+}
+
+// TODO: short-circuit when B_zero_point is 0 or A_zero_point is 0
+// This implemntation should be general enough to handle not just 3x3 but other
+// filter shapes by parameterizing with R and S but restricting it to just 3x3
+// for now.
+template <
+ int S,
+ bool FUSE_RELU,
+ bool HAS_BIAS,
+ bool A_SYMMETRIC,
+ bool B_SYMMETRIC,
+ typename BIAS_TYPE>
+static inline __attribute__((always_inline)) void depthwise_2d_(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ std::int32_t A_zero_point,
+ const std::uint8_t* A,
+ std::int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
+ std::int32_t C_zero_point,
+ std::int32_t* C_int32,
+ std::uint8_t* C_uint8,
+ const std::int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ float act_times_w_scale,
+ int thread_id,
+ int num_threads) {
+ assert(K % 8 == 0);
+ constexpr int R = S;
+ constexpr int PAD_T = (R - 1) / 2, PAD_B = PAD_T, PAD_L = (S - 1) / 2,
+ PAD_R = PAD_L;
+ int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+ const std::int8_t* Bp = B.PackedMat();
+
+ std::int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64)));
+
+ int n_begin, n_end;
+ int h_begin, h_end, w_begin, w_end;
+ if (N >= num_threads) {
+ int n_per_thread = (N + num_threads - 1) / num_threads;
+ n_begin = std::min(thread_id * n_per_thread, N);
+ n_end = std::min(n_begin + n_per_thread, N);
+ h_begin = 0;
+ h_end = H_OUT;
+ w_begin = 0;
+ w_end = W_OUT;
+ } else {
+ int nthreads_per_n = num_threads / N;
+ n_begin = std::min(thread_id / nthreads_per_n, N);
+ n_end = std::min(n_begin + 1, N);
+
+ int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads);
+ int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads);
+ int nthreads_of_n = tid_of_n_end - tid_of_n_begin;
+ int tid_within_n = thread_id - tid_of_n_begin;
+ assert(tid_within_n >= 0);
+ assert(tid_within_n < nthreads_of_n);
+
+ // n is processed by num_threads_h * num_threads_w 2D grid of threads
+ int num_threads_h, num_threads_w;
+ // num_threads_w <= num_threads_h
+ std::tie(num_threads_w, num_threads_h) = closest_factors_(nthreads_of_n);
+ int tid_h = tid_within_n / num_threads_w;
+ int tid_w = tid_within_n % num_threads_w;
+
+ int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h;
+ h_begin = std::min(tid_h * h_per_thread, H_OUT);
+ h_end = std::min(h_begin + h_per_thread, H_OUT);
+
+ int w_per_thread = (W_OUT + num_threads_w - 1) / num_threads_w;
+ w_begin = std::min(tid_w * w_per_thread, W_OUT);
+ w_end = std::min(w_begin + w_per_thread, W_OUT);
+ }
+
+ for (int n = n_begin; n < n_end; ++n) {
+ const std::uint8_t* A_base = A + n * H * W * K;
+ std::uint8_t* C_uint8_base = C_uint8 + n * H_OUT * W_OUT * K;
+
+ int h = 0;
+ int w = 0;
+
+ for (h = h_begin; h < std::max(PAD_T, h_begin); ++h) {
+ for (w = w_begin; w < std::max(PAD_L, w_begin); ++w) {
+ depthwise_2d_kernel_<
+ S,
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+ }
+
+ for (; w < std::min(W_OUT - PAD_R, w_end); ++w) {
+ depthwise_2d_kernel_<
+ S,
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+ }
+
+ for (; w < w_end; ++w) {
+ depthwise_2d_kernel_<
+ S,
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+ }
+ }
+
+ for (; h < std::min(H - PAD_B, h_end); ++h) {
+ for (w = w_begin; w < std::max(PAD_L, w_begin); ++w) {
+ depthwise_2d_kernel_<
+ S,
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+ }
+
+ for (; w < std::min(W_OUT - PAD_R, w_end); ++w) {
+ depthwise_2d_kernel_<
+ S,
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+ }
+
+ for (; w < w_end; ++w) {
+ depthwise_2d_kernel_<
+ S,
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+ }
+ }
+
+ for (; h < h_end; ++h) {
+ for (w = w_begin; w < std::max(PAD_L, w_begin); ++w) {
+ depthwise_2d_kernel_<
+ S,
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+ }
+
+ for (; w < std::min(W_OUT - PAD_R, w_end); ++w) {
+ depthwise_2d_kernel_<
+ S,
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+ }
+
+ for (; w < w_end; ++w) {
+ depthwise_2d_kernel_<
+ S,
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ B_SYMMETRIC,
+ BIAS_TYPE>(
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+ }
+ }
+ } // for each n
+};
+
+template <
+ int S,
+ bool FUSE_RELU,
+ bool HAS_BIAS,
+ bool A_SYMMETRIC,
+ typename BIAS_TYPE>
+static inline __attribute__((always_inline)) void
+depthwise_2d_per_channel_quantization_(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ std::int32_t A_zero_point,
+ const std::uint8_t* A,
+ const std::int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ const float* C_multiplier,
+ std::int32_t C_zero_point,
+ std::int32_t* C_int32,
+ std::uint8_t* C_uint8,
+ const std::int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ const float* act_times_w_scale,
+ int thread_id,
+ int num_threads) {
+ assert(K % 8 == 0);
+ constexpr int R = S;
+ constexpr int PAD_T = (R - 1) / 2, PAD_B = PAD_T, PAD_L = (S - 1) / 2,
+ PAD_R = PAD_L;
+ int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
+ int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
+ const std::int8_t* Bp = B.PackedMat();
+
+ std::int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64)));
+
+ int n_begin, n_end;
+ int h_begin, h_end, w_begin, w_end;
+ if (N >= num_threads) {
+ int n_per_thread = (N + num_threads - 1) / num_threads;
+ n_begin = std::min(thread_id * n_per_thread, N);
+ n_end = std::min(n_begin + n_per_thread, N);
+ h_begin = 0;
+ h_end = H_OUT;
+ w_begin = 0;
+ w_end = W_OUT;
+ } else {
+ int nthreads_per_n = num_threads / N;
+ n_begin = std::min(thread_id / nthreads_per_n, N);
+ n_end = std::min(n_begin + 1, N);
+
+ int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads);
+ int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads);
+ int nthreads_of_n = tid_of_n_end - tid_of_n_begin;
+ int tid_within_n = thread_id - tid_of_n_begin;
+ assert(tid_within_n >= 0);
+ assert(tid_within_n < nthreads_of_n);
+
+ // n is processed by num_threads_h * num_threads_w 2D grid of threads
+ int num_threads_h, num_threads_w;
+ // num_threads_w <= num_threads_h
+ std::tie(num_threads_w, num_threads_h) = closest_factors_(nthreads_of_n);
+ int tid_h = tid_within_n / num_threads_w;
+ int tid_w = tid_within_n % num_threads_w;
+
+ int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h;
+ h_begin = std::min(tid_h * h_per_thread, H_OUT);
+ h_end = std::min(h_begin + h_per_thread, H_OUT);
+
+ int w_per_thread = (W_OUT + num_threads_w - 1) / num_threads_w;
+ w_begin = std::min(tid_w * w_per_thread, W_OUT);
+ w_end = std::min(w_begin + w_per_thread, W_OUT);
+ }
+
+ for (int n = n_begin; n < n_end; ++n) {
+ const std::uint8_t* A_base = A + n * H * W * K;
+ std::uint8_t* C_uint8_base = C_uint8 + n * H_OUT * W_OUT * K;
+
+ int h = 0;
+ int w = 0;
+
+ for (h = h_begin; h < std::max(PAD_T, h_begin); ++h) {
+ for (w = w_begin; w < std::max(PAD_L, w_begin); ++w) {
+ depthwise_2d_per_channel_quantization_kernel_<
+ S,
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ BIAS_TYPE>(
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+ }
+
+ for (; w < std::min(W_OUT - PAD_R, w_end); ++w) {
+ depthwise_2d_per_channel_quantization_kernel_<
+ S,
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ BIAS_TYPE>(
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+ }
+
+ for (; w < w_end; ++w) {
+ depthwise_2d_per_channel_quantization_kernel_<
+ S,
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ BIAS_TYPE>(
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+ }
+ }
+
+ for (; h < std::min(H - PAD_B, h_end); ++h) {
+ for (w = w_begin; w < std::max(PAD_L, w_begin); ++w) {
+ depthwise_2d_per_channel_quantization_kernel_<
+ S,
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ BIAS_TYPE>(
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+ }
+
+ for (; w < std::min(W_OUT - PAD_R, w_end); ++w) {
+ depthwise_2d_per_channel_quantization_kernel_<
+ S,
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ BIAS_TYPE>(
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+ }
+
+ for (; w < w_end; ++w) {
+ depthwise_2d_per_channel_quantization_kernel_<
+ S,
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ BIAS_TYPE>(
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+ }
+ }
+
+ for (; h < h_end; ++h) {
+ for (w = w_begin; w < std::max(PAD_L, w_begin); ++w) {
+ depthwise_2d_per_channel_quantization_kernel_<
+ S,
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ BIAS_TYPE>(
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+ }
+
+ for (; w < std::min(W_OUT - PAD_R, w_end); ++w) {
+ depthwise_2d_per_channel_quantization_kernel_<
+ S,
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ BIAS_TYPE>(
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+ }
+
+ for (; w < w_end; ++w) {
+ depthwise_2d_per_channel_quantization_kernel_<
+ S,
+ FUSE_RELU,
+ HAS_BIAS,
+ A_SYMMETRIC,
+ BIAS_TYPE>(
+ H,
+ W,
+ K,
+ h,
+ w,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A_base,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32,
+ C_uint8_base,
+ row_offsets,
+ col_offsets,
+ bias,
+ act_times_w_scale);
+ }
+ }
+ } // for each n
+};
+
+// Dispatch A_SYMMETRIC and B_SYMMETRIC
+template <int S, bool FUSE_RELU, bool HAS_BIAS, typename BIAS_TYPE>
+static void depthwise_2d_(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ std::int32_t A_zero_point,
+ const std::uint8_t* A,
+ std::int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
+ std::int32_t C_zero_point,
+ std::uint8_t* C,
+ const std::int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ float act_times_w_scale,
+ int thread_id,
+ int num_threads) {
+ std::int32_t C_int32_temp[(K + 31) / 32 * 32];
+ if (A_zero_point == 0 || col_offsets == nullptr) {
+ if (B_zero_point == 0) {
+ depthwise_2d_<
+ S,
+ FUSE_RELU,
+ HAS_BIAS,
+ true /*A_symmetric*/,
+ true /*B_symmetric*/,
+ BIAS_TYPE>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_2d_<
+ S,
+ FUSE_RELU,
+ HAS_BIAS,
+ true /*A_symmetric*/,
+ false /*B_symmetric*/,
+ BIAS_TYPE>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ }
+ } else {
+ if (B_zero_point == 0) {
+ depthwise_2d_<
+ S,
+ FUSE_RELU,
+ HAS_BIAS,
+ false /*A_symmetric*/,
+ true /*B_symmetric*/,
+ BIAS_TYPE>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_2d_<
+ S,
+ FUSE_RELU,
+ HAS_BIAS,
+ false /*A_symmetric*/,
+ false /*B_symmetric*/,
+ BIAS_TYPE>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ }
+ }
+}
+
+// Dispatch HAS_BIAS
+template <int S, bool FUSE_RELU, typename BIAS_TYPE>
+static void depthwise_2d_(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ std::int32_t A_zero_point,
+ const std::uint8_t* A,
+ std::int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
+ std::int32_t C_zero_point,
+ std::uint8_t* C,
+ const std::int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ float act_times_w_scale,
+ int thread_id,
+ int num_threads) {
+ if (bias) {
+ depthwise_2d_<S, FUSE_RELU, true /*HAS_BIAS*/, BIAS_TYPE>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_2d_<S, FUSE_RELU, false /*HAS_BIAS*/, BIAS_TYPE>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ }
+}
+
+// Dispatch A_SYMMETRIC
+template <int S, bool FUSE_RELU, bool HAS_BIAS, typename BIAS_TYPE>
+static void depthwise_2d_per_channel_quantization_(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& Bp,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ const float* act_times_w_scale,
+ int thread_id,
+ int num_threads) {
+ int32_t C_int32_temp[(K + 31) / 32 * 32];
+ if (A_zero_point == 0 || col_offsets == nullptr) {
+ depthwise_2d_per_channel_quantization_<
+ S,
+ FUSE_RELU,
+ HAS_BIAS,
+ true /*A_SYMM*/,
+ BIAS_TYPE>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_2d_per_channel_quantization_<
+ S,
+ FUSE_RELU,
+ HAS_BIAS,
+ false /*A_SYMM*/,
+ BIAS_TYPE>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C_int32_temp,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ }
+}
+
+// Dispatch HAS_BIAS
+template <int S, bool FUSE_RELU, typename BIAS_TYPE>
+static void depthwise_2d_per_channel_quantization_(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& Bp,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ const float* act_times_w_scale,
+ int thread_id,
+ int num_threads) {
+ if (bias) {
+ depthwise_2d_per_channel_quantization_<
+ S,
+ FUSE_RELU,
+ true /* HAS_BIAS */,
+ BIAS_TYPE>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_2d_per_channel_quantization_<
+ S,
+ FUSE_RELU,
+ false /* HAS_BIAS */,
+ BIAS_TYPE>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ }
+}
+
+template <typename BIAS_TYPE = std::int32_t>
+FBGEMM_API void depthwise_3x3_pad_1(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ std::int32_t A_zero_point,
+ const std::uint8_t* A,
+ std::int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& Bp,
+ float C_multiplier,
+ std::int32_t C_zero_point,
+ std::uint8_t* C,
+ const std::int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ bool fuse_relu = false,
+ float act_times_w_scale = 1.0f,
+ int thread_id = 0,
+ int num_threads = 1);
+
+template <typename BIAS_TYPE = std::int32_t>
+FBGEMM_API void depthwise_3x3_per_channel_quantization_pad_1(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ std::int32_t A_zero_point,
+ const std::uint8_t* A,
+ const std::int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& Bp,
+ const float* C_multiplier,
+ std::int32_t C_zero_point,
+ std::uint8_t* C,
+ const std::int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ bool fuse_relu = false,
+ const float* act_times_w_scale = nullptr,
+ int thread_id = 0,
+ int num_threads = 1);
+
+} // namespace fbgemm
diff --git a/src/FbgemmI8Depthwise3DAvx2.cc b/src/FbgemmI8Depthwise3DAvx2.cc
index 2114b20..ee8cc29 100644
--- a/src/FbgemmI8Depthwise3DAvx2.cc
+++ b/src/FbgemmI8Depthwise3DAvx2.cc
@@ -1237,97 +1237,6 @@ void depthwise_3x3x3_per_channel_quantization_pad_1(
}
}
-// To be removed
-void depthwise_3x3x3_pad_1(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- int32_t B_zero_point,
- const PackedDepthWiseConvMatrix& B,
- float C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const int32_t* bias,
- bool fuse_relu,
- int thread_id,
- int num_threads) {
- depthwise_3x3x3_pad_1<int32_t>(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- fuse_relu,
- 1.0f, // act_scale * weight_scale
- thread_id,
- num_threads);
-}
-
-void depthwise_3x3x3_per_channel_quantization_pad_1(
- int N,
- int T,
- int H,
- int W,
- int K,
- int stride_t,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- const int32_t* B_zero_point,
- const PackedDepthWiseConvMatrix& B,
- const float* C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const int32_t* bias,
- bool fuse_relu,
- int thread_id,
- int num_threads) {
- depthwise_3x3x3_per_channel_quantization_pad_1(
- N,
- T,
- H,
- W,
- K,
- stride_t,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- fuse_relu,
- nullptr, // act_scale * weight_scale
- thread_id,
- num_threads);
-}
-
template void depthwise_3x3x3_pad_1(
int N,
int T,
diff --git a/src/FbgemmI8Depthwise3x3Avx2.cc b/src/FbgemmI8Depthwise3x3Avx2.cc
new file mode 100644
index 0000000..5226a04
--- /dev/null
+++ b/src/FbgemmI8Depthwise3x3Avx2.cc
@@ -0,0 +1,618 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "fbgemm/FbgemmI8DepthwiseAvx2.h"
+
+#include <string>
+
+#include "FbgemmI8Depthwise2DAvx2-inl.h"
+
+using namespace std;
+
+namespace fbgemm {
+
+// Dispatch input shape and FUSE_RELU
+// assumption: W > 3 and H > 3
+template <typename BIAS_TYPE>
+void depthwise_3x3_pad_1(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ bool fuse_relu,
+ float act_times_w_scale,
+ int thread_id,
+ int num_threads) {
+ if (B.GetKernelProduct() != 3 * 3) {
+ string msg =
+ "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " +
+ to_string(3 * 3) + " but has " + to_string(B.GetKernelProduct());
+ throw logic_error(msg);
+ }
+ if (stride_h == 0 || stride_w == 0 || num_threads == 0) {
+ assert(0 && "stride_h == 0 || stride_w == 0 || num_threads == 0");
+ return;
+ }
+ if (N == 0) {
+ // In C2, batch size 0 is allowed, so we should just early return.
+ return;
+ }
+ if (fuse_relu) {
+ if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
+ depthwise_2d_<3, true /* FUSE_RELU */, BIAS_TYPE>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
+ depthwise_2d_<3, true /* FUSE_RELU */, BIAS_TYPE>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else if (1 == stride_h && 1 == stride_w) {
+ depthwise_2d_<3, true /* FUSE_RELU */, BIAS_TYPE>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else if (2 == stride_h && 2 == stride_w) {
+ depthwise_2d_<3, true /* FUSE_RELU */, BIAS_TYPE>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_2d_<3, true /* FUSE_RELU */, BIAS_TYPE>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ }
+ } else {
+ if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
+ depthwise_2d_<3, false /* FUSE_RELU */, BIAS_TYPE>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
+ depthwise_2d_<3, false /* FUSE_RELU */, BIAS_TYPE>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else if (1 == stride_h && 1 == stride_w) {
+ depthwise_2d_<3, false /* FUSE_RELU */, BIAS_TYPE>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else if (2 == stride_h && 2 == stride_w) {
+ depthwise_2d_<3, false /* FUSE_RELU */, BIAS_TYPE>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_2d_<3, false /* FUSE_RELU */, BIAS_TYPE>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ B,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ }
+ }
+}
+
+// Dispatch input shape and FUSE_RELU
+template <typename BIAS_TYPE>
+void depthwise_3x3_per_channel_quantization_pad_1(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& Bp,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ bool fuse_relu,
+ const float* act_times_w_scale,
+ int thread_id,
+ int num_threads) {
+ if (Bp.GetKernelProduct() != 3 * 3) {
+ string msg =
+ "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " +
+ to_string(3 * 3) + " but has " + to_string(Bp.GetKernelProduct());
+ throw logic_error(msg);
+ }
+ if (stride_h == 0 || stride_w == 0 || num_threads == 0) {
+ assert(0 && "stride_h == 0 || stride_w == 0 || num_threads == 0");
+ return;
+ }
+ if (N == 0) {
+ // In C2, batch size 0 is allowed, so we should just early return.
+ return;
+ }
+ if (fuse_relu) {
+ if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
+ depthwise_2d_per_channel_quantization_<
+ 3,
+ true /* FUSE_RELU */,
+ BIAS_TYPE>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
+ depthwise_2d_per_channel_quantization_<
+ 3,
+ true /* FUSE_RELU */,
+ BIAS_TYPE>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else if (1 == stride_h && 1 == stride_w) {
+ depthwise_2d_per_channel_quantization_<
+ 3,
+ true /* FUSE_RELU */,
+ BIAS_TYPE>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else if (2 == stride_h && 2 == stride_w) {
+ depthwise_2d_per_channel_quantization_<
+ 3,
+ true /* FUSE_RELU */,
+ BIAS_TYPE>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_2d_per_channel_quantization_<
+ 3,
+ true /* FUSE_RELU */,
+ BIAS_TYPE>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ }
+ } else {
+ if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
+ depthwise_2d_per_channel_quantization_<
+ 3,
+ false /* FUSE_RELU */,
+ BIAS_TYPE>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
+ depthwise_2d_per_channel_quantization_<
+ 3,
+ false /* FUSE_RELU */,
+ BIAS_TYPE>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else if (1 == stride_h && 1 == stride_w) {
+ depthwise_2d_per_channel_quantization_<
+ 3,
+ false /* FUSE_RELU */,
+ BIAS_TYPE>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else if (2 == stride_h && 2 == stride_w) {
+ depthwise_2d_per_channel_quantization_<
+ 3,
+ false /* FUSE_RELU */,
+ BIAS_TYPE>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_2d_per_channel_quantization_<
+ 3,
+ false /* FUSE_RELU */,
+ BIAS_TYPE>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ }
+ }
+}
+
+template void depthwise_3x3_pad_1(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ bool fuse_relu,
+ float act_times_w_scale,
+ int thread_id,
+ int num_threads);
+
+template void depthwise_3x3_pad_1(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ int32_t B_zero_point,
+ const PackedDepthWiseConvMatrix& B,
+ float C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const float* bias,
+ bool fuse_relu,
+ float act_times_w_scale,
+ int thread_id,
+ int num_threads);
+
+template void depthwise_3x3_per_channel_quantization_pad_1(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& Bp,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ bool fuse_relu,
+ const float* act_times_w_scale,
+ int thread_id,
+ int num_threads);
+
+template void depthwise_3x3_per_channel_quantization_pad_1(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& Bp,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const float* bias,
+ bool fuse_relu,
+ const float* act_times_w_scale,
+ int thread_id,
+ int num_threads);
+
+} // namespace fbgemm
diff --git a/src/FbgemmI8DepthwiseAvx2.cc b/src/FbgemmI8DepthwiseAvx2.cc
index 994f206..ada1c75 100644
--- a/src/FbgemmI8DepthwiseAvx2.cc
+++ b/src/FbgemmI8DepthwiseAvx2.cc
@@ -8,1126 +8,16 @@
#include "fbgemm/Utils.h"
#include <string>
-#include <tuple> // for tie
-#include "FbgemmI8DepthwiseAvx2-inl.h"
+#include "FbgemmI8Depthwise2DAvx2-inl.h"
using namespace std;
namespace fbgemm {
-template <bool SUM_A = false, bool REMAINDER = false>
-static inline ALWAYS_INLINE void inner_prod_3x3_packed_(
- const __m256i* a_v,
- const __m256i* Bp,
- int32_t* C,
- int remainder,
- __m256i* a_sum = nullptr) {
- return inner_prod_packed_<9, SUM_A, REMAINDER>(a_v, Bp, C, remainder, a_sum);
-}
-
-template <
- bool SUM_A,
- bool REMAINDER = false,
- bool PER_CHANNEL_QUANTIZATION = false>
-static inline ALWAYS_INLINE void inner_prod_3x3_packed_(
- int H,
- int W,
- int K,
- int h_in,
- int w_in,
- const uint8_t* A,
- int32_t A_zero_point,
- const int8_t* Bp,
- const int32_t* B_zero_point,
- int32_t* C,
- int remainder,
- int32_t* row_offsets) {
- __m256i A_zero_point_v = _mm256_set1_epi8(static_cast<uint8_t>(A_zero_point));
- __m256i mask_v = _mm256_setzero_si256();
- if (REMAINDER) {
- mask_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(masks[remainder / 4]));
- }
-
- // The code below can be written as a simple R*S loop but the compiler
- // doesn't unroll so we're manually unrolling it.
- // constexpr int R = 3, S = 3;
- // array<__m256i, R * S> a_v;
- // for (int r = 0; r < R; ++r) {
- // for (int s = 0; s < S; ++s) {
- // if (h_in + r >= 0 && h_in + r < H && w_in + s >= 0 && w_in + s < W) {
- // if (REMAINDER) {
- // a_v[r * S + s] =
- // _mm256_maskload_epi32((const int *)(A + (r * W + s) * K),
- // mask_v);
- // } else {
- // a_v[r * S + s] =
- // _mm256_lddqu_si256((const __m256i *)(A + (r * W + s) * K));
- // }
- // } else {
- // a_v[r * S + s] = A_zero_point_v;
- // }
- // }
- // }
- __m256i a_v[9] = {
- A_zero_point_v,
- A_zero_point_v,
- A_zero_point_v,
- A_zero_point_v,
- A_zero_point_v,
- A_zero_point_v,
- A_zero_point_v,
- A_zero_point_v,
- A_zero_point_v,
- };
-
- if (h_in >= 0 && h_in < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[0] = load_a<REMAINDER>(A + (0 * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[1] = load_a<REMAINDER>(A + (0 * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[2] = load_a<REMAINDER>(A + (0 * W + 2) * K, mask_v);
- }
- }
-
- if (h_in + 1 >= 0 && h_in + 1 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[3] = load_a<REMAINDER>(A + (1 * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[4] = load_a<REMAINDER>(A + (1 * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[5] = load_a<REMAINDER>(A + (1 * W + 2) * K, mask_v);
- }
- }
-
- if (h_in + 2 >= 0 && h_in + 2 < H) {
- if (w_in >= 0 && w_in < W) {
- a_v[6] = load_a<REMAINDER>(A + (2 * W + 0) * K, mask_v);
- }
- if (w_in + 1 >= 0 && w_in + 1 < W) {
- a_v[7] = load_a<REMAINDER>(A + (2 * W + 1) * K, mask_v);
- }
- if (w_in + 2 >= 0 && w_in + 2 < W) {
- a_v[8] = load_a<REMAINDER>(A + (2 * W + 2) * K, mask_v);
- }
- }
-
- __m256i a_sum[4];
- inner_prod_3x3_packed_<SUM_A, REMAINDER>(
- a_v, reinterpret_cast<const __m256i*>(Bp), C, remainder, a_sum);
- if (SUM_A) {
- __m256i B_zero_point_v;
- for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) {
- if (PER_CHANNEL_QUANTIZATION) {
- B_zero_point_v = _mm256_loadu_si256(
- reinterpret_cast<const __m256i*>(B_zero_point + i * 8));
- } else {
- B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]);
- }
- _mm256_store_si256(
- reinterpret_cast<__m256i*>(&row_offsets[i * 8]),
- _mm256_mullo_epi32(a_sum[i], B_zero_point_v));
- }
- }
-}
-
-template <
- bool FUSE_RELU,
- bool HAS_BIAS,
- bool A_SYMMETRIC,
- bool B_SYMMETRIC,
- typename BIAS_TYPE>
-static inline ALWAYS_INLINE void depthwise_3x3_kernel_(
- int H,
- int W,
- int K,
- int h,
- int w,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- int32_t B_zero_point,
- const int8_t* Bp,
- float C_multiplier,
- int32_t C_zero_point,
- int32_t* C_int32,
- uint8_t* C_uint8,
- int32_t* row_offsets,
- const int32_t* col_offsets,
- const BIAS_TYPE* bias,
- float act_times_w_scale) {
- constexpr int S = 3;
- constexpr int PAD_T = 1, PAD_L = 1, PAD_R = 1;
- int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
- int h_in = -PAD_T + h * stride_h;
- int w_in = -PAD_L + w * stride_w;
-
- int k;
- for (k = 0; k < K / 32 * 32; k += 32) {
- inner_prod_3x3_packed_<!B_SYMMETRIC /*SUM_A*/>(
- H,
- W,
- K,
- h_in,
- w_in,
- A + (h_in * W + w_in) * K + k,
- A_zero_point,
- Bp + k * 10,
- &B_zero_point,
- C_int32 + k,
- 0,
- B_SYMMETRIC ? nullptr : &row_offsets[k]);
- }
- int remainder = K - k;
- if (remainder) {
- inner_prod_3x3_packed_<!B_SYMMETRIC, true>(
- H,
- W,
- K,
- h_in,
- w_in,
- A + (h_in * W + w_in) * K + k,
- A_zero_point,
- Bp + k * 10,
- &B_zero_point,
- C_int32 + k,
- remainder,
- B_SYMMETRIC ? nullptr : &row_offsets[k]);
- }
-
- requantize_<
- FUSE_RELU,
- HAS_BIAS,
- false, /*PER_CHAN_QUANT*/
- A_SYMMETRIC,
- B_SYMMETRIC,
- BIAS_TYPE>(
- A_zero_point,
- &C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8 + (h * W_OUT + w) * K,
- K,
- row_offsets,
- col_offsets,
- bias,
- &act_times_w_scale);
-}
-
-template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, typename BIAS_TYPE>
-static inline ALWAYS_INLINE void
-depthwise_3x3_per_channel_quantization_kernel_(
- int H,
- int W,
- int K,
- int h,
- int w,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- const int32_t* B_zero_point,
- const int8_t* Bp,
- const float* C_multiplier,
- int32_t C_zero_point,
- int32_t* C_int32,
- uint8_t* C_uint8,
- int32_t* row_offsets,
- const int32_t* col_offsets,
- const BIAS_TYPE* bias,
- const float* act_times_w_scale) {
- constexpr int S = 3;
- constexpr int PAD_T = 1, PAD_L = 1, PAD_R = 1;
- int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
- int h_in = -PAD_T + h * stride_h;
- int w_in = -PAD_L + w * stride_w;
-
- int k;
- for (k = 0; k < K / 32 * 32; k += 32) {
- inner_prod_3x3_packed_<
- true, /*SUM_A*/
- false, /*remainder*/
- true /*per-channel*/>(
- H,
- W,
- K,
- h_in,
- w_in,
- A + (h_in * W + w_in) * K + k,
- A_zero_point,
- Bp + k * 10,
- B_zero_point + k,
- C_int32 + k,
- 0,
- &row_offsets[k]);
- }
- int remainder = K - k;
- if (remainder) {
- inner_prod_3x3_packed_<
- true, /*SUM_A*/
- true, /*remainder*/
- true /*per-channel*/>(
- H,
- W,
- K,
- h_in,
- w_in,
- A + (h_in * W + w_in) * K + k,
- A_zero_point,
- Bp + k * 10,
- B_zero_point + k,
- C_int32 + k,
- remainder,
- &row_offsets[k]);
- }
-
- requantize_<
- FUSE_RELU,
- HAS_BIAS,
- true, /*PER_CHAN_QUANT*/
- A_SYMMETRIC,
- false, /*B_SYMM*/
- BIAS_TYPE>(
- A_zero_point,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8 + (h * W_OUT + w) * K,
- K,
- row_offsets,
- col_offsets,
- bias,
- act_times_w_scale);
-}
-
-// TODO: short-circuit when B_zero_point is 0 or A_zero_point is 0
-// This implemntation should be general enough to handle not just 3x3 but other
-// filter shapes by parameterizing with R and S but restricting it to just 3x3
-// for now.
-template <
- bool FUSE_RELU,
- bool HAS_BIAS,
- bool A_SYMMETRIC,
- bool B_SYMMETRIC,
- typename BIAS_TYPE>
-static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
- int N,
- int H,
- int W,
- int K,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- int32_t B_zero_point,
- const PackedDepthWiseConvMatrix& B,
- float C_multiplier,
- int32_t C_zero_point,
- int32_t* C_int32,
- uint8_t* C_uint8,
- const int32_t* col_offsets,
- const BIAS_TYPE* bias,
- float act_times_w_scale,
- int thread_id,
- int num_threads) {
- assert(K % 8 == 0);
- constexpr int R = 3, S = 3;
- constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
- int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
- const int8_t* Bp = B.PackedMat();
-
- int32_t* row_offsets = static_cast<int32_t *>(ALIGNED_MALLOC(((K + 31) / 32 * 32)*sizeof(int32_t), 64));
-
- int n_begin, n_end;
- int h_begin, h_end, w_begin, w_end;
- if (N >= num_threads) {
- int n_per_thread = (N + num_threads - 1) / num_threads;
- n_begin = std::min(thread_id * n_per_thread, N);
- n_end = std::min(n_begin + n_per_thread, N);
- h_begin = 0;
- h_end = H_OUT;
- w_begin = 0;
- w_end = W_OUT;
- } else {
- int nthreads_per_n = num_threads / N;
- n_begin = std::min(thread_id / nthreads_per_n, N);
- n_end = std::min(n_begin + 1, N);
-
- int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads);
- int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads);
- int nthreads_of_n = tid_of_n_end - tid_of_n_begin;
- int tid_within_n = thread_id - tid_of_n_begin;
- assert(tid_within_n >= 0);
- assert(tid_within_n < nthreads_of_n);
-
- // n is processed by num_threads_h * num_threads_w 2D grid of threads
- int num_threads_h, num_threads_w;
- // num_threads_w <= num_threads_h
- tie(num_threads_w, num_threads_h) = closest_factors_(nthreads_of_n);
- int tid_h = tid_within_n / num_threads_w;
- int tid_w = tid_within_n % num_threads_w;
-
- int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h;
- h_begin = std::min(tid_h * h_per_thread, H_OUT);
- h_end = std::min(h_begin + h_per_thread, H_OUT);
-
- int w_per_thread = (W_OUT + num_threads_w - 1) / num_threads_w;
- w_begin = std::min(tid_w * w_per_thread, W_OUT);
- w_end = std::min(w_begin + w_per_thread, W_OUT);
- }
-
- for (int n = n_begin; n < n_end; ++n) {
- const uint8_t* A_base = A + n * H * W * K;
- uint8_t* C_uint8_base = C_uint8 + n * H_OUT * W_OUT * K;
-
- int h = 0;
- int w = 0;
-
- if (h_begin == 0) {
- if (w_begin == 0) {
- depthwise_3x3_kernel_<
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC,
- B_SYMMETRIC,
- BIAS_TYPE>(
- H,
- W,
- K,
- h,
- w,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias,
- act_times_w_scale);
- }
-
- for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
- depthwise_3x3_kernel_<
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC,
- B_SYMMETRIC,
- BIAS_TYPE>(
- H,
- W,
- K,
- h,
- w,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias,
- act_times_w_scale);
- }
-
- if (w_end == W_OUT) {
- w = W_OUT - 1;
- depthwise_3x3_kernel_<
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC,
- B_SYMMETRIC,
- BIAS_TYPE>(
- H,
- W,
- K,
- h,
- w,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias,
- act_times_w_scale);
- }
- }
-
- for (h = std::max(1, h_begin); h < std::min(H - 1, h_end); ++h) {
- if (w_begin == 0) {
- w = 0;
- depthwise_3x3_kernel_<
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC,
- B_SYMMETRIC,
- BIAS_TYPE>(
- H,
- W,
- K,
- h,
- w,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias,
- act_times_w_scale);
- }
-
- for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
- depthwise_3x3_kernel_<
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC,
- B_SYMMETRIC,
- BIAS_TYPE>(
- H,
- W,
- K,
- h,
- w,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias,
- act_times_w_scale);
- }
-
- if (w_end == W_OUT) {
- w = W_OUT - 1;
- depthwise_3x3_kernel_<
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC,
- B_SYMMETRIC,
- BIAS_TYPE>(
- H,
- W,
- K,
- h,
- w,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias,
- act_times_w_scale);
- }
- }
-
- if (h_end == H_OUT) {
- h = H_OUT - 1;
- w = 0;
- if (w_begin == 0) {
- depthwise_3x3_kernel_<
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC,
- B_SYMMETRIC,
- BIAS_TYPE>(
- H,
- W,
- K,
- h,
- w,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias,
- act_times_w_scale);
- }
-
- for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
- depthwise_3x3_kernel_<
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC,
- B_SYMMETRIC,
- BIAS_TYPE>(
- H,
- W,
- K,
- h,
- w,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias,
- act_times_w_scale);
- }
-
- if (w_end == W_OUT) {
- w = W_OUT - 1;
- depthwise_3x3_kernel_<
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC,
- B_SYMMETRIC,
- BIAS_TYPE>(
- H,
- W,
- K,
- h,
- w,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias,
- act_times_w_scale);
- }
- }
- } // for each n
- FREE(row_offsets);
-};
-
-template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, typename BIAS_TYPE>
-static inline ALWAYS_INLINE void
-depthwise_3x3_per_channel_quantization_pad_1_(
- int N,
- int H,
- int W,
- int K,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- const int32_t* B_zero_point,
- const PackedDepthWiseConvMatrix& B,
- const float* C_multiplier,
- int32_t C_zero_point,
- int32_t* C_int32,
- uint8_t* C_uint8,
- const int32_t* col_offsets,
- const BIAS_TYPE* bias,
- const float* act_times_w_scale,
- int thread_id,
- int num_threads) {
- assert(K % 8 == 0);
- constexpr int R = 3, S = 3;
- constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1;
- int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1;
- int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
- const int8_t* Bp = B.PackedMat();
-
- int32_t* row_offsets = static_cast<int32_t*>(ALIGNED_MALLOC(((K + 31) / 32 * 32)*sizeof(int32_t), 64)); // __attribute__((aligned(64)));
-
- int n_begin, n_end;
- int h_begin, h_end, w_begin, w_end;
- if (N >= num_threads) {
- int n_per_thread = (N + num_threads - 1) / num_threads;
- n_begin = std::min(thread_id * n_per_thread, N);
- n_end = std::min(n_begin + n_per_thread, N);
- h_begin = 0;
- h_end = H_OUT;
- w_begin = 0;
- w_end = W_OUT;
- } else {
- int nthreads_per_n = num_threads / N;
- n_begin = std::min(thread_id / nthreads_per_n, N);
- n_end = std::min(n_begin + 1, N);
-
- int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads);
- int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads);
- int nthreads_of_n = tid_of_n_end - tid_of_n_begin;
- int tid_within_n = thread_id - tid_of_n_begin;
- assert(tid_within_n >= 0);
- assert(tid_within_n < nthreads_of_n);
-
- // n is processed by num_threads_h * num_threads_w 2D grid of threads
- int num_threads_h, num_threads_w;
- // num_threads_w <= num_threads_h
- tie(num_threads_w, num_threads_h) = closest_factors_(nthreads_of_n);
- int tid_h = tid_within_n / num_threads_w;
- int tid_w = tid_within_n % num_threads_w;
-
- int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h;
- h_begin = std::min(tid_h * h_per_thread, H_OUT);
- h_end = std::min(h_begin + h_per_thread, H_OUT);
-
- int w_per_thread = (W_OUT + num_threads_w - 1) / num_threads_w;
- w_begin = std::min(tid_w * w_per_thread, W_OUT);
- w_end = std::min(w_begin + w_per_thread, W_OUT);
- }
-
- for (int n = n_begin; n < n_end; ++n) {
- const uint8_t* A_base = A + n * H * W * K;
- uint8_t* C_uint8_base = C_uint8 + n * H_OUT * W_OUT * K;
-
- int h = 0;
- int w = 0;
-
- if (h_begin == 0) {
- if (w_begin == 0) {
- depthwise_3x3_per_channel_quantization_kernel_<
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC,
- BIAS_TYPE>(
- H,
- W,
- K,
- h,
- w,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias,
- act_times_w_scale);
- }
-
- for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
- depthwise_3x3_per_channel_quantization_kernel_<
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC,
- BIAS_TYPE>(
- H,
- W,
- K,
- h,
- w,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias,
- act_times_w_scale);
- }
-
- if (w_end == W_OUT) {
- w = W_OUT - 1;
- depthwise_3x3_per_channel_quantization_kernel_<
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC,
- BIAS_TYPE>(
- H,
- W,
- K,
- h,
- w,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias,
- act_times_w_scale);
- }
- }
-
- for (h = std::max(1, h_begin); h < std::min(H - 1, h_end); ++h) {
- if (w_begin == 0) {
- w = 0;
- depthwise_3x3_per_channel_quantization_kernel_<
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC,
- BIAS_TYPE>(
- H,
- W,
- K,
- h,
- w,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias,
- act_times_w_scale);
- }
-
- for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
- depthwise_3x3_per_channel_quantization_kernel_<
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC,
- BIAS_TYPE>(
- H,
- W,
- K,
- h,
- w,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias,
- act_times_w_scale);
- }
-
- if (w_end == W_OUT) {
- w = W_OUT - 1;
- depthwise_3x3_per_channel_quantization_kernel_<
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC,
- BIAS_TYPE>(
- H,
- W,
- K,
- h,
- w,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias,
- act_times_w_scale);
- }
- }
-
- if (h_end == H_OUT) {
- h = H_OUT - 1;
- w = 0;
- if (w_begin == 0) {
- depthwise_3x3_per_channel_quantization_kernel_<
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC,
- BIAS_TYPE>(
- H,
- W,
- K,
- h,
- w,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias,
- act_times_w_scale);
- }
-
- for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) {
- depthwise_3x3_per_channel_quantization_kernel_<
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC,
- BIAS_TYPE>(
- H,
- W,
- K,
- h,
- w,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias,
- act_times_w_scale);
- }
-
- if (w_end == W_OUT) {
- w = W_OUT - 1;
- depthwise_3x3_per_channel_quantization_kernel_<
- FUSE_RELU,
- HAS_BIAS,
- A_SYMMETRIC,
- BIAS_TYPE>(
- H,
- W,
- K,
- h,
- w,
- stride_h,
- stride_w,
- A_zero_point,
- A_base,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32,
- C_uint8_base,
- row_offsets,
- col_offsets,
- bias,
- act_times_w_scale);
- }
- }
- } // for each n
-};
-
-// Dispatch A_SYMMETRIC and B_SYMMETRIC
-template <bool FUSE_RELU, bool HAS_BIAS, typename BIAS_TYPE>
-static void depthwise_3x3_pad_1_(
- int N,
- int H,
- int W,
- int K,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- int32_t B_zero_point,
- const PackedDepthWiseConvMatrix& B,
- float C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const BIAS_TYPE* bias,
- float act_times_w_scale,
- int thread_id,
- int num_threads) {
- int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32];
- if (A_zero_point == 0 || col_offsets == nullptr) {
- if (B_zero_point == 0) {
- depthwise_3x3_pad_1_<
- FUSE_RELU,
- HAS_BIAS,
- true /*A_symmetric*/,
- true /*B_symmetric*/,
- BIAS_TYPE>(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3_pad_1_<
- FUSE_RELU,
- HAS_BIAS,
- true /*A_symmetric*/,
- false /*B_symmetric*/,
- BIAS_TYPE>(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- }
- } else {
- if (B_zero_point == 0) {
- depthwise_3x3_pad_1_<
- FUSE_RELU,
- HAS_BIAS,
- false /*A_symmetric*/,
- true /*B_symmetric*/,
- BIAS_TYPE>(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3_pad_1_<
- FUSE_RELU,
- HAS_BIAS,
- false /*A_symmetric*/,
- false /*B_symmetric*/,
- BIAS_TYPE>(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- }
- }
- delete[] C_int32_temp;
-}
-
-// Dispatch HAS_BIAS
-template <bool FUSE_RELU, typename BIAS_TYPE>
-static void depthwise_3x3_pad_1_(
+// Dispatch input shape and FUSE_RELU
+template <typename BIAS_TYPE /*=std::int32_t*/>
+void depthwise_2d_same_pad(
int N,
int H,
int W,
@@ -1143,31 +33,12 @@ static void depthwise_3x3_pad_1_(
uint8_t* C,
const int32_t* col_offsets,
const BIAS_TYPE* bias,
+ bool fuse_relu,
float act_times_w_scale,
int thread_id,
int num_threads) {
- if (bias) {
- depthwise_3x3_pad_1_<FUSE_RELU, true /*HAS_BIAS*/, BIAS_TYPE>(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3_pad_1_<FUSE_RELU, false /*HAS_BIAS*/, BIAS_TYPE>(
+ if (B.GetKernelProduct() == 3 * 3) {
+ depthwise_3x3_pad_1(
N,
H,
W,
@@ -1183,284 +54,21 @@ static void depthwise_3x3_pad_1_(
C,
col_offsets,
bias,
+ fuse_relu,
act_times_w_scale,
thread_id,
num_threads);
+ return;
}
-}
-// Dispatch input shape and FUSE_RELU
-// assumption: W > 3 and H > 3
-template <typename BIAS_TYPE>
-void depthwise_3x3_pad_1(
- int N,
- int H,
- int W,
- int K,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- int32_t B_zero_point,
- const PackedDepthWiseConvMatrix& B,
- float C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const BIAS_TYPE* bias,
- bool fuse_relu,
- float act_times_w_scale,
- int thread_id,
- int num_threads) {
- if (B.GetKernelProduct() != 3 * 3) {
+ if (B.GetKernelProduct() != 5 * 5) {
string msg =
"[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " +
- to_string(3 * 3) + " but has " + to_string(B.GetKernelProduct());
+ to_string(5 * 5) + " but has " + to_string(B.GetKernelProduct());
throw logic_error(msg);
}
- if (stride_h == 0 || stride_w == 0 || num_threads == 0) {
- assert(0 && "stride_h == 0 || stride_w == 0 || num_threads == 0");
- return;
- }
- if (N == 0) {
- // In C2, batch size 0 is allowed, so we should just early return.
- return;
- }
if (fuse_relu) {
- if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
- depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
- depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- } else if (1 == stride_h && 1 == stride_w) {
- depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- } else if (2 == stride_h && 2 == stride_w) {
- depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- }
- } else {
- if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
- depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
- depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- } else if (1 == stride_h && 1 == stride_w) {
- depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- } else if (2 == stride_h && 2 == stride_w) {
- depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- }
- }
-}
-
-// Dispatch A_SYMMETRIC
-template <bool FUSE_RELU, bool HAS_BIAS, typename BIAS_TYPE>
-static void depthwise_3x3_per_channel_quantization_pad_1_(
- int N,
- int H,
- int W,
- int K,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- const int32_t* B_zero_point,
- const PackedDepthWiseConvMatrix& Bp,
- const float* C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const BIAS_TYPE* bias,
- const float* act_times_w_scale,
- int thread_id,
- int num_threads) {
- int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32];
- if (A_zero_point == 0 || col_offsets == nullptr) {
- depthwise_3x3_per_channel_quantization_pad_1_<
- FUSE_RELU,
- HAS_BIAS,
- true /*A_SYMM*/,
- BIAS_TYPE>(
+ depthwise_2d_<5, true /* FUSE_RELU */, BIAS_TYPE>(
N,
H,
W,
@@ -1470,81 +78,7 @@ static void depthwise_3x3_per_channel_quantization_pad_1_(
A_zero_point,
A,
B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3_per_channel_quantization_pad_1_<
- FUSE_RELU,
- HAS_BIAS,
- false /*A_SYMM*/,
- BIAS_TYPE>(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C_int32_temp,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- }
- delete[] C_int32_temp;
-}
-
-// Dispatch HAS_BIAS
-template <bool FUSE_RELU, typename BIAS_TYPE>
-static void depthwise_3x3_per_channel_quantization_pad_1_(
- int N,
- int H,
- int W,
- int K,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- const int32_t* B_zero_point,
- const PackedDepthWiseConvMatrix& Bp,
- const float* C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const BIAS_TYPE* bias,
- const float* act_times_w_scale,
- int thread_id,
- int num_threads) {
- if (bias) {
- depthwise_3x3_per_channel_quantization_pad_1_<
- FUSE_RELU,
- true /* HAS_BIAS */,
- BIAS_TYPE>(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- Bp,
+ B,
C_multiplier,
C_zero_point,
C,
@@ -1554,10 +88,7 @@ static void depthwise_3x3_per_channel_quantization_pad_1_(
thread_id,
num_threads);
} else {
- depthwise_3x3_per_channel_quantization_pad_1_<
- FUSE_RELU,
- false /* HAS_BIAS */,
- BIAS_TYPE>(
+ depthwise_2d_<5, false /* FUSE_RELU */, BIAS_TYPE>(
N,
H,
W,
@@ -1567,7 +98,7 @@ static void depthwise_3x3_per_channel_quantization_pad_1_(
A_zero_point,
A,
B_zero_point,
- Bp,
+ B,
C_multiplier,
C_zero_point,
C,
@@ -1579,354 +110,7 @@ static void depthwise_3x3_per_channel_quantization_pad_1_(
}
}
-// Dispatch input shape and FUSE_RELU
-template <typename BIAS_TYPE>
-void depthwise_3x3_per_channel_quantization_pad_1(
- int N,
- int H,
- int W,
- int K,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- const int32_t* B_zero_point,
- const PackedDepthWiseConvMatrix& Bp,
- const float* C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const BIAS_TYPE* bias,
- bool fuse_relu,
- const float* act_times_w_scale,
- int thread_id,
- int num_threads) {
- if (Bp.GetKernelProduct() != 3 * 3) {
- string msg =
- "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " +
- to_string(3 * 3) + " but has " + to_string(Bp.GetKernelProduct());
- throw logic_error(msg);
- }
- if (stride_h == 0 || stride_w == 0 || num_threads == 0) {
- assert(0 && "stride_h == 0 || stride_w == 0 || num_threads == 0");
- return;
- }
- if (N == 0) {
- // In C2, batch size 0 is allowed, so we should just early return.
- return;
- }
- if (fuse_relu) {
- if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<
- true /* FUSE_RELU */,
- BIAS_TYPE>(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<
- true /* FUSE_RELU */,
- BIAS_TYPE>(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- } else if (1 == stride_h && 1 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<
- true /* FUSE_RELU */,
- BIAS_TYPE>(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- } else if (2 == stride_h && 2 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<
- true /* FUSE_RELU */,
- BIAS_TYPE>(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3_per_channel_quantization_pad_1_<
- true /* FUSE_RELU */,
- BIAS_TYPE>(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- }
- } else {
- if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<
- false /* FUSE_RELU */,
- BIAS_TYPE>(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<
- false /* FUSE_RELU */,
- BIAS_TYPE>(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- } else if (1 == stride_h && 1 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<
- false /* FUSE_RELU */,
- BIAS_TYPE>(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- } else if (2 == stride_h && 2 == stride_w) {
- depthwise_3x3_per_channel_quantization_pad_1_<
- false /* FUSE_RELU */,
- BIAS_TYPE>(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- } else {
- depthwise_3x3_per_channel_quantization_pad_1_<
- false /* FUSE_RELU */,
- BIAS_TYPE>(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- act_times_w_scale,
- thread_id,
- num_threads);
- }
- }
-}
-
-// To be removed
-void depthwise_3x3_pad_1(
- int N,
- int H,
- int W,
- int K,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- int32_t B_zero_point,
- const PackedDepthWiseConvMatrix& B,
- float C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const int32_t* bias,
- bool fuse_relu,
- int thread_id,
- int num_threads) {
- depthwise_3x3_pad_1<std::int32_t>(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- B,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- fuse_relu,
- 1.0f,
- thread_id,
- num_threads);
-}
-
-// To be removed
-void depthwise_3x3_per_channel_quantization_pad_1(
- int N,
- int H,
- int W,
- int K,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- const int32_t* B_zero_point,
- const PackedDepthWiseConvMatrix& Bp,
- const float* C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const int32_t* bias,
- bool fuse_relu,
- int thread_id,
- int num_threads) {
- depthwise_3x3_per_channel_quantization_pad_1<std::int32_t>(
- N,
- H,
- W,
- K,
- stride_h,
- stride_w,
- A_zero_point,
- A,
- B_zero_point,
- Bp,
- C_multiplier,
- C_zero_point,
- C,
- col_offsets,
- bias,
- fuse_relu,
- nullptr,
- thread_id,
- num_threads);
-}
-
-template void depthwise_3x3_pad_1(
+template void depthwise_2d_same_pad<int32_t>(
int N,
int H,
int W,
@@ -1947,7 +131,7 @@ template void depthwise_3x3_pad_1(
int thread_id,
int num_threads);
-template void depthwise_3x3_pad_1(
+template void depthwise_2d_same_pad<float>(
int N,
int H,
int W,
@@ -1968,46 +152,4 @@ template void depthwise_3x3_pad_1(
int thread_id,
int num_threads);
-template void depthwise_3x3_per_channel_quantization_pad_1(
- int N,
- int H,
- int W,
- int K,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- const int32_t* B_zero_point,
- const PackedDepthWiseConvMatrix& Bp,
- const float* C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const int32_t* bias,
- bool fuse_relu,
- const float* act_times_w_scale,
- int thread_id,
- int num_threads);
-
-template void depthwise_3x3_per_channel_quantization_pad_1(
- int N,
- int H,
- int W,
- int K,
- int stride_h,
- int stride_w,
- int32_t A_zero_point,
- const uint8_t* A,
- const int32_t* B_zero_point,
- const PackedDepthWiseConvMatrix& Bp,
- const float* C_multiplier,
- int32_t C_zero_point,
- uint8_t* C,
- const int32_t* col_offsets,
- const float* bias,
- bool fuse_relu,
- const float* act_times_w_scale,
- int thread_id,
- int num_threads);
-
} // namespace fbgemm
diff --git a/src/FbgemmI8DepthwisePerChannelQuantAvx2.cc b/src/FbgemmI8DepthwisePerChannelQuantAvx2.cc
new file mode 100644
index 0000000..f429214
--- /dev/null
+++ b/src/FbgemmI8DepthwisePerChannelQuantAvx2.cc
@@ -0,0 +1,154 @@
+/*
+ * Copyright (c) Facebook, Inc. and its affiliates.
+ * All rights reserved.
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+#include "fbgemm/FbgemmI8DepthwiseAvx2.h"
+
+#include <string>
+
+#include "FbgemmI8Depthwise2DAvx2-inl.h"
+
+using namespace std;
+
+namespace fbgemm {
+
+// Dispatch input shape and FUSE_RELU
+template <typename BIAS_TYPE /*=std::int32_t*/>
+void depthwise_2d_per_channel_quantization_same_pad(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& Bp,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const BIAS_TYPE* bias,
+ bool fuse_relu,
+ const float* act_times_w_scale,
+ int thread_id,
+ int num_threads) {
+ if (Bp.GetKernelProduct() == 3 * 3) {
+ depthwise_3x3_per_channel_quantization_pad_1(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ fuse_relu,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ return;
+ }
+
+ if (Bp.GetKernelProduct() != 5 * 5) {
+ string msg =
+ "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " +
+ to_string(5 * 5) + " but has " + to_string(Bp.GetKernelProduct());
+ throw logic_error(msg);
+ }
+ if (fuse_relu) {
+ depthwise_2d_per_channel_quantization_<5, true /* FUSE_RELU */, BIAS_TYPE>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ } else {
+ depthwise_2d_per_channel_quantization_<5, false /* FUSE_RELU */, BIAS_TYPE>(
+ N,
+ H,
+ W,
+ K,
+ stride_h,
+ stride_w,
+ A_zero_point,
+ A,
+ B_zero_point,
+ Bp,
+ C_multiplier,
+ C_zero_point,
+ C,
+ col_offsets,
+ bias,
+ act_times_w_scale,
+ thread_id,
+ num_threads);
+ }
+}
+
+template void depthwise_2d_per_channel_quantization_same_pad<int32_t>(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& Bp,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const int32_t* bias,
+ bool fuse_relu,
+ const float* act_times_w_scale,
+ int thread_id,
+ int num_threads);
+
+template void depthwise_2d_per_channel_quantization_same_pad<float>(
+ int N,
+ int H,
+ int W,
+ int K,
+ int stride_h,
+ int stride_w,
+ int32_t A_zero_point,
+ const uint8_t* A,
+ const int32_t* B_zero_point,
+ const PackedDepthWiseConvMatrix& Bp,
+ const float* C_multiplier,
+ int32_t C_zero_point,
+ uint8_t* C,
+ const int32_t* col_offsets,
+ const float* bias,
+ bool fuse_relu,
+ const float* act_times_w_scale,
+ int thread_id,
+ int num_threads);
+
+} // namespace fbgemm
diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc
index 396e792..dfd4498 100644
--- a/src/GroupwiseConvAcc32Avx2.cc
+++ b/src/GroupwiseConvAcc32Avx2.cc
@@ -1764,6 +1764,11 @@ void fbgemmGroupwiseConv(
const processOutputType& outProcess,
int thread_id,
int num_threads) {
+ // TODO: Remove this when threading is supported.
+ if (thread_id > 0) {
+ return;
+ }
+
return fbgemmGroupwiseConvBase_<
packed_W,
outType,
@@ -1804,6 +1809,10 @@ void fbgemmGroupwiseConv(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
+ // TODO: Remove this when threading is supported.
+ if (thread_id > 0) {
+ return;
+ }
if (!fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param) ||
(!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) {
return fbgemmGroupwiseConvBase_<
diff --git a/src/PackWeightsForConv.cc b/src/PackWeightsForConv.cc
index 192fb00..abbe9ad 100644
--- a/src/PackWeightsForConv.cc
+++ b/src/PackWeightsForConv.cc
@@ -24,7 +24,9 @@ PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv(
switch (ConvFastPath<SPATIAL_DIM, accT>(conv_p)) {
case optimized_conv_t::depthwise: {
W_dw_packed_ = std::make_shared<PackedDepthWiseConvMatrix>(
- conv_p.G, SPATIAL_DIM == 3 ? 3 * 3 * 3 : 3 * 3, sdata);
+ conv_p.G,
+ SPATIAL_DIM == 3 ? 3 * 3 * 3 : conv_p.K[0] * conv_p.K[1],
+ sdata);
break;
}
case optimized_conv_t::groupwise: {
diff --git a/src/codegen_fp16fp32.cc b/src/codegen_fp16fp32.cc
index 8f80593..edd1e83 100644
--- a/src/codegen_fp16fp32.cc
+++ b/src/codegen_fp16fp32.cc
@@ -82,8 +82,28 @@ int main() {
// {12, 1, 0},
// {13, 1, 0},
// {14, 1, 0},
+ }},
+ {3,
+ "AVX512",
+ {
+ // 6x2 register layout
+ {1, 2, 0},
+ {2, 2, 0},
+ {3, 2, 0},
+ {4, 2, 0},
+ {5, 2, 0},
+ {6, 2, 0},
+ {7, 2, 0},
+ {8, 2, 0},
+ {9, 2, 0},
+ {10, 2, 0},
+ {11, 2, 0},
+ {12, 2, 0},
+ {13, 2, 0},
+ {14, 2, 0},
}}};
+<<<<<<< HEAD
// open all files
ofstream srcfile;
srcfile.open("FbgemmFP16UKernelsAvx2.cc");
@@ -669,15 +689,330 @@ int main() {
fptr_typedef[B_type] =
"typedef void (*funcptr_" + B_type + ")" + fargs;
- }
+=======
+ for (auto s : isa) {
+ string isa_file_name = s.avx <= 2 ? "Avx2" : "Avx512";
+
+ // open all files
+ ofstream srcfile;
+ srcfile.open("FbgemmFP16UKernels" + isa_file_name + ".cc");
+ srcfile
+ << "/*\n"
+ " * Copyright (c) Facebook, Inc. and its affiliates.\n"
+ " * All rights reserved.\n"
+ " * This source code is licensed under the BSD-style license found in the\n"
+ " * LICENSE file in the root directory of this source tree.\n"
+ " */\n";
+ srcfile << "#include \"FbgemmFP16UKernels" + isa_file_name + ".h\"\n\n";
+ srcfile << "namespace fbgemm {\n\n";
+ if (iaca) {
+ srcfile << "#include \"iacaMarks.h\"\n";
+ }
+
+ ofstream hdrfile;
+ hdrfile.open("FbgemmFP16UKernels" + isa_file_name + ".h");
+ hdrfile
+ << "/*\n"
+ " * Copyright (c) Facebook, Inc. and its affiliates.\n"
+ " * All rights reserved.\n"
+ " * This source code is licensed under the BSD-style license found in the\n"
+ " * LICENSE file in the root directory of this source tree.\n"
+ " */\n";
+
+ hdrfile << "#pragma once\n";
+ hdrfile << "#include <cstdint>\n";
+ if (s.avx == 3) {
+ hdrfile << "#include \"FbgemmFP16UKernelsAvx2.h\"\n";
+ }
+ hdrfile << "#include \"fbgemm/Types.h\"\n\n";
+ hdrfile << "namespace fbgemm {\n\n";
+ if (s.avx == 2) {
+ hdrfile << "using fp16 = float16;\n";
+ hdrfile << "using fp32 = float;\n";
+ hdrfile
+ << "struct GemmParams {\n uint64_t k;\n float* A;\n const fp16* B;\n"
+ " float* beta;\n uint64_t accum;\n float* C;\n uint64_t ldc;\n"
+ " uint64_t b_block_cols;\n uint64_t b_block_size;\n};\n";
+ }
+
+ map<string, string> fptr_typedef;
+ fptr_typedef["fp16"] = "";
+ fptr_typedef["fp32"] = "";
+
+ unsigned labelId = 0;
+
+ bool fixedA = false, fixedB = false, fixedC = false;
+ bool fp16 = true;
+
+ vector<vector<unsigned>>& ukernel_shape = s.shapes;
+
+ vector<string> funcname(ukernel_shape.size()),
+ fheader(ukernel_shape.size());
+ string fargs;
+
+ string B_type = ((fp16) ? "fp16" : "fp32");
+ string prefix = s.name + /*"_" + B_type */ +"_" + "fA" + to_string(fixedA) +
+ "fB" + to_string(fixedB) + "fC" + to_string(fixedC);
+ cout << "Generating code for " << s.name << " " << B_type << "\n";
+
+ string vec_reg_prefix = s.avx <= 2 ? "ymm" : "zmm";
+ int num_vec_regs = s.avx <= 2 ? 16 : 32;
+ int vec_len_in_bytes = s.avx <= 2 ? 32 : 64;
+
+ for (unsigned k = 0; k < ukernel_shape.size(); k++) {
+ printf("shape: %d x %d * 32\n", ukernel_shape[k][0], ukernel_shape[k][1]);
+
+ string p1 = "GemmParams* gp";
+
+ funcname[k] = "gemmkernel_" + to_string(ukernel_shape[k][0]) + "x" +
+ to_string(ukernel_shape[k][1]) + "_";
+ funcname[k] += prefix;
+
+ fargs = "(" + p1 + ")";
+
+ fheader[k] = "void __attribute__((noinline)) " + funcname[k] + fargs;
+ srcfile << fheader[k] << " {\n";
+
+ unsigned last_free_vecreg = 0;
+ // produce register block of C
+ vector<vector<string>> vCtile(ukernel_shape[k][0]);
+ for (auto r = 0; r < ukernel_shape[k][0]; r++)
+ for (auto c = 0; c < ukernel_shape[k][1]; c++) {
+ vCtile[r].push_back(vec_reg_prefix + to_string(last_free_vecreg));
+ last_free_vecreg++;
}
+ assert(last_free_vecreg <= num_vec_regs - 2);
- srcfile << "\n} // namespace fbgemm\n";
- srcfile.close();
+ string vAtmp = vec_reg_prefix + to_string(last_free_vecreg++);
+ // produce register block of B col
+ vector<string> vBcol(ukernel_shape[k][1]);
- hdrfile << fptr_typedef["fp16"] << ";\n";
- hdrfile << fptr_typedef["fp32"] << ";\n";
- hdrfile << "\n} // namespace fbgemm\n\n";
- hdrfile << "#endif\n";
- hdrfile.close();
+ for (auto c = 0; c < ukernel_shape[k][1]; c++) {
+ vBcol[c] = (vec_reg_prefix + to_string(last_free_vecreg));
+ last_free_vecreg++;
+ }
+
+ assert(last_free_vecreg <= num_vec_regs);
+
+ srcfile << " asm volatile(\n";
+
+ srcfile << "#if !defined(__clang__)"
+ << "\n";
+ addi(srcfile, "mov r14, %[gp]");
+ srcfile << "#else\n";
+ addi(srcfile, "mov %[gp], %%r14");
+ addi(srcfile, ".intel_syntax noprefix");
+ srcfile << "#endif\n";
+
+ srcfile << "\n // Copy parameters\n";
+ srcfile << " // k\n";
+ addi(srcfile, "mov r8, [r14 + 0]");
+ srcfile << " // A\n";
+ addi(srcfile, "mov r9, [r14 + 8]");
+ srcfile << " // B\n";
+ addi(srcfile, "mov r10, [r14 + 16]");
+ srcfile << " // beta\n";
+ addi(srcfile, "mov r15, [r14 + 24]");
+ srcfile << " // accum\n";
+ addi(srcfile, "mov rdx, [r14 + 32]");
+ srcfile << " // C\n";
+ addi(srcfile, "mov r12, [r14 + 40]");
+ srcfile << " // ldc\n";
+ addi(srcfile, "mov r13, [r14 + 48]");
+ srcfile << " // b_block_cols\n";
+ addi(srcfile, "mov rdi, [r14 + 56]");
+ srcfile << " // b_block_size\n";
+ addi(srcfile, "mov rsi, [r14 + 64]");
+ srcfile << " // Make copies of A and C\n";
+ addi(srcfile, "mov rax, r9");
+ addi(srcfile, "mov rcx, r12");
+ srcfile << "\n";
+
+ addi(srcfile, "mov rbx, 0");
+
+ string exitlabel = "L_exit%=";
+ string label2 = "loop_outter%=";
+ addi(srcfile, label2 + ":");
+ addi(srcfile, "mov r14, 0");
+
+ // set all vCtile regs to zeros
+ for (auto r = 0; r < vCtile.size(); r++) {
+ for (auto c = 0; c < vCtile[r].size(); c++) {
+ addi(
+ srcfile,
+ "vxorps " + vCtile[r][c] + "," + vCtile[r][c] + "," +
+ vCtile[r][c]);
+ }
+ }
+
+ // start marker
+ if (iaca) {
+ addi(srcfile, "mov ebx, 111");
+ addi(srcfile, ".byte 0x64, 0x67, 0x90");
+ }
+
+ srcfile << "\n";
+
+ srcfile << "\n";
+ string label = "loop_inner%=";
+ addi(srcfile, label + ":");
+ srcfile << "\n";
+
+ for (int c = 0; c < vCtile[0].size(); c++) {
+ addi(
+ srcfile,
+ "vcvtph2ps " + vBcol[c] + "," + (s.avx <= 2 ? "XMM" : "YMM") +
+ "WORD PTR [r10 + " + to_string(vec_len_in_bytes / 2 * c) + "]");
+ }
+
+ for (int r = 0; r < vCtile.size(); r++) {
+ addi(
+ srcfile,
+ "vbroadcastss " + vAtmp + ",DWORD PTR [r9+" + to_string(4 * r) +
+ "]");
+ for (int c = 0; c < vCtile[0].size(); c++) {
+ addi(
+ srcfile,
+ "vfmadd231ps " + vCtile[r][c] + "," + vBcol[c] + "," + vAtmp);
+ }
+ }
+
+ addi(
+ srcfile,
+ "add r9," + to_string(4 * ukernel_shape[k][0]),
+ fixedA); // move A ptr
+
+ addi(
+ srcfile,
+ "add r10," + to_string(vec_len_in_bytes / 2 * ukernel_shape[k][1]),
+ fixedA); // move A ptr
+
+ addi(srcfile, "inc r14");
+ addi(srcfile, "cmp r14, r8");
+ addi(srcfile, "jl " + label);
+
+ srcfile << "\n";
+
+ addi(srcfile, exitlabel + ":");
+
+ // addi(srcfile, "add r10, rsi");
+ srcfile << "\n";
+
+ // end marker
+ if (iaca) {
+ addi(srcfile, "mov ebx, 222");
+ addi(srcfile, ".byte 0x64, 0x67, 0x90");
+ }
+
+ addi(srcfile, "cmp rdx, 1");
+ addi(srcfile, "je L_accum%=");
+ srcfile << " // Dump C\n";
+
+ for (auto r = 0; r < vCtile.size(); r++) {
+ for (auto c = 0; c < vCtile[r].size(); c++) {
+ addi(
+ srcfile,
+ "vmovups " + vec_reg_prefix + "word PTR [r12 + " +
+ to_string(vec_len_in_bytes * c) + "], " + vCtile[r][c],
+ fixedC);
+ }
+ addi(srcfile, "add r12, r13", fixedC); // move C ptr
+ }
+ addi(srcfile, "jmp L_done%=");
+
+ srcfile << "\n";
+ addi(srcfile, "L_accum%=:");
+ srcfile << " // Dump C with accumulate\n";
+
+ string r_spare =
+ vec_reg_prefix + to_string(num_vec_regs - (s.avx == 1 ? 2 : 1));
+ string r_last = vec_reg_prefix + to_string(num_vec_regs - 1);
+ addi(
+ srcfile,
+ "vbroadcastss " + r_spare + string(",DWORD PTR [r15]"),
+ fixedC);
+ // store out C
+ for (auto r = 0; r < vCtile.size(); r++) {
+ for (auto c = 0; c < vCtile[r].size(); c++) {
+ switch (s.avx) {
+ case 1:
+ addi(
+ srcfile,
+ string("vmulps " + r_last + ", ") + r_spare + comma +
+ "YMMWORD PTR [r12 + " + to_string(vec_len_in_bytes * c) +
+ "]",
+ fixedC);
+ addi(
+ srcfile,
+ "vaddps " + vCtile[r][c] + "," + vCtile[r][c] + "," + r_last,
+ fixedC);
+ break;
+ case 2:
+ case 3:
+ addi(
+ srcfile,
+ "vfmadd231ps " + vCtile[r][c] + "," + r_spare + "," +
+ vec_reg_prefix + "word PTR [r12 + " +
+ to_string(vec_len_in_bytes * c) + "]",
+ fixedC);
+ break;
+ default:
+ assert(0);
+>>>>>>> upstream
+ }
+ addi(
+ srcfile,
+ "vmovups " + vec_reg_prefix + "word PTR [r12 + " +
+ to_string(vec_len_in_bytes * c) + "], " + vCtile[r][c],
+ fixedC);
+ }
+ addi(srcfile, "add r12, r13", fixedC); // move C ptr
+ }
+
+ srcfile << "\n";
+ addi(srcfile, "L_done%=:");
+
+ srcfile << "\n // next outer iteration\n";
+ // C
+ addi(
+ srcfile,
+ "add rcx, " + to_string(vec_len_in_bytes * ukernel_shape[k][1]),
+ fixedC);
+ addi(srcfile, "mov r12, rcx", fixedC);
+ // A
+ addi(srcfile, "mov r9, rax");
+
+ addi(srcfile, "inc rbx");
+ addi(srcfile, "cmp rbx, rdi");
+ addi(srcfile, "jl " + label2);
+
+ // output
+ srcfile << " :\n";
+ // input
+ srcfile << " : [gp] \"rm\"(gp)\n";
+
+ // clobbered
+ srcfile << " : \"r8\",\n \"r9\",\n \"r10\",\n"
+ " \"r11\",\n \"r15\",\n \"r13\",\n"
+ " \"r14\",\n \"rax\",\n \"rcx\",\n"
+ " \"rdx\",\n \"rsi\",\n \"rdi\",\n"
+ " \"rbx\",\n \"r12\",\n"
+ " \"memory\");\n";
+ srcfile << "}\n";
+ }
+
+ for (unsigned k = 0; k < ukernel_shape.size(); k++) {
+ hdrfile << fheader[k] << ";\n";
+ }
+
+ fptr_typedef[B_type] = "typedef void (*funcptr_" + B_type + ")" + fargs;
+
+ srcfile << "\n} // namespace fbgemm\n";
+ srcfile.close();
+
+ hdrfile << fptr_typedef["fp16"] << ";\n";
+ hdrfile << fptr_typedef["fp32"] << ";\n";
+ hdrfile << "\n} // namespace fbgemm\n\n";
+ hdrfile.close();
+ } // isa
}