From 3af8fe54149d9451d593c635d616d2c380e21acb Mon Sep 17 00:00:00 2001 From: Daya S Khudia Date: Wed, 5 Dec 2018 14:13:59 -0800 Subject: Final cleanup for avx2 isolation and consistent file names (#40) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/40 File name changes + removal of -mavx2 compiler flag non-avx files This completes the separation of avx2 code to few files that make minimal use of c++ std lib. Reviewed By: jianyuh Differential Revision: D13330577 fbshipit-source-id: b469ebee484168800ce2d12fd2356edecbf0fa4d --- src/FbgemmFP16.cc | 2 +- src/FbgemmFP16UKernels.cc | 2301 ------------------------------ src/FbgemmFP16UKernels.h | 48 - src/FbgemmFP16UKernelsAvx2.cc | 2301 ++++++++++++++++++++++++++++++ src/FbgemmFP16UKernelsAvx2.h | 46 + src/GenerateKernelU8S8S32ACC16Avx512.cc | 301 ++++ src/GenerateKernelU8S8S32ACC16_avx512.cc | 301 ---- src/GenerateKernelU8S8S32ACC32Avx512.cc | 317 ++++ src/GenerateKernelU8S8S32ACC32_avx512.cc | 317 ---- src/UtilsAvx2.cc | 169 +++ src/UtilsAvx512.cc | 246 ++++ src/Utils_avx2.cc | 169 --- src/Utils_avx512.cc | 246 ---- src/codegen_fp16fp32.cc | 8 +- 14 files changed, 3384 insertions(+), 3388 deletions(-) delete mode 100644 src/FbgemmFP16UKernels.cc delete mode 100644 src/FbgemmFP16UKernels.h create mode 100644 src/FbgemmFP16UKernelsAvx2.cc create mode 100644 src/FbgemmFP16UKernelsAvx2.h create mode 100644 src/GenerateKernelU8S8S32ACC16Avx512.cc delete mode 100644 src/GenerateKernelU8S8S32ACC16_avx512.cc create mode 100644 src/GenerateKernelU8S8S32ACC32Avx512.cc delete mode 100644 src/GenerateKernelU8S8S32ACC32_avx512.cc create mode 100644 src/UtilsAvx2.cc create mode 100644 src/UtilsAvx512.cc delete mode 100644 src/Utils_avx2.cc delete mode 100644 src/Utils_avx512.cc (limited to 'src') diff --git a/src/FbgemmFP16.cc b/src/FbgemmFP16.cc index 6d44c74..2af1f89 100644 --- a/src/FbgemmFP16.cc +++ b/src/FbgemmFP16.cc @@ -10,7 +10,7 @@ #include #include -#include "FbgemmFP16UKernels.h" +#include "FbgemmFP16UKernelsAvx2.h" using namespace std; diff --git a/src/FbgemmFP16UKernels.cc b/src/FbgemmFP16UKernels.cc deleted file mode 100644 index d915765..0000000 --- a/src/FbgemmFP16UKernels.cc +++ /dev/null @@ -1,2301 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * All rights reserved. - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ -#include "FbgemmFP16UKernels.h" - -namespace fbgemm { - -void __attribute__((noinline)) gemmkernel_1x1_AVX2_fA0fB0fC0(GemmParams* gp) { - asm volatile( -#if !defined(__clang__) - "mov r14, %[gp]\t\n" -#else - "mov %[gp], %%r14\t\n" - ".intel_syntax noprefix\t\n" -#endif - - // Copy parameters - // k - "mov r8, [r14 + 0]\t\n" - // A - "mov r9, [r14 + 8]\t\n" - // B - "mov r10, [r14 + 16]\t\n" - // beta - "mov r15, [r14 + 24]\t\n" - // accum - "mov rdx, [r14 + 32]\t\n" - // C - "mov r12, [r14 + 40]\t\n" - // ldc - "mov r13, [r14 + 48]\t\n" - // b_block_cols - "mov rdi, [r14 + 56]\t\n" - // b_block_size - "mov rsi, [r14 + 64]\t\n" - // Make copies of A and C - "mov rax, r9\t\n" - "mov rcx, r12\t\n" - - "mov rbx, 0\t\n" - "loop_outter%=:\t\n" - "mov r14, 0\t\n" - "vxorps ymm0,ymm0,ymm0\t\n" - - "vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" - "mov r11, 16\t\n" - - "loop_inner%=:\t\n" - - "vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" - "inc r14\t\n" - "vbroadcastss ymm1,DWORD PTR [r9+0]\t\n" - "vfmadd231ps ymm0,ymm15,ymm1\t\n" - "cmp r14, r8\t\n" - "jge L_exit%=\t\n" - "vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" - "inc r14\t\n" - "vbroadcastss ymm1,DWORD PTR [r9+4]\t\n" - "vfmadd231ps ymm0,ymm14,ymm1\t\n" - "add r11, 32\t\n" - "add r9,8\t\n" - "cmp r14, r8\t\n" - "jl loop_inner%=\t\n" - - "L_exit%=:\t\n" - "add r10, rsi\t\n" - - "cmp rdx, 1\t\n" - "je L_accum%=\t\n" - // Dump C - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "add r12, r13\t\n" - "jmp L_done%=\t\n" - - "L_accum%=:\t\n" - // Dump C with accumulate - "vbroadcastss ymm15,DWORD PTR [r15]\t\n" - "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "add r12, r13\t\n" - - "L_done%=:\t\n" - - // next outer iteration - "add rcx, 32\t\n" - "mov r12, rcx\t\n" - "mov r9, rax\t\n" - "inc rbx\t\n" - "cmp rbx, rdi\t\n" - "jl loop_outter%=\t\n" - : - : [gp] "rm"(gp) - : "r8", - "r9", - "r10", - "r11", - "r15", - "r13", - "r14", - "rax", - "rcx", - "rdx", - "rsi", - "rdi", - "rbx", - "r12", - "memory"); -} -void __attribute__((noinline)) gemmkernel_2x1_AVX2_fA0fB0fC0(GemmParams* gp) { - asm volatile( -#if !defined(__clang__) - "mov r14, %[gp]\t\n" -#else - "mov %[gp], %%r14\t\n" - ".intel_syntax noprefix\t\n" -#endif - - // Copy parameters - // k - "mov r8, [r14 + 0]\t\n" - // A - "mov r9, [r14 + 8]\t\n" - // B - "mov r10, [r14 + 16]\t\n" - // beta - "mov r15, [r14 + 24]\t\n" - // accum - "mov rdx, [r14 + 32]\t\n" - // C - "mov r12, [r14 + 40]\t\n" - // ldc - "mov r13, [r14 + 48]\t\n" - // b_block_cols - "mov rdi, [r14 + 56]\t\n" - // b_block_size - "mov rsi, [r14 + 64]\t\n" - // Make copies of A and C - "mov rax, r9\t\n" - "mov rcx, r12\t\n" - - "mov rbx, 0\t\n" - "loop_outter%=:\t\n" - "mov r14, 0\t\n" - "vxorps ymm0,ymm0,ymm0\t\n" - "vxorps ymm1,ymm1,ymm1\t\n" - - "vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" - "mov r11, 16\t\n" - - "loop_inner%=:\t\n" - - "vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" - "inc r14\t\n" - "vbroadcastss ymm2,DWORD PTR [r9+0]\t\n" - "vfmadd231ps ymm0,ymm15,ymm2\t\n" - "vbroadcastss ymm2,DWORD PTR [r9+4]\t\n" - "vfmadd231ps ymm1,ymm15,ymm2\t\n" - "cmp r14, r8\t\n" - "jge L_exit%=\t\n" - "vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" - "inc r14\t\n" - "vbroadcastss ymm2,DWORD PTR [r9+8]\t\n" - "vfmadd231ps ymm0,ymm14,ymm2\t\n" - "vbroadcastss ymm2,DWORD PTR [r9+12]\t\n" - "vfmadd231ps ymm1,ymm14,ymm2\t\n" - "add r11, 32\t\n" - "add r9,16\t\n" - "cmp r14, r8\t\n" - "jl loop_inner%=\t\n" - - "L_exit%=:\t\n" - "add r10, rsi\t\n" - - "cmp rdx, 1\t\n" - "je L_accum%=\t\n" - // Dump C - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" - "add r12, r13\t\n" - "jmp L_done%=\t\n" - - "L_accum%=:\t\n" - // Dump C with accumulate - "vbroadcastss ymm15,DWORD PTR [r15]\t\n" - "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" - "add r12, r13\t\n" - - "L_done%=:\t\n" - - // next outer iteration - "add rcx, 32\t\n" - "mov r12, rcx\t\n" - "mov r9, rax\t\n" - "inc rbx\t\n" - "cmp rbx, rdi\t\n" - "jl loop_outter%=\t\n" - : - : [gp] "rm"(gp) - : "r8", - "r9", - "r10", - "r11", - "r15", - "r13", - "r14", - "rax", - "rcx", - "rdx", - "rsi", - "rdi", - "rbx", - "r12", - "memory"); -} -void __attribute__((noinline)) gemmkernel_3x1_AVX2_fA0fB0fC0(GemmParams* gp) { - asm volatile( -#if !defined(__clang__) - "mov r14, %[gp]\t\n" -#else - "mov %[gp], %%r14\t\n" - ".intel_syntax noprefix\t\n" -#endif - - // Copy parameters - // k - "mov r8, [r14 + 0]\t\n" - // A - "mov r9, [r14 + 8]\t\n" - // B - "mov r10, [r14 + 16]\t\n" - // beta - "mov r15, [r14 + 24]\t\n" - // accum - "mov rdx, [r14 + 32]\t\n" - // C - "mov r12, [r14 + 40]\t\n" - // ldc - "mov r13, [r14 + 48]\t\n" - // b_block_cols - "mov rdi, [r14 + 56]\t\n" - // b_block_size - "mov rsi, [r14 + 64]\t\n" - // Make copies of A and C - "mov rax, r9\t\n" - "mov rcx, r12\t\n" - - "mov rbx, 0\t\n" - "loop_outter%=:\t\n" - "mov r14, 0\t\n" - "vxorps ymm0,ymm0,ymm0\t\n" - "vxorps ymm1,ymm1,ymm1\t\n" - "vxorps ymm2,ymm2,ymm2\t\n" - - "vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" - "mov r11, 16\t\n" - - "loop_inner%=:\t\n" - - "vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" - "inc r14\t\n" - "vbroadcastss ymm3,DWORD PTR [r9+0]\t\n" - "vfmadd231ps ymm0,ymm15,ymm3\t\n" - "vbroadcastss ymm3,DWORD PTR [r9+4]\t\n" - "vfmadd231ps ymm1,ymm15,ymm3\t\n" - "vbroadcastss ymm3,DWORD PTR [r9+8]\t\n" - "vfmadd231ps ymm2,ymm15,ymm3\t\n" - "cmp r14, r8\t\n" - "jge L_exit%=\t\n" - "vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" - "inc r14\t\n" - "vbroadcastss ymm3,DWORD PTR [r9+12]\t\n" - "vfmadd231ps ymm0,ymm14,ymm3\t\n" - "vbroadcastss ymm3,DWORD PTR [r9+16]\t\n" - "vfmadd231ps ymm1,ymm14,ymm3\t\n" - "add r11, 32\t\n" - "vbroadcastss ymm3,DWORD PTR [r9+20]\t\n" - "vfmadd231ps ymm2,ymm14,ymm3\t\n" - "add r9,24\t\n" - "cmp r14, r8\t\n" - "jl loop_inner%=\t\n" - - "L_exit%=:\t\n" - "add r10, rsi\t\n" - - "cmp rdx, 1\t\n" - "je L_accum%=\t\n" - // Dump C - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" - "add r12, r13\t\n" - "jmp L_done%=\t\n" - - "L_accum%=:\t\n" - // Dump C with accumulate - "vbroadcastss ymm15,DWORD PTR [r15]\t\n" - "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" - "add r12, r13\t\n" - - "L_done%=:\t\n" - - // next outer iteration - "add rcx, 32\t\n" - "mov r12, rcx\t\n" - "mov r9, rax\t\n" - "inc rbx\t\n" - "cmp rbx, rdi\t\n" - "jl loop_outter%=\t\n" - : - : [gp] "rm"(gp) - : "r8", - "r9", - "r10", - "r11", - "r15", - "r13", - "r14", - "rax", - "rcx", - "rdx", - "rsi", - "rdi", - "rbx", - "r12", - "memory"); -} -void __attribute__((noinline)) gemmkernel_4x1_AVX2_fA0fB0fC0(GemmParams* gp) { - asm volatile( -#if !defined(__clang__) - "mov r14, %[gp]\t\n" -#else - "mov %[gp], %%r14\t\n" - ".intel_syntax noprefix\t\n" -#endif - - // Copy parameters - // k - "mov r8, [r14 + 0]\t\n" - // A - "mov r9, [r14 + 8]\t\n" - // B - "mov r10, [r14 + 16]\t\n" - // beta - "mov r15, [r14 + 24]\t\n" - // accum - "mov rdx, [r14 + 32]\t\n" - // C - "mov r12, [r14 + 40]\t\n" - // ldc - "mov r13, [r14 + 48]\t\n" - // b_block_cols - "mov rdi, [r14 + 56]\t\n" - // b_block_size - "mov rsi, [r14 + 64]\t\n" - // Make copies of A and C - "mov rax, r9\t\n" - "mov rcx, r12\t\n" - - "mov rbx, 0\t\n" - "loop_outter%=:\t\n" - "mov r14, 0\t\n" - "vxorps ymm0,ymm0,ymm0\t\n" - "vxorps ymm1,ymm1,ymm1\t\n" - "vxorps ymm2,ymm2,ymm2\t\n" - "vxorps ymm3,ymm3,ymm3\t\n" - - "vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" - "mov r11, 16\t\n" - - "loop_inner%=:\t\n" - - "vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" - "inc r14\t\n" - "vbroadcastss ymm4,DWORD PTR [r9+0]\t\n" - "vfmadd231ps ymm0,ymm15,ymm4\t\n" - "vbroadcastss ymm4,DWORD PTR [r9+4]\t\n" - "vfmadd231ps ymm1,ymm15,ymm4\t\n" - "vbroadcastss ymm4,DWORD PTR [r9+8]\t\n" - "vfmadd231ps ymm2,ymm15,ymm4\t\n" - "vbroadcastss ymm4,DWORD PTR [r9+12]\t\n" - "vfmadd231ps ymm3,ymm15,ymm4\t\n" - "cmp r14, r8\t\n" - "jge L_exit%=\t\n" - "vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" - "inc r14\t\n" - "vbroadcastss ymm4,DWORD PTR [r9+16]\t\n" - "vfmadd231ps ymm0,ymm14,ymm4\t\n" - "vbroadcastss ymm4,DWORD PTR [r9+20]\t\n" - "vfmadd231ps ymm1,ymm14,ymm4\t\n" - "vbroadcastss ymm4,DWORD PTR [r9+24]\t\n" - "vfmadd231ps ymm2,ymm14,ymm4\t\n" - "add r11, 32\t\n" - "vbroadcastss ymm4,DWORD PTR [r9+28]\t\n" - "vfmadd231ps ymm3,ymm14,ymm4\t\n" - "add r9,32\t\n" - "cmp r14, r8\t\n" - "jl loop_inner%=\t\n" - - "L_exit%=:\t\n" - "add r10, rsi\t\n" - - "cmp rdx, 1\t\n" - "je L_accum%=\t\n" - // Dump C - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" - "add r12, r13\t\n" - "jmp L_done%=\t\n" - - "L_accum%=:\t\n" - // Dump C with accumulate - "vbroadcastss ymm15,DWORD PTR [r15]\t\n" - "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" - "add r12, r13\t\n" - - "L_done%=:\t\n" - - // next outer iteration - "add rcx, 32\t\n" - "mov r12, rcx\t\n" - "mov r9, rax\t\n" - "inc rbx\t\n" - "cmp rbx, rdi\t\n" - "jl loop_outter%=\t\n" - : - : [gp] "rm"(gp) - : "r8", - "r9", - "r10", - "r11", - "r15", - "r13", - "r14", - "rax", - "rcx", - "rdx", - "rsi", - "rdi", - "rbx", - "r12", - "memory"); -} -void __attribute__((noinline)) gemmkernel_5x1_AVX2_fA0fB0fC0(GemmParams* gp) { - asm volatile( -#if !defined(__clang__) - "mov r14, %[gp]\t\n" -#else - "mov %[gp], %%r14\t\n" - ".intel_syntax noprefix\t\n" -#endif - - // Copy parameters - // k - "mov r8, [r14 + 0]\t\n" - // A - "mov r9, [r14 + 8]\t\n" - // B - "mov r10, [r14 + 16]\t\n" - // beta - "mov r15, [r14 + 24]\t\n" - // accum - "mov rdx, [r14 + 32]\t\n" - // C - "mov r12, [r14 + 40]\t\n" - // ldc - "mov r13, [r14 + 48]\t\n" - // b_block_cols - "mov rdi, [r14 + 56]\t\n" - // b_block_size - "mov rsi, [r14 + 64]\t\n" - // Make copies of A and C - "mov rax, r9\t\n" - "mov rcx, r12\t\n" - - "mov rbx, 0\t\n" - "loop_outter%=:\t\n" - "mov r14, 0\t\n" - "vxorps ymm0,ymm0,ymm0\t\n" - "vxorps ymm1,ymm1,ymm1\t\n" - "vxorps ymm2,ymm2,ymm2\t\n" - "vxorps ymm3,ymm3,ymm3\t\n" - "vxorps ymm4,ymm4,ymm4\t\n" - - "vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" - "mov r11, 16\t\n" - - "loop_inner%=:\t\n" - - "vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" - "inc r14\t\n" - "vbroadcastss ymm5,DWORD PTR [r9+0]\t\n" - "vfmadd231ps ymm0,ymm15,ymm5\t\n" - "vbroadcastss ymm5,DWORD PTR [r9+4]\t\n" - "vfmadd231ps ymm1,ymm15,ymm5\t\n" - "vbroadcastss ymm5,DWORD PTR [r9+8]\t\n" - "vfmadd231ps ymm2,ymm15,ymm5\t\n" - "vbroadcastss ymm5,DWORD PTR [r9+12]\t\n" - "vfmadd231ps ymm3,ymm15,ymm5\t\n" - "vbroadcastss ymm5,DWORD PTR [r9+16]\t\n" - "vfmadd231ps ymm4,ymm15,ymm5\t\n" - "cmp r14, r8\t\n" - "jge L_exit%=\t\n" - "vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" - "inc r14\t\n" - "vbroadcastss ymm5,DWORD PTR [r9+20]\t\n" - "vfmadd231ps ymm0,ymm14,ymm5\t\n" - "vbroadcastss ymm5,DWORD PTR [r9+24]\t\n" - "vfmadd231ps ymm1,ymm14,ymm5\t\n" - "vbroadcastss ymm5,DWORD PTR [r9+28]\t\n" - "vfmadd231ps ymm2,ymm14,ymm5\t\n" - "add r11, 32\t\n" - "vbroadcastss ymm5,DWORD PTR [r9+32]\t\n" - "vfmadd231ps ymm3,ymm14,ymm5\t\n" - "vbroadcastss ymm5,DWORD PTR [r9+36]\t\n" - "vfmadd231ps ymm4,ymm14,ymm5\t\n" - "add r9,40\t\n" - "cmp r14, r8\t\n" - "jl loop_inner%=\t\n" - - "L_exit%=:\t\n" - "add r10, rsi\t\n" - - "cmp rdx, 1\t\n" - "je L_accum%=\t\n" - // Dump C - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" - "add r12, r13\t\n" - "jmp L_done%=\t\n" - - "L_accum%=:\t\n" - // Dump C with accumulate - "vbroadcastss ymm15,DWORD PTR [r15]\t\n" - "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" - "add r12, r13\t\n" - - "L_done%=:\t\n" - - // next outer iteration - "add rcx, 32\t\n" - "mov r12, rcx\t\n" - "mov r9, rax\t\n" - "inc rbx\t\n" - "cmp rbx, rdi\t\n" - "jl loop_outter%=\t\n" - : - : [gp] "rm"(gp) - : "r8", - "r9", - "r10", - "r11", - "r15", - "r13", - "r14", - "rax", - "rcx", - "rdx", - "rsi", - "rdi", - "rbx", - "r12", - "memory"); -} -void __attribute__((noinline)) gemmkernel_6x1_AVX2_fA0fB0fC0(GemmParams* gp) { - asm volatile( -#if !defined(__clang__) - "mov r14, %[gp]\t\n" -#else - "mov %[gp], %%r14\t\n" - ".intel_syntax noprefix\t\n" -#endif - - // Copy parameters - // k - "mov r8, [r14 + 0]\t\n" - // A - "mov r9, [r14 + 8]\t\n" - // B - "mov r10, [r14 + 16]\t\n" - // beta - "mov r15, [r14 + 24]\t\n" - // accum - "mov rdx, [r14 + 32]\t\n" - // C - "mov r12, [r14 + 40]\t\n" - // ldc - "mov r13, [r14 + 48]\t\n" - // b_block_cols - "mov rdi, [r14 + 56]\t\n" - // b_block_size - "mov rsi, [r14 + 64]\t\n" - // Make copies of A and C - "mov rax, r9\t\n" - "mov rcx, r12\t\n" - - "mov rbx, 0\t\n" - "loop_outter%=:\t\n" - "mov r14, 0\t\n" - "vxorps ymm0,ymm0,ymm0\t\n" - "vxorps ymm1,ymm1,ymm1\t\n" - "vxorps ymm2,ymm2,ymm2\t\n" - "vxorps ymm3,ymm3,ymm3\t\n" - "vxorps ymm4,ymm4,ymm4\t\n" - "vxorps ymm5,ymm5,ymm5\t\n" - - "vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" - "mov r11, 16\t\n" - - "loop_inner%=:\t\n" - - "vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" - "inc r14\t\n" - "vbroadcastss ymm6,DWORD PTR [r9+0]\t\n" - "vfmadd231ps ymm0,ymm15,ymm6\t\n" - "vbroadcastss ymm6,DWORD PTR [r9+4]\t\n" - "vfmadd231ps ymm1,ymm15,ymm6\t\n" - "vbroadcastss ymm6,DWORD PTR [r9+8]\t\n" - "vfmadd231ps ymm2,ymm15,ymm6\t\n" - "vbroadcastss ymm6,DWORD PTR [r9+12]\t\n" - "vfmadd231ps ymm3,ymm15,ymm6\t\n" - "vbroadcastss ymm6,DWORD PTR [r9+16]\t\n" - "vfmadd231ps ymm4,ymm15,ymm6\t\n" - "vbroadcastss ymm6,DWORD PTR [r9+20]\t\n" - "vfmadd231ps ymm5,ymm15,ymm6\t\n" - "cmp r14, r8\t\n" - "jge L_exit%=\t\n" - "vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" - "inc r14\t\n" - "vbroadcastss ymm6,DWORD PTR [r9+24]\t\n" - "vfmadd231ps ymm0,ymm14,ymm6\t\n" - "vbroadcastss ymm6,DWORD PTR [r9+28]\t\n" - "vfmadd231ps ymm1,ymm14,ymm6\t\n" - "vbroadcastss ymm6,DWORD PTR [r9+32]\t\n" - "vfmadd231ps ymm2,ymm14,ymm6\t\n" - "vbroadcastss ymm6,DWORD PTR [r9+36]\t\n" - "vfmadd231ps ymm3,ymm14,ymm6\t\n" - "add r11, 32\t\n" - "vbroadcastss ymm6,DWORD PTR [r9+40]\t\n" - "vfmadd231ps ymm4,ymm14,ymm6\t\n" - "vbroadcastss ymm6,DWORD PTR [r9+44]\t\n" - "vfmadd231ps ymm5,ymm14,ymm6\t\n" - "add r9,48\t\n" - "cmp r14, r8\t\n" - "jl loop_inner%=\t\n" - - "L_exit%=:\t\n" - "add r10, rsi\t\n" - - "cmp rdx, 1\t\n" - "je L_accum%=\t\n" - // Dump C - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" - "add r12, r13\t\n" - "jmp L_done%=\t\n" - - "L_accum%=:\t\n" - // Dump C with accumulate - "vbroadcastss ymm15,DWORD PTR [r15]\t\n" - "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" - "add r12, r13\t\n" - - "L_done%=:\t\n" - - // next outer iteration - "add rcx, 32\t\n" - "mov r12, rcx\t\n" - "mov r9, rax\t\n" - "inc rbx\t\n" - "cmp rbx, rdi\t\n" - "jl loop_outter%=\t\n" - : - : [gp] "rm"(gp) - : "r8", - "r9", - "r10", - "r11", - "r15", - "r13", - "r14", - "rax", - "rcx", - "rdx", - "rsi", - "rdi", - "rbx", - "r12", - "memory"); -} -void __attribute__((noinline)) gemmkernel_7x1_AVX2_fA0fB0fC0(GemmParams* gp) { - asm volatile( -#if !defined(__clang__) - "mov r14, %[gp]\t\n" -#else - "mov %[gp], %%r14\t\n" - ".intel_syntax noprefix\t\n" -#endif - - // Copy parameters - // k - "mov r8, [r14 + 0]\t\n" - // A - "mov r9, [r14 + 8]\t\n" - // B - "mov r10, [r14 + 16]\t\n" - // beta - "mov r15, [r14 + 24]\t\n" - // accum - "mov rdx, [r14 + 32]\t\n" - // C - "mov r12, [r14 + 40]\t\n" - // ldc - "mov r13, [r14 + 48]\t\n" - // b_block_cols - "mov rdi, [r14 + 56]\t\n" - // b_block_size - "mov rsi, [r14 + 64]\t\n" - // Make copies of A and C - "mov rax, r9\t\n" - "mov rcx, r12\t\n" - - "mov rbx, 0\t\n" - "loop_outter%=:\t\n" - "mov r14, 0\t\n" - "vxorps ymm0,ymm0,ymm0\t\n" - "vxorps ymm1,ymm1,ymm1\t\n" - "vxorps ymm2,ymm2,ymm2\t\n" - "vxorps ymm3,ymm3,ymm3\t\n" - "vxorps ymm4,ymm4,ymm4\t\n" - "vxorps ymm5,ymm5,ymm5\t\n" - "vxorps ymm6,ymm6,ymm6\t\n" - - "vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" - "mov r11, 16\t\n" - - "loop_inner%=:\t\n" - - "vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" - "inc r14\t\n" - "vbroadcastss ymm7,DWORD PTR [r9+0]\t\n" - "vfmadd231ps ymm0,ymm15,ymm7\t\n" - "vbroadcastss ymm7,DWORD PTR [r9+4]\t\n" - "vfmadd231ps ymm1,ymm15,ymm7\t\n" - "vbroadcastss ymm7,DWORD PTR [r9+8]\t\n" - "vfmadd231ps ymm2,ymm15,ymm7\t\n" - "vbroadcastss ymm7,DWORD PTR [r9+12]\t\n" - "vfmadd231ps ymm3,ymm15,ymm7\t\n" - "vbroadcastss ymm7,DWORD PTR [r9+16]\t\n" - "vfmadd231ps ymm4,ymm15,ymm7\t\n" - "vbroadcastss ymm7,DWORD PTR [r9+20]\t\n" - "vfmadd231ps ymm5,ymm15,ymm7\t\n" - "vbroadcastss ymm7,DWORD PTR [r9+24]\t\n" - "vfmadd231ps ymm6,ymm15,ymm7\t\n" - "cmp r14, r8\t\n" - "jge L_exit%=\t\n" - "vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" - "inc r14\t\n" - "vbroadcastss ymm7,DWORD PTR [r9+28]\t\n" - "vfmadd231ps ymm0,ymm14,ymm7\t\n" - "vbroadcastss ymm7,DWORD PTR [r9+32]\t\n" - "vfmadd231ps ymm1,ymm14,ymm7\t\n" - "vbroadcastss ymm7,DWORD PTR [r9+36]\t\n" - "vfmadd231ps ymm2,ymm14,ymm7\t\n" - "vbroadcastss ymm7,DWORD PTR [r9+40]\t\n" - "vfmadd231ps ymm3,ymm14,ymm7\t\n" - "add r11, 32\t\n" - "vbroadcastss ymm7,DWORD PTR [r9+44]\t\n" - "vfmadd231ps ymm4,ymm14,ymm7\t\n" - "vbroadcastss ymm7,DWORD PTR [r9+48]\t\n" - "vfmadd231ps ymm5,ymm14,ymm7\t\n" - "vbroadcastss ymm7,DWORD PTR [r9+52]\t\n" - "vfmadd231ps ymm6,ymm14,ymm7\t\n" - "add r9,56\t\n" - "cmp r14, r8\t\n" - "jl loop_inner%=\t\n" - - "L_exit%=:\t\n" - "add r10, rsi\t\n" - - "cmp rdx, 1\t\n" - "je L_accum%=\t\n" - // Dump C - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" - "add r12, r13\t\n" - "jmp L_done%=\t\n" - - "L_accum%=:\t\n" - // Dump C with accumulate - "vbroadcastss ymm15,DWORD PTR [r15]\t\n" - "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" - "add r12, r13\t\n" - - "L_done%=:\t\n" - - // next outer iteration - "add rcx, 32\t\n" - "mov r12, rcx\t\n" - "mov r9, rax\t\n" - "inc rbx\t\n" - "cmp rbx, rdi\t\n" - "jl loop_outter%=\t\n" - : - : [gp] "rm"(gp) - : "r8", - "r9", - "r10", - "r11", - "r15", - "r13", - "r14", - "rax", - "rcx", - "rdx", - "rsi", - "rdi", - "rbx", - "r12", - "memory"); -} -void __attribute__((noinline)) gemmkernel_8x1_AVX2_fA0fB0fC0(GemmParams* gp) { - asm volatile( -#if !defined(__clang__) - "mov r14, %[gp]\t\n" -#else - "mov %[gp], %%r14\t\n" - ".intel_syntax noprefix\t\n" -#endif - - // Copy parameters - // k - "mov r8, [r14 + 0]\t\n" - // A - "mov r9, [r14 + 8]\t\n" - // B - "mov r10, [r14 + 16]\t\n" - // beta - "mov r15, [r14 + 24]\t\n" - // accum - "mov rdx, [r14 + 32]\t\n" - // C - "mov r12, [r14 + 40]\t\n" - // ldc - "mov r13, [r14 + 48]\t\n" - // b_block_cols - "mov rdi, [r14 + 56]\t\n" - // b_block_size - "mov rsi, [r14 + 64]\t\n" - // Make copies of A and C - "mov rax, r9\t\n" - "mov rcx, r12\t\n" - - "mov rbx, 0\t\n" - "loop_outter%=:\t\n" - "mov r14, 0\t\n" - "vxorps ymm0,ymm0,ymm0\t\n" - "vxorps ymm1,ymm1,ymm1\t\n" - "vxorps ymm2,ymm2,ymm2\t\n" - "vxorps ymm3,ymm3,ymm3\t\n" - "vxorps ymm4,ymm4,ymm4\t\n" - "vxorps ymm5,ymm5,ymm5\t\n" - "vxorps ymm6,ymm6,ymm6\t\n" - "vxorps ymm7,ymm7,ymm7\t\n" - - "vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" - "mov r11, 16\t\n" - - "loop_inner%=:\t\n" - - "vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" - "inc r14\t\n" - "vbroadcastss ymm8,DWORD PTR [r9+0]\t\n" - "vfmadd231ps ymm0,ymm15,ymm8\t\n" - "vbroadcastss ymm8,DWORD PTR [r9+4]\t\n" - "vfmadd231ps ymm1,ymm15,ymm8\t\n" - "vbroadcastss ymm8,DWORD PTR [r9+8]\t\n" - "vfmadd231ps ymm2,ymm15,ymm8\t\n" - "vbroadcastss ymm8,DWORD PTR [r9+12]\t\n" - "vfmadd231ps ymm3,ymm15,ymm8\t\n" - "vbroadcastss ymm8,DWORD PTR [r9+16]\t\n" - "vfmadd231ps ymm4,ymm15,ymm8\t\n" - "vbroadcastss ymm8,DWORD PTR [r9+20]\t\n" - "vfmadd231ps ymm5,ymm15,ymm8\t\n" - "vbroadcastss ymm8,DWORD PTR [r9+24]\t\n" - "vfmadd231ps ymm6,ymm15,ymm8\t\n" - "vbroadcastss ymm8,DWORD PTR [r9+28]\t\n" - "vfmadd231ps ymm7,ymm15,ymm8\t\n" - "cmp r14, r8\t\n" - "jge L_exit%=\t\n" - "vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" - "inc r14\t\n" - "vbroadcastss ymm8,DWORD PTR [r9+32]\t\n" - "vfmadd231ps ymm0,ymm14,ymm8\t\n" - "vbroadcastss ymm8,DWORD PTR [r9+36]\t\n" - "vfmadd231ps ymm1,ymm14,ymm8\t\n" - "vbroadcastss ymm8,DWORD PTR [r9+40]\t\n" - "vfmadd231ps ymm2,ymm14,ymm8\t\n" - "vbroadcastss ymm8,DWORD PTR [r9+44]\t\n" - "vfmadd231ps ymm3,ymm14,ymm8\t\n" - "vbroadcastss ymm8,DWORD PTR [r9+48]\t\n" - "vfmadd231ps ymm4,ymm14,ymm8\t\n" - "add r11, 32\t\n" - "vbroadcastss ymm8,DWORD PTR [r9+52]\t\n" - "vfmadd231ps ymm5,ymm14,ymm8\t\n" - "vbroadcastss ymm8,DWORD PTR [r9+56]\t\n" - "vfmadd231ps ymm6,ymm14,ymm8\t\n" - "vbroadcastss ymm8,DWORD PTR [r9+60]\t\n" - "vfmadd231ps ymm7,ymm14,ymm8\t\n" - "add r9,64\t\n" - "cmp r14, r8\t\n" - "jl loop_inner%=\t\n" - - "L_exit%=:\t\n" - "add r10, rsi\t\n" - - "cmp rdx, 1\t\n" - "je L_accum%=\t\n" - // Dump C - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" - "add r12, r13\t\n" - "jmp L_done%=\t\n" - - "L_accum%=:\t\n" - // Dump C with accumulate - "vbroadcastss ymm15,DWORD PTR [r15]\t\n" - "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" - "add r12, r13\t\n" - - "L_done%=:\t\n" - - // next outer iteration - "add rcx, 32\t\n" - "mov r12, rcx\t\n" - "mov r9, rax\t\n" - "inc rbx\t\n" - "cmp rbx, rdi\t\n" - "jl loop_outter%=\t\n" - : - : [gp] "rm"(gp) - : "r8", - "r9", - "r10", - "r11", - "r15", - "r13", - "r14", - "rax", - "rcx", - "rdx", - "rsi", - "rdi", - "rbx", - "r12", - "memory"); -} -void __attribute__((noinline)) gemmkernel_9x1_AVX2_fA0fB0fC0(GemmParams* gp) { - asm volatile( -#if !defined(__clang__) - "mov r14, %[gp]\t\n" -#else - "mov %[gp], %%r14\t\n" - ".intel_syntax noprefix\t\n" -#endif - - // Copy parameters - // k - "mov r8, [r14 + 0]\t\n" - // A - "mov r9, [r14 + 8]\t\n" - // B - "mov r10, [r14 + 16]\t\n" - // beta - "mov r15, [r14 + 24]\t\n" - // accum - "mov rdx, [r14 + 32]\t\n" - // C - "mov r12, [r14 + 40]\t\n" - // ldc - "mov r13, [r14 + 48]\t\n" - // b_block_cols - "mov rdi, [r14 + 56]\t\n" - // b_block_size - "mov rsi, [r14 + 64]\t\n" - // Make copies of A and C - "mov rax, r9\t\n" - "mov rcx, r12\t\n" - - "mov rbx, 0\t\n" - "loop_outter%=:\t\n" - "mov r14, 0\t\n" - "vxorps ymm0,ymm0,ymm0\t\n" - "vxorps ymm1,ymm1,ymm1\t\n" - "vxorps ymm2,ymm2,ymm2\t\n" - "vxorps ymm3,ymm3,ymm3\t\n" - "vxorps ymm4,ymm4,ymm4\t\n" - "vxorps ymm5,ymm5,ymm5\t\n" - "vxorps ymm6,ymm6,ymm6\t\n" - "vxorps ymm7,ymm7,ymm7\t\n" - "vxorps ymm8,ymm8,ymm8\t\n" - - "vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" - "mov r11, 16\t\n" - - "loop_inner%=:\t\n" - - "vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" - "inc r14\t\n" - "vbroadcastss ymm9,DWORD PTR [r9+0]\t\n" - "vfmadd231ps ymm0,ymm15,ymm9\t\n" - "vbroadcastss ymm9,DWORD PTR [r9+4]\t\n" - "vfmadd231ps ymm1,ymm15,ymm9\t\n" - "vbroadcastss ymm9,DWORD PTR [r9+8]\t\n" - "vfmadd231ps ymm2,ymm15,ymm9\t\n" - "vbroadcastss ymm9,DWORD PTR [r9+12]\t\n" - "vfmadd231ps ymm3,ymm15,ymm9\t\n" - "vbroadcastss ymm9,DWORD PTR [r9+16]\t\n" - "vfmadd231ps ymm4,ymm15,ymm9\t\n" - "vbroadcastss ymm9,DWORD PTR [r9+20]\t\n" - "vfmadd231ps ymm5,ymm15,ymm9\t\n" - "vbroadcastss ymm9,DWORD PTR [r9+24]\t\n" - "vfmadd231ps ymm6,ymm15,ymm9\t\n" - "vbroadcastss ymm9,DWORD PTR [r9+28]\t\n" - "vfmadd231ps ymm7,ymm15,ymm9\t\n" - "vbroadcastss ymm9,DWORD PTR [r9+32]\t\n" - "vfmadd231ps ymm8,ymm15,ymm9\t\n" - "cmp r14, r8\t\n" - "jge L_exit%=\t\n" - "vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" - "inc r14\t\n" - "vbroadcastss ymm9,DWORD PTR [r9+36]\t\n" - "vfmadd231ps ymm0,ymm14,ymm9\t\n" - "vbroadcastss ymm9,DWORD PTR [r9+40]\t\n" - "vfmadd231ps ymm1,ymm14,ymm9\t\n" - "vbroadcastss ymm9,DWORD PTR [r9+44]\t\n" - "vfmadd231ps ymm2,ymm14,ymm9\t\n" - "vbroadcastss ymm9,DWORD PTR [r9+48]\t\n" - "vfmadd231ps ymm3,ymm14,ymm9\t\n" - "vbroadcastss ymm9,DWORD PTR [r9+52]\t\n" - "vfmadd231ps ymm4,ymm14,ymm9\t\n" - "add r11, 32\t\n" - "vbroadcastss ymm9,DWORD PTR [r9+56]\t\n" - "vfmadd231ps ymm5,ymm14,ymm9\t\n" - "vbroadcastss ymm9,DWORD PTR [r9+60]\t\n" - "vfmadd231ps ymm6,ymm14,ymm9\t\n" - "vbroadcastss ymm9,DWORD PTR [r9+64]\t\n" - "vfmadd231ps ymm7,ymm14,ymm9\t\n" - "vbroadcastss ymm9,DWORD PTR [r9+68]\t\n" - "vfmadd231ps ymm8,ymm14,ymm9\t\n" - "add r9,72\t\n" - "cmp r14, r8\t\n" - "jl loop_inner%=\t\n" - - "L_exit%=:\t\n" - "add r10, rsi\t\n" - - "cmp rdx, 1\t\n" - "je L_accum%=\t\n" - // Dump C - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" - "add r12, r13\t\n" - "jmp L_done%=\t\n" - - "L_accum%=:\t\n" - // Dump C with accumulate - "vbroadcastss ymm15,DWORD PTR [r15]\t\n" - "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" - "add r12, r13\t\n" - - "L_done%=:\t\n" - - // next outer iteration - "add rcx, 32\t\n" - "mov r12, rcx\t\n" - "mov r9, rax\t\n" - "inc rbx\t\n" - "cmp rbx, rdi\t\n" - "jl loop_outter%=\t\n" - : - : [gp] "rm"(gp) - : "r8", - "r9", - "r10", - "r11", - "r15", - "r13", - "r14", - "rax", - "rcx", - "rdx", - "rsi", - "rdi", - "rbx", - "r12", - "memory"); -} -void __attribute__((noinline)) gemmkernel_10x1_AVX2_fA0fB0fC0(GemmParams* gp) { - asm volatile( -#if !defined(__clang__) - "mov r14, %[gp]\t\n" -#else - "mov %[gp], %%r14\t\n" - ".intel_syntax noprefix\t\n" -#endif - - // Copy parameters - // k - "mov r8, [r14 + 0]\t\n" - // A - "mov r9, [r14 + 8]\t\n" - // B - "mov r10, [r14 + 16]\t\n" - // beta - "mov r15, [r14 + 24]\t\n" - // accum - "mov rdx, [r14 + 32]\t\n" - // C - "mov r12, [r14 + 40]\t\n" - // ldc - "mov r13, [r14 + 48]\t\n" - // b_block_cols - "mov rdi, [r14 + 56]\t\n" - // b_block_size - "mov rsi, [r14 + 64]\t\n" - // Make copies of A and C - "mov rax, r9\t\n" - "mov rcx, r12\t\n" - - "mov rbx, 0\t\n" - "loop_outter%=:\t\n" - "mov r14, 0\t\n" - "vxorps ymm0,ymm0,ymm0\t\n" - "vxorps ymm1,ymm1,ymm1\t\n" - "vxorps ymm2,ymm2,ymm2\t\n" - "vxorps ymm3,ymm3,ymm3\t\n" - "vxorps ymm4,ymm4,ymm4\t\n" - "vxorps ymm5,ymm5,ymm5\t\n" - "vxorps ymm6,ymm6,ymm6\t\n" - "vxorps ymm7,ymm7,ymm7\t\n" - "vxorps ymm8,ymm8,ymm8\t\n" - "vxorps ymm9,ymm9,ymm9\t\n" - - "vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" - "mov r11, 16\t\n" - - "loop_inner%=:\t\n" - - "vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" - "inc r14\t\n" - "vbroadcastss ymm10,DWORD PTR [r9+0]\t\n" - "vfmadd231ps ymm0,ymm15,ymm10\t\n" - "vbroadcastss ymm10,DWORD PTR [r9+4]\t\n" - "vfmadd231ps ymm1,ymm15,ymm10\t\n" - "vbroadcastss ymm10,DWORD PTR [r9+8]\t\n" - "vfmadd231ps ymm2,ymm15,ymm10\t\n" - "vbroadcastss ymm10,DWORD PTR [r9+12]\t\n" - "vfmadd231ps ymm3,ymm15,ymm10\t\n" - "vbroadcastss ymm10,DWORD PTR [r9+16]\t\n" - "vfmadd231ps ymm4,ymm15,ymm10\t\n" - "vbroadcastss ymm10,DWORD PTR [r9+20]\t\n" - "vfmadd231ps ymm5,ymm15,ymm10\t\n" - "vbroadcastss ymm10,DWORD PTR [r9+24]\t\n" - "vfmadd231ps ymm6,ymm15,ymm10\t\n" - "vbroadcastss ymm10,DWORD PTR [r9+28]\t\n" - "vfmadd231ps ymm7,ymm15,ymm10\t\n" - "vbroadcastss ymm10,DWORD PTR [r9+32]\t\n" - "vfmadd231ps ymm8,ymm15,ymm10\t\n" - "vbroadcastss ymm10,DWORD PTR [r9+36]\t\n" - "vfmadd231ps ymm9,ymm15,ymm10\t\n" - "cmp r14, r8\t\n" - "jge L_exit%=\t\n" - "vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" - "inc r14\t\n" - "vbroadcastss ymm10,DWORD PTR [r9+40]\t\n" - "vfmadd231ps ymm0,ymm14,ymm10\t\n" - "vbroadcastss ymm10,DWORD PTR [r9+44]\t\n" - "vfmadd231ps ymm1,ymm14,ymm10\t\n" - "vbroadcastss ymm10,DWORD PTR [r9+48]\t\n" - "vfmadd231ps ymm2,ymm14,ymm10\t\n" - "vbroadcastss ymm10,DWORD PTR [r9+52]\t\n" - "vfmadd231ps ymm3,ymm14,ymm10\t\n" - "vbroadcastss ymm10,DWORD PTR [r9+56]\t\n" - "vfmadd231ps ymm4,ymm14,ymm10\t\n" - "vbroadcastss ymm10,DWORD PTR [r9+60]\t\n" - "vfmadd231ps ymm5,ymm14,ymm10\t\n" - "add r11, 32\t\n" - "vbroadcastss ymm10,DWORD PTR [r9+64]\t\n" - "vfmadd231ps ymm6,ymm14,ymm10\t\n" - "vbroadcastss ymm10,DWORD PTR [r9+68]\t\n" - "vfmadd231ps ymm7,ymm14,ymm10\t\n" - "vbroadcastss ymm10,DWORD PTR [r9+72]\t\n" - "vfmadd231ps ymm8,ymm14,ymm10\t\n" - "vbroadcastss ymm10,DWORD PTR [r9+76]\t\n" - "vfmadd231ps ymm9,ymm14,ymm10\t\n" - "add r9,80\t\n" - "cmp r14, r8\t\n" - "jl loop_inner%=\t\n" - - "L_exit%=:\t\n" - "add r10, rsi\t\n" - - "cmp rdx, 1\t\n" - "je L_accum%=\t\n" - // Dump C - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" - "add r12, r13\t\n" - "jmp L_done%=\t\n" - - "L_accum%=:\t\n" - // Dump C with accumulate - "vbroadcastss ymm15,DWORD PTR [r15]\t\n" - "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" - "add r12, r13\t\n" - - "L_done%=:\t\n" - - // next outer iteration - "add rcx, 32\t\n" - "mov r12, rcx\t\n" - "mov r9, rax\t\n" - "inc rbx\t\n" - "cmp rbx, rdi\t\n" - "jl loop_outter%=\t\n" - : - : [gp] "rm"(gp) - : "r8", - "r9", - "r10", - "r11", - "r15", - "r13", - "r14", - "rax", - "rcx", - "rdx", - "rsi", - "rdi", - "rbx", - "r12", - "memory"); -} -void __attribute__((noinline)) gemmkernel_11x1_AVX2_fA0fB0fC0(GemmParams* gp) { - asm volatile( -#if !defined(__clang__) - "mov r14, %[gp]\t\n" -#else - "mov %[gp], %%r14\t\n" - ".intel_syntax noprefix\t\n" -#endif - - // Copy parameters - // k - "mov r8, [r14 + 0]\t\n" - // A - "mov r9, [r14 + 8]\t\n" - // B - "mov r10, [r14 + 16]\t\n" - // beta - "mov r15, [r14 + 24]\t\n" - // accum - "mov rdx, [r14 + 32]\t\n" - // C - "mov r12, [r14 + 40]\t\n" - // ldc - "mov r13, [r14 + 48]\t\n" - // b_block_cols - "mov rdi, [r14 + 56]\t\n" - // b_block_size - "mov rsi, [r14 + 64]\t\n" - // Make copies of A and C - "mov rax, r9\t\n" - "mov rcx, r12\t\n" - - "mov rbx, 0\t\n" - "loop_outter%=:\t\n" - "mov r14, 0\t\n" - "vxorps ymm0,ymm0,ymm0\t\n" - "vxorps ymm1,ymm1,ymm1\t\n" - "vxorps ymm2,ymm2,ymm2\t\n" - "vxorps ymm3,ymm3,ymm3\t\n" - "vxorps ymm4,ymm4,ymm4\t\n" - "vxorps ymm5,ymm5,ymm5\t\n" - "vxorps ymm6,ymm6,ymm6\t\n" - "vxorps ymm7,ymm7,ymm7\t\n" - "vxorps ymm8,ymm8,ymm8\t\n" - "vxorps ymm9,ymm9,ymm9\t\n" - "vxorps ymm10,ymm10,ymm10\t\n" - - "vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" - "mov r11, 16\t\n" - - "loop_inner%=:\t\n" - - "vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" - "inc r14\t\n" - "vbroadcastss ymm11,DWORD PTR [r9+0]\t\n" - "vfmadd231ps ymm0,ymm15,ymm11\t\n" - "vbroadcastss ymm11,DWORD PTR [r9+4]\t\n" - "vfmadd231ps ymm1,ymm15,ymm11\t\n" - "vbroadcastss ymm11,DWORD PTR [r9+8]\t\n" - "vfmadd231ps ymm2,ymm15,ymm11\t\n" - "vbroadcastss ymm11,DWORD PTR [r9+12]\t\n" - "vfmadd231ps ymm3,ymm15,ymm11\t\n" - "vbroadcastss ymm11,DWORD PTR [r9+16]\t\n" - "vfmadd231ps ymm4,ymm15,ymm11\t\n" - "vbroadcastss ymm11,DWORD PTR [r9+20]\t\n" - "vfmadd231ps ymm5,ymm15,ymm11\t\n" - "vbroadcastss ymm11,DWORD PTR [r9+24]\t\n" - "vfmadd231ps ymm6,ymm15,ymm11\t\n" - "vbroadcastss ymm11,DWORD PTR [r9+28]\t\n" - "vfmadd231ps ymm7,ymm15,ymm11\t\n" - "vbroadcastss ymm11,DWORD PTR [r9+32]\t\n" - "vfmadd231ps ymm8,ymm15,ymm11\t\n" - "vbroadcastss ymm11,DWORD PTR [r9+36]\t\n" - "vfmadd231ps ymm9,ymm15,ymm11\t\n" - "vbroadcastss ymm11,DWORD PTR [r9+40]\t\n" - "vfmadd231ps ymm10,ymm15,ymm11\t\n" - "cmp r14, r8\t\n" - "jge L_exit%=\t\n" - "vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" - "inc r14\t\n" - "vbroadcastss ymm11,DWORD PTR [r9+44]\t\n" - "vfmadd231ps ymm0,ymm14,ymm11\t\n" - "vbroadcastss ymm11,DWORD PTR [r9+48]\t\n" - "vfmadd231ps ymm1,ymm14,ymm11\t\n" - "vbroadcastss ymm11,DWORD PTR [r9+52]\t\n" - "vfmadd231ps ymm2,ymm14,ymm11\t\n" - "vbroadcastss ymm11,DWORD PTR [r9+56]\t\n" - "vfmadd231ps ymm3,ymm14,ymm11\t\n" - "vbroadcastss ymm11,DWORD PTR [r9+60]\t\n" - "vfmadd231ps ymm4,ymm14,ymm11\t\n" - "vbroadcastss ymm11,DWORD PTR [r9+64]\t\n" - "vfmadd231ps ymm5,ymm14,ymm11\t\n" - "add r11, 32\t\n" - "vbroadcastss ymm11,DWORD PTR [r9+68]\t\n" - "vfmadd231ps ymm6,ymm14,ymm11\t\n" - "vbroadcastss ymm11,DWORD PTR [r9+72]\t\n" - "vfmadd231ps ymm7,ymm14,ymm11\t\n" - "vbroadcastss ymm11,DWORD PTR [r9+76]\t\n" - "vfmadd231ps ymm8,ymm14,ymm11\t\n" - "vbroadcastss ymm11,DWORD PTR [r9+80]\t\n" - "vfmadd231ps ymm9,ymm14,ymm11\t\n" - "vbroadcastss ymm11,DWORD PTR [r9+84]\t\n" - "vfmadd231ps ymm10,ymm14,ymm11\t\n" - "add r9,88\t\n" - "cmp r14, r8\t\n" - "jl loop_inner%=\t\n" - - "L_exit%=:\t\n" - "add r10, rsi\t\n" - - "cmp rdx, 1\t\n" - "je L_accum%=\t\n" - // Dump C - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" - "add r12, r13\t\n" - "jmp L_done%=\t\n" - - "L_accum%=:\t\n" - // Dump C with accumulate - "vbroadcastss ymm15,DWORD PTR [r15]\t\n" - "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm10,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" - "add r12, r13\t\n" - - "L_done%=:\t\n" - - // next outer iteration - "add rcx, 32\t\n" - "mov r12, rcx\t\n" - "mov r9, rax\t\n" - "inc rbx\t\n" - "cmp rbx, rdi\t\n" - "jl loop_outter%=\t\n" - : - : [gp] "rm"(gp) - : "r8", - "r9", - "r10", - "r11", - "r15", - "r13", - "r14", - "rax", - "rcx", - "rdx", - "rsi", - "rdi", - "rbx", - "r12", - "memory"); -} -void __attribute__((noinline)) gemmkernel_12x1_AVX2_fA0fB0fC0(GemmParams* gp) { - asm volatile( -#if !defined(__clang__) - "mov r14, %[gp]\t\n" -#else - "mov %[gp], %%r14\t\n" - ".intel_syntax noprefix\t\n" -#endif - - // Copy parameters - // k - "mov r8, [r14 + 0]\t\n" - // A - "mov r9, [r14 + 8]\t\n" - // B - "mov r10, [r14 + 16]\t\n" - // beta - "mov r15, [r14 + 24]\t\n" - // accum - "mov rdx, [r14 + 32]\t\n" - // C - "mov r12, [r14 + 40]\t\n" - // ldc - "mov r13, [r14 + 48]\t\n" - // b_block_cols - "mov rdi, [r14 + 56]\t\n" - // b_block_size - "mov rsi, [r14 + 64]\t\n" - // Make copies of A and C - "mov rax, r9\t\n" - "mov rcx, r12\t\n" - - "mov rbx, 0\t\n" - "loop_outter%=:\t\n" - "mov r14, 0\t\n" - "vxorps ymm0,ymm0,ymm0\t\n" - "vxorps ymm1,ymm1,ymm1\t\n" - "vxorps ymm2,ymm2,ymm2\t\n" - "vxorps ymm3,ymm3,ymm3\t\n" - "vxorps ymm4,ymm4,ymm4\t\n" - "vxorps ymm5,ymm5,ymm5\t\n" - "vxorps ymm6,ymm6,ymm6\t\n" - "vxorps ymm7,ymm7,ymm7\t\n" - "vxorps ymm8,ymm8,ymm8\t\n" - "vxorps ymm9,ymm9,ymm9\t\n" - "vxorps ymm10,ymm10,ymm10\t\n" - "vxorps ymm11,ymm11,ymm11\t\n" - - "vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" - "mov r11, 16\t\n" - - "loop_inner%=:\t\n" - - "vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" - "inc r14\t\n" - "vbroadcastss ymm12,DWORD PTR [r9+0]\t\n" - "vfmadd231ps ymm0,ymm15,ymm12\t\n" - "vbroadcastss ymm12,DWORD PTR [r9+4]\t\n" - "vfmadd231ps ymm1,ymm15,ymm12\t\n" - "vbroadcastss ymm12,DWORD PTR [r9+8]\t\n" - "vfmadd231ps ymm2,ymm15,ymm12\t\n" - "vbroadcastss ymm12,DWORD PTR [r9+12]\t\n" - "vfmadd231ps ymm3,ymm15,ymm12\t\n" - "vbroadcastss ymm12,DWORD PTR [r9+16]\t\n" - "vfmadd231ps ymm4,ymm15,ymm12\t\n" - "vbroadcastss ymm12,DWORD PTR [r9+20]\t\n" - "vfmadd231ps ymm5,ymm15,ymm12\t\n" - "vbroadcastss ymm12,DWORD PTR [r9+24]\t\n" - "vfmadd231ps ymm6,ymm15,ymm12\t\n" - "vbroadcastss ymm12,DWORD PTR [r9+28]\t\n" - "vfmadd231ps ymm7,ymm15,ymm12\t\n" - "vbroadcastss ymm12,DWORD PTR [r9+32]\t\n" - "vfmadd231ps ymm8,ymm15,ymm12\t\n" - "vbroadcastss ymm12,DWORD PTR [r9+36]\t\n" - "vfmadd231ps ymm9,ymm15,ymm12\t\n" - "vbroadcastss ymm12,DWORD PTR [r9+40]\t\n" - "vfmadd231ps ymm10,ymm15,ymm12\t\n" - "vbroadcastss ymm12,DWORD PTR [r9+44]\t\n" - "vfmadd231ps ymm11,ymm15,ymm12\t\n" - "cmp r14, r8\t\n" - "jge L_exit%=\t\n" - "vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" - "inc r14\t\n" - "vbroadcastss ymm12,DWORD PTR [r9+48]\t\n" - "vfmadd231ps ymm0,ymm14,ymm12\t\n" - "vbroadcastss ymm12,DWORD PTR [r9+52]\t\n" - "vfmadd231ps ymm1,ymm14,ymm12\t\n" - "vbroadcastss ymm12,DWORD PTR [r9+56]\t\n" - "vfmadd231ps ymm2,ymm14,ymm12\t\n" - "vbroadcastss ymm12,DWORD PTR [r9+60]\t\n" - "vfmadd231ps ymm3,ymm14,ymm12\t\n" - "vbroadcastss ymm12,DWORD PTR [r9+64]\t\n" - "vfmadd231ps ymm4,ymm14,ymm12\t\n" - "vbroadcastss ymm12,DWORD PTR [r9+68]\t\n" - "vfmadd231ps ymm5,ymm14,ymm12\t\n" - "vbroadcastss ymm12,DWORD PTR [r9+72]\t\n" - "vfmadd231ps ymm6,ymm14,ymm12\t\n" - "add r11, 32\t\n" - "vbroadcastss ymm12,DWORD PTR [r9+76]\t\n" - "vfmadd231ps ymm7,ymm14,ymm12\t\n" - "vbroadcastss ymm12,DWORD PTR [r9+80]\t\n" - "vfmadd231ps ymm8,ymm14,ymm12\t\n" - "vbroadcastss ymm12,DWORD PTR [r9+84]\t\n" - "vfmadd231ps ymm9,ymm14,ymm12\t\n" - "vbroadcastss ymm12,DWORD PTR [r9+88]\t\n" - "vfmadd231ps ymm10,ymm14,ymm12\t\n" - "vbroadcastss ymm12,DWORD PTR [r9+92]\t\n" - "vfmadd231ps ymm11,ymm14,ymm12\t\n" - "add r9,96\t\n" - "cmp r14, r8\t\n" - "jl loop_inner%=\t\n" - - "L_exit%=:\t\n" - "add r10, rsi\t\n" - - "cmp rdx, 1\t\n" - "je L_accum%=\t\n" - // Dump C - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm11\t\n" - "add r12, r13\t\n" - "jmp L_done%=\t\n" - - "L_accum%=:\t\n" - // Dump C with accumulate - "vbroadcastss ymm15,DWORD PTR [r15]\t\n" - "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm10,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm11,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm11\t\n" - "add r12, r13\t\n" - - "L_done%=:\t\n" - - // next outer iteration - "add rcx, 32\t\n" - "mov r12, rcx\t\n" - "mov r9, rax\t\n" - "inc rbx\t\n" - "cmp rbx, rdi\t\n" - "jl loop_outter%=\t\n" - : - : [gp] "rm"(gp) - : "r8", - "r9", - "r10", - "r11", - "r15", - "r13", - "r14", - "rax", - "rcx", - "rdx", - "rsi", - "rdi", - "rbx", - "r12", - "memory"); -} -void __attribute__((noinline)) gemmkernel_13x1_AVX2_fA0fB0fC0(GemmParams* gp) { - asm volatile( -#if !defined(__clang__) - "mov r14, %[gp]\t\n" -#else - "mov %[gp], %%r14\t\n" - ".intel_syntax noprefix\t\n" -#endif - - // Copy parameters - // k - "mov r8, [r14 + 0]\t\n" - // A - "mov r9, [r14 + 8]\t\n" - // B - "mov r10, [r14 + 16]\t\n" - // beta - "mov r15, [r14 + 24]\t\n" - // accum - "mov rdx, [r14 + 32]\t\n" - // C - "mov r12, [r14 + 40]\t\n" - // ldc - "mov r13, [r14 + 48]\t\n" - // b_block_cols - "mov rdi, [r14 + 56]\t\n" - // b_block_size - "mov rsi, [r14 + 64]\t\n" - // Make copies of A and C - "mov rax, r9\t\n" - "mov rcx, r12\t\n" - - "mov rbx, 0\t\n" - "loop_outter%=:\t\n" - "mov r14, 0\t\n" - "vxorps ymm0,ymm0,ymm0\t\n" - "vxorps ymm1,ymm1,ymm1\t\n" - "vxorps ymm2,ymm2,ymm2\t\n" - "vxorps ymm3,ymm3,ymm3\t\n" - "vxorps ymm4,ymm4,ymm4\t\n" - "vxorps ymm5,ymm5,ymm5\t\n" - "vxorps ymm6,ymm6,ymm6\t\n" - "vxorps ymm7,ymm7,ymm7\t\n" - "vxorps ymm8,ymm8,ymm8\t\n" - "vxorps ymm9,ymm9,ymm9\t\n" - "vxorps ymm10,ymm10,ymm10\t\n" - "vxorps ymm11,ymm11,ymm11\t\n" - "vxorps ymm12,ymm12,ymm12\t\n" - - "vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" - "mov r11, 16\t\n" - - "loop_inner%=:\t\n" - - "vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" - "inc r14\t\n" - "vbroadcastss ymm13,DWORD PTR [r9+0]\t\n" - "vfmadd231ps ymm0,ymm15,ymm13\t\n" - "vbroadcastss ymm13,DWORD PTR [r9+4]\t\n" - "vfmadd231ps ymm1,ymm15,ymm13\t\n" - "vbroadcastss ymm13,DWORD PTR [r9+8]\t\n" - "vfmadd231ps ymm2,ymm15,ymm13\t\n" - "vbroadcastss ymm13,DWORD PTR [r9+12]\t\n" - "vfmadd231ps ymm3,ymm15,ymm13\t\n" - "vbroadcastss ymm13,DWORD PTR [r9+16]\t\n" - "vfmadd231ps ymm4,ymm15,ymm13\t\n" - "vbroadcastss ymm13,DWORD PTR [r9+20]\t\n" - "vfmadd231ps ymm5,ymm15,ymm13\t\n" - "vbroadcastss ymm13,DWORD PTR [r9+24]\t\n" - "vfmadd231ps ymm6,ymm15,ymm13\t\n" - "vbroadcastss ymm13,DWORD PTR [r9+28]\t\n" - "vfmadd231ps ymm7,ymm15,ymm13\t\n" - "vbroadcastss ymm13,DWORD PTR [r9+32]\t\n" - "vfmadd231ps ymm8,ymm15,ymm13\t\n" - "vbroadcastss ymm13,DWORD PTR [r9+36]\t\n" - "vfmadd231ps ymm9,ymm15,ymm13\t\n" - "vbroadcastss ymm13,DWORD PTR [r9+40]\t\n" - "vfmadd231ps ymm10,ymm15,ymm13\t\n" - "vbroadcastss ymm13,DWORD PTR [r9+44]\t\n" - "vfmadd231ps ymm11,ymm15,ymm13\t\n" - "vbroadcastss ymm13,DWORD PTR [r9+48]\t\n" - "vfmadd231ps ymm12,ymm15,ymm13\t\n" - "cmp r14, r8\t\n" - "jge L_exit%=\t\n" - "vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" - "inc r14\t\n" - "vbroadcastss ymm13,DWORD PTR [r9+52]\t\n" - "vfmadd231ps ymm0,ymm14,ymm13\t\n" - "vbroadcastss ymm13,DWORD PTR [r9+56]\t\n" - "vfmadd231ps ymm1,ymm14,ymm13\t\n" - "vbroadcastss ymm13,DWORD PTR [r9+60]\t\n" - "vfmadd231ps ymm2,ymm14,ymm13\t\n" - "vbroadcastss ymm13,DWORD PTR [r9+64]\t\n" - "vfmadd231ps ymm3,ymm14,ymm13\t\n" - "vbroadcastss ymm13,DWORD PTR [r9+68]\t\n" - "vfmadd231ps ymm4,ymm14,ymm13\t\n" - "vbroadcastss ymm13,DWORD PTR [r9+72]\t\n" - "vfmadd231ps ymm5,ymm14,ymm13\t\n" - "vbroadcastss ymm13,DWORD PTR [r9+76]\t\n" - "vfmadd231ps ymm6,ymm14,ymm13\t\n" - "add r11, 32\t\n" - "vbroadcastss ymm13,DWORD PTR [r9+80]\t\n" - "vfmadd231ps ymm7,ymm14,ymm13\t\n" - "vbroadcastss ymm13,DWORD PTR [r9+84]\t\n" - "vfmadd231ps ymm8,ymm14,ymm13\t\n" - "vbroadcastss ymm13,DWORD PTR [r9+88]\t\n" - "vfmadd231ps ymm9,ymm14,ymm13\t\n" - "vbroadcastss ymm13,DWORD PTR [r9+92]\t\n" - "vfmadd231ps ymm10,ymm14,ymm13\t\n" - "vbroadcastss ymm13,DWORD PTR [r9+96]\t\n" - "vfmadd231ps ymm11,ymm14,ymm13\t\n" - "vbroadcastss ymm13,DWORD PTR [r9+100]\t\n" - "vfmadd231ps ymm12,ymm14,ymm13\t\n" - "add r9,104\t\n" - "cmp r14, r8\t\n" - "jl loop_inner%=\t\n" - - "L_exit%=:\t\n" - "add r10, rsi\t\n" - - "cmp rdx, 1\t\n" - "je L_accum%=\t\n" - // Dump C - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm11\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm12\t\n" - "add r12, r13\t\n" - "jmp L_done%=\t\n" - - "L_accum%=:\t\n" - // Dump C with accumulate - "vbroadcastss ymm15,DWORD PTR [r15]\t\n" - "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm10,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm11,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm11\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm12,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm12\t\n" - "add r12, r13\t\n" - - "L_done%=:\t\n" - - // next outer iteration - "add rcx, 32\t\n" - "mov r12, rcx\t\n" - "mov r9, rax\t\n" - "inc rbx\t\n" - "cmp rbx, rdi\t\n" - "jl loop_outter%=\t\n" - : - : [gp] "rm"(gp) - : "r8", - "r9", - "r10", - "r11", - "r15", - "r13", - "r14", - "rax", - "rcx", - "rdx", - "rsi", - "rdi", - "rbx", - "r12", - "memory"); -} -void __attribute__((noinline)) gemmkernel_14x1_AVX2_fA0fB0fC0(GemmParams* gp) { - asm volatile( -#if !defined(__clang__) - "mov r14, %[gp]\t\n" -#else - "mov %[gp], %%r14\t\n" - ".intel_syntax noprefix\t\n" -#endif - - // Copy parameters - // k - "mov r8, [r14 + 0]\t\n" - // A - "mov r9, [r14 + 8]\t\n" - // B - "mov r10, [r14 + 16]\t\n" - // beta - "mov r15, [r14 + 24]\t\n" - // accum - "mov rdx, [r14 + 32]\t\n" - // C - "mov r12, [r14 + 40]\t\n" - // ldc - "mov r13, [r14 + 48]\t\n" - // b_block_cols - "mov rdi, [r14 + 56]\t\n" - // b_block_size - "mov rsi, [r14 + 64]\t\n" - // Make copies of A and C - "mov rax, r9\t\n" - "mov rcx, r12\t\n" - - "mov rbx, 0\t\n" - "loop_outter%=:\t\n" - "mov r14, 0\t\n" - "vxorps ymm0,ymm0,ymm0\t\n" - "vxorps ymm1,ymm1,ymm1\t\n" - "vxorps ymm2,ymm2,ymm2\t\n" - "vxorps ymm3,ymm3,ymm3\t\n" - "vxorps ymm4,ymm4,ymm4\t\n" - "vxorps ymm5,ymm5,ymm5\t\n" - "vxorps ymm6,ymm6,ymm6\t\n" - "vxorps ymm7,ymm7,ymm7\t\n" - "vxorps ymm8,ymm8,ymm8\t\n" - "vxorps ymm9,ymm9,ymm9\t\n" - "vxorps ymm10,ymm10,ymm10\t\n" - "vxorps ymm11,ymm11,ymm11\t\n" - "vxorps ymm12,ymm12,ymm12\t\n" - "vxorps ymm13,ymm13,ymm13\t\n" - - "mov r11, 0\t\n" - - "loop_inner%=:\t\n" - - "vcvtph2ps ymm15,XMMWORD PTR [r10 + r11]\t\n" - "vbroadcastss ymm14,DWORD PTR [r9+0]\t\n" - "vfmadd231ps ymm0,ymm15,ymm14\t\n" - "vbroadcastss ymm14,DWORD PTR [r9+4]\t\n" - "vfmadd231ps ymm1,ymm15,ymm14\t\n" - "vbroadcastss ymm14,DWORD PTR [r9+8]\t\n" - "vfmadd231ps ymm2,ymm15,ymm14\t\n" - "vbroadcastss ymm14,DWORD PTR [r9+12]\t\n" - "vfmadd231ps ymm3,ymm15,ymm14\t\n" - "vbroadcastss ymm14,DWORD PTR [r9+16]\t\n" - "vfmadd231ps ymm4,ymm15,ymm14\t\n" - "vbroadcastss ymm14,DWORD PTR [r9+20]\t\n" - "vfmadd231ps ymm5,ymm15,ymm14\t\n" - "vbroadcastss ymm14,DWORD PTR [r9+24]\t\n" - "vfmadd231ps ymm6,ymm15,ymm14\t\n" - "vbroadcastss ymm14,DWORD PTR [r9+28]\t\n" - "vfmadd231ps ymm7,ymm15,ymm14\t\n" - "vbroadcastss ymm14,DWORD PTR [r9+32]\t\n" - "vfmadd231ps ymm8,ymm15,ymm14\t\n" - "vbroadcastss ymm14,DWORD PTR [r9+36]\t\n" - "vfmadd231ps ymm9,ymm15,ymm14\t\n" - "vbroadcastss ymm14,DWORD PTR [r9+40]\t\n" - "vfmadd231ps ymm10,ymm15,ymm14\t\n" - "vbroadcastss ymm14,DWORD PTR [r9+44]\t\n" - "vfmadd231ps ymm11,ymm15,ymm14\t\n" - "vbroadcastss ymm14,DWORD PTR [r9+48]\t\n" - "vfmadd231ps ymm12,ymm15,ymm14\t\n" - "vbroadcastss ymm14,DWORD PTR [r9+52]\t\n" - "vfmadd231ps ymm13,ymm15,ymm14\t\n" - "add r9,56\t\n" - "add r11, 16\t\n" - "inc r14\t\n" - "cmp r14, r8\t\n" - "jl loop_inner%=\t\n" - "add r10, rsi\t\n" - - "cmp rdx, 1\t\n" - "je L_accum%=\t\n" - // Dump C - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm11\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm12\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm13\t\n" - "add r12, r13\t\n" - "jmp L_done%=\t\n" - - "L_accum%=:\t\n" - // Dump C with accumulate - "vbroadcastss ymm15,DWORD PTR [r15]\t\n" - "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm10,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm11,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm11\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm12,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm12\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm13,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm13\t\n" - "add r12, r13\t\n" - - "L_done%=:\t\n" - - // next outer iteration - "add rcx, 32\t\n" - "mov r12, rcx\t\n" - "mov r9, rax\t\n" - "inc rbx\t\n" - "cmp rbx, rdi\t\n" - "jl loop_outter%=\t\n" - : - : [gp] "rm"(gp) - : "r8", - "r9", - "r10", - "r11", - "r15", - "r13", - "r14", - "rax", - "rcx", - "rdx", - "rsi", - "rdi", - "rbx", - "r12", - "memory"); -} - -} // namespace fbgemm diff --git a/src/FbgemmFP16UKernels.h b/src/FbgemmFP16UKernels.h deleted file mode 100644 index d35d431..0000000 --- a/src/FbgemmFP16UKernels.h +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * All rights reserved. - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ -#ifndef FBGEMM_UKERNELS -#define FBGEMM_UKERNELS -#include -#include -#include -#include "fbgemm/Types.h" - -namespace fbgemm { - -using fp16 = float16; -using fp32 = float; -struct GemmParams { - uint64_t k; - float* A; - const fp16* B; - float* beta; - uint64_t accum; - float* C; - uint64_t ldc; - uint64_t b_block_cols; - uint64_t b_block_size; -}; -void __attribute__((noinline)) gemmkernel_1x1_AVX2_fA0fB0fC0(GemmParams* gp); -void __attribute__((noinline)) gemmkernel_2x1_AVX2_fA0fB0fC0(GemmParams* gp); -void __attribute__((noinline)) gemmkernel_3x1_AVX2_fA0fB0fC0(GemmParams* gp); -void __attribute__((noinline)) gemmkernel_4x1_AVX2_fA0fB0fC0(GemmParams* gp); -void __attribute__((noinline)) gemmkernel_5x1_AVX2_fA0fB0fC0(GemmParams* gp); -void __attribute__((noinline)) gemmkernel_6x1_AVX2_fA0fB0fC0(GemmParams* gp); -void __attribute__((noinline)) gemmkernel_7x1_AVX2_fA0fB0fC0(GemmParams* gp); -void __attribute__((noinline)) gemmkernel_8x1_AVX2_fA0fB0fC0(GemmParams* gp); -void __attribute__((noinline)) gemmkernel_9x1_AVX2_fA0fB0fC0(GemmParams* gp); -void __attribute__((noinline)) gemmkernel_10x1_AVX2_fA0fB0fC0(GemmParams* gp); -void __attribute__((noinline)) gemmkernel_11x1_AVX2_fA0fB0fC0(GemmParams* gp); -void __attribute__((noinline)) gemmkernel_12x1_AVX2_fA0fB0fC0(GemmParams* gp); -void __attribute__((noinline)) gemmkernel_13x1_AVX2_fA0fB0fC0(GemmParams* gp); -void __attribute__((noinline)) gemmkernel_14x1_AVX2_fA0fB0fC0(GemmParams* gp); -typedef void (*funcptr_fp16)(GemmParams* gp); -; - -} // namespace fbgemm - -#endif diff --git a/src/FbgemmFP16UKernelsAvx2.cc b/src/FbgemmFP16UKernelsAvx2.cc new file mode 100644 index 0000000..8a0cb0d --- /dev/null +++ b/src/FbgemmFP16UKernelsAvx2.cc @@ -0,0 +1,2301 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "FbgemmFP16UKernelsAvx2.h" + +namespace fbgemm { + +void __attribute__((noinline)) gemmkernel_1x1_AVX2_fA0fB0fC0(GemmParams* gp) { + asm volatile( +#if !defined(__clang__) + "mov r14, %[gp]\t\n" +#else + "mov %[gp], %%r14\t\n" + ".intel_syntax noprefix\t\n" +#endif + + // Copy parameters + // k + "mov r8, [r14 + 0]\t\n" + // A + "mov r9, [r14 + 8]\t\n" + // B + "mov r10, [r14 + 16]\t\n" + // beta + "mov r15, [r14 + 24]\t\n" + // accum + "mov rdx, [r14 + 32]\t\n" + // C + "mov r12, [r14 + 40]\t\n" + // ldc + "mov r13, [r14 + 48]\t\n" + // b_block_cols + "mov rdi, [r14 + 56]\t\n" + // b_block_size + "mov rsi, [r14 + 64]\t\n" + // Make copies of A and C + "mov rax, r9\t\n" + "mov rcx, r12\t\n" + + "mov rbx, 0\t\n" + "loop_outter%=:\t\n" + "mov r14, 0\t\n" + "vxorps ymm0,ymm0,ymm0\t\n" + + "vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" + "mov r11, 16\t\n" + + "loop_inner%=:\t\n" + + "vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" + "inc r14\t\n" + "vbroadcastss ymm1,DWORD PTR [r9+0]\t\n" + "vfmadd231ps ymm0,ymm15,ymm1\t\n" + "cmp r14, r8\t\n" + "jge L_exit%=\t\n" + "vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" + "inc r14\t\n" + "vbroadcastss ymm1,DWORD PTR [r9+4]\t\n" + "vfmadd231ps ymm0,ymm14,ymm1\t\n" + "add r11, 32\t\n" + "add r9,8\t\n" + "cmp r14, r8\t\n" + "jl loop_inner%=\t\n" + + "L_exit%=:\t\n" + "add r10, rsi\t\n" + + "cmp rdx, 1\t\n" + "je L_accum%=\t\n" + // Dump C + "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + "add r12, r13\t\n" + "jmp L_done%=\t\n" + + "L_accum%=:\t\n" + // Dump C with accumulate + "vbroadcastss ymm15,DWORD PTR [r15]\t\n" + "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + "add r12, r13\t\n" + + "L_done%=:\t\n" + + // next outer iteration + "add rcx, 32\t\n" + "mov r12, rcx\t\n" + "mov r9, rax\t\n" + "inc rbx\t\n" + "cmp rbx, rdi\t\n" + "jl loop_outter%=\t\n" + : + : [gp] "rm"(gp) + : "r8", + "r9", + "r10", + "r11", + "r15", + "r13", + "r14", + "rax", + "rcx", + "rdx", + "rsi", + "rdi", + "rbx", + "r12", + "memory"); +} +void __attribute__((noinline)) gemmkernel_2x1_AVX2_fA0fB0fC0(GemmParams* gp) { + asm volatile( +#if !defined(__clang__) + "mov r14, %[gp]\t\n" +#else + "mov %[gp], %%r14\t\n" + ".intel_syntax noprefix\t\n" +#endif + + // Copy parameters + // k + "mov r8, [r14 + 0]\t\n" + // A + "mov r9, [r14 + 8]\t\n" + // B + "mov r10, [r14 + 16]\t\n" + // beta + "mov r15, [r14 + 24]\t\n" + // accum + "mov rdx, [r14 + 32]\t\n" + // C + "mov r12, [r14 + 40]\t\n" + // ldc + "mov r13, [r14 + 48]\t\n" + // b_block_cols + "mov rdi, [r14 + 56]\t\n" + // b_block_size + "mov rsi, [r14 + 64]\t\n" + // Make copies of A and C + "mov rax, r9\t\n" + "mov rcx, r12\t\n" + + "mov rbx, 0\t\n" + "loop_outter%=:\t\n" + "mov r14, 0\t\n" + "vxorps ymm0,ymm0,ymm0\t\n" + "vxorps ymm1,ymm1,ymm1\t\n" + + "vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" + "mov r11, 16\t\n" + + "loop_inner%=:\t\n" + + "vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" + "inc r14\t\n" + "vbroadcastss ymm2,DWORD PTR [r9+0]\t\n" + "vfmadd231ps ymm0,ymm15,ymm2\t\n" + "vbroadcastss ymm2,DWORD PTR [r9+4]\t\n" + "vfmadd231ps ymm1,ymm15,ymm2\t\n" + "cmp r14, r8\t\n" + "jge L_exit%=\t\n" + "vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" + "inc r14\t\n" + "vbroadcastss ymm2,DWORD PTR [r9+8]\t\n" + "vfmadd231ps ymm0,ymm14,ymm2\t\n" + "vbroadcastss ymm2,DWORD PTR [r9+12]\t\n" + "vfmadd231ps ymm1,ymm14,ymm2\t\n" + "add r11, 32\t\n" + "add r9,16\t\n" + "cmp r14, r8\t\n" + "jl loop_inner%=\t\n" + + "L_exit%=:\t\n" + "add r10, rsi\t\n" + + "cmp rdx, 1\t\n" + "je L_accum%=\t\n" + // Dump C + "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" + "add r12, r13\t\n" + "jmp L_done%=\t\n" + + "L_accum%=:\t\n" + // Dump C with accumulate + "vbroadcastss ymm15,DWORD PTR [r15]\t\n" + "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" + "add r12, r13\t\n" + + "L_done%=:\t\n" + + // next outer iteration + "add rcx, 32\t\n" + "mov r12, rcx\t\n" + "mov r9, rax\t\n" + "inc rbx\t\n" + "cmp rbx, rdi\t\n" + "jl loop_outter%=\t\n" + : + : [gp] "rm"(gp) + : "r8", + "r9", + "r10", + "r11", + "r15", + "r13", + "r14", + "rax", + "rcx", + "rdx", + "rsi", + "rdi", + "rbx", + "r12", + "memory"); +} +void __attribute__((noinline)) gemmkernel_3x1_AVX2_fA0fB0fC0(GemmParams* gp) { + asm volatile( +#if !defined(__clang__) + "mov r14, %[gp]\t\n" +#else + "mov %[gp], %%r14\t\n" + ".intel_syntax noprefix\t\n" +#endif + + // Copy parameters + // k + "mov r8, [r14 + 0]\t\n" + // A + "mov r9, [r14 + 8]\t\n" + // B + "mov r10, [r14 + 16]\t\n" + // beta + "mov r15, [r14 + 24]\t\n" + // accum + "mov rdx, [r14 + 32]\t\n" + // C + "mov r12, [r14 + 40]\t\n" + // ldc + "mov r13, [r14 + 48]\t\n" + // b_block_cols + "mov rdi, [r14 + 56]\t\n" + // b_block_size + "mov rsi, [r14 + 64]\t\n" + // Make copies of A and C + "mov rax, r9\t\n" + "mov rcx, r12\t\n" + + "mov rbx, 0\t\n" + "loop_outter%=:\t\n" + "mov r14, 0\t\n" + "vxorps ymm0,ymm0,ymm0\t\n" + "vxorps ymm1,ymm1,ymm1\t\n" + "vxorps ymm2,ymm2,ymm2\t\n" + + "vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" + "mov r11, 16\t\n" + + "loop_inner%=:\t\n" + + "vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" + "inc r14\t\n" + "vbroadcastss ymm3,DWORD PTR [r9+0]\t\n" + "vfmadd231ps ymm0,ymm15,ymm3\t\n" + "vbroadcastss ymm3,DWORD PTR [r9+4]\t\n" + "vfmadd231ps ymm1,ymm15,ymm3\t\n" + "vbroadcastss ymm3,DWORD PTR [r9+8]\t\n" + "vfmadd231ps ymm2,ymm15,ymm3\t\n" + "cmp r14, r8\t\n" + "jge L_exit%=\t\n" + "vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" + "inc r14\t\n" + "vbroadcastss ymm3,DWORD PTR [r9+12]\t\n" + "vfmadd231ps ymm0,ymm14,ymm3\t\n" + "vbroadcastss ymm3,DWORD PTR [r9+16]\t\n" + "vfmadd231ps ymm1,ymm14,ymm3\t\n" + "add r11, 32\t\n" + "vbroadcastss ymm3,DWORD PTR [r9+20]\t\n" + "vfmadd231ps ymm2,ymm14,ymm3\t\n" + "add r9,24\t\n" + "cmp r14, r8\t\n" + "jl loop_inner%=\t\n" + + "L_exit%=:\t\n" + "add r10, rsi\t\n" + + "cmp rdx, 1\t\n" + "je L_accum%=\t\n" + // Dump C + "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" + "add r12, r13\t\n" + "jmp L_done%=\t\n" + + "L_accum%=:\t\n" + // Dump C with accumulate + "vbroadcastss ymm15,DWORD PTR [r15]\t\n" + "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" + "add r12, r13\t\n" + + "L_done%=:\t\n" + + // next outer iteration + "add rcx, 32\t\n" + "mov r12, rcx\t\n" + "mov r9, rax\t\n" + "inc rbx\t\n" + "cmp rbx, rdi\t\n" + "jl loop_outter%=\t\n" + : + : [gp] "rm"(gp) + : "r8", + "r9", + "r10", + "r11", + "r15", + "r13", + "r14", + "rax", + "rcx", + "rdx", + "rsi", + "rdi", + "rbx", + "r12", + "memory"); +} +void __attribute__((noinline)) gemmkernel_4x1_AVX2_fA0fB0fC0(GemmParams* gp) { + asm volatile( +#if !defined(__clang__) + "mov r14, %[gp]\t\n" +#else + "mov %[gp], %%r14\t\n" + ".intel_syntax noprefix\t\n" +#endif + + // Copy parameters + // k + "mov r8, [r14 + 0]\t\n" + // A + "mov r9, [r14 + 8]\t\n" + // B + "mov r10, [r14 + 16]\t\n" + // beta + "mov r15, [r14 + 24]\t\n" + // accum + "mov rdx, [r14 + 32]\t\n" + // C + "mov r12, [r14 + 40]\t\n" + // ldc + "mov r13, [r14 + 48]\t\n" + // b_block_cols + "mov rdi, [r14 + 56]\t\n" + // b_block_size + "mov rsi, [r14 + 64]\t\n" + // Make copies of A and C + "mov rax, r9\t\n" + "mov rcx, r12\t\n" + + "mov rbx, 0\t\n" + "loop_outter%=:\t\n" + "mov r14, 0\t\n" + "vxorps ymm0,ymm0,ymm0\t\n" + "vxorps ymm1,ymm1,ymm1\t\n" + "vxorps ymm2,ymm2,ymm2\t\n" + "vxorps ymm3,ymm3,ymm3\t\n" + + "vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" + "mov r11, 16\t\n" + + "loop_inner%=:\t\n" + + "vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" + "inc r14\t\n" + "vbroadcastss ymm4,DWORD PTR [r9+0]\t\n" + "vfmadd231ps ymm0,ymm15,ymm4\t\n" + "vbroadcastss ymm4,DWORD PTR [r9+4]\t\n" + "vfmadd231ps ymm1,ymm15,ymm4\t\n" + "vbroadcastss ymm4,DWORD PTR [r9+8]\t\n" + "vfmadd231ps ymm2,ymm15,ymm4\t\n" + "vbroadcastss ymm4,DWORD PTR [r9+12]\t\n" + "vfmadd231ps ymm3,ymm15,ymm4\t\n" + "cmp r14, r8\t\n" + "jge L_exit%=\t\n" + "vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" + "inc r14\t\n" + "vbroadcastss ymm4,DWORD PTR [r9+16]\t\n" + "vfmadd231ps ymm0,ymm14,ymm4\t\n" + "vbroadcastss ymm4,DWORD PTR [r9+20]\t\n" + "vfmadd231ps ymm1,ymm14,ymm4\t\n" + "vbroadcastss ymm4,DWORD PTR [r9+24]\t\n" + "vfmadd231ps ymm2,ymm14,ymm4\t\n" + "add r11, 32\t\n" + "vbroadcastss ymm4,DWORD PTR [r9+28]\t\n" + "vfmadd231ps ymm3,ymm14,ymm4\t\n" + "add r9,32\t\n" + "cmp r14, r8\t\n" + "jl loop_inner%=\t\n" + + "L_exit%=:\t\n" + "add r10, rsi\t\n" + + "cmp rdx, 1\t\n" + "je L_accum%=\t\n" + // Dump C + "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" + "add r12, r13\t\n" + "jmp L_done%=\t\n" + + "L_accum%=:\t\n" + // Dump C with accumulate + "vbroadcastss ymm15,DWORD PTR [r15]\t\n" + "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" + "add r12, r13\t\n" + + "L_done%=:\t\n" + + // next outer iteration + "add rcx, 32\t\n" + "mov r12, rcx\t\n" + "mov r9, rax\t\n" + "inc rbx\t\n" + "cmp rbx, rdi\t\n" + "jl loop_outter%=\t\n" + : + : [gp] "rm"(gp) + : "r8", + "r9", + "r10", + "r11", + "r15", + "r13", + "r14", + "rax", + "rcx", + "rdx", + "rsi", + "rdi", + "rbx", + "r12", + "memory"); +} +void __attribute__((noinline)) gemmkernel_5x1_AVX2_fA0fB0fC0(GemmParams* gp) { + asm volatile( +#if !defined(__clang__) + "mov r14, %[gp]\t\n" +#else + "mov %[gp], %%r14\t\n" + ".intel_syntax noprefix\t\n" +#endif + + // Copy parameters + // k + "mov r8, [r14 + 0]\t\n" + // A + "mov r9, [r14 + 8]\t\n" + // B + "mov r10, [r14 + 16]\t\n" + // beta + "mov r15, [r14 + 24]\t\n" + // accum + "mov rdx, [r14 + 32]\t\n" + // C + "mov r12, [r14 + 40]\t\n" + // ldc + "mov r13, [r14 + 48]\t\n" + // b_block_cols + "mov rdi, [r14 + 56]\t\n" + // b_block_size + "mov rsi, [r14 + 64]\t\n" + // Make copies of A and C + "mov rax, r9\t\n" + "mov rcx, r12\t\n" + + "mov rbx, 0\t\n" + "loop_outter%=:\t\n" + "mov r14, 0\t\n" + "vxorps ymm0,ymm0,ymm0\t\n" + "vxorps ymm1,ymm1,ymm1\t\n" + "vxorps ymm2,ymm2,ymm2\t\n" + "vxorps ymm3,ymm3,ymm3\t\n" + "vxorps ymm4,ymm4,ymm4\t\n" + + "vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" + "mov r11, 16\t\n" + + "loop_inner%=:\t\n" + + "vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" + "inc r14\t\n" + "vbroadcastss ymm5,DWORD PTR [r9+0]\t\n" + "vfmadd231ps ymm0,ymm15,ymm5\t\n" + "vbroadcastss ymm5,DWORD PTR [r9+4]\t\n" + "vfmadd231ps ymm1,ymm15,ymm5\t\n" + "vbroadcastss ymm5,DWORD PTR [r9+8]\t\n" + "vfmadd231ps ymm2,ymm15,ymm5\t\n" + "vbroadcastss ymm5,DWORD PTR [r9+12]\t\n" + "vfmadd231ps ymm3,ymm15,ymm5\t\n" + "vbroadcastss ymm5,DWORD PTR [r9+16]\t\n" + "vfmadd231ps ymm4,ymm15,ymm5\t\n" + "cmp r14, r8\t\n" + "jge L_exit%=\t\n" + "vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" + "inc r14\t\n" + "vbroadcastss ymm5,DWORD PTR [r9+20]\t\n" + "vfmadd231ps ymm0,ymm14,ymm5\t\n" + "vbroadcastss ymm5,DWORD PTR [r9+24]\t\n" + "vfmadd231ps ymm1,ymm14,ymm5\t\n" + "vbroadcastss ymm5,DWORD PTR [r9+28]\t\n" + "vfmadd231ps ymm2,ymm14,ymm5\t\n" + "add r11, 32\t\n" + "vbroadcastss ymm5,DWORD PTR [r9+32]\t\n" + "vfmadd231ps ymm3,ymm14,ymm5\t\n" + "vbroadcastss ymm5,DWORD PTR [r9+36]\t\n" + "vfmadd231ps ymm4,ymm14,ymm5\t\n" + "add r9,40\t\n" + "cmp r14, r8\t\n" + "jl loop_inner%=\t\n" + + "L_exit%=:\t\n" + "add r10, rsi\t\n" + + "cmp rdx, 1\t\n" + "je L_accum%=\t\n" + // Dump C + "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" + "add r12, r13\t\n" + "jmp L_done%=\t\n" + + "L_accum%=:\t\n" + // Dump C with accumulate + "vbroadcastss ymm15,DWORD PTR [r15]\t\n" + "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" + "add r12, r13\t\n" + + "L_done%=:\t\n" + + // next outer iteration + "add rcx, 32\t\n" + "mov r12, rcx\t\n" + "mov r9, rax\t\n" + "inc rbx\t\n" + "cmp rbx, rdi\t\n" + "jl loop_outter%=\t\n" + : + : [gp] "rm"(gp) + : "r8", + "r9", + "r10", + "r11", + "r15", + "r13", + "r14", + "rax", + "rcx", + "rdx", + "rsi", + "rdi", + "rbx", + "r12", + "memory"); +} +void __attribute__((noinline)) gemmkernel_6x1_AVX2_fA0fB0fC0(GemmParams* gp) { + asm volatile( +#if !defined(__clang__) + "mov r14, %[gp]\t\n" +#else + "mov %[gp], %%r14\t\n" + ".intel_syntax noprefix\t\n" +#endif + + // Copy parameters + // k + "mov r8, [r14 + 0]\t\n" + // A + "mov r9, [r14 + 8]\t\n" + // B + "mov r10, [r14 + 16]\t\n" + // beta + "mov r15, [r14 + 24]\t\n" + // accum + "mov rdx, [r14 + 32]\t\n" + // C + "mov r12, [r14 + 40]\t\n" + // ldc + "mov r13, [r14 + 48]\t\n" + // b_block_cols + "mov rdi, [r14 + 56]\t\n" + // b_block_size + "mov rsi, [r14 + 64]\t\n" + // Make copies of A and C + "mov rax, r9\t\n" + "mov rcx, r12\t\n" + + "mov rbx, 0\t\n" + "loop_outter%=:\t\n" + "mov r14, 0\t\n" + "vxorps ymm0,ymm0,ymm0\t\n" + "vxorps ymm1,ymm1,ymm1\t\n" + "vxorps ymm2,ymm2,ymm2\t\n" + "vxorps ymm3,ymm3,ymm3\t\n" + "vxorps ymm4,ymm4,ymm4\t\n" + "vxorps ymm5,ymm5,ymm5\t\n" + + "vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" + "mov r11, 16\t\n" + + "loop_inner%=:\t\n" + + "vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" + "inc r14\t\n" + "vbroadcastss ymm6,DWORD PTR [r9+0]\t\n" + "vfmadd231ps ymm0,ymm15,ymm6\t\n" + "vbroadcastss ymm6,DWORD PTR [r9+4]\t\n" + "vfmadd231ps ymm1,ymm15,ymm6\t\n" + "vbroadcastss ymm6,DWORD PTR [r9+8]\t\n" + "vfmadd231ps ymm2,ymm15,ymm6\t\n" + "vbroadcastss ymm6,DWORD PTR [r9+12]\t\n" + "vfmadd231ps ymm3,ymm15,ymm6\t\n" + "vbroadcastss ymm6,DWORD PTR [r9+16]\t\n" + "vfmadd231ps ymm4,ymm15,ymm6\t\n" + "vbroadcastss ymm6,DWORD PTR [r9+20]\t\n" + "vfmadd231ps ymm5,ymm15,ymm6\t\n" + "cmp r14, r8\t\n" + "jge L_exit%=\t\n" + "vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" + "inc r14\t\n" + "vbroadcastss ymm6,DWORD PTR [r9+24]\t\n" + "vfmadd231ps ymm0,ymm14,ymm6\t\n" + "vbroadcastss ymm6,DWORD PTR [r9+28]\t\n" + "vfmadd231ps ymm1,ymm14,ymm6\t\n" + "vbroadcastss ymm6,DWORD PTR [r9+32]\t\n" + "vfmadd231ps ymm2,ymm14,ymm6\t\n" + "vbroadcastss ymm6,DWORD PTR [r9+36]\t\n" + "vfmadd231ps ymm3,ymm14,ymm6\t\n" + "add r11, 32\t\n" + "vbroadcastss ymm6,DWORD PTR [r9+40]\t\n" + "vfmadd231ps ymm4,ymm14,ymm6\t\n" + "vbroadcastss ymm6,DWORD PTR [r9+44]\t\n" + "vfmadd231ps ymm5,ymm14,ymm6\t\n" + "add r9,48\t\n" + "cmp r14, r8\t\n" + "jl loop_inner%=\t\n" + + "L_exit%=:\t\n" + "add r10, rsi\t\n" + + "cmp rdx, 1\t\n" + "je L_accum%=\t\n" + // Dump C + "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" + "add r12, r13\t\n" + "jmp L_done%=\t\n" + + "L_accum%=:\t\n" + // Dump C with accumulate + "vbroadcastss ymm15,DWORD PTR [r15]\t\n" + "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" + "add r12, r13\t\n" + + "L_done%=:\t\n" + + // next outer iteration + "add rcx, 32\t\n" + "mov r12, rcx\t\n" + "mov r9, rax\t\n" + "inc rbx\t\n" + "cmp rbx, rdi\t\n" + "jl loop_outter%=\t\n" + : + : [gp] "rm"(gp) + : "r8", + "r9", + "r10", + "r11", + "r15", + "r13", + "r14", + "rax", + "rcx", + "rdx", + "rsi", + "rdi", + "rbx", + "r12", + "memory"); +} +void __attribute__((noinline)) gemmkernel_7x1_AVX2_fA0fB0fC0(GemmParams* gp) { + asm volatile( +#if !defined(__clang__) + "mov r14, %[gp]\t\n" +#else + "mov %[gp], %%r14\t\n" + ".intel_syntax noprefix\t\n" +#endif + + // Copy parameters + // k + "mov r8, [r14 + 0]\t\n" + // A + "mov r9, [r14 + 8]\t\n" + // B + "mov r10, [r14 + 16]\t\n" + // beta + "mov r15, [r14 + 24]\t\n" + // accum + "mov rdx, [r14 + 32]\t\n" + // C + "mov r12, [r14 + 40]\t\n" + // ldc + "mov r13, [r14 + 48]\t\n" + // b_block_cols + "mov rdi, [r14 + 56]\t\n" + // b_block_size + "mov rsi, [r14 + 64]\t\n" + // Make copies of A and C + "mov rax, r9\t\n" + "mov rcx, r12\t\n" + + "mov rbx, 0\t\n" + "loop_outter%=:\t\n" + "mov r14, 0\t\n" + "vxorps ymm0,ymm0,ymm0\t\n" + "vxorps ymm1,ymm1,ymm1\t\n" + "vxorps ymm2,ymm2,ymm2\t\n" + "vxorps ymm3,ymm3,ymm3\t\n" + "vxorps ymm4,ymm4,ymm4\t\n" + "vxorps ymm5,ymm5,ymm5\t\n" + "vxorps ymm6,ymm6,ymm6\t\n" + + "vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" + "mov r11, 16\t\n" + + "loop_inner%=:\t\n" + + "vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" + "inc r14\t\n" + "vbroadcastss ymm7,DWORD PTR [r9+0]\t\n" + "vfmadd231ps ymm0,ymm15,ymm7\t\n" + "vbroadcastss ymm7,DWORD PTR [r9+4]\t\n" + "vfmadd231ps ymm1,ymm15,ymm7\t\n" + "vbroadcastss ymm7,DWORD PTR [r9+8]\t\n" + "vfmadd231ps ymm2,ymm15,ymm7\t\n" + "vbroadcastss ymm7,DWORD PTR [r9+12]\t\n" + "vfmadd231ps ymm3,ymm15,ymm7\t\n" + "vbroadcastss ymm7,DWORD PTR [r9+16]\t\n" + "vfmadd231ps ymm4,ymm15,ymm7\t\n" + "vbroadcastss ymm7,DWORD PTR [r9+20]\t\n" + "vfmadd231ps ymm5,ymm15,ymm7\t\n" + "vbroadcastss ymm7,DWORD PTR [r9+24]\t\n" + "vfmadd231ps ymm6,ymm15,ymm7\t\n" + "cmp r14, r8\t\n" + "jge L_exit%=\t\n" + "vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" + "inc r14\t\n" + "vbroadcastss ymm7,DWORD PTR [r9+28]\t\n" + "vfmadd231ps ymm0,ymm14,ymm7\t\n" + "vbroadcastss ymm7,DWORD PTR [r9+32]\t\n" + "vfmadd231ps ymm1,ymm14,ymm7\t\n" + "vbroadcastss ymm7,DWORD PTR [r9+36]\t\n" + "vfmadd231ps ymm2,ymm14,ymm7\t\n" + "vbroadcastss ymm7,DWORD PTR [r9+40]\t\n" + "vfmadd231ps ymm3,ymm14,ymm7\t\n" + "add r11, 32\t\n" + "vbroadcastss ymm7,DWORD PTR [r9+44]\t\n" + "vfmadd231ps ymm4,ymm14,ymm7\t\n" + "vbroadcastss ymm7,DWORD PTR [r9+48]\t\n" + "vfmadd231ps ymm5,ymm14,ymm7\t\n" + "vbroadcastss ymm7,DWORD PTR [r9+52]\t\n" + "vfmadd231ps ymm6,ymm14,ymm7\t\n" + "add r9,56\t\n" + "cmp r14, r8\t\n" + "jl loop_inner%=\t\n" + + "L_exit%=:\t\n" + "add r10, rsi\t\n" + + "cmp rdx, 1\t\n" + "je L_accum%=\t\n" + // Dump C + "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" + "add r12, r13\t\n" + "jmp L_done%=\t\n" + + "L_accum%=:\t\n" + // Dump C with accumulate + "vbroadcastss ymm15,DWORD PTR [r15]\t\n" + "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" + "add r12, r13\t\n" + + "L_done%=:\t\n" + + // next outer iteration + "add rcx, 32\t\n" + "mov r12, rcx\t\n" + "mov r9, rax\t\n" + "inc rbx\t\n" + "cmp rbx, rdi\t\n" + "jl loop_outter%=\t\n" + : + : [gp] "rm"(gp) + : "r8", + "r9", + "r10", + "r11", + "r15", + "r13", + "r14", + "rax", + "rcx", + "rdx", + "rsi", + "rdi", + "rbx", + "r12", + "memory"); +} +void __attribute__((noinline)) gemmkernel_8x1_AVX2_fA0fB0fC0(GemmParams* gp) { + asm volatile( +#if !defined(__clang__) + "mov r14, %[gp]\t\n" +#else + "mov %[gp], %%r14\t\n" + ".intel_syntax noprefix\t\n" +#endif + + // Copy parameters + // k + "mov r8, [r14 + 0]\t\n" + // A + "mov r9, [r14 + 8]\t\n" + // B + "mov r10, [r14 + 16]\t\n" + // beta + "mov r15, [r14 + 24]\t\n" + // accum + "mov rdx, [r14 + 32]\t\n" + // C + "mov r12, [r14 + 40]\t\n" + // ldc + "mov r13, [r14 + 48]\t\n" + // b_block_cols + "mov rdi, [r14 + 56]\t\n" + // b_block_size + "mov rsi, [r14 + 64]\t\n" + // Make copies of A and C + "mov rax, r9\t\n" + "mov rcx, r12\t\n" + + "mov rbx, 0\t\n" + "loop_outter%=:\t\n" + "mov r14, 0\t\n" + "vxorps ymm0,ymm0,ymm0\t\n" + "vxorps ymm1,ymm1,ymm1\t\n" + "vxorps ymm2,ymm2,ymm2\t\n" + "vxorps ymm3,ymm3,ymm3\t\n" + "vxorps ymm4,ymm4,ymm4\t\n" + "vxorps ymm5,ymm5,ymm5\t\n" + "vxorps ymm6,ymm6,ymm6\t\n" + "vxorps ymm7,ymm7,ymm7\t\n" + + "vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" + "mov r11, 16\t\n" + + "loop_inner%=:\t\n" + + "vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" + "inc r14\t\n" + "vbroadcastss ymm8,DWORD PTR [r9+0]\t\n" + "vfmadd231ps ymm0,ymm15,ymm8\t\n" + "vbroadcastss ymm8,DWORD PTR [r9+4]\t\n" + "vfmadd231ps ymm1,ymm15,ymm8\t\n" + "vbroadcastss ymm8,DWORD PTR [r9+8]\t\n" + "vfmadd231ps ymm2,ymm15,ymm8\t\n" + "vbroadcastss ymm8,DWORD PTR [r9+12]\t\n" + "vfmadd231ps ymm3,ymm15,ymm8\t\n" + "vbroadcastss ymm8,DWORD PTR [r9+16]\t\n" + "vfmadd231ps ymm4,ymm15,ymm8\t\n" + "vbroadcastss ymm8,DWORD PTR [r9+20]\t\n" + "vfmadd231ps ymm5,ymm15,ymm8\t\n" + "vbroadcastss ymm8,DWORD PTR [r9+24]\t\n" + "vfmadd231ps ymm6,ymm15,ymm8\t\n" + "vbroadcastss ymm8,DWORD PTR [r9+28]\t\n" + "vfmadd231ps ymm7,ymm15,ymm8\t\n" + "cmp r14, r8\t\n" + "jge L_exit%=\t\n" + "vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" + "inc r14\t\n" + "vbroadcastss ymm8,DWORD PTR [r9+32]\t\n" + "vfmadd231ps ymm0,ymm14,ymm8\t\n" + "vbroadcastss ymm8,DWORD PTR [r9+36]\t\n" + "vfmadd231ps ymm1,ymm14,ymm8\t\n" + "vbroadcastss ymm8,DWORD PTR [r9+40]\t\n" + "vfmadd231ps ymm2,ymm14,ymm8\t\n" + "vbroadcastss ymm8,DWORD PTR [r9+44]\t\n" + "vfmadd231ps ymm3,ymm14,ymm8\t\n" + "vbroadcastss ymm8,DWORD PTR [r9+48]\t\n" + "vfmadd231ps ymm4,ymm14,ymm8\t\n" + "add r11, 32\t\n" + "vbroadcastss ymm8,DWORD PTR [r9+52]\t\n" + "vfmadd231ps ymm5,ymm14,ymm8\t\n" + "vbroadcastss ymm8,DWORD PTR [r9+56]\t\n" + "vfmadd231ps ymm6,ymm14,ymm8\t\n" + "vbroadcastss ymm8,DWORD PTR [r9+60]\t\n" + "vfmadd231ps ymm7,ymm14,ymm8\t\n" + "add r9,64\t\n" + "cmp r14, r8\t\n" + "jl loop_inner%=\t\n" + + "L_exit%=:\t\n" + "add r10, rsi\t\n" + + "cmp rdx, 1\t\n" + "je L_accum%=\t\n" + // Dump C + "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" + "add r12, r13\t\n" + "jmp L_done%=\t\n" + + "L_accum%=:\t\n" + // Dump C with accumulate + "vbroadcastss ymm15,DWORD PTR [r15]\t\n" + "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" + "add r12, r13\t\n" + + "L_done%=:\t\n" + + // next outer iteration + "add rcx, 32\t\n" + "mov r12, rcx\t\n" + "mov r9, rax\t\n" + "inc rbx\t\n" + "cmp rbx, rdi\t\n" + "jl loop_outter%=\t\n" + : + : [gp] "rm"(gp) + : "r8", + "r9", + "r10", + "r11", + "r15", + "r13", + "r14", + "rax", + "rcx", + "rdx", + "rsi", + "rdi", + "rbx", + "r12", + "memory"); +} +void __attribute__((noinline)) gemmkernel_9x1_AVX2_fA0fB0fC0(GemmParams* gp) { + asm volatile( +#if !defined(__clang__) + "mov r14, %[gp]\t\n" +#else + "mov %[gp], %%r14\t\n" + ".intel_syntax noprefix\t\n" +#endif + + // Copy parameters + // k + "mov r8, [r14 + 0]\t\n" + // A + "mov r9, [r14 + 8]\t\n" + // B + "mov r10, [r14 + 16]\t\n" + // beta + "mov r15, [r14 + 24]\t\n" + // accum + "mov rdx, [r14 + 32]\t\n" + // C + "mov r12, [r14 + 40]\t\n" + // ldc + "mov r13, [r14 + 48]\t\n" + // b_block_cols + "mov rdi, [r14 + 56]\t\n" + // b_block_size + "mov rsi, [r14 + 64]\t\n" + // Make copies of A and C + "mov rax, r9\t\n" + "mov rcx, r12\t\n" + + "mov rbx, 0\t\n" + "loop_outter%=:\t\n" + "mov r14, 0\t\n" + "vxorps ymm0,ymm0,ymm0\t\n" + "vxorps ymm1,ymm1,ymm1\t\n" + "vxorps ymm2,ymm2,ymm2\t\n" + "vxorps ymm3,ymm3,ymm3\t\n" + "vxorps ymm4,ymm4,ymm4\t\n" + "vxorps ymm5,ymm5,ymm5\t\n" + "vxorps ymm6,ymm6,ymm6\t\n" + "vxorps ymm7,ymm7,ymm7\t\n" + "vxorps ymm8,ymm8,ymm8\t\n" + + "vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" + "mov r11, 16\t\n" + + "loop_inner%=:\t\n" + + "vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" + "inc r14\t\n" + "vbroadcastss ymm9,DWORD PTR [r9+0]\t\n" + "vfmadd231ps ymm0,ymm15,ymm9\t\n" + "vbroadcastss ymm9,DWORD PTR [r9+4]\t\n" + "vfmadd231ps ymm1,ymm15,ymm9\t\n" + "vbroadcastss ymm9,DWORD PTR [r9+8]\t\n" + "vfmadd231ps ymm2,ymm15,ymm9\t\n" + "vbroadcastss ymm9,DWORD PTR [r9+12]\t\n" + "vfmadd231ps ymm3,ymm15,ymm9\t\n" + "vbroadcastss ymm9,DWORD PTR [r9+16]\t\n" + "vfmadd231ps ymm4,ymm15,ymm9\t\n" + "vbroadcastss ymm9,DWORD PTR [r9+20]\t\n" + "vfmadd231ps ymm5,ymm15,ymm9\t\n" + "vbroadcastss ymm9,DWORD PTR [r9+24]\t\n" + "vfmadd231ps ymm6,ymm15,ymm9\t\n" + "vbroadcastss ymm9,DWORD PTR [r9+28]\t\n" + "vfmadd231ps ymm7,ymm15,ymm9\t\n" + "vbroadcastss ymm9,DWORD PTR [r9+32]\t\n" + "vfmadd231ps ymm8,ymm15,ymm9\t\n" + "cmp r14, r8\t\n" + "jge L_exit%=\t\n" + "vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" + "inc r14\t\n" + "vbroadcastss ymm9,DWORD PTR [r9+36]\t\n" + "vfmadd231ps ymm0,ymm14,ymm9\t\n" + "vbroadcastss ymm9,DWORD PTR [r9+40]\t\n" + "vfmadd231ps ymm1,ymm14,ymm9\t\n" + "vbroadcastss ymm9,DWORD PTR [r9+44]\t\n" + "vfmadd231ps ymm2,ymm14,ymm9\t\n" + "vbroadcastss ymm9,DWORD PTR [r9+48]\t\n" + "vfmadd231ps ymm3,ymm14,ymm9\t\n" + "vbroadcastss ymm9,DWORD PTR [r9+52]\t\n" + "vfmadd231ps ymm4,ymm14,ymm9\t\n" + "add r11, 32\t\n" + "vbroadcastss ymm9,DWORD PTR [r9+56]\t\n" + "vfmadd231ps ymm5,ymm14,ymm9\t\n" + "vbroadcastss ymm9,DWORD PTR [r9+60]\t\n" + "vfmadd231ps ymm6,ymm14,ymm9\t\n" + "vbroadcastss ymm9,DWORD PTR [r9+64]\t\n" + "vfmadd231ps ymm7,ymm14,ymm9\t\n" + "vbroadcastss ymm9,DWORD PTR [r9+68]\t\n" + "vfmadd231ps ymm8,ymm14,ymm9\t\n" + "add r9,72\t\n" + "cmp r14, r8\t\n" + "jl loop_inner%=\t\n" + + "L_exit%=:\t\n" + "add r10, rsi\t\n" + + "cmp rdx, 1\t\n" + "je L_accum%=\t\n" + // Dump C + "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" + "add r12, r13\t\n" + "jmp L_done%=\t\n" + + "L_accum%=:\t\n" + // Dump C with accumulate + "vbroadcastss ymm15,DWORD PTR [r15]\t\n" + "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" + "add r12, r13\t\n" + + "L_done%=:\t\n" + + // next outer iteration + "add rcx, 32\t\n" + "mov r12, rcx\t\n" + "mov r9, rax\t\n" + "inc rbx\t\n" + "cmp rbx, rdi\t\n" + "jl loop_outter%=\t\n" + : + : [gp] "rm"(gp) + : "r8", + "r9", + "r10", + "r11", + "r15", + "r13", + "r14", + "rax", + "rcx", + "rdx", + "rsi", + "rdi", + "rbx", + "r12", + "memory"); +} +void __attribute__((noinline)) gemmkernel_10x1_AVX2_fA0fB0fC0(GemmParams* gp) { + asm volatile( +#if !defined(__clang__) + "mov r14, %[gp]\t\n" +#else + "mov %[gp], %%r14\t\n" + ".intel_syntax noprefix\t\n" +#endif + + // Copy parameters + // k + "mov r8, [r14 + 0]\t\n" + // A + "mov r9, [r14 + 8]\t\n" + // B + "mov r10, [r14 + 16]\t\n" + // beta + "mov r15, [r14 + 24]\t\n" + // accum + "mov rdx, [r14 + 32]\t\n" + // C + "mov r12, [r14 + 40]\t\n" + // ldc + "mov r13, [r14 + 48]\t\n" + // b_block_cols + "mov rdi, [r14 + 56]\t\n" + // b_block_size + "mov rsi, [r14 + 64]\t\n" + // Make copies of A and C + "mov rax, r9\t\n" + "mov rcx, r12\t\n" + + "mov rbx, 0\t\n" + "loop_outter%=:\t\n" + "mov r14, 0\t\n" + "vxorps ymm0,ymm0,ymm0\t\n" + "vxorps ymm1,ymm1,ymm1\t\n" + "vxorps ymm2,ymm2,ymm2\t\n" + "vxorps ymm3,ymm3,ymm3\t\n" + "vxorps ymm4,ymm4,ymm4\t\n" + "vxorps ymm5,ymm5,ymm5\t\n" + "vxorps ymm6,ymm6,ymm6\t\n" + "vxorps ymm7,ymm7,ymm7\t\n" + "vxorps ymm8,ymm8,ymm8\t\n" + "vxorps ymm9,ymm9,ymm9\t\n" + + "vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" + "mov r11, 16\t\n" + + "loop_inner%=:\t\n" + + "vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" + "inc r14\t\n" + "vbroadcastss ymm10,DWORD PTR [r9+0]\t\n" + "vfmadd231ps ymm0,ymm15,ymm10\t\n" + "vbroadcastss ymm10,DWORD PTR [r9+4]\t\n" + "vfmadd231ps ymm1,ymm15,ymm10\t\n" + "vbroadcastss ymm10,DWORD PTR [r9+8]\t\n" + "vfmadd231ps ymm2,ymm15,ymm10\t\n" + "vbroadcastss ymm10,DWORD PTR [r9+12]\t\n" + "vfmadd231ps ymm3,ymm15,ymm10\t\n" + "vbroadcastss ymm10,DWORD PTR [r9+16]\t\n" + "vfmadd231ps ymm4,ymm15,ymm10\t\n" + "vbroadcastss ymm10,DWORD PTR [r9+20]\t\n" + "vfmadd231ps ymm5,ymm15,ymm10\t\n" + "vbroadcastss ymm10,DWORD PTR [r9+24]\t\n" + "vfmadd231ps ymm6,ymm15,ymm10\t\n" + "vbroadcastss ymm10,DWORD PTR [r9+28]\t\n" + "vfmadd231ps ymm7,ymm15,ymm10\t\n" + "vbroadcastss ymm10,DWORD PTR [r9+32]\t\n" + "vfmadd231ps ymm8,ymm15,ymm10\t\n" + "vbroadcastss ymm10,DWORD PTR [r9+36]\t\n" + "vfmadd231ps ymm9,ymm15,ymm10\t\n" + "cmp r14, r8\t\n" + "jge L_exit%=\t\n" + "vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" + "inc r14\t\n" + "vbroadcastss ymm10,DWORD PTR [r9+40]\t\n" + "vfmadd231ps ymm0,ymm14,ymm10\t\n" + "vbroadcastss ymm10,DWORD PTR [r9+44]\t\n" + "vfmadd231ps ymm1,ymm14,ymm10\t\n" + "vbroadcastss ymm10,DWORD PTR [r9+48]\t\n" + "vfmadd231ps ymm2,ymm14,ymm10\t\n" + "vbroadcastss ymm10,DWORD PTR [r9+52]\t\n" + "vfmadd231ps ymm3,ymm14,ymm10\t\n" + "vbroadcastss ymm10,DWORD PTR [r9+56]\t\n" + "vfmadd231ps ymm4,ymm14,ymm10\t\n" + "vbroadcastss ymm10,DWORD PTR [r9+60]\t\n" + "vfmadd231ps ymm5,ymm14,ymm10\t\n" + "add r11, 32\t\n" + "vbroadcastss ymm10,DWORD PTR [r9+64]\t\n" + "vfmadd231ps ymm6,ymm14,ymm10\t\n" + "vbroadcastss ymm10,DWORD PTR [r9+68]\t\n" + "vfmadd231ps ymm7,ymm14,ymm10\t\n" + "vbroadcastss ymm10,DWORD PTR [r9+72]\t\n" + "vfmadd231ps ymm8,ymm14,ymm10\t\n" + "vbroadcastss ymm10,DWORD PTR [r9+76]\t\n" + "vfmadd231ps ymm9,ymm14,ymm10\t\n" + "add r9,80\t\n" + "cmp r14, r8\t\n" + "jl loop_inner%=\t\n" + + "L_exit%=:\t\n" + "add r10, rsi\t\n" + + "cmp rdx, 1\t\n" + "je L_accum%=\t\n" + // Dump C + "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" + "add r12, r13\t\n" + "jmp L_done%=\t\n" + + "L_accum%=:\t\n" + // Dump C with accumulate + "vbroadcastss ymm15,DWORD PTR [r15]\t\n" + "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" + "add r12, r13\t\n" + + "L_done%=:\t\n" + + // next outer iteration + "add rcx, 32\t\n" + "mov r12, rcx\t\n" + "mov r9, rax\t\n" + "inc rbx\t\n" + "cmp rbx, rdi\t\n" + "jl loop_outter%=\t\n" + : + : [gp] "rm"(gp) + : "r8", + "r9", + "r10", + "r11", + "r15", + "r13", + "r14", + "rax", + "rcx", + "rdx", + "rsi", + "rdi", + "rbx", + "r12", + "memory"); +} +void __attribute__((noinline)) gemmkernel_11x1_AVX2_fA0fB0fC0(GemmParams* gp) { + asm volatile( +#if !defined(__clang__) + "mov r14, %[gp]\t\n" +#else + "mov %[gp], %%r14\t\n" + ".intel_syntax noprefix\t\n" +#endif + + // Copy parameters + // k + "mov r8, [r14 + 0]\t\n" + // A + "mov r9, [r14 + 8]\t\n" + // B + "mov r10, [r14 + 16]\t\n" + // beta + "mov r15, [r14 + 24]\t\n" + // accum + "mov rdx, [r14 + 32]\t\n" + // C + "mov r12, [r14 + 40]\t\n" + // ldc + "mov r13, [r14 + 48]\t\n" + // b_block_cols + "mov rdi, [r14 + 56]\t\n" + // b_block_size + "mov rsi, [r14 + 64]\t\n" + // Make copies of A and C + "mov rax, r9\t\n" + "mov rcx, r12\t\n" + + "mov rbx, 0\t\n" + "loop_outter%=:\t\n" + "mov r14, 0\t\n" + "vxorps ymm0,ymm0,ymm0\t\n" + "vxorps ymm1,ymm1,ymm1\t\n" + "vxorps ymm2,ymm2,ymm2\t\n" + "vxorps ymm3,ymm3,ymm3\t\n" + "vxorps ymm4,ymm4,ymm4\t\n" + "vxorps ymm5,ymm5,ymm5\t\n" + "vxorps ymm6,ymm6,ymm6\t\n" + "vxorps ymm7,ymm7,ymm7\t\n" + "vxorps ymm8,ymm8,ymm8\t\n" + "vxorps ymm9,ymm9,ymm9\t\n" + "vxorps ymm10,ymm10,ymm10\t\n" + + "vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" + "mov r11, 16\t\n" + + "loop_inner%=:\t\n" + + "vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" + "inc r14\t\n" + "vbroadcastss ymm11,DWORD PTR [r9+0]\t\n" + "vfmadd231ps ymm0,ymm15,ymm11\t\n" + "vbroadcastss ymm11,DWORD PTR [r9+4]\t\n" + "vfmadd231ps ymm1,ymm15,ymm11\t\n" + "vbroadcastss ymm11,DWORD PTR [r9+8]\t\n" + "vfmadd231ps ymm2,ymm15,ymm11\t\n" + "vbroadcastss ymm11,DWORD PTR [r9+12]\t\n" + "vfmadd231ps ymm3,ymm15,ymm11\t\n" + "vbroadcastss ymm11,DWORD PTR [r9+16]\t\n" + "vfmadd231ps ymm4,ymm15,ymm11\t\n" + "vbroadcastss ymm11,DWORD PTR [r9+20]\t\n" + "vfmadd231ps ymm5,ymm15,ymm11\t\n" + "vbroadcastss ymm11,DWORD PTR [r9+24]\t\n" + "vfmadd231ps ymm6,ymm15,ymm11\t\n" + "vbroadcastss ymm11,DWORD PTR [r9+28]\t\n" + "vfmadd231ps ymm7,ymm15,ymm11\t\n" + "vbroadcastss ymm11,DWORD PTR [r9+32]\t\n" + "vfmadd231ps ymm8,ymm15,ymm11\t\n" + "vbroadcastss ymm11,DWORD PTR [r9+36]\t\n" + "vfmadd231ps ymm9,ymm15,ymm11\t\n" + "vbroadcastss ymm11,DWORD PTR [r9+40]\t\n" + "vfmadd231ps ymm10,ymm15,ymm11\t\n" + "cmp r14, r8\t\n" + "jge L_exit%=\t\n" + "vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" + "inc r14\t\n" + "vbroadcastss ymm11,DWORD PTR [r9+44]\t\n" + "vfmadd231ps ymm0,ymm14,ymm11\t\n" + "vbroadcastss ymm11,DWORD PTR [r9+48]\t\n" + "vfmadd231ps ymm1,ymm14,ymm11\t\n" + "vbroadcastss ymm11,DWORD PTR [r9+52]\t\n" + "vfmadd231ps ymm2,ymm14,ymm11\t\n" + "vbroadcastss ymm11,DWORD PTR [r9+56]\t\n" + "vfmadd231ps ymm3,ymm14,ymm11\t\n" + "vbroadcastss ymm11,DWORD PTR [r9+60]\t\n" + "vfmadd231ps ymm4,ymm14,ymm11\t\n" + "vbroadcastss ymm11,DWORD PTR [r9+64]\t\n" + "vfmadd231ps ymm5,ymm14,ymm11\t\n" + "add r11, 32\t\n" + "vbroadcastss ymm11,DWORD PTR [r9+68]\t\n" + "vfmadd231ps ymm6,ymm14,ymm11\t\n" + "vbroadcastss ymm11,DWORD PTR [r9+72]\t\n" + "vfmadd231ps ymm7,ymm14,ymm11\t\n" + "vbroadcastss ymm11,DWORD PTR [r9+76]\t\n" + "vfmadd231ps ymm8,ymm14,ymm11\t\n" + "vbroadcastss ymm11,DWORD PTR [r9+80]\t\n" + "vfmadd231ps ymm9,ymm14,ymm11\t\n" + "vbroadcastss ymm11,DWORD PTR [r9+84]\t\n" + "vfmadd231ps ymm10,ymm14,ymm11\t\n" + "add r9,88\t\n" + "cmp r14, r8\t\n" + "jl loop_inner%=\t\n" + + "L_exit%=:\t\n" + "add r10, rsi\t\n" + + "cmp rdx, 1\t\n" + "je L_accum%=\t\n" + // Dump C + "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" + "add r12, r13\t\n" + "jmp L_done%=\t\n" + + "L_accum%=:\t\n" + // Dump C with accumulate + "vbroadcastss ymm15,DWORD PTR [r15]\t\n" + "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm10,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" + "add r12, r13\t\n" + + "L_done%=:\t\n" + + // next outer iteration + "add rcx, 32\t\n" + "mov r12, rcx\t\n" + "mov r9, rax\t\n" + "inc rbx\t\n" + "cmp rbx, rdi\t\n" + "jl loop_outter%=\t\n" + : + : [gp] "rm"(gp) + : "r8", + "r9", + "r10", + "r11", + "r15", + "r13", + "r14", + "rax", + "rcx", + "rdx", + "rsi", + "rdi", + "rbx", + "r12", + "memory"); +} +void __attribute__((noinline)) gemmkernel_12x1_AVX2_fA0fB0fC0(GemmParams* gp) { + asm volatile( +#if !defined(__clang__) + "mov r14, %[gp]\t\n" +#else + "mov %[gp], %%r14\t\n" + ".intel_syntax noprefix\t\n" +#endif + + // Copy parameters + // k + "mov r8, [r14 + 0]\t\n" + // A + "mov r9, [r14 + 8]\t\n" + // B + "mov r10, [r14 + 16]\t\n" + // beta + "mov r15, [r14 + 24]\t\n" + // accum + "mov rdx, [r14 + 32]\t\n" + // C + "mov r12, [r14 + 40]\t\n" + // ldc + "mov r13, [r14 + 48]\t\n" + // b_block_cols + "mov rdi, [r14 + 56]\t\n" + // b_block_size + "mov rsi, [r14 + 64]\t\n" + // Make copies of A and C + "mov rax, r9\t\n" + "mov rcx, r12\t\n" + + "mov rbx, 0\t\n" + "loop_outter%=:\t\n" + "mov r14, 0\t\n" + "vxorps ymm0,ymm0,ymm0\t\n" + "vxorps ymm1,ymm1,ymm1\t\n" + "vxorps ymm2,ymm2,ymm2\t\n" + "vxorps ymm3,ymm3,ymm3\t\n" + "vxorps ymm4,ymm4,ymm4\t\n" + "vxorps ymm5,ymm5,ymm5\t\n" + "vxorps ymm6,ymm6,ymm6\t\n" + "vxorps ymm7,ymm7,ymm7\t\n" + "vxorps ymm8,ymm8,ymm8\t\n" + "vxorps ymm9,ymm9,ymm9\t\n" + "vxorps ymm10,ymm10,ymm10\t\n" + "vxorps ymm11,ymm11,ymm11\t\n" + + "vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" + "mov r11, 16\t\n" + + "loop_inner%=:\t\n" + + "vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" + "inc r14\t\n" + "vbroadcastss ymm12,DWORD PTR [r9+0]\t\n" + "vfmadd231ps ymm0,ymm15,ymm12\t\n" + "vbroadcastss ymm12,DWORD PTR [r9+4]\t\n" + "vfmadd231ps ymm1,ymm15,ymm12\t\n" + "vbroadcastss ymm12,DWORD PTR [r9+8]\t\n" + "vfmadd231ps ymm2,ymm15,ymm12\t\n" + "vbroadcastss ymm12,DWORD PTR [r9+12]\t\n" + "vfmadd231ps ymm3,ymm15,ymm12\t\n" + "vbroadcastss ymm12,DWORD PTR [r9+16]\t\n" + "vfmadd231ps ymm4,ymm15,ymm12\t\n" + "vbroadcastss ymm12,DWORD PTR [r9+20]\t\n" + "vfmadd231ps ymm5,ymm15,ymm12\t\n" + "vbroadcastss ymm12,DWORD PTR [r9+24]\t\n" + "vfmadd231ps ymm6,ymm15,ymm12\t\n" + "vbroadcastss ymm12,DWORD PTR [r9+28]\t\n" + "vfmadd231ps ymm7,ymm15,ymm12\t\n" + "vbroadcastss ymm12,DWORD PTR [r9+32]\t\n" + "vfmadd231ps ymm8,ymm15,ymm12\t\n" + "vbroadcastss ymm12,DWORD PTR [r9+36]\t\n" + "vfmadd231ps ymm9,ymm15,ymm12\t\n" + "vbroadcastss ymm12,DWORD PTR [r9+40]\t\n" + "vfmadd231ps ymm10,ymm15,ymm12\t\n" + "vbroadcastss ymm12,DWORD PTR [r9+44]\t\n" + "vfmadd231ps ymm11,ymm15,ymm12\t\n" + "cmp r14, r8\t\n" + "jge L_exit%=\t\n" + "vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" + "inc r14\t\n" + "vbroadcastss ymm12,DWORD PTR [r9+48]\t\n" + "vfmadd231ps ymm0,ymm14,ymm12\t\n" + "vbroadcastss ymm12,DWORD PTR [r9+52]\t\n" + "vfmadd231ps ymm1,ymm14,ymm12\t\n" + "vbroadcastss ymm12,DWORD PTR [r9+56]\t\n" + "vfmadd231ps ymm2,ymm14,ymm12\t\n" + "vbroadcastss ymm12,DWORD PTR [r9+60]\t\n" + "vfmadd231ps ymm3,ymm14,ymm12\t\n" + "vbroadcastss ymm12,DWORD PTR [r9+64]\t\n" + "vfmadd231ps ymm4,ymm14,ymm12\t\n" + "vbroadcastss ymm12,DWORD PTR [r9+68]\t\n" + "vfmadd231ps ymm5,ymm14,ymm12\t\n" + "vbroadcastss ymm12,DWORD PTR [r9+72]\t\n" + "vfmadd231ps ymm6,ymm14,ymm12\t\n" + "add r11, 32\t\n" + "vbroadcastss ymm12,DWORD PTR [r9+76]\t\n" + "vfmadd231ps ymm7,ymm14,ymm12\t\n" + "vbroadcastss ymm12,DWORD PTR [r9+80]\t\n" + "vfmadd231ps ymm8,ymm14,ymm12\t\n" + "vbroadcastss ymm12,DWORD PTR [r9+84]\t\n" + "vfmadd231ps ymm9,ymm14,ymm12\t\n" + "vbroadcastss ymm12,DWORD PTR [r9+88]\t\n" + "vfmadd231ps ymm10,ymm14,ymm12\t\n" + "vbroadcastss ymm12,DWORD PTR [r9+92]\t\n" + "vfmadd231ps ymm11,ymm14,ymm12\t\n" + "add r9,96\t\n" + "cmp r14, r8\t\n" + "jl loop_inner%=\t\n" + + "L_exit%=:\t\n" + "add r10, rsi\t\n" + + "cmp rdx, 1\t\n" + "je L_accum%=\t\n" + // Dump C + "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm11\t\n" + "add r12, r13\t\n" + "jmp L_done%=\t\n" + + "L_accum%=:\t\n" + // Dump C with accumulate + "vbroadcastss ymm15,DWORD PTR [r15]\t\n" + "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm10,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm11,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm11\t\n" + "add r12, r13\t\n" + + "L_done%=:\t\n" + + // next outer iteration + "add rcx, 32\t\n" + "mov r12, rcx\t\n" + "mov r9, rax\t\n" + "inc rbx\t\n" + "cmp rbx, rdi\t\n" + "jl loop_outter%=\t\n" + : + : [gp] "rm"(gp) + : "r8", + "r9", + "r10", + "r11", + "r15", + "r13", + "r14", + "rax", + "rcx", + "rdx", + "rsi", + "rdi", + "rbx", + "r12", + "memory"); +} +void __attribute__((noinline)) gemmkernel_13x1_AVX2_fA0fB0fC0(GemmParams* gp) { + asm volatile( +#if !defined(__clang__) + "mov r14, %[gp]\t\n" +#else + "mov %[gp], %%r14\t\n" + ".intel_syntax noprefix\t\n" +#endif + + // Copy parameters + // k + "mov r8, [r14 + 0]\t\n" + // A + "mov r9, [r14 + 8]\t\n" + // B + "mov r10, [r14 + 16]\t\n" + // beta + "mov r15, [r14 + 24]\t\n" + // accum + "mov rdx, [r14 + 32]\t\n" + // C + "mov r12, [r14 + 40]\t\n" + // ldc + "mov r13, [r14 + 48]\t\n" + // b_block_cols + "mov rdi, [r14 + 56]\t\n" + // b_block_size + "mov rsi, [r14 + 64]\t\n" + // Make copies of A and C + "mov rax, r9\t\n" + "mov rcx, r12\t\n" + + "mov rbx, 0\t\n" + "loop_outter%=:\t\n" + "mov r14, 0\t\n" + "vxorps ymm0,ymm0,ymm0\t\n" + "vxorps ymm1,ymm1,ymm1\t\n" + "vxorps ymm2,ymm2,ymm2\t\n" + "vxorps ymm3,ymm3,ymm3\t\n" + "vxorps ymm4,ymm4,ymm4\t\n" + "vxorps ymm5,ymm5,ymm5\t\n" + "vxorps ymm6,ymm6,ymm6\t\n" + "vxorps ymm7,ymm7,ymm7\t\n" + "vxorps ymm8,ymm8,ymm8\t\n" + "vxorps ymm9,ymm9,ymm9\t\n" + "vxorps ymm10,ymm10,ymm10\t\n" + "vxorps ymm11,ymm11,ymm11\t\n" + "vxorps ymm12,ymm12,ymm12\t\n" + + "vcvtph2ps ymm15, XMMWORD PTR [r10 + 0]\t\n" + "mov r11, 16\t\n" + + "loop_inner%=:\t\n" + + "vcvtph2ps ymm14,XMMWORD PTR [r10 + r11 + 0]\t\n" + "inc r14\t\n" + "vbroadcastss ymm13,DWORD PTR [r9+0]\t\n" + "vfmadd231ps ymm0,ymm15,ymm13\t\n" + "vbroadcastss ymm13,DWORD PTR [r9+4]\t\n" + "vfmadd231ps ymm1,ymm15,ymm13\t\n" + "vbroadcastss ymm13,DWORD PTR [r9+8]\t\n" + "vfmadd231ps ymm2,ymm15,ymm13\t\n" + "vbroadcastss ymm13,DWORD PTR [r9+12]\t\n" + "vfmadd231ps ymm3,ymm15,ymm13\t\n" + "vbroadcastss ymm13,DWORD PTR [r9+16]\t\n" + "vfmadd231ps ymm4,ymm15,ymm13\t\n" + "vbroadcastss ymm13,DWORD PTR [r9+20]\t\n" + "vfmadd231ps ymm5,ymm15,ymm13\t\n" + "vbroadcastss ymm13,DWORD PTR [r9+24]\t\n" + "vfmadd231ps ymm6,ymm15,ymm13\t\n" + "vbroadcastss ymm13,DWORD PTR [r9+28]\t\n" + "vfmadd231ps ymm7,ymm15,ymm13\t\n" + "vbroadcastss ymm13,DWORD PTR [r9+32]\t\n" + "vfmadd231ps ymm8,ymm15,ymm13\t\n" + "vbroadcastss ymm13,DWORD PTR [r9+36]\t\n" + "vfmadd231ps ymm9,ymm15,ymm13\t\n" + "vbroadcastss ymm13,DWORD PTR [r9+40]\t\n" + "vfmadd231ps ymm10,ymm15,ymm13\t\n" + "vbroadcastss ymm13,DWORD PTR [r9+44]\t\n" + "vfmadd231ps ymm11,ymm15,ymm13\t\n" + "vbroadcastss ymm13,DWORD PTR [r9+48]\t\n" + "vfmadd231ps ymm12,ymm15,ymm13\t\n" + "cmp r14, r8\t\n" + "jge L_exit%=\t\n" + "vcvtph2ps ymm15,XMMWORD PTR [r10 + r11 + 16]\t\n" + "inc r14\t\n" + "vbroadcastss ymm13,DWORD PTR [r9+52]\t\n" + "vfmadd231ps ymm0,ymm14,ymm13\t\n" + "vbroadcastss ymm13,DWORD PTR [r9+56]\t\n" + "vfmadd231ps ymm1,ymm14,ymm13\t\n" + "vbroadcastss ymm13,DWORD PTR [r9+60]\t\n" + "vfmadd231ps ymm2,ymm14,ymm13\t\n" + "vbroadcastss ymm13,DWORD PTR [r9+64]\t\n" + "vfmadd231ps ymm3,ymm14,ymm13\t\n" + "vbroadcastss ymm13,DWORD PTR [r9+68]\t\n" + "vfmadd231ps ymm4,ymm14,ymm13\t\n" + "vbroadcastss ymm13,DWORD PTR [r9+72]\t\n" + "vfmadd231ps ymm5,ymm14,ymm13\t\n" + "vbroadcastss ymm13,DWORD PTR [r9+76]\t\n" + "vfmadd231ps ymm6,ymm14,ymm13\t\n" + "add r11, 32\t\n" + "vbroadcastss ymm13,DWORD PTR [r9+80]\t\n" + "vfmadd231ps ymm7,ymm14,ymm13\t\n" + "vbroadcastss ymm13,DWORD PTR [r9+84]\t\n" + "vfmadd231ps ymm8,ymm14,ymm13\t\n" + "vbroadcastss ymm13,DWORD PTR [r9+88]\t\n" + "vfmadd231ps ymm9,ymm14,ymm13\t\n" + "vbroadcastss ymm13,DWORD PTR [r9+92]\t\n" + "vfmadd231ps ymm10,ymm14,ymm13\t\n" + "vbroadcastss ymm13,DWORD PTR [r9+96]\t\n" + "vfmadd231ps ymm11,ymm14,ymm13\t\n" + "vbroadcastss ymm13,DWORD PTR [r9+100]\t\n" + "vfmadd231ps ymm12,ymm14,ymm13\t\n" + "add r9,104\t\n" + "cmp r14, r8\t\n" + "jl loop_inner%=\t\n" + + "L_exit%=:\t\n" + "add r10, rsi\t\n" + + "cmp rdx, 1\t\n" + "je L_accum%=\t\n" + // Dump C + "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm11\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm12\t\n" + "add r12, r13\t\n" + "jmp L_done%=\t\n" + + "L_accum%=:\t\n" + // Dump C with accumulate + "vbroadcastss ymm15,DWORD PTR [r15]\t\n" + "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm10,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm11,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm11\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm12,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm12\t\n" + "add r12, r13\t\n" + + "L_done%=:\t\n" + + // next outer iteration + "add rcx, 32\t\n" + "mov r12, rcx\t\n" + "mov r9, rax\t\n" + "inc rbx\t\n" + "cmp rbx, rdi\t\n" + "jl loop_outter%=\t\n" + : + : [gp] "rm"(gp) + : "r8", + "r9", + "r10", + "r11", + "r15", + "r13", + "r14", + "rax", + "rcx", + "rdx", + "rsi", + "rdi", + "rbx", + "r12", + "memory"); +} +void __attribute__((noinline)) gemmkernel_14x1_AVX2_fA0fB0fC0(GemmParams* gp) { + asm volatile( +#if !defined(__clang__) + "mov r14, %[gp]\t\n" +#else + "mov %[gp], %%r14\t\n" + ".intel_syntax noprefix\t\n" +#endif + + // Copy parameters + // k + "mov r8, [r14 + 0]\t\n" + // A + "mov r9, [r14 + 8]\t\n" + // B + "mov r10, [r14 + 16]\t\n" + // beta + "mov r15, [r14 + 24]\t\n" + // accum + "mov rdx, [r14 + 32]\t\n" + // C + "mov r12, [r14 + 40]\t\n" + // ldc + "mov r13, [r14 + 48]\t\n" + // b_block_cols + "mov rdi, [r14 + 56]\t\n" + // b_block_size + "mov rsi, [r14 + 64]\t\n" + // Make copies of A and C + "mov rax, r9\t\n" + "mov rcx, r12\t\n" + + "mov rbx, 0\t\n" + "loop_outter%=:\t\n" + "mov r14, 0\t\n" + "vxorps ymm0,ymm0,ymm0\t\n" + "vxorps ymm1,ymm1,ymm1\t\n" + "vxorps ymm2,ymm2,ymm2\t\n" + "vxorps ymm3,ymm3,ymm3\t\n" + "vxorps ymm4,ymm4,ymm4\t\n" + "vxorps ymm5,ymm5,ymm5\t\n" + "vxorps ymm6,ymm6,ymm6\t\n" + "vxorps ymm7,ymm7,ymm7\t\n" + "vxorps ymm8,ymm8,ymm8\t\n" + "vxorps ymm9,ymm9,ymm9\t\n" + "vxorps ymm10,ymm10,ymm10\t\n" + "vxorps ymm11,ymm11,ymm11\t\n" + "vxorps ymm12,ymm12,ymm12\t\n" + "vxorps ymm13,ymm13,ymm13\t\n" + + "mov r11, 0\t\n" + + "loop_inner%=:\t\n" + + "vcvtph2ps ymm15,XMMWORD PTR [r10 + r11]\t\n" + "vbroadcastss ymm14,DWORD PTR [r9+0]\t\n" + "vfmadd231ps ymm0,ymm15,ymm14\t\n" + "vbroadcastss ymm14,DWORD PTR [r9+4]\t\n" + "vfmadd231ps ymm1,ymm15,ymm14\t\n" + "vbroadcastss ymm14,DWORD PTR [r9+8]\t\n" + "vfmadd231ps ymm2,ymm15,ymm14\t\n" + "vbroadcastss ymm14,DWORD PTR [r9+12]\t\n" + "vfmadd231ps ymm3,ymm15,ymm14\t\n" + "vbroadcastss ymm14,DWORD PTR [r9+16]\t\n" + "vfmadd231ps ymm4,ymm15,ymm14\t\n" + "vbroadcastss ymm14,DWORD PTR [r9+20]\t\n" + "vfmadd231ps ymm5,ymm15,ymm14\t\n" + "vbroadcastss ymm14,DWORD PTR [r9+24]\t\n" + "vfmadd231ps ymm6,ymm15,ymm14\t\n" + "vbroadcastss ymm14,DWORD PTR [r9+28]\t\n" + "vfmadd231ps ymm7,ymm15,ymm14\t\n" + "vbroadcastss ymm14,DWORD PTR [r9+32]\t\n" + "vfmadd231ps ymm8,ymm15,ymm14\t\n" + "vbroadcastss ymm14,DWORD PTR [r9+36]\t\n" + "vfmadd231ps ymm9,ymm15,ymm14\t\n" + "vbroadcastss ymm14,DWORD PTR [r9+40]\t\n" + "vfmadd231ps ymm10,ymm15,ymm14\t\n" + "vbroadcastss ymm14,DWORD PTR [r9+44]\t\n" + "vfmadd231ps ymm11,ymm15,ymm14\t\n" + "vbroadcastss ymm14,DWORD PTR [r9+48]\t\n" + "vfmadd231ps ymm12,ymm15,ymm14\t\n" + "vbroadcastss ymm14,DWORD PTR [r9+52]\t\n" + "vfmadd231ps ymm13,ymm15,ymm14\t\n" + "add r9,56\t\n" + "add r11, 16\t\n" + "inc r14\t\n" + "cmp r14, r8\t\n" + "jl loop_inner%=\t\n" + "add r10, rsi\t\n" + + "cmp rdx, 1\t\n" + "je L_accum%=\t\n" + // Dump C + "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm11\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm12\t\n" + "add r12, r13\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm13\t\n" + "add r12, r13\t\n" + "jmp L_done%=\t\n" + + "L_accum%=:\t\n" + // Dump C with accumulate + "vbroadcastss ymm15,DWORD PTR [r15]\t\n" + "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm1\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm3\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm5\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm7\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm9\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm10,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm11,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm11\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm12,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm12\t\n" + "add r12, r13\t\n" + "vfmadd231ps ymm13,ymm15,YMMWORD PTR [r12 + 0]\t\n" + "vmovups YMMWORD PTR [r12 + 0], ymm13\t\n" + "add r12, r13\t\n" + + "L_done%=:\t\n" + + // next outer iteration + "add rcx, 32\t\n" + "mov r12, rcx\t\n" + "mov r9, rax\t\n" + "inc rbx\t\n" + "cmp rbx, rdi\t\n" + "jl loop_outter%=\t\n" + : + : [gp] "rm"(gp) + : "r8", + "r9", + "r10", + "r11", + "r15", + "r13", + "r14", + "rax", + "rcx", + "rdx", + "rsi", + "rdi", + "rbx", + "r12", + "memory"); +} + +} // namespace fbgemm diff --git a/src/FbgemmFP16UKernelsAvx2.h b/src/FbgemmFP16UKernelsAvx2.h new file mode 100644 index 0000000..4053332 --- /dev/null +++ b/src/FbgemmFP16UKernelsAvx2.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#ifndef FBGEMM_UKERNELS +#define FBGEMM_UKERNELS +#include +#include "fbgemm/Types.h" + +namespace fbgemm { + +using fp16 = float16; +using fp32 = float; +struct GemmParams { + uint64_t k; + float* A; + const fp16* B; + float* beta; + uint64_t accum; + float* C; + uint64_t ldc; + uint64_t b_block_cols; + uint64_t b_block_size; +}; +void __attribute__((noinline)) gemmkernel_1x1_AVX2_fA0fB0fC0(GemmParams* gp); +void __attribute__((noinline)) gemmkernel_2x1_AVX2_fA0fB0fC0(GemmParams* gp); +void __attribute__((noinline)) gemmkernel_3x1_AVX2_fA0fB0fC0(GemmParams* gp); +void __attribute__((noinline)) gemmkernel_4x1_AVX2_fA0fB0fC0(GemmParams* gp); +void __attribute__((noinline)) gemmkernel_5x1_AVX2_fA0fB0fC0(GemmParams* gp); +void __attribute__((noinline)) gemmkernel_6x1_AVX2_fA0fB0fC0(GemmParams* gp); +void __attribute__((noinline)) gemmkernel_7x1_AVX2_fA0fB0fC0(GemmParams* gp); +void __attribute__((noinline)) gemmkernel_8x1_AVX2_fA0fB0fC0(GemmParams* gp); +void __attribute__((noinline)) gemmkernel_9x1_AVX2_fA0fB0fC0(GemmParams* gp); +void __attribute__((noinline)) gemmkernel_10x1_AVX2_fA0fB0fC0(GemmParams* gp); +void __attribute__((noinline)) gemmkernel_11x1_AVX2_fA0fB0fC0(GemmParams* gp); +void __attribute__((noinline)) gemmkernel_12x1_AVX2_fA0fB0fC0(GemmParams* gp); +void __attribute__((noinline)) gemmkernel_13x1_AVX2_fA0fB0fC0(GemmParams* gp); +void __attribute__((noinline)) gemmkernel_14x1_AVX2_fA0fB0fC0(GemmParams* gp); +typedef void (*funcptr_fp16)(GemmParams* gp); +; + +} // namespace fbgemm + +#endif diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc new file mode 100644 index 0000000..eeeaea0 --- /dev/null +++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc @@ -0,0 +1,301 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include "GenerateKernel.h" + +namespace fbgemm { + +namespace x86 = asmjit::x86; + +/** + * Generate AVX512 instructions for initializing the C registers to 0 in 16-bit + * Accumulation kernel. + */ +template <> +template <> +void CodeGenBase::initCRegs< + inst_set_t::avx512>( + asmjit::X86Emitter* a, + int rowRegs, + int colRegs, + int leadingDimCRegAssign) { + for (int i = 0; i < rowRegs; ++i) { + for (int j = 0; j < colRegs; ++j) { + a->vxorps( + CRegs_avx512_[i * leadingDimCRegAssign + j], + CRegs_avx512_[i * leadingDimCRegAssign + j], + CRegs_avx512_[i * leadingDimCRegAssign + j]); + } + } +} + +/** + * Generate AVX512 instructions for computing block in the rank-k update of + * 16-bit Accmulation kernel. + */ +template <> +template <> +void CodeGenBase::genComputeBlock< + inst_set_t::avx512>( + asmjit::X86Emitter* a, + asmjit::X86Gp buffer_A, + asmjit::X86Gp buffer_B, + asmjit::X86Gp /* unused (reserved for prefetching)*/, + int rowRegs, + int colRegs, + int lda, + int leadingDimCRegAssign) { + // used for matrix A + asmjit::X86Zmm AReg = x86::zmm29; + + asmjit::X86Zmm tmpReg = x86::zmm30; + + for (int i = 0; i < rowRegs; ++i) { + // broadcast A + a->vpbroadcastw( + AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t))); + for (int j = 0; j < colRegs; ++j) { + a->vpmaddubsw( + tmpReg, AReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t))); + a->vpaddsw( + CRegs_avx512_[i * leadingDimCRegAssign + j], + tmpReg, + CRegs_avx512_[i * leadingDimCRegAssign + j]); + // Prefetching is hurting performance in some cases + // because prefetch instructions itself consumes a slot + // in pipeline issue thus slowing down the kernel. + // if((i == rowRegs - 1) && j % 2 == 0){ + // a->prefetcht0(x86::dword_ptr(B_pf, j*VLEN_*sizeof(int8_t))); + //} + } + } +} + +/** + * Generate AVX512 instructions for storing the C registers back to the memory + * in 16-bit Accumulation kernel. + */ +template <> +template <> +void CodeGenBase::storeCRegs< + inst_set_t::avx512>( + asmjit::X86Emitter* a, + int rowRegs, + int colRegs, + asmjit::X86Gp C_Offset, + asmjit::X86Gp ldcReg, + bool accum, + int leadingDimCRegAssign) { + asmjit::X86Ymm extractDest256 = x86::ymm31; + asmjit::X86Zmm extractDest512 = x86::zmm31; + + for (int i = 0; i < rowRegs; ++i) { + a->imul(C_Offset, ldcReg, static_cast(i * sizeof(int32_t))); + for (int j = 0; j < colRegs; ++j) { + for (int idx = 0; idx < 2; ++idx) { + a->vextracti32x8( + extractDest256, CRegs_avx512_[i * leadingDimCRegAssign + j], idx); + a->vpmovsxwd(extractDest512, extractDest256); + asmjit::X86Mem destAddr = x86::dword_ptr( + a->zcx(), C_Offset, 0, (j * 2 + idx) * 16 * sizeof(int32_t)); + if (accum) { + a->vpaddd(extractDest512, extractDest512, destAddr); + } + a->vmovups(destAddr, extractDest512); + } + } + } +} + +/** + * Get or Create the AVX512 instructions for 16-bit Accumulation macro-kernel. + * + */ +template <> +template <> +CodeGenBase::jit_micro_kernel_fp +CodeGenBase::getOrCreate( + bool accum, + int32_t mc, + int32_t nc, + int32_t kc, + int32_t /* unused */) { + auto kernelSig = std::make_tuple(accum, mc, nc); + if (codeCache_.find(kernelSig) != codeCache_.end()) { + return codeCache_[kernelSig]; + } + + code_.reset(false); + code_.init(rt_.getCodeInfo()); + asmjit::X86Assembler assembler(&code_); + asmjit::X86Emitter* a = assembler.asEmitter(); + // ToDo: Dump in a file for debugging + // code dumping/logging + // asmjit::FileLogger logger(stderr); + // code_.setLogger(&logger); + + constexpr int kBlock = + PackingTraits::KCB; + constexpr int mRegBlockSize = + PackingTraits::MR; + // constexpr int nRegBlockSize = + // PackingTraits::NR; + constexpr int row_interleave = + PackingTraits::ROW_INTERLEAVE; + int mRegBlocks = mc / mRegBlockSize; + int mRegBlocksRem = mc % mRegBlockSize; + assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); + // assert((nc == nRegBlockSize) && + //"nc must be equal to the number of register blocks"); + + // arguments to the function created + asmjit::X86Gp buffer_A = a->zdi(); + asmjit::X86Gp buffer_B = a->zsi(); + asmjit::X86Gp B_pf = a->zdx(); + asmjit::X86Gp CBase = a->zcx(); + asmjit::X86Gp kSize = a->gpzRef(8); + asmjit::X86Gp ldcReg = a->gpzRef(9); + + asmjit::FuncDetail func; + func.init( + asmjit:: + FuncSignature6( + asmjit::CallConv::kIdHost)); + + asmjit::FuncFrameInfo ffi; + ffi.setDirtyRegs( + asmjit::X86Reg::kKindVec, + asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); + ffi.setDirtyRegs( + asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14)); + + asmjit::FuncArgsMapper args(&func); + args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); + + args.updateFrameInfo(ffi); + + asmjit::FuncFrameLayout layout; + layout.init(func, ffi); + + asmjit::FuncUtils::emitProlog(a, layout); + asmjit::FuncUtils::allocArgs(a, layout, args); + + asmjit::Label Loopk = a->newLabel(); + asmjit::Label LoopMBlocks = a->newLabel(); + + asmjit::X86Gp buffer_B_saved = a->gpzRef(10); + asmjit::X86Gp C_Offset = a->gpzRef(11); + // asmjit::X86Gp B_pf_saved = a->gpzRef(12); + asmjit::X86Gp iIdx = a->gpzRef(13); + asmjit::X86Gp kIdx = a->gpzRef(14); + + int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; + if (mRegBlocks > 0) { + // move 0 to iteration variables + a->mov(iIdx, 0); + + // save B_buffer address + a->mov(buffer_B_saved, buffer_B); + // a->mov(B_pf_saved, B_pf); + + a->bind(LoopMBlocks); + a->inc(iIdx); + + int rowRegs = mRegBlockSize; + + // init C registers + initCRegs(a, rowRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(Loopk); + // k is incremented by row_interleave + a->add(kIdx, static_cast(row_interleave)); + + genComputeBlock( + a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); + + // update buffer_A address for next k iteration + a->add( + buffer_A, static_cast(row_interleave * sizeof(uint8_t))); + + // update buffer_B address for next k iteration + a->add( + buffer_B, static_cast(VLEN_ * colRegs * sizeof(int8_t))); + // a->add(B_pf, static_cast(VLEN_ * colRegs * sizeof(int8_t))); + + a->cmp(kIdx, kSize); + a->jl(Loopk); + + // store C matrix + storeCRegs( + a, rowRegs, colRegs, C_Offset, ldcReg, accum); + + // increment A for next block + a->sub(buffer_A, kSize); + a->add( + buffer_A, static_cast((rowRegs)*kBlock * sizeof(uint8_t))); + // increment C for next block + a->imul( + C_Offset, ldcReg, static_cast(rowRegs * sizeof(int32_t))); + a->add(CBase, C_Offset); + // reset B + a->mov(buffer_B, buffer_B_saved); + // a->mov(B_pf, B_pf_saved); + + a->cmp(iIdx, mRegBlocks); + a->jl(LoopMBlocks); + } + // generate code for remainder + if (mRegBlocksRem > 0) { + asmjit::Label LoopkRem = a->newLabel(); + int rowRegs = mRegBlocksRem; + + // init C registers + initCRegs(a, rowRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(LoopkRem); + + // k is incremented by row_interleave + a->add(kIdx, static_cast(row_interleave)); + + genComputeBlock( + a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); + + // update buffer_A address for next k iteration + a->add( + buffer_A, static_cast(row_interleave * sizeof(uint8_t))); + + // update buffer_B address for next k iteration + a->add( + buffer_B, static_cast(VLEN_ * colRegs * sizeof(int8_t))); + // a->add(B_pf, static_cast(VLEN_ * colRegs * sizeof(int8_t))); + + a->cmp(kIdx, kSize); + a->jl(LoopkRem); + + // store C matrix + storeCRegs( + a, rowRegs, colRegs, C_Offset, ldcReg, accum); + } + + asmjit::FuncUtils::emitEpilog(a, layout); + + jit_micro_kernel_fp fn; + asmjit::Error err = rt_.add(&fn, &code_); + if (err) { + std::cout << "Error: in fn add" << std::endl; + return nullptr; + } + codeCache_[kernelSig] = fn; + return fn; +} + +} // namespace fbgemm diff --git a/src/GenerateKernelU8S8S32ACC16_avx512.cc b/src/GenerateKernelU8S8S32ACC16_avx512.cc deleted file mode 100644 index eeeaea0..0000000 --- a/src/GenerateKernelU8S8S32ACC16_avx512.cc +++ /dev/null @@ -1,301 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * All rights reserved. - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ -#include -#include "GenerateKernel.h" - -namespace fbgemm { - -namespace x86 = asmjit::x86; - -/** - * Generate AVX512 instructions for initializing the C registers to 0 in 16-bit - * Accumulation kernel. - */ -template <> -template <> -void CodeGenBase::initCRegs< - inst_set_t::avx512>( - asmjit::X86Emitter* a, - int rowRegs, - int colRegs, - int leadingDimCRegAssign) { - for (int i = 0; i < rowRegs; ++i) { - for (int j = 0; j < colRegs; ++j) { - a->vxorps( - CRegs_avx512_[i * leadingDimCRegAssign + j], - CRegs_avx512_[i * leadingDimCRegAssign + j], - CRegs_avx512_[i * leadingDimCRegAssign + j]); - } - } -} - -/** - * Generate AVX512 instructions for computing block in the rank-k update of - * 16-bit Accmulation kernel. - */ -template <> -template <> -void CodeGenBase::genComputeBlock< - inst_set_t::avx512>( - asmjit::X86Emitter* a, - asmjit::X86Gp buffer_A, - asmjit::X86Gp buffer_B, - asmjit::X86Gp /* unused (reserved for prefetching)*/, - int rowRegs, - int colRegs, - int lda, - int leadingDimCRegAssign) { - // used for matrix A - asmjit::X86Zmm AReg = x86::zmm29; - - asmjit::X86Zmm tmpReg = x86::zmm30; - - for (int i = 0; i < rowRegs; ++i) { - // broadcast A - a->vpbroadcastw( - AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t))); - for (int j = 0; j < colRegs; ++j) { - a->vpmaddubsw( - tmpReg, AReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t))); - a->vpaddsw( - CRegs_avx512_[i * leadingDimCRegAssign + j], - tmpReg, - CRegs_avx512_[i * leadingDimCRegAssign + j]); - // Prefetching is hurting performance in some cases - // because prefetch instructions itself consumes a slot - // in pipeline issue thus slowing down the kernel. - // if((i == rowRegs - 1) && j % 2 == 0){ - // a->prefetcht0(x86::dword_ptr(B_pf, j*VLEN_*sizeof(int8_t))); - //} - } - } -} - -/** - * Generate AVX512 instructions for storing the C registers back to the memory - * in 16-bit Accumulation kernel. - */ -template <> -template <> -void CodeGenBase::storeCRegs< - inst_set_t::avx512>( - asmjit::X86Emitter* a, - int rowRegs, - int colRegs, - asmjit::X86Gp C_Offset, - asmjit::X86Gp ldcReg, - bool accum, - int leadingDimCRegAssign) { - asmjit::X86Ymm extractDest256 = x86::ymm31; - asmjit::X86Zmm extractDest512 = x86::zmm31; - - for (int i = 0; i < rowRegs; ++i) { - a->imul(C_Offset, ldcReg, static_cast(i * sizeof(int32_t))); - for (int j = 0; j < colRegs; ++j) { - for (int idx = 0; idx < 2; ++idx) { - a->vextracti32x8( - extractDest256, CRegs_avx512_[i * leadingDimCRegAssign + j], idx); - a->vpmovsxwd(extractDest512, extractDest256); - asmjit::X86Mem destAddr = x86::dword_ptr( - a->zcx(), C_Offset, 0, (j * 2 + idx) * 16 * sizeof(int32_t)); - if (accum) { - a->vpaddd(extractDest512, extractDest512, destAddr); - } - a->vmovups(destAddr, extractDest512); - } - } - } -} - -/** - * Get or Create the AVX512 instructions for 16-bit Accumulation macro-kernel. - * - */ -template <> -template <> -CodeGenBase::jit_micro_kernel_fp -CodeGenBase::getOrCreate( - bool accum, - int32_t mc, - int32_t nc, - int32_t kc, - int32_t /* unused */) { - auto kernelSig = std::make_tuple(accum, mc, nc); - if (codeCache_.find(kernelSig) != codeCache_.end()) { - return codeCache_[kernelSig]; - } - - code_.reset(false); - code_.init(rt_.getCodeInfo()); - asmjit::X86Assembler assembler(&code_); - asmjit::X86Emitter* a = assembler.asEmitter(); - // ToDo: Dump in a file for debugging - // code dumping/logging - // asmjit::FileLogger logger(stderr); - // code_.setLogger(&logger); - - constexpr int kBlock = - PackingTraits::KCB; - constexpr int mRegBlockSize = - PackingTraits::MR; - // constexpr int nRegBlockSize = - // PackingTraits::NR; - constexpr int row_interleave = - PackingTraits::ROW_INTERLEAVE; - int mRegBlocks = mc / mRegBlockSize; - int mRegBlocksRem = mc % mRegBlockSize; - assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); - // assert((nc == nRegBlockSize) && - //"nc must be equal to the number of register blocks"); - - // arguments to the function created - asmjit::X86Gp buffer_A = a->zdi(); - asmjit::X86Gp buffer_B = a->zsi(); - asmjit::X86Gp B_pf = a->zdx(); - asmjit::X86Gp CBase = a->zcx(); - asmjit::X86Gp kSize = a->gpzRef(8); - asmjit::X86Gp ldcReg = a->gpzRef(9); - - asmjit::FuncDetail func; - func.init( - asmjit:: - FuncSignature6( - asmjit::CallConv::kIdHost)); - - asmjit::FuncFrameInfo ffi; - ffi.setDirtyRegs( - asmjit::X86Reg::kKindVec, - asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); - ffi.setDirtyRegs( - asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14)); - - asmjit::FuncArgsMapper args(&func); - args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); - - args.updateFrameInfo(ffi); - - asmjit::FuncFrameLayout layout; - layout.init(func, ffi); - - asmjit::FuncUtils::emitProlog(a, layout); - asmjit::FuncUtils::allocArgs(a, layout, args); - - asmjit::Label Loopk = a->newLabel(); - asmjit::Label LoopMBlocks = a->newLabel(); - - asmjit::X86Gp buffer_B_saved = a->gpzRef(10); - asmjit::X86Gp C_Offset = a->gpzRef(11); - // asmjit::X86Gp B_pf_saved = a->gpzRef(12); - asmjit::X86Gp iIdx = a->gpzRef(13); - asmjit::X86Gp kIdx = a->gpzRef(14); - - int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; - if (mRegBlocks > 0) { - // move 0 to iteration variables - a->mov(iIdx, 0); - - // save B_buffer address - a->mov(buffer_B_saved, buffer_B); - // a->mov(B_pf_saved, B_pf); - - a->bind(LoopMBlocks); - a->inc(iIdx); - - int rowRegs = mRegBlockSize; - - // init C registers - initCRegs(a, rowRegs, colRegs); - - // init k loop index - a->mov(kIdx, 0); - a->bind(Loopk); - // k is incremented by row_interleave - a->add(kIdx, static_cast(row_interleave)); - - genComputeBlock( - a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); - - // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast(row_interleave * sizeof(uint8_t))); - - // update buffer_B address for next k iteration - a->add( - buffer_B, static_cast(VLEN_ * colRegs * sizeof(int8_t))); - // a->add(B_pf, static_cast(VLEN_ * colRegs * sizeof(int8_t))); - - a->cmp(kIdx, kSize); - a->jl(Loopk); - - // store C matrix - storeCRegs( - a, rowRegs, colRegs, C_Offset, ldcReg, accum); - - // increment A for next block - a->sub(buffer_A, kSize); - a->add( - buffer_A, static_cast((rowRegs)*kBlock * sizeof(uint8_t))); - // increment C for next block - a->imul( - C_Offset, ldcReg, static_cast(rowRegs * sizeof(int32_t))); - a->add(CBase, C_Offset); - // reset B - a->mov(buffer_B, buffer_B_saved); - // a->mov(B_pf, B_pf_saved); - - a->cmp(iIdx, mRegBlocks); - a->jl(LoopMBlocks); - } - // generate code for remainder - if (mRegBlocksRem > 0) { - asmjit::Label LoopkRem = a->newLabel(); - int rowRegs = mRegBlocksRem; - - // init C registers - initCRegs(a, rowRegs, colRegs); - - // init k loop index - a->mov(kIdx, 0); - a->bind(LoopkRem); - - // k is incremented by row_interleave - a->add(kIdx, static_cast(row_interleave)); - - genComputeBlock( - a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock); - - // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast(row_interleave * sizeof(uint8_t))); - - // update buffer_B address for next k iteration - a->add( - buffer_B, static_cast(VLEN_ * colRegs * sizeof(int8_t))); - // a->add(B_pf, static_cast(VLEN_ * colRegs * sizeof(int8_t))); - - a->cmp(kIdx, kSize); - a->jl(LoopkRem); - - // store C matrix - storeCRegs( - a, rowRegs, colRegs, C_Offset, ldcReg, accum); - } - - asmjit::FuncUtils::emitEpilog(a, layout); - - jit_micro_kernel_fp fn; - asmjit::Error err = rt_.add(&fn, &code_); - if (err) { - std::cout << "Error: in fn add" << std::endl; - return nullptr; - } - codeCache_[kernelSig] = fn; - return fn; -} - -} // namespace fbgemm diff --git a/src/GenerateKernelU8S8S32ACC32Avx512.cc b/src/GenerateKernelU8S8S32ACC32Avx512.cc new file mode 100644 index 0000000..0621bb0 --- /dev/null +++ b/src/GenerateKernelU8S8S32ACC32Avx512.cc @@ -0,0 +1,317 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include "GenerateKernel.h" + +namespace fbgemm { + +namespace x86 = asmjit::x86; + +/** + * Generate AVX512 instructions for initializing the C registers to 0 in 32-bit + * Accumulation kernel. + */ +template <> +template <> +void CodeGenBase::initCRegs< + inst_set_t::avx512>( + asmjit::X86Emitter* a, + int rowRegs, + int colRegs, + int leadingDimCReg) { + for (int i = 0; i < rowRegs; ++i) { + for (int j = 0; j < colRegs; ++j) { + a->vxorps( + CRegs_avx512_[i * leadingDimCReg + j], + CRegs_avx512_[i * leadingDimCReg + j], + CRegs_avx512_[i * leadingDimCReg + j]); + } + } +} + +/** + * Generate AVX512 instructions for computing block in the rank-k update of + * 32-bit Accmulation kernel. + */ +template <> +template <> +void CodeGenBase::genComputeBlock< + inst_set_t::avx512>( + asmjit::X86Emitter* a, + asmjit::X86Gp buffer_A, + asmjit::X86Gp buffer_B, + asmjit::X86Gp B_pf, + int rowRegs, + int colRegs, + int lda, + int leadingDimCRegAssign) { + // used for matrix A + asmjit::X86Zmm AReg = x86::zmm31; + + // used for matrix B + asmjit::X86Zmm BReg = x86::zmm30; + + // Contains 16-bit 1s + asmjit::X86Zmm oneReg = x86::zmm29; + + // temporary register + asmjit::X86Zmm res1 = x86::zmm28; + + for (int j = 0; j < colRegs; ++j) { + // load B + a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t))); + // load A, broadcast and fmas + for (int i = 0; i < rowRegs; ++i) { + a->vpbroadcastd( + AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t))); + a->vpmaddubsw(res1, AReg, BReg); + a->vpmaddwd(res1, oneReg, res1); + a->vpaddd( + CRegs_avx512_[i * leadingDimCRegAssign + j], + res1, + CRegs_avx512_[i * leadingDimCRegAssign + j]); + } + a->prefetcht0(x86::dword_ptr(B_pf, j * VLEN_ * sizeof(int8_t))); + } +} + +/** + * Generate AVX512 instructions for storing the C registers back to the memory + * in 32-bit Accumulation kernel. + */ +template <> +template <> +void CodeGenBase::storeCRegs< + inst_set_t::avx512>( + asmjit::X86Emitter* a, + int rowRegs, + int colRegs, + asmjit::X86Gp C_Offset, + asmjit::X86Gp ldcReg, + bool accum, + int leadingDimCRegAssign) { + // temp register + asmjit::X86Zmm tmpReg = x86::zmm28; + + for (int i = 0; i < rowRegs; ++i) { + if (i != 0) { + a->add(C_Offset, ldcReg); + } + for (int j = 0; j < colRegs; ++j) { + if (accum) { + a->vpaddd( + CRegs_avx512_[i * leadingDimCRegAssign + j], + CRegs_avx512_[i * leadingDimCRegAssign + j], + x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t))); + } + a->vmovups( + x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)), + CRegs_avx512_[i * leadingDimCRegAssign + j]); + } + } +} + +/** + * Get or Create the AVX512 instructions for 32-bit Accumulation macro-kernel. + * + */ +template <> +template <> +CodeGenBase::jit_micro_kernel_fp +CodeGenBase::getOrCreate( + bool accum, + int32_t mc, + int32_t nc, + int32_t kc, + int32_t /* unused */) { + auto kernelSig = std::make_tuple(accum, mc, nc); + if (codeCache_.find(kernelSig) != codeCache_.end()) { + return codeCache_[kernelSig]; + } + + code_.reset(false); + code_.init(rt_.getCodeInfo()); + asmjit::X86Assembler assembler(&code_); + asmjit::X86Emitter* a = assembler.asEmitter(); + // ToDo: Dump in a file for debugging + // code dumping/logging + // asmjit::FileLogger logger(stderr); + // code_.setLogger(&logger); + + constexpr int kBlock = + PackingTraits::KCB; + constexpr int mRegBlockSize = + PackingTraits::MR; + constexpr int row_interleave = + PackingTraits::ROW_INTERLEAVE; + assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); + // assert(mc <= 12 && "mc must be <= 12 (available registers constraint)"); + int mRegBlocks = mc / mRegBlockSize; + int mRegBlocksRem = mc % mRegBlockSize; + + // arguments to the function created + asmjit::X86Gp buffer_A = a->zdi(); + asmjit::X86Gp buffer_B = a->zsi(); + asmjit::X86Gp B_pf = a->zdx(); + asmjit::X86Gp CBase = a->zcx(); + asmjit::X86Gp kSize = a->gpzRef(8); + asmjit::X86Gp ldcReg = a->gpzRef(9); + + asmjit::FuncDetail func; + func.init( + asmjit:: + FuncSignature6( + asmjit::CallConv::kIdHost)); + + asmjit::FuncFrameInfo ffi; + ffi.setDirtyRegs( + asmjit::X86Reg::kKindVec, + asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); + ffi.setDirtyRegs( + asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14)); + + asmjit::FuncArgsMapper args(&func); + args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); + + args.updateFrameInfo(ffi); + + asmjit::FuncFrameLayout layout; + layout.init(func, ffi); + + asmjit::FuncUtils::emitProlog(a, layout); + asmjit::FuncUtils::allocArgs(a, layout, args); + + asmjit::Label Loopk = a->newLabel(); + asmjit::Label LoopMBlocks = a->newLabel(); + + asmjit::X86Gp buffer_B_saved = a->gpzRef(10); + asmjit::X86Gp C_Offset = a->gpzRef(11); + asmjit::X86Gp B_pf_saved = a->gpzRef(12); + asmjit::X86Gp iIdx = a->gpzRef(13); + asmjit::X86Gp kIdx = a->gpzRef(14); + // asmjit::X86Gp B_pf = a->gpzRef(8); + + asmjit::X86Zmm oneReg = x86::zmm29; + // create 16-bit 1s + // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001 + // and so on + // a->vpcmpeqw(oneReg, oneReg, oneReg); + a->vpternlogd(oneReg, oneReg, oneReg, 0xff); + a->vpsrlw(oneReg, oneReg, 15); + a->imul(ldcReg, ldcReg, static_cast(sizeof(int32_t))); + a->mov(C_Offset, 0); + + int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; + if (mRegBlocks > 0) { + // move 0 to iteration variables + a->mov(iIdx, 0); + + // save B_buffer address + a->mov(buffer_B_saved, buffer_B); + a->mov(B_pf_saved, B_pf); + + a->bind(LoopMBlocks); + a->inc(iIdx); + + int rowRegs = mRegBlockSize; + + // init C registers + initCRegs(a, rowRegs, colRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(Loopk); + + // k is incremented by row_interleave + a->add(kIdx, static_cast(row_interleave)); + + genComputeBlock( + a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); + + // update buffer_A address for next k iteration + a->add( + buffer_A, static_cast(row_interleave * sizeof(uint8_t))); + + // update buffer_B address for next k iteration + a->add( + buffer_B, static_cast(VLEN_ * colRegs * sizeof(int8_t))); + a->add(B_pf, static_cast(VLEN_ * colRegs * sizeof(int8_t))); + + // a->add(B_pf, static_cast(32*sizeof(float))); + + a->cmp(kIdx, kSize); + a->jl(Loopk); + + // store C matrix + storeCRegs( + a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); + + // increment A for next block + a->sub(buffer_A, kSize); + a->add( + buffer_A, static_cast((rowRegs)*kBlock * sizeof(uint8_t))); + + // increment C for next block + a->imul(C_Offset, ldcReg, static_cast(rowRegs)); + a->add(CBase, C_Offset); + a->mov(C_Offset, 0); + + // reset B + a->mov(buffer_B, buffer_B_saved); + a->mov(B_pf, B_pf_saved); + a->cmp(iIdx, mRegBlocks); + a->jl(LoopMBlocks); + } + // generate code for remainder + if (mRegBlocksRem > 0) { + asmjit::Label LoopkRem = a->newLabel(); + int rowRegs = mRegBlocksRem; + + // init C registers + initCRegs(a, rowRegs, colRegs, colRegs); + + // init k loop index + a->mov(kIdx, 0); + a->bind(LoopkRem); + + // k is incremented by row_interleave + a->add(kIdx, static_cast(row_interleave)); + + genComputeBlock( + a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); + + // update buffer_A address for next k iteration + a->add( + buffer_A, static_cast(row_interleave * sizeof(uint8_t))); + + // update buffer_B address for next k iteration + a->add( + buffer_B, static_cast(VLEN_ * colRegs * sizeof(int8_t))); + a->add(B_pf, static_cast(VLEN_ * colRegs * sizeof(int8_t))); + + a->cmp(kIdx, kSize); + a->jl(LoopkRem); + + // store C matrix + storeCRegs( + a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); + } + + asmjit::FuncUtils::emitEpilog(a, layout); + + jit_micro_kernel_fp fn; + asmjit::Error err = rt_.add(&fn, &code_); + if (err) { + std::cout << "Error: in fn add" << std::endl; + return nullptr; + } + codeCache_[kernelSig] = fn; + return fn; +} + +} // namespace fbgemm diff --git a/src/GenerateKernelU8S8S32ACC32_avx512.cc b/src/GenerateKernelU8S8S32ACC32_avx512.cc deleted file mode 100644 index 0621bb0..0000000 --- a/src/GenerateKernelU8S8S32ACC32_avx512.cc +++ /dev/null @@ -1,317 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * All rights reserved. - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ -#include -#include "GenerateKernel.h" - -namespace fbgemm { - -namespace x86 = asmjit::x86; - -/** - * Generate AVX512 instructions for initializing the C registers to 0 in 32-bit - * Accumulation kernel. - */ -template <> -template <> -void CodeGenBase::initCRegs< - inst_set_t::avx512>( - asmjit::X86Emitter* a, - int rowRegs, - int colRegs, - int leadingDimCReg) { - for (int i = 0; i < rowRegs; ++i) { - for (int j = 0; j < colRegs; ++j) { - a->vxorps( - CRegs_avx512_[i * leadingDimCReg + j], - CRegs_avx512_[i * leadingDimCReg + j], - CRegs_avx512_[i * leadingDimCReg + j]); - } - } -} - -/** - * Generate AVX512 instructions for computing block in the rank-k update of - * 32-bit Accmulation kernel. - */ -template <> -template <> -void CodeGenBase::genComputeBlock< - inst_set_t::avx512>( - asmjit::X86Emitter* a, - asmjit::X86Gp buffer_A, - asmjit::X86Gp buffer_B, - asmjit::X86Gp B_pf, - int rowRegs, - int colRegs, - int lda, - int leadingDimCRegAssign) { - // used for matrix A - asmjit::X86Zmm AReg = x86::zmm31; - - // used for matrix B - asmjit::X86Zmm BReg = x86::zmm30; - - // Contains 16-bit 1s - asmjit::X86Zmm oneReg = x86::zmm29; - - // temporary register - asmjit::X86Zmm res1 = x86::zmm28; - - for (int j = 0; j < colRegs; ++j) { - // load B - a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t))); - // load A, broadcast and fmas - for (int i = 0; i < rowRegs; ++i) { - a->vpbroadcastd( - AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t))); - a->vpmaddubsw(res1, AReg, BReg); - a->vpmaddwd(res1, oneReg, res1); - a->vpaddd( - CRegs_avx512_[i * leadingDimCRegAssign + j], - res1, - CRegs_avx512_[i * leadingDimCRegAssign + j]); - } - a->prefetcht0(x86::dword_ptr(B_pf, j * VLEN_ * sizeof(int8_t))); - } -} - -/** - * Generate AVX512 instructions for storing the C registers back to the memory - * in 32-bit Accumulation kernel. - */ -template <> -template <> -void CodeGenBase::storeCRegs< - inst_set_t::avx512>( - asmjit::X86Emitter* a, - int rowRegs, - int colRegs, - asmjit::X86Gp C_Offset, - asmjit::X86Gp ldcReg, - bool accum, - int leadingDimCRegAssign) { - // temp register - asmjit::X86Zmm tmpReg = x86::zmm28; - - for (int i = 0; i < rowRegs; ++i) { - if (i != 0) { - a->add(C_Offset, ldcReg); - } - for (int j = 0; j < colRegs; ++j) { - if (accum) { - a->vpaddd( - CRegs_avx512_[i * leadingDimCRegAssign + j], - CRegs_avx512_[i * leadingDimCRegAssign + j], - x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t))); - } - a->vmovups( - x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)), - CRegs_avx512_[i * leadingDimCRegAssign + j]); - } - } -} - -/** - * Get or Create the AVX512 instructions for 32-bit Accumulation macro-kernel. - * - */ -template <> -template <> -CodeGenBase::jit_micro_kernel_fp -CodeGenBase::getOrCreate( - bool accum, - int32_t mc, - int32_t nc, - int32_t kc, - int32_t /* unused */) { - auto kernelSig = std::make_tuple(accum, mc, nc); - if (codeCache_.find(kernelSig) != codeCache_.end()) { - return codeCache_[kernelSig]; - } - - code_.reset(false); - code_.init(rt_.getCodeInfo()); - asmjit::X86Assembler assembler(&code_); - asmjit::X86Emitter* a = assembler.asEmitter(); - // ToDo: Dump in a file for debugging - // code dumping/logging - // asmjit::FileLogger logger(stderr); - // code_.setLogger(&logger); - - constexpr int kBlock = - PackingTraits::KCB; - constexpr int mRegBlockSize = - PackingTraits::MR; - constexpr int row_interleave = - PackingTraits::ROW_INTERLEAVE; - assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); - // assert(mc <= 12 && "mc must be <= 12 (available registers constraint)"); - int mRegBlocks = mc / mRegBlockSize; - int mRegBlocksRem = mc % mRegBlockSize; - - // arguments to the function created - asmjit::X86Gp buffer_A = a->zdi(); - asmjit::X86Gp buffer_B = a->zsi(); - asmjit::X86Gp B_pf = a->zdx(); - asmjit::X86Gp CBase = a->zcx(); - asmjit::X86Gp kSize = a->gpzRef(8); - asmjit::X86Gp ldcReg = a->gpzRef(9); - - asmjit::FuncDetail func; - func.init( - asmjit:: - FuncSignature6( - asmjit::CallConv::kIdHost)); - - asmjit::FuncFrameInfo ffi; - ffi.setDirtyRegs( - asmjit::X86Reg::kKindVec, - asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); - ffi.setDirtyRegs( - asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14)); - - asmjit::FuncArgsMapper args(&func); - args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); - - args.updateFrameInfo(ffi); - - asmjit::FuncFrameLayout layout; - layout.init(func, ffi); - - asmjit::FuncUtils::emitProlog(a, layout); - asmjit::FuncUtils::allocArgs(a, layout, args); - - asmjit::Label Loopk = a->newLabel(); - asmjit::Label LoopMBlocks = a->newLabel(); - - asmjit::X86Gp buffer_B_saved = a->gpzRef(10); - asmjit::X86Gp C_Offset = a->gpzRef(11); - asmjit::X86Gp B_pf_saved = a->gpzRef(12); - asmjit::X86Gp iIdx = a->gpzRef(13); - asmjit::X86Gp kIdx = a->gpzRef(14); - // asmjit::X86Gp B_pf = a->gpzRef(8); - - asmjit::X86Zmm oneReg = x86::zmm29; - // create 16-bit 1s - // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001 - // and so on - // a->vpcmpeqw(oneReg, oneReg, oneReg); - a->vpternlogd(oneReg, oneReg, oneReg, 0xff); - a->vpsrlw(oneReg, oneReg, 15); - a->imul(ldcReg, ldcReg, static_cast(sizeof(int32_t))); - a->mov(C_Offset, 0); - - int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; - if (mRegBlocks > 0) { - // move 0 to iteration variables - a->mov(iIdx, 0); - - // save B_buffer address - a->mov(buffer_B_saved, buffer_B); - a->mov(B_pf_saved, B_pf); - - a->bind(LoopMBlocks); - a->inc(iIdx); - - int rowRegs = mRegBlockSize; - - // init C registers - initCRegs(a, rowRegs, colRegs, colRegs); - - // init k loop index - a->mov(kIdx, 0); - a->bind(Loopk); - - // k is incremented by row_interleave - a->add(kIdx, static_cast(row_interleave)); - - genComputeBlock( - a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); - - // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast(row_interleave * sizeof(uint8_t))); - - // update buffer_B address for next k iteration - a->add( - buffer_B, static_cast(VLEN_ * colRegs * sizeof(int8_t))); - a->add(B_pf, static_cast(VLEN_ * colRegs * sizeof(int8_t))); - - // a->add(B_pf, static_cast(32*sizeof(float))); - - a->cmp(kIdx, kSize); - a->jl(Loopk); - - // store C matrix - storeCRegs( - a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); - - // increment A for next block - a->sub(buffer_A, kSize); - a->add( - buffer_A, static_cast((rowRegs)*kBlock * sizeof(uint8_t))); - - // increment C for next block - a->imul(C_Offset, ldcReg, static_cast(rowRegs)); - a->add(CBase, C_Offset); - a->mov(C_Offset, 0); - - // reset B - a->mov(buffer_B, buffer_B_saved); - a->mov(B_pf, B_pf_saved); - a->cmp(iIdx, mRegBlocks); - a->jl(LoopMBlocks); - } - // generate code for remainder - if (mRegBlocksRem > 0) { - asmjit::Label LoopkRem = a->newLabel(); - int rowRegs = mRegBlocksRem; - - // init C registers - initCRegs(a, rowRegs, colRegs, colRegs); - - // init k loop index - a->mov(kIdx, 0); - a->bind(LoopkRem); - - // k is incremented by row_interleave - a->add(kIdx, static_cast(row_interleave)); - - genComputeBlock( - a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); - - // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast(row_interleave * sizeof(uint8_t))); - - // update buffer_B address for next k iteration - a->add( - buffer_B, static_cast(VLEN_ * colRegs * sizeof(int8_t))); - a->add(B_pf, static_cast(VLEN_ * colRegs * sizeof(int8_t))); - - a->cmp(kIdx, kSize); - a->jl(LoopkRem); - - // store C matrix - storeCRegs( - a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); - } - - asmjit::FuncUtils::emitEpilog(a, layout); - - jit_micro_kernel_fp fn; - asmjit::Error err = rt_.add(&fn, &code_); - if (err) { - std::cout << "Error: in fn add" << std::endl; - return nullptr; - } - codeCache_[kernelSig] = fn; - return fn; -} - -} // namespace fbgemm diff --git a/src/UtilsAvx2.cc b/src/UtilsAvx2.cc new file mode 100644 index 0000000..badf70b --- /dev/null +++ b/src/UtilsAvx2.cc @@ -0,0 +1,169 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include "TransposeUtils.h" +#include + +namespace fbgemm { + +namespace internal { + +inline void +transpose_kernel_4x4_sse(const float* src, int ld_src, float* dst, int ld_dst) { + // load from src to registers + // a : a0 a1 a2 a3 + // b : b0 b1 b2 b3 + // c : c0 c1 c2 c3 + // d : d0 d1 d2 d3 + __m128 a = _mm_loadu_ps(&src[0 * ld_src]); + __m128 b = _mm_loadu_ps(&src[1 * ld_src]); + __m128 c = _mm_loadu_ps(&src[2 * ld_src]); + __m128 d = _mm_loadu_ps(&src[3 * ld_src]); + + // transpose the 4x4 matrix formed by 32-bit elements: Macro from SSE + // a : a0 b0 c0 d0 + // b : a1 b1 c1 d1 + // c : a2 b2 c2 d2 + // d : a3 b3 c3 d3 + _MM_TRANSPOSE4_PS(a, b, c, d); + + // store from registers to dst + _mm_storeu_ps(&dst[0 * ld_dst], a); + _mm_storeu_ps(&dst[1 * ld_dst], b); + _mm_storeu_ps(&dst[2 * ld_dst], c); + _mm_storeu_ps(&dst[3 * ld_dst], d); +} + +inline void transpose_4x4( + int M, + int N, + const float* src, + int ld_src, + float* dst, + int ld_dst) { + int ib = 0, jb = 0; + for (ib = 0; ib + 4 <= M; ib += 4) { + for (jb = 0; jb + 4 <= N; jb += 4) { + transpose_kernel_4x4_sse( + &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); + } + } + transpose_ref(ib, N - jb, &src[jb], ld_src, &dst[jb * ld_dst], ld_dst); + transpose_ref(M - ib, N, &src[ib * ld_src], ld_src, &dst[ib], ld_dst); +} + +inline void transpose_kernel_8x8_avx2( + const float* src, + int ld_src, + float* dst, + int ld_dst) { + // load from src to registers + // a : a0 a1 a2 a3 a4 a5 a6 a7 + // b : b0 b1 b2 b3 b4 b5 b6 b7 + // c : c0 c1 c2 c3 c4 c5 c6 c7 + // d : d0 d1 d2 d3 d4 d5 d6 d7 + // e : e0 e1 e2 e3 e4 e5 e6 e7 + // f : f0 f1 f2 f3 f4 f5 f6 f7 + // g : g0 g1 g2 g3 g4 g5 g6 g7 + // h : h0 h1 h2 h3 h4 h5 h6 h7 + __m256 a = _mm256_loadu_ps(&src[0 * ld_src]); + __m256 b = _mm256_loadu_ps(&src[1 * ld_src]); + __m256 c = _mm256_loadu_ps(&src[2 * ld_src]); + __m256 d = _mm256_loadu_ps(&src[3 * ld_src]); + __m256 e = _mm256_loadu_ps(&src[4 * ld_src]); + __m256 f = _mm256_loadu_ps(&src[5 * ld_src]); + __m256 g = _mm256_loadu_ps(&src[6 * ld_src]); + __m256 h = _mm256_loadu_ps(&src[7 * ld_src]); + + __m256 ab0145, ab2367, cd0145, cd2367, ef0145, ef2367, gh0145, gh2367; + __m256 abcd04, abcd15, efgh04, efgh15, abcd26, abcd37, efgh26, efgh37; + // unpacking and interleaving 32-bit elements + // ab0145 : a0 b0 a1 b1 a4 b4 a5 b5 + // ab2367 : a2 b2 a3 b3 a6 b6 a7 b7 + // cd0145 : c0 d0 c1 d1 c4 d4 c5 d5 + // cd2367 : c2 d2 c3 d3 c6 d6 c7 d7 + // ef0145 : e0 f0 e1 f1 e4 f4 e5 f5 + // ef2367 : e2 f2 e3 f3 e6 f6 e7 f7 + // gh0145 : g0 h0 g1 h1 g4 h4 g5 h5 + // gh2367 : g2 h2 g3 h3 g6 h6 g7 h7 + ab0145 = _mm256_unpacklo_ps(a, b); + ab2367 = _mm256_unpackhi_ps(a, b); + cd0145 = _mm256_unpacklo_ps(c, d); + cd2367 = _mm256_unpackhi_ps(c, d); + ef0145 = _mm256_unpacklo_ps(e, f); + ef2367 = _mm256_unpackhi_ps(e, f); + gh0145 = _mm256_unpacklo_ps(g, h); + gh2367 = _mm256_unpackhi_ps(g, h); + + // shuffling the 32-bit elements + // abcd04 : a0 b0 c0 d0 a4 b4 c4 d4 + // abcd15 : a1 b1 c1 d1 a5 b5 c5 d5 + // efgh04 : e0 f0 g0 h0 e4 f4 g4 h4 + // efgh15 : e1 f1 g1 h1 e5 b5 c5 d5 + // abcd26 : a2 b2 c2 d2 a6 b6 c6 d6 + // abcd37 : a3 b3 c3 d3 a7 b7 c7 d7 + // efgh26 : e2 f2 g2 h2 e6 f6 g6 h6 + // efgh37 : e3 f3 g3 h3 e7 f7 g7 h7 + abcd04 = _mm256_shuffle_ps(ab0145, cd0145, 0x44); + abcd15 = _mm256_shuffle_ps(ab0145, cd0145, 0xee); + efgh04 = _mm256_shuffle_ps(ef0145, gh0145, 0x44); + efgh15 = _mm256_shuffle_ps(ef0145, gh0145, 0xee); + abcd26 = _mm256_shuffle_ps(ab2367, cd2367, 0x44); + abcd37 = _mm256_shuffle_ps(ab2367, cd2367, 0xee); + efgh26 = _mm256_shuffle_ps(ef2367, gh2367, 0x44); + efgh37 = _mm256_shuffle_ps(ef2367, gh2367, 0xee); + + // shuffling 128-bit elements + // a : a0 b0 c0 d0 e0 f0 g0 h0 + // b : a1 b1 c1 d1 e1 f1 g1 h1 + // c : a2 b2 c2 d2 e2 f2 g2 h2 + // d : a3 b3 c3 d3 e3 f3 g3 h3 + // e : a4 b4 c4 d4 e4 f4 g4 h4 + // f : a5 b5 c5 d5 e5 f5 g5 h5 + // g : a6 b6 c6 d6 e6 f6 g6 h6 + // h : a7 b7 c7 d7 e7 f7 g7 h7 + a = _mm256_permute2f128_ps(efgh04, abcd04, 0x02); + b = _mm256_permute2f128_ps(efgh15, abcd15, 0x02); + c = _mm256_permute2f128_ps(efgh26, abcd26, 0x02); + d = _mm256_permute2f128_ps(efgh37, abcd37, 0x02); + e = _mm256_permute2f128_ps(efgh04, abcd04, 0x13); + f = _mm256_permute2f128_ps(efgh15, abcd15, 0x13); + g = _mm256_permute2f128_ps(efgh26, abcd26, 0x13); + h = _mm256_permute2f128_ps(efgh37, abcd37, 0x13); + + // store from registers to dst + _mm256_storeu_ps(&dst[0 * ld_dst], a); + _mm256_storeu_ps(&dst[1 * ld_dst], b); + _mm256_storeu_ps(&dst[2 * ld_dst], c); + _mm256_storeu_ps(&dst[3 * ld_dst], d); + _mm256_storeu_ps(&dst[4 * ld_dst], e); + _mm256_storeu_ps(&dst[5 * ld_dst], f); + _mm256_storeu_ps(&dst[6 * ld_dst], g); + _mm256_storeu_ps(&dst[7 * ld_dst], h); +} + + +void transpose_8x8( + int M, + int N, + const float* src, + int ld_src, + float* dst, + int ld_dst) { + int ib = 0, jb = 0; + for (ib = 0; ib + 8 <= M; ib += 8) { + for (jb = 0; jb + 8 <= N; jb += 8) { + transpose_kernel_8x8_avx2( + &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); + } + } + transpose_4x4(ib, N - jb, &src[jb], ld_src, &dst[jb * ld_dst], ld_dst); + transpose_4x4(M - ib, N, &src[ib * ld_src], ld_src, &dst[ib], ld_dst); +} + +} // namespace internal + +} // namespace fbgemm diff --git a/src/UtilsAvx512.cc b/src/UtilsAvx512.cc new file mode 100644 index 0000000..f49bb6f --- /dev/null +++ b/src/UtilsAvx512.cc @@ -0,0 +1,246 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "TransposeUtils.h" +#include + +namespace fbgemm { + +namespace internal { + +inline void transpose_kernel_16x16_avx512( + const float* src, + int ld_src, + float* dst, + int ld_dst) { + // load from src to registers + // a: a0 a1 a2 a3 a4 a5 a6 a7 a8 a9 a10 a11 a12 a13 a14 a15 + // b: b0 b1 b2 b3 b4 b5 b6 b7 b8 b9 b10 b11 b12 b13 b14 b15 + // c: c0 c1 c2 c3 c4 c5 c6 c7 c8 c9 c10 c11 c12 c13 c14 c15 + // d: d0 d1 d2 d3 d4 d5 d6 d7 d8 d9 d10 d11 d12 d13 d14 d15 + // e: e0 e1 e2 e3 e4 e5 e6 e7 e8 e9 e10 e11 e12 e13 e14 e15 + // f: f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 f10 f11 f12 f13 f14 f15 + // g: g0 g1 g2 g3 g4 g5 g6 g7 g8 g9 g10 g11 g12 g13 g14 g15 + // h: h0 h1 h2 h3 h4 h5 h6 h7 h8 h9 h10 h11 h12 h13 h14 h15 + // i: i0 i1 i2 i3 i4 i5 i6 i7 i8 i9 i10 i11 i12 i13 i14 i15 + // j: j0 j1 j2 j3 j4 j5 j6 j7 j8 j9 j10 j11 j12 j13 j14 j15 + // k: k0 k1 k2 k3 k4 k5 k6 k7 k8 k9 k10 k11 k12 k13 k14 k15 + // l: l0 l1 l2 l3 l4 l5 l6 l7 l8 l9 l10 l11 l12 l13 l14 l15 + // m: m0 m1 m2 m3 m4 m5 m6 m7 m8 m9 m10 m11 m12 m13 m14 m15 + // n: n0 n1 n2 n3 n4 n5 n6 n7 n8 n9 n10 n11 n12 n13 n14 n15 + // o: o0 o1 o2 o3 o4 o5 o6 o7 o8 o9 o10 o11 o12 o13 o14 o15 + // p: p0 p1 p2 p3 p4 p5 p6 p7 p8 p9 p10 p11 p12 p13 p14 p15 + __m512 a = _mm512_loadu_ps(&src[0 * ld_src]); + __m512 b = _mm512_loadu_ps(&src[1 * ld_src]); + __m512 c = _mm512_loadu_ps(&src[2 * ld_src]); + __m512 d = _mm512_loadu_ps(&src[3 * ld_src]); + __m512 e = _mm512_loadu_ps(&src[4 * ld_src]); + __m512 f = _mm512_loadu_ps(&src[5 * ld_src]); + __m512 g = _mm512_loadu_ps(&src[6 * ld_src]); + __m512 h = _mm512_loadu_ps(&src[7 * ld_src]); + __m512 i = _mm512_loadu_ps(&src[8 * ld_src]); + __m512 j = _mm512_loadu_ps(&src[9 * ld_src]); + __m512 k = _mm512_loadu_ps(&src[10 * ld_src]); + __m512 l = _mm512_loadu_ps(&src[11 * ld_src]); + __m512 m = _mm512_loadu_ps(&src[12 * ld_src]); + __m512 n = _mm512_loadu_ps(&src[13 * ld_src]); + __m512 o = _mm512_loadu_ps(&src[14 * ld_src]); + __m512 p = _mm512_loadu_ps(&src[15 * ld_src]); + + __m512 ta, tb, tc, td, te, tf, tg, th, ti, tj, tk, tl, tm, tn, to, tq; + // unpacking and interleaving 32-bit elements + // a0 b0 a1 b1 a4 b4 a5 b5 a8 b8 a9 b9 a12 b12 a13 b13 + // a2 b2 a3 b3 a6 b6 a7 b7 a10 b10 a11 b11 a14 b14 a15 b15 + // c0 d0 c1 d1 ... + // c2 d2 c3 d3 ... + // e0 f0 e1 f1 ... + // e2 f2 e3 f3 ... + // g0 h0 g1 h1 ... + // g2 h2 g3 h3 ... + // i0 ... + // i2 ... + // k0 ... + // k2 ... + // m0 ... + // m2 ... + // o0 ... + // o1 ... + ta = _mm512_unpacklo_ps(a, b); + tb = _mm512_unpackhi_ps(a, b); + tc = _mm512_unpacklo_ps(c, d); + td = _mm512_unpackhi_ps(c, d); + te = _mm512_unpacklo_ps(e, f); + tf = _mm512_unpackhi_ps(e, f); + tg = _mm512_unpacklo_ps(g, h); + th = _mm512_unpackhi_ps(g, h); + ti = _mm512_unpacklo_ps(i, j); + tj = _mm512_unpackhi_ps(i, j); + tk = _mm512_unpacklo_ps(k, l); + tl = _mm512_unpackhi_ps(k, l); + tm = _mm512_unpacklo_ps(m, n); + tn = _mm512_unpackhi_ps(m, n); + to = _mm512_unpacklo_ps(o, p); + tq = _mm512_unpackhi_ps(o, p); + + // unpacking and interleaving 64-bit elements + // a0 b0 c0 d0 a4 b4 c4 d4 a8 b8 c8 d8 a12 b12 c12 d12 + // a1 b1 c1 d1 ... + // a2 b2 c2 d2 ... + // a3 b3 c3 d3 ... + // e0 f0 g0 h0 e4 f4 g4 h4 e8 f8 g8 h8 e12 f12 g12 h12 + // e1 f1 g1 h1 ... + // e2 f2 g2 h2 ... + // e3 f3 g3 h3 ... + // i0 j0 k0 l0 ... + // i1 j1 k1 l1 ... + // i2 j2 k2 l2 ... + // i3 j3 k3 l3 ... + // m0 n0 o0 p0 ... + // m1 n1 o1 p1 ... + // m2 n2 o2 p2 ... + // m3 n3 o3 p3 ... + a = reinterpret_cast<__m512>(_mm512_unpacklo_pd( + reinterpret_cast<__m512d>(ta), reinterpret_cast<__m512d>(tc))); + b = reinterpret_cast<__m512>(_mm512_unpackhi_pd( + reinterpret_cast<__m512d>(ta), reinterpret_cast<__m512d>(tc))); + c = reinterpret_cast<__m512>(_mm512_unpacklo_pd( + reinterpret_cast<__m512d>(tb), reinterpret_cast<__m512d>(td))); + d = reinterpret_cast<__m512>(_mm512_unpackhi_pd( + reinterpret_cast<__m512d>(tb), reinterpret_cast<__m512d>(td))); + e = reinterpret_cast<__m512>(_mm512_unpacklo_pd( + reinterpret_cast<__m512d>(te), reinterpret_cast<__m512d>(tg))); + f = reinterpret_cast<__m512>(_mm512_unpackhi_pd( + reinterpret_cast<__m512d>(te), reinterpret_cast<__m512d>(tg))); + g = reinterpret_cast<__m512>(_mm512_unpacklo_pd( + reinterpret_cast<__m512d>(tf), reinterpret_cast<__m512d>(th))); + h = reinterpret_cast<__m512>(_mm512_unpackhi_pd( + reinterpret_cast<__m512d>(tf), reinterpret_cast<__m512d>(th))); + i = reinterpret_cast<__m512>(_mm512_unpacklo_pd( + reinterpret_cast<__m512d>(ti), reinterpret_cast<__m512d>(tk))); + j = reinterpret_cast<__m512>(_mm512_unpackhi_pd( + reinterpret_cast<__m512d>(ti), reinterpret_cast<__m512d>(tk))); + k = reinterpret_cast<__m512>(_mm512_unpacklo_pd( + reinterpret_cast<__m512d>(tj), reinterpret_cast<__m512d>(tl))); + l = reinterpret_cast<__m512>(_mm512_unpackhi_pd( + reinterpret_cast<__m512d>(tj), reinterpret_cast<__m512d>(tl))); + m = reinterpret_cast<__m512>(_mm512_unpacklo_pd( + reinterpret_cast<__m512d>(tm), reinterpret_cast<__m512d>(to))); + n = reinterpret_cast<__m512>(_mm512_unpackhi_pd( + reinterpret_cast<__m512d>(tm), reinterpret_cast<__m512d>(to))); + o = reinterpret_cast<__m512>(_mm512_unpacklo_pd( + reinterpret_cast<__m512d>(tn), reinterpret_cast<__m512d>(tq))); + p = reinterpret_cast<__m512>(_mm512_unpackhi_pd( + reinterpret_cast<__m512d>(tn), reinterpret_cast<__m512d>(tq))); + + // shuffle 128-bits (composed of 4 32-bit elements) + // a0 b0 c0 d0 a8 b8 c8 d8 e0 f0 g0 h0 e8 f8 g8 h8 + // a1 b1 c1 d1 ... + // a2 b2 c2 d2 ... + // a3 b3 c3 d3 ... + // a4 b4 c4 d4 ... + // a5 b5 c5 d5 ... + // a6 b6 c6 d6 ... + // a7 b7 c7 d7 ... + // i0 j0 k0 l0 i8 j8 k8 l8 m0 n0 o0 p0 m8 n8 o8 p8 + // i1 j1 k1 l1 ... + // i2 j2 k2 l2 ... + // i3 j3 k3 l3 ... + // i4 j4 k4 l4 ... + // i5 j5 k5 l5 ... + // i6 j6 k6 l6 ... + // i7 j7 k7 l7 ... + ta = _mm512_shuffle_f32x4(a, e, 0x88); + tb = _mm512_shuffle_f32x4(b, f, 0x88); + tc = _mm512_shuffle_f32x4(c, g, 0x88); + td = _mm512_shuffle_f32x4(d, h, 0x88); + te = _mm512_shuffle_f32x4(a, e, 0xdd); + tf = _mm512_shuffle_f32x4(b, f, 0xdd); + tg = _mm512_shuffle_f32x4(c, g, 0xdd); + th = _mm512_shuffle_f32x4(d, h, 0xdd); + ti = _mm512_shuffle_f32x4(i, m, 0x88); + tj = _mm512_shuffle_f32x4(j, n, 0x88); + tk = _mm512_shuffle_f32x4(k, o, 0x88); + tl = _mm512_shuffle_f32x4(l, p, 0x88); + tm = _mm512_shuffle_f32x4(i, m, 0xdd); + tn = _mm512_shuffle_f32x4(j, n, 0xdd); + to = _mm512_shuffle_f32x4(k, o, 0xdd); + tq = _mm512_shuffle_f32x4(l, p, 0xdd); + + // shuffle 128-bits (composed of 4 32-bit elements) + // a0 b0 c0 d0 ... o0 + // a1 b1 c1 d1 ... o1 + // a2 b2 c2 d2 ... o2 + // a3 b3 c3 d3 ... o3 + // a4 ... + // a5 ... + // a6 ... + // a7 ... + // a8 ... + // a9 ... + // a10 ... + // a11 ... + // a12 ... + // a13 ... + // a14 ... + // a15 b15 c15 d15 ... o15 + a = _mm512_shuffle_f32x4(ta, ti, 0x88); + b = _mm512_shuffle_f32x4(tb, tj, 0x88); + c = _mm512_shuffle_f32x4(tc, tk, 0x88); + d = _mm512_shuffle_f32x4(td, tl, 0x88); + e = _mm512_shuffle_f32x4(te, tm, 0x88); + f = _mm512_shuffle_f32x4(tf, tn, 0x88); + g = _mm512_shuffle_f32x4(tg, to, 0x88); + h = _mm512_shuffle_f32x4(th, tq, 0x88); + i = _mm512_shuffle_f32x4(ta, ti, 0xdd); + j = _mm512_shuffle_f32x4(tb, tj, 0xdd); + k = _mm512_shuffle_f32x4(tc, tk, 0xdd); + l = _mm512_shuffle_f32x4(td, tl, 0xdd); + m = _mm512_shuffle_f32x4(te, tm, 0xdd); + n = _mm512_shuffle_f32x4(tf, tn, 0xdd); + o = _mm512_shuffle_f32x4(tg, to, 0xdd); + p = _mm512_shuffle_f32x4(th, tq, 0xdd); + + // store from registers to dst + _mm512_storeu_ps(&dst[0 * ld_dst], a); + _mm512_storeu_ps(&dst[1 * ld_dst], b); + _mm512_storeu_ps(&dst[2 * ld_dst], c); + _mm512_storeu_ps(&dst[3 * ld_dst], d); + _mm512_storeu_ps(&dst[4 * ld_dst], e); + _mm512_storeu_ps(&dst[5 * ld_dst], f); + _mm512_storeu_ps(&dst[6 * ld_dst], g); + _mm512_storeu_ps(&dst[7 * ld_dst], h); + _mm512_storeu_ps(&dst[8 * ld_dst], i); + _mm512_storeu_ps(&dst[9 * ld_dst], j); + _mm512_storeu_ps(&dst[10 * ld_dst], k); + _mm512_storeu_ps(&dst[11 * ld_dst], l); + _mm512_storeu_ps(&dst[12 * ld_dst], m); + _mm512_storeu_ps(&dst[13 * ld_dst], n); + _mm512_storeu_ps(&dst[14 * ld_dst], o); + _mm512_storeu_ps(&dst[15 * ld_dst], p); +} + +void transpose_16x16( + int M, + int N, + const float* src, + int ld_src, + float* dst, + int ld_dst) { + int ib = 0, jb = 0; + for (ib = 0; ib + 16 <= M; ib += 16) { + for (jb = 0; jb + 16 <= N; jb += 16) { + transpose_kernel_16x16_avx512( + &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); + } + } + transpose_8x8(ib, N - jb, &src[jb], ld_src, &dst[jb * ld_dst], ld_dst); + transpose_8x8(M - ib, N, &src[ib * ld_src], ld_src, &dst[ib], ld_dst); +} + +} // namespace internal + +} // namespace fbgemm diff --git a/src/Utils_avx2.cc b/src/Utils_avx2.cc deleted file mode 100644 index badf70b..0000000 --- a/src/Utils_avx2.cc +++ /dev/null @@ -1,169 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * All rights reserved. - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ -#include "TransposeUtils.h" -#include - -namespace fbgemm { - -namespace internal { - -inline void -transpose_kernel_4x4_sse(const float* src, int ld_src, float* dst, int ld_dst) { - // load from src to registers - // a : a0 a1 a2 a3 - // b : b0 b1 b2 b3 - // c : c0 c1 c2 c3 - // d : d0 d1 d2 d3 - __m128 a = _mm_loadu_ps(&src[0 * ld_src]); - __m128 b = _mm_loadu_ps(&src[1 * ld_src]); - __m128 c = _mm_loadu_ps(&src[2 * ld_src]); - __m128 d = _mm_loadu_ps(&src[3 * ld_src]); - - // transpose the 4x4 matrix formed by 32-bit elements: Macro from SSE - // a : a0 b0 c0 d0 - // b : a1 b1 c1 d1 - // c : a2 b2 c2 d2 - // d : a3 b3 c3 d3 - _MM_TRANSPOSE4_PS(a, b, c, d); - - // store from registers to dst - _mm_storeu_ps(&dst[0 * ld_dst], a); - _mm_storeu_ps(&dst[1 * ld_dst], b); - _mm_storeu_ps(&dst[2 * ld_dst], c); - _mm_storeu_ps(&dst[3 * ld_dst], d); -} - -inline void transpose_4x4( - int M, - int N, - const float* src, - int ld_src, - float* dst, - int ld_dst) { - int ib = 0, jb = 0; - for (ib = 0; ib + 4 <= M; ib += 4) { - for (jb = 0; jb + 4 <= N; jb += 4) { - transpose_kernel_4x4_sse( - &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); - } - } - transpose_ref(ib, N - jb, &src[jb], ld_src, &dst[jb * ld_dst], ld_dst); - transpose_ref(M - ib, N, &src[ib * ld_src], ld_src, &dst[ib], ld_dst); -} - -inline void transpose_kernel_8x8_avx2( - const float* src, - int ld_src, - float* dst, - int ld_dst) { - // load from src to registers - // a : a0 a1 a2 a3 a4 a5 a6 a7 - // b : b0 b1 b2 b3 b4 b5 b6 b7 - // c : c0 c1 c2 c3 c4 c5 c6 c7 - // d : d0 d1 d2 d3 d4 d5 d6 d7 - // e : e0 e1 e2 e3 e4 e5 e6 e7 - // f : f0 f1 f2 f3 f4 f5 f6 f7 - // g : g0 g1 g2 g3 g4 g5 g6 g7 - // h : h0 h1 h2 h3 h4 h5 h6 h7 - __m256 a = _mm256_loadu_ps(&src[0 * ld_src]); - __m256 b = _mm256_loadu_ps(&src[1 * ld_src]); - __m256 c = _mm256_loadu_ps(&src[2 * ld_src]); - __m256 d = _mm256_loadu_ps(&src[3 * ld_src]); - __m256 e = _mm256_loadu_ps(&src[4 * ld_src]); - __m256 f = _mm256_loadu_ps(&src[5 * ld_src]); - __m256 g = _mm256_loadu_ps(&src[6 * ld_src]); - __m256 h = _mm256_loadu_ps(&src[7 * ld_src]); - - __m256 ab0145, ab2367, cd0145, cd2367, ef0145, ef2367, gh0145, gh2367; - __m256 abcd04, abcd15, efgh04, efgh15, abcd26, abcd37, efgh26, efgh37; - // unpacking and interleaving 32-bit elements - // ab0145 : a0 b0 a1 b1 a4 b4 a5 b5 - // ab2367 : a2 b2 a3 b3 a6 b6 a7 b7 - // cd0145 : c0 d0 c1 d1 c4 d4 c5 d5 - // cd2367 : c2 d2 c3 d3 c6 d6 c7 d7 - // ef0145 : e0 f0 e1 f1 e4 f4 e5 f5 - // ef2367 : e2 f2 e3 f3 e6 f6 e7 f7 - // gh0145 : g0 h0 g1 h1 g4 h4 g5 h5 - // gh2367 : g2 h2 g3 h3 g6 h6 g7 h7 - ab0145 = _mm256_unpacklo_ps(a, b); - ab2367 = _mm256_unpackhi_ps(a, b); - cd0145 = _mm256_unpacklo_ps(c, d); - cd2367 = _mm256_unpackhi_ps(c, d); - ef0145 = _mm256_unpacklo_ps(e, f); - ef2367 = _mm256_unpackhi_ps(e, f); - gh0145 = _mm256_unpacklo_ps(g, h); - gh2367 = _mm256_unpackhi_ps(g, h); - - // shuffling the 32-bit elements - // abcd04 : a0 b0 c0 d0 a4 b4 c4 d4 - // abcd15 : a1 b1 c1 d1 a5 b5 c5 d5 - // efgh04 : e0 f0 g0 h0 e4 f4 g4 h4 - // efgh15 : e1 f1 g1 h1 e5 b5 c5 d5 - // abcd26 : a2 b2 c2 d2 a6 b6 c6 d6 - // abcd37 : a3 b3 c3 d3 a7 b7 c7 d7 - // efgh26 : e2 f2 g2 h2 e6 f6 g6 h6 - // efgh37 : e3 f3 g3 h3 e7 f7 g7 h7 - abcd04 = _mm256_shuffle_ps(ab0145, cd0145, 0x44); - abcd15 = _mm256_shuffle_ps(ab0145, cd0145, 0xee); - efgh04 = _mm256_shuffle_ps(ef0145, gh0145, 0x44); - efgh15 = _mm256_shuffle_ps(ef0145, gh0145, 0xee); - abcd26 = _mm256_shuffle_ps(ab2367, cd2367, 0x44); - abcd37 = _mm256_shuffle_ps(ab2367, cd2367, 0xee); - efgh26 = _mm256_shuffle_ps(ef2367, gh2367, 0x44); - efgh37 = _mm256_shuffle_ps(ef2367, gh2367, 0xee); - - // shuffling 128-bit elements - // a : a0 b0 c0 d0 e0 f0 g0 h0 - // b : a1 b1 c1 d1 e1 f1 g1 h1 - // c : a2 b2 c2 d2 e2 f2 g2 h2 - // d : a3 b3 c3 d3 e3 f3 g3 h3 - // e : a4 b4 c4 d4 e4 f4 g4 h4 - // f : a5 b5 c5 d5 e5 f5 g5 h5 - // g : a6 b6 c6 d6 e6 f6 g6 h6 - // h : a7 b7 c7 d7 e7 f7 g7 h7 - a = _mm256_permute2f128_ps(efgh04, abcd04, 0x02); - b = _mm256_permute2f128_ps(efgh15, abcd15, 0x02); - c = _mm256_permute2f128_ps(efgh26, abcd26, 0x02); - d = _mm256_permute2f128_ps(efgh37, abcd37, 0x02); - e = _mm256_permute2f128_ps(efgh04, abcd04, 0x13); - f = _mm256_permute2f128_ps(efgh15, abcd15, 0x13); - g = _mm256_permute2f128_ps(efgh26, abcd26, 0x13); - h = _mm256_permute2f128_ps(efgh37, abcd37, 0x13); - - // store from registers to dst - _mm256_storeu_ps(&dst[0 * ld_dst], a); - _mm256_storeu_ps(&dst[1 * ld_dst], b); - _mm256_storeu_ps(&dst[2 * ld_dst], c); - _mm256_storeu_ps(&dst[3 * ld_dst], d); - _mm256_storeu_ps(&dst[4 * ld_dst], e); - _mm256_storeu_ps(&dst[5 * ld_dst], f); - _mm256_storeu_ps(&dst[6 * ld_dst], g); - _mm256_storeu_ps(&dst[7 * ld_dst], h); -} - - -void transpose_8x8( - int M, - int N, - const float* src, - int ld_src, - float* dst, - int ld_dst) { - int ib = 0, jb = 0; - for (ib = 0; ib + 8 <= M; ib += 8) { - for (jb = 0; jb + 8 <= N; jb += 8) { - transpose_kernel_8x8_avx2( - &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); - } - } - transpose_4x4(ib, N - jb, &src[jb], ld_src, &dst[jb * ld_dst], ld_dst); - transpose_4x4(M - ib, N, &src[ib * ld_src], ld_src, &dst[ib], ld_dst); -} - -} // namespace internal - -} // namespace fbgemm diff --git a/src/Utils_avx512.cc b/src/Utils_avx512.cc deleted file mode 100644 index f49bb6f..0000000 --- a/src/Utils_avx512.cc +++ /dev/null @@ -1,246 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * All rights reserved. - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "TransposeUtils.h" -#include - -namespace fbgemm { - -namespace internal { - -inline void transpose_kernel_16x16_avx512( - const float* src, - int ld_src, - float* dst, - int ld_dst) { - // load from src to registers - // a: a0 a1 a2 a3 a4 a5 a6 a7 a8 a9 a10 a11 a12 a13 a14 a15 - // b: b0 b1 b2 b3 b4 b5 b6 b7 b8 b9 b10 b11 b12 b13 b14 b15 - // c: c0 c1 c2 c3 c4 c5 c6 c7 c8 c9 c10 c11 c12 c13 c14 c15 - // d: d0 d1 d2 d3 d4 d5 d6 d7 d8 d9 d10 d11 d12 d13 d14 d15 - // e: e0 e1 e2 e3 e4 e5 e6 e7 e8 e9 e10 e11 e12 e13 e14 e15 - // f: f0 f1 f2 f3 f4 f5 f6 f7 f8 f9 f10 f11 f12 f13 f14 f15 - // g: g0 g1 g2 g3 g4 g5 g6 g7 g8 g9 g10 g11 g12 g13 g14 g15 - // h: h0 h1 h2 h3 h4 h5 h6 h7 h8 h9 h10 h11 h12 h13 h14 h15 - // i: i0 i1 i2 i3 i4 i5 i6 i7 i8 i9 i10 i11 i12 i13 i14 i15 - // j: j0 j1 j2 j3 j4 j5 j6 j7 j8 j9 j10 j11 j12 j13 j14 j15 - // k: k0 k1 k2 k3 k4 k5 k6 k7 k8 k9 k10 k11 k12 k13 k14 k15 - // l: l0 l1 l2 l3 l4 l5 l6 l7 l8 l9 l10 l11 l12 l13 l14 l15 - // m: m0 m1 m2 m3 m4 m5 m6 m7 m8 m9 m10 m11 m12 m13 m14 m15 - // n: n0 n1 n2 n3 n4 n5 n6 n7 n8 n9 n10 n11 n12 n13 n14 n15 - // o: o0 o1 o2 o3 o4 o5 o6 o7 o8 o9 o10 o11 o12 o13 o14 o15 - // p: p0 p1 p2 p3 p4 p5 p6 p7 p8 p9 p10 p11 p12 p13 p14 p15 - __m512 a = _mm512_loadu_ps(&src[0 * ld_src]); - __m512 b = _mm512_loadu_ps(&src[1 * ld_src]); - __m512 c = _mm512_loadu_ps(&src[2 * ld_src]); - __m512 d = _mm512_loadu_ps(&src[3 * ld_src]); - __m512 e = _mm512_loadu_ps(&src[4 * ld_src]); - __m512 f = _mm512_loadu_ps(&src[5 * ld_src]); - __m512 g = _mm512_loadu_ps(&src[6 * ld_src]); - __m512 h = _mm512_loadu_ps(&src[7 * ld_src]); - __m512 i = _mm512_loadu_ps(&src[8 * ld_src]); - __m512 j = _mm512_loadu_ps(&src[9 * ld_src]); - __m512 k = _mm512_loadu_ps(&src[10 * ld_src]); - __m512 l = _mm512_loadu_ps(&src[11 * ld_src]); - __m512 m = _mm512_loadu_ps(&src[12 * ld_src]); - __m512 n = _mm512_loadu_ps(&src[13 * ld_src]); - __m512 o = _mm512_loadu_ps(&src[14 * ld_src]); - __m512 p = _mm512_loadu_ps(&src[15 * ld_src]); - - __m512 ta, tb, tc, td, te, tf, tg, th, ti, tj, tk, tl, tm, tn, to, tq; - // unpacking and interleaving 32-bit elements - // a0 b0 a1 b1 a4 b4 a5 b5 a8 b8 a9 b9 a12 b12 a13 b13 - // a2 b2 a3 b3 a6 b6 a7 b7 a10 b10 a11 b11 a14 b14 a15 b15 - // c0 d0 c1 d1 ... - // c2 d2 c3 d3 ... - // e0 f0 e1 f1 ... - // e2 f2 e3 f3 ... - // g0 h0 g1 h1 ... - // g2 h2 g3 h3 ... - // i0 ... - // i2 ... - // k0 ... - // k2 ... - // m0 ... - // m2 ... - // o0 ... - // o1 ... - ta = _mm512_unpacklo_ps(a, b); - tb = _mm512_unpackhi_ps(a, b); - tc = _mm512_unpacklo_ps(c, d); - td = _mm512_unpackhi_ps(c, d); - te = _mm512_unpacklo_ps(e, f); - tf = _mm512_unpackhi_ps(e, f); - tg = _mm512_unpacklo_ps(g, h); - th = _mm512_unpackhi_ps(g, h); - ti = _mm512_unpacklo_ps(i, j); - tj = _mm512_unpackhi_ps(i, j); - tk = _mm512_unpacklo_ps(k, l); - tl = _mm512_unpackhi_ps(k, l); - tm = _mm512_unpacklo_ps(m, n); - tn = _mm512_unpackhi_ps(m, n); - to = _mm512_unpacklo_ps(o, p); - tq = _mm512_unpackhi_ps(o, p); - - // unpacking and interleaving 64-bit elements - // a0 b0 c0 d0 a4 b4 c4 d4 a8 b8 c8 d8 a12 b12 c12 d12 - // a1 b1 c1 d1 ... - // a2 b2 c2 d2 ... - // a3 b3 c3 d3 ... - // e0 f0 g0 h0 e4 f4 g4 h4 e8 f8 g8 h8 e12 f12 g12 h12 - // e1 f1 g1 h1 ... - // e2 f2 g2 h2 ... - // e3 f3 g3 h3 ... - // i0 j0 k0 l0 ... - // i1 j1 k1 l1 ... - // i2 j2 k2 l2 ... - // i3 j3 k3 l3 ... - // m0 n0 o0 p0 ... - // m1 n1 o1 p1 ... - // m2 n2 o2 p2 ... - // m3 n3 o3 p3 ... - a = reinterpret_cast<__m512>(_mm512_unpacklo_pd( - reinterpret_cast<__m512d>(ta), reinterpret_cast<__m512d>(tc))); - b = reinterpret_cast<__m512>(_mm512_unpackhi_pd( - reinterpret_cast<__m512d>(ta), reinterpret_cast<__m512d>(tc))); - c = reinterpret_cast<__m512>(_mm512_unpacklo_pd( - reinterpret_cast<__m512d>(tb), reinterpret_cast<__m512d>(td))); - d = reinterpret_cast<__m512>(_mm512_unpackhi_pd( - reinterpret_cast<__m512d>(tb), reinterpret_cast<__m512d>(td))); - e = reinterpret_cast<__m512>(_mm512_unpacklo_pd( - reinterpret_cast<__m512d>(te), reinterpret_cast<__m512d>(tg))); - f = reinterpret_cast<__m512>(_mm512_unpackhi_pd( - reinterpret_cast<__m512d>(te), reinterpret_cast<__m512d>(tg))); - g = reinterpret_cast<__m512>(_mm512_unpacklo_pd( - reinterpret_cast<__m512d>(tf), reinterpret_cast<__m512d>(th))); - h = reinterpret_cast<__m512>(_mm512_unpackhi_pd( - reinterpret_cast<__m512d>(tf), reinterpret_cast<__m512d>(th))); - i = reinterpret_cast<__m512>(_mm512_unpacklo_pd( - reinterpret_cast<__m512d>(ti), reinterpret_cast<__m512d>(tk))); - j = reinterpret_cast<__m512>(_mm512_unpackhi_pd( - reinterpret_cast<__m512d>(ti), reinterpret_cast<__m512d>(tk))); - k = reinterpret_cast<__m512>(_mm512_unpacklo_pd( - reinterpret_cast<__m512d>(tj), reinterpret_cast<__m512d>(tl))); - l = reinterpret_cast<__m512>(_mm512_unpackhi_pd( - reinterpret_cast<__m512d>(tj), reinterpret_cast<__m512d>(tl))); - m = reinterpret_cast<__m512>(_mm512_unpacklo_pd( - reinterpret_cast<__m512d>(tm), reinterpret_cast<__m512d>(to))); - n = reinterpret_cast<__m512>(_mm512_unpackhi_pd( - reinterpret_cast<__m512d>(tm), reinterpret_cast<__m512d>(to))); - o = reinterpret_cast<__m512>(_mm512_unpacklo_pd( - reinterpret_cast<__m512d>(tn), reinterpret_cast<__m512d>(tq))); - p = reinterpret_cast<__m512>(_mm512_unpackhi_pd( - reinterpret_cast<__m512d>(tn), reinterpret_cast<__m512d>(tq))); - - // shuffle 128-bits (composed of 4 32-bit elements) - // a0 b0 c0 d0 a8 b8 c8 d8 e0 f0 g0 h0 e8 f8 g8 h8 - // a1 b1 c1 d1 ... - // a2 b2 c2 d2 ... - // a3 b3 c3 d3 ... - // a4 b4 c4 d4 ... - // a5 b5 c5 d5 ... - // a6 b6 c6 d6 ... - // a7 b7 c7 d7 ... - // i0 j0 k0 l0 i8 j8 k8 l8 m0 n0 o0 p0 m8 n8 o8 p8 - // i1 j1 k1 l1 ... - // i2 j2 k2 l2 ... - // i3 j3 k3 l3 ... - // i4 j4 k4 l4 ... - // i5 j5 k5 l5 ... - // i6 j6 k6 l6 ... - // i7 j7 k7 l7 ... - ta = _mm512_shuffle_f32x4(a, e, 0x88); - tb = _mm512_shuffle_f32x4(b, f, 0x88); - tc = _mm512_shuffle_f32x4(c, g, 0x88); - td = _mm512_shuffle_f32x4(d, h, 0x88); - te = _mm512_shuffle_f32x4(a, e, 0xdd); - tf = _mm512_shuffle_f32x4(b, f, 0xdd); - tg = _mm512_shuffle_f32x4(c, g, 0xdd); - th = _mm512_shuffle_f32x4(d, h, 0xdd); - ti = _mm512_shuffle_f32x4(i, m, 0x88); - tj = _mm512_shuffle_f32x4(j, n, 0x88); - tk = _mm512_shuffle_f32x4(k, o, 0x88); - tl = _mm512_shuffle_f32x4(l, p, 0x88); - tm = _mm512_shuffle_f32x4(i, m, 0xdd); - tn = _mm512_shuffle_f32x4(j, n, 0xdd); - to = _mm512_shuffle_f32x4(k, o, 0xdd); - tq = _mm512_shuffle_f32x4(l, p, 0xdd); - - // shuffle 128-bits (composed of 4 32-bit elements) - // a0 b0 c0 d0 ... o0 - // a1 b1 c1 d1 ... o1 - // a2 b2 c2 d2 ... o2 - // a3 b3 c3 d3 ... o3 - // a4 ... - // a5 ... - // a6 ... - // a7 ... - // a8 ... - // a9 ... - // a10 ... - // a11 ... - // a12 ... - // a13 ... - // a14 ... - // a15 b15 c15 d15 ... o15 - a = _mm512_shuffle_f32x4(ta, ti, 0x88); - b = _mm512_shuffle_f32x4(tb, tj, 0x88); - c = _mm512_shuffle_f32x4(tc, tk, 0x88); - d = _mm512_shuffle_f32x4(td, tl, 0x88); - e = _mm512_shuffle_f32x4(te, tm, 0x88); - f = _mm512_shuffle_f32x4(tf, tn, 0x88); - g = _mm512_shuffle_f32x4(tg, to, 0x88); - h = _mm512_shuffle_f32x4(th, tq, 0x88); - i = _mm512_shuffle_f32x4(ta, ti, 0xdd); - j = _mm512_shuffle_f32x4(tb, tj, 0xdd); - k = _mm512_shuffle_f32x4(tc, tk, 0xdd); - l = _mm512_shuffle_f32x4(td, tl, 0xdd); - m = _mm512_shuffle_f32x4(te, tm, 0xdd); - n = _mm512_shuffle_f32x4(tf, tn, 0xdd); - o = _mm512_shuffle_f32x4(tg, to, 0xdd); - p = _mm512_shuffle_f32x4(th, tq, 0xdd); - - // store from registers to dst - _mm512_storeu_ps(&dst[0 * ld_dst], a); - _mm512_storeu_ps(&dst[1 * ld_dst], b); - _mm512_storeu_ps(&dst[2 * ld_dst], c); - _mm512_storeu_ps(&dst[3 * ld_dst], d); - _mm512_storeu_ps(&dst[4 * ld_dst], e); - _mm512_storeu_ps(&dst[5 * ld_dst], f); - _mm512_storeu_ps(&dst[6 * ld_dst], g); - _mm512_storeu_ps(&dst[7 * ld_dst], h); - _mm512_storeu_ps(&dst[8 * ld_dst], i); - _mm512_storeu_ps(&dst[9 * ld_dst], j); - _mm512_storeu_ps(&dst[10 * ld_dst], k); - _mm512_storeu_ps(&dst[11 * ld_dst], l); - _mm512_storeu_ps(&dst[12 * ld_dst], m); - _mm512_storeu_ps(&dst[13 * ld_dst], n); - _mm512_storeu_ps(&dst[14 * ld_dst], o); - _mm512_storeu_ps(&dst[15 * ld_dst], p); -} - -void transpose_16x16( - int M, - int N, - const float* src, - int ld_src, - float* dst, - int ld_dst) { - int ib = 0, jb = 0; - for (ib = 0; ib + 16 <= M; ib += 16) { - for (jb = 0; jb + 16 <= N; jb += 16) { - transpose_kernel_16x16_avx512( - &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst); - } - } - transpose_8x8(ib, N - jb, &src[jb], ld_src, &dst[jb * ld_dst], ld_dst); - transpose_8x8(M - ib, N, &src[ib * ld_src], ld_src, &dst[ib], ld_dst); -} - -} // namespace internal - -} // namespace fbgemm diff --git a/src/codegen_fp16fp32.cc b/src/codegen_fp16fp32.cc index 8dd3166..17bb113 100644 --- a/src/codegen_fp16fp32.cc +++ b/src/codegen_fp16fp32.cc @@ -65,7 +65,7 @@ int main() { // open all files ofstream srcfile; - srcfile.open("FbgemmFP16UKernels.cc"); + srcfile.open("FbgemmFP16UKernelsAvx2.cc"); srcfile << "/*\n" " * Copyright (c) Facebook, Inc. and its affiliates.\n" @@ -73,14 +73,14 @@ int main() { " * This source code is licensed under the BSD-style license found in the\n" " * LICENSE file in the root directory of this source tree.\n" " */\n"; - srcfile << "#include \"FbgemmFP16UKernels.h\"\n\n"; + srcfile << "#include \"FbgemmFP16UKernelsAvx2.h\"\n\n"; srcfile << "namespace fbgemm {\n\n"; if (iaca) { srcfile << "#include \"iacaMarks.h\"\n"; } ofstream hdrfile; - hdrfile.open("FbgemmFP16UKernels.h"); + hdrfile.open("FbgemmFP16UKernelsAvx2.h"); hdrfile << "/*\n" " * Copyright (c) Facebook, Inc. and its affiliates.\n" @@ -92,8 +92,6 @@ int main() { hdrfile << "#ifndef FBGEMM_UKERNELS\n"; hdrfile << "#define FBGEMM_UKERNELS\n"; hdrfile << "#include \n"; - hdrfile << "#include \n"; - hdrfile << "#include \n"; hdrfile << "#include \"fbgemm/Types.h\"\n\n"; hdrfile << "namespace fbgemm {\n\n"; hdrfile << "using fp16 = float16;\n"; -- cgit v1.2.3