Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYoung Jin Kim <youki@microsoft.com>2019-06-18 03:39:01 +0300
committerYoung Jin Kim <youki@microsoft.com>2019-06-18 03:39:01 +0300
commit25c1595cebfc88775692660b1da7d0bf632112a8 (patch)
treeb33cd867fb0514f98233bf9d7e200cb622e17ae0
parentbf2f45f35cc0d7b6f420894652824a377f764714 (diff)
parenta838fc2a9c354b7da8a2663d300b64686a234247 (diff)
Merged PR 8337: Enable windows build and FP16 packed GEMM on windows
- Compile options changed - Memory management fixed to support both OSs
-rw-r--r--CMakeLists.txt46
-rw-r--r--bench/AlignedVec.h12
-rw-r--r--bench/BenchUtils.cc15
-rw-r--r--bench/CMakeLists.txt8
-rw-r--r--bench/ConvUnifiedBenchmark.cc1
-rw-r--r--bench/PackedRequantizeAcc16Benchmark.cc1
-rw-r--r--cmake/modules/FindMKL.cmake2
-rw-r--r--include/fbgemm/Fbgemm.h12
-rw-r--r--include/fbgemm/FbgemmFP16.h16
-rw-r--r--include/fbgemm/Types.h10
-rw-r--r--include/fbgemm/Utils.h10
-rw-r--r--src/FbgemmFP16UKernelsAvx2.cc1261
-rw-r--r--src/FbgemmFP16UKernelsAvx2.h17
-rw-r--r--src/FbgemmI8DepthwiseAvx2.cc80
-rw-r--r--src/FbgemmI8Spmdm.cc16
-rw-r--r--src/PackAWithQuantRowOffset.cc4
-rw-r--r--src/UtilsAvx512.cc35
-rw-r--r--src/codegen_fp16fp32.cc432
-rw-r--r--test/CMakeLists.txt8
-rw-r--r--test/Im2ColFusedRequantizeTest.cc1
-rw-r--r--test/PackedRequantizeAcc16Test.cc1
21 files changed, 1137 insertions, 851 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index b575e17..e6c7419 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -22,6 +22,10 @@ set(FBGEMM_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
set(FBGEMM_THIRDPARTY_DIR ${FBGEMM_BINARY_DIR}/third_party)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
+# only static library is available for windows
+if(MSVC)
+ set(FBGEMM_LIBRARY_TYPE "static")
+endif(MSVC)
#All the source files that either use avx2 instructions statically or JIT
#avx2/avx512 instructions.
@@ -49,10 +53,20 @@ set(FBGEMM_GENERIC_SRCS src/ExecuteKernel.cc
src/Utils.cc)
#check if compiler supports avx512
-include(CheckCXXCompilerFlag)
-CHECK_CXX_COMPILER_FLAG(-mavx512f COMPILER_SUPPORTS_AVX512)
-if(NOT COMPILER_SUPPORTS_AVX512)
- message(FATAL_ERROR "A compiler with AVX512 support is required.")
+if (MSVC)
+ set(DISABLE_GLOBALLY "/wd\"4310\" /wd\"4324\"")
+ set(INTRINSICS "/arch:AVX2")
+ set(CMAKE_CXX_FLAGS "/EHsc /DWIN32 /D_WINDOWS /DUNICODE /D_UNICODE /D_CRT_NONSTDC_NO_WARNINGS /D_CRT_SECURE_NO_WARNINGS ${DISABLE_GLOBALLY}")
+ set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS} /MT /O2 ${INTRINSICS} /Zi /MP /GL /DNDEBUG")
+ set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS} /MTd /Od /Ob0 ${INTRINSICS} /RTC1 /Zi /D_DEBUG")
+ set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /DEBUG /LTCG:incremental /INCREMENTAL:NO /NODEFAULTLIB:MSVCRT /ignore:4049")
+ set(CMAKE_STATIC_LINKER_FLAGS "${CMAKE_STATIC_LINKER_FLAGS} /LTCG:incremental /NODEFAULTLIB:MSVCRT")
+else()
+ include(CheckCXXCompilerFlag)
+ CHECK_CXX_COMPILER_FLAG(-mavx512f COMPILER_SUPPORTS_AVX512)
+ if(NOT COMPILER_SUPPORTS_AVX512)
+ message(FATAL_ERROR "A compiler with AVX512 support is required.")
+ endif()
endif()
#We should default to a Release build
@@ -104,11 +118,13 @@ set_target_properties(fbgemm_generic fbgemm_avx2 fbgemm_avx512 PROPERTIES
CXX_EXTENSIONS NO
CXX_VISIBILITY_PRESET hidden)
-target_compile_options(fbgemm_avx2 PRIVATE
- "-m64" "-mavx2" "-mfma" "-masm=intel")
-target_compile_options(fbgemm_avx512 PRIVATE
- "-m64" "-mavx2" "-mfma" "-mavx512f" "-mavx512bw" "-mavx512dq"
- "-mavx512vl" "-masm=intel")
+if (NOT MSVC)
+ target_compile_options(fbgemm_avx2 PRIVATE
+ "-m64" "-mavx2" "-mfma" "-masm=intel" "-mf16c")
+ target_compile_options(fbgemm_avx512 PRIVATE
+ "-m64" "-mavx2" "-mfma" "-mavx512f" "-mavx512bw" "-mavx512dq"
+ "-mavx512vl" "-masm=intel" "-mf16c")
+endif()
if(NOT TARGET asmjit)
#Download asmjit from github if ASMJIT_SRC_DIR is not specified.
@@ -118,7 +134,7 @@ if(NOT TARGET asmjit)
endif()
#build asmjit
- set(ASMJIT_STATIC ON)
+ set(ASMJIT_STATIC TRUE CACHE STRING "" FORCE)
add_subdirectory("${ASMJIT_SRC_DIR}" "${FBGEMM_BINARY_DIR}/asmjit")
set_property(TARGET asmjit PROPERTY POSITION_INDEPENDENT_CODE ON)
endif()
@@ -135,6 +151,9 @@ if(NOT TARGET cpuinfo)
set(CPUINFO_BUILD_MOCK_TESTS OFF CACHE BOOL "Do not build cpuinfo mock tests")
set(CPUINFO_BUILD_BENCHMARKS OFF CACHE BOOL "Do not build cpuinfo benchmarks")
set(CPUINFO_LIBRARY_TYPE static)
+ if(MSVC)
+ set(CPUINFO_RUNTIME_TYPE static)
+ endif(MSVC)
add_subdirectory("${CPUINFO_SOURCE_DIR}" "${FBGEMM_BINARY_DIR}/cpuinfo")
set_property(TARGET cpuinfo PROPERTY POSITION_INDEPENDENT_CODE ON)
endif()
@@ -177,9 +196,10 @@ elseif(FBGEMM_LIBRARY_TYPE STREQUAL "static")
$<TARGET_OBJECTS:fbgemm_generic>
$<TARGET_OBJECTS:fbgemm_avx2>
$<TARGET_OBJECTS:fbgemm_avx512>)
- target_compile_definitions(fbgemm_avx2 PRIVATE FBGEMM_STATIC)
- target_compile_definitions(fbgemm_avx512 PRIVATE FBGEMM_STATIC)
- target_compile_definitions(fbgemm PRIVATE FBGEMM_STATIC)
+ target_compile_definitions(fbgemm_generic PRIVATE FBGEMM_STATIC ASMJIT_STATIC)
+ target_compile_definitions(fbgemm_avx2 PRIVATE FBGEMM_STATIC ASMJIT_STATIC)
+ target_compile_definitions(fbgemm_avx512 PRIVATE FBGEMM_STATIC ASMJIT_STATIC)
+ target_compile_definitions(fbgemm PRIVATE FBGEMM_STATIC ASMJIT_STATIC)
else()
message(FATAL_ERROR "Unsupported library type ${FBGEMM_LIBRARY_TYPE}")
endif()
diff --git a/bench/AlignedVec.h b/bench/AlignedVec.h
index fd4b88e..0b8c7ce 100644
--- a/bench/AlignedVec.h
+++ b/bench/AlignedVec.h
@@ -99,8 +99,12 @@ class aligned_allocator {
// Mallocator wraps malloc().
void* pv = nullptr;
- posix_memalign(&pv, Alignment, n * sizeof(T));
- // pv = aligned_alloc(Alignment, n * sizeof(T));
+#ifdef _MSC_VER
+ pv = _aligned_malloc(n * sizeof(T), Alignment);
+#else
+ int result = posix_memalign(&pv, Alignment, n * sizeof(T));
+ assert(result == 0);
+#endif
// Allocators should throw std::bad_alloc in the case of memory allocation
// failure.
@@ -112,7 +116,11 @@ class aligned_allocator {
}
void deallocate(T* const p, const std::size_t /*n*/) const {
+#ifdef _MSC_VER
+ _aligned_free(p);
+#else
free(p);
+#endif
}
// The following will be the same for all allocators that ignore hints.
diff --git a/bench/BenchUtils.cc b/bench/BenchUtils.cc
index a5ab949..fb20988 100644
--- a/bench/BenchUtils.cc
+++ b/bench/BenchUtils.cc
@@ -24,6 +24,21 @@ void randFill(aligned_vector<T>& vec, T low, T high, std::true_type) {
std::generate(vec.begin(), vec.end(), [&] { return dis(eng); });
}
+// MSVC doesn't accept uint8_t and int8_t as a template argument.
+#ifdef _MSC_VER
+void randFill(aligned_vector<uint8_t>& vec, uint8_t low, uint8_t high, std::true_type) {
+ std::uniform_int_distribution<unsigned short> dis(low, high);
+ for (int i = 0; i < vec.size(); i++)
+ vec[i] = (uint8_t)dis(eng);
+}
+
+void randFill(aligned_vector<int8_t>& vec, int8_t low, int8_t high, std::true_type) {
+ std::uniform_int_distribution<short> dis(low, high);
+ for (int i = 0; i < vec.size(); i++)
+ vec[i] = (int8_t)dis(eng);
+}
+#endif
+
template <typename T>
void randFill(aligned_vector<T>& vec, T low, T high, std::false_type) {
std::uniform_real_distribution<T> dis(low, high);
diff --git a/bench/CMakeLists.txt b/bench/CMakeLists.txt
index a58eed7..c2e76ef 100644
--- a/bench/CMakeLists.txt
+++ b/bench/CMakeLists.txt
@@ -12,8 +12,12 @@ macro(add_benchmark BENCHNAME)
set_target_properties(${BENCHNAME} PROPERTIES
CXX_STANDARD 11
CXX_EXTENSIONS NO)
- target_compile_options(${BENCHNAME} PRIVATE
- "-m64" "-mavx2" "-mfma" "-masm=intel")
+ if(MSVC)
+ target_compile_options(${BENCHNAME} PRIVATE "/DFBGEMM_STATIC" "/MT")
+ else(MSVC)
+ target_compile_options(${BENCHNAME} PRIVATE
+ "-m64" "-mavx2" "-mfma" "-masm=intel")
+ endif(MSVC)
target_link_libraries(${BENCHNAME} fbgemm)
add_dependencies(${BENCHNAME} fbgemm)
if(${MKL_FOUND})
diff --git a/bench/ConvUnifiedBenchmark.cc b/bench/ConvUnifiedBenchmark.cc
index 59079c7..6bc2cf4 100644
--- a/bench/ConvUnifiedBenchmark.cc
+++ b/bench/ConvUnifiedBenchmark.cc
@@ -11,6 +11,7 @@
#include <iostream>
#include <random>
#include <vector>
+#include <numeric>
#ifdef _OPENMP
#include <omp.h>
diff --git a/bench/PackedRequantizeAcc16Benchmark.cc b/bench/PackedRequantizeAcc16Benchmark.cc
index 8706f96..40ff662 100644
--- a/bench/PackedRequantizeAcc16Benchmark.cc
+++ b/bench/PackedRequantizeAcc16Benchmark.cc
@@ -11,6 +11,7 @@
#include <iostream>
#include <random>
#include <vector>
+#include <numeric>
#ifdef _OPENMP
#include <omp.h>
diff --git a/cmake/modules/FindMKL.cmake b/cmake/modules/FindMKL.cmake
index 6b38af7..7661a40 100644
--- a/cmake/modules/FindMKL.cmake
+++ b/cmake/modules/FindMKL.cmake
@@ -47,7 +47,7 @@ ELSE(CMAKE_COMPILER_IS_GNUCC)
SET(mklifaces "intel")
SET(mklrtls "iomp5" "guide")
IF (MSVC)
- SET(mklrtls "libiomp5md")
+ SET(mklrtls "libiomp5mt")
ENDIF (MSVC)
ENDIF (CMAKE_COMPILER_IS_GNUCC)
diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h
index 721b12f..90d1ee9 100644
--- a/include/fbgemm/Fbgemm.h
+++ b/include/fbgemm/Fbgemm.h
@@ -266,7 +266,11 @@ class PackMatrix {
virtual ~PackMatrix() {
if (bufAllocatedHere_) {
+#ifdef _MSC_VER
+ _aligned_free(buf_);
+#else
free(buf_);
+#endif
}
}
@@ -512,7 +516,11 @@ class FBGEMM_API PackWeightMatrixForGConv {
~PackWeightMatrixForGConv() {
if (bufAllocatedHere_) {
+#ifdef _MSC_VER
+ _aligned_free(pdata_);
+#else
free(pdata_);
+#endif
}
}
@@ -1358,8 +1366,12 @@ FBGEMM_API bool fbgemmOptimizedGConv(const conv_param_t<SPATIAL_DIM>& conv_p);
*/
static void* fbgemmAlignedAlloc(size_t __align, size_t __size) {
void* aligned_mem;
+#ifdef _MSC_VER
+ aligned_mem = _aligned_malloc(__size, __align);
+#else
if (posix_memalign(&aligned_mem, __align, __size))
return 0;
+#endif
return aligned_mem;
}
diff --git a/include/fbgemm/FbgemmFP16.h b/include/fbgemm/FbgemmFP16.h
index 2b4c42b..3d84977 100644
--- a/include/fbgemm/FbgemmFP16.h
+++ b/include/fbgemm/FbgemmFP16.h
@@ -108,16 +108,24 @@ class PackedGemmMatrixFP16 {
// allocate and initialize packed memory
const int padding = 1024; // required by sw pipelined kernels
size_ = (blockRowSize() * nbrow_) * (blockColSize() * nbcol_);
- // pmat_ = (float16 *)aligned_alloc(64, matSize() * sizeof(float16) +
- // padding);
- posix_memalign((void**)&pmat_, 64, matSize() * sizeof(float16) + padding);
+#ifdef _MSC_VER
+ pmat_ = (float16 *)_aligned_malloc(matSize() * sizeof(float16) +
+ padding, 64);
+#else
+ int result = posix_memalign((void**)&pmat_, 64, matSize() * sizeof(float16) + padding);
+ assert(result == 0);
+#endif
for (auto i = 0; i < matSize(); i++) {
pmat_[i] = tconv(0.f, pmat_[i]);
}
}
~PackedGemmMatrixFP16() {
+#ifdef _MSC_VER
+ _aligned_free(pmat_);
+#else
free(pmat_);
+#endif
}
// protected:
@@ -166,7 +174,7 @@ class PackedGemmMatrixFP16 {
}
int matSize() const {
- return size_;
+ return (int)size_;
}
int numRows() const {
return nrow_;
diff --git a/include/fbgemm/Types.h b/include/fbgemm/Types.h
index d5f3f6a..c71e7a4 100644
--- a/include/fbgemm/Types.h
+++ b/include/fbgemm/Types.h
@@ -12,7 +12,11 @@
namespace fbgemm {
+#ifdef _MSC_VER
+typedef struct __declspec(align(2)) __f16 {
+#else
typedef struct __attribute__((aligned(2))) __f16 {
+#endif
uint16_t x;
} float16;
@@ -38,11 +42,11 @@ static inline float16 cpu_float2half_rn(float f) {
// Get rid of +Inf/-Inf, +0/-0.
if (u > 0x477fefff) {
- ret.x = sign | 0x7c00U;
+ ret.x = (uint16_t) (sign | 0x7c00U);
return ret;
}
if (u < 0x33000001) {
- ret.x = (sign | 0x0000);
+ ret.x = (uint16_t)(sign | 0x0000);
return ret;
}
@@ -72,7 +76,7 @@ static inline float16 cpu_float2half_rn(float f) {
}
}
- ret.x = (sign | (exponent << 10) | mantissa);
+ ret.x = (uint16_t)(sign | (exponent << 10) | mantissa);
return ret;
}
diff --git a/include/fbgemm/Utils.h b/include/fbgemm/Utils.h
index ef1d4ab..9f8e1ee 100644
--- a/include/fbgemm/Utils.h
+++ b/include/fbgemm/Utils.h
@@ -10,6 +10,16 @@
#include "FbgemmBuild.h"
#include "UtilsAvx2.h"
+#ifdef _MSC_VER
+# define ALWAYS_INLINE // __forceinline
+# define ALIGNED_MALLOC(size, alignment) _aligned_malloc(size, alignment)
+# define FREE(ptr) _aligned_free(ptr)
+#else
+# define ALWAYS_INLINE __attribute__((always_inline))
+# define ALIGNED_MALLOC(size, alignment) aligned_alloc(alignment, size)
+# define FREE(ptr) free(ptr)
+#endif
+
namespace fbgemm {
/**
diff --git a/src/FbgemmFP16UKernelsAvx2.cc b/src/FbgemmFP16UKernelsAvx2.cc
index 0c795b0..5f7492f 100644
--- a/src/FbgemmFP16UKernelsAvx2.cc
+++ b/src/FbgemmFP16UKernelsAvx2.cc
@@ -5,791 +5,642 @@
* LICENSE file in the root directory of this source tree.
*/
#include "FbgemmFP16UKernelsAvx2.h"
+#include <immintrin.h>
namespace fbgemm {
-void __attribute__((noinline)) gemmkernel_1x2_AVX2_fA0fB0fC0(GemmParams* gp) {
- asm volatile(
-#if !defined(__clang__)
- "mov r14, %[gp]\t\n"
-#else
- "mov %[gp], %%r14\t\n"
- ".intel_syntax noprefix\t\n"
-#endif
+void NOINLINE_ATTR gemmkernel_1x2_AVX2_fA0fB0fC0(GemmParams* gp) {
+ char* r14 = (char*)gp; //"mov r14, %[gp]\t\n"
// Copy parameters
// k
- "mov r8, [r14 + 0]\t\n"
+ uint64_t r8 = *(uint64_t *)((char*)r14 + 0 ); //"mov r8, [r14 + 0]\t\n"
// A
- "mov r9, [r14 + 8]\t\n"
+ float* r9 = *(float* *)((char*)r14 + 8 ); //"mov r9, [r14 + 8]\t\n"
// B
- "mov r10, [r14 + 16]\t\n"
+ const fp16* r10 = *(const fp16**)((char*)r14 + 16); //"mov r10, [r14 + 16]\t\n"
// beta
- "mov r15, [r14 + 24]\t\n"
+ float* r15 = *(float* *)((char*)r14 + 24); //"mov r15, [r14 + 24]\t\n"
// accum
- "mov rdx, [r14 + 32]\t\n"
+ uint64_t rdx = *(uint64_t *)((char*)r14 + 32); //"mov rdx, [r14 + 32]\t\n"
// C
- "mov r12, [r14 + 40]\t\n"
+ float* r12 = *(float* *)((char*)r14 + 40); //"mov r12, [r14 + 40]\t\n"
// ldc
- "mov r13, [r14 + 48]\t\n"
+ uint64_t r13 = *(uint64_t *)((char*)r14 + 48); //"mov r13, [r14 + 48]\t\n"
// b_block_cols
- "mov rdi, [r14 + 56]\t\n"
+ uint64_t rdi = *(uint64_t *)((char*)r14 + 56); //"mov rdi, [r14 + 56]\t\n"
// b_block_size
- "mov rsi, [r14 + 64]\t\n"
+ uint64_t rsi = *(uint64_t *)((char*)r14 + 64); //"mov rsi, [r14 + 64]\t\n"
// Make copies of A and C
- "mov rax, r9\t\n"
- "mov rcx, r12\t\n"
-
- "mov rbx, 0\t\n"
- "loop_outter%=:\t\n"
- "mov r14, 0\t\n"
- "vxorps ymm0,ymm0,ymm0\t\n"
- "vxorps ymm1,ymm1,ymm1\t\n"
-
-
- "loop_inner%=:\t\n"
-
- "vcvtph2ps ymm3,XMMWORD PTR [r10 + 0]\t\n"
- "vcvtph2ps ymm4,XMMWORD PTR [r10 + 16]\t\n"
- "vbroadcastss ymm2,DWORD PTR [r9+0]\t\n"
- "vfmadd231ps ymm0,ymm3,ymm2\t\n"
- "vfmadd231ps ymm1,ymm4,ymm2\t\n"
- "add r9,4\t\n"
- "add r10,32\t\n"
- "inc r14\t\n"
- "cmp r14, r8\t\n"
- "jl loop_inner%=\t\n"
-
- "L_exit%=:\t\n"
-
- "cmp rdx, 1\t\n"
- "je L_accum%=\t\n"
- // Dump C
- "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm1\t\n"
- "add r12, r13\t\n"
- "jmp L_done%=\t\n"
-
- "L_accum%=:\t\n"
- // Dump C with accumulate
- "vbroadcastss ymm15,DWORD PTR [r15]\t\n"
- "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
- "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm1\t\n"
- "add r12, r13\t\n"
-
- "L_done%=:\t\n"
-
- // next outer iteration
- "add rcx, 64\t\n"
- "mov r12, rcx\t\n"
- "mov r9, rax\t\n"
- "inc rbx\t\n"
- "cmp rbx, rdi\t\n"
- "jl loop_outter%=\t\n"
- :
- : [gp] "rm"(gp)
- : "r8",
- "r9",
- "r10",
- "r11",
- "r15",
- "r13",
- "r14",
- "rax",
- "rcx",
- "rdx",
- "rsi",
- "rdi",
- "rbx",
- "r12",
- "memory");
+ float* rax = r9; //"mov rax, r9\t\n"
+ float* rcx = r12; //"mov rcx, r12\t\n"
+
+ uint64_t rbx = 0; //"mov rbx, 0\t\n"
+ for (; rbx < rdi; ++rbx) { //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n"
+ // //"loop_outter%=:\t\n"
+ uint64_t r14_i = 0; //"mov r14, 0\t\n"
+ __m256 ymm0 = _mm256_setzero_ps(); //"vxorps ymm0,ymm0,ymm0\t\n"
+ __m256 ymm1 = _mm256_setzero_ps(); //"vxorps ymm1,ymm1,ymm1\t\n"
+
+ for (; r14_i < r8; ++r14_i) { //"inc r14; cmp r14, r8; jl loop_inner%=\t\n"
+ // loop_inner%=: //"\t\n"
+ auto fp16mem0 = _mm_load_si128((__m128i*)((char*)r10 + 0)); //"vcvtph2ps ymm3,XMMWORD PTR [r10 + 0]\t\n"
+ auto ymm3 = _mm256_cvtph_ps(fp16mem0); //"vcvtph2ps ymm3,XMMWORD PTR [r10 + 0]\t\n"
+ auto fp16mem16 = _mm_load_si128((__m128i*)((char*)r10 + 16)); //"vcvtph2ps ymm4,XMMWORD PTR [r10 + 16]\t\n"
+ auto ymm4 = _mm256_cvtph_ps(fp16mem16); //"vcvtph2ps ymm4,XMMWORD PTR [r10 + 16]\t\n"
+ auto ymm2 = _mm256_broadcast_ss((float*)((char*)r9 + 0)); //"vbroadcastss ymm2,DWORD PTR [r9+0]\t\n"
+ ymm0 = _mm256_fmadd_ps(ymm2, ymm3, ymm0); //"vfmadd231ps ymm0,ymm3,ymm2\t\n"
+ ymm1 = _mm256_fmadd_ps(ymm2, ymm4, ymm1); //"vfmadd231ps ymm1,ymm4,ymm2\t\n"
+ r9 = (float*)((char*)r9 + 4); //"add r9,4\t\n"
+ r10 = (fp16*)((char*)r10 + 32); //"add r10,32\t\n"
+ } //"inc r14; cmp r14, r8; jl loop_outter%=\t\n"
+
+ if(rdx != 1) { //"cmp rdx, 1; je L_accum%=\t\n"
+ // Dump C
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm0); //"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm1); //"vmovups YMMWORD PTR [r12 + 32], ymm1\t\n"
+ } else { //"jmp L_done%=\t\n"
+ // Dump C with accumulate
+ auto ymm15 = _mm256_broadcast_ss((float*)r15); //"vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+ auto r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ ymm0 = _mm256_fmadd_ps(r12_0, ymm15, ymm0); //"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm0); //"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+ auto r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ ymm1 = _mm256_fmadd_ps(r12_32, ymm15, ymm1); //"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm1); //"vmovups YMMWORD PTR [r12 + 32], ymm1\t\n"
+ } //"L_done%=:\t\n"
+
+ // next outer iteration
+ rcx = (float*)((char*)rcx + 64); //"add rcx, 64\t\n"
+ r12 = rcx; //"mov r12, rcx\t\n"
+ r9 = rax; //"mov r9, rax\t\n"
+ } //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n"
}
-void __attribute__((noinline)) gemmkernel_2x2_AVX2_fA0fB0fC0(GemmParams* gp) {
- asm volatile(
-#if !defined(__clang__)
- "mov r14, %[gp]\t\n"
-#else
- "mov %[gp], %%r14\t\n"
- ".intel_syntax noprefix\t\n"
-#endif
+
+void NOINLINE_ATTR gemmkernel_2x2_AVX2_fA0fB0fC0(GemmParams* gp) {
+ char* r14 = (char*)gp; //"mov r14, %[gp]\t\n"
// Copy parameters
// k
- "mov r8, [r14 + 0]\t\n"
+ uint64_t r8 = *(uint64_t *)((char*)r14 + 0 ); //"mov r8, [r14 + 0]\t\n"
// A
- "mov r9, [r14 + 8]\t\n"
+ float* r9 = *(float* *)((char*)r14 + 8 ); //"mov r9, [r14 + 8]\t\n"
// B
- "mov r10, [r14 + 16]\t\n"
+ const fp16* r10 = *(const fp16**)((char*)r14 + 16); //"mov r10, [r14 + 16]\t\n"
// beta
- "mov r15, [r14 + 24]\t\n"
+ float* r15 = *(float* *)((char*)r14 + 24); //"mov r15, [r14 + 24]\t\n"
// accum
- "mov rdx, [r14 + 32]\t\n"
+ uint64_t rdx = *(uint64_t *)((char*)r14 + 32); //"mov rdx, [r14 + 32]\t\n"
// C
- "mov r12, [r14 + 40]\t\n"
+ float* r12 = *(float* *)((char*)r14 + 40); //"mov r12, [r14 + 40]\t\n"
// ldc
- "mov r13, [r14 + 48]\t\n"
+ uint64_t r13 = *(uint64_t *)((char*)r14 + 48); //"mov r13, [r14 + 48]\t\n"
// b_block_cols
- "mov rdi, [r14 + 56]\t\n"
+ uint64_t rdi = *(uint64_t *)((char*)r14 + 56); //"mov rdi, [r14 + 56]\t\n"
// b_block_size
- "mov rsi, [r14 + 64]\t\n"
+ uint64_t rsi = *(uint64_t *)((char*)r14 + 64); //"mov rsi, [r14 + 64]\t\n"
// Make copies of A and C
- "mov rax, r9\t\n"
- "mov rcx, r12\t\n"
-
- "mov rbx, 0\t\n"
- "loop_outter%=:\t\n"
- "mov r14, 0\t\n"
- "vxorps ymm0,ymm0,ymm0\t\n"
- "vxorps ymm1,ymm1,ymm1\t\n"
- "vxorps ymm2,ymm2,ymm2\t\n"
- "vxorps ymm3,ymm3,ymm3\t\n"
-
-
- "loop_inner%=:\t\n"
-
- "vcvtph2ps ymm5,XMMWORD PTR [r10 + 0]\t\n"
- "vcvtph2ps ymm6,XMMWORD PTR [r10 + 16]\t\n"
- "vbroadcastss ymm4,DWORD PTR [r9+0]\t\n"
- "vfmadd231ps ymm0,ymm5,ymm4\t\n"
- "vfmadd231ps ymm1,ymm6,ymm4\t\n"
- "vbroadcastss ymm4,DWORD PTR [r9+4]\t\n"
- "vfmadd231ps ymm2,ymm5,ymm4\t\n"
- "vfmadd231ps ymm3,ymm6,ymm4\t\n"
- "add r9,8\t\n"
- "add r10,32\t\n"
- "inc r14\t\n"
- "cmp r14, r8\t\n"
- "jl loop_inner%=\t\n"
-
- "L_exit%=:\t\n"
-
- "cmp rdx, 1\t\n"
- "je L_accum%=\t\n"
- // Dump C
- "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm1\t\n"
- "add r12, r13\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm3\t\n"
- "add r12, r13\t\n"
- "jmp L_done%=\t\n"
-
- "L_accum%=:\t\n"
- // Dump C with accumulate
- "vbroadcastss ymm15,DWORD PTR [r15]\t\n"
- "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
- "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm1\t\n"
- "add r12, r13\t\n"
- "vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
- "vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 32]\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm3\t\n"
- "add r12, r13\t\n"
-
- "L_done%=:\t\n"
-
- // next outer iteration
- "add rcx, 64\t\n"
- "mov r12, rcx\t\n"
- "mov r9, rax\t\n"
- "inc rbx\t\n"
- "cmp rbx, rdi\t\n"
- "jl loop_outter%=\t\n"
- :
- : [gp] "rm"(gp)
- : "r8",
- "r9",
- "r10",
- "r11",
- "r15",
- "r13",
- "r14",
- "rax",
- "rcx",
- "rdx",
- "rsi",
- "rdi",
- "rbx",
- "r12",
- "memory");
+ float* rax = r9; //"mov rax, r9\t\n"
+ float* rcx = r12; //"mov rcx, r12\t\n"
+
+ uint64_t rbx = 0; //"mov rbx, 0\t\n"
+ for (; rbx < rdi; ++rbx) { //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n"
+ // //"loop_outter%=:\t\n"
+ uint64_t r14_i = 0; //"mov r14, 0\t\n"
+ __m256 ymm0 = _mm256_setzero_ps(); //"vxorps ymm0,ymm0,ymm0\t\n"
+ __m256 ymm1 = _mm256_setzero_ps(); //"vxorps ymm1,ymm1,ymm1\t\n"
+ __m256 ymm2 = _mm256_setzero_ps(); //"vxorps ymm2,ymm2,ymm2\t\n"
+ __m256 ymm3 = _mm256_setzero_ps(); //"vxorps ymm3,ymm3,ymm3\t\n"
+
+ for (; r14_i < r8; ++r14_i) { //"inc r14; cmp r14, r8; jl loop_inner%=\t\n"
+ // loop_inner%=: //"\t\n"
+ auto fp16mem0 = _mm_load_si128((__m128i*)((char*)r10 + 0)); //"vcvtph2ps ymm5,XMMWORD PTR [r10 + 0]\t\n"
+ auto ymm5 = _mm256_cvtph_ps(fp16mem0); //"vcvtph2ps ymm5,XMMWORD PTR [r10 + 0]\t\n"
+ auto fp16mem16 = _mm_load_si128((__m128i*)((char*)r10 + 16)); //"vcvtph2ps ymm6,XMMWORD PTR [r10 + 16]\t\n"
+ auto ymm6 = _mm256_cvtph_ps(fp16mem16); //"vcvtph2ps ymm6,XMMWORD PTR [r10 + 16]\t\n"
+ auto ymm4 = _mm256_broadcast_ss((float*)((char*)r9 + 0)); //"vbroadcastss ymm4,DWORD PTR [r9+0]\t\n"
+ ymm0 = _mm256_fmadd_ps(ymm4, ymm5, ymm0); //"vfmadd231ps ymm0,ymm5,ymm4\t\n"
+ ymm1 = _mm256_fmadd_ps(ymm4, ymm6, ymm1); //"vfmadd231ps ymm1,ymm6,ymm4\t\n"
+ ymm4 = _mm256_broadcast_ss((float*)((char*)r9 + 4)); //"vbroadcastss ymm4,DWORD PTR [r9+4]\t\n"
+ ymm2 = _mm256_fmadd_ps(ymm4, ymm5, ymm2); //"vfmadd231ps ymm2,ymm5,ymm4\t\n"
+ ymm3 = _mm256_fmadd_ps(ymm4, ymm6, ymm3); //"vfmadd231ps ymm3,ymm6,ymm4\t\n"
+ r9 = (float*)((char*)r9 + 8); //"add r9,8\t\n"
+ r10 = (fp16*)((char*)r10 + 32); //"add r10,32\t\n"
+ } //"inc r14; cmp r14, r8; jl loop_outter%=\t\n"
+
+ if(rdx != 1) { //"cmp rdx, 1; je L_accum%=\t\n"
+ // Dump C
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm0); //"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm1); //"vmovups YMMWORD PTR [r12 + 32], ymm1\t\n"
+ r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm2); //"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm3); //"vmovups YMMWORD PTR [r12 + 32], ymm3\t\n"
+ } else { //"jmp L_done%=\t\n"
+ // Dump C with accumulate
+ auto ymm15 = _mm256_broadcast_ss((float*)r15); //"vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+ auto r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ ymm0 = _mm256_fmadd_ps(r12_0, ymm15, ymm0); //"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm0); //"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+ auto r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ ymm1 = _mm256_fmadd_ps(r12_32, ymm15, ymm1); //"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm1); //"vmovups YMMWORD PTR [r12 + 32], ymm1\t\n"
+ r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n"
+ r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ ymm2 = _mm256_fmadd_ps(r12_0, ymm15, ymm2); //"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm2); //"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+ r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ ymm3 = _mm256_fmadd_ps(r12_32, ymm15, ymm3); //"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm3); //"vmovups YMMWORD PTR [r12 + 32], ymm3\t\n"
+ } //"L_done%=:\t\n"
+
+ // next outer iteration
+ rcx = (float*)((char*)rcx + 64); //"add rcx, 64\t\n"
+ r12 = rcx; //"mov r12, rcx\t\n"
+ r9 = rax; //"mov r9, rax\t\n"
+ } //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n"
}
-void __attribute__((noinline)) gemmkernel_3x2_AVX2_fA0fB0fC0(GemmParams* gp) {
- asm volatile(
-#if !defined(__clang__)
- "mov r14, %[gp]\t\n"
-#else
- "mov %[gp], %%r14\t\n"
- ".intel_syntax noprefix\t\n"
-#endif
+
+void NOINLINE_ATTR gemmkernel_3x2_AVX2_fA0fB0fC0(GemmParams* gp) {
+ char* r14 = (char*)gp; //"mov r14, %[gp]\t\n"
// Copy parameters
// k
- "mov r8, [r14 + 0]\t\n"
+ uint64_t r8 = *(uint64_t *)((char*)r14 + 0 ); //"mov r8, [r14 + 0]\t\n"
// A
- "mov r9, [r14 + 8]\t\n"
+ float* r9 = *(float* *)((char*)r14 + 8 ); //"mov r9, [r14 + 8]\t\n"
// B
- "mov r10, [r14 + 16]\t\n"
+ const fp16* r10 = *(const fp16**)((char*)r14 + 16); //"mov r10, [r14 + 16]\t\n"
// beta
- "mov r15, [r14 + 24]\t\n"
+ float* r15 = *(float* *)((char*)r14 + 24); //"mov r15, [r14 + 24]\t\n"
// accum
- "mov rdx, [r14 + 32]\t\n"
+ uint64_t rdx = *(uint64_t *)((char*)r14 + 32); //"mov rdx, [r14 + 32]\t\n"
// C
- "mov r12, [r14 + 40]\t\n"
+ float* r12 = *(float* *)((char*)r14 + 40); //"mov r12, [r14 + 40]\t\n"
// ldc
- "mov r13, [r14 + 48]\t\n"
+ uint64_t r13 = *(uint64_t *)((char*)r14 + 48); //"mov r13, [r14 + 48]\t\n"
// b_block_cols
- "mov rdi, [r14 + 56]\t\n"
+ uint64_t rdi = *(uint64_t *)((char*)r14 + 56); //"mov rdi, [r14 + 56]\t\n"
// b_block_size
- "mov rsi, [r14 + 64]\t\n"
+ uint64_t rsi = *(uint64_t *)((char*)r14 + 64); //"mov rsi, [r14 + 64]\t\n"
// Make copies of A and C
- "mov rax, r9\t\n"
- "mov rcx, r12\t\n"
-
- "mov rbx, 0\t\n"
- "loop_outter%=:\t\n"
- "mov r14, 0\t\n"
- "vxorps ymm0,ymm0,ymm0\t\n"
- "vxorps ymm1,ymm1,ymm1\t\n"
- "vxorps ymm2,ymm2,ymm2\t\n"
- "vxorps ymm3,ymm3,ymm3\t\n"
- "vxorps ymm4,ymm4,ymm4\t\n"
- "vxorps ymm5,ymm5,ymm5\t\n"
-
-
- "loop_inner%=:\t\n"
-
- "vcvtph2ps ymm7,XMMWORD PTR [r10 + 0]\t\n"
- "vcvtph2ps ymm8,XMMWORD PTR [r10 + 16]\t\n"
- "vbroadcastss ymm6,DWORD PTR [r9+0]\t\n"
- "vfmadd231ps ymm0,ymm7,ymm6\t\n"
- "vfmadd231ps ymm1,ymm8,ymm6\t\n"
- "vbroadcastss ymm6,DWORD PTR [r9+4]\t\n"
- "vfmadd231ps ymm2,ymm7,ymm6\t\n"
- "vfmadd231ps ymm3,ymm8,ymm6\t\n"
- "vbroadcastss ymm6,DWORD PTR [r9+8]\t\n"
- "vfmadd231ps ymm4,ymm7,ymm6\t\n"
- "vfmadd231ps ymm5,ymm8,ymm6\t\n"
- "add r9,12\t\n"
- "add r10,32\t\n"
- "inc r14\t\n"
- "cmp r14, r8\t\n"
- "jl loop_inner%=\t\n"
-
- "L_exit%=:\t\n"
-
- "cmp rdx, 1\t\n"
- "je L_accum%=\t\n"
- // Dump C
- "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm1\t\n"
- "add r12, r13\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm3\t\n"
- "add r12, r13\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm5\t\n"
- "add r12, r13\t\n"
- "jmp L_done%=\t\n"
-
- "L_accum%=:\t\n"
- // Dump C with accumulate
- "vbroadcastss ymm15,DWORD PTR [r15]\t\n"
- "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
- "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm1\t\n"
- "add r12, r13\t\n"
- "vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
- "vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 32]\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm3\t\n"
- "add r12, r13\t\n"
- "vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
- "vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 32]\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm5\t\n"
- "add r12, r13\t\n"
-
- "L_done%=:\t\n"
-
- // next outer iteration
- "add rcx, 64\t\n"
- "mov r12, rcx\t\n"
- "mov r9, rax\t\n"
- "inc rbx\t\n"
- "cmp rbx, rdi\t\n"
- "jl loop_outter%=\t\n"
- :
- : [gp] "rm"(gp)
- : "r8",
- "r9",
- "r10",
- "r11",
- "r15",
- "r13",
- "r14",
- "rax",
- "rcx",
- "rdx",
- "rsi",
- "rdi",
- "rbx",
- "r12",
- "memory");
+ float* rax = r9; //"mov rax, r9\t\n"
+ float* rcx = r12; //"mov rcx, r12\t\n"
+
+ uint64_t rbx = 0; //"mov rbx, 0\t\n"
+ for (; rbx < rdi; ++rbx) { //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n"
+ // //"loop_outter%=:\t\n"
+ uint64_t r14_i = 0; //"mov r14, 0\t\n"
+ __m256 ymm0 = _mm256_setzero_ps(); //"vxorps ymm0,ymm0,ymm0\t\n"
+ __m256 ymm1 = _mm256_setzero_ps(); //"vxorps ymm1,ymm1,ymm1\t\n"
+ __m256 ymm2 = _mm256_setzero_ps(); //"vxorps ymm2,ymm2,ymm2\t\n"
+ __m256 ymm3 = _mm256_setzero_ps(); //"vxorps ymm3,ymm3,ymm3\t\n"
+ __m256 ymm4 = _mm256_setzero_ps(); //"vxorps ymm4,ymm4,ymm4\t\n"
+ __m256 ymm5 = _mm256_setzero_ps(); //"vxorps ymm5,ymm5,ymm5\t\n"
+
+ for (; r14_i < r8; ++r14_i) { //"inc r14; cmp r14, r8; jl loop_inner%=\t\n"
+ // loop_inner%=: //"\t\n"
+ auto fp16mem0 = _mm_load_si128((__m128i*)((char*)r10 + 0)); //"vcvtph2ps ymm7,XMMWORD PTR [r10 + 0]\t\n"
+ auto ymm7 = _mm256_cvtph_ps(fp16mem0); //"vcvtph2ps ymm7,XMMWORD PTR [r10 + 0]\t\n"
+ auto fp16mem16 = _mm_load_si128((__m128i*)((char*)r10 + 16)); //"vcvtph2ps ymm8,XMMWORD PTR [r10 + 16]\t\n"
+ auto ymm8 = _mm256_cvtph_ps(fp16mem16); //"vcvtph2ps ymm8,XMMWORD PTR [r10 + 16]\t\n"
+ auto ymm6 = _mm256_broadcast_ss((float*)((char*)r9 + 0)); //"vbroadcastss ymm6,DWORD PTR [r9+0]\t\n"
+ ymm0 = _mm256_fmadd_ps(ymm6, ymm7, ymm0); //"vfmadd231ps ymm0,ymm7,ymm6\t\n"
+ ymm1 = _mm256_fmadd_ps(ymm6, ymm8, ymm1); //"vfmadd231ps ymm1,ymm8,ymm6\t\n"
+ ymm6 = _mm256_broadcast_ss((float*)((char*)r9 + 4)); //"vbroadcastss ymm6,DWORD PTR [r9+4]\t\n"
+ ymm2 = _mm256_fmadd_ps(ymm6, ymm7, ymm2); //"vfmadd231ps ymm2,ymm7,ymm6\t\n"
+ ymm3 = _mm256_fmadd_ps(ymm6, ymm8, ymm3); //"vfmadd231ps ymm3,ymm8,ymm6\t\n"
+ ymm6 = _mm256_broadcast_ss((float*)((char*)r9 + 8)); //"vbroadcastss ymm6,DWORD PTR [r9+8]\t\n"
+ ymm4 = _mm256_fmadd_ps(ymm6, ymm7, ymm4); //"vfmadd231ps ymm4,ymm7,ymm6\t\n"
+ ymm5 = _mm256_fmadd_ps(ymm6, ymm8, ymm5); //"vfmadd231ps ymm5,ymm8,ymm6\t\n"
+ r9 = (float*)((char*)r9 + 12); //"add r9,12\t\n"
+ r10 = (fp16*)((char*)r10 + 32); //"add r10,32\t\n"
+ } //"inc r14; cmp r14, r8; jl loop_outter%=\t\n"
+
+ if(rdx != 1) { //"cmp rdx, 1; je L_accum%=\t\n"
+ // Dump C
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm0); //"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm1); //"vmovups YMMWORD PTR [r12 + 32], ymm1\t\n"
+ r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm2); //"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm3); //"vmovups YMMWORD PTR [r12 + 32], ymm3\t\n"
+ r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm4); //"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm5); //"vmovups YMMWORD PTR [r12 + 32], ymm5\t\n"
+ } else { //"jmp L_done%=\t\n"
+ // Dump C with accumulate
+ auto ymm15 = _mm256_broadcast_ss((float*)r15); //"vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+ auto r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ ymm0 = _mm256_fmadd_ps(r12_0, ymm15, ymm0); //"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm0); //"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+ auto r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ ymm1 = _mm256_fmadd_ps(r12_32, ymm15, ymm1); //"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm1); //"vmovups YMMWORD PTR [r12 + 32], ymm1\t\n"
+ r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n"
+ r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ ymm2 = _mm256_fmadd_ps(r12_0, ymm15, ymm2); //"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm2); //"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+ r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ ymm3 = _mm256_fmadd_ps(r12_32, ymm15, ymm3); //"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm3); //"vmovups YMMWORD PTR [r12 + 32], ymm3\t\n"
+ r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n"
+ r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ ymm4 = _mm256_fmadd_ps(r12_0, ymm15, ymm4); //"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm4); //"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+ r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ ymm5 = _mm256_fmadd_ps(r12_32, ymm15, ymm5); //"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm5); //"vmovups YMMWORD PTR [r12 + 32], ymm5\t\n"
+ } //"L_done%=:\t\n"
+
+ // next outer iteration
+ rcx = (float*)((char*)rcx + 64); //"add rcx, 64\t\n"
+ r12 = rcx; //"mov r12, rcx\t\n"
+ r9 = rax; //"mov r9, rax\t\n"
+ } //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n"
}
-void __attribute__((noinline)) gemmkernel_4x2_AVX2_fA0fB0fC0(GemmParams* gp) {
- asm volatile(
-#if !defined(__clang__)
- "mov r14, %[gp]\t\n"
-#else
- "mov %[gp], %%r14\t\n"
- ".intel_syntax noprefix\t\n"
-#endif
+
+void NOINLINE_ATTR gemmkernel_4x2_AVX2_fA0fB0fC0(GemmParams* gp) {
+ char* r14 = (char*)gp; //"mov r14, %[gp]\t\n"
// Copy parameters
// k
- "mov r8, [r14 + 0]\t\n"
+ uint64_t r8 = *(uint64_t *)((char*)r14 + 0 ); //"mov r8, [r14 + 0]\t\n"
// A
- "mov r9, [r14 + 8]\t\n"
+ float* r9 = *(float* *)((char*)r14 + 8 ); //"mov r9, [r14 + 8]\t\n"
// B
- "mov r10, [r14 + 16]\t\n"
+ const fp16* r10 = *(const fp16**)((char*)r14 + 16); //"mov r10, [r14 + 16]\t\n"
// beta
- "mov r15, [r14 + 24]\t\n"
+ float* r15 = *(float* *)((char*)r14 + 24); //"mov r15, [r14 + 24]\t\n"
// accum
- "mov rdx, [r14 + 32]\t\n"
+ uint64_t rdx = *(uint64_t *)((char*)r14 + 32); //"mov rdx, [r14 + 32]\t\n"
// C
- "mov r12, [r14 + 40]\t\n"
+ float* r12 = *(float* *)((char*)r14 + 40); //"mov r12, [r14 + 40]\t\n"
// ldc
- "mov r13, [r14 + 48]\t\n"
+ uint64_t r13 = *(uint64_t *)((char*)r14 + 48); //"mov r13, [r14 + 48]\t\n"
// b_block_cols
- "mov rdi, [r14 + 56]\t\n"
+ uint64_t rdi = *(uint64_t *)((char*)r14 + 56); //"mov rdi, [r14 + 56]\t\n"
// b_block_size
- "mov rsi, [r14 + 64]\t\n"
+ uint64_t rsi = *(uint64_t *)((char*)r14 + 64); //"mov rsi, [r14 + 64]\t\n"
// Make copies of A and C
- "mov rax, r9\t\n"
- "mov rcx, r12\t\n"
-
- "mov rbx, 0\t\n"
- "loop_outter%=:\t\n"
- "mov r14, 0\t\n"
- "vxorps ymm0,ymm0,ymm0\t\n"
- "vxorps ymm1,ymm1,ymm1\t\n"
- "vxorps ymm2,ymm2,ymm2\t\n"
- "vxorps ymm3,ymm3,ymm3\t\n"
- "vxorps ymm4,ymm4,ymm4\t\n"
- "vxorps ymm5,ymm5,ymm5\t\n"
- "vxorps ymm6,ymm6,ymm6\t\n"
- "vxorps ymm7,ymm7,ymm7\t\n"
-
-
- "loop_inner%=:\t\n"
-
- "vcvtph2ps ymm9,XMMWORD PTR [r10 + 0]\t\n"
- "vcvtph2ps ymm10,XMMWORD PTR [r10 + 16]\t\n"
- "vbroadcastss ymm8,DWORD PTR [r9+0]\t\n"
- "vfmadd231ps ymm0,ymm9,ymm8\t\n"
- "vfmadd231ps ymm1,ymm10,ymm8\t\n"
- "vbroadcastss ymm8,DWORD PTR [r9+4]\t\n"
- "vfmadd231ps ymm2,ymm9,ymm8\t\n"
- "vfmadd231ps ymm3,ymm10,ymm8\t\n"
- "vbroadcastss ymm8,DWORD PTR [r9+8]\t\n"
- "vfmadd231ps ymm4,ymm9,ymm8\t\n"
- "vfmadd231ps ymm5,ymm10,ymm8\t\n"
- "vbroadcastss ymm8,DWORD PTR [r9+12]\t\n"
- "vfmadd231ps ymm6,ymm9,ymm8\t\n"
- "vfmadd231ps ymm7,ymm10,ymm8\t\n"
- "add r9,16\t\n"
- "add r10,32\t\n"
- "inc r14\t\n"
- "cmp r14, r8\t\n"
- "jl loop_inner%=\t\n"
-
- "L_exit%=:\t\n"
-
- "cmp rdx, 1\t\n"
- "je L_accum%=\t\n"
- // Dump C
- "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm1\t\n"
- "add r12, r13\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm3\t\n"
- "add r12, r13\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm5\t\n"
- "add r12, r13\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm7\t\n"
- "add r12, r13\t\n"
- "jmp L_done%=\t\n"
-
- "L_accum%=:\t\n"
- // Dump C with accumulate
- "vbroadcastss ymm15,DWORD PTR [r15]\t\n"
- "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
- "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm1\t\n"
- "add r12, r13\t\n"
- "vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
- "vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 32]\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm3\t\n"
- "add r12, r13\t\n"
- "vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
- "vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 32]\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm5\t\n"
- "add r12, r13\t\n"
- "vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
- "vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 32]\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm7\t\n"
- "add r12, r13\t\n"
-
- "L_done%=:\t\n"
-
- // next outer iteration
- "add rcx, 64\t\n"
- "mov r12, rcx\t\n"
- "mov r9, rax\t\n"
- "inc rbx\t\n"
- "cmp rbx, rdi\t\n"
- "jl loop_outter%=\t\n"
- :
- : [gp] "rm"(gp)
- : "r8",
- "r9",
- "r10",
- "r11",
- "r15",
- "r13",
- "r14",
- "rax",
- "rcx",
- "rdx",
- "rsi",
- "rdi",
- "rbx",
- "r12",
- "memory");
+ float* rax = r9; //"mov rax, r9\t\n"
+ float* rcx = r12; //"mov rcx, r12\t\n"
+
+ uint64_t rbx = 0; //"mov rbx, 0\t\n"
+ for (; rbx < rdi; ++rbx) { //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n"
+ // //"loop_outter%=:\t\n"
+ uint64_t r14_i = 0; //"mov r14, 0\t\n"
+ __m256 ymm0 = _mm256_setzero_ps(); //"vxorps ymm0,ymm0,ymm0\t\n"
+ __m256 ymm1 = _mm256_setzero_ps(); //"vxorps ymm1,ymm1,ymm1\t\n"
+ __m256 ymm2 = _mm256_setzero_ps(); //"vxorps ymm2,ymm2,ymm2\t\n"
+ __m256 ymm3 = _mm256_setzero_ps(); //"vxorps ymm3,ymm3,ymm3\t\n"
+ __m256 ymm4 = _mm256_setzero_ps(); //"vxorps ymm4,ymm4,ymm4\t\n"
+ __m256 ymm5 = _mm256_setzero_ps(); //"vxorps ymm5,ymm5,ymm5\t\n"
+ __m256 ymm6 = _mm256_setzero_ps(); //"vxorps ymm6,ymm6,ymm6\t\n"
+ __m256 ymm7 = _mm256_setzero_ps(); //"vxorps ymm7,ymm7,ymm7\t\n"
+
+ for (; r14_i < r8; ++r14_i) { //"inc r14; cmp r14, r8; jl loop_inner%=\t\n"
+ // loop_inner%=: //"\t\n"
+ auto fp16mem0 = _mm_load_si128((__m128i*)((char*)r10 + 0)); //"vcvtph2ps ymm9,XMMWORD PTR [r10 + 0]\t\n"
+ auto ymm9 = _mm256_cvtph_ps(fp16mem0); //"vcvtph2ps ymm9,XMMWORD PTR [r10 + 0]\t\n"
+ auto fp16mem16 = _mm_load_si128((__m128i*)((char*)r10 + 16)); //"vcvtph2ps ymm10,XMMWORD PTR [r10 + 16]\t\n"
+ auto ymm10 = _mm256_cvtph_ps(fp16mem16); //"vcvtph2ps ymm10,XMMWORD PTR [r10 + 16]\t\n"
+ auto ymm8 = _mm256_broadcast_ss((float*)((char*)r9 + 0)); //"vbroadcastss ymm8,DWORD PTR [r9+0]\t\n"
+ ymm0 = _mm256_fmadd_ps(ymm8, ymm9, ymm0); //"vfmadd231ps ymm0,ymm9,ymm8\t\n"
+ ymm1 = _mm256_fmadd_ps(ymm8, ymm10, ymm1); //"vfmadd231ps ymm1,ymm10,ymm8\t\n"
+ ymm8 = _mm256_broadcast_ss((float*)((char*)r9 + 4)); //"vbroadcastss ymm8,DWORD PTR [r9+4]\t\n"
+ ymm2 = _mm256_fmadd_ps(ymm8, ymm9, ymm2); //"vfmadd231ps ymm2,ymm9,ymm8\t\n"
+ ymm3 = _mm256_fmadd_ps(ymm8, ymm10, ymm3); //"vfmadd231ps ymm3,ymm10,ymm8\t\n"
+ ymm8 = _mm256_broadcast_ss((float*)((char*)r9 + 8)); //"vbroadcastss ymm8,DWORD PTR [r9+8]\t\n"
+ ymm4 = _mm256_fmadd_ps(ymm8, ymm9, ymm4); //"vfmadd231ps ymm4,ymm9,ymm8\t\n"
+ ymm5 = _mm256_fmadd_ps(ymm8, ymm10, ymm5); //"vfmadd231ps ymm5,ymm10,ymm8\t\n"
+ ymm8 = _mm256_broadcast_ss((float*)((char*)r9 + 12)); //"vbroadcastss ymm8,DWORD PTR [r9+12]\t\n"
+ ymm6 = _mm256_fmadd_ps(ymm8, ymm9, ymm6); //"vfmadd231ps ymm6,ymm9,ymm8\t\n"
+ ymm7 = _mm256_fmadd_ps(ymm8, ymm10, ymm7); //"vfmadd231ps ymm7,ymm10,ymm8\t\n"
+ r9 = (float*)((char*)r9 + 16); //"add r9,16\t\n"
+ r10 = (fp16*)((char*)r10 + 32); //"add r10,32\t\n"
+ } //"inc r14; cmp r14, r8; jl loop_outter%=\t\n"
+
+ if(rdx != 1) { //"cmp rdx, 1; je L_accum%=\t\n"
+ // Dump C
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm0); //"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm1); //"vmovups YMMWORD PTR [r12 + 32], ymm1\t\n"
+ r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm2); //"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm3); //"vmovups YMMWORD PTR [r12 + 32], ymm3\t\n"
+ r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm4); //"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm5); //"vmovups YMMWORD PTR [r12 + 32], ymm5\t\n"
+ r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm6); //"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm7); //"vmovups YMMWORD PTR [r12 + 32], ymm7\t\n"
+ } else { //"jmp L_done%=\t\n"
+ // Dump C with accumulate
+ auto ymm15 = _mm256_broadcast_ss((float*)r15); //"vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+ auto r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ ymm0 = _mm256_fmadd_ps(r12_0, ymm15, ymm0); //"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm0); //"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+ auto r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ ymm1 = _mm256_fmadd_ps(r12_32, ymm15, ymm1); //"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm1); //"vmovups YMMWORD PTR [r12 + 32], ymm1\t\n"
+ r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n"
+ r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ ymm2 = _mm256_fmadd_ps(r12_0, ymm15, ymm2); //"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm2); //"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+ r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ ymm3 = _mm256_fmadd_ps(r12_32, ymm15, ymm3); //"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm3); //"vmovups YMMWORD PTR [r12 + 32], ymm3\t\n"
+ r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n"
+ r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ ymm4 = _mm256_fmadd_ps(r12_0, ymm15, ymm4); //"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm4); //"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+ r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ ymm5 = _mm256_fmadd_ps(r12_32, ymm15, ymm5); //"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm5); //"vmovups YMMWORD PTR [r12 + 32], ymm5\t\n"
+ r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n"
+ r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ ymm6 = _mm256_fmadd_ps(r12_0, ymm15, ymm6); //"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm6); //"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+ r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ ymm7 = _mm256_fmadd_ps(r12_32, ymm15, ymm7); //"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm7); //"vmovups YMMWORD PTR [r12 + 32], ymm7\t\n"
+ } //"L_done%=:\t\n"
+
+ // next outer iteration
+ rcx = (float*)((char*)rcx + 64); //"add rcx, 64\t\n"
+ r12 = rcx; //"mov r12, rcx\t\n"
+ r9 = rax; //"mov r9, rax\t\n"
+ } //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n"
}
-void __attribute__((noinline)) gemmkernel_5x2_AVX2_fA0fB0fC0(GemmParams* gp) {
- asm volatile(
-#if !defined(__clang__)
- "mov r14, %[gp]\t\n"
-#else
- "mov %[gp], %%r14\t\n"
- ".intel_syntax noprefix\t\n"
-#endif
+
+void NOINLINE_ATTR gemmkernel_5x2_AVX2_fA0fB0fC0(GemmParams* gp) {
+ char* r14 = (char*)gp; //"mov r14, %[gp]\t\n"
// Copy parameters
// k
- "mov r8, [r14 + 0]\t\n"
+ uint64_t r8 = *(uint64_t *)((char*)r14 + 0 ); //"mov r8, [r14 + 0]\t\n"
// A
- "mov r9, [r14 + 8]\t\n"
+ float* r9 = *(float* *)((char*)r14 + 8 ); //"mov r9, [r14 + 8]\t\n"
// B
- "mov r10, [r14 + 16]\t\n"
+ const fp16* r10 = *(const fp16**)((char*)r14 + 16); //"mov r10, [r14 + 16]\t\n"
// beta
- "mov r15, [r14 + 24]\t\n"
+ float* r15 = *(float* *)((char*)r14 + 24); //"mov r15, [r14 + 24]\t\n"
// accum
- "mov rdx, [r14 + 32]\t\n"
+ uint64_t rdx = *(uint64_t *)((char*)r14 + 32); //"mov rdx, [r14 + 32]\t\n"
// C
- "mov r12, [r14 + 40]\t\n"
+ float* r12 = *(float* *)((char*)r14 + 40); //"mov r12, [r14 + 40]\t\n"
// ldc
- "mov r13, [r14 + 48]\t\n"
+ uint64_t r13 = *(uint64_t *)((char*)r14 + 48); //"mov r13, [r14 + 48]\t\n"
// b_block_cols
- "mov rdi, [r14 + 56]\t\n"
+ uint64_t rdi = *(uint64_t *)((char*)r14 + 56); //"mov rdi, [r14 + 56]\t\n"
// b_block_size
- "mov rsi, [r14 + 64]\t\n"
+ uint64_t rsi = *(uint64_t *)((char*)r14 + 64); //"mov rsi, [r14 + 64]\t\n"
// Make copies of A and C
- "mov rax, r9\t\n"
- "mov rcx, r12\t\n"
-
- "mov rbx, 0\t\n"
- "loop_outter%=:\t\n"
- "mov r14, 0\t\n"
- "vxorps ymm0,ymm0,ymm0\t\n"
- "vxorps ymm1,ymm1,ymm1\t\n"
- "vxorps ymm2,ymm2,ymm2\t\n"
- "vxorps ymm3,ymm3,ymm3\t\n"
- "vxorps ymm4,ymm4,ymm4\t\n"
- "vxorps ymm5,ymm5,ymm5\t\n"
- "vxorps ymm6,ymm6,ymm6\t\n"
- "vxorps ymm7,ymm7,ymm7\t\n"
- "vxorps ymm8,ymm8,ymm8\t\n"
- "vxorps ymm9,ymm9,ymm9\t\n"
-
-
- "loop_inner%=:\t\n"
-
- "vcvtph2ps ymm11,XMMWORD PTR [r10 + 0]\t\n"
- "vcvtph2ps ymm12,XMMWORD PTR [r10 + 16]\t\n"
- "vbroadcastss ymm10,DWORD PTR [r9+0]\t\n"
- "vfmadd231ps ymm0,ymm11,ymm10\t\n"
- "vfmadd231ps ymm1,ymm12,ymm10\t\n"
- "vbroadcastss ymm10,DWORD PTR [r9+4]\t\n"
- "vfmadd231ps ymm2,ymm11,ymm10\t\n"
- "vfmadd231ps ymm3,ymm12,ymm10\t\n"
- "vbroadcastss ymm10,DWORD PTR [r9+8]\t\n"
- "vfmadd231ps ymm4,ymm11,ymm10\t\n"
- "vfmadd231ps ymm5,ymm12,ymm10\t\n"
- "vbroadcastss ymm10,DWORD PTR [r9+12]\t\n"
- "vfmadd231ps ymm6,ymm11,ymm10\t\n"
- "vfmadd231ps ymm7,ymm12,ymm10\t\n"
- "vbroadcastss ymm10,DWORD PTR [r9+16]\t\n"
- "vfmadd231ps ymm8,ymm11,ymm10\t\n"
- "vfmadd231ps ymm9,ymm12,ymm10\t\n"
- "add r9,20\t\n"
- "add r10,32\t\n"
- "inc r14\t\n"
- "cmp r14, r8\t\n"
- "jl loop_inner%=\t\n"
-
- "L_exit%=:\t\n"
-
- "cmp rdx, 1\t\n"
- "je L_accum%=\t\n"
- // Dump C
- "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm1\t\n"
- "add r12, r13\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm3\t\n"
- "add r12, r13\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm5\t\n"
- "add r12, r13\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm7\t\n"
- "add r12, r13\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm8\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm9\t\n"
- "add r12, r13\t\n"
- "jmp L_done%=\t\n"
-
- "L_accum%=:\t\n"
- // Dump C with accumulate
- "vbroadcastss ymm15,DWORD PTR [r15]\t\n"
- "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
- "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm1\t\n"
- "add r12, r13\t\n"
- "vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
- "vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 32]\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm3\t\n"
- "add r12, r13\t\n"
- "vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
- "vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 32]\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm5\t\n"
- "add r12, r13\t\n"
- "vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
- "vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 32]\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm7\t\n"
- "add r12, r13\t\n"
- "vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm8\t\n"
- "vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 32]\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm9\t\n"
- "add r12, r13\t\n"
-
- "L_done%=:\t\n"
-
- // next outer iteration
- "add rcx, 64\t\n"
- "mov r12, rcx\t\n"
- "mov r9, rax\t\n"
- "inc rbx\t\n"
- "cmp rbx, rdi\t\n"
- "jl loop_outter%=\t\n"
- :
- : [gp] "rm"(gp)
- : "r8",
- "r9",
- "r10",
- "r11",
- "r15",
- "r13",
- "r14",
- "rax",
- "rcx",
- "rdx",
- "rsi",
- "rdi",
- "rbx",
- "r12",
- "memory");
+ float* rax = r9; //"mov rax, r9\t\n"
+ float* rcx = r12; //"mov rcx, r12\t\n"
+
+ uint64_t rbx = 0; //"mov rbx, 0\t\n"
+ for (; rbx < rdi; ++rbx) { //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n"
+ // //"loop_outter%=:\t\n"
+ uint64_t r14_i = 0; //"mov r14, 0\t\n"
+ __m256 ymm0 = _mm256_setzero_ps(); //"vxorps ymm0,ymm0,ymm0\t\n"
+ __m256 ymm1 = _mm256_setzero_ps(); //"vxorps ymm1,ymm1,ymm1\t\n"
+ __m256 ymm2 = _mm256_setzero_ps(); //"vxorps ymm2,ymm2,ymm2\t\n"
+ __m256 ymm3 = _mm256_setzero_ps(); //"vxorps ymm3,ymm3,ymm3\t\n"
+ __m256 ymm4 = _mm256_setzero_ps(); //"vxorps ymm4,ymm4,ymm4\t\n"
+ __m256 ymm5 = _mm256_setzero_ps(); //"vxorps ymm5,ymm5,ymm5\t\n"
+ __m256 ymm6 = _mm256_setzero_ps(); //"vxorps ymm6,ymm6,ymm6\t\n"
+ __m256 ymm7 = _mm256_setzero_ps(); //"vxorps ymm7,ymm7,ymm7\t\n"
+ __m256 ymm8 = _mm256_setzero_ps(); //"vxorps ymm8,ymm8,ymm8\t\n"
+ __m256 ymm9 = _mm256_setzero_ps(); //"vxorps ymm9,ymm9,ymm9\t\n"
+
+ for (; r14_i < r8; ++r14_i) { //"inc r14; cmp r14, r8; jl loop_inner%=\t\n"
+ // loop_inner%=: //"\t\n"
+ auto fp16mem0 = _mm_load_si128((__m128i*)((char*)r10 + 0)); //"vcvtph2ps ymm11,XMMWORD PTR [r10 + 0]\t\n"
+ auto ymm11 = _mm256_cvtph_ps(fp16mem0); //"vcvtph2ps ymm11,XMMWORD PTR [r10 + 0]\t\n"
+ auto fp16mem16 = _mm_load_si128((__m128i*)((char*)r10 + 16)); //"vcvtph2ps ymm12,XMMWORD PTR [r10 + 16]\t\n"
+ auto ymm12 = _mm256_cvtph_ps(fp16mem16); //"vcvtph2ps ymm12,XMMWORD PTR [r10 + 16]\t\n"
+ auto ymm10 = _mm256_broadcast_ss((float*)((char*)r9 + 0)); //"vbroadcastss ymm10,DWORD PTR [r9+0]\t\n"
+ ymm0 = _mm256_fmadd_ps(ymm10, ymm11, ymm0); //"vfmadd231ps ymm0,ymm11,ymm10\t\n"
+ ymm1 = _mm256_fmadd_ps(ymm10, ymm12, ymm1); //"vfmadd231ps ymm1,ymm12,ymm10\t\n"
+ ymm10 = _mm256_broadcast_ss((float*)((char*)r9 + 4)); //"vbroadcastss ymm10,DWORD PTR [r9+4]\t\n"
+ ymm2 = _mm256_fmadd_ps(ymm10, ymm11, ymm2); //"vfmadd231ps ymm2,ymm11,ymm10\t\n"
+ ymm3 = _mm256_fmadd_ps(ymm10, ymm12, ymm3); //"vfmadd231ps ymm3,ymm12,ymm10\t\n"
+ ymm10 = _mm256_broadcast_ss((float*)((char*)r9 + 8)); //"vbroadcastss ymm10,DWORD PTR [r9+8]\t\n"
+ ymm4 = _mm256_fmadd_ps(ymm10, ymm11, ymm4); //"vfmadd231ps ymm4,ymm11,ymm10\t\n"
+ ymm5 = _mm256_fmadd_ps(ymm10, ymm12, ymm5); //"vfmadd231ps ymm5,ymm12,ymm10\t\n"
+ ymm10 = _mm256_broadcast_ss((float*)((char*)r9 + 12)); //"vbroadcastss ymm10,DWORD PTR [r9+12]\t\n"
+ ymm6 = _mm256_fmadd_ps(ymm10, ymm11, ymm6); //"vfmadd231ps ymm6,ymm11,ymm10\t\n"
+ ymm7 = _mm256_fmadd_ps(ymm10, ymm12, ymm7); //"vfmadd231ps ymm7,ymm12,ymm10\t\n"
+ ymm10 = _mm256_broadcast_ss((float*)((char*)r9 + 16)); //"vbroadcastss ymm10,DWORD PTR [r9+16]\t\n"
+ ymm8 = _mm256_fmadd_ps(ymm10, ymm11, ymm8); //"vfmadd231ps ymm8,ymm11,ymm10\t\n"
+ ymm9 = _mm256_fmadd_ps(ymm10, ymm12, ymm9); //"vfmadd231ps ymm9,ymm12,ymm10\t\n"
+ r9 = (float*)((char*)r9 + 20); //"add r9,20\t\n"
+ r10 = (fp16*)((char*)r10 + 32); //"add r10,32\t\n"
+ } //"inc r14; cmp r14, r8; jl loop_outter%=\t\n"
+
+ if(rdx != 1) { //"cmp rdx, 1; je L_accum%=\t\n"
+ // Dump C
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm0); //"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm1); //"vmovups YMMWORD PTR [r12 + 32], ymm1\t\n"
+ r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm2); //"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm3); //"vmovups YMMWORD PTR [r12 + 32], ymm3\t\n"
+ r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm4); //"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm5); //"vmovups YMMWORD PTR [r12 + 32], ymm5\t\n"
+ r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm6); //"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm7); //"vmovups YMMWORD PTR [r12 + 32], ymm7\t\n"
+ r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm8); //"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm9); //"vmovups YMMWORD PTR [r12 + 32], ymm9\t\n"
+ } else { //"jmp L_done%=\t\n"
+ // Dump C with accumulate
+ auto ymm15 = _mm256_broadcast_ss((float*)r15); //"vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+ auto r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ ymm0 = _mm256_fmadd_ps(r12_0, ymm15, ymm0); //"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm0); //"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+ auto r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ ymm1 = _mm256_fmadd_ps(r12_32, ymm15, ymm1); //"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm1); //"vmovups YMMWORD PTR [r12 + 32], ymm1\t\n"
+ r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n"
+ r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ ymm2 = _mm256_fmadd_ps(r12_0, ymm15, ymm2); //"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm2); //"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+ r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ ymm3 = _mm256_fmadd_ps(r12_32, ymm15, ymm3); //"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm3); //"vmovups YMMWORD PTR [r12 + 32], ymm3\t\n"
+ r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n"
+ r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ ymm4 = _mm256_fmadd_ps(r12_0, ymm15, ymm4); //"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm4); //"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+ r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ ymm5 = _mm256_fmadd_ps(r12_32, ymm15, ymm5); //"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm5); //"vmovups YMMWORD PTR [r12 + 32], ymm5\t\n"
+ r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n"
+ r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ ymm6 = _mm256_fmadd_ps(r12_0, ymm15, ymm6); //"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm6); //"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+ r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ ymm7 = _mm256_fmadd_ps(r12_32, ymm15, ymm7); //"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm7); //"vmovups YMMWORD PTR [r12 + 32], ymm7\t\n"
+ r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n"
+ r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ ymm8 = _mm256_fmadd_ps(r12_0, ymm15, ymm8); //"vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm8); //"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n"
+ r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ ymm9 = _mm256_fmadd_ps(r12_32, ymm15, ymm9); //"vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm9); //"vmovups YMMWORD PTR [r12 + 32], ymm9\t\n"
+ } //"L_done%=:\t\n"
+
+ // next outer iteration
+ rcx = (float*)((char*)rcx + 64); //"add rcx, 64\t\n"
+ r12 = rcx; //"mov r12, rcx\t\n"
+ r9 = rax; //"mov r9, rax\t\n"
+ } //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n"
}
-void __attribute__((noinline)) gemmkernel_6x2_AVX2_fA0fB0fC0(GemmParams* gp) {
- asm volatile(
-#if !defined(__clang__)
- "mov r14, %[gp]\t\n"
-#else
- "mov %[gp], %%r14\t\n"
- ".intel_syntax noprefix\t\n"
-#endif
+
+void NOINLINE_ATTR gemmkernel_6x2_AVX2_fA0fB0fC0(GemmParams* gp) {
+ char* r14 = (char*)gp; //"mov r14, %[gp]\t\n"
// Copy parameters
// k
- "mov r8, [r14 + 0]\t\n"
+ uint64_t r8 = *(uint64_t *)((char*)r14 + 0 ); //"mov r8, [r14 + 0]\t\n"
// A
- "mov r9, [r14 + 8]\t\n"
+ float* r9 = *(float* *)((char*)r14 + 8 ); //"mov r9, [r14 + 8]\t\n"
// B
- "mov r10, [r14 + 16]\t\n"
+ const fp16* r10 = *(const fp16**)((char*)r14 + 16); //"mov r10, [r14 + 16]\t\n"
// beta
- "mov r15, [r14 + 24]\t\n"
+ float* r15 = *(float* *)((char*)r14 + 24); //"mov r15, [r14 + 24]\t\n"
// accum
- "mov rdx, [r14 + 32]\t\n"
+ uint64_t rdx = *(uint64_t *)((char*)r14 + 32); //"mov rdx, [r14 + 32]\t\n"
// C
- "mov r12, [r14 + 40]\t\n"
+ float* r12 = *(float* *)((char*)r14 + 40); //"mov r12, [r14 + 40]\t\n"
// ldc
- "mov r13, [r14 + 48]\t\n"
+ uint64_t r13 = *(uint64_t *)((char*)r14 + 48); //"mov r13, [r14 + 48]\t\n"
// b_block_cols
- "mov rdi, [r14 + 56]\t\n"
+ uint64_t rdi = *(uint64_t *)((char*)r14 + 56); //"mov rdi, [r14 + 56]\t\n"
// b_block_size
- "mov rsi, [r14 + 64]\t\n"
+ uint64_t rsi = *(uint64_t *)((char*)r14 + 64); //"mov rsi, [r14 + 64]\t\n"
// Make copies of A and C
- "mov rax, r9\t\n"
- "mov rcx, r12\t\n"
-
- "mov rbx, 0\t\n"
- "loop_outter%=:\t\n"
- "mov r14, 0\t\n"
- "vxorps ymm0,ymm0,ymm0\t\n"
- "vxorps ymm1,ymm1,ymm1\t\n"
- "vxorps ymm2,ymm2,ymm2\t\n"
- "vxorps ymm3,ymm3,ymm3\t\n"
- "vxorps ymm4,ymm4,ymm4\t\n"
- "vxorps ymm5,ymm5,ymm5\t\n"
- "vxorps ymm6,ymm6,ymm6\t\n"
- "vxorps ymm7,ymm7,ymm7\t\n"
- "vxorps ymm8,ymm8,ymm8\t\n"
- "vxorps ymm9,ymm9,ymm9\t\n"
- "vxorps ymm10,ymm10,ymm10\t\n"
- "vxorps ymm11,ymm11,ymm11\t\n"
-
-
- "loop_inner%=:\t\n"
-
- "vcvtph2ps ymm13,XMMWORD PTR [r10 + 0]\t\n"
- "vcvtph2ps ymm14,XMMWORD PTR [r10 + 16]\t\n"
- "vbroadcastss ymm12,DWORD PTR [r9+0]\t\n"
- "vfmadd231ps ymm0,ymm13,ymm12\t\n"
- "vfmadd231ps ymm1,ymm14,ymm12\t\n"
- "vbroadcastss ymm12,DWORD PTR [r9+4]\t\n"
- "vfmadd231ps ymm2,ymm13,ymm12\t\n"
- "vfmadd231ps ymm3,ymm14,ymm12\t\n"
- "vbroadcastss ymm12,DWORD PTR [r9+8]\t\n"
- "vfmadd231ps ymm4,ymm13,ymm12\t\n"
- "vfmadd231ps ymm5,ymm14,ymm12\t\n"
- "vbroadcastss ymm12,DWORD PTR [r9+12]\t\n"
- "vfmadd231ps ymm6,ymm13,ymm12\t\n"
- "vfmadd231ps ymm7,ymm14,ymm12\t\n"
- "vbroadcastss ymm12,DWORD PTR [r9+16]\t\n"
- "vfmadd231ps ymm8,ymm13,ymm12\t\n"
- "vfmadd231ps ymm9,ymm14,ymm12\t\n"
- "vbroadcastss ymm12,DWORD PTR [r9+20]\t\n"
- "vfmadd231ps ymm10,ymm13,ymm12\t\n"
- "vfmadd231ps ymm11,ymm14,ymm12\t\n"
- "add r9,24\t\n"
- "add r10,32\t\n"
- "inc r14\t\n"
- "cmp r14, r8\t\n"
- "jl loop_inner%=\t\n"
-
- "L_exit%=:\t\n"
-
- "cmp rdx, 1\t\n"
- "je L_accum%=\t\n"
- // Dump C
- "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm1\t\n"
- "add r12, r13\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm3\t\n"
- "add r12, r13\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm5\t\n"
- "add r12, r13\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm7\t\n"
- "add r12, r13\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm8\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm9\t\n"
- "add r12, r13\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm10\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm11\t\n"
- "add r12, r13\t\n"
- "jmp L_done%=\t\n"
-
- "L_accum%=:\t\n"
- // Dump C with accumulate
- "vbroadcastss ymm15,DWORD PTR [r15]\t\n"
- "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
- "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm1\t\n"
- "add r12, r13\t\n"
- "vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
- "vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 32]\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm3\t\n"
- "add r12, r13\t\n"
- "vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
- "vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 32]\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm5\t\n"
- "add r12, r13\t\n"
- "vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
- "vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 32]\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm7\t\n"
- "add r12, r13\t\n"
- "vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm8\t\n"
- "vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 32]\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm9\t\n"
- "add r12, r13\t\n"
- "vfmadd231ps ymm10,ymm15,YMMWORD PTR [r12 + 0]\t\n"
- "vmovups YMMWORD PTR [r12 + 0], ymm10\t\n"
- "vfmadd231ps ymm11,ymm15,YMMWORD PTR [r12 + 32]\t\n"
- "vmovups YMMWORD PTR [r12 + 32], ymm11\t\n"
- "add r12, r13\t\n"
-
- "L_done%=:\t\n"
-
- // next outer iteration
- "add rcx, 64\t\n"
- "mov r12, rcx\t\n"
- "mov r9, rax\t\n"
- "inc rbx\t\n"
- "cmp rbx, rdi\t\n"
- "jl loop_outter%=\t\n"
- :
- : [gp] "rm"(gp)
- : "r8",
- "r9",
- "r10",
- "r11",
- "r15",
- "r13",
- "r14",
- "rax",
- "rcx",
- "rdx",
- "rsi",
- "rdi",
- "rbx",
- "r12",
- "memory");
+ float* rax = r9; //"mov rax, r9\t\n"
+ float* rcx = r12; //"mov rcx, r12\t\n"
+
+ uint64_t rbx = 0; //"mov rbx, 0\t\n"
+ for (; rbx < rdi; ++rbx) { //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n"
+ // //"loop_outter%=:\t\n"
+ uint64_t r14_i = 0; //"mov r14, 0\t\n"
+ __m256 ymm0 = _mm256_setzero_ps(); //"vxorps ymm0,ymm0,ymm0\t\n"
+ __m256 ymm1 = _mm256_setzero_ps(); //"vxorps ymm1,ymm1,ymm1\t\n"
+ __m256 ymm2 = _mm256_setzero_ps(); //"vxorps ymm2,ymm2,ymm2\t\n"
+ __m256 ymm3 = _mm256_setzero_ps(); //"vxorps ymm3,ymm3,ymm3\t\n"
+ __m256 ymm4 = _mm256_setzero_ps(); //"vxorps ymm4,ymm4,ymm4\t\n"
+ __m256 ymm5 = _mm256_setzero_ps(); //"vxorps ymm5,ymm5,ymm5\t\n"
+ __m256 ymm6 = _mm256_setzero_ps(); //"vxorps ymm6,ymm6,ymm6\t\n"
+ __m256 ymm7 = _mm256_setzero_ps(); //"vxorps ymm7,ymm7,ymm7\t\n"
+ __m256 ymm8 = _mm256_setzero_ps(); //"vxorps ymm8,ymm8,ymm8\t\n"
+ __m256 ymm9 = _mm256_setzero_ps(); //"vxorps ymm9,ymm9,ymm9\t\n"
+ __m256 ymm10 = _mm256_setzero_ps(); //"vxorps ymm10,ymm10,ymm10\t\n"
+ __m256 ymm11 = _mm256_setzero_ps(); //"vxorps ymm11,ymm11,ymm11\t\n"
+
+ for (; r14_i < r8; ++r14_i) { //"inc r14; cmp r14, r8; jl loop_inner%=\t\n"
+ // loop_inner%=: //"\t\n"
+ auto fp16mem0 = _mm_load_si128((__m128i*)((char*)r10 + 0)); //"vcvtph2ps ymm13,XMMWORD PTR [r10 + 0]\t\n"
+ auto ymm13 = _mm256_cvtph_ps(fp16mem0); //"vcvtph2ps ymm13,XMMWORD PTR [r10 + 0]\t\n"
+ auto fp16mem16 = _mm_load_si128((__m128i*)((char*)r10 + 16)); //"vcvtph2ps ymm14,XMMWORD PTR [r10 + 16]\t\n"
+ auto ymm14 = _mm256_cvtph_ps(fp16mem16); //"vcvtph2ps ymm14,XMMWORD PTR [r10 + 16]\t\n"
+ auto ymm12 = _mm256_broadcast_ss((float*)((char*)r9 + 0)); //"vbroadcastss ymm12,DWORD PTR [r9+0]\t\n"
+ ymm0 = _mm256_fmadd_ps(ymm12, ymm13, ymm0); //"vfmadd231ps ymm0,ymm13,ymm12\t\n"
+ ymm1 = _mm256_fmadd_ps(ymm12, ymm14, ymm1); //"vfmadd231ps ymm1,ymm14,ymm12\t\n"
+ ymm12 = _mm256_broadcast_ss((float*)((char*)r9 + 4)); //"vbroadcastss ymm12,DWORD PTR [r9+4]\t\n"
+ ymm2 = _mm256_fmadd_ps(ymm12, ymm13, ymm2); //"vfmadd231ps ymm2,ymm13,ymm12\t\n"
+ ymm3 = _mm256_fmadd_ps(ymm12, ymm14, ymm3); //"vfmadd231ps ymm3,ymm14,ymm12\t\n"
+ ymm12 = _mm256_broadcast_ss((float*)((char*)r9 + 8)); //"vbroadcastss ymm12,DWORD PTR [r9+8]\t\n"
+ ymm4 = _mm256_fmadd_ps(ymm12, ymm13, ymm4); //"vfmadd231ps ymm4,ymm13,ymm12\t\n"
+ ymm5 = _mm256_fmadd_ps(ymm12, ymm14, ymm5); //"vfmadd231ps ymm5,ymm14,ymm12\t\n"
+ ymm12 = _mm256_broadcast_ss((float*)((char*)r9 + 12)); //"vbroadcastss ymm12,DWORD PTR [r9+12]\t\n"
+ ymm6 = _mm256_fmadd_ps(ymm12, ymm13, ymm6); //"vfmadd231ps ymm6,ymm13,ymm12\t\n"
+ ymm7 = _mm256_fmadd_ps(ymm12, ymm14, ymm7); //"vfmadd231ps ymm7,ymm14,ymm12\t\n"
+ ymm12 = _mm256_broadcast_ss((float*)((char*)r9 + 16)); //"vbroadcastss ymm12,DWORD PTR [r9+16]\t\n"
+ ymm8 = _mm256_fmadd_ps(ymm12, ymm13, ymm8); //"vfmadd231ps ymm8,ymm13,ymm12\t\n"
+ ymm9 = _mm256_fmadd_ps(ymm12, ymm14, ymm9); //"vfmadd231ps ymm9,ymm14,ymm12\t\n"
+ ymm12 = _mm256_broadcast_ss((float*)((char*)r9 + 20)); //"vbroadcastss ymm12,DWORD PTR [r9+20]\t\n"
+ ymm10 = _mm256_fmadd_ps(ymm12, ymm13, ymm10); //"vfmadd231ps ymm10,ymm13,ymm12\t\n"
+ ymm11 = _mm256_fmadd_ps(ymm12, ymm14, ymm11); //"vfmadd231ps ymm11,ymm14,ymm12\t\n"
+ r9 = (float*)((char*)r9 + 24); //"add r9,24\t\n"
+ r10 = (fp16*)((char*)r10 + 32); //"add r10,32\t\n"
+ } //"inc r14; cmp r14, r8; jl loop_outter%=\t\n"
+
+ if(rdx != 1) { //"cmp rdx, 1; je L_accum%=\t\n"
+ // Dump C
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm0); //"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm1); //"vmovups YMMWORD PTR [r12 + 32], ymm1\t\n"
+ r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm2); //"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm3); //"vmovups YMMWORD PTR [r12 + 32], ymm3\t\n"
+ r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm4); //"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm5); //"vmovups YMMWORD PTR [r12 + 32], ymm5\t\n"
+ r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm6); //"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm7); //"vmovups YMMWORD PTR [r12 + 32], ymm7\t\n"
+ r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm8); //"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm9); //"vmovups YMMWORD PTR [r12 + 32], ymm9\t\n"
+ r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm10); //"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm11); //"vmovups YMMWORD PTR [r12 + 32], ymm11\t\n"
+ } else { //"jmp L_done%=\t\n"
+ // Dump C with accumulate
+ auto ymm15 = _mm256_broadcast_ss((float*)r15); //"vbroadcastss ymm15,DWORD PTR [r15]\t\n"
+ auto r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ ymm0 = _mm256_fmadd_ps(r12_0, ymm15, ymm0); //"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm0); //"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n"
+ auto r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ ymm1 = _mm256_fmadd_ps(r12_32, ymm15, ymm1); //"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm1); //"vmovups YMMWORD PTR [r12 + 32], ymm1\t\n"
+ r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n"
+ r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ ymm2 = _mm256_fmadd_ps(r12_0, ymm15, ymm2); //"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm2); //"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n"
+ r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ ymm3 = _mm256_fmadd_ps(r12_32, ymm15, ymm3); //"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm3); //"vmovups YMMWORD PTR [r12 + 32], ymm3\t\n"
+ r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n"
+ r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ ymm4 = _mm256_fmadd_ps(r12_0, ymm15, ymm4); //"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm4); //"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n"
+ r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ ymm5 = _mm256_fmadd_ps(r12_32, ymm15, ymm5); //"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm5); //"vmovups YMMWORD PTR [r12 + 32], ymm5\t\n"
+ r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n"
+ r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ ymm6 = _mm256_fmadd_ps(r12_0, ymm15, ymm6); //"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm6); //"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n"
+ r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ ymm7 = _mm256_fmadd_ps(r12_32, ymm15, ymm7); //"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm7); //"vmovups YMMWORD PTR [r12 + 32], ymm7\t\n"
+ r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n"
+ r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ ymm8 = _mm256_fmadd_ps(r12_0, ymm15, ymm8); //"vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm8); //"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n"
+ r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ ymm9 = _mm256_fmadd_ps(r12_32, ymm15, ymm9); //"vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm9); //"vmovups YMMWORD PTR [r12 + 32], ymm9\t\n"
+ r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n"
+ r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm10,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ ymm10 = _mm256_fmadd_ps(r12_0, ymm15, ymm10); //"vfmadd231ps ymm10,ymm15,YMMWORD PTR [r12 + 0]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 0), ymm10); //"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n"
+ r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm11,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ ymm11 = _mm256_fmadd_ps(r12_32, ymm15, ymm11); //"vfmadd231ps ymm11,ymm15,YMMWORD PTR [r12 + 32]\t\n"
+ _mm256_storeu_ps((float*)((char*)r12 + 32), ymm11); //"vmovups YMMWORD PTR [r12 + 32], ymm11\t\n"
+ } //"L_done%=:\t\n"
+
+ // next outer iteration
+ rcx = (float*)((char*)rcx + 64); //"add rcx, 64\t\n"
+ r12 = rcx; //"mov r12, rcx\t\n"
+ r9 = rax; //"mov r9, rax\t\n"
+ } //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n"
}
+
} // namespace fbgemm
diff --git a/src/FbgemmFP16UKernelsAvx2.h b/src/FbgemmFP16UKernelsAvx2.h
index 6e7dfbc..d48a88e 100644
--- a/src/FbgemmFP16UKernelsAvx2.h
+++ b/src/FbgemmFP16UKernelsAvx2.h
@@ -13,6 +13,11 @@ namespace fbgemm {
using fp16 = float16;
using fp32 = float;
+#ifdef _MSC_VER
+ #define NOINLINE_ATTR __declspec(noinline)
+#else
+ #define NOINLINE_ATTR __attribute__((noinline))
+#endif
struct GemmParams {
uint64_t k;
float* A;
@@ -24,12 +29,12 @@ struct GemmParams {
uint64_t b_block_cols;
uint64_t b_block_size;
};
-void __attribute__((noinline)) gemmkernel_1x2_AVX2_fA0fB0fC0(GemmParams* gp);
-void __attribute__((noinline)) gemmkernel_2x2_AVX2_fA0fB0fC0(GemmParams* gp);
-void __attribute__((noinline)) gemmkernel_3x2_AVX2_fA0fB0fC0(GemmParams* gp);
-void __attribute__((noinline)) gemmkernel_4x2_AVX2_fA0fB0fC0(GemmParams* gp);
-void __attribute__((noinline)) gemmkernel_5x2_AVX2_fA0fB0fC0(GemmParams* gp);
-void __attribute__((noinline)) gemmkernel_6x2_AVX2_fA0fB0fC0(GemmParams* gp);
+void NOINLINE_ATTR gemmkernel_1x2_AVX2_fA0fB0fC0(GemmParams* gp);
+void NOINLINE_ATTR gemmkernel_2x2_AVX2_fA0fB0fC0(GemmParams* gp);
+void NOINLINE_ATTR gemmkernel_3x2_AVX2_fA0fB0fC0(GemmParams* gp);
+void NOINLINE_ATTR gemmkernel_4x2_AVX2_fA0fB0fC0(GemmParams* gp);
+void NOINLINE_ATTR gemmkernel_5x2_AVX2_fA0fB0fC0(GemmParams* gp);
+void NOINLINE_ATTR gemmkernel_6x2_AVX2_fA0fB0fC0(GemmParams* gp);
typedef void (*funcptr_fp16)(GemmParams* gp);
;
diff --git a/src/FbgemmI8DepthwiseAvx2.cc b/src/FbgemmI8DepthwiseAvx2.cc
index ee39faf..f96d1d2 100644
--- a/src/FbgemmI8DepthwiseAvx2.cc
+++ b/src/FbgemmI8DepthwiseAvx2.cc
@@ -5,6 +5,7 @@
* LICENSE file in the root directory of this source tree.
*/
#include "fbgemm/FbgemmI8DepthwiseAvx2.h"
+#include "fbgemm/Utils.h"
#include <algorithm> // for min and max
#include <cassert>
@@ -36,7 +37,8 @@ PackedDepthWiseConvMatrix<KERNEL_PROD>::PackedDepthWiseConvMatrix(
const int8_t* smat)
: K_(K) {
// Transpose the input matrix to make packing faster.
- alignas(64) int8_t smat_transposed[K * KERNEL_PROD];
+ int8_t* smat_transposed = static_cast<int8_t *>(ALIGNED_MALLOC(
+ K * KERNEL_PROD * sizeof(int8_t), 64));
for (int i = 0; i < KERNEL_PROD; ++i) {
for (int j = 0; j < K; ++j) {
smat_transposed[i * K + j] = smat[i + j * KERNEL_PROD];
@@ -45,12 +47,15 @@ PackedDepthWiseConvMatrix<KERNEL_PROD>::PackedDepthWiseConvMatrix(
// Allocate packed arrays
constexpr int KERNEL_PROD_ALIGNED = (KERNEL_PROD + 1) / 2 * 2;
- // pmat_ = static_cast<int8_t *>(fbgemmAlignedAlloc(
- // 64, ((K + 31) / 32) * KERNEL_PROD_ALIGNED * 32 * sizeof(int8_t)));
+#ifdef _MSC_VER
+ pmat_ = static_cast<int8_t *>(_aligned_malloc(
+ ((K + 31) / 32) * KERNEL_PROD_ALIGNED * 32 * sizeof(int8_t), 64));
+#else
posix_memalign(
(void**)&pmat_,
64,
((K + 31) / 32) * KERNEL_PROD_ALIGNED * 32 * sizeof(int8_t));
+#endif
// Pack input matrix
// The layout is optimized to use vpmaddubsw efficiently (see
@@ -160,11 +165,17 @@ PackedDepthWiseConvMatrix<KERNEL_PROD>::PackedDepthWiseConvMatrix(
b_interleaved_epi32[i]);
}
}
+
+ FREE(smat_transposed);
}
template <int KERNEL_PROD>
PackedDepthWiseConvMatrix<KERNEL_PROD>::~PackedDepthWiseConvMatrix() {
+#ifdef _MSC_VER
+ _aligned_free(pmat_);
+#else
free(pmat_);
+#endif
}
template class PackedDepthWiseConvMatrix<3 * 3>;
@@ -179,7 +190,7 @@ template class PackedDepthWiseConvMatrix<3 * 3 * 3>;
// c2_v: c[8:12], c[24:28]
// c3_v: c[12:16], c[28:32]
template <bool SUM_A = false>
-static inline __attribute__((always_inline)) void madd_epi16x4_packed(
+static inline ALWAYS_INLINE void madd_epi16x4_packed(
__m256i a0_v,
__m256i a1_v,
__m256i a2_v,
@@ -238,7 +249,7 @@ static inline __attribute__((always_inline)) void madd_epi16x4_packed(
// c2_v: c[8:12], c[24:28]
// c3_v: c[12:16], c[28:32]
template <bool SUM_A = false>
-static inline __attribute__((always_inline)) void madd_epi16x3_packed(
+static inline ALWAYS_INLINE void madd_epi16x3_packed(
__m256i a0_v,
__m256i a1_v,
__m256i a2_v,
@@ -298,7 +309,7 @@ static inline __attribute__((always_inline)) void madd_epi16x3_packed(
// c2_v: c[16:20], c[20:24]
// c3_v: c[24:28], c[28:32]
template <bool SUM_A = false>
-static inline __attribute__((always_inline)) void madd_epi16x2_packed(
+static inline ALWAYS_INLINE void madd_epi16x2_packed(
__m256i a0_v,
__m256i a1_v,
const __m256i* b,
@@ -339,7 +350,7 @@ static inline __attribute__((always_inline)) void madd_epi16x2_packed(
// c2_v: c[16:20], c[20:24]
// c3_v: c[24:28], c[28:32]
template <bool SUM_A = false>
-static inline __attribute__((always_inline)) void madd_epi16_packed(
+static inline ALWAYS_INLINE void madd_epi16_packed(
__m256i a_v,
const __m256i* b,
__m256i* c0_v,
@@ -374,7 +385,7 @@ static inline __attribute__((always_inline)) void madd_epi16_packed(
// K is the number of accumulations we're doing
template <int K, bool SUM_A = false, bool REMAINDER = false, bool ACC = false>
-static inline __attribute__((always_inline)) void inner_prod_packed_(
+static inline ALWAYS_INLINE void inner_prod_packed_(
const __m256i* a_v,
const __m256i* Bp,
int32_t* C,
@@ -514,7 +525,7 @@ static inline __attribute__((always_inline)) void inner_prod_packed_(
}
template <bool SUM_A = false, bool REMAINDER = false>
-static inline __attribute__((always_inline)) void inner_prod_3x3_packed_(
+static inline ALWAYS_INLINE void inner_prod_3x3_packed_(
const __m256i* a_v,
const __m256i* Bp,
int32_t* C,
@@ -531,7 +542,7 @@ template <
bool PER_CHANNEL_QUANTIZATION,
bool A_SYMMETRIC,
bool B_SYMMETRIC>
-static inline __attribute__((always_inline)) void requantize_(
+static inline ALWAYS_INLINE void requantize_(
int32_t A_zero_point,
const float* C_multiplier,
int32_t C_zero_point,
@@ -745,7 +756,7 @@ static inline __attribute__((always_inline)) void requantize_(
}
template <bool REMAINDER>
-static inline __attribute__((always_inline)) __m256i load_a(
+static inline ALWAYS_INLINE __m256i load_a(
const uint8_t* A,
__m256i mask_v) {
if (REMAINDER) {
@@ -759,7 +770,7 @@ template <
bool SUM_A,
bool REMAINDER = false,
bool PER_CHANNEL_QUANTIZATION = false>
-static inline __attribute__((always_inline)) void inner_prod_3x3_packed_(
+static inline ALWAYS_INLINE void inner_prod_3x3_packed_(
int H,
int W,
int K,
@@ -870,7 +881,7 @@ template <
bool SUM_A,
bool REMAINDER = false,
bool PER_CHANNEL_QUANTIZATION = false>
-static inline __attribute__((always_inline)) void inner_prod_3x3x3_packed_(
+static inline ALWAYS_INLINE void inner_prod_3x3x3_packed_(
int T,
int H,
int W,
@@ -1118,7 +1129,7 @@ static inline __attribute__((always_inline)) void inner_prod_3x3x3_packed_(
}
template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, bool B_SYMMETRIC>
-static inline __attribute__((always_inline)) void depthwise_3x3_kernel_(
+static inline ALWAYS_INLINE void depthwise_3x3_kernel_(
int H,
int W,
int K,
@@ -1194,7 +1205,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_kernel_(
}
template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, bool B_SYMMETRIC>
-static inline __attribute__((always_inline)) void depthwise_3x3x3_kernel_(
+static inline ALWAYS_INLINE void depthwise_3x3x3_kernel_(
int T,
int H,
int W,
@@ -1279,7 +1290,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_kernel_(
}
template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC>
-static inline __attribute__((always_inline)) void
+static inline ALWAYS_INLINE void
depthwise_3x3_per_channel_quantization_kernel_(
int H,
int W,
@@ -1362,7 +1373,7 @@ depthwise_3x3_per_channel_quantization_kernel_(
}
template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC>
-static inline __attribute__((always_inline)) void
+static inline ALWAYS_INLINE void
depthwise_3x3x3_per_channel_quantization_kernel_(
int T,
int H,
@@ -1465,7 +1476,7 @@ static pair<int, int> closest_factors_(int n) {
// filter shapes by parameterizing with R and S but restricting it to just 3x3
// for now.
template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, bool B_SYMMETRIC>
-static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_(
+static inline ALWAYS_INLINE void depthwise_3x3_pad_1_(
int N,
int H,
int W,
@@ -1491,7 +1502,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_(
int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
const int8_t* Bp = B.PackedMat();
- int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64)));
+ int32_t* row_offsets = static_cast<int32_t *>(ALIGNED_MALLOC(((K + 31) / 32 * 32)*sizeof(int32_t), 64));
int n_begin, n_end;
int h_begin, h_end, w_begin, w_end;
@@ -1748,10 +1759,11 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_(
}
}
} // for each n
+ FREE(row_offsets);
};
template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC, bool B_SYMMETRIC>
-static inline __attribute__((always_inline)) void depthwise_3x3x3_pad_1_(
+static inline ALWAYS_INLINE void depthwise_3x3x3_pad_1_(
int N,
int T,
int H,
@@ -1781,7 +1793,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_pad_1_(
int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
const int8_t* Bp = B.PackedMat();
- int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64)));
+ int32_t* row_offsets = static_cast<int32_t*>(ALIGNED_MALLOC(((K + 31) / 32 * 32)*sizeof(int32_t), 64)); // __attribute__((aligned(64)));
int n_begin, n_end;
int t_begin, t_end, h_begin, h_end;
@@ -1858,10 +1870,12 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_pad_1_(
} // h
} // t
} // for each n
+
+ FREE(row_offsets);
};
template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC>
-static inline __attribute__((always_inline)) void
+static inline ALWAYS_INLINE void
depthwise_3x3_per_channel_quantization_pad_1_(
int N,
int H,
@@ -1888,7 +1902,7 @@ depthwise_3x3_per_channel_quantization_pad_1_(
int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1;
const int8_t* Bp = B.PackedMat();
- int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64)));
+ int32_t* row_offsets = static_cast<int32_t*>(ALIGNED_MALLOC(((K + 31) / 32 * 32)*sizeof(int32_t), 64)); // __attribute__((aligned(64)));
int n_begin, n_end;
int h_begin, h_end, w_begin, w_end;
@@ -2172,10 +2186,12 @@ depthwise_3x3_per_channel_quantization_pad_1_(
}
}
} // for each n
+
+ FREE(row_offsets);
};
template <bool FUSE_RELU, bool HAS_BIAS, bool A_SYMMETRIC>
-static inline __attribute__((always_inline)) void
+static inline ALWAYS_INLINE void
depthwise_3x3x3_per_channel_quantization_pad_1_(
int N,
int T,
@@ -2206,7 +2222,7 @@ depthwise_3x3x3_per_channel_quantization_pad_1_(
int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1;
const int8_t* Bp = B.PackedMat();
- int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64)));
+ int32_t* row_offsets = static_cast<int32_t*>(ALIGNED_MALLOC(((K + 31) / 32 * 32)*sizeof(int32_t), 64)); // __attribute__((aligned(64)));
int n_begin, n_end;
int t_begin, t_end, h_begin, h_end;
@@ -2282,6 +2298,8 @@ depthwise_3x3x3_per_channel_quantization_pad_1_(
} // h
} // t
} // for each n
+
+ FREE(row_offsets);
};
// Dispatch A_SYMMETRIC and B_SYMMETRIC
@@ -2304,7 +2322,7 @@ static void depthwise_3x3_pad_1_(
const int32_t* bias,
int thread_id,
int num_threads) {
- int32_t C_int32_temp[(K + 31) / 32 * 32];
+ int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32];
if (A_zero_point == 0 || col_offsets == nullptr) {
if (B_zero_point == 0) {
depthwise_3x3_pad_1_<
@@ -2406,6 +2424,7 @@ static void depthwise_3x3_pad_1_(
num_threads);
}
}
+ delete[] C_int32_temp;
}
// Dispatch HAS_BIAS
@@ -2709,7 +2728,7 @@ static void depthwise_3x3x3_pad_1_(
const int32_t* bias,
int thread_id,
int num_threads) {
- int32_t C_int32_temp[(K + 31) / 32 * 32];
+ int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32];
if (A_zero_point == 0 || col_offsets == nullptr) {
if (B_zero_point == 0) {
depthwise_3x3x3_pad_1_<
@@ -2819,6 +2838,7 @@ static void depthwise_3x3x3_pad_1_(
num_threads);
}
}
+ delete[] C_int32_temp;
}
// Dispatch HAS_BIAS
@@ -2975,7 +2995,7 @@ static void depthwise_3x3_per_channel_quantization_pad_1_(
const int32_t* bias,
int thread_id,
int num_threads) {
- int32_t C_int32_temp[(K + 31) / 32 * 32];
+ int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32];
if (A_zero_point == 0 || col_offsets == nullptr) {
depthwise_3x3_per_channel_quantization_pad_1_<
FUSE_RELU,
@@ -3023,6 +3043,7 @@ static void depthwise_3x3_per_channel_quantization_pad_1_(
thread_id,
num_threads);
}
+ delete[] C_int32_temp;
}
// Dispatch HAS_BIAS
@@ -3329,7 +3350,7 @@ static void depthwise_3x3x3_per_channel_quantization_pad_1_(
const int32_t* bias,
int thread_id,
int num_threads) {
- int32_t C_int32_temp[(K + 31) / 32 * 32];
+ int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32];
if (A_zero_point == 0 || col_offsets == nullptr) {
depthwise_3x3x3_per_channel_quantization_pad_1_<
FUSE_RELU,
@@ -3381,6 +3402,7 @@ static void depthwise_3x3x3_per_channel_quantization_pad_1_(
thread_id,
num_threads);
}
+ delete[] C_int32_temp;
}
// Dispatch HAS_BIAS
diff --git a/src/FbgemmI8Spmdm.cc b/src/FbgemmI8Spmdm.cc
index 10e5a1b..edcc4e8 100644
--- a/src/FbgemmI8Spmdm.cc
+++ b/src/FbgemmI8Spmdm.cc
@@ -5,6 +5,7 @@
* LICENSE file in the root directory of this source tree.
*/
#include "fbgemm/FbgemmI8Spmdm.h"
+#include "fbgemm/Utils.h"
#include <algorithm>
#include <array>
@@ -70,9 +71,6 @@ void CompressedSparseColumn::SpMDM(
t_very_start = std::chrono::high_resolution_clock::now();
#endif
- alignas(64) uint8_t A_buffer[K * 32];
- alignas(64) int32_t C_buffer[N * 32];
-
// If we compute C = C + A * B, where B is a sparse matrix in CSC format, for
// each non-zero in B, we'd need to access the corresponding column in A.
// This results in strided access, which we want to avoid.
@@ -82,7 +80,7 @@ void CompressedSparseColumn::SpMDM(
// The cost of transpose is O(K*N) and we do O(NNZ*N) multiplications.
// If NNZ/K is small, it's not worth doing transpose so we just use this
// scalar loop.
- int32_t C_temp[block.row_size];
+ int32_t* C_temp = new int32_t[block.row_size];
if (accumulation) {
for (int j = 0; j < block.col_size; ++j) {
int k = colptr_[block.col_start + j];
@@ -141,6 +139,7 @@ void CompressedSparseColumn::SpMDM(
}
} // for each column of B
}
+ delete[] C_temp;
return;
}
@@ -152,12 +151,15 @@ void CompressedSparseColumn::SpMDM(
t_start = std::chrono::high_resolution_clock::now();
#endif
+ uint8_t* A_buffer = static_cast<uint8_t*>(ALIGNED_MALLOC(K * 32 * sizeof(uint8_t), 64));
+ int32_t* C_buffer = static_cast<int32_t*>(ALIGNED_MALLOC(N * 32 * sizeof(int32_t), 64));
+
// Take 32 rows at a time
int i_end = block.row_start + block.row_size;
for (int i1 = block.row_start; i1 < i_end; i1 += 32) {
// Transpose 32 x K submatrix of A
if (i_end - i1 < 32) {
- alignas(64) uint8_t A_temp_buffer[K * 32];
+ uint8_t* A_temp_buffer = static_cast<uint8_t*>(ALIGNED_MALLOC(K * 32 * sizeof(uint8_t), 64));
for (int i2 = 0; i2 < (i_end - i1) / 8 * 8; i2 += 8) {
transpose_8rows(K, A + (i1 + i2) * lda, lda, A_buffer + i2, 32);
}
@@ -173,6 +175,7 @@ void CompressedSparseColumn::SpMDM(
for (int i2 = (i_end - i1) / 8 * 8; i2 < 32; i2 += 8) {
transpose_8rows(K, A_temp_buffer + i2 * K, K, A_buffer + i2, 32);
}
+ FREE(A_temp_buffer);
} else {
for (int i2 = 0; i2 < 32; i2 += 8) {
transpose_8rows(K, A + (i1 + i2) * lda, lda, A_buffer + i2, 32);
@@ -250,6 +253,9 @@ void CompressedSparseColumn::SpMDM(
spmdm_run_time += (dt);
t_start = std::chrono::high_resolution_clock::now();
#endif
+
+ FREE(A_buffer);
+ FREE(C_buffer);
}
void CompressedSparseColumn::SparseConv(
diff --git a/src/PackAWithQuantRowOffset.cc b/src/PackAWithQuantRowOffset.cc
index 7572a51..305a298 100644
--- a/src/PackAWithQuantRowOffset.cc
+++ b/src/PackAWithQuantRowOffset.cc
@@ -111,7 +111,7 @@ void PackAWithQuantRowOffset<T, accT>::pack(const block_type_t& block) {
(block.col_start % (this->numCols() / this->numGroups())) != 0;
int32_t* row_offset_buf = getRowOffsetBuffer();
- float smat_transposed[block.row_size * block.col_size];
+ float* smat_transposed = new float[block.row_size * block.col_size];
if (tr) {
transpose_simd(
block.col_size,
@@ -150,6 +150,8 @@ void PackAWithQuantRowOffset<T, accT>::pack(const block_type_t& block) {
out[i * BaseType::blockColSize() + j] = 0;
}
}
+
+ delete[] smat_transposed;
}
template <typename T, typename accT>
diff --git a/src/UtilsAvx512.cc b/src/UtilsAvx512.cc
index 44d4f9a..b286eab 100644
--- a/src/UtilsAvx512.cc
+++ b/src/UtilsAvx512.cc
@@ -103,6 +103,40 @@ inline void transpose_kernel_16x16_avx512(
// m1 n1 o1 p1 ...
// m2 n2 o2 p2 ...
// m3 n3 o3 p3 ...
+#ifdef _MSC_VER
+ a = reinterpret_cast<__m512&>(_mm512_unpacklo_pd(
+ reinterpret_cast<__m512d&>(ta), reinterpret_cast<__m512d&>(tc)));
+ b = reinterpret_cast<__m512&>(_mm512_unpackhi_pd(
+ reinterpret_cast<__m512d&>(ta), reinterpret_cast<__m512d&>(tc)));
+ c = reinterpret_cast<__m512&>(_mm512_unpacklo_pd(
+ reinterpret_cast<__m512d&>(tb), reinterpret_cast<__m512d&>(td)));
+ d = reinterpret_cast<__m512&>(_mm512_unpackhi_pd(
+ reinterpret_cast<__m512d&>(tb), reinterpret_cast<__m512d&>(td)));
+ e = reinterpret_cast<__m512&>(_mm512_unpacklo_pd(
+ reinterpret_cast<__m512d&>(te), reinterpret_cast<__m512d&>(tg)));
+ f = reinterpret_cast<__m512&>(_mm512_unpackhi_pd(
+ reinterpret_cast<__m512d&>(te), reinterpret_cast<__m512d&>(tg)));
+ g = reinterpret_cast<__m512&>(_mm512_unpacklo_pd(
+ reinterpret_cast<__m512d&>(tf), reinterpret_cast<__m512d&>(th)));
+ h = reinterpret_cast<__m512&>(_mm512_unpackhi_pd(
+ reinterpret_cast<__m512d&>(tf), reinterpret_cast<__m512d&>(th)));
+ i = reinterpret_cast<__m512&>(_mm512_unpacklo_pd(
+ reinterpret_cast<__m512d&>(ti), reinterpret_cast<__m512d&>(tk)));
+ j = reinterpret_cast<__m512&>(_mm512_unpackhi_pd(
+ reinterpret_cast<__m512d&>(ti), reinterpret_cast<__m512d&>(tk)));
+ k = reinterpret_cast<__m512&>(_mm512_unpacklo_pd(
+ reinterpret_cast<__m512d&>(tj), reinterpret_cast<__m512d&>(tl)));
+ l = reinterpret_cast<__m512&>(_mm512_unpackhi_pd(
+ reinterpret_cast<__m512d&>(tj), reinterpret_cast<__m512d&>(tl)));
+ m = reinterpret_cast<__m512&>(_mm512_unpacklo_pd(
+ reinterpret_cast<__m512d&>(tm), reinterpret_cast<__m512d&>(to)));
+ n = reinterpret_cast<__m512&>(_mm512_unpackhi_pd(
+ reinterpret_cast<__m512d&>(tm), reinterpret_cast<__m512d&>(to)));
+ o = reinterpret_cast<__m512&>(_mm512_unpacklo_pd(
+ reinterpret_cast<__m512d&>(tn), reinterpret_cast<__m512d&>(tq)));
+ p = reinterpret_cast<__m512&>(_mm512_unpackhi_pd(
+ reinterpret_cast<__m512d&>(tn), reinterpret_cast<__m512d&>(tq)));
+#else
a = reinterpret_cast<__m512>(_mm512_unpacklo_pd(
reinterpret_cast<__m512d>(ta), reinterpret_cast<__m512d>(tc)));
b = reinterpret_cast<__m512>(_mm512_unpackhi_pd(
@@ -135,6 +169,7 @@ inline void transpose_kernel_16x16_avx512(
reinterpret_cast<__m512d>(tn), reinterpret_cast<__m512d>(tq)));
p = reinterpret_cast<__m512>(_mm512_unpackhi_pd(
reinterpret_cast<__m512d>(tn), reinterpret_cast<__m512d>(tq)));
+#endif
// shuffle 128-bits (composed of 4 32-bit elements)
// a0 b0 c0 d0 a8 b8 c8 d8 e0 f0 g0 h0 e8 f8 g8 h8
diff --git a/src/codegen_fp16fp32.cc b/src/codegen_fp16fp32.cc
index 7c8e10c..8f80593 100644
--- a/src/codegen_fp16fp32.cc
+++ b/src/codegen_fp16fp32.cc
@@ -18,10 +18,16 @@
using namespace std;
+void addi(ofstream& of, string i, string asmstr = "", bool disable = false) {
+ if (disable == false)
+ of << " " + i + " //\"" + asmstr + "\\t\\n\"" + "\n";
+}
+#if 0
void addi(ofstream& of, string i, bool disable = false) {
if (disable == false)
of << " \"" + i + "\\t\\n\"" + "\n";
}
+#endif
struct ISA {
unsigned avx; // 1, 2 or 3
@@ -88,7 +94,8 @@ int main() {
" * This source code is licensed under the BSD-style license found in the\n"
" * LICENSE file in the root directory of this source tree.\n"
" */\n";
- srcfile << "#include \"FbgemmFP16UKernelsAvx2.h\"\n\n";
+ srcfile << "#include \"FbgemmFP16UKernelsAvx2.h\"\n";
+ srcfile << "#include <immintrin.h>\n\n";
srcfile << "namespace fbgemm {\n\n";
if (iaca) {
srcfile << "#include \"iacaMarks.h\"\n";
@@ -111,6 +118,11 @@ int main() {
hdrfile << "namespace fbgemm {\n\n";
hdrfile << "using fp16 = float16;\n";
hdrfile << "using fp32 = float;\n";
+ hdrfile << "#ifdef _MSC_VER\n";
+ hdrfile << " #define NOINLINE_ATTR __declspec(noinline)\n";
+ hdrfile << "#else\n";
+ hdrfile << " #define NOINLINE_ATTR __attribute__((noinline))\n";
+ hdrfile << "#endif\n";
hdrfile
<< "struct GemmParams {\n uint64_t k;\n float* A;\n const fp16* B;\n"
" float* beta;\n uint64_t accum;\n float* C;\n uint64_t ldc;\n"
@@ -158,8 +170,9 @@ int main() {
fargs = "(" + p1 + ")";
+#if 1
fheader[k] =
- "void __attribute__((noinline)) " + funcname[k] + fargs;
+ "void NOINLINE_ATTR " + funcname[k] + fargs;
srcfile << fheader[k] << " {\n";
unsigned last_free_ymmreg = 0;
@@ -183,85 +196,92 @@ int main() {
assert(last_free_ymmreg <= 16);
- srcfile << " asm volatile(\n";
+ //srcfile << " asm volatile(\n";
- srcfile << "#if !defined(__clang__)"
- << "\n";
- addi(srcfile, "mov r14, %[gp]");
- srcfile << "#else\n";
- addi(srcfile, "mov %[gp], %%r14");
- addi(srcfile, ".intel_syntax noprefix");
- srcfile << "#endif\n";
+ //srcfile << "#if !defined(__clang__)"
+ //<< "\n";
+ addi(srcfile, "char* r14 = (char*)gp;", "mov r14, %[gp]");
+ //srcfile << "#else\n";
+ //addi(srcfile, "mov %[gp], %%r14");
+ //addi(srcfile, ".intel_syntax noprefix");
+ //srcfile << "#endif\n";
srcfile << "\n // Copy parameters\n";
- srcfile << " // k\n";
- addi(srcfile, "mov r8, [r14 + 0]");
- srcfile << " // A\n";
- addi(srcfile, "mov r9, [r14 + 8]");
- srcfile << " // B\n";
- addi(srcfile, "mov r10, [r14 + 16]");
- srcfile << " // beta\n";
- addi(srcfile, "mov r15, [r14 + 24]");
- srcfile << " // accum\n";
- addi(srcfile, "mov rdx, [r14 + 32]");
- srcfile << " // C\n";
- addi(srcfile, "mov r12, [r14 + 40]");
- srcfile << " // ldc\n";
- addi(srcfile, "mov r13, [r14 + 48]");
- srcfile << " // b_block_cols\n";
- addi(srcfile, "mov rdi, [r14 + 56]");
- srcfile << " // b_block_size\n";
- addi(srcfile, "mov rsi, [r14 + 64]");
+ srcfile << " // k\n"; addi(srcfile, "uint64_t r8 = *(uint64_t *)((char*)r14 + 0 );", "mov r8, [r14 + 0]");
+ srcfile << " // A\n"; addi(srcfile, "float* r9 = *(float* *)((char*)r14 + 8 );", "mov r9, [r14 + 8]");
+ srcfile << " // B\n"; addi(srcfile, "const fp16* r10 = *(const fp16**)((char*)r14 + 16);", "mov r10, [r14 + 16]");
+ srcfile << " // beta\n"; addi(srcfile, "float* r15 = *(float* *)((char*)r14 + 24);", "mov r15, [r14 + 24]");
+ srcfile << " // accum\n"; addi(srcfile, "uint64_t rdx = *(uint64_t *)((char*)r14 + 32);", "mov rdx, [r14 + 32]");
+ srcfile << " // C\n"; addi(srcfile, "float* r12 = *(float* *)((char*)r14 + 40);", "mov r12, [r14 + 40]");
+ srcfile << " // ldc\n"; addi(srcfile, "uint64_t r13 = *(uint64_t *)((char*)r14 + 48);", "mov r13, [r14 + 48]");
+ srcfile << " // b_block_cols\n"; addi(srcfile, "uint64_t rdi = *(uint64_t *)((char*)r14 + 56);", "mov rdi, [r14 + 56]");
+ srcfile << " // b_block_size\n"; addi(srcfile, "uint64_t rsi = *(uint64_t *)((char*)r14 + 64);", "mov rsi, [r14 + 64]");
srcfile << " // Make copies of A and C\n";
- addi(srcfile, "mov rax, r9");
- addi(srcfile, "mov rcx, r12");
+ addi(srcfile, "float* rax = r9;", "mov rax, r9");
+ addi(srcfile, "float* rcx = r12;", "mov rcx, r12");
srcfile << "\n";
- addi(srcfile, "mov rbx, 0");
+ addi(srcfile, "uint64_t rbx = 0;", "mov rbx, 0");
string exitlabel = "L_exit%=";
string label2 = "loop_outter%=";
- addi(srcfile, label2 + ":");
- addi(srcfile, "mov r14, 0");
+ addi(srcfile, "for (; rbx < rdi; ++rbx) {", "inc rbx; cmp rbx, rdi; jl " + label2);
+ addi(srcfile, "// ", label2 + ":");
+ addi(srcfile, " uint64_t r14_i = 0;", "mov r14, 0");
// set all vCtile regs to zeros
for (auto r = 0; r < vCtile.size(); r++) {
for (auto c = 0; c < vCtile[r].size(); c++) {
addi(
srcfile,
+ " __m256 " + vCtile[r][c] + " = _mm256_setzero_ps();",
"vxorps " + vCtile[r][c] + "," + vCtile[r][c] + "," +
vCtile[r][c]);
}
}
// start marker
- if (iaca) {
- addi(srcfile, "mov ebx, 111");
- addi(srcfile, ".byte 0x64, 0x67, 0x90");
- }
+ //if (iaca) {
+ // addi(srcfile, "mov ebx, 111");
+ // addi(srcfile, ".byte 0x64, 0x67, 0x90");
+ //}
- srcfile << "\n";
+ //srcfile << "\n";
srcfile << "\n";
string label = "loop_inner%=";
- addi(srcfile, label + ":");
- srcfile << "\n";
+ addi(srcfile, " for (; r14_i < r8; ++r14_i) {", "inc r14; cmp r14, r8; jl " + label);
+ addi(srcfile, " // " + label + ":");
+ //srcfile << "\n";
for (int c = 0; c < vCtile[0].size(); c++) {
addi(
+ srcfile,
+ " auto fp16mem" + to_string(16 * c) + " = _mm_load_si128((__m128i*)((char*)r10 + " + to_string(16 * c) + "));",
+ "vcvtph2ps " + vBcol[c] + ",XMMWORD PTR [r10 + " +
+ to_string(16 * c) + "]");
+ addi(
srcfile,
+ " auto " + vBcol[c] + " = _mm256_cvtph_ps(fp16mem" + to_string(16 * c) + ");",
"vcvtph2ps " + vBcol[c] + ",XMMWORD PTR [r10 + " +
to_string(16 * c) + "]");
}
for (int r = 0; r < vCtile.size(); r++) {
+ //addi(
+ // srcfile,
+ // ((r == 0) ? " auto " + vAtmp : "" + vAtmp) + " = _mm256_broadcastss_ps(r9 + " + to_string(4 * r) + ");",
+ // "vbroadcastss " + vAtmp + ",DWORD PTR [r9+" +
+ // to_string(4 * r) + "]");
addi(
srcfile,
+ ((r == 0) ? " auto " + vAtmp : " " + vAtmp) + " = _mm256_broadcast_ss((float*)((char*)r9 + " + to_string(4 * r) + "));",
"vbroadcastss " + vAtmp + ",DWORD PTR [r9+" +
to_string(4 * r) + "]");
for (int c = 0; c < vCtile[0].size(); c++) {
addi(
srcfile,
+ " " + vCtile[r][c] + " = _mm256_fmadd_ps(" + vAtmp + ", " + vBcol[c] + ", " + vCtile[r][c] + ");",
"vfmadd231ps " + vCtile[r][c] + "," + vBcol[c] + "," +
vAtmp);
}
@@ -269,21 +289,25 @@ int main() {
addi(
srcfile,
+ " r9 = (float*)((char*)r9 + " + to_string(4 * ukernel_shape[k][0]) + ");",
"add r9," + to_string(4 * ukernel_shape[k][0]),
fixedA); // move A ptr
addi(
srcfile,
+ " r10 = (fp16*)((char*)r10 + " + to_string(16 * ukernel_shape[k][1]) + ");",
"add r10," + to_string(16 * ukernel_shape[k][1]),
fixedA); // move A ptr
- addi(srcfile, "inc r14");
- addi(srcfile, "cmp r14, r8");
- addi(srcfile, "jl " + label);
+ addi(srcfile, " }", "inc r14; cmp r14, r8; jl " + label2);
+ // move to for loop
+ //addi(srcfile, "inc r14");
+ //addi(srcfile, "cmp r14, r8");
+ //addi(srcfile, "jl " + label);
- srcfile << "\n";
+ //srcfile << "\n";
- addi(srcfile, exitlabel + ":");
+ //addi(srcfile, exitlabel + ":");
// addi(srcfile, "add r10, rsi");
srcfile << "\n";
@@ -294,29 +318,33 @@ int main() {
addi(srcfile, ".byte 0x64, 0x67, 0x90");
}
- addi(srcfile, "cmp rdx, 1");
- addi(srcfile, "je L_accum%=");
- srcfile << " // Dump C\n";
+ //addi(srcfile, "cmp rdx, 1");
+ addi(srcfile, " if(rdx != 1) {", "cmp rdx, 1; je L_accum%=");
+
+ srcfile << " // Dump C\n";
for (auto r = 0; r < vCtile.size(); r++) {
for (auto c = 0; c < vCtile[r].size(); c++) {
addi(
srcfile,
+ " _mm256_storeu_ps((float*)((char*)r12 + " + to_string(32 * c) + "), " + vCtile[r][c] + ");",
"vmovups YMMWORD PTR [r12 + " + to_string(32 * c) +
"], " + vCtile[r][c],
fixedC);
}
- addi(srcfile, "add r12, r13", fixedC); // move C ptr
+ if (r != vCtile.size() - 1)
+ addi(srcfile, " r12 = (float*)((char*)r12 + r13);", "add r12, r13", fixedC); // move C ptr
}
- addi(srcfile, "jmp L_done%=");
+ addi(srcfile, " } else {", "jmp L_done%=");
- srcfile << "\n";
- addi(srcfile, "L_accum%=:");
- srcfile << " // Dump C with accumulate\n";
+ //srcfile << "\n";
+ //addi(srcfile, "L_accum%=:");
+ srcfile << " // Dump C with accumulate\n";
string r_spare = (s.avx == 1) ? "ymm14" : "ymm15";
addi(
srcfile,
+ " auto " + r_spare + " = _mm256_broadcast_ss((float*)r15);",
"vbroadcastss " + r_spare + string(",DWORD PTR [r15]"),
fixedC);
// store out C
@@ -326,18 +354,29 @@ int main() {
case 1:
addi(
srcfile,
+ "not supported",
string("vmulps ymm15, ") + r_spare + comma +
"YMMWORD PTR [r12 + " + to_string(32 * c) + "]",
fixedC);
addi(
srcfile,
+ "not supported",
"vaddps " + vCtile[r][c] + "," + vCtile[r][c] + "," +
"ymm15",
fixedC);
break;
case 2:
+ //if (r == 0) {
+ addi(
+ srcfile,
+ ((r == 0) ? " auto r12_" + to_string(32 * c) : " r12_" + to_string(32 * c)) + " = _mm256_load_ps((float*)((char*)r12 + " + to_string(32 * c) + "));",
+ "vfmadd231ps " + vCtile[r][c] + "," + r_spare + "," +
+ "YMMWORD PTR [r12 + " + to_string(32 * c) + "]",
+ fixedC);
+ //}
addi(
srcfile,
+ " " + vCtile[r][c] + " = _mm256_fmadd_ps(r12_" + to_string(32 * c) + ", " + r_spare + ", " + vCtile[r][c] + ");",
"vfmadd231ps " + vCtile[r][c] + "," + r_spare + "," +
"YMMWORD PTR [r12 + " + to_string(32 * c) + "]",
fixedC);
@@ -347,46 +386,283 @@ int main() {
}
addi(
srcfile,
+ " _mm256_storeu_ps((float*)((char*)r12 + " + to_string(32 * c) + "), " + vCtile[r][c] + ");",
"vmovups YMMWORD PTR [r12 + " + to_string(32 * c) +
"], " + vCtile[r][c],
fixedC);
}
- addi(srcfile, "add r12, r13", fixedC); // move C ptr
+ if (r != vCtile.size() - 1)
+ addi(srcfile, " r12 = (float*)((char*)r12 + r13);", "add r12, r13", fixedC); // move C ptr
}
- srcfile << "\n";
- addi(srcfile, "L_done%=:");
+ //srcfile << "\n";
+ addi(srcfile, " }", "L_done%=:");
- srcfile << "\n // next outer iteration\n";
+ srcfile << "\n // next outer iteration\n";
// C
addi(
srcfile,
+ " rcx = (float*)((char*)rcx + " + to_string(32 * ukernel_shape[k][1]) + ");",
"add rcx, " + to_string(32 * ukernel_shape[k][1]),
fixedC);
- addi(srcfile, "mov r12, rcx", fixedC);
+ addi(srcfile, " r12 = rcx;", "mov r12, rcx", fixedC);
// A
- addi(srcfile, "mov r9, rax");
-
- addi(srcfile, "inc rbx");
- addi(srcfile, "cmp rbx, rdi");
- addi(srcfile, "jl " + label2);
-
- // output
- srcfile << " :\n";
- // input
- srcfile << " : [gp] \"rm\"(gp)\n";
-
- // clobbered
- srcfile
- << " : \"r8\",\n \"r9\",\n \"r10\",\n"
- " \"r11\",\n \"r15\",\n \"r13\",\n"
- " \"r14\",\n \"rax\",\n \"rcx\",\n"
- " \"rdx\",\n \"rsi\",\n \"rdi\",\n"
- " \"rbx\",\n \"r12\",\n"
- " \"memory\");\n";
- srcfile << "}\n";
+ addi(srcfile, " r9 = rax;", "mov r9, rax");
+
+ // move to top for looop
+ //addi(srcfile, "inc rbx");
+ //addi(srcfile, "cmp rbx, rdi");
+ //addi(srcfile, "jl " + label2);
+ addi(srcfile, "}", "inc rbx; cmp rbx, rdi; jl " + label2);
+
+ //// output
+ //srcfile << " :\n";
+ //// input
+ //srcfile << " : [gp] \"rm\"(gp)\n";
+
+ //// clobbered
+ //srcfile
+ // << " : \"r8\",\n \"r9\",\n \"r10\",\n"
+ // " \"r11\",\n \"r15\",\n \"r13\",\n"
+ // " \"r14\",\n \"rax\",\n \"rcx\",\n"
+ // " \"rdx\",\n \"rsi\",\n \"rdi\",\n"
+ // " \"rbx\",\n \"r12\",\n"
+ // " \"memory\");\n";
+ srcfile << "}\n\n";
+ }
+
+#else
+ fheader[k] =
+ "void __attribute__((noinline)) " + funcname[k] + fargs;
+ srcfile << fheader[k] << " {\n";
+
+ unsigned last_free_ymmreg = 0;
+ // produce register block of C
+ vector<vector<string>> vCtile(ukernel_shape[k][0]);
+ for (auto r = 0; r < ukernel_shape[k][0]; r++)
+ for (auto c = 0; c < ukernel_shape[k][1]; c++) {
+ vCtile[r].push_back("ymm" + to_string(last_free_ymmreg));
+ last_free_ymmreg++;
+ }
+ assert(last_free_ymmreg <= 14);
+
+ string vAtmp = "ymm" + to_string(last_free_ymmreg++);
+ // produce register block of B col
+ vector<string> vBcol(ukernel_shape[k][1]);
+
+ for (auto c = 0; c < ukernel_shape[k][1]; c++) {
+ vBcol[c] = ("ymm" + to_string(last_free_ymmreg));
+ last_free_ymmreg++;
+ }
+
+ assert(last_free_ymmreg <= 16);
+
+ srcfile << " asm volatile(\n";
+
+ srcfile << "#if !defined(__clang__)"
+ << "\n";
+ addi(srcfile, "mov r14, %[gp]");
+ srcfile << "#else\n";
+ addi(srcfile, "mov %[gp], %%r14");
+ addi(srcfile, ".intel_syntax noprefix");
+ srcfile << "#endif\n";
+
+ srcfile << "\n // Copy parameters\n";
+ srcfile << " // k\n";
+ addi(srcfile, "mov r8, [r14 + 0]");
+ srcfile << " // A\n";
+ addi(srcfile, "mov r9, [r14 + 8]");
+ srcfile << " // B\n";
+ addi(srcfile, "mov r10, [r14 + 16]");
+ srcfile << " // beta\n";
+ addi(srcfile, "mov r15, [r14 + 24]");
+ srcfile << " // accum\n";
+ addi(srcfile, "mov rdx, [r14 + 32]");
+ srcfile << " // C\n";
+ addi(srcfile, "mov r12, [r14 + 40]");
+ srcfile << " // ldc\n";
+ addi(srcfile, "mov r13, [r14 + 48]");
+ srcfile << " // b_block_cols\n";
+ addi(srcfile, "mov rdi, [r14 + 56]");
+ srcfile << " // b_block_size\n";
+ addi(srcfile, "mov rsi, [r14 + 64]");
+ srcfile << " // Make copies of A and C\n";
+ addi(srcfile, "mov rax, r9");
+ addi(srcfile, "mov rcx, r12");
+ srcfile << "\n";
+
+ addi(srcfile, "mov rbx, 0");
+
+ string exitlabel = "L_exit%=";
+ string label2 = "loop_outter%=";
+ addi(srcfile, label2 + ":");
+ addi(srcfile, "mov r14, 0");
+
+ // set all vCtile regs to zeros
+ for (auto r = 0; r < vCtile.size(); r++) {
+ for (auto c = 0; c < vCtile[r].size(); c++) {
+ addi(
+ srcfile,
+ "vxorps " + vCtile[r][c] + "," + vCtile[r][c] + "," +
+ vCtile[r][c]);
+ }
+ }
+
+ // start marker
+ if (iaca) {
+ addi(srcfile, "mov ebx, 111");
+ addi(srcfile, ".byte 0x64, 0x67, 0x90");
+ }
+
+ srcfile << "\n";
+
+ srcfile << "\n";
+ string label = "loop_inner%=";
+ addi(srcfile, label + ":");
+ srcfile << "\n";
+
+ for (int c = 0; c < vCtile[0].size(); c++) {
+ addi(
+ srcfile,
+ "vcvtph2ps " + vBcol[c] + ",XMMWORD PTR [r10 + " +
+ to_string(16 * c) + "]");
}
+ for (int r = 0; r < vCtile.size(); r++) {
+ addi(
+ srcfile,
+ "vbroadcastss " + vAtmp + ",DWORD PTR [r9+" +
+ to_string(4 * r) + "]");
+ for (int c = 0; c < vCtile[0].size(); c++) {
+ addi(
+ srcfile,
+ "vfmadd231ps " + vCtile[r][c] + "," + vBcol[c] + "," +
+ vAtmp);
+ }
+ }
+
+ addi(
+ srcfile,
+ "add r9," + to_string(4 * ukernel_shape[k][0]),
+ fixedA); // move A ptr
+
+ addi(
+ srcfile,
+ "add r10," + to_string(16 * ukernel_shape[k][1]),
+ fixedA); // move A ptr
+
+ addi(srcfile, "inc r14");
+ addi(srcfile, "cmp r14, r8");
+ addi(srcfile, "jl " + label);
+
+ srcfile << "\n";
+
+ addi(srcfile, exitlabel + ":");
+
+ // addi(srcfile, "add r10, rsi");
+ srcfile << "\n";
+
+ // end marker
+ if (iaca) {
+ addi(srcfile, "mov ebx, 222");
+ addi(srcfile, ".byte 0x64, 0x67, 0x90");
+ }
+
+ addi(srcfile, "cmp rdx, 1");
+ addi(srcfile, "je L_accum%=");
+ srcfile << " // Dump C\n";
+
+ for (auto r = 0; r < vCtile.size(); r++) {
+ for (auto c = 0; c < vCtile[r].size(); c++) {
+ addi(
+ srcfile,
+ "vmovups YMMWORD PTR [r12 + " + to_string(32 * c) +
+ "], " + vCtile[r][c],
+ fixedC);
+ }
+ addi(srcfile, "add r12, r13", fixedC); // move C ptr
+ }
+ addi(srcfile, "jmp L_done%=");
+
+ srcfile << "\n";
+ addi(srcfile, "L_accum%=:");
+ srcfile << " // Dump C with accumulate\n";
+
+ string r_spare = (s.avx == 1) ? "ymm14" : "ymm15";
+ addi(
+ srcfile,
+ "vbroadcastss " + r_spare + string(",DWORD PTR [r15]"),
+ fixedC);
+ // store out C
+ for (auto r = 0; r < vCtile.size(); r++) {
+ for (auto c = 0; c < vCtile[r].size(); c++) {
+ switch (s.avx) {
+ case 1:
+ addi(
+ srcfile,
+ string("vmulps ymm15, ") + r_spare + comma +
+ "YMMWORD PTR [r12 + " + to_string(32 * c) + "]",
+ fixedC);
+ addi(
+ srcfile,
+ "vaddps " + vCtile[r][c] + "," + vCtile[r][c] + "," +
+ "ymm15",
+ fixedC);
+ break;
+ case 2:
+ addi(
+ srcfile,
+ "vfmadd231ps " + vCtile[r][c] + "," + r_spare + "," +
+ "YMMWORD PTR [r12 + " + to_string(32 * c) + "]",
+ fixedC);
+ break;
+ default:
+ assert(0);
+ }
+ addi(
+ srcfile,
+ "vmovups YMMWORD PTR [r12 + " + to_string(32 * c) +
+ "], " + vCtile[r][c],
+ fixedC);
+ }
+ addi(srcfile, "add r12, r13", fixedC); // move C ptr
+ }
+
+ srcfile << "\n";
+ addi(srcfile, "L_done%=:");
+
+ srcfile << "\n // next outer iteration\n";
+ // C
+ addi(
+ srcfile,
+ "add rcx, " + to_string(32 * ukernel_shape[k][1]),
+ fixedC);
+ addi(srcfile, "mov r12, rcx", fixedC);
+ // A
+ addi(srcfile, "mov r9, rax");
+
+ addi(srcfile, "inc rbx");
+ addi(srcfile, "cmp rbx, rdi");
+ addi(srcfile, "jl " + label2);
+
+ // output
+ srcfile << " :\n";
+ // input
+ srcfile << " : [gp] \"rm\"(gp)\n";
+
+ // clobbered
+ srcfile
+ << " : \"r8\",\n \"r9\",\n \"r10\",\n"
+ " \"r11\",\n \"r15\",\n \"r13\",\n"
+ " \"r14\",\n \"rax\",\n \"rcx\",\n"
+ " \"rdx\",\n \"rsi\",\n \"rdi\",\n"
+ " \"rbx\",\n \"r12\",\n"
+ " \"memory\");\n";
+ srcfile << "}\n";
+ }
+
+#endif
+
for (unsigned k = 0; k < ukernel_shape.size(); k++) {
hdrfile << fheader[k] << ";\n";
}
diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt
index a7e531b..42c3e1f 100644
--- a/test/CMakeLists.txt
+++ b/test/CMakeLists.txt
@@ -18,8 +18,12 @@ macro(add_gtest TESTNAME)
set_target_properties(${TESTNAME} PROPERTIES
CXX_STANDARD 11
CXX_EXTENSIONS NO)
- target_compile_options(${TESTNAME} PRIVATE
- "-m64" "-mavx2" "-mfma" "-masm=intel")
+ if(MSVC)
+ target_compile_options(${TESTNAME} PRIVATE "/DFBGEMM_STATIC")
+ else(MSVC)
+ target_compile_options(${TESTNAME} PRIVATE
+ "-m64" "-mavx2" "-mfma" "-masm=intel")
+ endif(MSVC)
target_link_libraries(${TESTNAME} gtest gmock gtest_main fbgemm)
add_dependencies(${TESTNAME} gtest fbgemm)
add_test(${TESTNAME} ${TESTNAME})
diff --git a/test/Im2ColFusedRequantizeTest.cc b/test/Im2ColFusedRequantizeTest.cc
index d9c2f75..b14303f 100644
--- a/test/Im2ColFusedRequantizeTest.cc
+++ b/test/Im2ColFusedRequantizeTest.cc
@@ -7,6 +7,7 @@
#include <cmath>
#include <cstdio>
#include <random>
+#include <numeric>
#ifdef _OPENMP
#include <omp.h>
diff --git a/test/PackedRequantizeAcc16Test.cc b/test/PackedRequantizeAcc16Test.cc
index 55f6e7f..20f860e 100644
--- a/test/PackedRequantizeAcc16Test.cc
+++ b/test/PackedRequantizeAcc16Test.cc
@@ -9,6 +9,7 @@
#include <cmath>
#include <random>
#include <vector>
+#include <numeric>
#ifdef _OPENMP
#include <omp.h>