diff options
author | Young Jin Kim <youki@microsoft.com> | 2019-12-07 01:52:35 +0300 |
---|---|---|
committer | Young Jin Kim <youki@microsoft.com> | 2019-12-07 01:52:35 +0300 |
commit | 6f7ad8fb91e8ab94d35f847035a1b5ab8e5c5b44 (patch) | |
tree | a0c9e0d7b9acc9c8e711b36ea184ef494733c7e0 /src | |
parent | 21f93c950b8b27918cd59c8f3139fb41ad1bd2c6 (diff) | |
parent | 0d7da7c36f50276b5a550d46508516d139522687 (diff) |
Fixing merge erroryouki/fp16avx512
Diffstat (limited to 'src')
-rw-r--r-- | src/FbgemmConv.cc | 53 | ||||
-rw-r--r-- | src/FbgemmFP16.cc | 204 | ||||
-rw-r--r-- | src/FbgemmFP16UKernelsAvx2.cc | 621 | ||||
-rw-r--r-- | src/FbgemmFP16UKernelsAvx2.h | 4 | ||||
-rw-r--r-- | src/FbgemmFP16UKernelsAvx512.cc | 2558 | ||||
-rw-r--r-- | src/FbgemmFP16UKernelsAvx512.h | 32 | ||||
-rw-r--r-- | src/FbgemmI8Depthwise2DAvx2-inl.h | 1623 | ||||
-rw-r--r-- | src/FbgemmI8Depthwise3DAvx2.cc | 91 | ||||
-rw-r--r-- | src/FbgemmI8Depthwise3x3Avx2.cc | 618 | ||||
-rw-r--r-- | src/FbgemmI8DepthwiseAvx2.cc | 1892 | ||||
-rw-r--r-- | src/FbgemmI8DepthwisePerChannelQuantAvx2.cc | 154 | ||||
-rw-r--r-- | src/GroupwiseConvAcc32Avx2.cc | 9 | ||||
-rw-r--r-- | src/PackWeightsForConv.cc | 4 | ||||
-rw-r--r-- | src/codegen_fp16fp32.cc | 351 |
14 files changed, 6194 insertions, 2020 deletions
diff --git a/src/FbgemmConv.cc b/src/FbgemmConv.cc index de833d2..da3fa88 100644 --- a/src/FbgemmConv.cc +++ b/src/FbgemmConv.cc @@ -6,9 +6,9 @@ */ #include <algorithm> +#include <functional> #include <numeric> #include <vector> -#include <functional> #include "fbgemm/Fbgemm.h" namespace fbgemm { @@ -24,13 +24,16 @@ bool takeDepthWiseFastPath(const conv_param_t<SPATIAL_DIM>& conv_p) { conv_p.stride.end(), [](int i) { return i == 1 || i == 2; }) && std::all_of( - conv_p.K.begin(), conv_p.K.end(), [](int i) { return i == 3; }) && + conv_p.K.begin(), + conv_p.K.end(), + [&conv_p](int i) { return i == conv_p.K[0]; }) && + (conv_p.K[0] == 3 || (SPATIAL_DIM == 2 && conv_p.K[0] == 5)) && std::all_of( conv_p.dilation.begin(), conv_p.dilation.end(), [](int i) { return i == 1; }) && - std::all_of(conv_p.pad.begin(), conv_p.pad.end(), [](int i) { - return i == 1; + std::all_of(conv_p.pad.begin(), conv_p.pad.end(), [&conv_p](int i) { + return i == (conv_p.K[0] - 1) / 2; }); } @@ -151,9 +154,9 @@ int fbgemmConv( "not supported"; throw std::runtime_error(msg); } - } else { + } else if (SPATIAL_DIM == 2) { if (processOutputType::QGRANType == QuantizationGranularity::TENSOR) { - depthwise_3x3_pad_1( + depthwise_2d_same_pad( conv_p.MB, // mini batch conv_p.IN_DIM[0], // H conv_p.IN_DIM[1], // W @@ -178,7 +181,7 @@ int fbgemmConv( QuantizationGranularity::OUT_CHANNEL || processOutputType::QGRANType == QuantizationGranularity::GROUP) { // The number of channels == groups for depthwise convolutions - depthwise_3x3_per_channel_quantization_pad_1( + depthwise_2d_per_channel_quantization_same_pad( conv_p.MB, // mini batch conv_p.IN_DIM[0], // H conv_p.IN_DIM[1], // W @@ -204,6 +207,10 @@ int fbgemmConv( "not supported"; throw std::runtime_error(msg); } + } else { + std::string msg = + "[FBGEMM_CONV_ERROR] This spatial dim is not supported"; + throw std::runtime_error(msg); } break; } @@ -212,20 +219,24 @@ int fbgemmConv( // std::cout << "Groupwise fast path" << std::endl; assert( SPATIAL_DIM == 2 && "Only 2D groupwise convolutions are supported"); - std::vector<int32_t> row_offset_buf( - rowOffsetBufferSizeGConv<SPATIAL_DIM>(conv_p)); - outProcess.setRowOffsets(row_offset_buf.data()); - fbgemmGroupwiseConv( - conv_p, - activations, - outProcess.getAZeroPoint(), - row_offset_buf.data(), - *(packed_weights.getPackedWForGroupwise()), - out, - outBuffer, - outProcess, - thread_id, - num_threads); + // thread 0 does all the work. + // TODO: Remove this when fbgemmGroupwiseConv supports threads + if (thread_id == 0) { + std::vector<int32_t> row_offset_buf( + rowOffsetBufferSizeGConv<SPATIAL_DIM>(conv_p)); + outProcess.setRowOffsets(row_offset_buf.data()); + fbgemmGroupwiseConv( + conv_p, + activations, + outProcess.getAZeroPoint(), + row_offset_buf.data(), + *(packed_weights.getPackedWForGroupwise()), + out, + outBuffer, + outProcess, + 0, + 1); + } break; } case optimized_conv_t::pointwise: { diff --git a/src/FbgemmFP16.cc b/src/FbgemmFP16.cc index b034f2c..e40277d 100644 --- a/src/FbgemmFP16.cc +++ b/src/FbgemmFP16.cc @@ -13,6 +13,7 @@ #include <utility> #include "FbgemmFP16UKernelsAvx2.h" +#include "FbgemmFP16UKernelsAvx512.h" using namespace std; @@ -32,26 +33,50 @@ inline void PackA(int nrow, int ncol, const float* from, int ldim, float* to) { transpose_simd(nrow, ncol, from, ldim, to, nrow); } +// Each kernel does the following computation that multiplies +// mb x k A sub-matrix with k x b_block_cols*64 B sub-matrix +// for (int j = 0; j < b_block_cols * 64; j += 64) { +// for (int kk = 0; kk < k; ++k) { +// for (int i = 0; i < mb; ++i) { +// c[i][j:j+64] += a[i][kk] * b[kk][j:j+64] +// } +// } +// } + struct KernelInfo { using knl_ptr = funcptr_fp16; // optimized kernels to cover all cases // 2 in ?x2 should be the same as kernel_ncol_blocks. // Here with kernel_ncol_blocks = 2, we can provide up to 6x2 kernels, due to // the restrictions of ymm register numbers (16). - static constexpr knl_ptr kernel[7] = { - nullptr, - gemmkernel_1x2_AVX2_fA0fB0fC0, - gemmkernel_2x2_AVX2_fA0fB0fC0, - gemmkernel_3x2_AVX2_fA0fB0fC0, - gemmkernel_4x2_AVX2_fA0fB0fC0, - gemmkernel_5x2_AVX2_fA0fB0fC0, - gemmkernel_6x2_AVX2_fA0fB0fC0 - }; + static constexpr knl_ptr kernel_avx2[] = {nullptr, + gemmkernel_1x2_AVX2_fA0fB0fC0, + gemmkernel_2x2_AVX2_fA0fB0fC0, + gemmkernel_3x2_AVX2_fA0fB0fC0, + gemmkernel_4x2_AVX2_fA0fB0fC0, + gemmkernel_5x2_AVX2_fA0fB0fC0, + gemmkernel_6x2_AVX2_fA0fB0fC0}; + + static constexpr knl_ptr kernel_avx512[] = {nullptr, + gemmkernel_1x2_AVX512_fA0fB0fC0, + gemmkernel_2x2_AVX512_fA0fB0fC0, + gemmkernel_3x2_AVX512_fA0fB0fC0, + gemmkernel_4x2_AVX512_fA0fB0fC0, + gemmkernel_5x2_AVX512_fA0fB0fC0, + gemmkernel_6x2_AVX512_fA0fB0fC0, + gemmkernel_7x2_AVX512_fA0fB0fC0, + gemmkernel_8x2_AVX512_fA0fB0fC0, + gemmkernel_9x2_AVX512_fA0fB0fC0, + gemmkernel_10x2_AVX512_fA0fB0fC0, + gemmkernel_11x2_AVX512_fA0fB0fC0, + gemmkernel_12x2_AVX512_fA0fB0fC0, + gemmkernel_13x2_AVX512_fA0fB0fC0, + gemmkernel_14x2_AVX512_fA0fB0fC0}; // autotuned kernel splits for various cases m = 1:mb_max // may need re-autotuning for new uarch // clang-format off - static constexpr int partition[121][2][2] = { + static constexpr int partition_avx2[121][2][2] = { // NOTE: clang-format wants to use a different formatting but the current // formatting should be easier to read. { { 0, 0 }, { 0, 0 } }, // 0 @@ -176,10 +201,139 @@ struct KernelInfo { { { 6, 19 }, { 5, 1 } }, // 119 { { 6, 20 }, { 0, 0 } }, // 120 }; + static constexpr partition_avx512[121][2][2] = { + // NOTE: clang-format wants to use a different formatting but the current + // formatting should be easier to read. + { + {{ { 0, 0 }, { 0, 0 } } }, // 0 + {{ { 1, 1 }, { 0, 0 } } }, // 1 + {{ { 2, 1 }, { 0, 0 } } }, // 2 + {{ { 3, 1 }, { 0, 0 } } }, // 3 + {{ { 4, 1 }, { 0, 0 } } }, // 4 + {{ { 5, 1 }, { 0, 0 } } }, // 5 + {{ { 6, 1 }, { 0, 0 } } }, // 6 + {{ { 7, 1 }, { 0, 0 } } }, // 7 + {{ { 8, 1 }, { 0, 0 } } }, // 8 + {{ { 9, 1 }, { 0, 0 } } }, // 9 + {{ { 10, 1 }, { 0, 0 } } }, // 10 + {{ { 11, 1 }, { 0, 0 } } }, // 11 + {{ { 12, 1 }, { 0, 0 } } }, // 12 + {{ { 13, 1 }, { 0, 0 } } }, // 13 + {{ { 14, 1 }, { 0, 0 } } }, // 14 + {{ { 8, 1 }, { 7, 1 } } }, // 15 + {{ { 8, 2 }, { 0, 0 } } }, // 16 + {{ { 9, 1 }, { 8, 1 } } }, // 17 + {{ { 9, 2 }, { 0, 0 } } }, // 18 + {{ { 10, 1 }, { 9, 1 } } }, // 19 + {{ { 10, 2 }, { 0, 0 } } }, // 20 + {{ { 11, 1 }, { 10, 1 } } }, // 21 + {{ { 11, 2 }, { 0, 0 } } }, // 22 + {{ { 12, 1 }, { 11, 1 } } }, // 23 + {{ { 12, 2 }, { 0, 0 } } }, // 24 + {{ { 13, 1 }, { 12, 1 } } }, // 25 + {{ { 13, 2 }, { 0, 0 } } }, // 26 + {{ { 14, 1 }, { 13, 1 } } }, // 27 + {{ { 14, 2 }, { 0, 0 } } }, // 28 + {{ { 10, 2 }, { 9, 1 } } }, // 29 + {{ { 10, 3 }, { 0, 0 } } }, // 30 + {{ { 11, 2 }, { 9, 1 } } }, // 31 + {{ { 11, 2 }, { 10, 1 } } }, // 32 + {{ { 11, 3 }, { 0, 0 } } }, // 33 + {{ { 12, 2 }, { 10, 1 } } }, // 34 + {{ { 12, 2 }, { 11, 1 } } }, // 35 + {{ { 12, 3 }, { 0, 0 } } }, // 36 + {{ { 13, 2 }, { 11, 1 } } }, // 37 + {{ { 13, 2 }, { 12, 1 } } }, // 38 + {{ { 13, 3 }, { 0, 0 } } }, // 39 + {{ { 14, 2 }, { 12, 1 } } }, // 40 + {{ { 14, 2 }, { 13, 1 } } }, // 41 + {{ { 14, 3 }, { 0, 0 } } }, // 42 + {{ { 11, 3 }, { 10, 1 } } }, // 43 + {{ { 11, 4 }, { 0, 0 } } }, // 44 + {{ { 12, 3 }, { 9, 1 } } }, // 45 + {{ { 12, 3 }, { 10, 1 } } }, // 46 + {{ { 12, 3 }, { 11, 1 } } }, // 47 + {{ { 12, 4 }, { 0, 0 } } }, // 48 + {{ { 13, 3 }, { 10, 1 } } }, // 49 + {{ { 13, 3 }, { 11, 1 } } }, // 50 + {{ { 13, 3 }, { 12, 1 } } }, // 51 + {{ { 13, 4 }, { 0, 0 } } }, // 52 + {{ { 14, 3 }, { 11, 1 } } }, // 53 + {{ { 14, 3 }, { 12, 1 } } }, // 54 + {{ { 14, 3 }, { 13, 1 } } }, // 55 + {{ { 14, 4 }, { 0, 0 } } }, // 56 + {{ { 12, 4 }, { 9, 1 } } }, // 57 + {{ { 12, 4 }, { 10, 1 } } }, // 58 + {{ { 12, 4 }, { 11, 1 } } }, // 59 + {{ { 12, 5 }, { 0, 0 } } }, // 60 + {{ { 13, 4 }, { 9, 1 } } }, // 61 + {{ { 13, 4 }, { 10, 1 } } }, // 62 + {{ { 13, 4 }, { 11, 1 } } }, // 63 + {{ { 13, 4 }, { 12, 1 } } }, // 64 + {{ { 13, 5 }, { 0, 0 } } }, // 65 + {{ { 14, 4 }, { 10, 1 } } }, // 66 + {{ { 14, 4 }, { 11, 1 } } }, // 67 + {{ { 14, 4 }, { 12, 1 } } }, // 68 + {{ { 14, 4 }, { 13, 1 } } }, // 69 + {{ { 14, 5 }, { 0, 0 } } }, // 70 + {{ { 12, 5 }, { 11, 1 } } }, // 71 + {{ { 12, 6 }, { 0, 0 } } }, // 72 + {{ { 13, 5 }, { 8, 1 } } }, // 73 + {{ { 13, 5 }, { 9, 1 } } }, // 74 + {{ { 13, 5 }, { 10, 1 } } }, // 75 + {{ { 13, 5 }, { 11, 1 } } }, // 76 + {{ { 13, 5 }, { 12, 1 } } }, // 77 + {{ { 13, 6 }, { 0, 0 } } }, // 78 + {{ { 14, 5 }, { 9, 1 } } }, // 79 + {{ { 14, 5 }, { 10, 1 } } }, // 80 + {{ { 14, 5 }, { 11, 1 } } }, // 81 + {{ { 14, 5 }, { 12, 1 } } }, // 82 + {{ { 14, 5 }, { 13, 1 } } }, // 83 + {{ { 14, 6 }, { 0, 0 } } }, // 84 + {{ { 13, 6 }, { 7, 1 } } }, // 85 + {{ { 13, 6 }, { 8, 1 } } }, // 86 + {{ { 13, 6 }, { 9, 1 } } }, // 87 + {{ { 13, 6 }, { 10, 1 } } }, // 88 + {{ { 13, 6 }, { 11, 1 } } }, // 89 + {{ { 13, 6 }, { 12, 1 } } }, // 90 + {{ { 13, 7 }, { 0, 0 } } }, // 91 + {{ { 14, 6 }, { 8, 1 } } }, // 92 + {{ { 14, 6 }, { 9, 1 } } }, // 93 + {{ { 14, 6 }, { 10, 1 } } }, // 94 + {{ { 14, 6 }, { 11, 1 } } }, // 95 + {{ { 14, 6 }, { 12, 1 } } }, // 96 + {{ { 14, 6 }, { 13, 1 } } }, // 97 + {{ { 14, 7 }, { 0, 0 } } }, // 98 + {{ { 13, 7 }, { 8, 1 } } }, // 99 + {{ { 13, 7 }, { 9, 1 } } }, // 100 + {{ { 13, 7 }, { 10, 1 } } }, // 101 + {{ { 13, 7 }, { 11, 1 } } }, // 102 + {{ { 13, 7 }, { 12, 1 } } }, // 103 + {{ { 13, 8 }, { 0, 0 } } }, // 104 + {{ { 14, 7 }, { 7, 1 } } }, // 105 + {{ { 14, 7 }, { 8, 1 } } }, // 106 + {{ { 14, 7 }, { 9, 1 } } }, // 107 + {{ { 14, 7 }, { 10, 1 } } }, // 108 + {{ { 14, 7 }, { 11, 1 } } }, // 109 + {{ { 14, 7 }, { 12, 1 } } }, // 110 + {{ { 14, 7 }, { 13, 1 } } }, // 111 + {{ { 14, 8 }, { 0, 0 } } }, // 112 + {{ { 13, 8 }, { 9, 1 } } }, // 113 + {{ { 13, 8 }, { 10, 1 } } }, // 114 + {{ { 13, 8 }, { 11, 1 } } }, // 115 + {{ { 13, 8 }, { 12, 1 } } }, // 116 + {{ { 13, 9 }, { 0, 0 } } }, // 117 + {{ { 14, 8 }, { 6, 1 } } }, // 118 + {{ { 14, 8 }, { 7, 1 } } }, // 119 + {{ { 14, 8 }, { 8, 1 } } }, // 120 + } + }; // clang-format on }; -constexpr KernelInfo::knl_ptr KernelInfo::kernel[7];; -constexpr int KernelInfo::partition[121][2][2]; +constexpr KernelInfo::knl_ptr KernelInfo::kernel_avx2[]; +constexpr KernelInfo::knl_ptr KernelInfo::kernel_avx512[]; +constexpr KernelInfo::partition[121][2][2] KernelInfo::partition_avx2; +constexpr KernelInfo::partition[121][2][2] KernelInfo::partition_avx512; // autotuned kernel splits for various cases m = 1:mb_max void cblas_gemm_compute( @@ -200,7 +354,8 @@ void cblas_gemm_compute( // constants const int n = Bp.numCols(), k = Bp.numRows(), ldc = n; const int mb_max = 120; - constexpr int simd_width = 8; + bool has_avx512 = fbgemmHasAvx512Support(); + int simd_width = has_avx512 ? 16 : 8; int kernel_ncol_blocks = Bp.kernelNumColBlocks(); int kernel_ncols = kernel_ncol_blocks * simd_width; @@ -210,6 +365,11 @@ void cblas_gemm_compute( GemmParams gp; + const funcptr_fp16* kernels = + has_avx512 ? KernelInfo::kernel_avx512 : KernelInfo::kernel_avx2; + const array<array<array<int, 2>, 2>, 121>& partition = + has_avx512 ? KernelInfo::partition_avx512 : KernelInfo::partition_avx2; + int i_begin, i_end; // fbgemmGetRange(num_threads, thread_id, m, 1, i_begin, i_end); i_begin = 0; @@ -236,8 +396,8 @@ void cblas_gemm_compute( auto m1 = m0; for (auto c = 0; c < 2; c++) { - auto kernel_nrows = KernelInfo::partition[mb][c][0]; - auto nkernel_nrows = KernelInfo::partition[mb][c][1]; + auto kernel_nrows = partition[mb][c][0]; + auto nkernel_nrows = partition[mb][c][1]; auto m_start = m1, m_end = m1 + kernel_nrows * nkernel_nrows; for (auto m2 = m_start; m2 < m_end; m2 += kernel_nrows) { @@ -271,7 +431,7 @@ void cblas_gemm_compute( gp.C += Bp.blockColSize() * jb_begin; gp.b_block_cols = jb_end - jb_begin; if (gp.b_block_cols) { - KernelInfo::kernel[kernel_nrows](&gp); + kernels[kernel_nrows](&gp); } } else { int last_blk_col = nbcol * Bp.blockColSize(); @@ -283,7 +443,7 @@ void cblas_gemm_compute( gp.C += Bp.blockColSize() * jb_begin; gp.b_block_cols = jb_end - jb_begin; if (gp.b_block_cols) { - KernelInfo::kernel[kernel_nrows](&gp); + kernels[kernel_nrows](&gp); } } @@ -297,19 +457,21 @@ void cblas_gemm_compute( gp.B = &(Bp(k_ind, last_blk_col)); gp.C = &C[m2 * ldc + last_blk_col]; gp.b_block_cols = 1; - KernelInfo::kernel[kernel_nrows](&gp); + kernels[kernel_nrows](&gp); } else { // small temporary buffer: the size should be larger than the // required kernel_nrow x kernel_ncols elements computed in the // registers. - float c_tmp[16 * 24] = {0}; - assert((16 * 24) > kernel_nrows * kernel_ncols); + float c_tmp[14 * 32] = {0}; + assert( + sizeof(c_tmp) / sizeof(c_tmp[0]) >= + kernel_nrows * kernel_ncols); gp.B = &(Bp(k_ind, last_blk_col)); gp.C = c_tmp; gp.ldc = kernel_ncols * sizeof(C[0]); gp.b_block_cols = 1; - KernelInfo::kernel[kernel_nrows](&gp); + kernels[kernel_nrows](&gp); for (int i = 0; i < kernel_nrows; i++) { // Todo: use assembly for (int j = last_blk_col; j < n; j++) { diff --git a/src/FbgemmFP16UKernelsAvx2.cc b/src/FbgemmFP16UKernelsAvx2.cc index 5f7492f..204fa03 100644 --- a/src/FbgemmFP16UKernelsAvx2.cc +++ b/src/FbgemmFP16UKernelsAvx2.cc @@ -32,6 +32,7 @@ void NOINLINE_ATTR gemmkernel_1x2_AVX2_fA0fB0fC0(GemmParams* gp) { // b_block_size uint64_t rsi = *(uint64_t *)((char*)r14 + 64); //"mov rsi, [r14 + 64]\t\n" // Make copies of A and C +<<<<<<< HEAD float* rax = r9; //"mov rax, r9\t\n" float* rcx = r12; //"mov rcx, r12\t\n" @@ -75,6 +76,76 @@ void NOINLINE_ATTR gemmkernel_1x2_AVX2_fA0fB0fC0(GemmParams* gp) { r12 = rcx; //"mov r12, rcx\t\n" r9 = rax; //"mov r9, rax\t\n" } //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n" +======= + "mov rax, r9\t\n" + "mov rcx, r12\t\n" + + "mov rbx, 0\t\n" + "loop_outter%=:\t\n" + "mov r14, 0\t\n" + "vxorps ymm0,ymm0,ymm0\t\n" + "vxorps ymm1,ymm1,ymm1\t\n" + + + "loop_inner%=:\t\n" + + "vcvtph2ps ymm3,XMMWORD PTR [r10 + 0]\t\n" + "vcvtph2ps ymm4,XMMWORD PTR [r10 + 16]\t\n" + "vbroadcastss ymm2,DWORD PTR [r9+0]\t\n" + "vfmadd231ps ymm0,ymm3,ymm2\t\n" + "vfmadd231ps ymm1,ymm4,ymm2\t\n" + "add r9,4\t\n" + "add r10,32\t\n" + "inc r14\t\n" + "cmp r14, r8\t\n" + "jl loop_inner%=\t\n" + + "L_exit%=:\t\n" + + "cmp rdx, 1\t\n" + "je L_accum%=\t\n" + // Dump C + "vmovups ymmword PTR [r12 + 0], ymm0\t\n" + "vmovups ymmword PTR [r12 + 32], ymm1\t\n" + "add r12, r13\t\n" + "jmp L_done%=\t\n" + + "L_accum%=:\t\n" + // Dump C with accumulate + "vbroadcastss ymm15,DWORD PTR [r15]\t\n" + "vfmadd231ps ymm0,ymm15,ymmword PTR [r12 + 0]\t\n" + "vmovups ymmword PTR [r12 + 0], ymm0\t\n" + "vfmadd231ps ymm1,ymm15,ymmword PTR [r12 + 32]\t\n" + "vmovups ymmword PTR [r12 + 32], ymm1\t\n" + "add r12, r13\t\n" + + "L_done%=:\t\n" + + // next outer iteration + "add rcx, 64\t\n" + "mov r12, rcx\t\n" + "mov r9, rax\t\n" + "inc rbx\t\n" + "cmp rbx, rdi\t\n" + "jl loop_outter%=\t\n" + : + : [gp] "rm"(gp) + : "r8", + "r9", + "r10", + "r11", + "r15", + "r13", + "r14", + "rax", + "rcx", + "rdx", + "rsi", + "rdi", + "rbx", + "r12", + "memory"); +>>>>>>> upstream } void NOINLINE_ATTR gemmkernel_2x2_AVX2_fA0fB0fC0(GemmParams* gp) { @@ -100,6 +171,7 @@ void NOINLINE_ATTR gemmkernel_2x2_AVX2_fA0fB0fC0(GemmParams* gp) { // b_block_size uint64_t rsi = *(uint64_t *)((char*)r14 + 64); //"mov rsi, [r14 + 64]\t\n" // Make copies of A and C +<<<<<<< HEAD float* rax = r9; //"mov rax, r9\t\n" float* rcx = r12; //"mov rcx, r12\t\n" @@ -158,6 +230,89 @@ void NOINLINE_ATTR gemmkernel_2x2_AVX2_fA0fB0fC0(GemmParams* gp) { r12 = rcx; //"mov r12, rcx\t\n" r9 = rax; //"mov r9, rax\t\n" } //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n" +======= + "mov rax, r9\t\n" + "mov rcx, r12\t\n" + + "mov rbx, 0\t\n" + "loop_outter%=:\t\n" + "mov r14, 0\t\n" + "vxorps ymm0,ymm0,ymm0\t\n" + "vxorps ymm1,ymm1,ymm1\t\n" + "vxorps ymm2,ymm2,ymm2\t\n" + "vxorps ymm3,ymm3,ymm3\t\n" + + + "loop_inner%=:\t\n" + + "vcvtph2ps ymm5,XMMWORD PTR [r10 + 0]\t\n" + "vcvtph2ps ymm6,XMMWORD PTR [r10 + 16]\t\n" + "vbroadcastss ymm4,DWORD PTR [r9+0]\t\n" + "vfmadd231ps ymm0,ymm5,ymm4\t\n" + "vfmadd231ps ymm1,ymm6,ymm4\t\n" + "vbroadcastss ymm4,DWORD PTR [r9+4]\t\n" + "vfmadd231ps ymm2,ymm5,ymm4\t\n" + "vfmadd231ps ymm3,ymm6,ymm4\t\n" + "add r9,8\t\n" + "add r10,32\t\n" + "inc r14\t\n" + "cmp r14, r8\t\n" + "jl loop_inner%=\t\n" + + "L_exit%=:\t\n" + + "cmp rdx, 1\t\n" + "je L_accum%=\t\n" + // Dump C + "vmovups ymmword PTR [r12 + 0], ymm0\t\n" + "vmovups ymmword PTR [r12 + 32], ymm1\t\n" + "add r12, r13\t\n" + "vmovups ymmword PTR [r12 + 0], ymm2\t\n" + "vmovups ymmword PTR [r12 + 32], ymm3\t\n" + "add r12, r13\t\n" + "jmp L_done%=\t\n" + + "L_accum%=:\t\n" + // Dump C with accumulate + "vbroadcastss ymm15,DWORD PTR [r15]\t\n" + "vfmadd231ps ymm0,ymm15,ymmword PTR [r12 + 0]\t\n" + "vmovups ymmword PTR [r12 + 0], ymm0\t\n" + "vfmadd231ps ymm1,ymm15,ymmword PTR [r12 + 32]\t\n" + "vmovups ymmword PTR [r12 + 32], ymm1\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm2,ymm15,ymmword PTR [r12 + 0]\t\n" + "vmovups ymmword PTR [r12 + 0], ymm2\t\n" + "vfmadd231ps ymm3,ymm15,ymmword PTR [r12 + 32]\t\n" + "vmovups ymmword PTR [r12 + 32], ymm3\t\n" + "add r12, r13\t\n" + + "L_done%=:\t\n" + + // next outer iteration + "add rcx, 64\t\n" + "mov r12, rcx\t\n" + "mov r9, rax\t\n" + "inc rbx\t\n" + "cmp rbx, rdi\t\n" + "jl loop_outter%=\t\n" + : + : [gp] "rm"(gp) + : "r8", + "r9", + "r10", + "r11", + "r15", + "r13", + "r14", + "rax", + "rcx", + "rdx", + "rsi", + "rdi", + "rbx", + "r12", + "memory"); +>>>>>>> upstream } void NOINLINE_ATTR gemmkernel_3x2_AVX2_fA0fB0fC0(GemmParams* gp) { @@ -183,6 +338,7 @@ void NOINLINE_ATTR gemmkernel_3x2_AVX2_fA0fB0fC0(GemmParams* gp) { // b_block_size uint64_t rsi = *(uint64_t *)((char*)r14 + 64); //"mov rsi, [r14 + 64]\t\n" // Make copies of A and C +<<<<<<< HEAD float* rax = r9; //"mov rax, r9\t\n" float* rcx = r12; //"mov rcx, r12\t\n" @@ -256,6 +412,102 @@ void NOINLINE_ATTR gemmkernel_3x2_AVX2_fA0fB0fC0(GemmParams* gp) { r12 = rcx; //"mov r12, rcx\t\n" r9 = rax; //"mov r9, rax\t\n" } //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n" +======= + "mov rax, r9\t\n" + "mov rcx, r12\t\n" + + "mov rbx, 0\t\n" + "loop_outter%=:\t\n" + "mov r14, 0\t\n" + "vxorps ymm0,ymm0,ymm0\t\n" + "vxorps ymm1,ymm1,ymm1\t\n" + "vxorps ymm2,ymm2,ymm2\t\n" + "vxorps ymm3,ymm3,ymm3\t\n" + "vxorps ymm4,ymm4,ymm4\t\n" + "vxorps ymm5,ymm5,ymm5\t\n" + + + "loop_inner%=:\t\n" + + "vcvtph2ps ymm7,XMMWORD PTR [r10 + 0]\t\n" + "vcvtph2ps ymm8,XMMWORD PTR [r10 + 16]\t\n" + "vbroadcastss ymm6,DWORD PTR [r9+0]\t\n" + "vfmadd231ps ymm0,ymm7,ymm6\t\n" + "vfmadd231ps ymm1,ymm8,ymm6\t\n" + "vbroadcastss ymm6,DWORD PTR [r9+4]\t\n" + "vfmadd231ps ymm2,ymm7,ymm6\t\n" + "vfmadd231ps ymm3,ymm8,ymm6\t\n" + "vbroadcastss ymm6,DWORD PTR [r9+8]\t\n" + "vfmadd231ps ymm4,ymm7,ymm6\t\n" + "vfmadd231ps ymm5,ymm8,ymm6\t\n" + "add r9,12\t\n" + "add r10,32\t\n" + "inc r14\t\n" + "cmp r14, r8\t\n" + "jl loop_inner%=\t\n" + + "L_exit%=:\t\n" + + "cmp rdx, 1\t\n" + "je L_accum%=\t\n" + // Dump C + "vmovups ymmword PTR [r12 + 0], ymm0\t\n" + "vmovups ymmword PTR [r12 + 32], ymm1\t\n" + "add r12, r13\t\n" + "vmovups ymmword PTR [r12 + 0], ymm2\t\n" + "vmovups ymmword PTR [r12 + 32], ymm3\t\n" + "add r12, r13\t\n" + "vmovups ymmword PTR [r12 + 0], ymm4\t\n" + "vmovups ymmword PTR [r12 + 32], ymm5\t\n" + "add r12, r13\t\n" + "jmp L_done%=\t\n" + + "L_accum%=:\t\n" + // Dump C with accumulate + "vbroadcastss ymm15,DWORD PTR [r15]\t\n" + "vfmadd231ps ymm0,ymm15,ymmword PTR [r12 + 0]\t\n" + "vmovups ymmword PTR [r12 + 0], ymm0\t\n" + "vfmadd231ps ymm1,ymm15,ymmword PTR [r12 + 32]\t\n" + "vmovups ymmword PTR [r12 + 32], ymm1\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm2,ymm15,ymmword PTR [r12 + 0]\t\n" + "vmovups ymmword PTR [r12 + 0], ymm2\t\n" + "vfmadd231ps ymm3,ymm15,ymmword PTR [r12 + 32]\t\n" + "vmovups ymmword PTR [r12 + 32], ymm3\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm4,ymm15,ymmword PTR [r12 + 0]\t\n" + "vmovups ymmword PTR [r12 + 0], ymm4\t\n" + "vfmadd231ps ymm5,ymm15,ymmword PTR [r12 + 32]\t\n" + "vmovups ymmword PTR [r12 + 32], ymm5\t\n" + "add r12, r13\t\n" + + "L_done%=:\t\n" + + // next outer iteration + "add rcx, 64\t\n" + "mov r12, rcx\t\n" + "mov r9, rax\t\n" + "inc rbx\t\n" + "cmp rbx, rdi\t\n" + "jl loop_outter%=\t\n" + : + : [gp] "rm"(gp) + : "r8", + "r9", + "r10", + "r11", + "r15", + "r13", + "r14", + "rax", + "rcx", + "rdx", + "rsi", + "rdi", + "rbx", + "r12", + "memory"); +>>>>>>> upstream } void NOINLINE_ATTR gemmkernel_4x2_AVX2_fA0fB0fC0(GemmParams* gp) { @@ -281,6 +533,7 @@ void NOINLINE_ATTR gemmkernel_4x2_AVX2_fA0fB0fC0(GemmParams* gp) { // b_block_size uint64_t rsi = *(uint64_t *)((char*)r14 + 64); //"mov rsi, [r14 + 64]\t\n" // Make copies of A and C +<<<<<<< HEAD float* rax = r9; //"mov rax, r9\t\n" float* rcx = r12; //"mov rcx, r12\t\n" @@ -369,6 +622,115 @@ void NOINLINE_ATTR gemmkernel_4x2_AVX2_fA0fB0fC0(GemmParams* gp) { r12 = rcx; //"mov r12, rcx\t\n" r9 = rax; //"mov r9, rax\t\n" } //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n" +======= + "mov rax, r9\t\n" + "mov rcx, r12\t\n" + + "mov rbx, 0\t\n" + "loop_outter%=:\t\n" + "mov r14, 0\t\n" + "vxorps ymm0,ymm0,ymm0\t\n" + "vxorps ymm1,ymm1,ymm1\t\n" + "vxorps ymm2,ymm2,ymm2\t\n" + "vxorps ymm3,ymm3,ymm3\t\n" + "vxorps ymm4,ymm4,ymm4\t\n" + "vxorps ymm5,ymm5,ymm5\t\n" + "vxorps ymm6,ymm6,ymm6\t\n" + "vxorps ymm7,ymm7,ymm7\t\n" + + + "loop_inner%=:\t\n" + + "vcvtph2ps ymm9,XMMWORD PTR [r10 + 0]\t\n" + "vcvtph2ps ymm10,XMMWORD PTR [r10 + 16]\t\n" + "vbroadcastss ymm8,DWORD PTR [r9+0]\t\n" + "vfmadd231ps ymm0,ymm9,ymm8\t\n" + "vfmadd231ps ymm1,ymm10,ymm8\t\n" + "vbroadcastss ymm8,DWORD PTR [r9+4]\t\n" + "vfmadd231ps ymm2,ymm9,ymm8\t\n" + "vfmadd231ps ymm3,ymm10,ymm8\t\n" + "vbroadcastss ymm8,DWORD PTR [r9+8]\t\n" + "vfmadd231ps ymm4,ymm9,ymm8\t\n" + "vfmadd231ps ymm5,ymm10,ymm8\t\n" + "vbroadcastss ymm8,DWORD PTR [r9+12]\t\n" + "vfmadd231ps ymm6,ymm9,ymm8\t\n" + "vfmadd231ps ymm7,ymm10,ymm8\t\n" + "add r9,16\t\n" + "add r10,32\t\n" + "inc r14\t\n" + "cmp r14, r8\t\n" + "jl loop_inner%=\t\n" + + "L_exit%=:\t\n" + + "cmp rdx, 1\t\n" + "je L_accum%=\t\n" + // Dump C + "vmovups ymmword PTR [r12 + 0], ymm0\t\n" + "vmovups ymmword PTR [r12 + 32], ymm1\t\n" + "add r12, r13\t\n" + "vmovups ymmword PTR [r12 + 0], ymm2\t\n" + "vmovups ymmword PTR [r12 + 32], ymm3\t\n" + "add r12, r13\t\n" + "vmovups ymmword PTR [r12 + 0], ymm4\t\n" + "vmovups ymmword PTR [r12 + 32], ymm5\t\n" + "add r12, r13\t\n" + "vmovups ymmword PTR [r12 + 0], ymm6\t\n" + "vmovups ymmword PTR [r12 + 32], ymm7\t\n" + "add r12, r13\t\n" + "jmp L_done%=\t\n" + + "L_accum%=:\t\n" + // Dump C with accumulate + "vbroadcastss ymm15,DWORD PTR [r15]\t\n" + "vfmadd231ps ymm0,ymm15,ymmword PTR [r12 + 0]\t\n" + "vmovups ymmword PTR [r12 + 0], ymm0\t\n" + "vfmadd231ps ymm1,ymm15,ymmword PTR [r12 + 32]\t\n" + "vmovups ymmword PTR [r12 + 32], ymm1\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm2,ymm15,ymmword PTR [r12 + 0]\t\n" + "vmovups ymmword PTR [r12 + 0], ymm2\t\n" + "vfmadd231ps ymm3,ymm15,ymmword PTR [r12 + 32]\t\n" + "vmovups ymmword PTR [r12 + 32], ymm3\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm4,ymm15,ymmword PTR [r12 + 0]\t\n" + "vmovups ymmword PTR [r12 + 0], ymm4\t\n" + "vfmadd231ps ymm5,ymm15,ymmword PTR [r12 + 32]\t\n" + "vmovups ymmword PTR [r12 + 32], ymm5\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm6,ymm15,ymmword PTR [r12 + 0]\t\n" + "vmovups ymmword PTR [r12 + 0], ymm6\t\n" + "vfmadd231ps ymm7,ymm15,ymmword PTR [r12 + 32]\t\n" + "vmovups ymmword PTR [r12 + 32], ymm7\t\n" + "add r12, r13\t\n" + + "L_done%=:\t\n" + + // next outer iteration + "add rcx, 64\t\n" + "mov r12, rcx\t\n" + "mov r9, rax\t\n" + "inc rbx\t\n" + "cmp rbx, rdi\t\n" + "jl loop_outter%=\t\n" + : + : [gp] "rm"(gp) + : "r8", + "r9", + "r10", + "r11", + "r15", + "r13", + "r14", + "rax", + "rcx", + "rdx", + "rsi", + "rdi", + "rbx", + "r12", + "memory"); +>>>>>>> upstream } void NOINLINE_ATTR gemmkernel_5x2_AVX2_fA0fB0fC0(GemmParams* gp) { @@ -394,6 +756,7 @@ void NOINLINE_ATTR gemmkernel_5x2_AVX2_fA0fB0fC0(GemmParams* gp) { // b_block_size uint64_t rsi = *(uint64_t *)((char*)r14 + 64); //"mov rsi, [r14 + 64]\t\n" // Make copies of A and C +<<<<<<< HEAD float* rax = r9; //"mov rax, r9\t\n" float* rcx = r12; //"mov rcx, r12\t\n" @@ -497,6 +860,128 @@ void NOINLINE_ATTR gemmkernel_5x2_AVX2_fA0fB0fC0(GemmParams* gp) { r12 = rcx; //"mov r12, rcx\t\n" r9 = rax; //"mov r9, rax\t\n" } //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n" +======= + "mov rax, r9\t\n" + "mov rcx, r12\t\n" + + "mov rbx, 0\t\n" + "loop_outter%=:\t\n" + "mov r14, 0\t\n" + "vxorps ymm0,ymm0,ymm0\t\n" + "vxorps ymm1,ymm1,ymm1\t\n" + "vxorps ymm2,ymm2,ymm2\t\n" + "vxorps ymm3,ymm3,ymm3\t\n" + "vxorps ymm4,ymm4,ymm4\t\n" + "vxorps ymm5,ymm5,ymm5\t\n" + "vxorps ymm6,ymm6,ymm6\t\n" + "vxorps ymm7,ymm7,ymm7\t\n" + "vxorps ymm8,ymm8,ymm8\t\n" + "vxorps ymm9,ymm9,ymm9\t\n" + + + "loop_inner%=:\t\n" + + "vcvtph2ps ymm11,XMMWORD PTR [r10 + 0]\t\n" + "vcvtph2ps ymm12,XMMWORD PTR [r10 + 16]\t\n" + "vbroadcastss ymm10,DWORD PTR [r9+0]\t\n" + "vfmadd231ps ymm0,ymm11,ymm10\t\n" + "vfmadd231ps ymm1,ymm12,ymm10\t\n" + "vbroadcastss ymm10,DWORD PTR [r9+4]\t\n" + "vfmadd231ps ymm2,ymm11,ymm10\t\n" + "vfmadd231ps ymm3,ymm12,ymm10\t\n" + "vbroadcastss ymm10,DWORD PTR [r9+8]\t\n" + "vfmadd231ps ymm4,ymm11,ymm10\t\n" + "vfmadd231ps ymm5,ymm12,ymm10\t\n" + "vbroadcastss ymm10,DWORD PTR [r9+12]\t\n" + "vfmadd231ps ymm6,ymm11,ymm10\t\n" + "vfmadd231ps ymm7,ymm12,ymm10\t\n" + "vbroadcastss ymm10,DWORD PTR [r9+16]\t\n" + "vfmadd231ps ymm8,ymm11,ymm10\t\n" + "vfmadd231ps ymm9,ymm12,ymm10\t\n" + "add r9,20\t\n" + "add r10,32\t\n" + "inc r14\t\n" + "cmp r14, r8\t\n" + "jl loop_inner%=\t\n" + + "L_exit%=:\t\n" + + "cmp rdx, 1\t\n" + "je L_accum%=\t\n" + // Dump C + "vmovups ymmword PTR [r12 + 0], ymm0\t\n" + "vmovups ymmword PTR [r12 + 32], ymm1\t\n" + "add r12, r13\t\n" + "vmovups ymmword PTR [r12 + 0], ymm2\t\n" + "vmovups ymmword PTR [r12 + 32], ymm3\t\n" + "add r12, r13\t\n" + "vmovups ymmword PTR [r12 + 0], ymm4\t\n" + "vmovups ymmword PTR [r12 + 32], ymm5\t\n" + "add r12, r13\t\n" + "vmovups ymmword PTR [r12 + 0], ymm6\t\n" + "vmovups ymmword PTR [r12 + 32], ymm7\t\n" + "add r12, r13\t\n" + "vmovups ymmword PTR [r12 + 0], ymm8\t\n" + "vmovups ymmword PTR [r12 + 32], ymm9\t\n" + "add r12, r13\t\n" + "jmp L_done%=\t\n" + + "L_accum%=:\t\n" + // Dump C with accumulate + "vbroadcastss ymm15,DWORD PTR [r15]\t\n" + "vfmadd231ps ymm0,ymm15,ymmword PTR [r12 + 0]\t\n" + "vmovups ymmword PTR [r12 + 0], ymm0\t\n" + "vfmadd231ps ymm1,ymm15,ymmword PTR [r12 + 32]\t\n" + "vmovups ymmword PTR [r12 + 32], ymm1\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm2,ymm15,ymmword PTR [r12 + 0]\t\n" + "vmovups ymmword PTR [r12 + 0], ymm2\t\n" + "vfmadd231ps ymm3,ymm15,ymmword PTR [r12 + 32]\t\n" + "vmovups ymmword PTR [r12 + 32], ymm3\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm4,ymm15,ymmword PTR [r12 + 0]\t\n" + "vmovups ymmword PTR [r12 + 0], ymm4\t\n" + "vfmadd231ps ymm5,ymm15,ymmword PTR [r12 + 32]\t\n" + "vmovups ymmword PTR [r12 + 32], ymm5\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm6,ymm15,ymmword PTR [r12 + 0]\t\n" + "vmovups ymmword PTR [r12 + 0], ymm6\t\n" + "vfmadd231ps ymm7,ymm15,ymmword PTR [r12 + 32]\t\n" + "vmovups ymmword PTR [r12 + 32], ymm7\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm8,ymm15,ymmword PTR [r12 + 0]\t\n" + "vmovups ymmword PTR [r12 + 0], ymm8\t\n" + "vfmadd231ps ymm9,ymm15,ymmword PTR [r12 + 32]\t\n" + "vmovups ymmword PTR [r12 + 32], ymm9\t\n" + "add r12, r13\t\n" + + "L_done%=:\t\n" + + // next outer iteration + "add rcx, 64\t\n" + "mov r12, rcx\t\n" + "mov r9, rax\t\n" + "inc rbx\t\n" + "cmp rbx, rdi\t\n" + "jl loop_outter%=\t\n" + : + : [gp] "rm"(gp) + : "r8", + "r9", + "r10", + "r11", + "r15", + "r13", + "r14", + "rax", + "rcx", + "rdx", + "rsi", + "rdi", + "rbx", + "r12", + "memory"); +>>>>>>> upstream } void NOINLINE_ATTR gemmkernel_6x2_AVX2_fA0fB0fC0(GemmParams* gp) { @@ -522,6 +1007,7 @@ void NOINLINE_ATTR gemmkernel_6x2_AVX2_fA0fB0fC0(GemmParams* gp) { // b_block_size uint64_t rsi = *(uint64_t *)((char*)r14 + 64); //"mov rsi, [r14 + 64]\t\n" // Make copies of A and C +<<<<<<< HEAD float* rax = r9; //"mov rax, r9\t\n" float* rcx = r12; //"mov rcx, r12\t\n" @@ -640,6 +1126,141 @@ void NOINLINE_ATTR gemmkernel_6x2_AVX2_fA0fB0fC0(GemmParams* gp) { r12 = rcx; //"mov r12, rcx\t\n" r9 = rax; //"mov r9, rax\t\n" } //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n" +======= + "mov rax, r9\t\n" + "mov rcx, r12\t\n" + + "mov rbx, 0\t\n" + "loop_outter%=:\t\n" + "mov r14, 0\t\n" + "vxorps ymm0,ymm0,ymm0\t\n" + "vxorps ymm1,ymm1,ymm1\t\n" + "vxorps ymm2,ymm2,ymm2\t\n" + "vxorps ymm3,ymm3,ymm3\t\n" + "vxorps ymm4,ymm4,ymm4\t\n" + "vxorps ymm5,ymm5,ymm5\t\n" + "vxorps ymm6,ymm6,ymm6\t\n" + "vxorps ymm7,ymm7,ymm7\t\n" + "vxorps ymm8,ymm8,ymm8\t\n" + "vxorps ymm9,ymm9,ymm9\t\n" + "vxorps ymm10,ymm10,ymm10\t\n" + "vxorps ymm11,ymm11,ymm11\t\n" + + + "loop_inner%=:\t\n" + + "vcvtph2ps ymm13,XMMWORD PTR [r10 + 0]\t\n" + "vcvtph2ps ymm14,XMMWORD PTR [r10 + 16]\t\n" + "vbroadcastss ymm12,DWORD PTR [r9+0]\t\n" + "vfmadd231ps ymm0,ymm13,ymm12\t\n" + "vfmadd231ps ymm1,ymm14,ymm12\t\n" + "vbroadcastss ymm12,DWORD PTR [r9+4]\t\n" + "vfmadd231ps ymm2,ymm13,ymm12\t\n" + "vfmadd231ps ymm3,ymm14,ymm12\t\n" + "vbroadcastss ymm12,DWORD PTR [r9+8]\t\n" + "vfmadd231ps ymm4,ymm13,ymm12\t\n" + "vfmadd231ps ymm5,ymm14,ymm12\t\n" + "vbroadcastss ymm12,DWORD PTR [r9+12]\t\n" + "vfmadd231ps ymm6,ymm13,ymm12\t\n" + "vfmadd231ps ymm7,ymm14,ymm12\t\n" + "vbroadcastss ymm12,DWORD PTR [r9+16]\t\n" + "vfmadd231ps ymm8,ymm13,ymm12\t\n" + "vfmadd231ps ymm9,ymm14,ymm12\t\n" + "vbroadcastss ymm12,DWORD PTR [r9+20]\t\n" + "vfmadd231ps ymm10,ymm13,ymm12\t\n" + "vfmadd231ps ymm11,ymm14,ymm12\t\n" + "add r9,24\t\n" + "add r10,32\t\n" + "inc r14\t\n" + "cmp r14, r8\t\n" + "jl loop_inner%=\t\n" + + "L_exit%=:\t\n" + + "cmp rdx, 1\t\n" + "je L_accum%=\t\n" + // Dump C + "vmovups ymmword PTR [r12 + 0], ymm0\t\n" + "vmovups ymmword PTR [r12 + 32], ymm1\t\n" + "add r12, r13\t\n" + "vmovups ymmword PTR [r12 + 0], ymm2\t\n" + "vmovups ymmword PTR [r12 + 32], ymm3\t\n" + "add r12, r13\t\n" + "vmovups ymmword PTR [r12 + 0], ymm4\t\n" + "vmovups ymmword PTR [r12 + 32], ymm5\t\n" + "add r12, r13\t\n" + "vmovups ymmword PTR [r12 + 0], ymm6\t\n" + "vmovups ymmword PTR [r12 + 32], ymm7\t\n" + "add r12, r13\t\n" + "vmovups ymmword PTR [r12 + 0], ymm8\t\n" + "vmovups ymmword PTR [r12 + 32], ymm9\t\n" + "add r12, r13\t\n" + "vmovups ymmword PTR [r12 + 0], ymm10\t\n" + "vmovups ymmword PTR [r12 + 32], ymm11\t\n" + "add r12, r13\t\n" + "jmp L_done%=\t\n" + + "L_accum%=:\t\n" + // Dump C with accumulate + "vbroadcastss ymm15,DWORD PTR [r15]\t\n" + "vfmadd231ps ymm0,ymm15,ymmword PTR [r12 + 0]\t\n" + "vmovups ymmword PTR [r12 + 0], ymm0\t\n" + "vfmadd231ps ymm1,ymm15,ymmword PTR [r12 + 32]\t\n" + "vmovups ymmword PTR [r12 + 32], ymm1\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm2,ymm15,ymmword PTR [r12 + 0]\t\n" + "vmovups ymmword PTR [r12 + 0], ymm2\t\n" + "vfmadd231ps ymm3,ymm15,ymmword PTR [r12 + 32]\t\n" + "vmovups ymmword PTR [r12 + 32], ymm3\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm4,ymm15,ymmword PTR [r12 + 0]\t\n" + "vmovups ymmword PTR [r12 + 0], ymm4\t\n" + "vfmadd231ps ymm5,ymm15,ymmword PTR [r12 + 32]\t\n" + "vmovups ymmword PTR [r12 + 32], ymm5\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm6,ymm15,ymmword PTR [r12 + 0]\t\n" + "vmovups ymmword PTR [r12 + 0], ymm6\t\n" + "vfmadd231ps ymm7,ymm15,ymmword PTR [r12 + 32]\t\n" + "vmovups ymmword PTR [r12 + 32], ymm7\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm8,ymm15,ymmword PTR [r12 + 0]\t\n" + "vmovups ymmword PTR [r12 + 0], ymm8\t\n" + "vfmadd231ps ymm9,ymm15,ymmword PTR [r12 + 32]\t\n" + "vmovups ymmword PTR [r12 + 32], ymm9\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm10,ymm15,ymmword PTR [r12 + 0]\t\n" + "vmovups ymmword PTR [r12 + 0], ymm10\t\n" + "vfmadd231ps ymm11,ymm15,ymmword PTR [r12 + 32]\t\n" + "vmovups ymmword PTR [r12 + 32], ymm11\t\n" + "add r12, r13\t\n" + + "L_done%=:\t\n" + + // next outer iteration + "add rcx, 64\t\n" + "mov r12, rcx\t\n" + "mov r9, rax\t\n" + "inc rbx\t\n" + "cmp rbx, rdi\t\n" + "jl loop_outter%=\t\n" + : + : [gp] "rm"(gp) + : "r8", + "r9", + "r10", + "r11", + "r15", + "r13", + "r14", + "rax", + "rcx", + "rdx", + "rsi", + "rdi", + "rbx", + "r12", + "memory"); +>>>>>>> upstream } diff --git a/src/FbgemmFP16UKernelsAvx2.h b/src/FbgemmFP16UKernelsAvx2.h index d48a88e..95fea41 100644 --- a/src/FbgemmFP16UKernelsAvx2.h +++ b/src/FbgemmFP16UKernelsAvx2.h @@ -4,8 +4,7 @@ * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ -#ifndef FBGEMM_UKERNELS -#define FBGEMM_UKERNELS +#pragma once #include <cstdint> #include "fbgemm/Types.h" @@ -40,4 +39,3 @@ typedef void (*funcptr_fp16)(GemmParams* gp); } // namespace fbgemm -#endif diff --git a/src/FbgemmFP16UKernelsAvx512.cc b/src/FbgemmFP16UKernelsAvx512.cc new file mode 100644 index 0000000..a7927c5 --- /dev/null +++ b/src/FbgemmFP16UKernelsAvx512.cc @@ -0,0 +1,2558 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "FbgemmFP16UKernelsAvx512.h" + +namespace fbgemm { + +void __attribute__((noinline)) gemmkernel_1x2_AVX512_fA0fB0fC0(GemmParams* gp) { + asm volatile( +#if !defined(__clang__) + "mov r14, %[gp]\t\n" +#else + "mov %[gp], %%r14\t\n" + ".intel_syntax noprefix\t\n" +#endif + + // Copy parameters + // k + "mov r8, [r14 + 0]\t\n" + // A + "mov r9, [r14 + 8]\t\n" + // B + "mov r10, [r14 + 16]\t\n" + // beta + "mov r15, [r14 + 24]\t\n" + // accum + "mov rdx, [r14 + 32]\t\n" + // C + "mov r12, [r14 + 40]\t\n" + // ldc + "mov r13, [r14 + 48]\t\n" + // b_block_cols + "mov rdi, [r14 + 56]\t\n" + // b_block_size + "mov rsi, [r14 + 64]\t\n" + // Make copies of A and C + "mov rax, r9\t\n" + "mov rcx, r12\t\n" + + "mov rbx, 0\t\n" + "loop_outter%=:\t\n" + "mov r14, 0\t\n" + "vxorps zmm0,zmm0,zmm0\t\n" + "vxorps zmm1,zmm1,zmm1\t\n" + + "loop_inner%=:\t\n" + + "vcvtph2ps zmm3,YMMWORD PTR [r10 + 0]\t\n" + "vcvtph2ps zmm4,YMMWORD PTR [r10 + 32]\t\n" + "vbroadcastss zmm2,DWORD PTR [r9+0]\t\n" + "vfmadd231ps zmm0,zmm3,zmm2\t\n" + "vfmadd231ps zmm1,zmm4,zmm2\t\n" + "add r9,4\t\n" + "add r10,64\t\n" + "inc r14\t\n" + "cmp r14, r8\t\n" + "jl loop_inner%=\t\n" + + "L_exit%=:\t\n" + + "cmp rdx, 1\t\n" + "je L_accum%=\t\n" + // Dump C + "vmovups zmmword PTR [r12 + 0], zmm0\t\n" + "vmovups zmmword PTR [r12 + 64], zmm1\t\n" + "add r12, r13\t\n" + "jmp L_done%=\t\n" + + "L_accum%=:\t\n" + // Dump C with accumulate + "vbroadcastss zmm31,DWORD PTR [r15]\t\n" + "vfmadd231ps zmm0,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm0\t\n" + "vfmadd231ps zmm1,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm1\t\n" + "add r12, r13\t\n" + + "L_done%=:\t\n" + + // next outer iteration + "add rcx, 128\t\n" + "mov r12, rcx\t\n" + "mov r9, rax\t\n" + "inc rbx\t\n" + "cmp rbx, rdi\t\n" + "jl loop_outter%=\t\n" + : + : [gp] "rm"(gp) + : "r8", + "r9", + "r10", + "r11", + "r15", + "r13", + "r14", + "rax", + "rcx", + "rdx", + "rsi", + "rdi", + "rbx", + "r12", + "memory"); +} +void __attribute__((noinline)) gemmkernel_2x2_AVX512_fA0fB0fC0(GemmParams* gp) { + asm volatile( +#if !defined(__clang__) + "mov r14, %[gp]\t\n" +#else + "mov %[gp], %%r14\t\n" + ".intel_syntax noprefix\t\n" +#endif + + // Copy parameters + // k + "mov r8, [r14 + 0]\t\n" + // A + "mov r9, [r14 + 8]\t\n" + // B + "mov r10, [r14 + 16]\t\n" + // beta + "mov r15, [r14 + 24]\t\n" + // accum + "mov rdx, [r14 + 32]\t\n" + // C + "mov r12, [r14 + 40]\t\n" + // ldc + "mov r13, [r14 + 48]\t\n" + // b_block_cols + "mov rdi, [r14 + 56]\t\n" + // b_block_size + "mov rsi, [r14 + 64]\t\n" + // Make copies of A and C + "mov rax, r9\t\n" + "mov rcx, r12\t\n" + + "mov rbx, 0\t\n" + "loop_outter%=:\t\n" + "mov r14, 0\t\n" + "vxorps zmm0,zmm0,zmm0\t\n" + "vxorps zmm1,zmm1,zmm1\t\n" + "vxorps zmm2,zmm2,zmm2\t\n" + "vxorps zmm3,zmm3,zmm3\t\n" + + "loop_inner%=:\t\n" + + "vcvtph2ps zmm5,YMMWORD PTR [r10 + 0]\t\n" + "vcvtph2ps zmm6,YMMWORD PTR [r10 + 32]\t\n" + "vbroadcastss zmm4,DWORD PTR [r9+0]\t\n" + "vfmadd231ps zmm0,zmm5,zmm4\t\n" + "vfmadd231ps zmm1,zmm6,zmm4\t\n" + "vbroadcastss zmm4,DWORD PTR [r9+4]\t\n" + "vfmadd231ps zmm2,zmm5,zmm4\t\n" + "vfmadd231ps zmm3,zmm6,zmm4\t\n" + "add r9,8\t\n" + "add r10,64\t\n" + "inc r14\t\n" + "cmp r14, r8\t\n" + "jl loop_inner%=\t\n" + + "L_exit%=:\t\n" + + "cmp rdx, 1\t\n" + "je L_accum%=\t\n" + // Dump C + "vmovups zmmword PTR [r12 + 0], zmm0\t\n" + "vmovups zmmword PTR [r12 + 64], zmm1\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm2\t\n" + "vmovups zmmword PTR [r12 + 64], zmm3\t\n" + "add r12, r13\t\n" + "jmp L_done%=\t\n" + + "L_accum%=:\t\n" + // Dump C with accumulate + "vbroadcastss zmm31,DWORD PTR [r15]\t\n" + "vfmadd231ps zmm0,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm0\t\n" + "vfmadd231ps zmm1,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm1\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm2,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm2\t\n" + "vfmadd231ps zmm3,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm3\t\n" + "add r12, r13\t\n" + + "L_done%=:\t\n" + + // next outer iteration + "add rcx, 128\t\n" + "mov r12, rcx\t\n" + "mov r9, rax\t\n" + "inc rbx\t\n" + "cmp rbx, rdi\t\n" + "jl loop_outter%=\t\n" + : + : [gp] "rm"(gp) + : "r8", + "r9", + "r10", + "r11", + "r15", + "r13", + "r14", + "rax", + "rcx", + "rdx", + "rsi", + "rdi", + "rbx", + "r12", + "memory"); +} +void __attribute__((noinline)) gemmkernel_3x2_AVX512_fA0fB0fC0(GemmParams* gp) { + asm volatile( +#if !defined(__clang__) + "mov r14, %[gp]\t\n" +#else + "mov %[gp], %%r14\t\n" + ".intel_syntax noprefix\t\n" +#endif + + // Copy parameters + // k + "mov r8, [r14 + 0]\t\n" + // A + "mov r9, [r14 + 8]\t\n" + // B + "mov r10, [r14 + 16]\t\n" + // beta + "mov r15, [r14 + 24]\t\n" + // accum + "mov rdx, [r14 + 32]\t\n" + // C + "mov r12, [r14 + 40]\t\n" + // ldc + "mov r13, [r14 + 48]\t\n" + // b_block_cols + "mov rdi, [r14 + 56]\t\n" + // b_block_size + "mov rsi, [r14 + 64]\t\n" + // Make copies of A and C + "mov rax, r9\t\n" + "mov rcx, r12\t\n" + + "mov rbx, 0\t\n" + "loop_outter%=:\t\n" + "mov r14, 0\t\n" + "vxorps zmm0,zmm0,zmm0\t\n" + "vxorps zmm1,zmm1,zmm1\t\n" + "vxorps zmm2,zmm2,zmm2\t\n" + "vxorps zmm3,zmm3,zmm3\t\n" + "vxorps zmm4,zmm4,zmm4\t\n" + "vxorps zmm5,zmm5,zmm5\t\n" + + "loop_inner%=:\t\n" + + "vcvtph2ps zmm7,YMMWORD PTR [r10 + 0]\t\n" + "vcvtph2ps zmm8,YMMWORD PTR [r10 + 32]\t\n" + "vbroadcastss zmm6,DWORD PTR [r9+0]\t\n" + "vfmadd231ps zmm0,zmm7,zmm6\t\n" + "vfmadd231ps zmm1,zmm8,zmm6\t\n" + "vbroadcastss zmm6,DWORD PTR [r9+4]\t\n" + "vfmadd231ps zmm2,zmm7,zmm6\t\n" + "vfmadd231ps zmm3,zmm8,zmm6\t\n" + "vbroadcastss zmm6,DWORD PTR [r9+8]\t\n" + "vfmadd231ps zmm4,zmm7,zmm6\t\n" + "vfmadd231ps zmm5,zmm8,zmm6\t\n" + "add r9,12\t\n" + "add r10,64\t\n" + "inc r14\t\n" + "cmp r14, r8\t\n" + "jl loop_inner%=\t\n" + + "L_exit%=:\t\n" + + "cmp rdx, 1\t\n" + "je L_accum%=\t\n" + // Dump C + "vmovups zmmword PTR [r12 + 0], zmm0\t\n" + "vmovups zmmword PTR [r12 + 64], zmm1\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm2\t\n" + "vmovups zmmword PTR [r12 + 64], zmm3\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm4\t\n" + "vmovups zmmword PTR [r12 + 64], zmm5\t\n" + "add r12, r13\t\n" + "jmp L_done%=\t\n" + + "L_accum%=:\t\n" + // Dump C with accumulate + "vbroadcastss zmm31,DWORD PTR [r15]\t\n" + "vfmadd231ps zmm0,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm0\t\n" + "vfmadd231ps zmm1,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm1\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm2,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm2\t\n" + "vfmadd231ps zmm3,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm3\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm4,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm4\t\n" + "vfmadd231ps zmm5,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm5\t\n" + "add r12, r13\t\n" + + "L_done%=:\t\n" + + // next outer iteration + "add rcx, 128\t\n" + "mov r12, rcx\t\n" + "mov r9, rax\t\n" + "inc rbx\t\n" + "cmp rbx, rdi\t\n" + "jl loop_outter%=\t\n" + : + : [gp] "rm"(gp) + : "r8", + "r9", + "r10", + "r11", + "r15", + "r13", + "r14", + "rax", + "rcx", + "rdx", + "rsi", + "rdi", + "rbx", + "r12", + "memory"); +} +void __attribute__((noinline)) gemmkernel_4x2_AVX512_fA0fB0fC0(GemmParams* gp) { + asm volatile( +#if !defined(__clang__) + "mov r14, %[gp]\t\n" +#else + "mov %[gp], %%r14\t\n" + ".intel_syntax noprefix\t\n" +#endif + + // Copy parameters + // k + "mov r8, [r14 + 0]\t\n" + // A + "mov r9, [r14 + 8]\t\n" + // B + "mov r10, [r14 + 16]\t\n" + // beta + "mov r15, [r14 + 24]\t\n" + // accum + "mov rdx, [r14 + 32]\t\n" + // C + "mov r12, [r14 + 40]\t\n" + // ldc + "mov r13, [r14 + 48]\t\n" + // b_block_cols + "mov rdi, [r14 + 56]\t\n" + // b_block_size + "mov rsi, [r14 + 64]\t\n" + // Make copies of A and C + "mov rax, r9\t\n" + "mov rcx, r12\t\n" + + "mov rbx, 0\t\n" + "loop_outter%=:\t\n" + "mov r14, 0\t\n" + "vxorps zmm0,zmm0,zmm0\t\n" + "vxorps zmm1,zmm1,zmm1\t\n" + "vxorps zmm2,zmm2,zmm2\t\n" + "vxorps zmm3,zmm3,zmm3\t\n" + "vxorps zmm4,zmm4,zmm4\t\n" + "vxorps zmm5,zmm5,zmm5\t\n" + "vxorps zmm6,zmm6,zmm6\t\n" + "vxorps zmm7,zmm7,zmm7\t\n" + + "loop_inner%=:\t\n" + + "vcvtph2ps zmm9,YMMWORD PTR [r10 + 0]\t\n" + "vcvtph2ps zmm10,YMMWORD PTR [r10 + 32]\t\n" + "vbroadcastss zmm8,DWORD PTR [r9+0]\t\n" + "vfmadd231ps zmm0,zmm9,zmm8\t\n" + "vfmadd231ps zmm1,zmm10,zmm8\t\n" + "vbroadcastss zmm8,DWORD PTR [r9+4]\t\n" + "vfmadd231ps zmm2,zmm9,zmm8\t\n" + "vfmadd231ps zmm3,zmm10,zmm8\t\n" + "vbroadcastss zmm8,DWORD PTR [r9+8]\t\n" + "vfmadd231ps zmm4,zmm9,zmm8\t\n" + "vfmadd231ps zmm5,zmm10,zmm8\t\n" + "vbroadcastss zmm8,DWORD PTR [r9+12]\t\n" + "vfmadd231ps zmm6,zmm9,zmm8\t\n" + "vfmadd231ps zmm7,zmm10,zmm8\t\n" + "add r9,16\t\n" + "add r10,64\t\n" + "inc r14\t\n" + "cmp r14, r8\t\n" + "jl loop_inner%=\t\n" + + "L_exit%=:\t\n" + + "cmp rdx, 1\t\n" + "je L_accum%=\t\n" + // Dump C + "vmovups zmmword PTR [r12 + 0], zmm0\t\n" + "vmovups zmmword PTR [r12 + 64], zmm1\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm2\t\n" + "vmovups zmmword PTR [r12 + 64], zmm3\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm4\t\n" + "vmovups zmmword PTR [r12 + 64], zmm5\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm6\t\n" + "vmovups zmmword PTR [r12 + 64], zmm7\t\n" + "add r12, r13\t\n" + "jmp L_done%=\t\n" + + "L_accum%=:\t\n" + // Dump C with accumulate + "vbroadcastss zmm31,DWORD PTR [r15]\t\n" + "vfmadd231ps zmm0,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm0\t\n" + "vfmadd231ps zmm1,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm1\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm2,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm2\t\n" + "vfmadd231ps zmm3,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm3\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm4,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm4\t\n" + "vfmadd231ps zmm5,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm5\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm6,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm6\t\n" + "vfmadd231ps zmm7,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm7\t\n" + "add r12, r13\t\n" + + "L_done%=:\t\n" + + // next outer iteration + "add rcx, 128\t\n" + "mov r12, rcx\t\n" + "mov r9, rax\t\n" + "inc rbx\t\n" + "cmp rbx, rdi\t\n" + "jl loop_outter%=\t\n" + : + : [gp] "rm"(gp) + : "r8", + "r9", + "r10", + "r11", + "r15", + "r13", + "r14", + "rax", + "rcx", + "rdx", + "rsi", + "rdi", + "rbx", + "r12", + "memory"); +} +void __attribute__((noinline)) gemmkernel_5x2_AVX512_fA0fB0fC0(GemmParams* gp) { + asm volatile( +#if !defined(__clang__) + "mov r14, %[gp]\t\n" +#else + "mov %[gp], %%r14\t\n" + ".intel_syntax noprefix\t\n" +#endif + + // Copy parameters + // k + "mov r8, [r14 + 0]\t\n" + // A + "mov r9, [r14 + 8]\t\n" + // B + "mov r10, [r14 + 16]\t\n" + // beta + "mov r15, [r14 + 24]\t\n" + // accum + "mov rdx, [r14 + 32]\t\n" + // C + "mov r12, [r14 + 40]\t\n" + // ldc + "mov r13, [r14 + 48]\t\n" + // b_block_cols + "mov rdi, [r14 + 56]\t\n" + // b_block_size + "mov rsi, [r14 + 64]\t\n" + // Make copies of A and C + "mov rax, r9\t\n" + "mov rcx, r12\t\n" + + "mov rbx, 0\t\n" + "loop_outter%=:\t\n" + "mov r14, 0\t\n" + "vxorps zmm0,zmm0,zmm0\t\n" + "vxorps zmm1,zmm1,zmm1\t\n" + "vxorps zmm2,zmm2,zmm2\t\n" + "vxorps zmm3,zmm3,zmm3\t\n" + "vxorps zmm4,zmm4,zmm4\t\n" + "vxorps zmm5,zmm5,zmm5\t\n" + "vxorps zmm6,zmm6,zmm6\t\n" + "vxorps zmm7,zmm7,zmm7\t\n" + "vxorps zmm8,zmm8,zmm8\t\n" + "vxorps zmm9,zmm9,zmm9\t\n" + + "loop_inner%=:\t\n" + + "vcvtph2ps zmm11,YMMWORD PTR [r10 + 0]\t\n" + "vcvtph2ps zmm12,YMMWORD PTR [r10 + 32]\t\n" + "vbroadcastss zmm10,DWORD PTR [r9+0]\t\n" + "vfmadd231ps zmm0,zmm11,zmm10\t\n" + "vfmadd231ps zmm1,zmm12,zmm10\t\n" + "vbroadcastss zmm10,DWORD PTR [r9+4]\t\n" + "vfmadd231ps zmm2,zmm11,zmm10\t\n" + "vfmadd231ps zmm3,zmm12,zmm10\t\n" + "vbroadcastss zmm10,DWORD PTR [r9+8]\t\n" + "vfmadd231ps zmm4,zmm11,zmm10\t\n" + "vfmadd231ps zmm5,zmm12,zmm10\t\n" + "vbroadcastss zmm10,DWORD PTR [r9+12]\t\n" + "vfmadd231ps zmm6,zmm11,zmm10\t\n" + "vfmadd231ps zmm7,zmm12,zmm10\t\n" + "vbroadcastss zmm10,DWORD PTR [r9+16]\t\n" + "vfmadd231ps zmm8,zmm11,zmm10\t\n" + "vfmadd231ps zmm9,zmm12,zmm10\t\n" + "add r9,20\t\n" + "add r10,64\t\n" + "inc r14\t\n" + "cmp r14, r8\t\n" + "jl loop_inner%=\t\n" + + "L_exit%=:\t\n" + + "cmp rdx, 1\t\n" + "je L_accum%=\t\n" + // Dump C + "vmovups zmmword PTR [r12 + 0], zmm0\t\n" + "vmovups zmmword PTR [r12 + 64], zmm1\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm2\t\n" + "vmovups zmmword PTR [r12 + 64], zmm3\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm4\t\n" + "vmovups zmmword PTR [r12 + 64], zmm5\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm6\t\n" + "vmovups zmmword PTR [r12 + 64], zmm7\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm8\t\n" + "vmovups zmmword PTR [r12 + 64], zmm9\t\n" + "add r12, r13\t\n" + "jmp L_done%=\t\n" + + "L_accum%=:\t\n" + // Dump C with accumulate + "vbroadcastss zmm31,DWORD PTR [r15]\t\n" + "vfmadd231ps zmm0,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm0\t\n" + "vfmadd231ps zmm1,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm1\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm2,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm2\t\n" + "vfmadd231ps zmm3,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm3\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm4,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm4\t\n" + "vfmadd231ps zmm5,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm5\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm6,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm6\t\n" + "vfmadd231ps zmm7,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm7\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm8,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm8\t\n" + "vfmadd231ps zmm9,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm9\t\n" + "add r12, r13\t\n" + + "L_done%=:\t\n" + + // next outer iteration + "add rcx, 128\t\n" + "mov r12, rcx\t\n" + "mov r9, rax\t\n" + "inc rbx\t\n" + "cmp rbx, rdi\t\n" + "jl loop_outter%=\t\n" + : + : [gp] "rm"(gp) + : "r8", + "r9", + "r10", + "r11", + "r15", + "r13", + "r14", + "rax", + "rcx", + "rdx", + "rsi", + "rdi", + "rbx", + "r12", + "memory"); +} +void __attribute__((noinline)) gemmkernel_6x2_AVX512_fA0fB0fC0(GemmParams* gp) { + asm volatile( +#if !defined(__clang__) + "mov r14, %[gp]\t\n" +#else + "mov %[gp], %%r14\t\n" + ".intel_syntax noprefix\t\n" +#endif + + // Copy parameters + // k + "mov r8, [r14 + 0]\t\n" + // A + "mov r9, [r14 + 8]\t\n" + // B + "mov r10, [r14 + 16]\t\n" + // beta + "mov r15, [r14 + 24]\t\n" + // accum + "mov rdx, [r14 + 32]\t\n" + // C + "mov r12, [r14 + 40]\t\n" + // ldc + "mov r13, [r14 + 48]\t\n" + // b_block_cols + "mov rdi, [r14 + 56]\t\n" + // b_block_size + "mov rsi, [r14 + 64]\t\n" + // Make copies of A and C + "mov rax, r9\t\n" + "mov rcx, r12\t\n" + + "mov rbx, 0\t\n" + "loop_outter%=:\t\n" + "mov r14, 0\t\n" + "vxorps zmm0,zmm0,zmm0\t\n" + "vxorps zmm1,zmm1,zmm1\t\n" + "vxorps zmm2,zmm2,zmm2\t\n" + "vxorps zmm3,zmm3,zmm3\t\n" + "vxorps zmm4,zmm4,zmm4\t\n" + "vxorps zmm5,zmm5,zmm5\t\n" + "vxorps zmm6,zmm6,zmm6\t\n" + "vxorps zmm7,zmm7,zmm7\t\n" + "vxorps zmm8,zmm8,zmm8\t\n" + "vxorps zmm9,zmm9,zmm9\t\n" + "vxorps zmm10,zmm10,zmm10\t\n" + "vxorps zmm11,zmm11,zmm11\t\n" + + "loop_inner%=:\t\n" + + "vcvtph2ps zmm13,YMMWORD PTR [r10 + 0]\t\n" + "vcvtph2ps zmm14,YMMWORD PTR [r10 + 32]\t\n" + "vbroadcastss zmm12,DWORD PTR [r9+0]\t\n" + "vfmadd231ps zmm0,zmm13,zmm12\t\n" + "vfmadd231ps zmm1,zmm14,zmm12\t\n" + "vbroadcastss zmm12,DWORD PTR [r9+4]\t\n" + "vfmadd231ps zmm2,zmm13,zmm12\t\n" + "vfmadd231ps zmm3,zmm14,zmm12\t\n" + "vbroadcastss zmm12,DWORD PTR [r9+8]\t\n" + "vfmadd231ps zmm4,zmm13,zmm12\t\n" + "vfmadd231ps zmm5,zmm14,zmm12\t\n" + "vbroadcastss zmm12,DWORD PTR [r9+12]\t\n" + "vfmadd231ps zmm6,zmm13,zmm12\t\n" + "vfmadd231ps zmm7,zmm14,zmm12\t\n" + "vbroadcastss zmm12,DWORD PTR [r9+16]\t\n" + "vfmadd231ps zmm8,zmm13,zmm12\t\n" + "vfmadd231ps zmm9,zmm14,zmm12\t\n" + "vbroadcastss zmm12,DWORD PTR [r9+20]\t\n" + "vfmadd231ps zmm10,zmm13,zmm12\t\n" + "vfmadd231ps zmm11,zmm14,zmm12\t\n" + "add r9,24\t\n" + "add r10,64\t\n" + "inc r14\t\n" + "cmp r14, r8\t\n" + "jl loop_inner%=\t\n" + + "L_exit%=:\t\n" + + "cmp rdx, 1\t\n" + "je L_accum%=\t\n" + // Dump C + "vmovups zmmword PTR [r12 + 0], zmm0\t\n" + "vmovups zmmword PTR [r12 + 64], zmm1\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm2\t\n" + "vmovups zmmword PTR [r12 + 64], zmm3\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm4\t\n" + "vmovups zmmword PTR [r12 + 64], zmm5\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm6\t\n" + "vmovups zmmword PTR [r12 + 64], zmm7\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm8\t\n" + "vmovups zmmword PTR [r12 + 64], zmm9\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm10\t\n" + "vmovups zmmword PTR [r12 + 64], zmm11\t\n" + "add r12, r13\t\n" + "jmp L_done%=\t\n" + + "L_accum%=:\t\n" + // Dump C with accumulate + "vbroadcastss zmm31,DWORD PTR [r15]\t\n" + "vfmadd231ps zmm0,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm0\t\n" + "vfmadd231ps zmm1,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm1\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm2,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm2\t\n" + "vfmadd231ps zmm3,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm3\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm4,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm4\t\n" + "vfmadd231ps zmm5,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm5\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm6,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm6\t\n" + "vfmadd231ps zmm7,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm7\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm8,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm8\t\n" + "vfmadd231ps zmm9,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm9\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm10,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm10\t\n" + "vfmadd231ps zmm11,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm11\t\n" + "add r12, r13\t\n" + + "L_done%=:\t\n" + + // next outer iteration + "add rcx, 128\t\n" + "mov r12, rcx\t\n" + "mov r9, rax\t\n" + "inc rbx\t\n" + "cmp rbx, rdi\t\n" + "jl loop_outter%=\t\n" + : + : [gp] "rm"(gp) + : "r8", + "r9", + "r10", + "r11", + "r15", + "r13", + "r14", + "rax", + "rcx", + "rdx", + "rsi", + "rdi", + "rbx", + "r12", + "memory"); +} +void __attribute__((noinline)) gemmkernel_7x2_AVX512_fA0fB0fC0(GemmParams* gp) { + asm volatile( +#if !defined(__clang__) + "mov r14, %[gp]\t\n" +#else + "mov %[gp], %%r14\t\n" + ".intel_syntax noprefix\t\n" +#endif + + // Copy parameters + // k + "mov r8, [r14 + 0]\t\n" + // A + "mov r9, [r14 + 8]\t\n" + // B + "mov r10, [r14 + 16]\t\n" + // beta + "mov r15, [r14 + 24]\t\n" + // accum + "mov rdx, [r14 + 32]\t\n" + // C + "mov r12, [r14 + 40]\t\n" + // ldc + "mov r13, [r14 + 48]\t\n" + // b_block_cols + "mov rdi, [r14 + 56]\t\n" + // b_block_size + "mov rsi, [r14 + 64]\t\n" + // Make copies of A and C + "mov rax, r9\t\n" + "mov rcx, r12\t\n" + + "mov rbx, 0\t\n" + "loop_outter%=:\t\n" + "mov r14, 0\t\n" + "vxorps zmm0,zmm0,zmm0\t\n" + "vxorps zmm1,zmm1,zmm1\t\n" + "vxorps zmm2,zmm2,zmm2\t\n" + "vxorps zmm3,zmm3,zmm3\t\n" + "vxorps zmm4,zmm4,zmm4\t\n" + "vxorps zmm5,zmm5,zmm5\t\n" + "vxorps zmm6,zmm6,zmm6\t\n" + "vxorps zmm7,zmm7,zmm7\t\n" + "vxorps zmm8,zmm8,zmm8\t\n" + "vxorps zmm9,zmm9,zmm9\t\n" + "vxorps zmm10,zmm10,zmm10\t\n" + "vxorps zmm11,zmm11,zmm11\t\n" + "vxorps zmm12,zmm12,zmm12\t\n" + "vxorps zmm13,zmm13,zmm13\t\n" + + "loop_inner%=:\t\n" + + "vcvtph2ps zmm15,YMMWORD PTR [r10 + 0]\t\n" + "vcvtph2ps zmm16,YMMWORD PTR [r10 + 32]\t\n" + "vbroadcastss zmm14,DWORD PTR [r9+0]\t\n" + "vfmadd231ps zmm0,zmm15,zmm14\t\n" + "vfmadd231ps zmm1,zmm16,zmm14\t\n" + "vbroadcastss zmm14,DWORD PTR [r9+4]\t\n" + "vfmadd231ps zmm2,zmm15,zmm14\t\n" + "vfmadd231ps zmm3,zmm16,zmm14\t\n" + "vbroadcastss zmm14,DWORD PTR [r9+8]\t\n" + "vfmadd231ps zmm4,zmm15,zmm14\t\n" + "vfmadd231ps zmm5,zmm16,zmm14\t\n" + "vbroadcastss zmm14,DWORD PTR [r9+12]\t\n" + "vfmadd231ps zmm6,zmm15,zmm14\t\n" + "vfmadd231ps zmm7,zmm16,zmm14\t\n" + "vbroadcastss zmm14,DWORD PTR [r9+16]\t\n" + "vfmadd231ps zmm8,zmm15,zmm14\t\n" + "vfmadd231ps zmm9,zmm16,zmm14\t\n" + "vbroadcastss zmm14,DWORD PTR [r9+20]\t\n" + "vfmadd231ps zmm10,zmm15,zmm14\t\n" + "vfmadd231ps zmm11,zmm16,zmm14\t\n" + "vbroadcastss zmm14,DWORD PTR [r9+24]\t\n" + "vfmadd231ps zmm12,zmm15,zmm14\t\n" + "vfmadd231ps zmm13,zmm16,zmm14\t\n" + "add r9,28\t\n" + "add r10,64\t\n" + "inc r14\t\n" + "cmp r14, r8\t\n" + "jl loop_inner%=\t\n" + + "L_exit%=:\t\n" + + "cmp rdx, 1\t\n" + "je L_accum%=\t\n" + // Dump C + "vmovups zmmword PTR [r12 + 0], zmm0\t\n" + "vmovups zmmword PTR [r12 + 64], zmm1\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm2\t\n" + "vmovups zmmword PTR [r12 + 64], zmm3\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm4\t\n" + "vmovups zmmword PTR [r12 + 64], zmm5\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm6\t\n" + "vmovups zmmword PTR [r12 + 64], zmm7\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm8\t\n" + "vmovups zmmword PTR [r12 + 64], zmm9\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm10\t\n" + "vmovups zmmword PTR [r12 + 64], zmm11\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm12\t\n" + "vmovups zmmword PTR [r12 + 64], zmm13\t\n" + "add r12, r13\t\n" + "jmp L_done%=\t\n" + + "L_accum%=:\t\n" + // Dump C with accumulate + "vbroadcastss zmm31,DWORD PTR [r15]\t\n" + "vfmadd231ps zmm0,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm0\t\n" + "vfmadd231ps zmm1,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm1\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm2,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm2\t\n" + "vfmadd231ps zmm3,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm3\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm4,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm4\t\n" + "vfmadd231ps zmm5,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm5\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm6,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm6\t\n" + "vfmadd231ps zmm7,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm7\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm8,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm8\t\n" + "vfmadd231ps zmm9,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm9\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm10,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm10\t\n" + "vfmadd231ps zmm11,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm11\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm12,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm12\t\n" + "vfmadd231ps zmm13,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm13\t\n" + "add r12, r13\t\n" + + "L_done%=:\t\n" + + // next outer iteration + "add rcx, 128\t\n" + "mov r12, rcx\t\n" + "mov r9, rax\t\n" + "inc rbx\t\n" + "cmp rbx, rdi\t\n" + "jl loop_outter%=\t\n" + : + : [gp] "rm"(gp) + : "r8", + "r9", + "r10", + "r11", + "r15", + "r13", + "r14", + "rax", + "rcx", + "rdx", + "rsi", + "rdi", + "rbx", + "r12", + "memory"); +} +void __attribute__((noinline)) gemmkernel_8x2_AVX512_fA0fB0fC0(GemmParams* gp) { + asm volatile( +#if !defined(__clang__) + "mov r14, %[gp]\t\n" +#else + "mov %[gp], %%r14\t\n" + ".intel_syntax noprefix\t\n" +#endif + + // Copy parameters + // k + "mov r8, [r14 + 0]\t\n" + // A + "mov r9, [r14 + 8]\t\n" + // B + "mov r10, [r14 + 16]\t\n" + // beta + "mov r15, [r14 + 24]\t\n" + // accum + "mov rdx, [r14 + 32]\t\n" + // C + "mov r12, [r14 + 40]\t\n" + // ldc + "mov r13, [r14 + 48]\t\n" + // b_block_cols + "mov rdi, [r14 + 56]\t\n" + // b_block_size + "mov rsi, [r14 + 64]\t\n" + // Make copies of A and C + "mov rax, r9\t\n" + "mov rcx, r12\t\n" + + "mov rbx, 0\t\n" + "loop_outter%=:\t\n" + "mov r14, 0\t\n" + "vxorps zmm0,zmm0,zmm0\t\n" + "vxorps zmm1,zmm1,zmm1\t\n" + "vxorps zmm2,zmm2,zmm2\t\n" + "vxorps zmm3,zmm3,zmm3\t\n" + "vxorps zmm4,zmm4,zmm4\t\n" + "vxorps zmm5,zmm5,zmm5\t\n" + "vxorps zmm6,zmm6,zmm6\t\n" + "vxorps zmm7,zmm7,zmm7\t\n" + "vxorps zmm8,zmm8,zmm8\t\n" + "vxorps zmm9,zmm9,zmm9\t\n" + "vxorps zmm10,zmm10,zmm10\t\n" + "vxorps zmm11,zmm11,zmm11\t\n" + "vxorps zmm12,zmm12,zmm12\t\n" + "vxorps zmm13,zmm13,zmm13\t\n" + "vxorps zmm14,zmm14,zmm14\t\n" + "vxorps zmm15,zmm15,zmm15\t\n" + + "loop_inner%=:\t\n" + + "vcvtph2ps zmm17,YMMWORD PTR [r10 + 0]\t\n" + "vcvtph2ps zmm18,YMMWORD PTR [r10 + 32]\t\n" + "vbroadcastss zmm16,DWORD PTR [r9+0]\t\n" + "vfmadd231ps zmm0,zmm17,zmm16\t\n" + "vfmadd231ps zmm1,zmm18,zmm16\t\n" + "vbroadcastss zmm16,DWORD PTR [r9+4]\t\n" + "vfmadd231ps zmm2,zmm17,zmm16\t\n" + "vfmadd231ps zmm3,zmm18,zmm16\t\n" + "vbroadcastss zmm16,DWORD PTR [r9+8]\t\n" + "vfmadd231ps zmm4,zmm17,zmm16\t\n" + "vfmadd231ps zmm5,zmm18,zmm16\t\n" + "vbroadcastss zmm16,DWORD PTR [r9+12]\t\n" + "vfmadd231ps zmm6,zmm17,zmm16\t\n" + "vfmadd231ps zmm7,zmm18,zmm16\t\n" + "vbroadcastss zmm16,DWORD PTR [r9+16]\t\n" + "vfmadd231ps zmm8,zmm17,zmm16\t\n" + "vfmadd231ps zmm9,zmm18,zmm16\t\n" + "vbroadcastss zmm16,DWORD PTR [r9+20]\t\n" + "vfmadd231ps zmm10,zmm17,zmm16\t\n" + "vfmadd231ps zmm11,zmm18,zmm16\t\n" + "vbroadcastss zmm16,DWORD PTR [r9+24]\t\n" + "vfmadd231ps zmm12,zmm17,zmm16\t\n" + "vfmadd231ps zmm13,zmm18,zmm16\t\n" + "vbroadcastss zmm16,DWORD PTR [r9+28]\t\n" + "vfmadd231ps zmm14,zmm17,zmm16\t\n" + "vfmadd231ps zmm15,zmm18,zmm16\t\n" + "add r9,32\t\n" + "add r10,64\t\n" + "inc r14\t\n" + "cmp r14, r8\t\n" + "jl loop_inner%=\t\n" + + "L_exit%=:\t\n" + + "cmp rdx, 1\t\n" + "je L_accum%=\t\n" + // Dump C + "vmovups zmmword PTR [r12 + 0], zmm0\t\n" + "vmovups zmmword PTR [r12 + 64], zmm1\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm2\t\n" + "vmovups zmmword PTR [r12 + 64], zmm3\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm4\t\n" + "vmovups zmmword PTR [r12 + 64], zmm5\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm6\t\n" + "vmovups zmmword PTR [r12 + 64], zmm7\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm8\t\n" + "vmovups zmmword PTR [r12 + 64], zmm9\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm10\t\n" + "vmovups zmmword PTR [r12 + 64], zmm11\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm12\t\n" + "vmovups zmmword PTR [r12 + 64], zmm13\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm14\t\n" + "vmovups zmmword PTR [r12 + 64], zmm15\t\n" + "add r12, r13\t\n" + "jmp L_done%=\t\n" + + "L_accum%=:\t\n" + // Dump C with accumulate + "vbroadcastss zmm31,DWORD PTR [r15]\t\n" + "vfmadd231ps zmm0,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm0\t\n" + "vfmadd231ps zmm1,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm1\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm2,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm2\t\n" + "vfmadd231ps zmm3,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm3\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm4,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm4\t\n" + "vfmadd231ps zmm5,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm5\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm6,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm6\t\n" + "vfmadd231ps zmm7,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm7\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm8,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm8\t\n" + "vfmadd231ps zmm9,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm9\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm10,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm10\t\n" + "vfmadd231ps zmm11,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm11\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm12,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm12\t\n" + "vfmadd231ps zmm13,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm13\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm14,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm14\t\n" + "vfmadd231ps zmm15,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm15\t\n" + "add r12, r13\t\n" + + "L_done%=:\t\n" + + // next outer iteration + "add rcx, 128\t\n" + "mov r12, rcx\t\n" + "mov r9, rax\t\n" + "inc rbx\t\n" + "cmp rbx, rdi\t\n" + "jl loop_outter%=\t\n" + : + : [gp] "rm"(gp) + : "r8", + "r9", + "r10", + "r11", + "r15", + "r13", + "r14", + "rax", + "rcx", + "rdx", + "rsi", + "rdi", + "rbx", + "r12", + "memory"); +} +void __attribute__((noinline)) gemmkernel_9x2_AVX512_fA0fB0fC0(GemmParams* gp) { + asm volatile( +#if !defined(__clang__) + "mov r14, %[gp]\t\n" +#else + "mov %[gp], %%r14\t\n" + ".intel_syntax noprefix\t\n" +#endif + + // Copy parameters + // k + "mov r8, [r14 + 0]\t\n" + // A + "mov r9, [r14 + 8]\t\n" + // B + "mov r10, [r14 + 16]\t\n" + // beta + "mov r15, [r14 + 24]\t\n" + // accum + "mov rdx, [r14 + 32]\t\n" + // C + "mov r12, [r14 + 40]\t\n" + // ldc + "mov r13, [r14 + 48]\t\n" + // b_block_cols + "mov rdi, [r14 + 56]\t\n" + // b_block_size + "mov rsi, [r14 + 64]\t\n" + // Make copies of A and C + "mov rax, r9\t\n" + "mov rcx, r12\t\n" + + "mov rbx, 0\t\n" + "loop_outter%=:\t\n" + "mov r14, 0\t\n" + "vxorps zmm0,zmm0,zmm0\t\n" + "vxorps zmm1,zmm1,zmm1\t\n" + "vxorps zmm2,zmm2,zmm2\t\n" + "vxorps zmm3,zmm3,zmm3\t\n" + "vxorps zmm4,zmm4,zmm4\t\n" + "vxorps zmm5,zmm5,zmm5\t\n" + "vxorps zmm6,zmm6,zmm6\t\n" + "vxorps zmm7,zmm7,zmm7\t\n" + "vxorps zmm8,zmm8,zmm8\t\n" + "vxorps zmm9,zmm9,zmm9\t\n" + "vxorps zmm10,zmm10,zmm10\t\n" + "vxorps zmm11,zmm11,zmm11\t\n" + "vxorps zmm12,zmm12,zmm12\t\n" + "vxorps zmm13,zmm13,zmm13\t\n" + "vxorps zmm14,zmm14,zmm14\t\n" + "vxorps zmm15,zmm15,zmm15\t\n" + "vxorps zmm16,zmm16,zmm16\t\n" + "vxorps zmm17,zmm17,zmm17\t\n" + + "loop_inner%=:\t\n" + + "vcvtph2ps zmm19,YMMWORD PTR [r10 + 0]\t\n" + "vcvtph2ps zmm20,YMMWORD PTR [r10 + 32]\t\n" + "vbroadcastss zmm18,DWORD PTR [r9+0]\t\n" + "vfmadd231ps zmm0,zmm19,zmm18\t\n" + "vfmadd231ps zmm1,zmm20,zmm18\t\n" + "vbroadcastss zmm18,DWORD PTR [r9+4]\t\n" + "vfmadd231ps zmm2,zmm19,zmm18\t\n" + "vfmadd231ps zmm3,zmm20,zmm18\t\n" + "vbroadcastss zmm18,DWORD PTR [r9+8]\t\n" + "vfmadd231ps zmm4,zmm19,zmm18\t\n" + "vfmadd231ps zmm5,zmm20,zmm18\t\n" + "vbroadcastss zmm18,DWORD PTR [r9+12]\t\n" + "vfmadd231ps zmm6,zmm19,zmm18\t\n" + "vfmadd231ps zmm7,zmm20,zmm18\t\n" + "vbroadcastss zmm18,DWORD PTR [r9+16]\t\n" + "vfmadd231ps zmm8,zmm19,zmm18\t\n" + "vfmadd231ps zmm9,zmm20,zmm18\t\n" + "vbroadcastss zmm18,DWORD PTR [r9+20]\t\n" + "vfmadd231ps zmm10,zmm19,zmm18\t\n" + "vfmadd231ps zmm11,zmm20,zmm18\t\n" + "vbroadcastss zmm18,DWORD PTR [r9+24]\t\n" + "vfmadd231ps zmm12,zmm19,zmm18\t\n" + "vfmadd231ps zmm13,zmm20,zmm18\t\n" + "vbroadcastss zmm18,DWORD PTR [r9+28]\t\n" + "vfmadd231ps zmm14,zmm19,zmm18\t\n" + "vfmadd231ps zmm15,zmm20,zmm18\t\n" + "vbroadcastss zmm18,DWORD PTR [r9+32]\t\n" + "vfmadd231ps zmm16,zmm19,zmm18\t\n" + "vfmadd231ps zmm17,zmm20,zmm18\t\n" + "add r9,36\t\n" + "add r10,64\t\n" + "inc r14\t\n" + "cmp r14, r8\t\n" + "jl loop_inner%=\t\n" + + "L_exit%=:\t\n" + + "cmp rdx, 1\t\n" + "je L_accum%=\t\n" + // Dump C + "vmovups zmmword PTR [r12 + 0], zmm0\t\n" + "vmovups zmmword PTR [r12 + 64], zmm1\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm2\t\n" + "vmovups zmmword PTR [r12 + 64], zmm3\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm4\t\n" + "vmovups zmmword PTR [r12 + 64], zmm5\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm6\t\n" + "vmovups zmmword PTR [r12 + 64], zmm7\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm8\t\n" + "vmovups zmmword PTR [r12 + 64], zmm9\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm10\t\n" + "vmovups zmmword PTR [r12 + 64], zmm11\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm12\t\n" + "vmovups zmmword PTR [r12 + 64], zmm13\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm14\t\n" + "vmovups zmmword PTR [r12 + 64], zmm15\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm16\t\n" + "vmovups zmmword PTR [r12 + 64], zmm17\t\n" + "add r12, r13\t\n" + "jmp L_done%=\t\n" + + "L_accum%=:\t\n" + // Dump C with accumulate + "vbroadcastss zmm31,DWORD PTR [r15]\t\n" + "vfmadd231ps zmm0,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm0\t\n" + "vfmadd231ps zmm1,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm1\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm2,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm2\t\n" + "vfmadd231ps zmm3,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm3\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm4,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm4\t\n" + "vfmadd231ps zmm5,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm5\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm6,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm6\t\n" + "vfmadd231ps zmm7,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm7\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm8,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm8\t\n" + "vfmadd231ps zmm9,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm9\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm10,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm10\t\n" + "vfmadd231ps zmm11,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm11\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm12,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm12\t\n" + "vfmadd231ps zmm13,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm13\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm14,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm14\t\n" + "vfmadd231ps zmm15,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm15\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm16,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm16\t\n" + "vfmadd231ps zmm17,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm17\t\n" + "add r12, r13\t\n" + + "L_done%=:\t\n" + + // next outer iteration + "add rcx, 128\t\n" + "mov r12, rcx\t\n" + "mov r9, rax\t\n" + "inc rbx\t\n" + "cmp rbx, rdi\t\n" + "jl loop_outter%=\t\n" + : + : [gp] "rm"(gp) + : "r8", + "r9", + "r10", + "r11", + "r15", + "r13", + "r14", + "rax", + "rcx", + "rdx", + "rsi", + "rdi", + "rbx", + "r12", + "memory"); +} +void __attribute__((noinline)) +gemmkernel_10x2_AVX512_fA0fB0fC0(GemmParams* gp) { + asm volatile( +#if !defined(__clang__) + "mov r14, %[gp]\t\n" +#else + "mov %[gp], %%r14\t\n" + ".intel_syntax noprefix\t\n" +#endif + + // Copy parameters + // k + "mov r8, [r14 + 0]\t\n" + // A + "mov r9, [r14 + 8]\t\n" + // B + "mov r10, [r14 + 16]\t\n" + // beta + "mov r15, [r14 + 24]\t\n" + // accum + "mov rdx, [r14 + 32]\t\n" + // C + "mov r12, [r14 + 40]\t\n" + // ldc + "mov r13, [r14 + 48]\t\n" + // b_block_cols + "mov rdi, [r14 + 56]\t\n" + // b_block_size + "mov rsi, [r14 + 64]\t\n" + // Make copies of A and C + "mov rax, r9\t\n" + "mov rcx, r12\t\n" + + "mov rbx, 0\t\n" + "loop_outter%=:\t\n" + "mov r14, 0\t\n" + "vxorps zmm0,zmm0,zmm0\t\n" + "vxorps zmm1,zmm1,zmm1\t\n" + "vxorps zmm2,zmm2,zmm2\t\n" + "vxorps zmm3,zmm3,zmm3\t\n" + "vxorps zmm4,zmm4,zmm4\t\n" + "vxorps zmm5,zmm5,zmm5\t\n" + "vxorps zmm6,zmm6,zmm6\t\n" + "vxorps zmm7,zmm7,zmm7\t\n" + "vxorps zmm8,zmm8,zmm8\t\n" + "vxorps zmm9,zmm9,zmm9\t\n" + "vxorps zmm10,zmm10,zmm10\t\n" + "vxorps zmm11,zmm11,zmm11\t\n" + "vxorps zmm12,zmm12,zmm12\t\n" + "vxorps zmm13,zmm13,zmm13\t\n" + "vxorps zmm14,zmm14,zmm14\t\n" + "vxorps zmm15,zmm15,zmm15\t\n" + "vxorps zmm16,zmm16,zmm16\t\n" + "vxorps zmm17,zmm17,zmm17\t\n" + "vxorps zmm18,zmm18,zmm18\t\n" + "vxorps zmm19,zmm19,zmm19\t\n" + + "loop_inner%=:\t\n" + + "vcvtph2ps zmm21,YMMWORD PTR [r10 + 0]\t\n" + "vcvtph2ps zmm22,YMMWORD PTR [r10 + 32]\t\n" + "vbroadcastss zmm20,DWORD PTR [r9+0]\t\n" + "vfmadd231ps zmm0,zmm21,zmm20\t\n" + "vfmadd231ps zmm1,zmm22,zmm20\t\n" + "vbroadcastss zmm20,DWORD PTR [r9+4]\t\n" + "vfmadd231ps zmm2,zmm21,zmm20\t\n" + "vfmadd231ps zmm3,zmm22,zmm20\t\n" + "vbroadcastss zmm20,DWORD PTR [r9+8]\t\n" + "vfmadd231ps zmm4,zmm21,zmm20\t\n" + "vfmadd231ps zmm5,zmm22,zmm20\t\n" + "vbroadcastss zmm20,DWORD PTR [r9+12]\t\n" + "vfmadd231ps zmm6,zmm21,zmm20\t\n" + "vfmadd231ps zmm7,zmm22,zmm20\t\n" + "vbroadcastss zmm20,DWORD PTR [r9+16]\t\n" + "vfmadd231ps zmm8,zmm21,zmm20\t\n" + "vfmadd231ps zmm9,zmm22,zmm20\t\n" + "vbroadcastss zmm20,DWORD PTR [r9+20]\t\n" + "vfmadd231ps zmm10,zmm21,zmm20\t\n" + "vfmadd231ps zmm11,zmm22,zmm20\t\n" + "vbroadcastss zmm20,DWORD PTR [r9+24]\t\n" + "vfmadd231ps zmm12,zmm21,zmm20\t\n" + "vfmadd231ps zmm13,zmm22,zmm20\t\n" + "vbroadcastss zmm20,DWORD PTR [r9+28]\t\n" + "vfmadd231ps zmm14,zmm21,zmm20\t\n" + "vfmadd231ps zmm15,zmm22,zmm20\t\n" + "vbroadcastss zmm20,DWORD PTR [r9+32]\t\n" + "vfmadd231ps zmm16,zmm21,zmm20\t\n" + "vfmadd231ps zmm17,zmm22,zmm20\t\n" + "vbroadcastss zmm20,DWORD PTR [r9+36]\t\n" + "vfmadd231ps zmm18,zmm21,zmm20\t\n" + "vfmadd231ps zmm19,zmm22,zmm20\t\n" + "add r9,40\t\n" + "add r10,64\t\n" + "inc r14\t\n" + "cmp r14, r8\t\n" + "jl loop_inner%=\t\n" + + "L_exit%=:\t\n" + + "cmp rdx, 1\t\n" + "je L_accum%=\t\n" + // Dump C + "vmovups zmmword PTR [r12 + 0], zmm0\t\n" + "vmovups zmmword PTR [r12 + 64], zmm1\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm2\t\n" + "vmovups zmmword PTR [r12 + 64], zmm3\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm4\t\n" + "vmovups zmmword PTR [r12 + 64], zmm5\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm6\t\n" + "vmovups zmmword PTR [r12 + 64], zmm7\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm8\t\n" + "vmovups zmmword PTR [r12 + 64], zmm9\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm10\t\n" + "vmovups zmmword PTR [r12 + 64], zmm11\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm12\t\n" + "vmovups zmmword PTR [r12 + 64], zmm13\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm14\t\n" + "vmovups zmmword PTR [r12 + 64], zmm15\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm16\t\n" + "vmovups zmmword PTR [r12 + 64], zmm17\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm18\t\n" + "vmovups zmmword PTR [r12 + 64], zmm19\t\n" + "add r12, r13\t\n" + "jmp L_done%=\t\n" + + "L_accum%=:\t\n" + // Dump C with accumulate + "vbroadcastss zmm31,DWORD PTR [r15]\t\n" + "vfmadd231ps zmm0,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm0\t\n" + "vfmadd231ps zmm1,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm1\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm2,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm2\t\n" + "vfmadd231ps zmm3,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm3\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm4,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm4\t\n" + "vfmadd231ps zmm5,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm5\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm6,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm6\t\n" + "vfmadd231ps zmm7,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm7\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm8,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm8\t\n" + "vfmadd231ps zmm9,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm9\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm10,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm10\t\n" + "vfmadd231ps zmm11,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm11\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm12,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm12\t\n" + "vfmadd231ps zmm13,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm13\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm14,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm14\t\n" + "vfmadd231ps zmm15,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm15\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm16,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm16\t\n" + "vfmadd231ps zmm17,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm17\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm18,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm18\t\n" + "vfmadd231ps zmm19,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm19\t\n" + "add r12, r13\t\n" + + "L_done%=:\t\n" + + // next outer iteration + "add rcx, 128\t\n" + "mov r12, rcx\t\n" + "mov r9, rax\t\n" + "inc rbx\t\n" + "cmp rbx, rdi\t\n" + "jl loop_outter%=\t\n" + : + : [gp] "rm"(gp) + : "r8", + "r9", + "r10", + "r11", + "r15", + "r13", + "r14", + "rax", + "rcx", + "rdx", + "rsi", + "rdi", + "rbx", + "r12", + "memory"); +} +void __attribute__((noinline)) +gemmkernel_11x2_AVX512_fA0fB0fC0(GemmParams* gp) { + asm volatile( +#if !defined(__clang__) + "mov r14, %[gp]\t\n" +#else + "mov %[gp], %%r14\t\n" + ".intel_syntax noprefix\t\n" +#endif + + // Copy parameters + // k + "mov r8, [r14 + 0]\t\n" + // A + "mov r9, [r14 + 8]\t\n" + // B + "mov r10, [r14 + 16]\t\n" + // beta + "mov r15, [r14 + 24]\t\n" + // accum + "mov rdx, [r14 + 32]\t\n" + // C + "mov r12, [r14 + 40]\t\n" + // ldc + "mov r13, [r14 + 48]\t\n" + // b_block_cols + "mov rdi, [r14 + 56]\t\n" + // b_block_size + "mov rsi, [r14 + 64]\t\n" + // Make copies of A and C + "mov rax, r9\t\n" + "mov rcx, r12\t\n" + + "mov rbx, 0\t\n" + "loop_outter%=:\t\n" + "mov r14, 0\t\n" + "vxorps zmm0,zmm0,zmm0\t\n" + "vxorps zmm1,zmm1,zmm1\t\n" + "vxorps zmm2,zmm2,zmm2\t\n" + "vxorps zmm3,zmm3,zmm3\t\n" + "vxorps zmm4,zmm4,zmm4\t\n" + "vxorps zmm5,zmm5,zmm5\t\n" + "vxorps zmm6,zmm6,zmm6\t\n" + "vxorps zmm7,zmm7,zmm7\t\n" + "vxorps zmm8,zmm8,zmm8\t\n" + "vxorps zmm9,zmm9,zmm9\t\n" + "vxorps zmm10,zmm10,zmm10\t\n" + "vxorps zmm11,zmm11,zmm11\t\n" + "vxorps zmm12,zmm12,zmm12\t\n" + "vxorps zmm13,zmm13,zmm13\t\n" + "vxorps zmm14,zmm14,zmm14\t\n" + "vxorps zmm15,zmm15,zmm15\t\n" + "vxorps zmm16,zmm16,zmm16\t\n" + "vxorps zmm17,zmm17,zmm17\t\n" + "vxorps zmm18,zmm18,zmm18\t\n" + "vxorps zmm19,zmm19,zmm19\t\n" + "vxorps zmm20,zmm20,zmm20\t\n" + "vxorps zmm21,zmm21,zmm21\t\n" + + "loop_inner%=:\t\n" + + "vcvtph2ps zmm23,YMMWORD PTR [r10 + 0]\t\n" + "vcvtph2ps zmm24,YMMWORD PTR [r10 + 32]\t\n" + "vbroadcastss zmm22,DWORD PTR [r9+0]\t\n" + "vfmadd231ps zmm0,zmm23,zmm22\t\n" + "vfmadd231ps zmm1,zmm24,zmm22\t\n" + "vbroadcastss zmm22,DWORD PTR [r9+4]\t\n" + "vfmadd231ps zmm2,zmm23,zmm22\t\n" + "vfmadd231ps zmm3,zmm24,zmm22\t\n" + "vbroadcastss zmm22,DWORD PTR [r9+8]\t\n" + "vfmadd231ps zmm4,zmm23,zmm22\t\n" + "vfmadd231ps zmm5,zmm24,zmm22\t\n" + "vbroadcastss zmm22,DWORD PTR [r9+12]\t\n" + "vfmadd231ps zmm6,zmm23,zmm22\t\n" + "vfmadd231ps zmm7,zmm24,zmm22\t\n" + "vbroadcastss zmm22,DWORD PTR [r9+16]\t\n" + "vfmadd231ps zmm8,zmm23,zmm22\t\n" + "vfmadd231ps zmm9,zmm24,zmm22\t\n" + "vbroadcastss zmm22,DWORD PTR [r9+20]\t\n" + "vfmadd231ps zmm10,zmm23,zmm22\t\n" + "vfmadd231ps zmm11,zmm24,zmm22\t\n" + "vbroadcastss zmm22,DWORD PTR [r9+24]\t\n" + "vfmadd231ps zmm12,zmm23,zmm22\t\n" + "vfmadd231ps zmm13,zmm24,zmm22\t\n" + "vbroadcastss zmm22,DWORD PTR [r9+28]\t\n" + "vfmadd231ps zmm14,zmm23,zmm22\t\n" + "vfmadd231ps zmm15,zmm24,zmm22\t\n" + "vbroadcastss zmm22,DWORD PTR [r9+32]\t\n" + "vfmadd231ps zmm16,zmm23,zmm22\t\n" + "vfmadd231ps zmm17,zmm24,zmm22\t\n" + "vbroadcastss zmm22,DWORD PTR [r9+36]\t\n" + "vfmadd231ps zmm18,zmm23,zmm22\t\n" + "vfmadd231ps zmm19,zmm24,zmm22\t\n" + "vbroadcastss zmm22,DWORD PTR [r9+40]\t\n" + "vfmadd231ps zmm20,zmm23,zmm22\t\n" + "vfmadd231ps zmm21,zmm24,zmm22\t\n" + "add r9,44\t\n" + "add r10,64\t\n" + "inc r14\t\n" + "cmp r14, r8\t\n" + "jl loop_inner%=\t\n" + + "L_exit%=:\t\n" + + "cmp rdx, 1\t\n" + "je L_accum%=\t\n" + // Dump C + "vmovups zmmword PTR [r12 + 0], zmm0\t\n" + "vmovups zmmword PTR [r12 + 64], zmm1\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm2\t\n" + "vmovups zmmword PTR [r12 + 64], zmm3\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm4\t\n" + "vmovups zmmword PTR [r12 + 64], zmm5\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm6\t\n" + "vmovups zmmword PTR [r12 + 64], zmm7\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm8\t\n" + "vmovups zmmword PTR [r12 + 64], zmm9\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm10\t\n" + "vmovups zmmword PTR [r12 + 64], zmm11\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm12\t\n" + "vmovups zmmword PTR [r12 + 64], zmm13\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm14\t\n" + "vmovups zmmword PTR [r12 + 64], zmm15\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm16\t\n" + "vmovups zmmword PTR [r12 + 64], zmm17\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm18\t\n" + "vmovups zmmword PTR [r12 + 64], zmm19\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm20\t\n" + "vmovups zmmword PTR [r12 + 64], zmm21\t\n" + "add r12, r13\t\n" + "jmp L_done%=\t\n" + + "L_accum%=:\t\n" + // Dump C with accumulate + "vbroadcastss zmm31,DWORD PTR [r15]\t\n" + "vfmadd231ps zmm0,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm0\t\n" + "vfmadd231ps zmm1,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm1\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm2,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm2\t\n" + "vfmadd231ps zmm3,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm3\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm4,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm4\t\n" + "vfmadd231ps zmm5,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm5\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm6,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm6\t\n" + "vfmadd231ps zmm7,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm7\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm8,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm8\t\n" + "vfmadd231ps zmm9,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm9\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm10,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm10\t\n" + "vfmadd231ps zmm11,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm11\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm12,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm12\t\n" + "vfmadd231ps zmm13,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm13\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm14,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm14\t\n" + "vfmadd231ps zmm15,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm15\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm16,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm16\t\n" + "vfmadd231ps zmm17,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm17\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm18,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm18\t\n" + "vfmadd231ps zmm19,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm19\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm20,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm20\t\n" + "vfmadd231ps zmm21,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm21\t\n" + "add r12, r13\t\n" + + "L_done%=:\t\n" + + // next outer iteration + "add rcx, 128\t\n" + "mov r12, rcx\t\n" + "mov r9, rax\t\n" + "inc rbx\t\n" + "cmp rbx, rdi\t\n" + "jl loop_outter%=\t\n" + : + : [gp] "rm"(gp) + : "r8", + "r9", + "r10", + "r11", + "r15", + "r13", + "r14", + "rax", + "rcx", + "rdx", + "rsi", + "rdi", + "rbx", + "r12", + "memory"); +} +void __attribute__((noinline)) +gemmkernel_12x2_AVX512_fA0fB0fC0(GemmParams* gp) { + asm volatile( +#if !defined(__clang__) + "mov r14, %[gp]\t\n" +#else + "mov %[gp], %%r14\t\n" + ".intel_syntax noprefix\t\n" +#endif + + // Copy parameters + // k + "mov r8, [r14 + 0]\t\n" + // A + "mov r9, [r14 + 8]\t\n" + // B + "mov r10, [r14 + 16]\t\n" + // beta + "mov r15, [r14 + 24]\t\n" + // accum + "mov rdx, [r14 + 32]\t\n" + // C + "mov r12, [r14 + 40]\t\n" + // ldc + "mov r13, [r14 + 48]\t\n" + // b_block_cols + "mov rdi, [r14 + 56]\t\n" + // b_block_size + "mov rsi, [r14 + 64]\t\n" + // Make copies of A and C + "mov rax, r9\t\n" + "mov rcx, r12\t\n" + + "mov rbx, 0\t\n" + "loop_outter%=:\t\n" + "mov r14, 0\t\n" + "vxorps zmm0,zmm0,zmm0\t\n" + "vxorps zmm1,zmm1,zmm1\t\n" + "vxorps zmm2,zmm2,zmm2\t\n" + "vxorps zmm3,zmm3,zmm3\t\n" + "vxorps zmm4,zmm4,zmm4\t\n" + "vxorps zmm5,zmm5,zmm5\t\n" + "vxorps zmm6,zmm6,zmm6\t\n" + "vxorps zmm7,zmm7,zmm7\t\n" + "vxorps zmm8,zmm8,zmm8\t\n" + "vxorps zmm9,zmm9,zmm9\t\n" + "vxorps zmm10,zmm10,zmm10\t\n" + "vxorps zmm11,zmm11,zmm11\t\n" + "vxorps zmm12,zmm12,zmm12\t\n" + "vxorps zmm13,zmm13,zmm13\t\n" + "vxorps zmm14,zmm14,zmm14\t\n" + "vxorps zmm15,zmm15,zmm15\t\n" + "vxorps zmm16,zmm16,zmm16\t\n" + "vxorps zmm17,zmm17,zmm17\t\n" + "vxorps zmm18,zmm18,zmm18\t\n" + "vxorps zmm19,zmm19,zmm19\t\n" + "vxorps zmm20,zmm20,zmm20\t\n" + "vxorps zmm21,zmm21,zmm21\t\n" + "vxorps zmm22,zmm22,zmm22\t\n" + "vxorps zmm23,zmm23,zmm23\t\n" + + "loop_inner%=:\t\n" + + "vcvtph2ps zmm25,YMMWORD PTR [r10 + 0]\t\n" + "vcvtph2ps zmm26,YMMWORD PTR [r10 + 32]\t\n" + "vbroadcastss zmm24,DWORD PTR [r9+0]\t\n" + "vfmadd231ps zmm0,zmm25,zmm24\t\n" + "vfmadd231ps zmm1,zmm26,zmm24\t\n" + "vbroadcastss zmm24,DWORD PTR [r9+4]\t\n" + "vfmadd231ps zmm2,zmm25,zmm24\t\n" + "vfmadd231ps zmm3,zmm26,zmm24\t\n" + "vbroadcastss zmm24,DWORD PTR [r9+8]\t\n" + "vfmadd231ps zmm4,zmm25,zmm24\t\n" + "vfmadd231ps zmm5,zmm26,zmm24\t\n" + "vbroadcastss zmm24,DWORD PTR [r9+12]\t\n" + "vfmadd231ps zmm6,zmm25,zmm24\t\n" + "vfmadd231ps zmm7,zmm26,zmm24\t\n" + "vbroadcastss zmm24,DWORD PTR [r9+16]\t\n" + "vfmadd231ps zmm8,zmm25,zmm24\t\n" + "vfmadd231ps zmm9,zmm26,zmm24\t\n" + "vbroadcastss zmm24,DWORD PTR [r9+20]\t\n" + "vfmadd231ps zmm10,zmm25,zmm24\t\n" + "vfmadd231ps zmm11,zmm26,zmm24\t\n" + "vbroadcastss zmm24,DWORD PTR [r9+24]\t\n" + "vfmadd231ps zmm12,zmm25,zmm24\t\n" + "vfmadd231ps zmm13,zmm26,zmm24\t\n" + "vbroadcastss zmm24,DWORD PTR [r9+28]\t\n" + "vfmadd231ps zmm14,zmm25,zmm24\t\n" + "vfmadd231ps zmm15,zmm26,zmm24\t\n" + "vbroadcastss zmm24,DWORD PTR [r9+32]\t\n" + "vfmadd231ps zmm16,zmm25,zmm24\t\n" + "vfmadd231ps zmm17,zmm26,zmm24\t\n" + "vbroadcastss zmm24,DWORD PTR [r9+36]\t\n" + "vfmadd231ps zmm18,zmm25,zmm24\t\n" + "vfmadd231ps zmm19,zmm26,zmm24\t\n" + "vbroadcastss zmm24,DWORD PTR [r9+40]\t\n" + "vfmadd231ps zmm20,zmm25,zmm24\t\n" + "vfmadd231ps zmm21,zmm26,zmm24\t\n" + "vbroadcastss zmm24,DWORD PTR [r9+44]\t\n" + "vfmadd231ps zmm22,zmm25,zmm24\t\n" + "vfmadd231ps zmm23,zmm26,zmm24\t\n" + "add r9,48\t\n" + "add r10,64\t\n" + "inc r14\t\n" + "cmp r14, r8\t\n" + "jl loop_inner%=\t\n" + + "L_exit%=:\t\n" + + "cmp rdx, 1\t\n" + "je L_accum%=\t\n" + // Dump C + "vmovups zmmword PTR [r12 + 0], zmm0\t\n" + "vmovups zmmword PTR [r12 + 64], zmm1\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm2\t\n" + "vmovups zmmword PTR [r12 + 64], zmm3\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm4\t\n" + "vmovups zmmword PTR [r12 + 64], zmm5\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm6\t\n" + "vmovups zmmword PTR [r12 + 64], zmm7\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm8\t\n" + "vmovups zmmword PTR [r12 + 64], zmm9\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm10\t\n" + "vmovups zmmword PTR [r12 + 64], zmm11\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm12\t\n" + "vmovups zmmword PTR [r12 + 64], zmm13\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm14\t\n" + "vmovups zmmword PTR [r12 + 64], zmm15\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm16\t\n" + "vmovups zmmword PTR [r12 + 64], zmm17\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm18\t\n" + "vmovups zmmword PTR [r12 + 64], zmm19\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm20\t\n" + "vmovups zmmword PTR [r12 + 64], zmm21\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm22\t\n" + "vmovups zmmword PTR [r12 + 64], zmm23\t\n" + "add r12, r13\t\n" + "jmp L_done%=\t\n" + + "L_accum%=:\t\n" + // Dump C with accumulate + "vbroadcastss zmm31,DWORD PTR [r15]\t\n" + "vfmadd231ps zmm0,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm0\t\n" + "vfmadd231ps zmm1,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm1\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm2,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm2\t\n" + "vfmadd231ps zmm3,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm3\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm4,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm4\t\n" + "vfmadd231ps zmm5,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm5\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm6,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm6\t\n" + "vfmadd231ps zmm7,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm7\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm8,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm8\t\n" + "vfmadd231ps zmm9,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm9\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm10,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm10\t\n" + "vfmadd231ps zmm11,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm11\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm12,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm12\t\n" + "vfmadd231ps zmm13,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm13\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm14,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm14\t\n" + "vfmadd231ps zmm15,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm15\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm16,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm16\t\n" + "vfmadd231ps zmm17,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm17\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm18,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm18\t\n" + "vfmadd231ps zmm19,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm19\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm20,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm20\t\n" + "vfmadd231ps zmm21,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm21\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm22,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm22\t\n" + "vfmadd231ps zmm23,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm23\t\n" + "add r12, r13\t\n" + + "L_done%=:\t\n" + + // next outer iteration + "add rcx, 128\t\n" + "mov r12, rcx\t\n" + "mov r9, rax\t\n" + "inc rbx\t\n" + "cmp rbx, rdi\t\n" + "jl loop_outter%=\t\n" + : + : [gp] "rm"(gp) + : "r8", + "r9", + "r10", + "r11", + "r15", + "r13", + "r14", + "rax", + "rcx", + "rdx", + "rsi", + "rdi", + "rbx", + "r12", + "memory"); +} +void __attribute__((noinline)) +gemmkernel_13x2_AVX512_fA0fB0fC0(GemmParams* gp) { + asm volatile( +#if !defined(__clang__) + "mov r14, %[gp]\t\n" +#else + "mov %[gp], %%r14\t\n" + ".intel_syntax noprefix\t\n" +#endif + + // Copy parameters + // k + "mov r8, [r14 + 0]\t\n" + // A + "mov r9, [r14 + 8]\t\n" + // B + "mov r10, [r14 + 16]\t\n" + // beta + "mov r15, [r14 + 24]\t\n" + // accum + "mov rdx, [r14 + 32]\t\n" + // C + "mov r12, [r14 + 40]\t\n" + // ldc + "mov r13, [r14 + 48]\t\n" + // b_block_cols + "mov rdi, [r14 + 56]\t\n" + // b_block_size + "mov rsi, [r14 + 64]\t\n" + // Make copies of A and C + "mov rax, r9\t\n" + "mov rcx, r12\t\n" + + "mov rbx, 0\t\n" + "loop_outter%=:\t\n" + "mov r14, 0\t\n" + "vxorps zmm0,zmm0,zmm0\t\n" + "vxorps zmm1,zmm1,zmm1\t\n" + "vxorps zmm2,zmm2,zmm2\t\n" + "vxorps zmm3,zmm3,zmm3\t\n" + "vxorps zmm4,zmm4,zmm4\t\n" + "vxorps zmm5,zmm5,zmm5\t\n" + "vxorps zmm6,zmm6,zmm6\t\n" + "vxorps zmm7,zmm7,zmm7\t\n" + "vxorps zmm8,zmm8,zmm8\t\n" + "vxorps zmm9,zmm9,zmm9\t\n" + "vxorps zmm10,zmm10,zmm10\t\n" + "vxorps zmm11,zmm11,zmm11\t\n" + "vxorps zmm12,zmm12,zmm12\t\n" + "vxorps zmm13,zmm13,zmm13\t\n" + "vxorps zmm14,zmm14,zmm14\t\n" + "vxorps zmm15,zmm15,zmm15\t\n" + "vxorps zmm16,zmm16,zmm16\t\n" + "vxorps zmm17,zmm17,zmm17\t\n" + "vxorps zmm18,zmm18,zmm18\t\n" + "vxorps zmm19,zmm19,zmm19\t\n" + "vxorps zmm20,zmm20,zmm20\t\n" + "vxorps zmm21,zmm21,zmm21\t\n" + "vxorps zmm22,zmm22,zmm22\t\n" + "vxorps zmm23,zmm23,zmm23\t\n" + "vxorps zmm24,zmm24,zmm24\t\n" + "vxorps zmm25,zmm25,zmm25\t\n" + + "loop_inner%=:\t\n" + + "vcvtph2ps zmm27,YMMWORD PTR [r10 + 0]\t\n" + "vcvtph2ps zmm28,YMMWORD PTR [r10 + 32]\t\n" + "vbroadcastss zmm26,DWORD PTR [r9+0]\t\n" + "vfmadd231ps zmm0,zmm27,zmm26\t\n" + "vfmadd231ps zmm1,zmm28,zmm26\t\n" + "vbroadcastss zmm26,DWORD PTR [r9+4]\t\n" + "vfmadd231ps zmm2,zmm27,zmm26\t\n" + "vfmadd231ps zmm3,zmm28,zmm26\t\n" + "vbroadcastss zmm26,DWORD PTR [r9+8]\t\n" + "vfmadd231ps zmm4,zmm27,zmm26\t\n" + "vfmadd231ps zmm5,zmm28,zmm26\t\n" + "vbroadcastss zmm26,DWORD PTR [r9+12]\t\n" + "vfmadd231ps zmm6,zmm27,zmm26\t\n" + "vfmadd231ps zmm7,zmm28,zmm26\t\n" + "vbroadcastss zmm26,DWORD PTR [r9+16]\t\n" + "vfmadd231ps zmm8,zmm27,zmm26\t\n" + "vfmadd231ps zmm9,zmm28,zmm26\t\n" + "vbroadcastss zmm26,DWORD PTR [r9+20]\t\n" + "vfmadd231ps zmm10,zmm27,zmm26\t\n" + "vfmadd231ps zmm11,zmm28,zmm26\t\n" + "vbroadcastss zmm26,DWORD PTR [r9+24]\t\n" + "vfmadd231ps zmm12,zmm27,zmm26\t\n" + "vfmadd231ps zmm13,zmm28,zmm26\t\n" + "vbroadcastss zmm26,DWORD PTR [r9+28]\t\n" + "vfmadd231ps zmm14,zmm27,zmm26\t\n" + "vfmadd231ps zmm15,zmm28,zmm26\t\n" + "vbroadcastss zmm26,DWORD PTR [r9+32]\t\n" + "vfmadd231ps zmm16,zmm27,zmm26\t\n" + "vfmadd231ps zmm17,zmm28,zmm26\t\n" + "vbroadcastss zmm26,DWORD PTR [r9+36]\t\n" + "vfmadd231ps zmm18,zmm27,zmm26\t\n" + "vfmadd231ps zmm19,zmm28,zmm26\t\n" + "vbroadcastss zmm26,DWORD PTR [r9+40]\t\n" + "vfmadd231ps zmm20,zmm27,zmm26\t\n" + "vfmadd231ps zmm21,zmm28,zmm26\t\n" + "vbroadcastss zmm26,DWORD PTR [r9+44]\t\n" + "vfmadd231ps zmm22,zmm27,zmm26\t\n" + "vfmadd231ps zmm23,zmm28,zmm26\t\n" + "vbroadcastss zmm26,DWORD PTR [r9+48]\t\n" + "vfmadd231ps zmm24,zmm27,zmm26\t\n" + "vfmadd231ps zmm25,zmm28,zmm26\t\n" + "add r9,52\t\n" + "add r10,64\t\n" + "inc r14\t\n" + "cmp r14, r8\t\n" + "jl loop_inner%=\t\n" + + "L_exit%=:\t\n" + + "cmp rdx, 1\t\n" + "je L_accum%=\t\n" + // Dump C + "vmovups zmmword PTR [r12 + 0], zmm0\t\n" + "vmovups zmmword PTR [r12 + 64], zmm1\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm2\t\n" + "vmovups zmmword PTR [r12 + 64], zmm3\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm4\t\n" + "vmovups zmmword PTR [r12 + 64], zmm5\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm6\t\n" + "vmovups zmmword PTR [r12 + 64], zmm7\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm8\t\n" + "vmovups zmmword PTR [r12 + 64], zmm9\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm10\t\n" + "vmovups zmmword PTR [r12 + 64], zmm11\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm12\t\n" + "vmovups zmmword PTR [r12 + 64], zmm13\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm14\t\n" + "vmovups zmmword PTR [r12 + 64], zmm15\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm16\t\n" + "vmovups zmmword PTR [r12 + 64], zmm17\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm18\t\n" + "vmovups zmmword PTR [r12 + 64], zmm19\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm20\t\n" + "vmovups zmmword PTR [r12 + 64], zmm21\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm22\t\n" + "vmovups zmmword PTR [r12 + 64], zmm23\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm24\t\n" + "vmovups zmmword PTR [r12 + 64], zmm25\t\n" + "add r12, r13\t\n" + "jmp L_done%=\t\n" + + "L_accum%=:\t\n" + // Dump C with accumulate + "vbroadcastss zmm31,DWORD PTR [r15]\t\n" + "vfmadd231ps zmm0,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm0\t\n" + "vfmadd231ps zmm1,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm1\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm2,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm2\t\n" + "vfmadd231ps zmm3,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm3\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm4,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm4\t\n" + "vfmadd231ps zmm5,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm5\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm6,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm6\t\n" + "vfmadd231ps zmm7,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm7\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm8,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm8\t\n" + "vfmadd231ps zmm9,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm9\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm10,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm10\t\n" + "vfmadd231ps zmm11,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm11\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm12,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm12\t\n" + "vfmadd231ps zmm13,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm13\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm14,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm14\t\n" + "vfmadd231ps zmm15,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm15\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm16,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm16\t\n" + "vfmadd231ps zmm17,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm17\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm18,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm18\t\n" + "vfmadd231ps zmm19,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm19\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm20,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm20\t\n" + "vfmadd231ps zmm21,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm21\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm22,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm22\t\n" + "vfmadd231ps zmm23,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm23\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm24,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm24\t\n" + "vfmadd231ps zmm25,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm25\t\n" + "add r12, r13\t\n" + + "L_done%=:\t\n" + + // next outer iteration + "add rcx, 128\t\n" + "mov r12, rcx\t\n" + "mov r9, rax\t\n" + "inc rbx\t\n" + "cmp rbx, rdi\t\n" + "jl loop_outter%=\t\n" + : + : [gp] "rm"(gp) + : "r8", + "r9", + "r10", + "r11", + "r15", + "r13", + "r14", + "rax", + "rcx", + "rdx", + "rsi", + "rdi", + "rbx", + "r12", + "memory"); +} +void __attribute__((noinline)) +gemmkernel_14x2_AVX512_fA0fB0fC0(GemmParams* gp) { + asm volatile( +#if !defined(__clang__) + "mov r14, %[gp]\t\n" +#else + "mov %[gp], %%r14\t\n" + ".intel_syntax noprefix\t\n" +#endif + + // Copy parameters + // k + "mov r8, [r14 + 0]\t\n" + // A + "mov r9, [r14 + 8]\t\n" + // B + "mov r10, [r14 + 16]\t\n" + // beta + "mov r15, [r14 + 24]\t\n" + // accum + "mov rdx, [r14 + 32]\t\n" + // C + "mov r12, [r14 + 40]\t\n" + // ldc + "mov r13, [r14 + 48]\t\n" + // b_block_cols + "mov rdi, [r14 + 56]\t\n" + // b_block_size + "mov rsi, [r14 + 64]\t\n" + // Make copies of A and C + "mov rax, r9\t\n" + "mov rcx, r12\t\n" + + "mov rbx, 0\t\n" + "loop_outter%=:\t\n" + "mov r14, 0\t\n" + "vxorps zmm0,zmm0,zmm0\t\n" + "vxorps zmm1,zmm1,zmm1\t\n" + "vxorps zmm2,zmm2,zmm2\t\n" + "vxorps zmm3,zmm3,zmm3\t\n" + "vxorps zmm4,zmm4,zmm4\t\n" + "vxorps zmm5,zmm5,zmm5\t\n" + "vxorps zmm6,zmm6,zmm6\t\n" + "vxorps zmm7,zmm7,zmm7\t\n" + "vxorps zmm8,zmm8,zmm8\t\n" + "vxorps zmm9,zmm9,zmm9\t\n" + "vxorps zmm10,zmm10,zmm10\t\n" + "vxorps zmm11,zmm11,zmm11\t\n" + "vxorps zmm12,zmm12,zmm12\t\n" + "vxorps zmm13,zmm13,zmm13\t\n" + "vxorps zmm14,zmm14,zmm14\t\n" + "vxorps zmm15,zmm15,zmm15\t\n" + "vxorps zmm16,zmm16,zmm16\t\n" + "vxorps zmm17,zmm17,zmm17\t\n" + "vxorps zmm18,zmm18,zmm18\t\n" + "vxorps zmm19,zmm19,zmm19\t\n" + "vxorps zmm20,zmm20,zmm20\t\n" + "vxorps zmm21,zmm21,zmm21\t\n" + "vxorps zmm22,zmm22,zmm22\t\n" + "vxorps zmm23,zmm23,zmm23\t\n" + "vxorps zmm24,zmm24,zmm24\t\n" + "vxorps zmm25,zmm25,zmm25\t\n" + "vxorps zmm26,zmm26,zmm26\t\n" + "vxorps zmm27,zmm27,zmm27\t\n" + + "loop_inner%=:\t\n" + + "vcvtph2ps zmm29,YMMWORD PTR [r10 + 0]\t\n" + "vcvtph2ps zmm30,YMMWORD PTR [r10 + 32]\t\n" + "vbroadcastss zmm28,DWORD PTR [r9+0]\t\n" + "vfmadd231ps zmm0,zmm29,zmm28\t\n" + "vfmadd231ps zmm1,zmm30,zmm28\t\n" + "vbroadcastss zmm28,DWORD PTR [r9+4]\t\n" + "vfmadd231ps zmm2,zmm29,zmm28\t\n" + "vfmadd231ps zmm3,zmm30,zmm28\t\n" + "vbroadcastss zmm28,DWORD PTR [r9+8]\t\n" + "vfmadd231ps zmm4,zmm29,zmm28\t\n" + "vfmadd231ps zmm5,zmm30,zmm28\t\n" + "vbroadcastss zmm28,DWORD PTR [r9+12]\t\n" + "vfmadd231ps zmm6,zmm29,zmm28\t\n" + "vfmadd231ps zmm7,zmm30,zmm28\t\n" + "vbroadcastss zmm28,DWORD PTR [r9+16]\t\n" + "vfmadd231ps zmm8,zmm29,zmm28\t\n" + "vfmadd231ps zmm9,zmm30,zmm28\t\n" + "vbroadcastss zmm28,DWORD PTR [r9+20]\t\n" + "vfmadd231ps zmm10,zmm29,zmm28\t\n" + "vfmadd231ps zmm11,zmm30,zmm28\t\n" + "vbroadcastss zmm28,DWORD PTR [r9+24]\t\n" + "vfmadd231ps zmm12,zmm29,zmm28\t\n" + "vfmadd231ps zmm13,zmm30,zmm28\t\n" + "vbroadcastss zmm28,DWORD PTR [r9+28]\t\n" + "vfmadd231ps zmm14,zmm29,zmm28\t\n" + "vfmadd231ps zmm15,zmm30,zmm28\t\n" + "vbroadcastss zmm28,DWORD PTR [r9+32]\t\n" + "vfmadd231ps zmm16,zmm29,zmm28\t\n" + "vfmadd231ps zmm17,zmm30,zmm28\t\n" + "vbroadcastss zmm28,DWORD PTR [r9+36]\t\n" + "vfmadd231ps zmm18,zmm29,zmm28\t\n" + "vfmadd231ps zmm19,zmm30,zmm28\t\n" + "vbroadcastss zmm28,DWORD PTR [r9+40]\t\n" + "vfmadd231ps zmm20,zmm29,zmm28\t\n" + "vfmadd231ps zmm21,zmm30,zmm28\t\n" + "vbroadcastss zmm28,DWORD PTR [r9+44]\t\n" + "vfmadd231ps zmm22,zmm29,zmm28\t\n" + "vfmadd231ps zmm23,zmm30,zmm28\t\n" + "vbroadcastss zmm28,DWORD PTR [r9+48]\t\n" + "vfmadd231ps zmm24,zmm29,zmm28\t\n" + "vfmadd231ps zmm25,zmm30,zmm28\t\n" + "vbroadcastss zmm28,DWORD PTR [r9+52]\t\n" + "vfmadd231ps zmm26,zmm29,zmm28\t\n" + "vfmadd231ps zmm27,zmm30,zmm28\t\n" + "add r9,56\t\n" + "add r10,64\t\n" + "inc r14\t\n" + "cmp r14, r8\t\n" + "jl loop_inner%=\t\n" + + "L_exit%=:\t\n" + + "cmp rdx, 1\t\n" + "je L_accum%=\t\n" + // Dump C + "vmovups zmmword PTR [r12 + 0], zmm0\t\n" + "vmovups zmmword PTR [r12 + 64], zmm1\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm2\t\n" + "vmovups zmmword PTR [r12 + 64], zmm3\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm4\t\n" + "vmovups zmmword PTR [r12 + 64], zmm5\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm6\t\n" + "vmovups zmmword PTR [r12 + 64], zmm7\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm8\t\n" + "vmovups zmmword PTR [r12 + 64], zmm9\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm10\t\n" + "vmovups zmmword PTR [r12 + 64], zmm11\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm12\t\n" + "vmovups zmmword PTR [r12 + 64], zmm13\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm14\t\n" + "vmovups zmmword PTR [r12 + 64], zmm15\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm16\t\n" + "vmovups zmmword PTR [r12 + 64], zmm17\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm18\t\n" + "vmovups zmmword PTR [r12 + 64], zmm19\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm20\t\n" + "vmovups zmmword PTR [r12 + 64], zmm21\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm22\t\n" + "vmovups zmmword PTR [r12 + 64], zmm23\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm24\t\n" + "vmovups zmmword PTR [r12 + 64], zmm25\t\n" + "add r12, r13\t\n" + "vmovups zmmword PTR [r12 + 0], zmm26\t\n" + "vmovups zmmword PTR [r12 + 64], zmm27\t\n" + "add r12, r13\t\n" + "jmp L_done%=\t\n" + + "L_accum%=:\t\n" + // Dump C with accumulate + "vbroadcastss zmm31,DWORD PTR [r15]\t\n" + "vfmadd231ps zmm0,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm0\t\n" + "vfmadd231ps zmm1,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm1\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm2,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm2\t\n" + "vfmadd231ps zmm3,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm3\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm4,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm4\t\n" + "vfmadd231ps zmm5,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm5\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm6,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm6\t\n" + "vfmadd231ps zmm7,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm7\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm8,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm8\t\n" + "vfmadd231ps zmm9,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm9\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm10,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm10\t\n" + "vfmadd231ps zmm11,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm11\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm12,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm12\t\n" + "vfmadd231ps zmm13,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm13\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm14,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm14\t\n" + "vfmadd231ps zmm15,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm15\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm16,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm16\t\n" + "vfmadd231ps zmm17,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm17\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm18,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm18\t\n" + "vfmadd231ps zmm19,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm19\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm20,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm20\t\n" + "vfmadd231ps zmm21,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm21\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm22,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm22\t\n" + "vfmadd231ps zmm23,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm23\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm24,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm24\t\n" + "vfmadd231ps zmm25,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm25\t\n" + "add r12, r13\t\n" + "vfmadd231ps zmm26,zmm31,zmmword PTR [r12 + 0]\t\n" + "vmovups zmmword PTR [r12 + 0], zmm26\t\n" + "vfmadd231ps zmm27,zmm31,zmmword PTR [r12 + 64]\t\n" + "vmovups zmmword PTR [r12 + 64], zmm27\t\n" + "add r12, r13\t\n" + + "L_done%=:\t\n" + + // next outer iteration + "add rcx, 128\t\n" + "mov r12, rcx\t\n" + "mov r9, rax\t\n" + "inc rbx\t\n" + "cmp rbx, rdi\t\n" + "jl loop_outter%=\t\n" + : + : [gp] "rm"(gp) + : "r8", + "r9", + "r10", + "r11", + "r15", + "r13", + "r14", + "rax", + "rcx", + "rdx", + "rsi", + "rdi", + "rbx", + "r12", + "memory"); +} + +} // namespace fbgemm diff --git a/src/FbgemmFP16UKernelsAvx512.h b/src/FbgemmFP16UKernelsAvx512.h new file mode 100644 index 0000000..bbb41f5 --- /dev/null +++ b/src/FbgemmFP16UKernelsAvx512.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once +#include <cstdint> +#include "FbgemmFP16UKernelsAvx2.h" +#include "fbgemm/Types.h" + +namespace fbgemm { + +void __attribute__((noinline)) gemmkernel_1x2_AVX512_fA0fB0fC0(GemmParams* gp); +void __attribute__((noinline)) gemmkernel_2x2_AVX512_fA0fB0fC0(GemmParams* gp); +void __attribute__((noinline)) gemmkernel_3x2_AVX512_fA0fB0fC0(GemmParams* gp); +void __attribute__((noinline)) gemmkernel_4x2_AVX512_fA0fB0fC0(GemmParams* gp); +void __attribute__((noinline)) gemmkernel_5x2_AVX512_fA0fB0fC0(GemmParams* gp); +void __attribute__((noinline)) gemmkernel_6x2_AVX512_fA0fB0fC0(GemmParams* gp); +void __attribute__((noinline)) gemmkernel_7x2_AVX512_fA0fB0fC0(GemmParams* gp); +void __attribute__((noinline)) gemmkernel_8x2_AVX512_fA0fB0fC0(GemmParams* gp); +void __attribute__((noinline)) gemmkernel_9x2_AVX512_fA0fB0fC0(GemmParams* gp); +void __attribute__((noinline)) gemmkernel_10x2_AVX512_fA0fB0fC0(GemmParams* gp); +void __attribute__((noinline)) gemmkernel_11x2_AVX512_fA0fB0fC0(GemmParams* gp); +void __attribute__((noinline)) gemmkernel_12x2_AVX512_fA0fB0fC0(GemmParams* gp); +void __attribute__((noinline)) gemmkernel_13x2_AVX512_fA0fB0fC0(GemmParams* gp); +void __attribute__((noinline)) gemmkernel_14x2_AVX512_fA0fB0fC0(GemmParams* gp); +typedef void (*funcptr_fp16)(GemmParams* gp); +; + +} // namespace fbgemm + diff --git a/src/FbgemmI8Depthwise2DAvx2-inl.h b/src/FbgemmI8Depthwise2DAvx2-inl.h new file mode 100644 index 0000000..1dda555 --- /dev/null +++ b/src/FbgemmI8Depthwise2DAvx2-inl.h @@ -0,0 +1,1623 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include <tuple> // for tie + +#include "src/FbgemmI8DepthwiseAvx2-inl.h" + +namespace fbgemm { + +template <int S = 3, bool SUM_A = false, bool REMAINDER = false> +static inline __attribute__((always_inline)) void inner_prod_2d_packed_( + const __m256i* a_v, + const __m256i* Bp, + std::int32_t* C, + int remainder, + __m256i* a_sum = nullptr) { + return inner_prod_packed_<S * S, SUM_A, REMAINDER>( + a_v, Bp, C, remainder, a_sum); +} + +template < + bool SUM_A, + bool REMAINDER = false, + bool PER_CHANNEL_QUANTIZATION = false> +static inline __attribute__((always_inline)) void inner_prod_3x3_packed_( + int H, + int W, + int K, + int h_in, + int w_in, + const std::uint8_t* A, + std::int32_t A_zero_point, + const std::int8_t* Bp, + const std::int32_t* B_zero_point, + std::int32_t* C, + int remainder, + std::int32_t* row_offsets) { + __m256i A_zero_point_v = + _mm256_set1_epi8(static_cast<std::uint8_t>(A_zero_point)); + __m256i mask_v = _mm256_setzero_si256(); + if (REMAINDER) { + mask_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(masks[remainder / 4])); + } + + // The code below can be written as a simple R*S loop but the compiler + // doesn't unroll so we're manually unrolling it. + // constexpr int R = 3, S = 3; + // array<__m256i, R * S> a_v; + // for (int r = 0; r < R; ++r) { + // for (int s = 0; s < S; ++s) { + // if (h_in + r >= 0 && h_in + r < H && w_in + s >= 0 && w_in + s < W) { + // if (REMAINDER) { + // a_v[r * S + s] = + // _mm256_maskload_epi32((const int *)(A + (r * W + s) * K), + // mask_v); + // } else { + // a_v[r * S + s] = + // _mm256_lddqu_si256((const __m256i *)(A + (r * W + s) * K)); + // } + // } else { + // a_v[r * S + s] = A_zero_point_v; + // } + // } + // } + __m256i a_v[9] = { + A_zero_point_v, + A_zero_point_v, + A_zero_point_v, + A_zero_point_v, + A_zero_point_v, + A_zero_point_v, + A_zero_point_v, + A_zero_point_v, + A_zero_point_v, + }; + + if (h_in >= 0 && h_in < H) { + if (w_in >= 0 && w_in < W) { + a_v[0] = load_a<REMAINDER>(A + (0 * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[1] = load_a<REMAINDER>(A + (0 * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[2] = load_a<REMAINDER>(A + (0 * W + 2) * K, mask_v); + } + } + + if (h_in + 1 >= 0 && h_in + 1 < H) { + if (w_in >= 0 && w_in < W) { + a_v[3] = load_a<REMAINDER>(A + (1 * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[4] = load_a<REMAINDER>(A + (1 * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[5] = load_a<REMAINDER>(A + (1 * W + 2) * K, mask_v); + } + } + + if (h_in + 2 >= 0 && h_in + 2 < H) { + if (w_in >= 0 && w_in < W) { + a_v[6] = load_a<REMAINDER>(A + (2 * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[7] = load_a<REMAINDER>(A + (2 * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[8] = load_a<REMAINDER>(A + (2 * W + 2) * K, mask_v); + } + } + + __m256i a_sum[4]; + inner_prod_2d_packed_<3, SUM_A, REMAINDER>( + a_v, reinterpret_cast<const __m256i*>(Bp), C, remainder, a_sum); + if (SUM_A) { + __m256i B_zero_point_v; + for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) { + if (PER_CHANNEL_QUANTIZATION) { + B_zero_point_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(B_zero_point + i * 8)); + } else { + B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]); + } + _mm256_store_si256( + reinterpret_cast<__m256i*>(&row_offsets[i * 8]), + _mm256_mullo_epi32(a_sum[i], B_zero_point_v)); + } + } +} + +template < + bool SUM_A, + bool REMAINDER = false, + bool PER_CHANNEL_QUANTIZATION = false> +static inline __attribute__((always_inline)) void inner_prod_5x5_packed_( + int H, + int W, + int K, + int h_in, + int w_in, + const std::uint8_t* A, + std::int32_t A_zero_point, + const std::int8_t* Bp, + const std::int32_t* B_zero_point, + std::int32_t* C, + int remainder, + std::int32_t* row_offsets) { + __m256i A_zero_point_v = + _mm256_set1_epi8(static_cast<std::uint8_t>(A_zero_point)); + __m256i mask_v = _mm256_setzero_si256(); + if (REMAINDER) { + mask_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(masks[remainder / 4])); + } + + // The code below can be written as a simple R*S loop but the compiler + // doesn't unroll so we're manually unrolling it. + // constexpr int R = 5, S = 5; + // array<__m256i, R * S> a_v; + // for (int r = 0; r < R; ++r) { + // for (int s = 0; s < S; ++s) { + // if (h_in + r >= 0 && h_in + r < H && w_in + s >= 0 && w_in + s < W) { + // if (REMAINDER) { + // a_v[r * S + s] = + // _mm256_maskload_epi32((const int *)(A + (r * W + s) * K), + // mask_v); + // } else { + // a_v[r * S + s] = + // _mm256_lddqu_si256((const __m256i *)(A + (r * W + s) * K)); + // } + // } else { + // a_v[r * S + s] = A_zero_point_v; + // } + // } + // } + __m256i a_v[25] = { + A_zero_point_v, A_zero_point_v, A_zero_point_v, A_zero_point_v, + A_zero_point_v, A_zero_point_v, A_zero_point_v, A_zero_point_v, + A_zero_point_v, A_zero_point_v, A_zero_point_v, A_zero_point_v, + A_zero_point_v, A_zero_point_v, A_zero_point_v, A_zero_point_v, + A_zero_point_v, A_zero_point_v, A_zero_point_v, A_zero_point_v, + A_zero_point_v, A_zero_point_v, A_zero_point_v, A_zero_point_v, + A_zero_point_v, + }; + + if (h_in >= 0 && h_in < H) { + if (w_in >= 0 && w_in < W) { + a_v[0] = load_a<REMAINDER>(A + (0 * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[1] = load_a<REMAINDER>(A + (0 * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[2] = load_a<REMAINDER>(A + (0 * W + 2) * K, mask_v); + } + if (w_in + 3 >= 0 && w_in + 3 < W) { + a_v[3] = load_a<REMAINDER>(A + (0 * W + 3) * K, mask_v); + } + if (w_in + 4 >= 0 && w_in + 4 < W) { + a_v[4] = load_a<REMAINDER>(A + (0 * W + 4) * K, mask_v); + } + } + + if (h_in + 1 >= 0 && h_in + 1 < H) { + if (w_in >= 0 && w_in < W) { + a_v[5] = load_a<REMAINDER>(A + (1 * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[6] = load_a<REMAINDER>(A + (1 * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[7] = load_a<REMAINDER>(A + (1 * W + 2) * K, mask_v); + } + if (w_in + 3 >= 0 && w_in + 3 < W) { + a_v[8] = load_a<REMAINDER>(A + (1 * W + 3) * K, mask_v); + } + if (w_in + 4 >= 0 && w_in + 4 < W) { + a_v[9] = load_a<REMAINDER>(A + (1 * W + 4) * K, mask_v); + } + } + + if (h_in + 2 >= 0 && h_in + 2 < H) { + if (w_in >= 0 && w_in < W) { + a_v[10] = load_a<REMAINDER>(A + (2 * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[11] = load_a<REMAINDER>(A + (2 * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[12] = load_a<REMAINDER>(A + (2 * W + 2) * K, mask_v); + } + if (w_in + 3 >= 0 && w_in + 3 < W) { + a_v[13] = load_a<REMAINDER>(A + (2 * W + 3) * K, mask_v); + } + if (w_in + 4 >= 0 && w_in + 4 < W) { + a_v[14] = load_a<REMAINDER>(A + (2 * W + 4) * K, mask_v); + } + } + + if (h_in + 3 >= 0 && h_in + 3 < H) { + if (w_in >= 0 && w_in < W) { + a_v[15] = load_a<REMAINDER>(A + (3 * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[16] = load_a<REMAINDER>(A + (3 * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[17] = load_a<REMAINDER>(A + (3 * W + 2) * K, mask_v); + } + if (w_in + 3 >= 0 && w_in + 3 < W) { + a_v[18] = load_a<REMAINDER>(A + (3 * W + 3) * K, mask_v); + } + if (w_in + 4 >= 0 && w_in + 4 < W) { + a_v[19] = load_a<REMAINDER>(A + (3 * W + 4) * K, mask_v); + } + } + + if (h_in + 4 >= 0 && h_in + 4 < H) { + if (w_in >= 0 && w_in < W) { + a_v[20] = load_a<REMAINDER>(A + (4 * W + 0) * K, mask_v); + } + if (w_in + 1 >= 0 && w_in + 1 < W) { + a_v[21] = load_a<REMAINDER>(A + (4 * W + 1) * K, mask_v); + } + if (w_in + 2 >= 0 && w_in + 2 < W) { + a_v[22] = load_a<REMAINDER>(A + (4 * W + 2) * K, mask_v); + } + if (w_in + 3 >= 0 && w_in + 3 < W) { + a_v[23] = load_a<REMAINDER>(A + (4 * W + 3) * K, mask_v); + } + if (w_in + 4 >= 0 && w_in + 4 < W) { + a_v[24] = load_a<REMAINDER>(A + (4 * W + 4) * K, mask_v); + } + } + + __m256i a_sum[4]; + inner_prod_2d_packed_<5, SUM_A, REMAINDER>( + a_v, reinterpret_cast<const __m256i*>(Bp), C, remainder, a_sum); + if (SUM_A) { + __m256i B_zero_point_v; + for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) { + if (PER_CHANNEL_QUANTIZATION) { + B_zero_point_v = _mm256_loadu_si256( + reinterpret_cast<const __m256i*>(B_zero_point + i * 8)); + } else { + B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]); + } + _mm256_store_si256( + reinterpret_cast<__m256i*>(&row_offsets[i * 8]), + _mm256_mullo_epi32(a_sum[i], B_zero_point_v)); + } + } +} + +template < + int S, + bool SUM_A, + bool REMAINDER = false, + bool PER_CHANNEL_QUANTIZATION = false> +static inline __attribute__((always_inline)) void inner_prod_2d_packed_( + int H, + int W, + int K, + int h_in, + int w_in, + const std::uint8_t* A, + std::int32_t A_zero_point, + const std::int8_t* Bp, + const std::int32_t* B_zero_point, + std::int32_t* C, + int remainder, + std::int32_t* row_offsets) { + if (S == 3) { + inner_prod_3x3_packed_<SUM_A, REMAINDER, PER_CHANNEL_QUANTIZATION>( + H, + W, + K, + h_in, + w_in, + A, + A_zero_point, + Bp, + B_zero_point, + C, + remainder, + row_offsets); + } else { + assert(S == 5); + inner_prod_5x5_packed_<SUM_A, REMAINDER, PER_CHANNEL_QUANTIZATION>( + H, + W, + K, + h_in, + w_in, + A, + A_zero_point, + Bp, + B_zero_point, + C, + remainder, + row_offsets); + } +} + +template < + int S, + bool FUSE_RELU, + bool HAS_BIAS, + bool A_SYMMETRIC, + bool B_SYMMETRIC, + typename BIAS_TYPE> +static inline __attribute__((always_inline)) void depthwise_2d_kernel_( + int H, + int W, + int K, + int h, + int w, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, + std::int32_t B_zero_point, + const std::int8_t* Bp, + float C_multiplier, + std::int32_t C_zero_point, + std::int32_t* C_int32, + std::uint8_t* C_uint8, + std::int32_t* row_offsets, + const std::int32_t* col_offsets, + const BIAS_TYPE* bias, + float act_times_w_scale) { + constexpr int PAD_T = (S - 1) / 2, PAD_L = (S - 1) / 2, PAD_R = (S - 1) / 2; + int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; + int h_in = -PAD_T + h * stride_h; + int w_in = -PAD_L + w * stride_w; + + constexpr int KERNEL_PROD_ALIGNED = (S * S + 1) / 2 * 2; + + int k; + for (k = 0; k < K / 32 * 32; k += 32) { + inner_prod_2d_packed_<S, !B_SYMMETRIC /*SUM_A*/>( + H, + W, + K, + h_in, + w_in, + A + (h_in * W + w_in) * K + k, + A_zero_point, + Bp + k * KERNEL_PROD_ALIGNED, + &B_zero_point, + C_int32 + k, + 0, + B_SYMMETRIC ? nullptr : &row_offsets[k]); + } + int remainder = K - k; + if (remainder) { + inner_prod_2d_packed_<S, !B_SYMMETRIC, true>( + H, + W, + K, + h_in, + w_in, + A + (h_in * W + w_in) * K + k, + A_zero_point, + Bp + k * KERNEL_PROD_ALIGNED, + &B_zero_point, + C_int32 + k, + remainder, + B_SYMMETRIC ? nullptr : &row_offsets[k]); + } + + requantize_< + FUSE_RELU, + HAS_BIAS, + false, /*PER_CHAN_QUANT*/ + A_SYMMETRIC, + B_SYMMETRIC, + BIAS_TYPE>( + A_zero_point, + &C_multiplier, + C_zero_point, + C_int32, + C_uint8 + (h * W_OUT + w) * K, + K, + row_offsets, + col_offsets, + bias, + &act_times_w_scale); +} + +template < + int S, + bool FUSE_RELU, + bool HAS_BIAS, + bool A_SYMMETRIC, + typename BIAS_TYPE> +static inline __attribute__((always_inline)) void +depthwise_2d_per_channel_quantization_kernel_( + int H, + int W, + int K, + int h, + int w, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, + const std::int32_t* B_zero_point, + const std::int8_t* Bp, + const float* C_multiplier, + std::int32_t C_zero_point, + std::int32_t* C_int32, + std::uint8_t* C_uint8, + std::int32_t* row_offsets, + const std::int32_t* col_offsets, + const BIAS_TYPE* bias, + const float* act_times_w_scale) { + constexpr int PAD_T = (S - 1) / 2, PAD_L = (S - 1) / 2, PAD_R = (S - 1) / 2; + int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; + int h_in = -PAD_T + h * stride_h; + int w_in = -PAD_L + w * stride_w; + + constexpr int KERNEL_PROD_ALIGNED = (S * S + 1) / 2 * 2; + + int k; + for (k = 0; k < K / 32 * 32; k += 32) { + inner_prod_2d_packed_< + S, + true, /*SUM_A*/ + false, /*remainder*/ + true /*per-channel*/>( + H, + W, + K, + h_in, + w_in, + A + (h_in * W + w_in) * K + k, + A_zero_point, + Bp + k * KERNEL_PROD_ALIGNED, + B_zero_point + k, + C_int32 + k, + 0, + &row_offsets[k]); + } + int remainder = K - k; + if (remainder) { + inner_prod_2d_packed_< + S, + true, /*SUM_A*/ + true, /*remainder*/ + true /*per-channel*/>( + H, + W, + K, + h_in, + w_in, + A + (h_in * W + w_in) * K + k, + A_zero_point, + Bp + k * KERNEL_PROD_ALIGNED, + B_zero_point + k, + C_int32 + k, + remainder, + &row_offsets[k]); + } + + requantize_< + FUSE_RELU, + HAS_BIAS, + true, /*PER_CHAN_QUANT*/ + A_SYMMETRIC, + false, /*B_SYMM*/ + BIAS_TYPE>( + A_zero_point, + C_multiplier, + C_zero_point, + C_int32, + C_uint8 + (h * W_OUT + w) * K, + K, + row_offsets, + col_offsets, + bias, + act_times_w_scale); +} + +// TODO: short-circuit when B_zero_point is 0 or A_zero_point is 0 +// This implemntation should be general enough to handle not just 3x3 but other +// filter shapes by parameterizing with R and S but restricting it to just 3x3 +// for now. +template < + int S, + bool FUSE_RELU, + bool HAS_BIAS, + bool A_SYMMETRIC, + bool B_SYMMETRIC, + typename BIAS_TYPE> +static inline __attribute__((always_inline)) void depthwise_2d_( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, + std::int32_t B_zero_point, + const PackedDepthWiseConvMatrix& B, + float C_multiplier, + std::int32_t C_zero_point, + std::int32_t* C_int32, + std::uint8_t* C_uint8, + const std::int32_t* col_offsets, + const BIAS_TYPE* bias, + float act_times_w_scale, + int thread_id, + int num_threads) { + assert(K % 8 == 0); + constexpr int R = S; + constexpr int PAD_T = (R - 1) / 2, PAD_B = PAD_T, PAD_L = (S - 1) / 2, + PAD_R = PAD_L; + int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; + const std::int8_t* Bp = B.PackedMat(); + + std::int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64))); + + int n_begin, n_end; + int h_begin, h_end, w_begin, w_end; + if (N >= num_threads) { + int n_per_thread = (N + num_threads - 1) / num_threads; + n_begin = std::min(thread_id * n_per_thread, N); + n_end = std::min(n_begin + n_per_thread, N); + h_begin = 0; + h_end = H_OUT; + w_begin = 0; + w_end = W_OUT; + } else { + int nthreads_per_n = num_threads / N; + n_begin = std::min(thread_id / nthreads_per_n, N); + n_end = std::min(n_begin + 1, N); + + int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads); + int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads); + int nthreads_of_n = tid_of_n_end - tid_of_n_begin; + int tid_within_n = thread_id - tid_of_n_begin; + assert(tid_within_n >= 0); + assert(tid_within_n < nthreads_of_n); + + // n is processed by num_threads_h * num_threads_w 2D grid of threads + int num_threads_h, num_threads_w; + // num_threads_w <= num_threads_h + std::tie(num_threads_w, num_threads_h) = closest_factors_(nthreads_of_n); + int tid_h = tid_within_n / num_threads_w; + int tid_w = tid_within_n % num_threads_w; + + int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h; + h_begin = std::min(tid_h * h_per_thread, H_OUT); + h_end = std::min(h_begin + h_per_thread, H_OUT); + + int w_per_thread = (W_OUT + num_threads_w - 1) / num_threads_w; + w_begin = std::min(tid_w * w_per_thread, W_OUT); + w_end = std::min(w_begin + w_per_thread, W_OUT); + } + + for (int n = n_begin; n < n_end; ++n) { + const std::uint8_t* A_base = A + n * H * W * K; + std::uint8_t* C_uint8_base = C_uint8 + n * H_OUT * W_OUT * K; + + int h = 0; + int w = 0; + + for (h = h_begin; h < std::max(PAD_T, h_begin); ++h) { + for (w = w_begin; w < std::max(PAD_L, w_begin); ++w) { + depthwise_2d_kernel_< + S, + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + BIAS_TYPE>( + H, + W, + K, + h, + w, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale); + } + + for (; w < std::min(W_OUT - PAD_R, w_end); ++w) { + depthwise_2d_kernel_< + S, + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + BIAS_TYPE>( + H, + W, + K, + h, + w, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale); + } + + for (; w < w_end; ++w) { + depthwise_2d_kernel_< + S, + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + BIAS_TYPE>( + H, + W, + K, + h, + w, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale); + } + } + + for (; h < std::min(H - PAD_B, h_end); ++h) { + for (w = w_begin; w < std::max(PAD_L, w_begin); ++w) { + depthwise_2d_kernel_< + S, + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + BIAS_TYPE>( + H, + W, + K, + h, + w, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale); + } + + for (; w < std::min(W_OUT - PAD_R, w_end); ++w) { + depthwise_2d_kernel_< + S, + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + BIAS_TYPE>( + H, + W, + K, + h, + w, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale); + } + + for (; w < w_end; ++w) { + depthwise_2d_kernel_< + S, + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + BIAS_TYPE>( + H, + W, + K, + h, + w, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale); + } + } + + for (; h < h_end; ++h) { + for (w = w_begin; w < std::max(PAD_L, w_begin); ++w) { + depthwise_2d_kernel_< + S, + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + BIAS_TYPE>( + H, + W, + K, + h, + w, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale); + } + + for (; w < std::min(W_OUT - PAD_R, w_end); ++w) { + depthwise_2d_kernel_< + S, + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + BIAS_TYPE>( + H, + W, + K, + h, + w, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale); + } + + for (; w < w_end; ++w) { + depthwise_2d_kernel_< + S, + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + B_SYMMETRIC, + BIAS_TYPE>( + H, + W, + K, + h, + w, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale); + } + } + } // for each n +}; + +template < + int S, + bool FUSE_RELU, + bool HAS_BIAS, + bool A_SYMMETRIC, + typename BIAS_TYPE> +static inline __attribute__((always_inline)) void +depthwise_2d_per_channel_quantization_( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, + const std::int32_t* B_zero_point, + const PackedDepthWiseConvMatrix& B, + const float* C_multiplier, + std::int32_t C_zero_point, + std::int32_t* C_int32, + std::uint8_t* C_uint8, + const std::int32_t* col_offsets, + const BIAS_TYPE* bias, + const float* act_times_w_scale, + int thread_id, + int num_threads) { + assert(K % 8 == 0); + constexpr int R = S; + constexpr int PAD_T = (R - 1) / 2, PAD_B = PAD_T, PAD_L = (S - 1) / 2, + PAD_R = PAD_L; + int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; + int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; + const std::int8_t* Bp = B.PackedMat(); + + std::int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64))); + + int n_begin, n_end; + int h_begin, h_end, w_begin, w_end; + if (N >= num_threads) { + int n_per_thread = (N + num_threads - 1) / num_threads; + n_begin = std::min(thread_id * n_per_thread, N); + n_end = std::min(n_begin + n_per_thread, N); + h_begin = 0; + h_end = H_OUT; + w_begin = 0; + w_end = W_OUT; + } else { + int nthreads_per_n = num_threads / N; + n_begin = std::min(thread_id / nthreads_per_n, N); + n_end = std::min(n_begin + 1, N); + + int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads); + int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads); + int nthreads_of_n = tid_of_n_end - tid_of_n_begin; + int tid_within_n = thread_id - tid_of_n_begin; + assert(tid_within_n >= 0); + assert(tid_within_n < nthreads_of_n); + + // n is processed by num_threads_h * num_threads_w 2D grid of threads + int num_threads_h, num_threads_w; + // num_threads_w <= num_threads_h + std::tie(num_threads_w, num_threads_h) = closest_factors_(nthreads_of_n); + int tid_h = tid_within_n / num_threads_w; + int tid_w = tid_within_n % num_threads_w; + + int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h; + h_begin = std::min(tid_h * h_per_thread, H_OUT); + h_end = std::min(h_begin + h_per_thread, H_OUT); + + int w_per_thread = (W_OUT + num_threads_w - 1) / num_threads_w; + w_begin = std::min(tid_w * w_per_thread, W_OUT); + w_end = std::min(w_begin + w_per_thread, W_OUT); + } + + for (int n = n_begin; n < n_end; ++n) { + const std::uint8_t* A_base = A + n * H * W * K; + std::uint8_t* C_uint8_base = C_uint8 + n * H_OUT * W_OUT * K; + + int h = 0; + int w = 0; + + for (h = h_begin; h < std::max(PAD_T, h_begin); ++h) { + for (w = w_begin; w < std::max(PAD_L, w_begin); ++w) { + depthwise_2d_per_channel_quantization_kernel_< + S, + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + BIAS_TYPE>( + H, + W, + K, + h, + w, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale); + } + + for (; w < std::min(W_OUT - PAD_R, w_end); ++w) { + depthwise_2d_per_channel_quantization_kernel_< + S, + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + BIAS_TYPE>( + H, + W, + K, + h, + w, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale); + } + + for (; w < w_end; ++w) { + depthwise_2d_per_channel_quantization_kernel_< + S, + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + BIAS_TYPE>( + H, + W, + K, + h, + w, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale); + } + } + + for (; h < std::min(H - PAD_B, h_end); ++h) { + for (w = w_begin; w < std::max(PAD_L, w_begin); ++w) { + depthwise_2d_per_channel_quantization_kernel_< + S, + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + BIAS_TYPE>( + H, + W, + K, + h, + w, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale); + } + + for (; w < std::min(W_OUT - PAD_R, w_end); ++w) { + depthwise_2d_per_channel_quantization_kernel_< + S, + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + BIAS_TYPE>( + H, + W, + K, + h, + w, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale); + } + + for (; w < w_end; ++w) { + depthwise_2d_per_channel_quantization_kernel_< + S, + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + BIAS_TYPE>( + H, + W, + K, + h, + w, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale); + } + } + + for (; h < h_end; ++h) { + for (w = w_begin; w < std::max(PAD_L, w_begin); ++w) { + depthwise_2d_per_channel_quantization_kernel_< + S, + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + BIAS_TYPE>( + H, + W, + K, + h, + w, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale); + } + + for (; w < std::min(W_OUT - PAD_R, w_end); ++w) { + depthwise_2d_per_channel_quantization_kernel_< + S, + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + BIAS_TYPE>( + H, + W, + K, + h, + w, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale); + } + + for (; w < w_end; ++w) { + depthwise_2d_per_channel_quantization_kernel_< + S, + FUSE_RELU, + HAS_BIAS, + A_SYMMETRIC, + BIAS_TYPE>( + H, + W, + K, + h, + w, + stride_h, + stride_w, + A_zero_point, + A_base, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32, + C_uint8_base, + row_offsets, + col_offsets, + bias, + act_times_w_scale); + } + } + } // for each n +}; + +// Dispatch A_SYMMETRIC and B_SYMMETRIC +template <int S, bool FUSE_RELU, bool HAS_BIAS, typename BIAS_TYPE> +static void depthwise_2d_( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, + std::int32_t B_zero_point, + const PackedDepthWiseConvMatrix& B, + float C_multiplier, + std::int32_t C_zero_point, + std::uint8_t* C, + const std::int32_t* col_offsets, + const BIAS_TYPE* bias, + float act_times_w_scale, + int thread_id, + int num_threads) { + std::int32_t C_int32_temp[(K + 31) / 32 * 32]; + if (A_zero_point == 0 || col_offsets == nullptr) { + if (B_zero_point == 0) { + depthwise_2d_< + S, + FUSE_RELU, + HAS_BIAS, + true /*A_symmetric*/, + true /*B_symmetric*/, + BIAS_TYPE>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else { + depthwise_2d_< + S, + FUSE_RELU, + HAS_BIAS, + true /*A_symmetric*/, + false /*B_symmetric*/, + BIAS_TYPE>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } + } else { + if (B_zero_point == 0) { + depthwise_2d_< + S, + FUSE_RELU, + HAS_BIAS, + false /*A_symmetric*/, + true /*B_symmetric*/, + BIAS_TYPE>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else { + depthwise_2d_< + S, + FUSE_RELU, + HAS_BIAS, + false /*A_symmetric*/, + false /*B_symmetric*/, + BIAS_TYPE>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } + } +} + +// Dispatch HAS_BIAS +template <int S, bool FUSE_RELU, typename BIAS_TYPE> +static void depthwise_2d_( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, + std::int32_t B_zero_point, + const PackedDepthWiseConvMatrix& B, + float C_multiplier, + std::int32_t C_zero_point, + std::uint8_t* C, + const std::int32_t* col_offsets, + const BIAS_TYPE* bias, + float act_times_w_scale, + int thread_id, + int num_threads) { + if (bias) { + depthwise_2d_<S, FUSE_RELU, true /*HAS_BIAS*/, BIAS_TYPE>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else { + depthwise_2d_<S, FUSE_RELU, false /*HAS_BIAS*/, BIAS_TYPE>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } +} + +// Dispatch A_SYMMETRIC +template <int S, bool FUSE_RELU, bool HAS_BIAS, typename BIAS_TYPE> +static void depthwise_2d_per_channel_quantization_( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const PackedDepthWiseConvMatrix& Bp, + const float* C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const BIAS_TYPE* bias, + const float* act_times_w_scale, + int thread_id, + int num_threads) { + int32_t C_int32_temp[(K + 31) / 32 * 32]; + if (A_zero_point == 0 || col_offsets == nullptr) { + depthwise_2d_per_channel_quantization_< + S, + FUSE_RELU, + HAS_BIAS, + true /*A_SYMM*/, + BIAS_TYPE>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else { + depthwise_2d_per_channel_quantization_< + S, + FUSE_RELU, + HAS_BIAS, + false /*A_SYMM*/, + BIAS_TYPE>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C_int32_temp, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } +} + +// Dispatch HAS_BIAS +template <int S, bool FUSE_RELU, typename BIAS_TYPE> +static void depthwise_2d_per_channel_quantization_( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const PackedDepthWiseConvMatrix& Bp, + const float* C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const BIAS_TYPE* bias, + const float* act_times_w_scale, + int thread_id, + int num_threads) { + if (bias) { + depthwise_2d_per_channel_quantization_< + S, + FUSE_RELU, + true /* HAS_BIAS */, + BIAS_TYPE>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else { + depthwise_2d_per_channel_quantization_< + S, + FUSE_RELU, + false /* HAS_BIAS */, + BIAS_TYPE>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } +} + +template <typename BIAS_TYPE = std::int32_t> +FBGEMM_API void depthwise_3x3_pad_1( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, + std::int32_t B_zero_point, + const PackedDepthWiseConvMatrix& Bp, + float C_multiplier, + std::int32_t C_zero_point, + std::uint8_t* C, + const std::int32_t* col_offsets, + const BIAS_TYPE* bias, + bool fuse_relu = false, + float act_times_w_scale = 1.0f, + int thread_id = 0, + int num_threads = 1); + +template <typename BIAS_TYPE = std::int32_t> +FBGEMM_API void depthwise_3x3_per_channel_quantization_pad_1( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + std::int32_t A_zero_point, + const std::uint8_t* A, + const std::int32_t* B_zero_point, + const PackedDepthWiseConvMatrix& Bp, + const float* C_multiplier, + std::int32_t C_zero_point, + std::uint8_t* C, + const std::int32_t* col_offsets, + const BIAS_TYPE* bias, + bool fuse_relu = false, + const float* act_times_w_scale = nullptr, + int thread_id = 0, + int num_threads = 1); + +} // namespace fbgemm diff --git a/src/FbgemmI8Depthwise3DAvx2.cc b/src/FbgemmI8Depthwise3DAvx2.cc index 2114b20..ee8cc29 100644 --- a/src/FbgemmI8Depthwise3DAvx2.cc +++ b/src/FbgemmI8Depthwise3DAvx2.cc @@ -1237,97 +1237,6 @@ void depthwise_3x3x3_per_channel_quantization_pad_1( } } -// To be removed -void depthwise_3x3x3_pad_1( - int N, - int T, - int H, - int W, - int K, - int stride_t, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - int32_t B_zero_point, - const PackedDepthWiseConvMatrix& B, - float C_multiplier, - int32_t C_zero_point, - uint8_t* C, - const int32_t* col_offsets, - const int32_t* bias, - bool fuse_relu, - int thread_id, - int num_threads) { - depthwise_3x3x3_pad_1<int32_t>( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - fuse_relu, - 1.0f, // act_scale * weight_scale - thread_id, - num_threads); -} - -void depthwise_3x3x3_per_channel_quantization_pad_1( - int N, - int T, - int H, - int W, - int K, - int stride_t, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - const int32_t* B_zero_point, - const PackedDepthWiseConvMatrix& B, - const float* C_multiplier, - int32_t C_zero_point, - uint8_t* C, - const int32_t* col_offsets, - const int32_t* bias, - bool fuse_relu, - int thread_id, - int num_threads) { - depthwise_3x3x3_per_channel_quantization_pad_1( - N, - T, - H, - W, - K, - stride_t, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - fuse_relu, - nullptr, // act_scale * weight_scale - thread_id, - num_threads); -} - template void depthwise_3x3x3_pad_1( int N, int T, diff --git a/src/FbgemmI8Depthwise3x3Avx2.cc b/src/FbgemmI8Depthwise3x3Avx2.cc new file mode 100644 index 0000000..5226a04 --- /dev/null +++ b/src/FbgemmI8Depthwise3x3Avx2.cc @@ -0,0 +1,618 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "fbgemm/FbgemmI8DepthwiseAvx2.h" + +#include <string> + +#include "FbgemmI8Depthwise2DAvx2-inl.h" + +using namespace std; + +namespace fbgemm { + +// Dispatch input shape and FUSE_RELU +// assumption: W > 3 and H > 3 +template <typename BIAS_TYPE> +void depthwise_3x3_pad_1( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + int32_t B_zero_point, + const PackedDepthWiseConvMatrix& B, + float C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const BIAS_TYPE* bias, + bool fuse_relu, + float act_times_w_scale, + int thread_id, + int num_threads) { + if (B.GetKernelProduct() != 3 * 3) { + string msg = + "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " + + to_string(3 * 3) + " but has " + to_string(B.GetKernelProduct()); + throw logic_error(msg); + } + if (stride_h == 0 || stride_w == 0 || num_threads == 0) { + assert(0 && "stride_h == 0 || stride_w == 0 || num_threads == 0"); + return; + } + if (N == 0) { + // In C2, batch size 0 is allowed, so we should just early return. + return; + } + if (fuse_relu) { + if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) { + depthwise_2d_<3, true /* FUSE_RELU */, BIAS_TYPE>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) { + depthwise_2d_<3, true /* FUSE_RELU */, BIAS_TYPE>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else if (1 == stride_h && 1 == stride_w) { + depthwise_2d_<3, true /* FUSE_RELU */, BIAS_TYPE>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else if (2 == stride_h && 2 == stride_w) { + depthwise_2d_<3, true /* FUSE_RELU */, BIAS_TYPE>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else { + depthwise_2d_<3, true /* FUSE_RELU */, BIAS_TYPE>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } + } else { + if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) { + depthwise_2d_<3, false /* FUSE_RELU */, BIAS_TYPE>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) { + depthwise_2d_<3, false /* FUSE_RELU */, BIAS_TYPE>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else if (1 == stride_h && 1 == stride_w) { + depthwise_2d_<3, false /* FUSE_RELU */, BIAS_TYPE>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else if (2 == stride_h && 2 == stride_w) { + depthwise_2d_<3, false /* FUSE_RELU */, BIAS_TYPE>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else { + depthwise_2d_<3, false /* FUSE_RELU */, BIAS_TYPE>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + B, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } + } +} + +// Dispatch input shape and FUSE_RELU +template <typename BIAS_TYPE> +void depthwise_3x3_per_channel_quantization_pad_1( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const PackedDepthWiseConvMatrix& Bp, + const float* C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const BIAS_TYPE* bias, + bool fuse_relu, + const float* act_times_w_scale, + int thread_id, + int num_threads) { + if (Bp.GetKernelProduct() != 3 * 3) { + string msg = + "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " + + to_string(3 * 3) + " but has " + to_string(Bp.GetKernelProduct()); + throw logic_error(msg); + } + if (stride_h == 0 || stride_w == 0 || num_threads == 0) { + assert(0 && "stride_h == 0 || stride_w == 0 || num_threads == 0"); + return; + } + if (N == 0) { + // In C2, batch size 0 is allowed, so we should just early return. + return; + } + if (fuse_relu) { + if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) { + depthwise_2d_per_channel_quantization_< + 3, + true /* FUSE_RELU */, + BIAS_TYPE>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) { + depthwise_2d_per_channel_quantization_< + 3, + true /* FUSE_RELU */, + BIAS_TYPE>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else if (1 == stride_h && 1 == stride_w) { + depthwise_2d_per_channel_quantization_< + 3, + true /* FUSE_RELU */, + BIAS_TYPE>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else if (2 == stride_h && 2 == stride_w) { + depthwise_2d_per_channel_quantization_< + 3, + true /* FUSE_RELU */, + BIAS_TYPE>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else { + depthwise_2d_per_channel_quantization_< + 3, + true /* FUSE_RELU */, + BIAS_TYPE>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } + } else { + if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) { + depthwise_2d_per_channel_quantization_< + 3, + false /* FUSE_RELU */, + BIAS_TYPE>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) { + depthwise_2d_per_channel_quantization_< + 3, + false /* FUSE_RELU */, + BIAS_TYPE>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else if (1 == stride_h && 1 == stride_w) { + depthwise_2d_per_channel_quantization_< + 3, + false /* FUSE_RELU */, + BIAS_TYPE>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else if (2 == stride_h && 2 == stride_w) { + depthwise_2d_per_channel_quantization_< + 3, + false /* FUSE_RELU */, + BIAS_TYPE>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else { + depthwise_2d_per_channel_quantization_< + 3, + false /* FUSE_RELU */, + BIAS_TYPE>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } + } +} + +template void depthwise_3x3_pad_1( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + int32_t B_zero_point, + const PackedDepthWiseConvMatrix& B, + float C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias, + bool fuse_relu, + float act_times_w_scale, + int thread_id, + int num_threads); + +template void depthwise_3x3_pad_1( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + int32_t B_zero_point, + const PackedDepthWiseConvMatrix& B, + float C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const float* bias, + bool fuse_relu, + float act_times_w_scale, + int thread_id, + int num_threads); + +template void depthwise_3x3_per_channel_quantization_pad_1( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const PackedDepthWiseConvMatrix& Bp, + const float* C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias, + bool fuse_relu, + const float* act_times_w_scale, + int thread_id, + int num_threads); + +template void depthwise_3x3_per_channel_quantization_pad_1( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const PackedDepthWiseConvMatrix& Bp, + const float* C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const float* bias, + bool fuse_relu, + const float* act_times_w_scale, + int thread_id, + int num_threads); + +} // namespace fbgemm diff --git a/src/FbgemmI8DepthwiseAvx2.cc b/src/FbgemmI8DepthwiseAvx2.cc index 994f206..ada1c75 100644 --- a/src/FbgemmI8DepthwiseAvx2.cc +++ b/src/FbgemmI8DepthwiseAvx2.cc @@ -8,1126 +8,16 @@ #include "fbgemm/Utils.h" #include <string> -#include <tuple> // for tie -#include "FbgemmI8DepthwiseAvx2-inl.h" +#include "FbgemmI8Depthwise2DAvx2-inl.h" using namespace std; namespace fbgemm { -template <bool SUM_A = false, bool REMAINDER = false> -static inline ALWAYS_INLINE void inner_prod_3x3_packed_( - const __m256i* a_v, - const __m256i* Bp, - int32_t* C, - int remainder, - __m256i* a_sum = nullptr) { - return inner_prod_packed_<9, SUM_A, REMAINDER>(a_v, Bp, C, remainder, a_sum); -} - -template < - bool SUM_A, - bool REMAINDER = false, - bool PER_CHANNEL_QUANTIZATION = false> -static inline ALWAYS_INLINE void inner_prod_3x3_packed_( - int H, - int W, - int K, - int h_in, - int w_in, - const uint8_t* A, - int32_t A_zero_point, - const int8_t* Bp, - const int32_t* B_zero_point, - int32_t* C, - int remainder, - int32_t* row_offsets) { - __m256i A_zero_point_v = _mm256_set1_epi8(static_cast<uint8_t>(A_zero_point)); - __m256i mask_v = _mm256_setzero_si256(); - if (REMAINDER) { - mask_v = _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(masks[remainder / 4])); - } - - // The code below can be written as a simple R*S loop but the compiler - // doesn't unroll so we're manually unrolling it. - // constexpr int R = 3, S = 3; - // array<__m256i, R * S> a_v; - // for (int r = 0; r < R; ++r) { - // for (int s = 0; s < S; ++s) { - // if (h_in + r >= 0 && h_in + r < H && w_in + s >= 0 && w_in + s < W) { - // if (REMAINDER) { - // a_v[r * S + s] = - // _mm256_maskload_epi32((const int *)(A + (r * W + s) * K), - // mask_v); - // } else { - // a_v[r * S + s] = - // _mm256_lddqu_si256((const __m256i *)(A + (r * W + s) * K)); - // } - // } else { - // a_v[r * S + s] = A_zero_point_v; - // } - // } - // } - __m256i a_v[9] = { - A_zero_point_v, - A_zero_point_v, - A_zero_point_v, - A_zero_point_v, - A_zero_point_v, - A_zero_point_v, - A_zero_point_v, - A_zero_point_v, - A_zero_point_v, - }; - - if (h_in >= 0 && h_in < H) { - if (w_in >= 0 && w_in < W) { - a_v[0] = load_a<REMAINDER>(A + (0 * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[1] = load_a<REMAINDER>(A + (0 * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[2] = load_a<REMAINDER>(A + (0 * W + 2) * K, mask_v); - } - } - - if (h_in + 1 >= 0 && h_in + 1 < H) { - if (w_in >= 0 && w_in < W) { - a_v[3] = load_a<REMAINDER>(A + (1 * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[4] = load_a<REMAINDER>(A + (1 * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[5] = load_a<REMAINDER>(A + (1 * W + 2) * K, mask_v); - } - } - - if (h_in + 2 >= 0 && h_in + 2 < H) { - if (w_in >= 0 && w_in < W) { - a_v[6] = load_a<REMAINDER>(A + (2 * W + 0) * K, mask_v); - } - if (w_in + 1 >= 0 && w_in + 1 < W) { - a_v[7] = load_a<REMAINDER>(A + (2 * W + 1) * K, mask_v); - } - if (w_in + 2 >= 0 && w_in + 2 < W) { - a_v[8] = load_a<REMAINDER>(A + (2 * W + 2) * K, mask_v); - } - } - - __m256i a_sum[4]; - inner_prod_3x3_packed_<SUM_A, REMAINDER>( - a_v, reinterpret_cast<const __m256i*>(Bp), C, remainder, a_sum); - if (SUM_A) { - __m256i B_zero_point_v; - for (int i = 0; i < (REMAINDER ? (remainder / 8) : 4); ++i) { - if (PER_CHANNEL_QUANTIZATION) { - B_zero_point_v = _mm256_loadu_si256( - reinterpret_cast<const __m256i*>(B_zero_point + i * 8)); - } else { - B_zero_point_v = _mm256_set1_epi32(B_zero_point[0]); - } - _mm256_store_si256( - reinterpret_cast<__m256i*>(&row_offsets[i * 8]), - _mm256_mullo_epi32(a_sum[i], B_zero_point_v)); - } - } -} - -template < - bool FUSE_RELU, - bool HAS_BIAS, - bool A_SYMMETRIC, - bool B_SYMMETRIC, - typename BIAS_TYPE> -static inline ALWAYS_INLINE void depthwise_3x3_kernel_( - int H, - int W, - int K, - int h, - int w, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - int32_t B_zero_point, - const int8_t* Bp, - float C_multiplier, - int32_t C_zero_point, - int32_t* C_int32, - uint8_t* C_uint8, - int32_t* row_offsets, - const int32_t* col_offsets, - const BIAS_TYPE* bias, - float act_times_w_scale) { - constexpr int S = 3; - constexpr int PAD_T = 1, PAD_L = 1, PAD_R = 1; - int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; - int h_in = -PAD_T + h * stride_h; - int w_in = -PAD_L + w * stride_w; - - int k; - for (k = 0; k < K / 32 * 32; k += 32) { - inner_prod_3x3_packed_<!B_SYMMETRIC /*SUM_A*/>( - H, - W, - K, - h_in, - w_in, - A + (h_in * W + w_in) * K + k, - A_zero_point, - Bp + k * 10, - &B_zero_point, - C_int32 + k, - 0, - B_SYMMETRIC ? nullptr : &row_offsets[k]); - } - int remainder = K - k; - if (remainder) { - inner_prod_3x3_packed_<!B_SYMMETRIC, true>( - H, - W, - K, - h_in, - w_in, - A + (h_in * W + w_in) * K + k, - A_zero_point, - Bp + k * 10, - &B_zero_point, - C_int32 + k, - remainder, - B_SYMMETRIC ? nullptr : &row_offsets[k]); - } - - requantize_< - FUSE_RELU, - HAS_BIAS, - false, /*PER_CHAN_QUANT*/ - A_SYMMETRIC, - B_SYMMETRIC, - BIAS_TYPE>( - A_zero_point, - &C_multiplier, - C_zero_point, - C_int32, - C_uint8 + (h * W_OUT + w) * K, - K, - row_offsets, - col_offsets, - bias, - &act_times_w_scale); -} - -template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, typename BIAS_TYPE> -static inline ALWAYS_INLINE void -depthwise_3x3_per_channel_quantization_kernel_( - int H, - int W, - int K, - int h, - int w, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - const int32_t* B_zero_point, - const int8_t* Bp, - const float* C_multiplier, - int32_t C_zero_point, - int32_t* C_int32, - uint8_t* C_uint8, - int32_t* row_offsets, - const int32_t* col_offsets, - const BIAS_TYPE* bias, - const float* act_times_w_scale) { - constexpr int S = 3; - constexpr int PAD_T = 1, PAD_L = 1, PAD_R = 1; - int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; - int h_in = -PAD_T + h * stride_h; - int w_in = -PAD_L + w * stride_w; - - int k; - for (k = 0; k < K / 32 * 32; k += 32) { - inner_prod_3x3_packed_< - true, /*SUM_A*/ - false, /*remainder*/ - true /*per-channel*/>( - H, - W, - K, - h_in, - w_in, - A + (h_in * W + w_in) * K + k, - A_zero_point, - Bp + k * 10, - B_zero_point + k, - C_int32 + k, - 0, - &row_offsets[k]); - } - int remainder = K - k; - if (remainder) { - inner_prod_3x3_packed_< - true, /*SUM_A*/ - true, /*remainder*/ - true /*per-channel*/>( - H, - W, - K, - h_in, - w_in, - A + (h_in * W + w_in) * K + k, - A_zero_point, - Bp + k * 10, - B_zero_point + k, - C_int32 + k, - remainder, - &row_offsets[k]); - } - - requantize_< - FUSE_RELU, - HAS_BIAS, - true, /*PER_CHAN_QUANT*/ - A_SYMMETRIC, - false, /*B_SYMM*/ - BIAS_TYPE>( - A_zero_point, - C_multiplier, - C_zero_point, - C_int32, - C_uint8 + (h * W_OUT + w) * K, - K, - row_offsets, - col_offsets, - bias, - act_times_w_scale); -} - -// TODO: short-circuit when B_zero_point is 0 or A_zero_point is 0 -// This implemntation should be general enough to handle not just 3x3 but other -// filter shapes by parameterizing with R and S but restricting it to just 3x3 -// for now. -template < - bool FUSE_RELU, - bool HAS_BIAS, - bool A_SYMMETRIC, - bool B_SYMMETRIC, - typename BIAS_TYPE> -static inline ALWAYS_INLINE void depthwise_3x3_pad_1_( - int N, - int H, - int W, - int K, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - int32_t B_zero_point, - const PackedDepthWiseConvMatrix& B, - float C_multiplier, - int32_t C_zero_point, - int32_t* C_int32, - uint8_t* C_uint8, - const int32_t* col_offsets, - const BIAS_TYPE* bias, - float act_times_w_scale, - int thread_id, - int num_threads) { - assert(K % 8 == 0); - constexpr int R = 3, S = 3; - constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; - int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; - const int8_t* Bp = B.PackedMat(); - - int32_t* row_offsets = static_cast<int32_t *>(ALIGNED_MALLOC(((K + 31) / 32 * 32)*sizeof(int32_t), 64)); - - int n_begin, n_end; - int h_begin, h_end, w_begin, w_end; - if (N >= num_threads) { - int n_per_thread = (N + num_threads - 1) / num_threads; - n_begin = std::min(thread_id * n_per_thread, N); - n_end = std::min(n_begin + n_per_thread, N); - h_begin = 0; - h_end = H_OUT; - w_begin = 0; - w_end = W_OUT; - } else { - int nthreads_per_n = num_threads / N; - n_begin = std::min(thread_id / nthreads_per_n, N); - n_end = std::min(n_begin + 1, N); - - int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads); - int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads); - int nthreads_of_n = tid_of_n_end - tid_of_n_begin; - int tid_within_n = thread_id - tid_of_n_begin; - assert(tid_within_n >= 0); - assert(tid_within_n < nthreads_of_n); - - // n is processed by num_threads_h * num_threads_w 2D grid of threads - int num_threads_h, num_threads_w; - // num_threads_w <= num_threads_h - tie(num_threads_w, num_threads_h) = closest_factors_(nthreads_of_n); - int tid_h = tid_within_n / num_threads_w; - int tid_w = tid_within_n % num_threads_w; - - int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h; - h_begin = std::min(tid_h * h_per_thread, H_OUT); - h_end = std::min(h_begin + h_per_thread, H_OUT); - - int w_per_thread = (W_OUT + num_threads_w - 1) / num_threads_w; - w_begin = std::min(tid_w * w_per_thread, W_OUT); - w_end = std::min(w_begin + w_per_thread, W_OUT); - } - - for (int n = n_begin; n < n_end; ++n) { - const uint8_t* A_base = A + n * H * W * K; - uint8_t* C_uint8_base = C_uint8 + n * H_OUT * W_OUT * K; - - int h = 0; - int w = 0; - - if (h_begin == 0) { - if (w_begin == 0) { - depthwise_3x3_kernel_< - FUSE_RELU, - HAS_BIAS, - A_SYMMETRIC, - B_SYMMETRIC, - BIAS_TYPE>( - H, - W, - K, - h, - w, - stride_h, - stride_w, - A_zero_point, - A_base, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32, - C_uint8_base, - row_offsets, - col_offsets, - bias, - act_times_w_scale); - } - - for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { - depthwise_3x3_kernel_< - FUSE_RELU, - HAS_BIAS, - A_SYMMETRIC, - B_SYMMETRIC, - BIAS_TYPE>( - H, - W, - K, - h, - w, - stride_h, - stride_w, - A_zero_point, - A_base, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32, - C_uint8_base, - row_offsets, - col_offsets, - bias, - act_times_w_scale); - } - - if (w_end == W_OUT) { - w = W_OUT - 1; - depthwise_3x3_kernel_< - FUSE_RELU, - HAS_BIAS, - A_SYMMETRIC, - B_SYMMETRIC, - BIAS_TYPE>( - H, - W, - K, - h, - w, - stride_h, - stride_w, - A_zero_point, - A_base, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32, - C_uint8_base, - row_offsets, - col_offsets, - bias, - act_times_w_scale); - } - } - - for (h = std::max(1, h_begin); h < std::min(H - 1, h_end); ++h) { - if (w_begin == 0) { - w = 0; - depthwise_3x3_kernel_< - FUSE_RELU, - HAS_BIAS, - A_SYMMETRIC, - B_SYMMETRIC, - BIAS_TYPE>( - H, - W, - K, - h, - w, - stride_h, - stride_w, - A_zero_point, - A_base, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32, - C_uint8_base, - row_offsets, - col_offsets, - bias, - act_times_w_scale); - } - - for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { - depthwise_3x3_kernel_< - FUSE_RELU, - HAS_BIAS, - A_SYMMETRIC, - B_SYMMETRIC, - BIAS_TYPE>( - H, - W, - K, - h, - w, - stride_h, - stride_w, - A_zero_point, - A_base, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32, - C_uint8_base, - row_offsets, - col_offsets, - bias, - act_times_w_scale); - } - - if (w_end == W_OUT) { - w = W_OUT - 1; - depthwise_3x3_kernel_< - FUSE_RELU, - HAS_BIAS, - A_SYMMETRIC, - B_SYMMETRIC, - BIAS_TYPE>( - H, - W, - K, - h, - w, - stride_h, - stride_w, - A_zero_point, - A_base, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32, - C_uint8_base, - row_offsets, - col_offsets, - bias, - act_times_w_scale); - } - } - - if (h_end == H_OUT) { - h = H_OUT - 1; - w = 0; - if (w_begin == 0) { - depthwise_3x3_kernel_< - FUSE_RELU, - HAS_BIAS, - A_SYMMETRIC, - B_SYMMETRIC, - BIAS_TYPE>( - H, - W, - K, - h, - w, - stride_h, - stride_w, - A_zero_point, - A_base, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32, - C_uint8_base, - row_offsets, - col_offsets, - bias, - act_times_w_scale); - } - - for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { - depthwise_3x3_kernel_< - FUSE_RELU, - HAS_BIAS, - A_SYMMETRIC, - B_SYMMETRIC, - BIAS_TYPE>( - H, - W, - K, - h, - w, - stride_h, - stride_w, - A_zero_point, - A_base, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32, - C_uint8_base, - row_offsets, - col_offsets, - bias, - act_times_w_scale); - } - - if (w_end == W_OUT) { - w = W_OUT - 1; - depthwise_3x3_kernel_< - FUSE_RELU, - HAS_BIAS, - A_SYMMETRIC, - B_SYMMETRIC, - BIAS_TYPE>( - H, - W, - K, - h, - w, - stride_h, - stride_w, - A_zero_point, - A_base, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32, - C_uint8_base, - row_offsets, - col_offsets, - bias, - act_times_w_scale); - } - } - } // for each n - FREE(row_offsets); -}; - -template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, typename BIAS_TYPE> -static inline ALWAYS_INLINE void -depthwise_3x3_per_channel_quantization_pad_1_( - int N, - int H, - int W, - int K, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - const int32_t* B_zero_point, - const PackedDepthWiseConvMatrix& B, - const float* C_multiplier, - int32_t C_zero_point, - int32_t* C_int32, - uint8_t* C_uint8, - const int32_t* col_offsets, - const BIAS_TYPE* bias, - const float* act_times_w_scale, - int thread_id, - int num_threads) { - assert(K % 8 == 0); - constexpr int R = 3, S = 3; - constexpr int PAD_T = 1, PAD_B = 1, PAD_L = 1, PAD_R = 1; - int H_OUT = (H + PAD_T + PAD_B - R) / stride_h + 1; - int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; - const int8_t* Bp = B.PackedMat(); - - int32_t* row_offsets = static_cast<int32_t*>(ALIGNED_MALLOC(((K + 31) / 32 * 32)*sizeof(int32_t), 64)); // __attribute__((aligned(64))); - - int n_begin, n_end; - int h_begin, h_end, w_begin, w_end; - if (N >= num_threads) { - int n_per_thread = (N + num_threads - 1) / num_threads; - n_begin = std::min(thread_id * n_per_thread, N); - n_end = std::min(n_begin + n_per_thread, N); - h_begin = 0; - h_end = H_OUT; - w_begin = 0; - w_end = W_OUT; - } else { - int nthreads_per_n = num_threads / N; - n_begin = std::min(thread_id / nthreads_per_n, N); - n_end = std::min(n_begin + 1, N); - - int tid_of_n_begin = std::min(n_begin * nthreads_per_n, num_threads); - int tid_of_n_end = std::min(tid_of_n_begin + nthreads_per_n, num_threads); - int nthreads_of_n = tid_of_n_end - tid_of_n_begin; - int tid_within_n = thread_id - tid_of_n_begin; - assert(tid_within_n >= 0); - assert(tid_within_n < nthreads_of_n); - - // n is processed by num_threads_h * num_threads_w 2D grid of threads - int num_threads_h, num_threads_w; - // num_threads_w <= num_threads_h - tie(num_threads_w, num_threads_h) = closest_factors_(nthreads_of_n); - int tid_h = tid_within_n / num_threads_w; - int tid_w = tid_within_n % num_threads_w; - - int h_per_thread = (H_OUT + num_threads_h - 1) / num_threads_h; - h_begin = std::min(tid_h * h_per_thread, H_OUT); - h_end = std::min(h_begin + h_per_thread, H_OUT); - - int w_per_thread = (W_OUT + num_threads_w - 1) / num_threads_w; - w_begin = std::min(tid_w * w_per_thread, W_OUT); - w_end = std::min(w_begin + w_per_thread, W_OUT); - } - - for (int n = n_begin; n < n_end; ++n) { - const uint8_t* A_base = A + n * H * W * K; - uint8_t* C_uint8_base = C_uint8 + n * H_OUT * W_OUT * K; - - int h = 0; - int w = 0; - - if (h_begin == 0) { - if (w_begin == 0) { - depthwise_3x3_per_channel_quantization_kernel_< - FUSE_RELU, - HAS_BIAS, - A_SYMMETRIC, - BIAS_TYPE>( - H, - W, - K, - h, - w, - stride_h, - stride_w, - A_zero_point, - A_base, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32, - C_uint8_base, - row_offsets, - col_offsets, - bias, - act_times_w_scale); - } - - for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { - depthwise_3x3_per_channel_quantization_kernel_< - FUSE_RELU, - HAS_BIAS, - A_SYMMETRIC, - BIAS_TYPE>( - H, - W, - K, - h, - w, - stride_h, - stride_w, - A_zero_point, - A_base, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32, - C_uint8_base, - row_offsets, - col_offsets, - bias, - act_times_w_scale); - } - - if (w_end == W_OUT) { - w = W_OUT - 1; - depthwise_3x3_per_channel_quantization_kernel_< - FUSE_RELU, - HAS_BIAS, - A_SYMMETRIC, - BIAS_TYPE>( - H, - W, - K, - h, - w, - stride_h, - stride_w, - A_zero_point, - A_base, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32, - C_uint8_base, - row_offsets, - col_offsets, - bias, - act_times_w_scale); - } - } - - for (h = std::max(1, h_begin); h < std::min(H - 1, h_end); ++h) { - if (w_begin == 0) { - w = 0; - depthwise_3x3_per_channel_quantization_kernel_< - FUSE_RELU, - HAS_BIAS, - A_SYMMETRIC, - BIAS_TYPE>( - H, - W, - K, - h, - w, - stride_h, - stride_w, - A_zero_point, - A_base, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32, - C_uint8_base, - row_offsets, - col_offsets, - bias, - act_times_w_scale); - } - - for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { - depthwise_3x3_per_channel_quantization_kernel_< - FUSE_RELU, - HAS_BIAS, - A_SYMMETRIC, - BIAS_TYPE>( - H, - W, - K, - h, - w, - stride_h, - stride_w, - A_zero_point, - A_base, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32, - C_uint8_base, - row_offsets, - col_offsets, - bias, - act_times_w_scale); - } - - if (w_end == W_OUT) { - w = W_OUT - 1; - depthwise_3x3_per_channel_quantization_kernel_< - FUSE_RELU, - HAS_BIAS, - A_SYMMETRIC, - BIAS_TYPE>( - H, - W, - K, - h, - w, - stride_h, - stride_w, - A_zero_point, - A_base, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32, - C_uint8_base, - row_offsets, - col_offsets, - bias, - act_times_w_scale); - } - } - - if (h_end == H_OUT) { - h = H_OUT - 1; - w = 0; - if (w_begin == 0) { - depthwise_3x3_per_channel_quantization_kernel_< - FUSE_RELU, - HAS_BIAS, - A_SYMMETRIC, - BIAS_TYPE>( - H, - W, - K, - h, - w, - stride_h, - stride_w, - A_zero_point, - A_base, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32, - C_uint8_base, - row_offsets, - col_offsets, - bias, - act_times_w_scale); - } - - for (w = std::max(1, w_begin); w < std::min(W_OUT - 1, w_end); ++w) { - depthwise_3x3_per_channel_quantization_kernel_< - FUSE_RELU, - HAS_BIAS, - A_SYMMETRIC, - BIAS_TYPE>( - H, - W, - K, - h, - w, - stride_h, - stride_w, - A_zero_point, - A_base, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32, - C_uint8_base, - row_offsets, - col_offsets, - bias, - act_times_w_scale); - } - - if (w_end == W_OUT) { - w = W_OUT - 1; - depthwise_3x3_per_channel_quantization_kernel_< - FUSE_RELU, - HAS_BIAS, - A_SYMMETRIC, - BIAS_TYPE>( - H, - W, - K, - h, - w, - stride_h, - stride_w, - A_zero_point, - A_base, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32, - C_uint8_base, - row_offsets, - col_offsets, - bias, - act_times_w_scale); - } - } - } // for each n -}; - -// Dispatch A_SYMMETRIC and B_SYMMETRIC -template <bool FUSE_RELU, bool HAS_BIAS, typename BIAS_TYPE> -static void depthwise_3x3_pad_1_( - int N, - int H, - int W, - int K, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - int32_t B_zero_point, - const PackedDepthWiseConvMatrix& B, - float C_multiplier, - int32_t C_zero_point, - uint8_t* C, - const int32_t* col_offsets, - const BIAS_TYPE* bias, - float act_times_w_scale, - int thread_id, - int num_threads) { - int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32]; - if (A_zero_point == 0 || col_offsets == nullptr) { - if (B_zero_point == 0) { - depthwise_3x3_pad_1_< - FUSE_RELU, - HAS_BIAS, - true /*A_symmetric*/, - true /*B_symmetric*/, - BIAS_TYPE>( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C_int32_temp, - C, - col_offsets, - bias, - act_times_w_scale, - thread_id, - num_threads); - } else { - depthwise_3x3_pad_1_< - FUSE_RELU, - HAS_BIAS, - true /*A_symmetric*/, - false /*B_symmetric*/, - BIAS_TYPE>( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C_int32_temp, - C, - col_offsets, - bias, - act_times_w_scale, - thread_id, - num_threads); - } - } else { - if (B_zero_point == 0) { - depthwise_3x3_pad_1_< - FUSE_RELU, - HAS_BIAS, - false /*A_symmetric*/, - true /*B_symmetric*/, - BIAS_TYPE>( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C_int32_temp, - C, - col_offsets, - bias, - act_times_w_scale, - thread_id, - num_threads); - } else { - depthwise_3x3_pad_1_< - FUSE_RELU, - HAS_BIAS, - false /*A_symmetric*/, - false /*B_symmetric*/, - BIAS_TYPE>( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C_int32_temp, - C, - col_offsets, - bias, - act_times_w_scale, - thread_id, - num_threads); - } - } - delete[] C_int32_temp; -} - -// Dispatch HAS_BIAS -template <bool FUSE_RELU, typename BIAS_TYPE> -static void depthwise_3x3_pad_1_( +// Dispatch input shape and FUSE_RELU +template <typename BIAS_TYPE /*=std::int32_t*/> +void depthwise_2d_same_pad( int N, int H, int W, @@ -1143,31 +33,12 @@ static void depthwise_3x3_pad_1_( uint8_t* C, const int32_t* col_offsets, const BIAS_TYPE* bias, + bool fuse_relu, float act_times_w_scale, int thread_id, int num_threads) { - if (bias) { - depthwise_3x3_pad_1_<FUSE_RELU, true /*HAS_BIAS*/, BIAS_TYPE>( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - act_times_w_scale, - thread_id, - num_threads); - } else { - depthwise_3x3_pad_1_<FUSE_RELU, false /*HAS_BIAS*/, BIAS_TYPE>( + if (B.GetKernelProduct() == 3 * 3) { + depthwise_3x3_pad_1( N, H, W, @@ -1183,284 +54,21 @@ static void depthwise_3x3_pad_1_( C, col_offsets, bias, + fuse_relu, act_times_w_scale, thread_id, num_threads); + return; } -} -// Dispatch input shape and FUSE_RELU -// assumption: W > 3 and H > 3 -template <typename BIAS_TYPE> -void depthwise_3x3_pad_1( - int N, - int H, - int W, - int K, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - int32_t B_zero_point, - const PackedDepthWiseConvMatrix& B, - float C_multiplier, - int32_t C_zero_point, - uint8_t* C, - const int32_t* col_offsets, - const BIAS_TYPE* bias, - bool fuse_relu, - float act_times_w_scale, - int thread_id, - int num_threads) { - if (B.GetKernelProduct() != 3 * 3) { + if (B.GetKernelProduct() != 5 * 5) { string msg = "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " + - to_string(3 * 3) + " but has " + to_string(B.GetKernelProduct()); + to_string(5 * 5) + " but has " + to_string(B.GetKernelProduct()); throw logic_error(msg); } - if (stride_h == 0 || stride_w == 0 || num_threads == 0) { - assert(0 && "stride_h == 0 || stride_w == 0 || num_threads == 0"); - return; - } - if (N == 0) { - // In C2, batch size 0 is allowed, so we should just early return. - return; - } if (fuse_relu) { - if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) { - depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - act_times_w_scale, - thread_id, - num_threads); - } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) { - depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - act_times_w_scale, - thread_id, - num_threads); - } else if (1 == stride_h && 1 == stride_w) { - depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - act_times_w_scale, - thread_id, - num_threads); - } else if (2 == stride_h && 2 == stride_w) { - depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - act_times_w_scale, - thread_id, - num_threads); - } else { - depthwise_3x3_pad_1_<true /* FUSE_RELU */, BIAS_TYPE>( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - act_times_w_scale, - thread_id, - num_threads); - } - } else { - if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) { - depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - act_times_w_scale, - thread_id, - num_threads); - } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) { - depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - act_times_w_scale, - thread_id, - num_threads); - } else if (1 == stride_h && 1 == stride_w) { - depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - act_times_w_scale, - thread_id, - num_threads); - } else if (2 == stride_h && 2 == stride_w) { - depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - act_times_w_scale, - thread_id, - num_threads); - } else { - depthwise_3x3_pad_1_<false /* FUSE_RELU */, BIAS_TYPE>( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - act_times_w_scale, - thread_id, - num_threads); - } - } -} - -// Dispatch A_SYMMETRIC -template <bool FUSE_RELU, bool HAS_BIAS, typename BIAS_TYPE> -static void depthwise_3x3_per_channel_quantization_pad_1_( - int N, - int H, - int W, - int K, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - const int32_t* B_zero_point, - const PackedDepthWiseConvMatrix& Bp, - const float* C_multiplier, - int32_t C_zero_point, - uint8_t* C, - const int32_t* col_offsets, - const BIAS_TYPE* bias, - const float* act_times_w_scale, - int thread_id, - int num_threads) { - int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32]; - if (A_zero_point == 0 || col_offsets == nullptr) { - depthwise_3x3_per_channel_quantization_pad_1_< - FUSE_RELU, - HAS_BIAS, - true /*A_SYMM*/, - BIAS_TYPE>( + depthwise_2d_<5, true /* FUSE_RELU */, BIAS_TYPE>( N, H, W, @@ -1470,81 +78,7 @@ static void depthwise_3x3_per_channel_quantization_pad_1_( A_zero_point, A, B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32_temp, - C, - col_offsets, - bias, - act_times_w_scale, - thread_id, - num_threads); - } else { - depthwise_3x3_per_channel_quantization_pad_1_< - FUSE_RELU, - HAS_BIAS, - false /*A_SYMM*/, - BIAS_TYPE>( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C_int32_temp, - C, - col_offsets, - bias, - act_times_w_scale, - thread_id, - num_threads); - } - delete[] C_int32_temp; -} - -// Dispatch HAS_BIAS -template <bool FUSE_RELU, typename BIAS_TYPE> -static void depthwise_3x3_per_channel_quantization_pad_1_( - int N, - int H, - int W, - int K, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - const int32_t* B_zero_point, - const PackedDepthWiseConvMatrix& Bp, - const float* C_multiplier, - int32_t C_zero_point, - uint8_t* C, - const int32_t* col_offsets, - const BIAS_TYPE* bias, - const float* act_times_w_scale, - int thread_id, - int num_threads) { - if (bias) { - depthwise_3x3_per_channel_quantization_pad_1_< - FUSE_RELU, - true /* HAS_BIAS */, - BIAS_TYPE>( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - Bp, + B, C_multiplier, C_zero_point, C, @@ -1554,10 +88,7 @@ static void depthwise_3x3_per_channel_quantization_pad_1_( thread_id, num_threads); } else { - depthwise_3x3_per_channel_quantization_pad_1_< - FUSE_RELU, - false /* HAS_BIAS */, - BIAS_TYPE>( + depthwise_2d_<5, false /* FUSE_RELU */, BIAS_TYPE>( N, H, W, @@ -1567,7 +98,7 @@ static void depthwise_3x3_per_channel_quantization_pad_1_( A_zero_point, A, B_zero_point, - Bp, + B, C_multiplier, C_zero_point, C, @@ -1579,354 +110,7 @@ static void depthwise_3x3_per_channel_quantization_pad_1_( } } -// Dispatch input shape and FUSE_RELU -template <typename BIAS_TYPE> -void depthwise_3x3_per_channel_quantization_pad_1( - int N, - int H, - int W, - int K, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - const int32_t* B_zero_point, - const PackedDepthWiseConvMatrix& Bp, - const float* C_multiplier, - int32_t C_zero_point, - uint8_t* C, - const int32_t* col_offsets, - const BIAS_TYPE* bias, - bool fuse_relu, - const float* act_times_w_scale, - int thread_id, - int num_threads) { - if (Bp.GetKernelProduct() != 3 * 3) { - string msg = - "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " + - to_string(3 * 3) + " but has " + to_string(Bp.GetKernelProduct()); - throw logic_error(msg); - } - if (stride_h == 0 || stride_w == 0 || num_threads == 0) { - assert(0 && "stride_h == 0 || stride_w == 0 || num_threads == 0"); - return; - } - if (N == 0) { - // In C2, batch size 0 is allowed, so we should just early return. - return; - } - if (fuse_relu) { - if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_< - true /* FUSE_RELU */, - BIAS_TYPE>( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - act_times_w_scale, - thread_id, - num_threads); - } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_< - true /* FUSE_RELU */, - BIAS_TYPE>( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - act_times_w_scale, - thread_id, - num_threads); - } else if (1 == stride_h && 1 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_< - true /* FUSE_RELU */, - BIAS_TYPE>( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - act_times_w_scale, - thread_id, - num_threads); - } else if (2 == stride_h && 2 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_< - true /* FUSE_RELU */, - BIAS_TYPE>( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - act_times_w_scale, - thread_id, - num_threads); - } else { - depthwise_3x3_per_channel_quantization_pad_1_< - true /* FUSE_RELU */, - BIAS_TYPE>( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - act_times_w_scale, - thread_id, - num_threads); - } - } else { - if (7 == H && 7 == W && 1 == stride_h && 1 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_< - false /* FUSE_RELU */, - BIAS_TYPE>( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - act_times_w_scale, - thread_id, - num_threads); - } else if (14 == H && 14 == W && 2 == stride_h && 2 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_< - false /* FUSE_RELU */, - BIAS_TYPE>( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - act_times_w_scale, - thread_id, - num_threads); - } else if (1 == stride_h && 1 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_< - false /* FUSE_RELU */, - BIAS_TYPE>( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - act_times_w_scale, - thread_id, - num_threads); - } else if (2 == stride_h && 2 == stride_w) { - depthwise_3x3_per_channel_quantization_pad_1_< - false /* FUSE_RELU */, - BIAS_TYPE>( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - act_times_w_scale, - thread_id, - num_threads); - } else { - depthwise_3x3_per_channel_quantization_pad_1_< - false /* FUSE_RELU */, - BIAS_TYPE>( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - act_times_w_scale, - thread_id, - num_threads); - } - } -} - -// To be removed -void depthwise_3x3_pad_1( - int N, - int H, - int W, - int K, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - int32_t B_zero_point, - const PackedDepthWiseConvMatrix& B, - float C_multiplier, - int32_t C_zero_point, - uint8_t* C, - const int32_t* col_offsets, - const int32_t* bias, - bool fuse_relu, - int thread_id, - int num_threads) { - depthwise_3x3_pad_1<std::int32_t>( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - B, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - fuse_relu, - 1.0f, - thread_id, - num_threads); -} - -// To be removed -void depthwise_3x3_per_channel_quantization_pad_1( - int N, - int H, - int W, - int K, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - const int32_t* B_zero_point, - const PackedDepthWiseConvMatrix& Bp, - const float* C_multiplier, - int32_t C_zero_point, - uint8_t* C, - const int32_t* col_offsets, - const int32_t* bias, - bool fuse_relu, - int thread_id, - int num_threads) { - depthwise_3x3_per_channel_quantization_pad_1<std::int32_t>( - N, - H, - W, - K, - stride_h, - stride_w, - A_zero_point, - A, - B_zero_point, - Bp, - C_multiplier, - C_zero_point, - C, - col_offsets, - bias, - fuse_relu, - nullptr, - thread_id, - num_threads); -} - -template void depthwise_3x3_pad_1( +template void depthwise_2d_same_pad<int32_t>( int N, int H, int W, @@ -1947,7 +131,7 @@ template void depthwise_3x3_pad_1( int thread_id, int num_threads); -template void depthwise_3x3_pad_1( +template void depthwise_2d_same_pad<float>( int N, int H, int W, @@ -1968,46 +152,4 @@ template void depthwise_3x3_pad_1( int thread_id, int num_threads); -template void depthwise_3x3_per_channel_quantization_pad_1( - int N, - int H, - int W, - int K, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - const int32_t* B_zero_point, - const PackedDepthWiseConvMatrix& Bp, - const float* C_multiplier, - int32_t C_zero_point, - uint8_t* C, - const int32_t* col_offsets, - const int32_t* bias, - bool fuse_relu, - const float* act_times_w_scale, - int thread_id, - int num_threads); - -template void depthwise_3x3_per_channel_quantization_pad_1( - int N, - int H, - int W, - int K, - int stride_h, - int stride_w, - int32_t A_zero_point, - const uint8_t* A, - const int32_t* B_zero_point, - const PackedDepthWiseConvMatrix& Bp, - const float* C_multiplier, - int32_t C_zero_point, - uint8_t* C, - const int32_t* col_offsets, - const float* bias, - bool fuse_relu, - const float* act_times_w_scale, - int thread_id, - int num_threads); - } // namespace fbgemm diff --git a/src/FbgemmI8DepthwisePerChannelQuantAvx2.cc b/src/FbgemmI8DepthwisePerChannelQuantAvx2.cc new file mode 100644 index 0000000..f429214 --- /dev/null +++ b/src/FbgemmI8DepthwisePerChannelQuantAvx2.cc @@ -0,0 +1,154 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "fbgemm/FbgemmI8DepthwiseAvx2.h" + +#include <string> + +#include "FbgemmI8Depthwise2DAvx2-inl.h" + +using namespace std; + +namespace fbgemm { + +// Dispatch input shape and FUSE_RELU +template <typename BIAS_TYPE /*=std::int32_t*/> +void depthwise_2d_per_channel_quantization_same_pad( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const PackedDepthWiseConvMatrix& Bp, + const float* C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const BIAS_TYPE* bias, + bool fuse_relu, + const float* act_times_w_scale, + int thread_id, + int num_threads) { + if (Bp.GetKernelProduct() == 3 * 3) { + depthwise_3x3_per_channel_quantization_pad_1( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + fuse_relu, + act_times_w_scale, + thread_id, + num_threads); + return; + } + + if (Bp.GetKernelProduct() != 5 * 5) { + string msg = + "[FBGEMM_CONV_ERROR] Packed weight is expected to have kernel_prod " + + to_string(5 * 5) + " but has " + to_string(Bp.GetKernelProduct()); + throw logic_error(msg); + } + if (fuse_relu) { + depthwise_2d_per_channel_quantization_<5, true /* FUSE_RELU */, BIAS_TYPE>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } else { + depthwise_2d_per_channel_quantization_<5, false /* FUSE_RELU */, BIAS_TYPE>( + N, + H, + W, + K, + stride_h, + stride_w, + A_zero_point, + A, + B_zero_point, + Bp, + C_multiplier, + C_zero_point, + C, + col_offsets, + bias, + act_times_w_scale, + thread_id, + num_threads); + } +} + +template void depthwise_2d_per_channel_quantization_same_pad<int32_t>( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const PackedDepthWiseConvMatrix& Bp, + const float* C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const int32_t* bias, + bool fuse_relu, + const float* act_times_w_scale, + int thread_id, + int num_threads); + +template void depthwise_2d_per_channel_quantization_same_pad<float>( + int N, + int H, + int W, + int K, + int stride_h, + int stride_w, + int32_t A_zero_point, + const uint8_t* A, + const int32_t* B_zero_point, + const PackedDepthWiseConvMatrix& Bp, + const float* C_multiplier, + int32_t C_zero_point, + uint8_t* C, + const int32_t* col_offsets, + const float* bias, + bool fuse_relu, + const float* act_times_w_scale, + int thread_id, + int num_threads); + +} // namespace fbgemm diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc index 396e792..dfd4498 100644 --- a/src/GroupwiseConvAcc32Avx2.cc +++ b/src/GroupwiseConvAcc32Avx2.cc @@ -1764,6 +1764,11 @@ void fbgemmGroupwiseConv( const processOutputType& outProcess, int thread_id, int num_threads) { + // TODO: Remove this when threading is supported. + if (thread_id > 0) { + return; + } + return fbgemmGroupwiseConvBase_< packed_W, outType, @@ -1804,6 +1809,10 @@ void fbgemmGroupwiseConv( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } + // TODO: Remove this when threading is supported. + if (thread_id > 0) { + return; + } if (!fbgemmOptimizedGConv<SPATIAL_DIM>(conv_param) || (!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) { return fbgemmGroupwiseConvBase_< diff --git a/src/PackWeightsForConv.cc b/src/PackWeightsForConv.cc index 192fb00..abbe9ad 100644 --- a/src/PackWeightsForConv.cc +++ b/src/PackWeightsForConv.cc @@ -24,7 +24,9 @@ PackWeightsForConv<SPATIAL_DIM, T, accT>::PackWeightsForConv( switch (ConvFastPath<SPATIAL_DIM, accT>(conv_p)) { case optimized_conv_t::depthwise: { W_dw_packed_ = std::make_shared<PackedDepthWiseConvMatrix>( - conv_p.G, SPATIAL_DIM == 3 ? 3 * 3 * 3 : 3 * 3, sdata); + conv_p.G, + SPATIAL_DIM == 3 ? 3 * 3 * 3 : conv_p.K[0] * conv_p.K[1], + sdata); break; } case optimized_conv_t::groupwise: { diff --git a/src/codegen_fp16fp32.cc b/src/codegen_fp16fp32.cc index 8f80593..edd1e83 100644 --- a/src/codegen_fp16fp32.cc +++ b/src/codegen_fp16fp32.cc @@ -82,8 +82,28 @@ int main() { // {12, 1, 0}, // {13, 1, 0}, // {14, 1, 0}, + }}, + {3, + "AVX512", + { + // 6x2 register layout + {1, 2, 0}, + {2, 2, 0}, + {3, 2, 0}, + {4, 2, 0}, + {5, 2, 0}, + {6, 2, 0}, + {7, 2, 0}, + {8, 2, 0}, + {9, 2, 0}, + {10, 2, 0}, + {11, 2, 0}, + {12, 2, 0}, + {13, 2, 0}, + {14, 2, 0}, }}}; +<<<<<<< HEAD // open all files ofstream srcfile; srcfile.open("FbgemmFP16UKernelsAvx2.cc"); @@ -669,15 +689,330 @@ int main() { fptr_typedef[B_type] = "typedef void (*funcptr_" + B_type + ")" + fargs; - } +======= + for (auto s : isa) { + string isa_file_name = s.avx <= 2 ? "Avx2" : "Avx512"; + + // open all files + ofstream srcfile; + srcfile.open("FbgemmFP16UKernels" + isa_file_name + ".cc"); + srcfile + << "/*\n" + " * Copyright (c) Facebook, Inc. and its affiliates.\n" + " * All rights reserved.\n" + " * This source code is licensed under the BSD-style license found in the\n" + " * LICENSE file in the root directory of this source tree.\n" + " */\n"; + srcfile << "#include \"FbgemmFP16UKernels" + isa_file_name + ".h\"\n\n"; + srcfile << "namespace fbgemm {\n\n"; + if (iaca) { + srcfile << "#include \"iacaMarks.h\"\n"; + } + + ofstream hdrfile; + hdrfile.open("FbgemmFP16UKernels" + isa_file_name + ".h"); + hdrfile + << "/*\n" + " * Copyright (c) Facebook, Inc. and its affiliates.\n" + " * All rights reserved.\n" + " * This source code is licensed under the BSD-style license found in the\n" + " * LICENSE file in the root directory of this source tree.\n" + " */\n"; + + hdrfile << "#pragma once\n"; + hdrfile << "#include <cstdint>\n"; + if (s.avx == 3) { + hdrfile << "#include \"FbgemmFP16UKernelsAvx2.h\"\n"; + } + hdrfile << "#include \"fbgemm/Types.h\"\n\n"; + hdrfile << "namespace fbgemm {\n\n"; + if (s.avx == 2) { + hdrfile << "using fp16 = float16;\n"; + hdrfile << "using fp32 = float;\n"; + hdrfile + << "struct GemmParams {\n uint64_t k;\n float* A;\n const fp16* B;\n" + " float* beta;\n uint64_t accum;\n float* C;\n uint64_t ldc;\n" + " uint64_t b_block_cols;\n uint64_t b_block_size;\n};\n"; + } + + map<string, string> fptr_typedef; + fptr_typedef["fp16"] = ""; + fptr_typedef["fp32"] = ""; + + unsigned labelId = 0; + + bool fixedA = false, fixedB = false, fixedC = false; + bool fp16 = true; + + vector<vector<unsigned>>& ukernel_shape = s.shapes; + + vector<string> funcname(ukernel_shape.size()), + fheader(ukernel_shape.size()); + string fargs; + + string B_type = ((fp16) ? "fp16" : "fp32"); + string prefix = s.name + /*"_" + B_type */ +"_" + "fA" + to_string(fixedA) + + "fB" + to_string(fixedB) + "fC" + to_string(fixedC); + cout << "Generating code for " << s.name << " " << B_type << "\n"; + + string vec_reg_prefix = s.avx <= 2 ? "ymm" : "zmm"; + int num_vec_regs = s.avx <= 2 ? 16 : 32; + int vec_len_in_bytes = s.avx <= 2 ? 32 : 64; + + for (unsigned k = 0; k < ukernel_shape.size(); k++) { + printf("shape: %d x %d * 32\n", ukernel_shape[k][0], ukernel_shape[k][1]); + + string p1 = "GemmParams* gp"; + + funcname[k] = "gemmkernel_" + to_string(ukernel_shape[k][0]) + "x" + + to_string(ukernel_shape[k][1]) + "_"; + funcname[k] += prefix; + + fargs = "(" + p1 + ")"; + + fheader[k] = "void __attribute__((noinline)) " + funcname[k] + fargs; + srcfile << fheader[k] << " {\n"; + + unsigned last_free_vecreg = 0; + // produce register block of C + vector<vector<string>> vCtile(ukernel_shape[k][0]); + for (auto r = 0; r < ukernel_shape[k][0]; r++) + for (auto c = 0; c < ukernel_shape[k][1]; c++) { + vCtile[r].push_back(vec_reg_prefix + to_string(last_free_vecreg)); + last_free_vecreg++; } + assert(last_free_vecreg <= num_vec_regs - 2); - srcfile << "\n} // namespace fbgemm\n"; - srcfile.close(); + string vAtmp = vec_reg_prefix + to_string(last_free_vecreg++); + // produce register block of B col + vector<string> vBcol(ukernel_shape[k][1]); - hdrfile << fptr_typedef["fp16"] << ";\n"; - hdrfile << fptr_typedef["fp32"] << ";\n"; - hdrfile << "\n} // namespace fbgemm\n\n"; - hdrfile << "#endif\n"; - hdrfile.close(); + for (auto c = 0; c < ukernel_shape[k][1]; c++) { + vBcol[c] = (vec_reg_prefix + to_string(last_free_vecreg)); + last_free_vecreg++; + } + + assert(last_free_vecreg <= num_vec_regs); + + srcfile << " asm volatile(\n"; + + srcfile << "#if !defined(__clang__)" + << "\n"; + addi(srcfile, "mov r14, %[gp]"); + srcfile << "#else\n"; + addi(srcfile, "mov %[gp], %%r14"); + addi(srcfile, ".intel_syntax noprefix"); + srcfile << "#endif\n"; + + srcfile << "\n // Copy parameters\n"; + srcfile << " // k\n"; + addi(srcfile, "mov r8, [r14 + 0]"); + srcfile << " // A\n"; + addi(srcfile, "mov r9, [r14 + 8]"); + srcfile << " // B\n"; + addi(srcfile, "mov r10, [r14 + 16]"); + srcfile << " // beta\n"; + addi(srcfile, "mov r15, [r14 + 24]"); + srcfile << " // accum\n"; + addi(srcfile, "mov rdx, [r14 + 32]"); + srcfile << " // C\n"; + addi(srcfile, "mov r12, [r14 + 40]"); + srcfile << " // ldc\n"; + addi(srcfile, "mov r13, [r14 + 48]"); + srcfile << " // b_block_cols\n"; + addi(srcfile, "mov rdi, [r14 + 56]"); + srcfile << " // b_block_size\n"; + addi(srcfile, "mov rsi, [r14 + 64]"); + srcfile << " // Make copies of A and C\n"; + addi(srcfile, "mov rax, r9"); + addi(srcfile, "mov rcx, r12"); + srcfile << "\n"; + + addi(srcfile, "mov rbx, 0"); + + string exitlabel = "L_exit%="; + string label2 = "loop_outter%="; + addi(srcfile, label2 + ":"); + addi(srcfile, "mov r14, 0"); + + // set all vCtile regs to zeros + for (auto r = 0; r < vCtile.size(); r++) { + for (auto c = 0; c < vCtile[r].size(); c++) { + addi( + srcfile, + "vxorps " + vCtile[r][c] + "," + vCtile[r][c] + "," + + vCtile[r][c]); + } + } + + // start marker + if (iaca) { + addi(srcfile, "mov ebx, 111"); + addi(srcfile, ".byte 0x64, 0x67, 0x90"); + } + + srcfile << "\n"; + + srcfile << "\n"; + string label = "loop_inner%="; + addi(srcfile, label + ":"); + srcfile << "\n"; + + for (int c = 0; c < vCtile[0].size(); c++) { + addi( + srcfile, + "vcvtph2ps " + vBcol[c] + "," + (s.avx <= 2 ? "XMM" : "YMM") + + "WORD PTR [r10 + " + to_string(vec_len_in_bytes / 2 * c) + "]"); + } + + for (int r = 0; r < vCtile.size(); r++) { + addi( + srcfile, + "vbroadcastss " + vAtmp + ",DWORD PTR [r9+" + to_string(4 * r) + + "]"); + for (int c = 0; c < vCtile[0].size(); c++) { + addi( + srcfile, + "vfmadd231ps " + vCtile[r][c] + "," + vBcol[c] + "," + vAtmp); + } + } + + addi( + srcfile, + "add r9," + to_string(4 * ukernel_shape[k][0]), + fixedA); // move A ptr + + addi( + srcfile, + "add r10," + to_string(vec_len_in_bytes / 2 * ukernel_shape[k][1]), + fixedA); // move A ptr + + addi(srcfile, "inc r14"); + addi(srcfile, "cmp r14, r8"); + addi(srcfile, "jl " + label); + + srcfile << "\n"; + + addi(srcfile, exitlabel + ":"); + + // addi(srcfile, "add r10, rsi"); + srcfile << "\n"; + + // end marker + if (iaca) { + addi(srcfile, "mov ebx, 222"); + addi(srcfile, ".byte 0x64, 0x67, 0x90"); + } + + addi(srcfile, "cmp rdx, 1"); + addi(srcfile, "je L_accum%="); + srcfile << " // Dump C\n"; + + for (auto r = 0; r < vCtile.size(); r++) { + for (auto c = 0; c < vCtile[r].size(); c++) { + addi( + srcfile, + "vmovups " + vec_reg_prefix + "word PTR [r12 + " + + to_string(vec_len_in_bytes * c) + "], " + vCtile[r][c], + fixedC); + } + addi(srcfile, "add r12, r13", fixedC); // move C ptr + } + addi(srcfile, "jmp L_done%="); + + srcfile << "\n"; + addi(srcfile, "L_accum%=:"); + srcfile << " // Dump C with accumulate\n"; + + string r_spare = + vec_reg_prefix + to_string(num_vec_regs - (s.avx == 1 ? 2 : 1)); + string r_last = vec_reg_prefix + to_string(num_vec_regs - 1); + addi( + srcfile, + "vbroadcastss " + r_spare + string(",DWORD PTR [r15]"), + fixedC); + // store out C + for (auto r = 0; r < vCtile.size(); r++) { + for (auto c = 0; c < vCtile[r].size(); c++) { + switch (s.avx) { + case 1: + addi( + srcfile, + string("vmulps " + r_last + ", ") + r_spare + comma + + "YMMWORD PTR [r12 + " + to_string(vec_len_in_bytes * c) + + "]", + fixedC); + addi( + srcfile, + "vaddps " + vCtile[r][c] + "," + vCtile[r][c] + "," + r_last, + fixedC); + break; + case 2: + case 3: + addi( + srcfile, + "vfmadd231ps " + vCtile[r][c] + "," + r_spare + "," + + vec_reg_prefix + "word PTR [r12 + " + + to_string(vec_len_in_bytes * c) + "]", + fixedC); + break; + default: + assert(0); +>>>>>>> upstream + } + addi( + srcfile, + "vmovups " + vec_reg_prefix + "word PTR [r12 + " + + to_string(vec_len_in_bytes * c) + "], " + vCtile[r][c], + fixedC); + } + addi(srcfile, "add r12, r13", fixedC); // move C ptr + } + + srcfile << "\n"; + addi(srcfile, "L_done%=:"); + + srcfile << "\n // next outer iteration\n"; + // C + addi( + srcfile, + "add rcx, " + to_string(vec_len_in_bytes * ukernel_shape[k][1]), + fixedC); + addi(srcfile, "mov r12, rcx", fixedC); + // A + addi(srcfile, "mov r9, rax"); + + addi(srcfile, "inc rbx"); + addi(srcfile, "cmp rbx, rdi"); + addi(srcfile, "jl " + label2); + + // output + srcfile << " :\n"; + // input + srcfile << " : [gp] \"rm\"(gp)\n"; + + // clobbered + srcfile << " : \"r8\",\n \"r9\",\n \"r10\",\n" + " \"r11\",\n \"r15\",\n \"r13\",\n" + " \"r14\",\n \"rax\",\n \"rcx\",\n" + " \"rdx\",\n \"rsi\",\n \"rdi\",\n" + " \"rbx\",\n \"r12\",\n" + " \"memory\");\n"; + srcfile << "}\n"; + } + + for (unsigned k = 0; k < ukernel_shape.size(); k++) { + hdrfile << fheader[k] << ";\n"; + } + + fptr_typedef[B_type] = "typedef void (*funcptr_" + B_type + ")" + fargs; + + srcfile << "\n} // namespace fbgemm\n"; + srcfile.close(); + + hdrfile << fptr_typedef["fp16"] << ";\n"; + hdrfile << fptr_typedef["fp32"] << ";\n"; + hdrfile << "\n} // namespace fbgemm\n\n"; + hdrfile.close(); + } // isa } |