diff options
author | jokeren <robinho364@gmail.com> | 2016-12-17 10:22:51 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-02-23 13:50:34 +0300 |
commit | f223bb5e3978aa3bd034bc7cf7803d63c6ce043c (patch) | |
tree | 52bc93298aa28e4f134d7c9187cb65f4485eecac | |
parent | 8137de1b1258bd38e5bb582b41e0fed02cbfbb08 (diff) |
Add fma cadd
-rw-r--r-- | lib/TH/CMakeLists.txt | 11 | ||||
-rw-r--r-- | lib/TH/THVector.c | 2 | ||||
-rw-r--r-- | lib/TH/cmake/FindSSE.cmake | 13 | ||||
-rw-r--r-- | lib/TH/generic/THVectorDispatch.c | 6 | ||||
-rw-r--r-- | lib/TH/vector/AVX.c | 40 |
5 files changed, 66 insertions, 6 deletions
diff --git a/lib/TH/CMakeLists.txt b/lib/TH/CMakeLists.txt index e21ee16..e11a964 100644 --- a/lib/TH/CMakeLists.txt +++ b/lib/TH/CMakeLists.txt @@ -124,15 +124,18 @@ ENDIF(C_SSE3_FOUND) IF(C_AVX_FOUND) SET(CMAKE_C_FLAGS "${C_AVX_FLAGS} -DUSE_AVX ${CMAKE_C_FLAGS}") ENDIF(C_AVX_FOUND) +IF(C_AVX2_FOUND) + SET(CMAKE_C_FLAGS "${C_AVX2_FLAGS} -DUSE_AVX2 ${CMAKE_C_FLAGS}") +ENDIF(C_AVX2_FOUND) -IF(C_AVX_FOUND OR C_SSE4_2_FOUND OR C_SSE4_1_FOUND) +IF(C_AVX2_FOUND OR C_AVX_FOUND OR C_SSE4_2_FOUND OR C_SSE4_1_FOUND) SET(simd generic/simd/convolve.c) IF(MSVC) SET_SOURCE_FILES_PROPERTIES(generic/simd/convolve.c PROPERTIES COMPILE_FLAGS "/std:c99") ELSE(MSVC) SET_SOURCE_FILES_PROPERTIES(generic/simd/convolve.c PROPERTIES COMPILE_FLAGS "-std=c99") ENDIF(MSVC) -ENDIF(C_AVX_FOUND OR C_SSE4_2_FOUND OR C_SSE4_1_FOUND) +ENDIF(C_AVX2_FOUND OR C_AVX_FOUND OR C_SSE4_2_FOUND OR C_SSE4_1_FOUND) IF(C_SSE4_1_FOUND) SET(CMAKE_C_FLAGS "${C_SSE4_1_FLAGS} -DUSE_SSE4_1 ${CMAKE_C_FLAGS}") @@ -150,7 +153,7 @@ IF(C_SSE4_1_FOUND OR C_SSE4_2_FOUND) SET(simd ${simd} generic/simd/convolve5x5_sse.c) ENDIF(C_SSE4_1_FOUND OR C_SSE4_2_FOUND) -IF(C_AVX_FOUND) +IF(C_AVX_FOUND OR C_AVX2_FOUND) SET(CMAKE_C_FLAGS "-DUSE_AVX ${CMAKE_C_FLAGS}") IF(MSVC) SET_SOURCE_FILES_PROPERTIES(generic/simd/convolve5x5_avx.c PROPERTIES COMPILE_FLAGS "/Ox /fp:fast /arch:AVX /std:c99") @@ -158,7 +161,7 @@ IF(C_AVX_FOUND) SET_SOURCE_FILES_PROPERTIES(generic/simd/convolve5x5_avx.c PROPERTIES COMPILE_FLAGS "-O3 -ffast-math -mavx -std=c99") ENDIF(MSVC) SET(simd ${simd} generic/simd/convolve5x5_avx.c) -ENDIF(C_AVX_FOUND) +ENDIF(C_AVX_FOUND OR C_AVX2_FOUND) SET(hdr THGeneral.h THHalf.h THAllocator.h THStorage.h THTensor.h THTensorApply.h THBlas.h THMath.h diff --git a/lib/TH/THVector.c b/lib/TH/THVector.c index af086fd..abf1e16 100644 --- a/lib/TH/THVector.c +++ b/lib/TH/THVector.c @@ -15,7 +15,7 @@ #include "vector/SSE.c" #endif -#if defined(USE_AVX) +#if defined(USE_AVX) || defined(USE_AVX2) #include "vector/AVX.c" #endif diff --git a/lib/TH/cmake/FindSSE.cmake b/lib/TH/cmake/FindSSE.cmake index d03cc19..f84ce89 100644 --- a/lib/TH/cmake/FindSSE.cmake +++ b/lib/TH/cmake/FindSSE.cmake @@ -68,6 +68,17 @@ SET(AVX_CODE " } ") +SET(AVX2_CODE " + #include <immintrin.h> + + int main() + { + __m256i a; + a = _mm256_abs_epi16(a); + return 0; + } +") + MACRO(CHECK_SSE lang type flags) SET(__FLAG_I 1) SET(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS}) @@ -103,6 +114,7 @@ CHECK_SSE(C "SSE3" " ;-msse3;/arch:SSE3") CHECK_SSE(C "SSE4_1" " ;-msse4.1;-msse4;/arch:SSE4") CHECK_SSE(C "SSE4_2" " ;-msse4.2;-msse4;/arch:SSE4") CHECK_SSE(C "AVX" " ;-mavx;/arch:AVX") +CHECK_SSE(C "AVX2" " ;-mavx2 -mfma;/arch:AVX2") CHECK_SSE(CXX "SSE1" " ;-msse;/arch:SSE") CHECK_SSE(CXX "SSE2" " ;-msse2;/arch:SSE2") @@ -110,3 +122,4 @@ CHECK_SSE(CXX "SSE3" " ;-msse3;/arch:SSE3") CHECK_SSE(CXX "SSE4_1" " ;-msse4.1;-msse4;/arch:SSE4") CHECK_SSE(CXX "SSE4_2" " ;-msse4.2;-msse4;/arch:SSE4") CHECK_SSE(CXX "AVX" " ;-mavx;/arch:AVX") +CHECK_SSE(CXX "AVX2" " ;-mavx2 -mfma;/arch:AVX2") diff --git a/lib/TH/generic/THVectorDispatch.c b/lib/TH/generic/THVectorDispatch.c index 9563f48..8da42c4 100644 --- a/lib/TH/generic/THVectorDispatch.c +++ b/lib/TH/generic/THVectorDispatch.c @@ -52,6 +52,12 @@ static FunctionDescription THVector_(cadd_DISPATCHTABLE)[] = { #endif #endif + #if defined(USE_AVX2) + #if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT) + FUNCTION_IMPL(THVector_(cadd_AVX2), SIMDExtension_AVX2), + #endif + #endif + #if defined(USE_AVX) #if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT) FUNCTION_IMPL(THVector_(cadd_AVX), SIMDExtension_AVX), diff --git a/lib/TH/vector/AVX.c b/lib/TH/vector/AVX.c index 1abfccf..74d684a 100644 --- a/lib/TH/vector/AVX.c +++ b/lib/TH/vector/AVX.c @@ -93,7 +93,7 @@ static void THDoubleVector_mul_AVX(double *y, const double *x, const double c, c static void THDoubleVector_cadd_AVX(double *z, const double *x, const double *y, const double c, const ptrdiff_t n) { ptrdiff_t i; __m256d YMM15 = _mm256_set_pd(c, c, c, c); - __m256d YMM0, YMM1, YMM2, YMM3, YMM4, YMM5; + __m256d YMM0, YMM1, YMM2, YMM3; for (i=0; i<=((n)-4); i+=4) { YMM0 = _mm256_loadu_pd(y+i); YMM1 = _mm256_loadu_pd(x+i); @@ -106,6 +106,25 @@ static void THDoubleVector_cadd_AVX(double *z, const double *x, const double *y, } } +static void THDoubleVector_cadd_AVX2(double *z, const double *x, const double *y, const double c, const ptrdiff_t n) { + ptrdiff_t i; + __m256d YMM15 = _mm256_set_pd(c, c, c, c); + __m256d YMM0, YMM1, YMM2, YMM3; + for (i=0; i<=((n)-8); i+=8) { + YMM0 = _mm256_loadu_pd(y+i); + YMM1 = _mm256_loadu_pd(y+i+4); + YMM2 = _mm256_loadu_pd(x+i); + YMM3 = _mm256_loadu_pd(x+i+4); + YMM2 = _mm256_fmadd_pd(YMM0, YMM15, YMM2); + YMM3 = _mm256_fmadd_pd(YMM1, YMM15, YMM3); + _mm256_storeu_pd(z+i, YMM2); + _mm256_storeu_pd(z+i+4, YMM3); + } + for (; i<(n); i++) { + z[i] = x[i] + y[i] * c; + } +} + static void THDoubleVector_add_AVX(double *y, const double *x, const double c, const ptrdiff_t n) { ptrdiff_t i; __m256d YMM15 = _mm256_set_pd(c, c, c, c); @@ -225,6 +244,25 @@ static void THFloatVector_cadd_AVX(float *z, const float *x, const float *y, con } } +static void THFloatVector_cadd_AVX2(float *z, const float *x, const float *y, const float c, const ptrdiff_t n) { + ptrdiff_t i; + __m256 YMM15 = _mm256_set_ps(c, c, c, c, c, c, c, c); + __m256 YMM0, YMM1, YMM2, YMM3; + for (i=0; i<=((n)-16); i+=16) { + YMM0 = _mm256_loadu_ps(y+i); + YMM1 = _mm256_loadu_ps(y+i+8); + YMM2 = _mm256_loadu_ps(x+i); + YMM3 = _mm256_loadu_ps(x+i+8); + YMM2 = _mm256_fmadd_ps(YMM0, YMM15, YMM2); + YMM3 = _mm256_fmadd_ps(YMM1, YMM15, YMM3); + _mm256_storeu_ps(z+i, YMM2); + _mm256_storeu_ps(z+i+8, YMM3); + } + for (; i<(n); i++) { + z[i] = x[i] + y[i] * c; + } +} + static void THFloatVector_add_AVX(float *y, const float *x, const float c, const ptrdiff_t n) { ptrdiff_t i; __m256 YMM15 = _mm256_set_ps(c, c, c, c, c, c, c, c); |