From 24ff10d324a9ca12820c4aa8bf74fabd4eab81dc Mon Sep 17 00:00:00 2001 From: Young Jin Kim Date: Wed, 12 Jun 2019 17:49:16 -0700 Subject: Compile both on windows and linux --- CMakeLists.txt | 46 +- bench/AlignedVec.h | 9 +- bench/BenchUtils.cc | 15 + bench/CMakeLists.txt | 8 +- bench/ConvUnifiedBenchmark.cc | 1 + bench/PackedRequantizeAcc16Benchmark.cc | 1 + cmake/modules/FindMKL.cmake | 2 +- include/fbgemm/Fbgemm.h | 12 + include/fbgemm/FbgemmFP16.h | 13 +- include/fbgemm/Types.h | 10 +- src/FbgemmFP16UKernelsAvx2.cc | 1261 ++++++++++++++----------------- src/FbgemmFP16UKernelsAvx2.h | 17 +- src/FbgemmI8DepthwiseAvx2.cc | 84 +- src/FbgemmI8Spmdm.cc | 15 +- src/PackAWithQuantRowOffset.cc | 4 +- src/UtilsAvx512.cc | 35 + src/codegen_fp16fp32.cc | 432 +++++++++-- test/CMakeLists.txt | 8 +- test/Im2ColFusedRequantizeTest.cc | 1 + test/PackedRequantizeAcc16Test.cc | 1 + 20 files changed, 1127 insertions(+), 848 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b575e17..e6c7419 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,6 +22,10 @@ set(FBGEMM_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}) set(FBGEMM_THIRDPARTY_DIR ${FBGEMM_BINARY_DIR}/third_party) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +# only static library is available for windows +if(MSVC) + set(FBGEMM_LIBRARY_TYPE "static") +endif(MSVC) #All the source files that either use avx2 instructions statically or JIT #avx2/avx512 instructions. @@ -49,10 +53,20 @@ set(FBGEMM_GENERIC_SRCS src/ExecuteKernel.cc src/Utils.cc) #check if compiler supports avx512 -include(CheckCXXCompilerFlag) -CHECK_CXX_COMPILER_FLAG(-mavx512f COMPILER_SUPPORTS_AVX512) -if(NOT COMPILER_SUPPORTS_AVX512) - message(FATAL_ERROR "A compiler with AVX512 support is required.") +if (MSVC) + set(DISABLE_GLOBALLY "/wd\"4310\" /wd\"4324\"") + set(INTRINSICS "/arch:AVX2") + set(CMAKE_CXX_FLAGS "/EHsc /DWIN32 /D_WINDOWS /DUNICODE /D_UNICODE /D_CRT_NONSTDC_NO_WARNINGS /D_CRT_SECURE_NO_WARNINGS ${DISABLE_GLOBALLY}") + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS} /MT /O2 ${INTRINSICS} /Zi /MP /GL /DNDEBUG") + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS} /MTd /Od /Ob0 ${INTRINSICS} /RTC1 /Zi /D_DEBUG") + set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /DEBUG /LTCG:incremental /INCREMENTAL:NO /NODEFAULTLIB:MSVCRT /ignore:4049") + set(CMAKE_STATIC_LINKER_FLAGS "${CMAKE_STATIC_LINKER_FLAGS} /LTCG:incremental /NODEFAULTLIB:MSVCRT") +else() + include(CheckCXXCompilerFlag) + CHECK_CXX_COMPILER_FLAG(-mavx512f COMPILER_SUPPORTS_AVX512) + if(NOT COMPILER_SUPPORTS_AVX512) + message(FATAL_ERROR "A compiler with AVX512 support is required.") + endif() endif() #We should default to a Release build @@ -104,11 +118,13 @@ set_target_properties(fbgemm_generic fbgemm_avx2 fbgemm_avx512 PROPERTIES CXX_EXTENSIONS NO CXX_VISIBILITY_PRESET hidden) -target_compile_options(fbgemm_avx2 PRIVATE - "-m64" "-mavx2" "-mfma" "-masm=intel") -target_compile_options(fbgemm_avx512 PRIVATE - "-m64" "-mavx2" "-mfma" "-mavx512f" "-mavx512bw" "-mavx512dq" - "-mavx512vl" "-masm=intel") +if (NOT MSVC) + target_compile_options(fbgemm_avx2 PRIVATE + "-m64" "-mavx2" "-mfma" "-masm=intel" "-mf16c") + target_compile_options(fbgemm_avx512 PRIVATE + "-m64" "-mavx2" "-mfma" "-mavx512f" "-mavx512bw" "-mavx512dq" + "-mavx512vl" "-masm=intel" "-mf16c") +endif() if(NOT TARGET asmjit) #Download asmjit from github if ASMJIT_SRC_DIR is not specified. @@ -118,7 +134,7 @@ if(NOT TARGET asmjit) endif() #build asmjit - set(ASMJIT_STATIC ON) + set(ASMJIT_STATIC TRUE CACHE STRING "" FORCE) add_subdirectory("${ASMJIT_SRC_DIR}" "${FBGEMM_BINARY_DIR}/asmjit") set_property(TARGET asmjit PROPERTY POSITION_INDEPENDENT_CODE ON) endif() @@ -135,6 +151,9 @@ if(NOT TARGET cpuinfo) set(CPUINFO_BUILD_MOCK_TESTS OFF CACHE BOOL "Do not build cpuinfo mock tests") set(CPUINFO_BUILD_BENCHMARKS OFF CACHE BOOL "Do not build cpuinfo benchmarks") set(CPUINFO_LIBRARY_TYPE static) + if(MSVC) + set(CPUINFO_RUNTIME_TYPE static) + endif(MSVC) add_subdirectory("${CPUINFO_SOURCE_DIR}" "${FBGEMM_BINARY_DIR}/cpuinfo") set_property(TARGET cpuinfo PROPERTY POSITION_INDEPENDENT_CODE ON) endif() @@ -177,9 +196,10 @@ elseif(FBGEMM_LIBRARY_TYPE STREQUAL "static") $ $ $) - target_compile_definitions(fbgemm_avx2 PRIVATE FBGEMM_STATIC) - target_compile_definitions(fbgemm_avx512 PRIVATE FBGEMM_STATIC) - target_compile_definitions(fbgemm PRIVATE FBGEMM_STATIC) + target_compile_definitions(fbgemm_generic PRIVATE FBGEMM_STATIC ASMJIT_STATIC) + target_compile_definitions(fbgemm_avx2 PRIVATE FBGEMM_STATIC ASMJIT_STATIC) + target_compile_definitions(fbgemm_avx512 PRIVATE FBGEMM_STATIC ASMJIT_STATIC) + target_compile_definitions(fbgemm PRIVATE FBGEMM_STATIC ASMJIT_STATIC) else() message(FATAL_ERROR "Unsupported library type ${FBGEMM_LIBRARY_TYPE}") endif() diff --git a/bench/AlignedVec.h b/bench/AlignedVec.h index fd4b88e..8b3ff14 100644 --- a/bench/AlignedVec.h +++ b/bench/AlignedVec.h @@ -99,8 +99,11 @@ class aligned_allocator { // Mallocator wraps malloc(). void* pv = nullptr; +#ifdef _MSC_VER + pv = _aligned_malloc(n * sizeof(T), Alignment); +#else posix_memalign(&pv, Alignment, n * sizeof(T)); - // pv = aligned_alloc(Alignment, n * sizeof(T)); +#endif // Allocators should throw std::bad_alloc in the case of memory allocation // failure. @@ -112,7 +115,11 @@ class aligned_allocator { } void deallocate(T* const p, const std::size_t /*n*/) const { +#ifdef _MSC_VER + _aligned_free(p); +#else free(p); +#endif } // The following will be the same for all allocators that ignore hints. diff --git a/bench/BenchUtils.cc b/bench/BenchUtils.cc index a5ab949..fb20988 100644 --- a/bench/BenchUtils.cc +++ b/bench/BenchUtils.cc @@ -24,6 +24,21 @@ void randFill(aligned_vector& vec, T low, T high, std::true_type) { std::generate(vec.begin(), vec.end(), [&] { return dis(eng); }); } +// MSVC doesn't accept uint8_t and int8_t as a template argument. +#ifdef _MSC_VER +void randFill(aligned_vector& vec, uint8_t low, uint8_t high, std::true_type) { + std::uniform_int_distribution dis(low, high); + for (int i = 0; i < vec.size(); i++) + vec[i] = (uint8_t)dis(eng); +} + +void randFill(aligned_vector& vec, int8_t low, int8_t high, std::true_type) { + std::uniform_int_distribution dis(low, high); + for (int i = 0; i < vec.size(); i++) + vec[i] = (int8_t)dis(eng); +} +#endif + template void randFill(aligned_vector& vec, T low, T high, std::false_type) { std::uniform_real_distribution dis(low, high); diff --git a/bench/CMakeLists.txt b/bench/CMakeLists.txt index a58eed7..c2e76ef 100644 --- a/bench/CMakeLists.txt +++ b/bench/CMakeLists.txt @@ -12,8 +12,12 @@ macro(add_benchmark BENCHNAME) set_target_properties(${BENCHNAME} PROPERTIES CXX_STANDARD 11 CXX_EXTENSIONS NO) - target_compile_options(${BENCHNAME} PRIVATE - "-m64" "-mavx2" "-mfma" "-masm=intel") + if(MSVC) + target_compile_options(${BENCHNAME} PRIVATE "/DFBGEMM_STATIC" "/MT") + else(MSVC) + target_compile_options(${BENCHNAME} PRIVATE + "-m64" "-mavx2" "-mfma" "-masm=intel") + endif(MSVC) target_link_libraries(${BENCHNAME} fbgemm) add_dependencies(${BENCHNAME} fbgemm) if(${MKL_FOUND}) diff --git a/bench/ConvUnifiedBenchmark.cc b/bench/ConvUnifiedBenchmark.cc index 59079c7..6bc2cf4 100644 --- a/bench/ConvUnifiedBenchmark.cc +++ b/bench/ConvUnifiedBenchmark.cc @@ -11,6 +11,7 @@ #include #include #include +#include #ifdef _OPENMP #include diff --git a/bench/PackedRequantizeAcc16Benchmark.cc b/bench/PackedRequantizeAcc16Benchmark.cc index 8706f96..40ff662 100644 --- a/bench/PackedRequantizeAcc16Benchmark.cc +++ b/bench/PackedRequantizeAcc16Benchmark.cc @@ -11,6 +11,7 @@ #include #include #include +#include #ifdef _OPENMP #include diff --git a/cmake/modules/FindMKL.cmake b/cmake/modules/FindMKL.cmake index 6b38af7..7661a40 100644 --- a/cmake/modules/FindMKL.cmake +++ b/cmake/modules/FindMKL.cmake @@ -47,7 +47,7 @@ ELSE(CMAKE_COMPILER_IS_GNUCC) SET(mklifaces "intel") SET(mklrtls "iomp5" "guide") IF (MSVC) - SET(mklrtls "libiomp5md") + SET(mklrtls "libiomp5mt") ENDIF (MSVC) ENDIF (CMAKE_COMPILER_IS_GNUCC) diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h index 721b12f..90d1ee9 100644 --- a/include/fbgemm/Fbgemm.h +++ b/include/fbgemm/Fbgemm.h @@ -266,7 +266,11 @@ class PackMatrix { virtual ~PackMatrix() { if (bufAllocatedHere_) { +#ifdef _MSC_VER + _aligned_free(buf_); +#else free(buf_); +#endif } } @@ -512,7 +516,11 @@ class FBGEMM_API PackWeightMatrixForGConv { ~PackWeightMatrixForGConv() { if (bufAllocatedHere_) { +#ifdef _MSC_VER + _aligned_free(pdata_); +#else free(pdata_); +#endif } } @@ -1358,8 +1366,12 @@ FBGEMM_API bool fbgemmOptimizedGConv(const conv_param_t& conv_p); */ static void* fbgemmAlignedAlloc(size_t __align, size_t __size) { void* aligned_mem; +#ifdef _MSC_VER + aligned_mem = _aligned_malloc(__size, __align); +#else if (posix_memalign(&aligned_mem, __align, __size)) return 0; +#endif return aligned_mem; } diff --git a/include/fbgemm/FbgemmFP16.h b/include/fbgemm/FbgemmFP16.h index 2b4c42b..f6eea21 100644 --- a/include/fbgemm/FbgemmFP16.h +++ b/include/fbgemm/FbgemmFP16.h @@ -108,16 +108,23 @@ class PackedGemmMatrixFP16 { // allocate and initialize packed memory const int padding = 1024; // required by sw pipelined kernels size_ = (blockRowSize() * nbrow_) * (blockColSize() * nbcol_); - // pmat_ = (float16 *)aligned_alloc(64, matSize() * sizeof(float16) + - // padding); +#ifdef _MSC_VER + pmat_ = (float16 *)_aligned_malloc(matSize() * sizeof(float16) + + padding, 64); +#else posix_memalign((void**)&pmat_, 64, matSize() * sizeof(float16) + padding); +#endif for (auto i = 0; i < matSize(); i++) { pmat_[i] = tconv(0.f, pmat_[i]); } } ~PackedGemmMatrixFP16() { +#ifdef _MSC_VER + _aligned_free(pmat_); +#else free(pmat_); +#endif } // protected: @@ -166,7 +173,7 @@ class PackedGemmMatrixFP16 { } int matSize() const { - return size_; + return (int)size_; } int numRows() const { return nrow_; diff --git a/include/fbgemm/Types.h b/include/fbgemm/Types.h index d5f3f6a..c71e7a4 100644 --- a/include/fbgemm/Types.h +++ b/include/fbgemm/Types.h @@ -12,7 +12,11 @@ namespace fbgemm { +#ifdef _MSC_VER +typedef struct __declspec(align(2)) __f16 { +#else typedef struct __attribute__((aligned(2))) __f16 { +#endif uint16_t x; } float16; @@ -38,11 +42,11 @@ static inline float16 cpu_float2half_rn(float f) { // Get rid of +Inf/-Inf, +0/-0. if (u > 0x477fefff) { - ret.x = sign | 0x7c00U; + ret.x = (uint16_t) (sign | 0x7c00U); return ret; } if (u < 0x33000001) { - ret.x = (sign | 0x0000); + ret.x = (uint16_t)(sign | 0x0000); return ret; } @@ -72,7 +76,7 @@ static inline float16 cpu_float2half_rn(float f) { } } - ret.x = (sign | (exponent << 10) | mantissa); + ret.x = (uint16_t)(sign | (exponent << 10) | mantissa); return ret; } diff --git a/src/FbgemmFP16UKernelsAvx2.cc b/src/FbgemmFP16UKernelsAvx2.cc index 0c795b0..5f7492f 100644 --- a/src/FbgemmFP16UKernelsAvx2.cc +++ b/src/FbgemmFP16UKernelsAvx2.cc @@ -5,791 +5,642 @@ * LICENSE file in the root directory of this source tree. */ #include "FbgemmFP16UKernelsAvx2.h" +#include namespace fbgemm { -void __attribute__((noinline)) gemmkernel_1x2_AVX2_fA0fB0fC0(GemmParams* gp) { - asm volatile( -#if !defined(__clang__) - "mov r14, %[gp]\t\n" -#else - "mov %[gp], %%r14\t\n" - ".intel_syntax noprefix\t\n" -#endif +void NOINLINE_ATTR gemmkernel_1x2_AVX2_fA0fB0fC0(GemmParams* gp) { + char* r14 = (char*)gp; //"mov r14, %[gp]\t\n" // Copy parameters // k - "mov r8, [r14 + 0]\t\n" + uint64_t r8 = *(uint64_t *)((char*)r14 + 0 ); //"mov r8, [r14 + 0]\t\n" // A - "mov r9, [r14 + 8]\t\n" + float* r9 = *(float* *)((char*)r14 + 8 ); //"mov r9, [r14 + 8]\t\n" // B - "mov r10, [r14 + 16]\t\n" + const fp16* r10 = *(const fp16**)((char*)r14 + 16); //"mov r10, [r14 + 16]\t\n" // beta - "mov r15, [r14 + 24]\t\n" + float* r15 = *(float* *)((char*)r14 + 24); //"mov r15, [r14 + 24]\t\n" // accum - "mov rdx, [r14 + 32]\t\n" + uint64_t rdx = *(uint64_t *)((char*)r14 + 32); //"mov rdx, [r14 + 32]\t\n" // C - "mov r12, [r14 + 40]\t\n" + float* r12 = *(float* *)((char*)r14 + 40); //"mov r12, [r14 + 40]\t\n" // ldc - "mov r13, [r14 + 48]\t\n" + uint64_t r13 = *(uint64_t *)((char*)r14 + 48); //"mov r13, [r14 + 48]\t\n" // b_block_cols - "mov rdi, [r14 + 56]\t\n" + uint64_t rdi = *(uint64_t *)((char*)r14 + 56); //"mov rdi, [r14 + 56]\t\n" // b_block_size - "mov rsi, [r14 + 64]\t\n" + uint64_t rsi = *(uint64_t *)((char*)r14 + 64); //"mov rsi, [r14 + 64]\t\n" // Make copies of A and C - "mov rax, r9\t\n" - "mov rcx, r12\t\n" - - "mov rbx, 0\t\n" - "loop_outter%=:\t\n" - "mov r14, 0\t\n" - "vxorps ymm0,ymm0,ymm0\t\n" - "vxorps ymm1,ymm1,ymm1\t\n" - - - "loop_inner%=:\t\n" - - "vcvtph2ps ymm3,XMMWORD PTR [r10 + 0]\t\n" - "vcvtph2ps ymm4,XMMWORD PTR [r10 + 16]\t\n" - "vbroadcastss ymm2,DWORD PTR [r9+0]\t\n" - "vfmadd231ps ymm0,ymm3,ymm2\t\n" - "vfmadd231ps ymm1,ymm4,ymm2\t\n" - "add r9,4\t\n" - "add r10,32\t\n" - "inc r14\t\n" - "cmp r14, r8\t\n" - "jl loop_inner%=\t\n" - - "L_exit%=:\t\n" - - "cmp rdx, 1\t\n" - "je L_accum%=\t\n" - // Dump C - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm1\t\n" - "add r12, r13\t\n" - "jmp L_done%=\t\n" - - "L_accum%=:\t\n" - // Dump C with accumulate - "vbroadcastss ymm15,DWORD PTR [r15]\t\n" - "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm1\t\n" - "add r12, r13\t\n" - - "L_done%=:\t\n" - - // next outer iteration - "add rcx, 64\t\n" - "mov r12, rcx\t\n" - "mov r9, rax\t\n" - "inc rbx\t\n" - "cmp rbx, rdi\t\n" - "jl loop_outter%=\t\n" - : - : [gp] "rm"(gp) - : "r8", - "r9", - "r10", - "r11", - "r15", - "r13", - "r14", - "rax", - "rcx", - "rdx", - "rsi", - "rdi", - "rbx", - "r12", - "memory"); + float* rax = r9; //"mov rax, r9\t\n" + float* rcx = r12; //"mov rcx, r12\t\n" + + uint64_t rbx = 0; //"mov rbx, 0\t\n" + for (; rbx < rdi; ++rbx) { //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n" + // //"loop_outter%=:\t\n" + uint64_t r14_i = 0; //"mov r14, 0\t\n" + __m256 ymm0 = _mm256_setzero_ps(); //"vxorps ymm0,ymm0,ymm0\t\n" + __m256 ymm1 = _mm256_setzero_ps(); //"vxorps ymm1,ymm1,ymm1\t\n" + + for (; r14_i < r8; ++r14_i) { //"inc r14; cmp r14, r8; jl loop_inner%=\t\n" + // loop_inner%=: //"\t\n" + auto fp16mem0 = _mm_load_si128((__m128i*)((char*)r10 + 0)); //"vcvtph2ps ymm3,XMMWORD PTR [r10 + 0]\t\n" + auto ymm3 = _mm256_cvtph_ps(fp16mem0); //"vcvtph2ps ymm3,XMMWORD PTR [r10 + 0]\t\n" + auto fp16mem16 = _mm_load_si128((__m128i*)((char*)r10 + 16)); //"vcvtph2ps ymm4,XMMWORD PTR [r10 + 16]\t\n" + auto ymm4 = _mm256_cvtph_ps(fp16mem16); //"vcvtph2ps ymm4,XMMWORD PTR [r10 + 16]\t\n" + auto ymm2 = _mm256_broadcast_ss((float*)((char*)r9 + 0)); //"vbroadcastss ymm2,DWORD PTR [r9+0]\t\n" + ymm0 = _mm256_fmadd_ps(ymm2, ymm3, ymm0); //"vfmadd231ps ymm0,ymm3,ymm2\t\n" + ymm1 = _mm256_fmadd_ps(ymm2, ymm4, ymm1); //"vfmadd231ps ymm1,ymm4,ymm2\t\n" + r9 = (float*)((char*)r9 + 4); //"add r9,4\t\n" + r10 = (fp16*)((char*)r10 + 32); //"add r10,32\t\n" + } //"inc r14; cmp r14, r8; jl loop_outter%=\t\n" + + if(rdx != 1) { //"cmp rdx, 1; je L_accum%=\t\n" + // Dump C + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm0); //"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm1); //"vmovups YMMWORD PTR [r12 + 32], ymm1\t\n" + } else { //"jmp L_done%=\t\n" + // Dump C with accumulate + auto ymm15 = _mm256_broadcast_ss((float*)r15); //"vbroadcastss ymm15,DWORD PTR [r15]\t\n" + auto r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" + ymm0 = _mm256_fmadd_ps(r12_0, ymm15, ymm0); //"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm0); //"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + auto r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n" + ymm1 = _mm256_fmadd_ps(r12_32, ymm15, ymm1); //"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm1); //"vmovups YMMWORD PTR [r12 + 32], ymm1\t\n" + } //"L_done%=:\t\n" + + // next outer iteration + rcx = (float*)((char*)rcx + 64); //"add rcx, 64\t\n" + r12 = rcx; //"mov r12, rcx\t\n" + r9 = rax; //"mov r9, rax\t\n" + } //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n" } -void __attribute__((noinline)) gemmkernel_2x2_AVX2_fA0fB0fC0(GemmParams* gp) { - asm volatile( -#if !defined(__clang__) - "mov r14, %[gp]\t\n" -#else - "mov %[gp], %%r14\t\n" - ".intel_syntax noprefix\t\n" -#endif + +void NOINLINE_ATTR gemmkernel_2x2_AVX2_fA0fB0fC0(GemmParams* gp) { + char* r14 = (char*)gp; //"mov r14, %[gp]\t\n" // Copy parameters // k - "mov r8, [r14 + 0]\t\n" + uint64_t r8 = *(uint64_t *)((char*)r14 + 0 ); //"mov r8, [r14 + 0]\t\n" // A - "mov r9, [r14 + 8]\t\n" + float* r9 = *(float* *)((char*)r14 + 8 ); //"mov r9, [r14 + 8]\t\n" // B - "mov r10, [r14 + 16]\t\n" + const fp16* r10 = *(const fp16**)((char*)r14 + 16); //"mov r10, [r14 + 16]\t\n" // beta - "mov r15, [r14 + 24]\t\n" + float* r15 = *(float* *)((char*)r14 + 24); //"mov r15, [r14 + 24]\t\n" // accum - "mov rdx, [r14 + 32]\t\n" + uint64_t rdx = *(uint64_t *)((char*)r14 + 32); //"mov rdx, [r14 + 32]\t\n" // C - "mov r12, [r14 + 40]\t\n" + float* r12 = *(float* *)((char*)r14 + 40); //"mov r12, [r14 + 40]\t\n" // ldc - "mov r13, [r14 + 48]\t\n" + uint64_t r13 = *(uint64_t *)((char*)r14 + 48); //"mov r13, [r14 + 48]\t\n" // b_block_cols - "mov rdi, [r14 + 56]\t\n" + uint64_t rdi = *(uint64_t *)((char*)r14 + 56); //"mov rdi, [r14 + 56]\t\n" // b_block_size - "mov rsi, [r14 + 64]\t\n" + uint64_t rsi = *(uint64_t *)((char*)r14 + 64); //"mov rsi, [r14 + 64]\t\n" // Make copies of A and C - "mov rax, r9\t\n" - "mov rcx, r12\t\n" - - "mov rbx, 0\t\n" - "loop_outter%=:\t\n" - "mov r14, 0\t\n" - "vxorps ymm0,ymm0,ymm0\t\n" - "vxorps ymm1,ymm1,ymm1\t\n" - "vxorps ymm2,ymm2,ymm2\t\n" - "vxorps ymm3,ymm3,ymm3\t\n" - - - "loop_inner%=:\t\n" - - "vcvtph2ps ymm5,XMMWORD PTR [r10 + 0]\t\n" - "vcvtph2ps ymm6,XMMWORD PTR [r10 + 16]\t\n" - "vbroadcastss ymm4,DWORD PTR [r9+0]\t\n" - "vfmadd231ps ymm0,ymm5,ymm4\t\n" - "vfmadd231ps ymm1,ymm6,ymm4\t\n" - "vbroadcastss ymm4,DWORD PTR [r9+4]\t\n" - "vfmadd231ps ymm2,ymm5,ymm4\t\n" - "vfmadd231ps ymm3,ymm6,ymm4\t\n" - "add r9,8\t\n" - "add r10,32\t\n" - "inc r14\t\n" - "cmp r14, r8\t\n" - "jl loop_inner%=\t\n" - - "L_exit%=:\t\n" - - "cmp rdx, 1\t\n" - "je L_accum%=\t\n" - // Dump C - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm1\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm3\t\n" - "add r12, r13\t\n" - "jmp L_done%=\t\n" - - "L_accum%=:\t\n" - // Dump C with accumulate - "vbroadcastss ymm15,DWORD PTR [r15]\t\n" - "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm1\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" - "vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 32]\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm3\t\n" - "add r12, r13\t\n" - - "L_done%=:\t\n" - - // next outer iteration - "add rcx, 64\t\n" - "mov r12, rcx\t\n" - "mov r9, rax\t\n" - "inc rbx\t\n" - "cmp rbx, rdi\t\n" - "jl loop_outter%=\t\n" - : - : [gp] "rm"(gp) - : "r8", - "r9", - "r10", - "r11", - "r15", - "r13", - "r14", - "rax", - "rcx", - "rdx", - "rsi", - "rdi", - "rbx", - "r12", - "memory"); + float* rax = r9; //"mov rax, r9\t\n" + float* rcx = r12; //"mov rcx, r12\t\n" + + uint64_t rbx = 0; //"mov rbx, 0\t\n" + for (; rbx < rdi; ++rbx) { //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n" + // //"loop_outter%=:\t\n" + uint64_t r14_i = 0; //"mov r14, 0\t\n" + __m256 ymm0 = _mm256_setzero_ps(); //"vxorps ymm0,ymm0,ymm0\t\n" + __m256 ymm1 = _mm256_setzero_ps(); //"vxorps ymm1,ymm1,ymm1\t\n" + __m256 ymm2 = _mm256_setzero_ps(); //"vxorps ymm2,ymm2,ymm2\t\n" + __m256 ymm3 = _mm256_setzero_ps(); //"vxorps ymm3,ymm3,ymm3\t\n" + + for (; r14_i < r8; ++r14_i) { //"inc r14; cmp r14, r8; jl loop_inner%=\t\n" + // loop_inner%=: //"\t\n" + auto fp16mem0 = _mm_load_si128((__m128i*)((char*)r10 + 0)); //"vcvtph2ps ymm5,XMMWORD PTR [r10 + 0]\t\n" + auto ymm5 = _mm256_cvtph_ps(fp16mem0); //"vcvtph2ps ymm5,XMMWORD PTR [r10 + 0]\t\n" + auto fp16mem16 = _mm_load_si128((__m128i*)((char*)r10 + 16)); //"vcvtph2ps ymm6,XMMWORD PTR [r10 + 16]\t\n" + auto ymm6 = _mm256_cvtph_ps(fp16mem16); //"vcvtph2ps ymm6,XMMWORD PTR [r10 + 16]\t\n" + auto ymm4 = _mm256_broadcast_ss((float*)((char*)r9 + 0)); //"vbroadcastss ymm4,DWORD PTR [r9+0]\t\n" + ymm0 = _mm256_fmadd_ps(ymm4, ymm5, ymm0); //"vfmadd231ps ymm0,ymm5,ymm4\t\n" + ymm1 = _mm256_fmadd_ps(ymm4, ymm6, ymm1); //"vfmadd231ps ymm1,ymm6,ymm4\t\n" + ymm4 = _mm256_broadcast_ss((float*)((char*)r9 + 4)); //"vbroadcastss ymm4,DWORD PTR [r9+4]\t\n" + ymm2 = _mm256_fmadd_ps(ymm4, ymm5, ymm2); //"vfmadd231ps ymm2,ymm5,ymm4\t\n" + ymm3 = _mm256_fmadd_ps(ymm4, ymm6, ymm3); //"vfmadd231ps ymm3,ymm6,ymm4\t\n" + r9 = (float*)((char*)r9 + 8); //"add r9,8\t\n" + r10 = (fp16*)((char*)r10 + 32); //"add r10,32\t\n" + } //"inc r14; cmp r14, r8; jl loop_outter%=\t\n" + + if(rdx != 1) { //"cmp rdx, 1; je L_accum%=\t\n" + // Dump C + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm0); //"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm1); //"vmovups YMMWORD PTR [r12 + 32], ymm1\t\n" + r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm2); //"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm3); //"vmovups YMMWORD PTR [r12 + 32], ymm3\t\n" + } else { //"jmp L_done%=\t\n" + // Dump C with accumulate + auto ymm15 = _mm256_broadcast_ss((float*)r15); //"vbroadcastss ymm15,DWORD PTR [r15]\t\n" + auto r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" + ymm0 = _mm256_fmadd_ps(r12_0, ymm15, ymm0); //"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm0); //"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + auto r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n" + ymm1 = _mm256_fmadd_ps(r12_32, ymm15, ymm1); //"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm1); //"vmovups YMMWORD PTR [r12 + 32], ymm1\t\n" + r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n" + r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" + ymm2 = _mm256_fmadd_ps(r12_0, ymm15, ymm2); //"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm2); //"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" + r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 32]\t\n" + ymm3 = _mm256_fmadd_ps(r12_32, ymm15, ymm3); //"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 32]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm3); //"vmovups YMMWORD PTR [r12 + 32], ymm3\t\n" + } //"L_done%=:\t\n" + + // next outer iteration + rcx = (float*)((char*)rcx + 64); //"add rcx, 64\t\n" + r12 = rcx; //"mov r12, rcx\t\n" + r9 = rax; //"mov r9, rax\t\n" + } //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n" } -void __attribute__((noinline)) gemmkernel_3x2_AVX2_fA0fB0fC0(GemmParams* gp) { - asm volatile( -#if !defined(__clang__) - "mov r14, %[gp]\t\n" -#else - "mov %[gp], %%r14\t\n" - ".intel_syntax noprefix\t\n" -#endif + +void NOINLINE_ATTR gemmkernel_3x2_AVX2_fA0fB0fC0(GemmParams* gp) { + char* r14 = (char*)gp; //"mov r14, %[gp]\t\n" // Copy parameters // k - "mov r8, [r14 + 0]\t\n" + uint64_t r8 = *(uint64_t *)((char*)r14 + 0 ); //"mov r8, [r14 + 0]\t\n" // A - "mov r9, [r14 + 8]\t\n" + float* r9 = *(float* *)((char*)r14 + 8 ); //"mov r9, [r14 + 8]\t\n" // B - "mov r10, [r14 + 16]\t\n" + const fp16* r10 = *(const fp16**)((char*)r14 + 16); //"mov r10, [r14 + 16]\t\n" // beta - "mov r15, [r14 + 24]\t\n" + float* r15 = *(float* *)((char*)r14 + 24); //"mov r15, [r14 + 24]\t\n" // accum - "mov rdx, [r14 + 32]\t\n" + uint64_t rdx = *(uint64_t *)((char*)r14 + 32); //"mov rdx, [r14 + 32]\t\n" // C - "mov r12, [r14 + 40]\t\n" + float* r12 = *(float* *)((char*)r14 + 40); //"mov r12, [r14 + 40]\t\n" // ldc - "mov r13, [r14 + 48]\t\n" + uint64_t r13 = *(uint64_t *)((char*)r14 + 48); //"mov r13, [r14 + 48]\t\n" // b_block_cols - "mov rdi, [r14 + 56]\t\n" + uint64_t rdi = *(uint64_t *)((char*)r14 + 56); //"mov rdi, [r14 + 56]\t\n" // b_block_size - "mov rsi, [r14 + 64]\t\n" + uint64_t rsi = *(uint64_t *)((char*)r14 + 64); //"mov rsi, [r14 + 64]\t\n" // Make copies of A and C - "mov rax, r9\t\n" - "mov rcx, r12\t\n" - - "mov rbx, 0\t\n" - "loop_outter%=:\t\n" - "mov r14, 0\t\n" - "vxorps ymm0,ymm0,ymm0\t\n" - "vxorps ymm1,ymm1,ymm1\t\n" - "vxorps ymm2,ymm2,ymm2\t\n" - "vxorps ymm3,ymm3,ymm3\t\n" - "vxorps ymm4,ymm4,ymm4\t\n" - "vxorps ymm5,ymm5,ymm5\t\n" - - - "loop_inner%=:\t\n" - - "vcvtph2ps ymm7,XMMWORD PTR [r10 + 0]\t\n" - "vcvtph2ps ymm8,XMMWORD PTR [r10 + 16]\t\n" - "vbroadcastss ymm6,DWORD PTR [r9+0]\t\n" - "vfmadd231ps ymm0,ymm7,ymm6\t\n" - "vfmadd231ps ymm1,ymm8,ymm6\t\n" - "vbroadcastss ymm6,DWORD PTR [r9+4]\t\n" - "vfmadd231ps ymm2,ymm7,ymm6\t\n" - "vfmadd231ps ymm3,ymm8,ymm6\t\n" - "vbroadcastss ymm6,DWORD PTR [r9+8]\t\n" - "vfmadd231ps ymm4,ymm7,ymm6\t\n" - "vfmadd231ps ymm5,ymm8,ymm6\t\n" - "add r9,12\t\n" - "add r10,32\t\n" - "inc r14\t\n" - "cmp r14, r8\t\n" - "jl loop_inner%=\t\n" - - "L_exit%=:\t\n" - - "cmp rdx, 1\t\n" - "je L_accum%=\t\n" - // Dump C - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm1\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm3\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm5\t\n" - "add r12, r13\t\n" - "jmp L_done%=\t\n" - - "L_accum%=:\t\n" - // Dump C with accumulate - "vbroadcastss ymm15,DWORD PTR [r15]\t\n" - "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm1\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" - "vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 32]\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm3\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" - "vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 32]\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm5\t\n" - "add r12, r13\t\n" - - "L_done%=:\t\n" - - // next outer iteration - "add rcx, 64\t\n" - "mov r12, rcx\t\n" - "mov r9, rax\t\n" - "inc rbx\t\n" - "cmp rbx, rdi\t\n" - "jl loop_outter%=\t\n" - : - : [gp] "rm"(gp) - : "r8", - "r9", - "r10", - "r11", - "r15", - "r13", - "r14", - "rax", - "rcx", - "rdx", - "rsi", - "rdi", - "rbx", - "r12", - "memory"); + float* rax = r9; //"mov rax, r9\t\n" + float* rcx = r12; //"mov rcx, r12\t\n" + + uint64_t rbx = 0; //"mov rbx, 0\t\n" + for (; rbx < rdi; ++rbx) { //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n" + // //"loop_outter%=:\t\n" + uint64_t r14_i = 0; //"mov r14, 0\t\n" + __m256 ymm0 = _mm256_setzero_ps(); //"vxorps ymm0,ymm0,ymm0\t\n" + __m256 ymm1 = _mm256_setzero_ps(); //"vxorps ymm1,ymm1,ymm1\t\n" + __m256 ymm2 = _mm256_setzero_ps(); //"vxorps ymm2,ymm2,ymm2\t\n" + __m256 ymm3 = _mm256_setzero_ps(); //"vxorps ymm3,ymm3,ymm3\t\n" + __m256 ymm4 = _mm256_setzero_ps(); //"vxorps ymm4,ymm4,ymm4\t\n" + __m256 ymm5 = _mm256_setzero_ps(); //"vxorps ymm5,ymm5,ymm5\t\n" + + for (; r14_i < r8; ++r14_i) { //"inc r14; cmp r14, r8; jl loop_inner%=\t\n" + // loop_inner%=: //"\t\n" + auto fp16mem0 = _mm_load_si128((__m128i*)((char*)r10 + 0)); //"vcvtph2ps ymm7,XMMWORD PTR [r10 + 0]\t\n" + auto ymm7 = _mm256_cvtph_ps(fp16mem0); //"vcvtph2ps ymm7,XMMWORD PTR [r10 + 0]\t\n" + auto fp16mem16 = _mm_load_si128((__m128i*)((char*)r10 + 16)); //"vcvtph2ps ymm8,XMMWORD PTR [r10 + 16]\t\n" + auto ymm8 = _mm256_cvtph_ps(fp16mem16); //"vcvtph2ps ymm8,XMMWORD PTR [r10 + 16]\t\n" + auto ymm6 = _mm256_broadcast_ss((float*)((char*)r9 + 0)); //"vbroadcastss ymm6,DWORD PTR [r9+0]\t\n" + ymm0 = _mm256_fmadd_ps(ymm6, ymm7, ymm0); //"vfmadd231ps ymm0,ymm7,ymm6\t\n" + ymm1 = _mm256_fmadd_ps(ymm6, ymm8, ymm1); //"vfmadd231ps ymm1,ymm8,ymm6\t\n" + ymm6 = _mm256_broadcast_ss((float*)((char*)r9 + 4)); //"vbroadcastss ymm6,DWORD PTR [r9+4]\t\n" + ymm2 = _mm256_fmadd_ps(ymm6, ymm7, ymm2); //"vfmadd231ps ymm2,ymm7,ymm6\t\n" + ymm3 = _mm256_fmadd_ps(ymm6, ymm8, ymm3); //"vfmadd231ps ymm3,ymm8,ymm6\t\n" + ymm6 = _mm256_broadcast_ss((float*)((char*)r9 + 8)); //"vbroadcastss ymm6,DWORD PTR [r9+8]\t\n" + ymm4 = _mm256_fmadd_ps(ymm6, ymm7, ymm4); //"vfmadd231ps ymm4,ymm7,ymm6\t\n" + ymm5 = _mm256_fmadd_ps(ymm6, ymm8, ymm5); //"vfmadd231ps ymm5,ymm8,ymm6\t\n" + r9 = (float*)((char*)r9 + 12); //"add r9,12\t\n" + r10 = (fp16*)((char*)r10 + 32); //"add r10,32\t\n" + } //"inc r14; cmp r14, r8; jl loop_outter%=\t\n" + + if(rdx != 1) { //"cmp rdx, 1; je L_accum%=\t\n" + // Dump C + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm0); //"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm1); //"vmovups YMMWORD PTR [r12 + 32], ymm1\t\n" + r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm2); //"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm3); //"vmovups YMMWORD PTR [r12 + 32], ymm3\t\n" + r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm4); //"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm5); //"vmovups YMMWORD PTR [r12 + 32], ymm5\t\n" + } else { //"jmp L_done%=\t\n" + // Dump C with accumulate + auto ymm15 = _mm256_broadcast_ss((float*)r15); //"vbroadcastss ymm15,DWORD PTR [r15]\t\n" + auto r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" + ymm0 = _mm256_fmadd_ps(r12_0, ymm15, ymm0); //"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm0); //"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + auto r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n" + ymm1 = _mm256_fmadd_ps(r12_32, ymm15, ymm1); //"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm1); //"vmovups YMMWORD PTR [r12 + 32], ymm1\t\n" + r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n" + r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" + ymm2 = _mm256_fmadd_ps(r12_0, ymm15, ymm2); //"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm2); //"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" + r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 32]\t\n" + ymm3 = _mm256_fmadd_ps(r12_32, ymm15, ymm3); //"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 32]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm3); //"vmovups YMMWORD PTR [r12 + 32], ymm3\t\n" + r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n" + r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" + ymm4 = _mm256_fmadd_ps(r12_0, ymm15, ymm4); //"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm4); //"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" + r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 32]\t\n" + ymm5 = _mm256_fmadd_ps(r12_32, ymm15, ymm5); //"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 32]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm5); //"vmovups YMMWORD PTR [r12 + 32], ymm5\t\n" + } //"L_done%=:\t\n" + + // next outer iteration + rcx = (float*)((char*)rcx + 64); //"add rcx, 64\t\n" + r12 = rcx; //"mov r12, rcx\t\n" + r9 = rax; //"mov r9, rax\t\n" + } //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n" } -void __attribute__((noinline)) gemmkernel_4x2_AVX2_fA0fB0fC0(GemmParams* gp) { - asm volatile( -#if !defined(__clang__) - "mov r14, %[gp]\t\n" -#else - "mov %[gp], %%r14\t\n" - ".intel_syntax noprefix\t\n" -#endif + +void NOINLINE_ATTR gemmkernel_4x2_AVX2_fA0fB0fC0(GemmParams* gp) { + char* r14 = (char*)gp; //"mov r14, %[gp]\t\n" // Copy parameters // k - "mov r8, [r14 + 0]\t\n" + uint64_t r8 = *(uint64_t *)((char*)r14 + 0 ); //"mov r8, [r14 + 0]\t\n" // A - "mov r9, [r14 + 8]\t\n" + float* r9 = *(float* *)((char*)r14 + 8 ); //"mov r9, [r14 + 8]\t\n" // B - "mov r10, [r14 + 16]\t\n" + const fp16* r10 = *(const fp16**)((char*)r14 + 16); //"mov r10, [r14 + 16]\t\n" // beta - "mov r15, [r14 + 24]\t\n" + float* r15 = *(float* *)((char*)r14 + 24); //"mov r15, [r14 + 24]\t\n" // accum - "mov rdx, [r14 + 32]\t\n" + uint64_t rdx = *(uint64_t *)((char*)r14 + 32); //"mov rdx, [r14 + 32]\t\n" // C - "mov r12, [r14 + 40]\t\n" + float* r12 = *(float* *)((char*)r14 + 40); //"mov r12, [r14 + 40]\t\n" // ldc - "mov r13, [r14 + 48]\t\n" + uint64_t r13 = *(uint64_t *)((char*)r14 + 48); //"mov r13, [r14 + 48]\t\n" // b_block_cols - "mov rdi, [r14 + 56]\t\n" + uint64_t rdi = *(uint64_t *)((char*)r14 + 56); //"mov rdi, [r14 + 56]\t\n" // b_block_size - "mov rsi, [r14 + 64]\t\n" + uint64_t rsi = *(uint64_t *)((char*)r14 + 64); //"mov rsi, [r14 + 64]\t\n" // Make copies of A and C - "mov rax, r9\t\n" - "mov rcx, r12\t\n" - - "mov rbx, 0\t\n" - "loop_outter%=:\t\n" - "mov r14, 0\t\n" - "vxorps ymm0,ymm0,ymm0\t\n" - "vxorps ymm1,ymm1,ymm1\t\n" - "vxorps ymm2,ymm2,ymm2\t\n" - "vxorps ymm3,ymm3,ymm3\t\n" - "vxorps ymm4,ymm4,ymm4\t\n" - "vxorps ymm5,ymm5,ymm5\t\n" - "vxorps ymm6,ymm6,ymm6\t\n" - "vxorps ymm7,ymm7,ymm7\t\n" - - - "loop_inner%=:\t\n" - - "vcvtph2ps ymm9,XMMWORD PTR [r10 + 0]\t\n" - "vcvtph2ps ymm10,XMMWORD PTR [r10 + 16]\t\n" - "vbroadcastss ymm8,DWORD PTR [r9+0]\t\n" - "vfmadd231ps ymm0,ymm9,ymm8\t\n" - "vfmadd231ps ymm1,ymm10,ymm8\t\n" - "vbroadcastss ymm8,DWORD PTR [r9+4]\t\n" - "vfmadd231ps ymm2,ymm9,ymm8\t\n" - "vfmadd231ps ymm3,ymm10,ymm8\t\n" - "vbroadcastss ymm8,DWORD PTR [r9+8]\t\n" - "vfmadd231ps ymm4,ymm9,ymm8\t\n" - "vfmadd231ps ymm5,ymm10,ymm8\t\n" - "vbroadcastss ymm8,DWORD PTR [r9+12]\t\n" - "vfmadd231ps ymm6,ymm9,ymm8\t\n" - "vfmadd231ps ymm7,ymm10,ymm8\t\n" - "add r9,16\t\n" - "add r10,32\t\n" - "inc r14\t\n" - "cmp r14, r8\t\n" - "jl loop_inner%=\t\n" - - "L_exit%=:\t\n" - - "cmp rdx, 1\t\n" - "je L_accum%=\t\n" - // Dump C - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm1\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm3\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm5\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm7\t\n" - "add r12, r13\t\n" - "jmp L_done%=\t\n" - - "L_accum%=:\t\n" - // Dump C with accumulate - "vbroadcastss ymm15,DWORD PTR [r15]\t\n" - "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm1\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" - "vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 32]\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm3\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" - "vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 32]\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm5\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" - "vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 32]\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm7\t\n" - "add r12, r13\t\n" - - "L_done%=:\t\n" - - // next outer iteration - "add rcx, 64\t\n" - "mov r12, rcx\t\n" - "mov r9, rax\t\n" - "inc rbx\t\n" - "cmp rbx, rdi\t\n" - "jl loop_outter%=\t\n" - : - : [gp] "rm"(gp) - : "r8", - "r9", - "r10", - "r11", - "r15", - "r13", - "r14", - "rax", - "rcx", - "rdx", - "rsi", - "rdi", - "rbx", - "r12", - "memory"); + float* rax = r9; //"mov rax, r9\t\n" + float* rcx = r12; //"mov rcx, r12\t\n" + + uint64_t rbx = 0; //"mov rbx, 0\t\n" + for (; rbx < rdi; ++rbx) { //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n" + // //"loop_outter%=:\t\n" + uint64_t r14_i = 0; //"mov r14, 0\t\n" + __m256 ymm0 = _mm256_setzero_ps(); //"vxorps ymm0,ymm0,ymm0\t\n" + __m256 ymm1 = _mm256_setzero_ps(); //"vxorps ymm1,ymm1,ymm1\t\n" + __m256 ymm2 = _mm256_setzero_ps(); //"vxorps ymm2,ymm2,ymm2\t\n" + __m256 ymm3 = _mm256_setzero_ps(); //"vxorps ymm3,ymm3,ymm3\t\n" + __m256 ymm4 = _mm256_setzero_ps(); //"vxorps ymm4,ymm4,ymm4\t\n" + __m256 ymm5 = _mm256_setzero_ps(); //"vxorps ymm5,ymm5,ymm5\t\n" + __m256 ymm6 = _mm256_setzero_ps(); //"vxorps ymm6,ymm6,ymm6\t\n" + __m256 ymm7 = _mm256_setzero_ps(); //"vxorps ymm7,ymm7,ymm7\t\n" + + for (; r14_i < r8; ++r14_i) { //"inc r14; cmp r14, r8; jl loop_inner%=\t\n" + // loop_inner%=: //"\t\n" + auto fp16mem0 = _mm_load_si128((__m128i*)((char*)r10 + 0)); //"vcvtph2ps ymm9,XMMWORD PTR [r10 + 0]\t\n" + auto ymm9 = _mm256_cvtph_ps(fp16mem0); //"vcvtph2ps ymm9,XMMWORD PTR [r10 + 0]\t\n" + auto fp16mem16 = _mm_load_si128((__m128i*)((char*)r10 + 16)); //"vcvtph2ps ymm10,XMMWORD PTR [r10 + 16]\t\n" + auto ymm10 = _mm256_cvtph_ps(fp16mem16); //"vcvtph2ps ymm10,XMMWORD PTR [r10 + 16]\t\n" + auto ymm8 = _mm256_broadcast_ss((float*)((char*)r9 + 0)); //"vbroadcastss ymm8,DWORD PTR [r9+0]\t\n" + ymm0 = _mm256_fmadd_ps(ymm8, ymm9, ymm0); //"vfmadd231ps ymm0,ymm9,ymm8\t\n" + ymm1 = _mm256_fmadd_ps(ymm8, ymm10, ymm1); //"vfmadd231ps ymm1,ymm10,ymm8\t\n" + ymm8 = _mm256_broadcast_ss((float*)((char*)r9 + 4)); //"vbroadcastss ymm8,DWORD PTR [r9+4]\t\n" + ymm2 = _mm256_fmadd_ps(ymm8, ymm9, ymm2); //"vfmadd231ps ymm2,ymm9,ymm8\t\n" + ymm3 = _mm256_fmadd_ps(ymm8, ymm10, ymm3); //"vfmadd231ps ymm3,ymm10,ymm8\t\n" + ymm8 = _mm256_broadcast_ss((float*)((char*)r9 + 8)); //"vbroadcastss ymm8,DWORD PTR [r9+8]\t\n" + ymm4 = _mm256_fmadd_ps(ymm8, ymm9, ymm4); //"vfmadd231ps ymm4,ymm9,ymm8\t\n" + ymm5 = _mm256_fmadd_ps(ymm8, ymm10, ymm5); //"vfmadd231ps ymm5,ymm10,ymm8\t\n" + ymm8 = _mm256_broadcast_ss((float*)((char*)r9 + 12)); //"vbroadcastss ymm8,DWORD PTR [r9+12]\t\n" + ymm6 = _mm256_fmadd_ps(ymm8, ymm9, ymm6); //"vfmadd231ps ymm6,ymm9,ymm8\t\n" + ymm7 = _mm256_fmadd_ps(ymm8, ymm10, ymm7); //"vfmadd231ps ymm7,ymm10,ymm8\t\n" + r9 = (float*)((char*)r9 + 16); //"add r9,16\t\n" + r10 = (fp16*)((char*)r10 + 32); //"add r10,32\t\n" + } //"inc r14; cmp r14, r8; jl loop_outter%=\t\n" + + if(rdx != 1) { //"cmp rdx, 1; je L_accum%=\t\n" + // Dump C + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm0); //"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm1); //"vmovups YMMWORD PTR [r12 + 32], ymm1\t\n" + r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm2); //"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm3); //"vmovups YMMWORD PTR [r12 + 32], ymm3\t\n" + r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm4); //"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm5); //"vmovups YMMWORD PTR [r12 + 32], ymm5\t\n" + r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm6); //"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm7); //"vmovups YMMWORD PTR [r12 + 32], ymm7\t\n" + } else { //"jmp L_done%=\t\n" + // Dump C with accumulate + auto ymm15 = _mm256_broadcast_ss((float*)r15); //"vbroadcastss ymm15,DWORD PTR [r15]\t\n" + auto r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" + ymm0 = _mm256_fmadd_ps(r12_0, ymm15, ymm0); //"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm0); //"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + auto r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n" + ymm1 = _mm256_fmadd_ps(r12_32, ymm15, ymm1); //"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm1); //"vmovups YMMWORD PTR [r12 + 32], ymm1\t\n" + r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n" + r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" + ymm2 = _mm256_fmadd_ps(r12_0, ymm15, ymm2); //"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm2); //"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" + r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 32]\t\n" + ymm3 = _mm256_fmadd_ps(r12_32, ymm15, ymm3); //"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 32]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm3); //"vmovups YMMWORD PTR [r12 + 32], ymm3\t\n" + r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n" + r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" + ymm4 = _mm256_fmadd_ps(r12_0, ymm15, ymm4); //"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm4); //"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" + r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 32]\t\n" + ymm5 = _mm256_fmadd_ps(r12_32, ymm15, ymm5); //"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 32]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm5); //"vmovups YMMWORD PTR [r12 + 32], ymm5\t\n" + r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n" + r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" + ymm6 = _mm256_fmadd_ps(r12_0, ymm15, ymm6); //"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm6); //"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" + r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 32]\t\n" + ymm7 = _mm256_fmadd_ps(r12_32, ymm15, ymm7); //"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 32]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm7); //"vmovups YMMWORD PTR [r12 + 32], ymm7\t\n" + } //"L_done%=:\t\n" + + // next outer iteration + rcx = (float*)((char*)rcx + 64); //"add rcx, 64\t\n" + r12 = rcx; //"mov r12, rcx\t\n" + r9 = rax; //"mov r9, rax\t\n" + } //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n" } -void __attribute__((noinline)) gemmkernel_5x2_AVX2_fA0fB0fC0(GemmParams* gp) { - asm volatile( -#if !defined(__clang__) - "mov r14, %[gp]\t\n" -#else - "mov %[gp], %%r14\t\n" - ".intel_syntax noprefix\t\n" -#endif + +void NOINLINE_ATTR gemmkernel_5x2_AVX2_fA0fB0fC0(GemmParams* gp) { + char* r14 = (char*)gp; //"mov r14, %[gp]\t\n" // Copy parameters // k - "mov r8, [r14 + 0]\t\n" + uint64_t r8 = *(uint64_t *)((char*)r14 + 0 ); //"mov r8, [r14 + 0]\t\n" // A - "mov r9, [r14 + 8]\t\n" + float* r9 = *(float* *)((char*)r14 + 8 ); //"mov r9, [r14 + 8]\t\n" // B - "mov r10, [r14 + 16]\t\n" + const fp16* r10 = *(const fp16**)((char*)r14 + 16); //"mov r10, [r14 + 16]\t\n" // beta - "mov r15, [r14 + 24]\t\n" + float* r15 = *(float* *)((char*)r14 + 24); //"mov r15, [r14 + 24]\t\n" // accum - "mov rdx, [r14 + 32]\t\n" + uint64_t rdx = *(uint64_t *)((char*)r14 + 32); //"mov rdx, [r14 + 32]\t\n" // C - "mov r12, [r14 + 40]\t\n" + float* r12 = *(float* *)((char*)r14 + 40); //"mov r12, [r14 + 40]\t\n" // ldc - "mov r13, [r14 + 48]\t\n" + uint64_t r13 = *(uint64_t *)((char*)r14 + 48); //"mov r13, [r14 + 48]\t\n" // b_block_cols - "mov rdi, [r14 + 56]\t\n" + uint64_t rdi = *(uint64_t *)((char*)r14 + 56); //"mov rdi, [r14 + 56]\t\n" // b_block_size - "mov rsi, [r14 + 64]\t\n" + uint64_t rsi = *(uint64_t *)((char*)r14 + 64); //"mov rsi, [r14 + 64]\t\n" // Make copies of A and C - "mov rax, r9\t\n" - "mov rcx, r12\t\n" - - "mov rbx, 0\t\n" - "loop_outter%=:\t\n" - "mov r14, 0\t\n" - "vxorps ymm0,ymm0,ymm0\t\n" - "vxorps ymm1,ymm1,ymm1\t\n" - "vxorps ymm2,ymm2,ymm2\t\n" - "vxorps ymm3,ymm3,ymm3\t\n" - "vxorps ymm4,ymm4,ymm4\t\n" - "vxorps ymm5,ymm5,ymm5\t\n" - "vxorps ymm6,ymm6,ymm6\t\n" - "vxorps ymm7,ymm7,ymm7\t\n" - "vxorps ymm8,ymm8,ymm8\t\n" - "vxorps ymm9,ymm9,ymm9\t\n" - - - "loop_inner%=:\t\n" - - "vcvtph2ps ymm11,XMMWORD PTR [r10 + 0]\t\n" - "vcvtph2ps ymm12,XMMWORD PTR [r10 + 16]\t\n" - "vbroadcastss ymm10,DWORD PTR [r9+0]\t\n" - "vfmadd231ps ymm0,ymm11,ymm10\t\n" - "vfmadd231ps ymm1,ymm12,ymm10\t\n" - "vbroadcastss ymm10,DWORD PTR [r9+4]\t\n" - "vfmadd231ps ymm2,ymm11,ymm10\t\n" - "vfmadd231ps ymm3,ymm12,ymm10\t\n" - "vbroadcastss ymm10,DWORD PTR [r9+8]\t\n" - "vfmadd231ps ymm4,ymm11,ymm10\t\n" - "vfmadd231ps ymm5,ymm12,ymm10\t\n" - "vbroadcastss ymm10,DWORD PTR [r9+12]\t\n" - "vfmadd231ps ymm6,ymm11,ymm10\t\n" - "vfmadd231ps ymm7,ymm12,ymm10\t\n" - "vbroadcastss ymm10,DWORD PTR [r9+16]\t\n" - "vfmadd231ps ymm8,ymm11,ymm10\t\n" - "vfmadd231ps ymm9,ymm12,ymm10\t\n" - "add r9,20\t\n" - "add r10,32\t\n" - "inc r14\t\n" - "cmp r14, r8\t\n" - "jl loop_inner%=\t\n" - - "L_exit%=:\t\n" - - "cmp rdx, 1\t\n" - "je L_accum%=\t\n" - // Dump C - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm1\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm3\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm5\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm7\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm9\t\n" - "add r12, r13\t\n" - "jmp L_done%=\t\n" - - "L_accum%=:\t\n" - // Dump C with accumulate - "vbroadcastss ymm15,DWORD PTR [r15]\t\n" - "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm1\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" - "vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 32]\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm3\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" - "vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 32]\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm5\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" - "vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 32]\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm7\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" - "vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 32]\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm9\t\n" - "add r12, r13\t\n" - - "L_done%=:\t\n" - - // next outer iteration - "add rcx, 64\t\n" - "mov r12, rcx\t\n" - "mov r9, rax\t\n" - "inc rbx\t\n" - "cmp rbx, rdi\t\n" - "jl loop_outter%=\t\n" - : - : [gp] "rm"(gp) - : "r8", - "r9", - "r10", - "r11", - "r15", - "r13", - "r14", - "rax", - "rcx", - "rdx", - "rsi", - "rdi", - "rbx", - "r12", - "memory"); + float* rax = r9; //"mov rax, r9\t\n" + float* rcx = r12; //"mov rcx, r12\t\n" + + uint64_t rbx = 0; //"mov rbx, 0\t\n" + for (; rbx < rdi; ++rbx) { //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n" + // //"loop_outter%=:\t\n" + uint64_t r14_i = 0; //"mov r14, 0\t\n" + __m256 ymm0 = _mm256_setzero_ps(); //"vxorps ymm0,ymm0,ymm0\t\n" + __m256 ymm1 = _mm256_setzero_ps(); //"vxorps ymm1,ymm1,ymm1\t\n" + __m256 ymm2 = _mm256_setzero_ps(); //"vxorps ymm2,ymm2,ymm2\t\n" + __m256 ymm3 = _mm256_setzero_ps(); //"vxorps ymm3,ymm3,ymm3\t\n" + __m256 ymm4 = _mm256_setzero_ps(); //"vxorps ymm4,ymm4,ymm4\t\n" + __m256 ymm5 = _mm256_setzero_ps(); //"vxorps ymm5,ymm5,ymm5\t\n" + __m256 ymm6 = _mm256_setzero_ps(); //"vxorps ymm6,ymm6,ymm6\t\n" + __m256 ymm7 = _mm256_setzero_ps(); //"vxorps ymm7,ymm7,ymm7\t\n" + __m256 ymm8 = _mm256_setzero_ps(); //"vxorps ymm8,ymm8,ymm8\t\n" + __m256 ymm9 = _mm256_setzero_ps(); //"vxorps ymm9,ymm9,ymm9\t\n" + + for (; r14_i < r8; ++r14_i) { //"inc r14; cmp r14, r8; jl loop_inner%=\t\n" + // loop_inner%=: //"\t\n" + auto fp16mem0 = _mm_load_si128((__m128i*)((char*)r10 + 0)); //"vcvtph2ps ymm11,XMMWORD PTR [r10 + 0]\t\n" + auto ymm11 = _mm256_cvtph_ps(fp16mem0); //"vcvtph2ps ymm11,XMMWORD PTR [r10 + 0]\t\n" + auto fp16mem16 = _mm_load_si128((__m128i*)((char*)r10 + 16)); //"vcvtph2ps ymm12,XMMWORD PTR [r10 + 16]\t\n" + auto ymm12 = _mm256_cvtph_ps(fp16mem16); //"vcvtph2ps ymm12,XMMWORD PTR [r10 + 16]\t\n" + auto ymm10 = _mm256_broadcast_ss((float*)((char*)r9 + 0)); //"vbroadcastss ymm10,DWORD PTR [r9+0]\t\n" + ymm0 = _mm256_fmadd_ps(ymm10, ymm11, ymm0); //"vfmadd231ps ymm0,ymm11,ymm10\t\n" + ymm1 = _mm256_fmadd_ps(ymm10, ymm12, ymm1); //"vfmadd231ps ymm1,ymm12,ymm10\t\n" + ymm10 = _mm256_broadcast_ss((float*)((char*)r9 + 4)); //"vbroadcastss ymm10,DWORD PTR [r9+4]\t\n" + ymm2 = _mm256_fmadd_ps(ymm10, ymm11, ymm2); //"vfmadd231ps ymm2,ymm11,ymm10\t\n" + ymm3 = _mm256_fmadd_ps(ymm10, ymm12, ymm3); //"vfmadd231ps ymm3,ymm12,ymm10\t\n" + ymm10 = _mm256_broadcast_ss((float*)((char*)r9 + 8)); //"vbroadcastss ymm10,DWORD PTR [r9+8]\t\n" + ymm4 = _mm256_fmadd_ps(ymm10, ymm11, ymm4); //"vfmadd231ps ymm4,ymm11,ymm10\t\n" + ymm5 = _mm256_fmadd_ps(ymm10, ymm12, ymm5); //"vfmadd231ps ymm5,ymm12,ymm10\t\n" + ymm10 = _mm256_broadcast_ss((float*)((char*)r9 + 12)); //"vbroadcastss ymm10,DWORD PTR [r9+12]\t\n" + ymm6 = _mm256_fmadd_ps(ymm10, ymm11, ymm6); //"vfmadd231ps ymm6,ymm11,ymm10\t\n" + ymm7 = _mm256_fmadd_ps(ymm10, ymm12, ymm7); //"vfmadd231ps ymm7,ymm12,ymm10\t\n" + ymm10 = _mm256_broadcast_ss((float*)((char*)r9 + 16)); //"vbroadcastss ymm10,DWORD PTR [r9+16]\t\n" + ymm8 = _mm256_fmadd_ps(ymm10, ymm11, ymm8); //"vfmadd231ps ymm8,ymm11,ymm10\t\n" + ymm9 = _mm256_fmadd_ps(ymm10, ymm12, ymm9); //"vfmadd231ps ymm9,ymm12,ymm10\t\n" + r9 = (float*)((char*)r9 + 20); //"add r9,20\t\n" + r10 = (fp16*)((char*)r10 + 32); //"add r10,32\t\n" + } //"inc r14; cmp r14, r8; jl loop_outter%=\t\n" + + if(rdx != 1) { //"cmp rdx, 1; je L_accum%=\t\n" + // Dump C + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm0); //"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm1); //"vmovups YMMWORD PTR [r12 + 32], ymm1\t\n" + r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm2); //"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm3); //"vmovups YMMWORD PTR [r12 + 32], ymm3\t\n" + r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm4); //"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm5); //"vmovups YMMWORD PTR [r12 + 32], ymm5\t\n" + r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm6); //"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm7); //"vmovups YMMWORD PTR [r12 + 32], ymm7\t\n" + r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm8); //"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm9); //"vmovups YMMWORD PTR [r12 + 32], ymm9\t\n" + } else { //"jmp L_done%=\t\n" + // Dump C with accumulate + auto ymm15 = _mm256_broadcast_ss((float*)r15); //"vbroadcastss ymm15,DWORD PTR [r15]\t\n" + auto r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" + ymm0 = _mm256_fmadd_ps(r12_0, ymm15, ymm0); //"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm0); //"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + auto r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n" + ymm1 = _mm256_fmadd_ps(r12_32, ymm15, ymm1); //"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm1); //"vmovups YMMWORD PTR [r12 + 32], ymm1\t\n" + r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n" + r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" + ymm2 = _mm256_fmadd_ps(r12_0, ymm15, ymm2); //"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm2); //"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" + r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 32]\t\n" + ymm3 = _mm256_fmadd_ps(r12_32, ymm15, ymm3); //"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 32]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm3); //"vmovups YMMWORD PTR [r12 + 32], ymm3\t\n" + r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n" + r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" + ymm4 = _mm256_fmadd_ps(r12_0, ymm15, ymm4); //"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm4); //"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" + r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 32]\t\n" + ymm5 = _mm256_fmadd_ps(r12_32, ymm15, ymm5); //"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 32]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm5); //"vmovups YMMWORD PTR [r12 + 32], ymm5\t\n" + r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n" + r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" + ymm6 = _mm256_fmadd_ps(r12_0, ymm15, ymm6); //"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm6); //"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" + r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 32]\t\n" + ymm7 = _mm256_fmadd_ps(r12_32, ymm15, ymm7); //"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 32]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm7); //"vmovups YMMWORD PTR [r12 + 32], ymm7\t\n" + r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n" + r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n" + ymm8 = _mm256_fmadd_ps(r12_0, ymm15, ymm8); //"vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm8); //"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" + r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 32]\t\n" + ymm9 = _mm256_fmadd_ps(r12_32, ymm15, ymm9); //"vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 32]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm9); //"vmovups YMMWORD PTR [r12 + 32], ymm9\t\n" + } //"L_done%=:\t\n" + + // next outer iteration + rcx = (float*)((char*)rcx + 64); //"add rcx, 64\t\n" + r12 = rcx; //"mov r12, rcx\t\n" + r9 = rax; //"mov r9, rax\t\n" + } //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n" } -void __attribute__((noinline)) gemmkernel_6x2_AVX2_fA0fB0fC0(GemmParams* gp) { - asm volatile( -#if !defined(__clang__) - "mov r14, %[gp]\t\n" -#else - "mov %[gp], %%r14\t\n" - ".intel_syntax noprefix\t\n" -#endif + +void NOINLINE_ATTR gemmkernel_6x2_AVX2_fA0fB0fC0(GemmParams* gp) { + char* r14 = (char*)gp; //"mov r14, %[gp]\t\n" // Copy parameters // k - "mov r8, [r14 + 0]\t\n" + uint64_t r8 = *(uint64_t *)((char*)r14 + 0 ); //"mov r8, [r14 + 0]\t\n" // A - "mov r9, [r14 + 8]\t\n" + float* r9 = *(float* *)((char*)r14 + 8 ); //"mov r9, [r14 + 8]\t\n" // B - "mov r10, [r14 + 16]\t\n" + const fp16* r10 = *(const fp16**)((char*)r14 + 16); //"mov r10, [r14 + 16]\t\n" // beta - "mov r15, [r14 + 24]\t\n" + float* r15 = *(float* *)((char*)r14 + 24); //"mov r15, [r14 + 24]\t\n" // accum - "mov rdx, [r14 + 32]\t\n" + uint64_t rdx = *(uint64_t *)((char*)r14 + 32); //"mov rdx, [r14 + 32]\t\n" // C - "mov r12, [r14 + 40]\t\n" + float* r12 = *(float* *)((char*)r14 + 40); //"mov r12, [r14 + 40]\t\n" // ldc - "mov r13, [r14 + 48]\t\n" + uint64_t r13 = *(uint64_t *)((char*)r14 + 48); //"mov r13, [r14 + 48]\t\n" // b_block_cols - "mov rdi, [r14 + 56]\t\n" + uint64_t rdi = *(uint64_t *)((char*)r14 + 56); //"mov rdi, [r14 + 56]\t\n" // b_block_size - "mov rsi, [r14 + 64]\t\n" + uint64_t rsi = *(uint64_t *)((char*)r14 + 64); //"mov rsi, [r14 + 64]\t\n" // Make copies of A and C - "mov rax, r9\t\n" - "mov rcx, r12\t\n" - - "mov rbx, 0\t\n" - "loop_outter%=:\t\n" - "mov r14, 0\t\n" - "vxorps ymm0,ymm0,ymm0\t\n" - "vxorps ymm1,ymm1,ymm1\t\n" - "vxorps ymm2,ymm2,ymm2\t\n" - "vxorps ymm3,ymm3,ymm3\t\n" - "vxorps ymm4,ymm4,ymm4\t\n" - "vxorps ymm5,ymm5,ymm5\t\n" - "vxorps ymm6,ymm6,ymm6\t\n" - "vxorps ymm7,ymm7,ymm7\t\n" - "vxorps ymm8,ymm8,ymm8\t\n" - "vxorps ymm9,ymm9,ymm9\t\n" - "vxorps ymm10,ymm10,ymm10\t\n" - "vxorps ymm11,ymm11,ymm11\t\n" - - - "loop_inner%=:\t\n" - - "vcvtph2ps ymm13,XMMWORD PTR [r10 + 0]\t\n" - "vcvtph2ps ymm14,XMMWORD PTR [r10 + 16]\t\n" - "vbroadcastss ymm12,DWORD PTR [r9+0]\t\n" - "vfmadd231ps ymm0,ymm13,ymm12\t\n" - "vfmadd231ps ymm1,ymm14,ymm12\t\n" - "vbroadcastss ymm12,DWORD PTR [r9+4]\t\n" - "vfmadd231ps ymm2,ymm13,ymm12\t\n" - "vfmadd231ps ymm3,ymm14,ymm12\t\n" - "vbroadcastss ymm12,DWORD PTR [r9+8]\t\n" - "vfmadd231ps ymm4,ymm13,ymm12\t\n" - "vfmadd231ps ymm5,ymm14,ymm12\t\n" - "vbroadcastss ymm12,DWORD PTR [r9+12]\t\n" - "vfmadd231ps ymm6,ymm13,ymm12\t\n" - "vfmadd231ps ymm7,ymm14,ymm12\t\n" - "vbroadcastss ymm12,DWORD PTR [r9+16]\t\n" - "vfmadd231ps ymm8,ymm13,ymm12\t\n" - "vfmadd231ps ymm9,ymm14,ymm12\t\n" - "vbroadcastss ymm12,DWORD PTR [r9+20]\t\n" - "vfmadd231ps ymm10,ymm13,ymm12\t\n" - "vfmadd231ps ymm11,ymm14,ymm12\t\n" - "add r9,24\t\n" - "add r10,32\t\n" - "inc r14\t\n" - "cmp r14, r8\t\n" - "jl loop_inner%=\t\n" - - "L_exit%=:\t\n" - - "cmp rdx, 1\t\n" - "je L_accum%=\t\n" - // Dump C - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm1\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm3\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm5\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm7\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm9\t\n" - "add r12, r13\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm11\t\n" - "add r12, r13\t\n" - "jmp L_done%=\t\n" - - "L_accum%=:\t\n" - // Dump C with accumulate - "vbroadcastss ymm15,DWORD PTR [r15]\t\n" - "vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" - "vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm1\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" - "vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 32]\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm3\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" - "vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 32]\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm5\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" - "vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 32]\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm7\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" - "vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 32]\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm9\t\n" - "add r12, r13\t\n" - "vfmadd231ps ymm10,ymm15,YMMWORD PTR [r12 + 0]\t\n" - "vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" - "vfmadd231ps ymm11,ymm15,YMMWORD PTR [r12 + 32]\t\n" - "vmovups YMMWORD PTR [r12 + 32], ymm11\t\n" - "add r12, r13\t\n" - - "L_done%=:\t\n" - - // next outer iteration - "add rcx, 64\t\n" - "mov r12, rcx\t\n" - "mov r9, rax\t\n" - "inc rbx\t\n" - "cmp rbx, rdi\t\n" - "jl loop_outter%=\t\n" - : - : [gp] "rm"(gp) - : "r8", - "r9", - "r10", - "r11", - "r15", - "r13", - "r14", - "rax", - "rcx", - "rdx", - "rsi", - "rdi", - "rbx", - "r12", - "memory"); + float* rax = r9; //"mov rax, r9\t\n" + float* rcx = r12; //"mov rcx, r12\t\n" + + uint64_t rbx = 0; //"mov rbx, 0\t\n" + for (; rbx < rdi; ++rbx) { //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n" + // //"loop_outter%=:\t\n" + uint64_t r14_i = 0; //"mov r14, 0\t\n" + __m256 ymm0 = _mm256_setzero_ps(); //"vxorps ymm0,ymm0,ymm0\t\n" + __m256 ymm1 = _mm256_setzero_ps(); //"vxorps ymm1,ymm1,ymm1\t\n" + __m256 ymm2 = _mm256_setzero_ps(); //"vxorps ymm2,ymm2,ymm2\t\n" + __m256 ymm3 = _mm256_setzero_ps(); //"vxorps ymm3,ymm3,ymm3\t\n" + __m256 ymm4 = _mm256_setzero_ps(); //"vxorps ymm4,ymm4,ymm4\t\n" + __m256 ymm5 = _mm256_setzero_ps(); //"vxorps ymm5,ymm5,ymm5\t\n" + __m256 ymm6 = _mm256_setzero_ps(); //"vxorps ymm6,ymm6,ymm6\t\n" + __m256 ymm7 = _mm256_setzero_ps(); //"vxorps ymm7,ymm7,ymm7\t\n" + __m256 ymm8 = _mm256_setzero_ps(); //"vxorps ymm8,ymm8,ymm8\t\n" + __m256 ymm9 = _mm256_setzero_ps(); //"vxorps ymm9,ymm9,ymm9\t\n" + __m256 ymm10 = _mm256_setzero_ps(); //"vxorps ymm10,ymm10,ymm10\t\n" + __m256 ymm11 = _mm256_setzero_ps(); //"vxorps ymm11,ymm11,ymm11\t\n" + + for (; r14_i < r8; ++r14_i) { //"inc r14; cmp r14, r8; jl loop_inner%=\t\n" + // loop_inner%=: //"\t\n" + auto fp16mem0 = _mm_load_si128((__m128i*)((char*)r10 + 0)); //"vcvtph2ps ymm13,XMMWORD PTR [r10 + 0]\t\n" + auto ymm13 = _mm256_cvtph_ps(fp16mem0); //"vcvtph2ps ymm13,XMMWORD PTR [r10 + 0]\t\n" + auto fp16mem16 = _mm_load_si128((__m128i*)((char*)r10 + 16)); //"vcvtph2ps ymm14,XMMWORD PTR [r10 + 16]\t\n" + auto ymm14 = _mm256_cvtph_ps(fp16mem16); //"vcvtph2ps ymm14,XMMWORD PTR [r10 + 16]\t\n" + auto ymm12 = _mm256_broadcast_ss((float*)((char*)r9 + 0)); //"vbroadcastss ymm12,DWORD PTR [r9+0]\t\n" + ymm0 = _mm256_fmadd_ps(ymm12, ymm13, ymm0); //"vfmadd231ps ymm0,ymm13,ymm12\t\n" + ymm1 = _mm256_fmadd_ps(ymm12, ymm14, ymm1); //"vfmadd231ps ymm1,ymm14,ymm12\t\n" + ymm12 = _mm256_broadcast_ss((float*)((char*)r9 + 4)); //"vbroadcastss ymm12,DWORD PTR [r9+4]\t\n" + ymm2 = _mm256_fmadd_ps(ymm12, ymm13, ymm2); //"vfmadd231ps ymm2,ymm13,ymm12\t\n" + ymm3 = _mm256_fmadd_ps(ymm12, ymm14, ymm3); //"vfmadd231ps ymm3,ymm14,ymm12\t\n" + ymm12 = _mm256_broadcast_ss((float*)((char*)r9 + 8)); //"vbroadcastss ymm12,DWORD PTR [r9+8]\t\n" + ymm4 = _mm256_fmadd_ps(ymm12, ymm13, ymm4); //"vfmadd231ps ymm4,ymm13,ymm12\t\n" + ymm5 = _mm256_fmadd_ps(ymm12, ymm14, ymm5); //"vfmadd231ps ymm5,ymm14,ymm12\t\n" + ymm12 = _mm256_broadcast_ss((float*)((char*)r9 + 12)); //"vbroadcastss ymm12,DWORD PTR [r9+12]\t\n" + ymm6 = _mm256_fmadd_ps(ymm12, ymm13, ymm6); //"vfmadd231ps ymm6,ymm13,ymm12\t\n" + ymm7 = _mm256_fmadd_ps(ymm12, ymm14, ymm7); //"vfmadd231ps ymm7,ymm14,ymm12\t\n" + ymm12 = _mm256_broadcast_ss((float*)((char*)r9 + 16)); //"vbroadcastss ymm12,DWORD PTR [r9+16]\t\n" + ymm8 = _mm256_fmadd_ps(ymm12, ymm13, ymm8); //"vfmadd231ps ymm8,ymm13,ymm12\t\n" + ymm9 = _mm256_fmadd_ps(ymm12, ymm14, ymm9); //"vfmadd231ps ymm9,ymm14,ymm12\t\n" + ymm12 = _mm256_broadcast_ss((float*)((char*)r9 + 20)); //"vbroadcastss ymm12,DWORD PTR [r9+20]\t\n" + ymm10 = _mm256_fmadd_ps(ymm12, ymm13, ymm10); //"vfmadd231ps ymm10,ymm13,ymm12\t\n" + ymm11 = _mm256_fmadd_ps(ymm12, ymm14, ymm11); //"vfmadd231ps ymm11,ymm14,ymm12\t\n" + r9 = (float*)((char*)r9 + 24); //"add r9,24\t\n" + r10 = (fp16*)((char*)r10 + 32); //"add r10,32\t\n" + } //"inc r14; cmp r14, r8; jl loop_outter%=\t\n" + + if(rdx != 1) { //"cmp rdx, 1; je L_accum%=\t\n" + // Dump C + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm0); //"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm1); //"vmovups YMMWORD PTR [r12 + 32], ymm1\t\n" + r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm2); //"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm3); //"vmovups YMMWORD PTR [r12 + 32], ymm3\t\n" + r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm4); //"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm5); //"vmovups YMMWORD PTR [r12 + 32], ymm5\t\n" + r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm6); //"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm7); //"vmovups YMMWORD PTR [r12 + 32], ymm7\t\n" + r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm8); //"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm9); //"vmovups YMMWORD PTR [r12 + 32], ymm9\t\n" + r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm10); //"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm11); //"vmovups YMMWORD PTR [r12 + 32], ymm11\t\n" + } else { //"jmp L_done%=\t\n" + // Dump C with accumulate + auto ymm15 = _mm256_broadcast_ss((float*)r15); //"vbroadcastss ymm15,DWORD PTR [r15]\t\n" + auto r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" + ymm0 = _mm256_fmadd_ps(r12_0, ymm15, ymm0); //"vfmadd231ps ymm0,ymm15,YMMWORD PTR [r12 + 0]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm0); //"vmovups YMMWORD PTR [r12 + 0], ymm0\t\n" + auto r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n" + ymm1 = _mm256_fmadd_ps(r12_32, ymm15, ymm1); //"vfmadd231ps ymm1,ymm15,YMMWORD PTR [r12 + 32]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm1); //"vmovups YMMWORD PTR [r12 + 32], ymm1\t\n" + r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n" + r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" + ymm2 = _mm256_fmadd_ps(r12_0, ymm15, ymm2); //"vfmadd231ps ymm2,ymm15,YMMWORD PTR [r12 + 0]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm2); //"vmovups YMMWORD PTR [r12 + 0], ymm2\t\n" + r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 32]\t\n" + ymm3 = _mm256_fmadd_ps(r12_32, ymm15, ymm3); //"vfmadd231ps ymm3,ymm15,YMMWORD PTR [r12 + 32]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm3); //"vmovups YMMWORD PTR [r12 + 32], ymm3\t\n" + r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n" + r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" + ymm4 = _mm256_fmadd_ps(r12_0, ymm15, ymm4); //"vfmadd231ps ymm4,ymm15,YMMWORD PTR [r12 + 0]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm4); //"vmovups YMMWORD PTR [r12 + 0], ymm4\t\n" + r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 32]\t\n" + ymm5 = _mm256_fmadd_ps(r12_32, ymm15, ymm5); //"vfmadd231ps ymm5,ymm15,YMMWORD PTR [r12 + 32]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm5); //"vmovups YMMWORD PTR [r12 + 32], ymm5\t\n" + r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n" + r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" + ymm6 = _mm256_fmadd_ps(r12_0, ymm15, ymm6); //"vfmadd231ps ymm6,ymm15,YMMWORD PTR [r12 + 0]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm6); //"vmovups YMMWORD PTR [r12 + 0], ymm6\t\n" + r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 32]\t\n" + ymm7 = _mm256_fmadd_ps(r12_32, ymm15, ymm7); //"vfmadd231ps ymm7,ymm15,YMMWORD PTR [r12 + 32]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm7); //"vmovups YMMWORD PTR [r12 + 32], ymm7\t\n" + r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n" + r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n" + ymm8 = _mm256_fmadd_ps(r12_0, ymm15, ymm8); //"vfmadd231ps ymm8,ymm15,YMMWORD PTR [r12 + 0]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm8); //"vmovups YMMWORD PTR [r12 + 0], ymm8\t\n" + r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 32]\t\n" + ymm9 = _mm256_fmadd_ps(r12_32, ymm15, ymm9); //"vfmadd231ps ymm9,ymm15,YMMWORD PTR [r12 + 32]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm9); //"vmovups YMMWORD PTR [r12 + 32], ymm9\t\n" + r12 = (float*)((char*)r12 + r13); //"add r12, r13\t\n" + r12_0 = _mm256_load_ps((float*)((char*)r12 + 0)); //"vfmadd231ps ymm10,ymm15,YMMWORD PTR [r12 + 0]\t\n" + ymm10 = _mm256_fmadd_ps(r12_0, ymm15, ymm10); //"vfmadd231ps ymm10,ymm15,YMMWORD PTR [r12 + 0]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 0), ymm10); //"vmovups YMMWORD PTR [r12 + 0], ymm10\t\n" + r12_32 = _mm256_load_ps((float*)((char*)r12 + 32)); //"vfmadd231ps ymm11,ymm15,YMMWORD PTR [r12 + 32]\t\n" + ymm11 = _mm256_fmadd_ps(r12_32, ymm15, ymm11); //"vfmadd231ps ymm11,ymm15,YMMWORD PTR [r12 + 32]\t\n" + _mm256_storeu_ps((float*)((char*)r12 + 32), ymm11); //"vmovups YMMWORD PTR [r12 + 32], ymm11\t\n" + } //"L_done%=:\t\n" + + // next outer iteration + rcx = (float*)((char*)rcx + 64); //"add rcx, 64\t\n" + r12 = rcx; //"mov r12, rcx\t\n" + r9 = rax; //"mov r9, rax\t\n" + } //"inc rbx; cmp rbx, rdi; jl loop_outter%=\t\n" } + } // namespace fbgemm diff --git a/src/FbgemmFP16UKernelsAvx2.h b/src/FbgemmFP16UKernelsAvx2.h index 6e7dfbc..d48a88e 100644 --- a/src/FbgemmFP16UKernelsAvx2.h +++ b/src/FbgemmFP16UKernelsAvx2.h @@ -13,6 +13,11 @@ namespace fbgemm { using fp16 = float16; using fp32 = float; +#ifdef _MSC_VER + #define NOINLINE_ATTR __declspec(noinline) +#else + #define NOINLINE_ATTR __attribute__((noinline)) +#endif struct GemmParams { uint64_t k; float* A; @@ -24,12 +29,12 @@ struct GemmParams { uint64_t b_block_cols; uint64_t b_block_size; }; -void __attribute__((noinline)) gemmkernel_1x2_AVX2_fA0fB0fC0(GemmParams* gp); -void __attribute__((noinline)) gemmkernel_2x2_AVX2_fA0fB0fC0(GemmParams* gp); -void __attribute__((noinline)) gemmkernel_3x2_AVX2_fA0fB0fC0(GemmParams* gp); -void __attribute__((noinline)) gemmkernel_4x2_AVX2_fA0fB0fC0(GemmParams* gp); -void __attribute__((noinline)) gemmkernel_5x2_AVX2_fA0fB0fC0(GemmParams* gp); -void __attribute__((noinline)) gemmkernel_6x2_AVX2_fA0fB0fC0(GemmParams* gp); +void NOINLINE_ATTR gemmkernel_1x2_AVX2_fA0fB0fC0(GemmParams* gp); +void NOINLINE_ATTR gemmkernel_2x2_AVX2_fA0fB0fC0(GemmParams* gp); +void NOINLINE_ATTR gemmkernel_3x2_AVX2_fA0fB0fC0(GemmParams* gp); +void NOINLINE_ATTR gemmkernel_4x2_AVX2_fA0fB0fC0(GemmParams* gp); +void NOINLINE_ATTR gemmkernel_5x2_AVX2_fA0fB0fC0(GemmParams* gp); +void NOINLINE_ATTR gemmkernel_6x2_AVX2_fA0fB0fC0(GemmParams* gp); typedef void (*funcptr_fp16)(GemmParams* gp); ; diff --git a/src/FbgemmI8DepthwiseAvx2.cc b/src/FbgemmI8DepthwiseAvx2.cc index ee39faf..3ad24ce 100644 --- a/src/FbgemmI8DepthwiseAvx2.cc +++ b/src/FbgemmI8DepthwiseAvx2.cc @@ -13,6 +13,12 @@ #include +#ifdef _MSC_VER + #define ALWAYS_INLINE __forceinline +#else + #define ALWAYS_INLINE __attribute__((always_inline)) +#endif + using namespace std; namespace fbgemm { @@ -36,7 +42,7 @@ PackedDepthWiseConvMatrix::PackedDepthWiseConvMatrix( const int8_t* smat) : K_(K) { // Transpose the input matrix to make packing faster. - alignas(64) int8_t smat_transposed[K * KERNEL_PROD]; + alignas(64) int8_t* smat_transposed = new int8_t[K * KERNEL_PROD]; for (int i = 0; i < KERNEL_PROD; ++i) { for (int j = 0; j < K; ++j) { smat_transposed[i * K + j] = smat[i + j * KERNEL_PROD]; @@ -45,12 +51,15 @@ PackedDepthWiseConvMatrix::PackedDepthWiseConvMatrix( // Allocate packed arrays constexpr int KERNEL_PROD_ALIGNED = (KERNEL_PROD + 1) / 2 * 2; - // pmat_ = static_cast(fbgemmAlignedAlloc( - // 64, ((K + 31) / 32) * KERNEL_PROD_ALIGNED * 32 * sizeof(int8_t))); +#ifdef _MSC_VER + pmat_ = static_cast(_aligned_malloc( + ((K + 31) / 32) * KERNEL_PROD_ALIGNED * 32 * sizeof(int8_t), 64)); +#else posix_memalign( (void**)&pmat_, 64, ((K + 31) / 32) * KERNEL_PROD_ALIGNED * 32 * sizeof(int8_t)); +#endif // Pack input matrix // The layout is optimized to use vpmaddubsw efficiently (see @@ -160,11 +169,17 @@ PackedDepthWiseConvMatrix::PackedDepthWiseConvMatrix( b_interleaved_epi32[i]); } } + + delete smat_transposed; } template PackedDepthWiseConvMatrix::~PackedDepthWiseConvMatrix() { +#ifdef _MSC_VER + _aligned_free(pmat_); +#else free(pmat_); +#endif } template class PackedDepthWiseConvMatrix<3 * 3>; @@ -179,7 +194,7 @@ template class PackedDepthWiseConvMatrix<3 * 3 * 3>; // c2_v: c[8:12], c[24:28] // c3_v: c[12:16], c[28:32] template -static inline __attribute__((always_inline)) void madd_epi16x4_packed( +static inline ALWAYS_INLINE void madd_epi16x4_packed( __m256i a0_v, __m256i a1_v, __m256i a2_v, @@ -238,7 +253,7 @@ static inline __attribute__((always_inline)) void madd_epi16x4_packed( // c2_v: c[8:12], c[24:28] // c3_v: c[12:16], c[28:32] template -static inline __attribute__((always_inline)) void madd_epi16x3_packed( +static inline ALWAYS_INLINE void madd_epi16x3_packed( __m256i a0_v, __m256i a1_v, __m256i a2_v, @@ -298,7 +313,7 @@ static inline __attribute__((always_inline)) void madd_epi16x3_packed( // c2_v: c[16:20], c[20:24] // c3_v: c[24:28], c[28:32] template -static inline __attribute__((always_inline)) void madd_epi16x2_packed( +static inline ALWAYS_INLINE void madd_epi16x2_packed( __m256i a0_v, __m256i a1_v, const __m256i* b, @@ -339,7 +354,7 @@ static inline __attribute__((always_inline)) void madd_epi16x2_packed( // c2_v: c[16:20], c[20:24] // c3_v: c[24:28], c[28:32] template -static inline __attribute__((always_inline)) void madd_epi16_packed( +static inline ALWAYS_INLINE void madd_epi16_packed( __m256i a_v, const __m256i* b, __m256i* c0_v, @@ -374,7 +389,7 @@ static inline __attribute__((always_inline)) void madd_epi16_packed( // K is the number of accumulations we're doing template -static inline __attribute__((always_inline)) void inner_prod_packed_( +static inline ALWAYS_INLINE void inner_prod_packed_( const __m256i* a_v, const __m256i* Bp, int32_t* C, @@ -514,7 +529,7 @@ static inline __attribute__((always_inline)) void inner_prod_packed_( } template -static inline __attribute__((always_inline)) void inner_prod_3x3_packed_( +static inline ALWAYS_INLINE void inner_prod_3x3_packed_( const __m256i* a_v, const __m256i* Bp, int32_t* C, @@ -531,7 +546,7 @@ template < bool PER_CHANNEL_QUANTIZATION, bool A_SYMMETRIC, bool B_SYMMETRIC> -static inline __attribute__((always_inline)) void requantize_( +static inline ALWAYS_INLINE void requantize_( int32_t A_zero_point, const float* C_multiplier, int32_t C_zero_point, @@ -745,7 +760,7 @@ static inline __attribute__((always_inline)) void requantize_( } template -static inline __attribute__((always_inline)) __m256i load_a( +static inline ALWAYS_INLINE __m256i load_a( const uint8_t* A, __m256i mask_v) { if (REMAINDER) { @@ -759,7 +774,7 @@ template < bool SUM_A, bool REMAINDER = false, bool PER_CHANNEL_QUANTIZATION = false> -static inline __attribute__((always_inline)) void inner_prod_3x3_packed_( +static inline ALWAYS_INLINE void inner_prod_3x3_packed_( int H, int W, int K, @@ -870,7 +885,7 @@ template < bool SUM_A, bool REMAINDER = false, bool PER_CHANNEL_QUANTIZATION = false> -static inline __attribute__((always_inline)) void inner_prod_3x3x3_packed_( +static inline ALWAYS_INLINE void inner_prod_3x3x3_packed_( int T, int H, int W, @@ -1118,7 +1133,7 @@ static inline __attribute__((always_inline)) void inner_prod_3x3x3_packed_( } template -static inline __attribute__((always_inline)) void depthwise_3x3_kernel_( +static inline ALWAYS_INLINE void depthwise_3x3_kernel_( int H, int W, int K, @@ -1194,7 +1209,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_kernel_( } template -static inline __attribute__((always_inline)) void depthwise_3x3x3_kernel_( +static inline ALWAYS_INLINE void depthwise_3x3x3_kernel_( int T, int H, int W, @@ -1279,7 +1294,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_kernel_( } template -static inline __attribute__((always_inline)) void +static inline ALWAYS_INLINE void depthwise_3x3_per_channel_quantization_kernel_( int H, int W, @@ -1362,7 +1377,7 @@ depthwise_3x3_per_channel_quantization_kernel_( } template -static inline __attribute__((always_inline)) void +static inline ALWAYS_INLINE void depthwise_3x3x3_per_channel_quantization_kernel_( int T, int H, @@ -1465,7 +1480,7 @@ static pair closest_factors_(int n) { // filter shapes by parameterizing with R and S but restricting it to just 3x3 // for now. template -static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( +static inline ALWAYS_INLINE void depthwise_3x3_pad_1_( int N, int H, int W, @@ -1491,7 +1506,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; const int8_t* Bp = B.PackedMat(); - int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64))); + alignas(64) int32_t* row_offsets = new int32_t[(K + 31) / 32 * 32]; int n_begin, n_end; int h_begin, h_end, w_begin, w_end; @@ -1748,10 +1763,11 @@ static inline __attribute__((always_inline)) void depthwise_3x3_pad_1_( } } } // for each n + delete row_offsets; }; template -static inline __attribute__((always_inline)) void depthwise_3x3x3_pad_1_( +static inline ALWAYS_INLINE void depthwise_3x3x3_pad_1_( int N, int T, int H, @@ -1781,7 +1797,7 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_pad_1_( int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; const int8_t* Bp = B.PackedMat(); - int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64))); + alignas(64) int32_t* row_offsets = new int32_t[(K + 31) / 32 * 32]; // __attribute__((aligned(64))); int n_begin, n_end; int t_begin, t_end, h_begin, h_end; @@ -1858,10 +1874,12 @@ static inline __attribute__((always_inline)) void depthwise_3x3x3_pad_1_( } // h } // t } // for each n + + delete row_offsets; }; template -static inline __attribute__((always_inline)) void +static inline ALWAYS_INLINE void depthwise_3x3_per_channel_quantization_pad_1_( int N, int H, @@ -1888,7 +1906,7 @@ depthwise_3x3_per_channel_quantization_pad_1_( int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; const int8_t* Bp = B.PackedMat(); - int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64))); + alignas(64) int32_t* row_offsets = new int32_t[(K + 31) / 32 * 32]; // __attribute__((aligned(64))); int n_begin, n_end; int h_begin, h_end, w_begin, w_end; @@ -2172,10 +2190,12 @@ depthwise_3x3_per_channel_quantization_pad_1_( } } } // for each n + + delete row_offsets; }; template -static inline __attribute__((always_inline)) void +static inline ALWAYS_INLINE void depthwise_3x3x3_per_channel_quantization_pad_1_( int N, int T, @@ -2206,7 +2226,7 @@ depthwise_3x3x3_per_channel_quantization_pad_1_( int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; const int8_t* Bp = B.PackedMat(); - int32_t row_offsets[(K + 31) / 32 * 32] __attribute__((aligned(64))); + alignas(64) int32_t* row_offsets = new int32_t[(K + 31) / 32 * 32]; // __attribute__((aligned(64))); int n_begin, n_end; int t_begin, t_end, h_begin, h_end; @@ -2282,6 +2302,8 @@ depthwise_3x3x3_per_channel_quantization_pad_1_( } // h } // t } // for each n + + delete row_offsets; }; // Dispatch A_SYMMETRIC and B_SYMMETRIC @@ -2304,7 +2326,7 @@ static void depthwise_3x3_pad_1_( const int32_t* bias, int thread_id, int num_threads) { - int32_t C_int32_temp[(K + 31) / 32 * 32]; + int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32]; if (A_zero_point == 0 || col_offsets == nullptr) { if (B_zero_point == 0) { depthwise_3x3_pad_1_< @@ -2406,6 +2428,7 @@ static void depthwise_3x3_pad_1_( num_threads); } } + delete C_int32_temp; } // Dispatch HAS_BIAS @@ -2709,7 +2732,7 @@ static void depthwise_3x3x3_pad_1_( const int32_t* bias, int thread_id, int num_threads) { - int32_t C_int32_temp[(K + 31) / 32 * 32]; + int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32]; if (A_zero_point == 0 || col_offsets == nullptr) { if (B_zero_point == 0) { depthwise_3x3x3_pad_1_< @@ -2819,6 +2842,7 @@ static void depthwise_3x3x3_pad_1_( num_threads); } } + delete C_int32_temp; } // Dispatch HAS_BIAS @@ -2975,7 +2999,7 @@ static void depthwise_3x3_per_channel_quantization_pad_1_( const int32_t* bias, int thread_id, int num_threads) { - int32_t C_int32_temp[(K + 31) / 32 * 32]; + int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32]; if (A_zero_point == 0 || col_offsets == nullptr) { depthwise_3x3_per_channel_quantization_pad_1_< FUSE_RELU, @@ -3023,6 +3047,7 @@ static void depthwise_3x3_per_channel_quantization_pad_1_( thread_id, num_threads); } + delete C_int32_temp; } // Dispatch HAS_BIAS @@ -3329,7 +3354,7 @@ static void depthwise_3x3x3_per_channel_quantization_pad_1_( const int32_t* bias, int thread_id, int num_threads) { - int32_t C_int32_temp[(K + 31) / 32 * 32]; + int32_t* C_int32_temp = new int32_t[(K + 31) / 32 * 32]; if (A_zero_point == 0 || col_offsets == nullptr) { depthwise_3x3x3_per_channel_quantization_pad_1_< FUSE_RELU, @@ -3381,6 +3406,7 @@ static void depthwise_3x3x3_per_channel_quantization_pad_1_( thread_id, num_threads); } + delete C_int32_temp; } // Dispatch HAS_BIAS diff --git a/src/FbgemmI8Spmdm.cc b/src/FbgemmI8Spmdm.cc index 10e5a1b..74b7211 100644 --- a/src/FbgemmI8Spmdm.cc +++ b/src/FbgemmI8Spmdm.cc @@ -70,8 +70,8 @@ void CompressedSparseColumn::SpMDM( t_very_start = std::chrono::high_resolution_clock::now(); #endif - alignas(64) uint8_t A_buffer[K * 32]; - alignas(64) int32_t C_buffer[N * 32]; + alignas(64) uint8_t* A_buffer = new uint8_t[K * 32]; + alignas(64) int32_t* C_buffer = new int32_t[N * 32]; // If we compute C = C + A * B, where B is a sparse matrix in CSC format, for // each non-zero in B, we'd need to access the corresponding column in A. @@ -82,7 +82,7 @@ void CompressedSparseColumn::SpMDM( // The cost of transpose is O(K*N) and we do O(NNZ*N) multiplications. // If NNZ/K is small, it's not worth doing transpose so we just use this // scalar loop. - int32_t C_temp[block.row_size]; + int32_t* C_temp = new int32_t[block.row_size]; if (accumulation) { for (int j = 0; j < block.col_size; ++j) { int k = colptr_[block.col_start + j]; @@ -141,6 +141,9 @@ void CompressedSparseColumn::SpMDM( } } // for each column of B } + delete C_temp; + delete A_buffer; + delete C_buffer; return; } @@ -157,7 +160,7 @@ void CompressedSparseColumn::SpMDM( for (int i1 = block.row_start; i1 < i_end; i1 += 32) { // Transpose 32 x K submatrix of A if (i_end - i1 < 32) { - alignas(64) uint8_t A_temp_buffer[K * 32]; + alignas(64) uint8_t* A_temp_buffer = new uint8_t[K * 32]; for (int i2 = 0; i2 < (i_end - i1) / 8 * 8; i2 += 8) { transpose_8rows(K, A + (i1 + i2) * lda, lda, A_buffer + i2, 32); } @@ -173,6 +176,7 @@ void CompressedSparseColumn::SpMDM( for (int i2 = (i_end - i1) / 8 * 8; i2 < 32; i2 += 8) { transpose_8rows(K, A_temp_buffer + i2 * K, K, A_buffer + i2, 32); } + delete A_temp_buffer; } else { for (int i2 = 0; i2 < 32; i2 += 8) { transpose_8rows(K, A + (i1 + i2) * lda, lda, A_buffer + i2, 32); @@ -233,6 +237,9 @@ void CompressedSparseColumn::SpMDM( reinterpret_cast(C + (i1 - block.row_start) * ldc), ldc); + delete A_buffer; + delete C_buffer; + #ifdef FBGEMM_MEASURE_TIME_BREAKDOWN t_end = std::chrono::high_resolution_clock::now(); dt = std::chrono::duration_cast(t_end - t_start) diff --git a/src/PackAWithQuantRowOffset.cc b/src/PackAWithQuantRowOffset.cc index 7572a51..c9f8a65 100644 --- a/src/PackAWithQuantRowOffset.cc +++ b/src/PackAWithQuantRowOffset.cc @@ -111,7 +111,7 @@ void PackAWithQuantRowOffset::pack(const block_type_t& block) { (block.col_start % (this->numCols() / this->numGroups())) != 0; int32_t* row_offset_buf = getRowOffsetBuffer(); - float smat_transposed[block.row_size * block.col_size]; + float* smat_transposed = new float[block.row_size * block.col_size]; if (tr) { transpose_simd( block.col_size, @@ -150,6 +150,8 @@ void PackAWithQuantRowOffset::pack(const block_type_t& block) { out[i * BaseType::blockColSize() + j] = 0; } } + + delete smat_transposed; } template diff --git a/src/UtilsAvx512.cc b/src/UtilsAvx512.cc index 44d4f9a..b286eab 100644 --- a/src/UtilsAvx512.cc +++ b/src/UtilsAvx512.cc @@ -103,6 +103,40 @@ inline void transpose_kernel_16x16_avx512( // m1 n1 o1 p1 ... // m2 n2 o2 p2 ... // m3 n3 o3 p3 ... +#ifdef _MSC_VER + a = reinterpret_cast<__m512&>(_mm512_unpacklo_pd( + reinterpret_cast<__m512d&>(ta), reinterpret_cast<__m512d&>(tc))); + b = reinterpret_cast<__m512&>(_mm512_unpackhi_pd( + reinterpret_cast<__m512d&>(ta), reinterpret_cast<__m512d&>(tc))); + c = reinterpret_cast<__m512&>(_mm512_unpacklo_pd( + reinterpret_cast<__m512d&>(tb), reinterpret_cast<__m512d&>(td))); + d = reinterpret_cast<__m512&>(_mm512_unpackhi_pd( + reinterpret_cast<__m512d&>(tb), reinterpret_cast<__m512d&>(td))); + e = reinterpret_cast<__m512&>(_mm512_unpacklo_pd( + reinterpret_cast<__m512d&>(te), reinterpret_cast<__m512d&>(tg))); + f = reinterpret_cast<__m512&>(_mm512_unpackhi_pd( + reinterpret_cast<__m512d&>(te), reinterpret_cast<__m512d&>(tg))); + g = reinterpret_cast<__m512&>(_mm512_unpacklo_pd( + reinterpret_cast<__m512d&>(tf), reinterpret_cast<__m512d&>(th))); + h = reinterpret_cast<__m512&>(_mm512_unpackhi_pd( + reinterpret_cast<__m512d&>(tf), reinterpret_cast<__m512d&>(th))); + i = reinterpret_cast<__m512&>(_mm512_unpacklo_pd( + reinterpret_cast<__m512d&>(ti), reinterpret_cast<__m512d&>(tk))); + j = reinterpret_cast<__m512&>(_mm512_unpackhi_pd( + reinterpret_cast<__m512d&>(ti), reinterpret_cast<__m512d&>(tk))); + k = reinterpret_cast<__m512&>(_mm512_unpacklo_pd( + reinterpret_cast<__m512d&>(tj), reinterpret_cast<__m512d&>(tl))); + l = reinterpret_cast<__m512&>(_mm512_unpackhi_pd( + reinterpret_cast<__m512d&>(tj), reinterpret_cast<__m512d&>(tl))); + m = reinterpret_cast<__m512&>(_mm512_unpacklo_pd( + reinterpret_cast<__m512d&>(tm), reinterpret_cast<__m512d&>(to))); + n = reinterpret_cast<__m512&>(_mm512_unpackhi_pd( + reinterpret_cast<__m512d&>(tm), reinterpret_cast<__m512d&>(to))); + o = reinterpret_cast<__m512&>(_mm512_unpacklo_pd( + reinterpret_cast<__m512d&>(tn), reinterpret_cast<__m512d&>(tq))); + p = reinterpret_cast<__m512&>(_mm512_unpackhi_pd( + reinterpret_cast<__m512d&>(tn), reinterpret_cast<__m512d&>(tq))); +#else a = reinterpret_cast<__m512>(_mm512_unpacklo_pd( reinterpret_cast<__m512d>(ta), reinterpret_cast<__m512d>(tc))); b = reinterpret_cast<__m512>(_mm512_unpackhi_pd( @@ -135,6 +169,7 @@ inline void transpose_kernel_16x16_avx512( reinterpret_cast<__m512d>(tn), reinterpret_cast<__m512d>(tq))); p = reinterpret_cast<__m512>(_mm512_unpackhi_pd( reinterpret_cast<__m512d>(tn), reinterpret_cast<__m512d>(tq))); +#endif // shuffle 128-bits (composed of 4 32-bit elements) // a0 b0 c0 d0 a8 b8 c8 d8 e0 f0 g0 h0 e8 f8 g8 h8 diff --git a/src/codegen_fp16fp32.cc b/src/codegen_fp16fp32.cc index 7c8e10c..8f80593 100644 --- a/src/codegen_fp16fp32.cc +++ b/src/codegen_fp16fp32.cc @@ -18,10 +18,16 @@ using namespace std; +void addi(ofstream& of, string i, string asmstr = "", bool disable = false) { + if (disable == false) + of << " " + i + " //\"" + asmstr + "\\t\\n\"" + "\n"; +} +#if 0 void addi(ofstream& of, string i, bool disable = false) { if (disable == false) of << " \"" + i + "\\t\\n\"" + "\n"; } +#endif struct ISA { unsigned avx; // 1, 2 or 3 @@ -88,7 +94,8 @@ int main() { " * This source code is licensed under the BSD-style license found in the\n" " * LICENSE file in the root directory of this source tree.\n" " */\n"; - srcfile << "#include \"FbgemmFP16UKernelsAvx2.h\"\n\n"; + srcfile << "#include \"FbgemmFP16UKernelsAvx2.h\"\n"; + srcfile << "#include \n\n"; srcfile << "namespace fbgemm {\n\n"; if (iaca) { srcfile << "#include \"iacaMarks.h\"\n"; @@ -111,6 +118,11 @@ int main() { hdrfile << "namespace fbgemm {\n\n"; hdrfile << "using fp16 = float16;\n"; hdrfile << "using fp32 = float;\n"; + hdrfile << "#ifdef _MSC_VER\n"; + hdrfile << " #define NOINLINE_ATTR __declspec(noinline)\n"; + hdrfile << "#else\n"; + hdrfile << " #define NOINLINE_ATTR __attribute__((noinline))\n"; + hdrfile << "#endif\n"; hdrfile << "struct GemmParams {\n uint64_t k;\n float* A;\n const fp16* B;\n" " float* beta;\n uint64_t accum;\n float* C;\n uint64_t ldc;\n" @@ -158,8 +170,9 @@ int main() { fargs = "(" + p1 + ")"; +#if 1 fheader[k] = - "void __attribute__((noinline)) " + funcname[k] + fargs; + "void NOINLINE_ATTR " + funcname[k] + fargs; srcfile << fheader[k] << " {\n"; unsigned last_free_ymmreg = 0; @@ -183,85 +196,92 @@ int main() { assert(last_free_ymmreg <= 16); - srcfile << " asm volatile(\n"; + //srcfile << " asm volatile(\n"; - srcfile << "#if !defined(__clang__)" - << "\n"; - addi(srcfile, "mov r14, %[gp]"); - srcfile << "#else\n"; - addi(srcfile, "mov %[gp], %%r14"); - addi(srcfile, ".intel_syntax noprefix"); - srcfile << "#endif\n"; + //srcfile << "#if !defined(__clang__)" + //<< "\n"; + addi(srcfile, "char* r14 = (char*)gp;", "mov r14, %[gp]"); + //srcfile << "#else\n"; + //addi(srcfile, "mov %[gp], %%r14"); + //addi(srcfile, ".intel_syntax noprefix"); + //srcfile << "#endif\n"; srcfile << "\n // Copy parameters\n"; - srcfile << " // k\n"; - addi(srcfile, "mov r8, [r14 + 0]"); - srcfile << " // A\n"; - addi(srcfile, "mov r9, [r14 + 8]"); - srcfile << " // B\n"; - addi(srcfile, "mov r10, [r14 + 16]"); - srcfile << " // beta\n"; - addi(srcfile, "mov r15, [r14 + 24]"); - srcfile << " // accum\n"; - addi(srcfile, "mov rdx, [r14 + 32]"); - srcfile << " // C\n"; - addi(srcfile, "mov r12, [r14 + 40]"); - srcfile << " // ldc\n"; - addi(srcfile, "mov r13, [r14 + 48]"); - srcfile << " // b_block_cols\n"; - addi(srcfile, "mov rdi, [r14 + 56]"); - srcfile << " // b_block_size\n"; - addi(srcfile, "mov rsi, [r14 + 64]"); + srcfile << " // k\n"; addi(srcfile, "uint64_t r8 = *(uint64_t *)((char*)r14 + 0 );", "mov r8, [r14 + 0]"); + srcfile << " // A\n"; addi(srcfile, "float* r9 = *(float* *)((char*)r14 + 8 );", "mov r9, [r14 + 8]"); + srcfile << " // B\n"; addi(srcfile, "const fp16* r10 = *(const fp16**)((char*)r14 + 16);", "mov r10, [r14 + 16]"); + srcfile << " // beta\n"; addi(srcfile, "float* r15 = *(float* *)((char*)r14 + 24);", "mov r15, [r14 + 24]"); + srcfile << " // accum\n"; addi(srcfile, "uint64_t rdx = *(uint64_t *)((char*)r14 + 32);", "mov rdx, [r14 + 32]"); + srcfile << " // C\n"; addi(srcfile, "float* r12 = *(float* *)((char*)r14 + 40);", "mov r12, [r14 + 40]"); + srcfile << " // ldc\n"; addi(srcfile, "uint64_t r13 = *(uint64_t *)((char*)r14 + 48);", "mov r13, [r14 + 48]"); + srcfile << " // b_block_cols\n"; addi(srcfile, "uint64_t rdi = *(uint64_t *)((char*)r14 + 56);", "mov rdi, [r14 + 56]"); + srcfile << " // b_block_size\n"; addi(srcfile, "uint64_t rsi = *(uint64_t *)((char*)r14 + 64);", "mov rsi, [r14 + 64]"); srcfile << " // Make copies of A and C\n"; - addi(srcfile, "mov rax, r9"); - addi(srcfile, "mov rcx, r12"); + addi(srcfile, "float* rax = r9;", "mov rax, r9"); + addi(srcfile, "float* rcx = r12;", "mov rcx, r12"); srcfile << "\n"; - addi(srcfile, "mov rbx, 0"); + addi(srcfile, "uint64_t rbx = 0;", "mov rbx, 0"); string exitlabel = "L_exit%="; string label2 = "loop_outter%="; - addi(srcfile, label2 + ":"); - addi(srcfile, "mov r14, 0"); + addi(srcfile, "for (; rbx < rdi; ++rbx) {", "inc rbx; cmp rbx, rdi; jl " + label2); + addi(srcfile, "// ", label2 + ":"); + addi(srcfile, " uint64_t r14_i = 0;", "mov r14, 0"); // set all vCtile regs to zeros for (auto r = 0; r < vCtile.size(); r++) { for (auto c = 0; c < vCtile[r].size(); c++) { addi( srcfile, + " __m256 " + vCtile[r][c] + " = _mm256_setzero_ps();", "vxorps " + vCtile[r][c] + "," + vCtile[r][c] + "," + vCtile[r][c]); } } // start marker - if (iaca) { - addi(srcfile, "mov ebx, 111"); - addi(srcfile, ".byte 0x64, 0x67, 0x90"); - } + //if (iaca) { + // addi(srcfile, "mov ebx, 111"); + // addi(srcfile, ".byte 0x64, 0x67, 0x90"); + //} - srcfile << "\n"; + //srcfile << "\n"; srcfile << "\n"; string label = "loop_inner%="; - addi(srcfile, label + ":"); - srcfile << "\n"; + addi(srcfile, " for (; r14_i < r8; ++r14_i) {", "inc r14; cmp r14, r8; jl " + label); + addi(srcfile, " // " + label + ":"); + //srcfile << "\n"; for (int c = 0; c < vCtile[0].size(); c++) { + addi( + srcfile, + " auto fp16mem" + to_string(16 * c) + " = _mm_load_si128((__m128i*)((char*)r10 + " + to_string(16 * c) + "));", + "vcvtph2ps " + vBcol[c] + ",XMMWORD PTR [r10 + " + + to_string(16 * c) + "]"); addi( srcfile, + " auto " + vBcol[c] + " = _mm256_cvtph_ps(fp16mem" + to_string(16 * c) + ");", "vcvtph2ps " + vBcol[c] + ",XMMWORD PTR [r10 + " + to_string(16 * c) + "]"); } for (int r = 0; r < vCtile.size(); r++) { + //addi( + // srcfile, + // ((r == 0) ? " auto " + vAtmp : "" + vAtmp) + " = _mm256_broadcastss_ps(r9 + " + to_string(4 * r) + ");", + // "vbroadcastss " + vAtmp + ",DWORD PTR [r9+" + + // to_string(4 * r) + "]"); addi( srcfile, + ((r == 0) ? " auto " + vAtmp : " " + vAtmp) + " = _mm256_broadcast_ss((float*)((char*)r9 + " + to_string(4 * r) + "));", "vbroadcastss " + vAtmp + ",DWORD PTR [r9+" + to_string(4 * r) + "]"); for (int c = 0; c < vCtile[0].size(); c++) { addi( srcfile, + " " + vCtile[r][c] + " = _mm256_fmadd_ps(" + vAtmp + ", " + vBcol[c] + ", " + vCtile[r][c] + ");", "vfmadd231ps " + vCtile[r][c] + "," + vBcol[c] + "," + vAtmp); } @@ -269,21 +289,25 @@ int main() { addi( srcfile, + " r9 = (float*)((char*)r9 + " + to_string(4 * ukernel_shape[k][0]) + ");", "add r9," + to_string(4 * ukernel_shape[k][0]), fixedA); // move A ptr addi( srcfile, + " r10 = (fp16*)((char*)r10 + " + to_string(16 * ukernel_shape[k][1]) + ");", "add r10," + to_string(16 * ukernel_shape[k][1]), fixedA); // move A ptr - addi(srcfile, "inc r14"); - addi(srcfile, "cmp r14, r8"); - addi(srcfile, "jl " + label); + addi(srcfile, " }", "inc r14; cmp r14, r8; jl " + label2); + // move to for loop + //addi(srcfile, "inc r14"); + //addi(srcfile, "cmp r14, r8"); + //addi(srcfile, "jl " + label); - srcfile << "\n"; + //srcfile << "\n"; - addi(srcfile, exitlabel + ":"); + //addi(srcfile, exitlabel + ":"); // addi(srcfile, "add r10, rsi"); srcfile << "\n"; @@ -294,29 +318,33 @@ int main() { addi(srcfile, ".byte 0x64, 0x67, 0x90"); } - addi(srcfile, "cmp rdx, 1"); - addi(srcfile, "je L_accum%="); - srcfile << " // Dump C\n"; + //addi(srcfile, "cmp rdx, 1"); + addi(srcfile, " if(rdx != 1) {", "cmp rdx, 1; je L_accum%="); + + srcfile << " // Dump C\n"; for (auto r = 0; r < vCtile.size(); r++) { for (auto c = 0; c < vCtile[r].size(); c++) { addi( srcfile, + " _mm256_storeu_ps((float*)((char*)r12 + " + to_string(32 * c) + "), " + vCtile[r][c] + ");", "vmovups YMMWORD PTR [r12 + " + to_string(32 * c) + "], " + vCtile[r][c], fixedC); } - addi(srcfile, "add r12, r13", fixedC); // move C ptr + if (r != vCtile.size() - 1) + addi(srcfile, " r12 = (float*)((char*)r12 + r13);", "add r12, r13", fixedC); // move C ptr } - addi(srcfile, "jmp L_done%="); + addi(srcfile, " } else {", "jmp L_done%="); - srcfile << "\n"; - addi(srcfile, "L_accum%=:"); - srcfile << " // Dump C with accumulate\n"; + //srcfile << "\n"; + //addi(srcfile, "L_accum%=:"); + srcfile << " // Dump C with accumulate\n"; string r_spare = (s.avx == 1) ? "ymm14" : "ymm15"; addi( srcfile, + " auto " + r_spare + " = _mm256_broadcast_ss((float*)r15);", "vbroadcastss " + r_spare + string(",DWORD PTR [r15]"), fixedC); // store out C @@ -326,18 +354,29 @@ int main() { case 1: addi( srcfile, + "not supported", string("vmulps ymm15, ") + r_spare + comma + "YMMWORD PTR [r12 + " + to_string(32 * c) + "]", fixedC); addi( srcfile, + "not supported", "vaddps " + vCtile[r][c] + "," + vCtile[r][c] + "," + "ymm15", fixedC); break; case 2: + //if (r == 0) { + addi( + srcfile, + ((r == 0) ? " auto r12_" + to_string(32 * c) : " r12_" + to_string(32 * c)) + " = _mm256_load_ps((float*)((char*)r12 + " + to_string(32 * c) + "));", + "vfmadd231ps " + vCtile[r][c] + "," + r_spare + "," + + "YMMWORD PTR [r12 + " + to_string(32 * c) + "]", + fixedC); + //} addi( srcfile, + " " + vCtile[r][c] + " = _mm256_fmadd_ps(r12_" + to_string(32 * c) + ", " + r_spare + ", " + vCtile[r][c] + ");", "vfmadd231ps " + vCtile[r][c] + "," + r_spare + "," + "YMMWORD PTR [r12 + " + to_string(32 * c) + "]", fixedC); @@ -347,46 +386,283 @@ int main() { } addi( srcfile, + " _mm256_storeu_ps((float*)((char*)r12 + " + to_string(32 * c) + "), " + vCtile[r][c] + ");", "vmovups YMMWORD PTR [r12 + " + to_string(32 * c) + "], " + vCtile[r][c], fixedC); } - addi(srcfile, "add r12, r13", fixedC); // move C ptr + if (r != vCtile.size() - 1) + addi(srcfile, " r12 = (float*)((char*)r12 + r13);", "add r12, r13", fixedC); // move C ptr } - srcfile << "\n"; - addi(srcfile, "L_done%=:"); + //srcfile << "\n"; + addi(srcfile, " }", "L_done%=:"); - srcfile << "\n // next outer iteration\n"; + srcfile << "\n // next outer iteration\n"; // C addi( srcfile, + " rcx = (float*)((char*)rcx + " + to_string(32 * ukernel_shape[k][1]) + ");", "add rcx, " + to_string(32 * ukernel_shape[k][1]), fixedC); - addi(srcfile, "mov r12, rcx", fixedC); + addi(srcfile, " r12 = rcx;", "mov r12, rcx", fixedC); // A - addi(srcfile, "mov r9, rax"); - - addi(srcfile, "inc rbx"); - addi(srcfile, "cmp rbx, rdi"); - addi(srcfile, "jl " + label2); - - // output - srcfile << " :\n"; - // input - srcfile << " : [gp] \"rm\"(gp)\n"; - - // clobbered - srcfile - << " : \"r8\",\n \"r9\",\n \"r10\",\n" - " \"r11\",\n \"r15\",\n \"r13\",\n" - " \"r14\",\n \"rax\",\n \"rcx\",\n" - " \"rdx\",\n \"rsi\",\n \"rdi\",\n" - " \"rbx\",\n \"r12\",\n" - " \"memory\");\n"; - srcfile << "}\n"; + addi(srcfile, " r9 = rax;", "mov r9, rax"); + + // move to top for looop + //addi(srcfile, "inc rbx"); + //addi(srcfile, "cmp rbx, rdi"); + //addi(srcfile, "jl " + label2); + addi(srcfile, "}", "inc rbx; cmp rbx, rdi; jl " + label2); + + //// output + //srcfile << " :\n"; + //// input + //srcfile << " : [gp] \"rm\"(gp)\n"; + + //// clobbered + //srcfile + // << " : \"r8\",\n \"r9\",\n \"r10\",\n" + // " \"r11\",\n \"r15\",\n \"r13\",\n" + // " \"r14\",\n \"rax\",\n \"rcx\",\n" + // " \"rdx\",\n \"rsi\",\n \"rdi\",\n" + // " \"rbx\",\n \"r12\",\n" + // " \"memory\");\n"; + srcfile << "}\n\n"; + } + +#else + fheader[k] = + "void __attribute__((noinline)) " + funcname[k] + fargs; + srcfile << fheader[k] << " {\n"; + + unsigned last_free_ymmreg = 0; + // produce register block of C + vector> vCtile(ukernel_shape[k][0]); + for (auto r = 0; r < ukernel_shape[k][0]; r++) + for (auto c = 0; c < ukernel_shape[k][1]; c++) { + vCtile[r].push_back("ymm" + to_string(last_free_ymmreg)); + last_free_ymmreg++; + } + assert(last_free_ymmreg <= 14); + + string vAtmp = "ymm" + to_string(last_free_ymmreg++); + // produce register block of B col + vector vBcol(ukernel_shape[k][1]); + + for (auto c = 0; c < ukernel_shape[k][1]; c++) { + vBcol[c] = ("ymm" + to_string(last_free_ymmreg)); + last_free_ymmreg++; + } + + assert(last_free_ymmreg <= 16); + + srcfile << " asm volatile(\n"; + + srcfile << "#if !defined(__clang__)" + << "\n"; + addi(srcfile, "mov r14, %[gp]"); + srcfile << "#else\n"; + addi(srcfile, "mov %[gp], %%r14"); + addi(srcfile, ".intel_syntax noprefix"); + srcfile << "#endif\n"; + + srcfile << "\n // Copy parameters\n"; + srcfile << " // k\n"; + addi(srcfile, "mov r8, [r14 + 0]"); + srcfile << " // A\n"; + addi(srcfile, "mov r9, [r14 + 8]"); + srcfile << " // B\n"; + addi(srcfile, "mov r10, [r14 + 16]"); + srcfile << " // beta\n"; + addi(srcfile, "mov r15, [r14 + 24]"); + srcfile << " // accum\n"; + addi(srcfile, "mov rdx, [r14 + 32]"); + srcfile << " // C\n"; + addi(srcfile, "mov r12, [r14 + 40]"); + srcfile << " // ldc\n"; + addi(srcfile, "mov r13, [r14 + 48]"); + srcfile << " // b_block_cols\n"; + addi(srcfile, "mov rdi, [r14 + 56]"); + srcfile << " // b_block_size\n"; + addi(srcfile, "mov rsi, [r14 + 64]"); + srcfile << " // Make copies of A and C\n"; + addi(srcfile, "mov rax, r9"); + addi(srcfile, "mov rcx, r12"); + srcfile << "\n"; + + addi(srcfile, "mov rbx, 0"); + + string exitlabel = "L_exit%="; + string label2 = "loop_outter%="; + addi(srcfile, label2 + ":"); + addi(srcfile, "mov r14, 0"); + + // set all vCtile regs to zeros + for (auto r = 0; r < vCtile.size(); r++) { + for (auto c = 0; c < vCtile[r].size(); c++) { + addi( + srcfile, + "vxorps " + vCtile[r][c] + "," + vCtile[r][c] + "," + + vCtile[r][c]); + } + } + + // start marker + if (iaca) { + addi(srcfile, "mov ebx, 111"); + addi(srcfile, ".byte 0x64, 0x67, 0x90"); + } + + srcfile << "\n"; + + srcfile << "\n"; + string label = "loop_inner%="; + addi(srcfile, label + ":"); + srcfile << "\n"; + + for (int c = 0; c < vCtile[0].size(); c++) { + addi( + srcfile, + "vcvtph2ps " + vBcol[c] + ",XMMWORD PTR [r10 + " + + to_string(16 * c) + "]"); } + for (int r = 0; r < vCtile.size(); r++) { + addi( + srcfile, + "vbroadcastss " + vAtmp + ",DWORD PTR [r9+" + + to_string(4 * r) + "]"); + for (int c = 0; c < vCtile[0].size(); c++) { + addi( + srcfile, + "vfmadd231ps " + vCtile[r][c] + "," + vBcol[c] + "," + + vAtmp); + } + } + + addi( + srcfile, + "add r9," + to_string(4 * ukernel_shape[k][0]), + fixedA); // move A ptr + + addi( + srcfile, + "add r10," + to_string(16 * ukernel_shape[k][1]), + fixedA); // move A ptr + + addi(srcfile, "inc r14"); + addi(srcfile, "cmp r14, r8"); + addi(srcfile, "jl " + label); + + srcfile << "\n"; + + addi(srcfile, exitlabel + ":"); + + // addi(srcfile, "add r10, rsi"); + srcfile << "\n"; + + // end marker + if (iaca) { + addi(srcfile, "mov ebx, 222"); + addi(srcfile, ".byte 0x64, 0x67, 0x90"); + } + + addi(srcfile, "cmp rdx, 1"); + addi(srcfile, "je L_accum%="); + srcfile << " // Dump C\n"; + + for (auto r = 0; r < vCtile.size(); r++) { + for (auto c = 0; c < vCtile[r].size(); c++) { + addi( + srcfile, + "vmovups YMMWORD PTR [r12 + " + to_string(32 * c) + + "], " + vCtile[r][c], + fixedC); + } + addi(srcfile, "add r12, r13", fixedC); // move C ptr + } + addi(srcfile, "jmp L_done%="); + + srcfile << "\n"; + addi(srcfile, "L_accum%=:"); + srcfile << " // Dump C with accumulate\n"; + + string r_spare = (s.avx == 1) ? "ymm14" : "ymm15"; + addi( + srcfile, + "vbroadcastss " + r_spare + string(",DWORD PTR [r15]"), + fixedC); + // store out C + for (auto r = 0; r < vCtile.size(); r++) { + for (auto c = 0; c < vCtile[r].size(); c++) { + switch (s.avx) { + case 1: + addi( + srcfile, + string("vmulps ymm15, ") + r_spare + comma + + "YMMWORD PTR [r12 + " + to_string(32 * c) + "]", + fixedC); + addi( + srcfile, + "vaddps " + vCtile[r][c] + "," + vCtile[r][c] + "," + + "ymm15", + fixedC); + break; + case 2: + addi( + srcfile, + "vfmadd231ps " + vCtile[r][c] + "," + r_spare + "," + + "YMMWORD PTR [r12 + " + to_string(32 * c) + "]", + fixedC); + break; + default: + assert(0); + } + addi( + srcfile, + "vmovups YMMWORD PTR [r12 + " + to_string(32 * c) + + "], " + vCtile[r][c], + fixedC); + } + addi(srcfile, "add r12, r13", fixedC); // move C ptr + } + + srcfile << "\n"; + addi(srcfile, "L_done%=:"); + + srcfile << "\n // next outer iteration\n"; + // C + addi( + srcfile, + "add rcx, " + to_string(32 * ukernel_shape[k][1]), + fixedC); + addi(srcfile, "mov r12, rcx", fixedC); + // A + addi(srcfile, "mov r9, rax"); + + addi(srcfile, "inc rbx"); + addi(srcfile, "cmp rbx, rdi"); + addi(srcfile, "jl " + label2); + + // output + srcfile << " :\n"; + // input + srcfile << " : [gp] \"rm\"(gp)\n"; + + // clobbered + srcfile + << " : \"r8\",\n \"r9\",\n \"r10\",\n" + " \"r11\",\n \"r15\",\n \"r13\",\n" + " \"r14\",\n \"rax\",\n \"rcx\",\n" + " \"rdx\",\n \"rsi\",\n \"rdi\",\n" + " \"rbx\",\n \"r12\",\n" + " \"memory\");\n"; + srcfile << "}\n"; + } + +#endif + for (unsigned k = 0; k < ukernel_shape.size(); k++) { hdrfile << fheader[k] << ";\n"; } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index a7e531b..42c3e1f 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -18,8 +18,12 @@ macro(add_gtest TESTNAME) set_target_properties(${TESTNAME} PROPERTIES CXX_STANDARD 11 CXX_EXTENSIONS NO) - target_compile_options(${TESTNAME} PRIVATE - "-m64" "-mavx2" "-mfma" "-masm=intel") + if(MSVC) + target_compile_options(${TESTNAME} PRIVATE "/DFBGEMM_STATIC") + else(MSVC) + target_compile_options(${TESTNAME} PRIVATE + "-m64" "-mavx2" "-mfma" "-masm=intel") + endif(MSVC) target_link_libraries(${TESTNAME} gtest gmock gtest_main fbgemm) add_dependencies(${TESTNAME} gtest fbgemm) add_test(${TESTNAME} ${TESTNAME}) diff --git a/test/Im2ColFusedRequantizeTest.cc b/test/Im2ColFusedRequantizeTest.cc index d9c2f75..b14303f 100644 --- a/test/Im2ColFusedRequantizeTest.cc +++ b/test/Im2ColFusedRequantizeTest.cc @@ -7,6 +7,7 @@ #include #include #include +#include #ifdef _OPENMP #include diff --git a/test/PackedRequantizeAcc16Test.cc b/test/PackedRequantizeAcc16Test.cc index 55f6e7f..20f860e 100644 --- a/test/PackedRequantizeAcc16Test.cc +++ b/test/PackedRequantizeAcc16Test.cc @@ -9,6 +9,7 @@ #include #include #include +#include #ifdef _OPENMP #include -- cgit v1.2.3 From d402bed4f186a90fddcbdb43dd655b221e7673c2 Mon Sep 17 00:00:00 2001 From: Young Jin Kim Date: Thu, 13 Jun 2019 13:36:28 -0700 Subject: turn off forceinline due to the compile speed --- src/FbgemmI8DepthwiseAvx2.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/FbgemmI8DepthwiseAvx2.cc b/src/FbgemmI8DepthwiseAvx2.cc index 3ad24ce..402ffbd 100644 --- a/src/FbgemmI8DepthwiseAvx2.cc +++ b/src/FbgemmI8DepthwiseAvx2.cc @@ -14,7 +14,7 @@ #include #ifdef _MSC_VER - #define ALWAYS_INLINE __forceinline + #define ALWAYS_INLINE #else #define ALWAYS_INLINE __attribute__((always_inline)) #endif -- cgit v1.2.3 From b4e3a9ceb703e186637a884959c1153fc6e0f9b4 Mon Sep 17 00:00:00 2001 From: Young Jin Kim Date: Fri, 14 Jun 2019 09:22:49 -0700 Subject: Improve some memroy allocation codes --- bench/AlignedVec.h | 3 ++- include/fbgemm/FbgemmFP16.h | 3 ++- src/FbgemmI8DepthwiseAvx2.cc | 18 +++++++++--------- src/PackAWithQuantRowOffset.cc | 2 +- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/bench/AlignedVec.h b/bench/AlignedVec.h index 8b3ff14..0b8c7ce 100644 --- a/bench/AlignedVec.h +++ b/bench/AlignedVec.h @@ -102,7 +102,8 @@ class aligned_allocator { #ifdef _MSC_VER pv = _aligned_malloc(n * sizeof(T), Alignment); #else - posix_memalign(&pv, Alignment, n * sizeof(T)); + int result = posix_memalign(&pv, Alignment, n * sizeof(T)); + assert(result == 0); #endif // Allocators should throw std::bad_alloc in the case of memory allocation diff --git a/include/fbgemm/FbgemmFP16.h b/include/fbgemm/FbgemmFP16.h index f6eea21..3d84977 100644 --- a/include/fbgemm/FbgemmFP16.h +++ b/include/fbgemm/FbgemmFP16.h @@ -112,7 +112,8 @@ class PackedGemmMatrixFP16 { pmat_ = (float16 *)_aligned_malloc(matSize() * sizeof(float16) + padding, 64); #else - posix_memalign((void**)&pmat_, 64, matSize() * sizeof(float16) + padding); + int result = posix_memalign((void**)&pmat_, 64, matSize() * sizeof(float16) + padding); + assert(result == 0); #endif for (auto i = 0; i < matSize(); i++) { pmat_[i] = tconv(0.f, pmat_[i]); diff --git a/src/FbgemmI8DepthwiseAvx2.cc b/src/FbgemmI8DepthwiseAvx2.cc index 402ffbd..d90d47a 100644 --- a/src/FbgemmI8DepthwiseAvx2.cc +++ b/src/FbgemmI8DepthwiseAvx2.cc @@ -170,7 +170,7 @@ PackedDepthWiseConvMatrix::PackedDepthWiseConvMatrix( } } - delete smat_transposed; + delete[] smat_transposed; } template @@ -1763,7 +1763,7 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_( } } } // for each n - delete row_offsets; + delete[] row_offsets; }; template @@ -1875,7 +1875,7 @@ static inline ALWAYS_INLINE void depthwise_3x3x3_pad_1_( } // t } // for each n - delete row_offsets; + delete[] row_offsets; }; template @@ -2191,7 +2191,7 @@ depthwise_3x3_per_channel_quantization_pad_1_( } } // for each n - delete row_offsets; + delete[] row_offsets; }; template @@ -2303,7 +2303,7 @@ depthwise_3x3x3_per_channel_quantization_pad_1_( } // t } // for each n - delete row_offsets; + delete[] row_offsets; }; // Dispatch A_SYMMETRIC and B_SYMMETRIC @@ -2428,7 +2428,7 @@ static void depthwise_3x3_pad_1_( num_threads); } } - delete C_int32_temp; + delete[] C_int32_temp; } // Dispatch HAS_BIAS @@ -2842,7 +2842,7 @@ static void depthwise_3x3x3_pad_1_( num_threads); } } - delete C_int32_temp; + delete[] C_int32_temp; } // Dispatch HAS_BIAS @@ -3047,7 +3047,7 @@ static void depthwise_3x3_per_channel_quantization_pad_1_( thread_id, num_threads); } - delete C_int32_temp; + delete[] C_int32_temp; } // Dispatch HAS_BIAS @@ -3406,7 +3406,7 @@ static void depthwise_3x3x3_per_channel_quantization_pad_1_( thread_id, num_threads); } - delete C_int32_temp; + delete[] C_int32_temp; } // Dispatch HAS_BIAS diff --git a/src/PackAWithQuantRowOffset.cc b/src/PackAWithQuantRowOffset.cc index c9f8a65..305a298 100644 --- a/src/PackAWithQuantRowOffset.cc +++ b/src/PackAWithQuantRowOffset.cc @@ -151,7 +151,7 @@ void PackAWithQuantRowOffset::pack(const block_type_t& block) { } } - delete smat_transposed; + delete[] smat_transposed; } template -- cgit v1.2.3 From 696a8f5a6e0285e12dde6dddc7dcd204e8e80068 Mon Sep 17 00:00:00 2001 From: Young Jin Kim Date: Fri, 14 Jun 2019 09:23:22 -0700 Subject: missed file --- src/FbgemmI8Spmdm.cc | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/src/FbgemmI8Spmdm.cc b/src/FbgemmI8Spmdm.cc index 74b7211..0865bff 100644 --- a/src/FbgemmI8Spmdm.cc +++ b/src/FbgemmI8Spmdm.cc @@ -70,9 +70,6 @@ void CompressedSparseColumn::SpMDM( t_very_start = std::chrono::high_resolution_clock::now(); #endif - alignas(64) uint8_t* A_buffer = new uint8_t[K * 32]; - alignas(64) int32_t* C_buffer = new int32_t[N * 32]; - // If we compute C = C + A * B, where B is a sparse matrix in CSC format, for // each non-zero in B, we'd need to access the corresponding column in A. // This results in strided access, which we want to avoid. @@ -141,9 +138,7 @@ void CompressedSparseColumn::SpMDM( } } // for each column of B } - delete C_temp; - delete A_buffer; - delete C_buffer; + delete[] C_temp; return; } @@ -155,6 +150,9 @@ void CompressedSparseColumn::SpMDM( t_start = std::chrono::high_resolution_clock::now(); #endif + alignas(64) uint8_t* A_buffer = new uint8_t[K * 32]; + alignas(64) int32_t* C_buffer = new int32_t[N * 32]; + // Take 32 rows at a time int i_end = block.row_start + block.row_size; for (int i1 = block.row_start; i1 < i_end; i1 += 32) { @@ -176,7 +174,7 @@ void CompressedSparseColumn::SpMDM( for (int i2 = (i_end - i1) / 8 * 8; i2 < 32; i2 += 8) { transpose_8rows(K, A_temp_buffer + i2 * K, K, A_buffer + i2, 32); } - delete A_temp_buffer; + delete[] A_temp_buffer; } else { for (int i2 = 0; i2 < 32; i2 += 8) { transpose_8rows(K, A + (i1 + i2) * lda, lda, A_buffer + i2, 32); @@ -237,9 +235,6 @@ void CompressedSparseColumn::SpMDM( reinterpret_cast(C + (i1 - block.row_start) * ldc), ldc); - delete A_buffer; - delete C_buffer; - #ifdef FBGEMM_MEASURE_TIME_BREAKDOWN t_end = std::chrono::high_resolution_clock::now(); dt = std::chrono::duration_cast(t_end - t_start) @@ -257,6 +252,9 @@ void CompressedSparseColumn::SpMDM( spmdm_run_time += (dt); t_start = std::chrono::high_resolution_clock::now(); #endif + + delete[] A_buffer; + delete[] C_buffer; } void CompressedSparseColumn::SparseConv( -- cgit v1.2.3 From a838fc2a9c354b7da8a2663d300b64686a234247 Mon Sep 17 00:00:00 2001 From: Young Jin Kim Date: Fri, 14 Jun 2019 14:41:53 -0700 Subject: Fix memory allocation bug --- include/fbgemm/Utils.h | 10 ++++++++++ src/FbgemmI8DepthwiseAvx2.cc | 28 ++++++++++++---------------- src/FbgemmI8Spmdm.cc | 13 +++++++------ 3 files changed, 29 insertions(+), 22 deletions(-) diff --git a/include/fbgemm/Utils.h b/include/fbgemm/Utils.h index ef1d4ab..9f8e1ee 100644 --- a/include/fbgemm/Utils.h +++ b/include/fbgemm/Utils.h @@ -10,6 +10,16 @@ #include "FbgemmBuild.h" #include "UtilsAvx2.h" +#ifdef _MSC_VER +# define ALWAYS_INLINE // __forceinline +# define ALIGNED_MALLOC(size, alignment) _aligned_malloc(size, alignment) +# define FREE(ptr) _aligned_free(ptr) +#else +# define ALWAYS_INLINE __attribute__((always_inline)) +# define ALIGNED_MALLOC(size, alignment) aligned_alloc(alignment, size) +# define FREE(ptr) free(ptr) +#endif + namespace fbgemm { /** diff --git a/src/FbgemmI8DepthwiseAvx2.cc b/src/FbgemmI8DepthwiseAvx2.cc index d90d47a..f96d1d2 100644 --- a/src/FbgemmI8DepthwiseAvx2.cc +++ b/src/FbgemmI8DepthwiseAvx2.cc @@ -5,6 +5,7 @@ * LICENSE file in the root directory of this source tree. */ #include "fbgemm/FbgemmI8DepthwiseAvx2.h" +#include "fbgemm/Utils.h" #include // for min and max #include @@ -13,12 +14,6 @@ #include -#ifdef _MSC_VER - #define ALWAYS_INLINE -#else - #define ALWAYS_INLINE __attribute__((always_inline)) -#endif - using namespace std; namespace fbgemm { @@ -42,7 +37,8 @@ PackedDepthWiseConvMatrix::PackedDepthWiseConvMatrix( const int8_t* smat) : K_(K) { // Transpose the input matrix to make packing faster. - alignas(64) int8_t* smat_transposed = new int8_t[K * KERNEL_PROD]; + int8_t* smat_transposed = static_cast(ALIGNED_MALLOC( + K * KERNEL_PROD * sizeof(int8_t), 64)); for (int i = 0; i < KERNEL_PROD; ++i) { for (int j = 0; j < K; ++j) { smat_transposed[i * K + j] = smat[i + j * KERNEL_PROD]; @@ -170,7 +166,7 @@ PackedDepthWiseConvMatrix::PackedDepthWiseConvMatrix( } } - delete[] smat_transposed; + FREE(smat_transposed); } template @@ -1506,7 +1502,7 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_( int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; const int8_t* Bp = B.PackedMat(); - alignas(64) int32_t* row_offsets = new int32_t[(K + 31) / 32 * 32]; + int32_t* row_offsets = static_cast(ALIGNED_MALLOC(((K + 31) / 32 * 32)*sizeof(int32_t), 64)); int n_begin, n_end; int h_begin, h_end, w_begin, w_end; @@ -1763,7 +1759,7 @@ static inline ALWAYS_INLINE void depthwise_3x3_pad_1_( } } } // for each n - delete[] row_offsets; + FREE(row_offsets); }; template @@ -1797,7 +1793,7 @@ static inline ALWAYS_INLINE void depthwise_3x3x3_pad_1_( int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; const int8_t* Bp = B.PackedMat(); - alignas(64) int32_t* row_offsets = new int32_t[(K + 31) / 32 * 32]; // __attribute__((aligned(64))); + int32_t* row_offsets = static_cast(ALIGNED_MALLOC(((K + 31) / 32 * 32)*sizeof(int32_t), 64)); // __attribute__((aligned(64))); int n_begin, n_end; int t_begin, t_end, h_begin, h_end; @@ -1875,7 +1871,7 @@ static inline ALWAYS_INLINE void depthwise_3x3x3_pad_1_( } // t } // for each n - delete[] row_offsets; + FREE(row_offsets); }; template @@ -1906,7 +1902,7 @@ depthwise_3x3_per_channel_quantization_pad_1_( int W_OUT = (W + PAD_L + PAD_R - S) / stride_w + 1; const int8_t* Bp = B.PackedMat(); - alignas(64) int32_t* row_offsets = new int32_t[(K + 31) / 32 * 32]; // __attribute__((aligned(64))); + int32_t* row_offsets = static_cast(ALIGNED_MALLOC(((K + 31) / 32 * 32)*sizeof(int32_t), 64)); // __attribute__((aligned(64))); int n_begin, n_end; int h_begin, h_end, w_begin, w_end; @@ -2191,7 +2187,7 @@ depthwise_3x3_per_channel_quantization_pad_1_( } } // for each n - delete[] row_offsets; + FREE(row_offsets); }; template @@ -2226,7 +2222,7 @@ depthwise_3x3x3_per_channel_quantization_pad_1_( int W_OUT = (W + PAD_L + PAD_R - K_W) / stride_w + 1; const int8_t* Bp = B.PackedMat(); - alignas(64) int32_t* row_offsets = new int32_t[(K + 31) / 32 * 32]; // __attribute__((aligned(64))); + int32_t* row_offsets = static_cast(ALIGNED_MALLOC(((K + 31) / 32 * 32)*sizeof(int32_t), 64)); // __attribute__((aligned(64))); int n_begin, n_end; int t_begin, t_end, h_begin, h_end; @@ -2303,7 +2299,7 @@ depthwise_3x3x3_per_channel_quantization_pad_1_( } // t } // for each n - delete[] row_offsets; + FREE(row_offsets); }; // Dispatch A_SYMMETRIC and B_SYMMETRIC diff --git a/src/FbgemmI8Spmdm.cc b/src/FbgemmI8Spmdm.cc index 0865bff..edcc4e8 100644 --- a/src/FbgemmI8Spmdm.cc +++ b/src/FbgemmI8Spmdm.cc @@ -5,6 +5,7 @@ * LICENSE file in the root directory of this source tree. */ #include "fbgemm/FbgemmI8Spmdm.h" +#include "fbgemm/Utils.h" #include #include @@ -150,15 +151,15 @@ void CompressedSparseColumn::SpMDM( t_start = std::chrono::high_resolution_clock::now(); #endif - alignas(64) uint8_t* A_buffer = new uint8_t[K * 32]; - alignas(64) int32_t* C_buffer = new int32_t[N * 32]; + uint8_t* A_buffer = static_cast(ALIGNED_MALLOC(K * 32 * sizeof(uint8_t), 64)); + int32_t* C_buffer = static_cast(ALIGNED_MALLOC(N * 32 * sizeof(int32_t), 64)); // Take 32 rows at a time int i_end = block.row_start + block.row_size; for (int i1 = block.row_start; i1 < i_end; i1 += 32) { // Transpose 32 x K submatrix of A if (i_end - i1 < 32) { - alignas(64) uint8_t* A_temp_buffer = new uint8_t[K * 32]; + uint8_t* A_temp_buffer = static_cast(ALIGNED_MALLOC(K * 32 * sizeof(uint8_t), 64)); for (int i2 = 0; i2 < (i_end - i1) / 8 * 8; i2 += 8) { transpose_8rows(K, A + (i1 + i2) * lda, lda, A_buffer + i2, 32); } @@ -174,7 +175,7 @@ void CompressedSparseColumn::SpMDM( for (int i2 = (i_end - i1) / 8 * 8; i2 < 32; i2 += 8) { transpose_8rows(K, A_temp_buffer + i2 * K, K, A_buffer + i2, 32); } - delete[] A_temp_buffer; + FREE(A_temp_buffer); } else { for (int i2 = 0; i2 < 32; i2 += 8) { transpose_8rows(K, A + (i1 + i2) * lda, lda, A_buffer + i2, 32); @@ -253,8 +254,8 @@ void CompressedSparseColumn::SpMDM( t_start = std::chrono::high_resolution_clock::now(); #endif - delete[] A_buffer; - delete[] C_buffer; + FREE(A_buffer); + FREE(C_buffer); } void CompressedSparseColumn::SparseConv( -- cgit v1.2.3