Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjokeren <robinho364@gmail.com>2016-12-17 10:22:51 +0300
committerSoumith Chintala <soumith@gmail.com>2017-02-23 13:50:34 +0300
commitf223bb5e3978aa3bd034bc7cf7803d63c6ce043c (patch)
tree52bc93298aa28e4f134d7c9187cb65f4485eecac
parent8137de1b1258bd38e5bb582b41e0fed02cbfbb08 (diff)
Add fma cadd
-rw-r--r--lib/TH/CMakeLists.txt11
-rw-r--r--lib/TH/THVector.c2
-rw-r--r--lib/TH/cmake/FindSSE.cmake13
-rw-r--r--lib/TH/generic/THVectorDispatch.c6
-rw-r--r--lib/TH/vector/AVX.c40
5 files changed, 66 insertions, 6 deletions
diff --git a/lib/TH/CMakeLists.txt b/lib/TH/CMakeLists.txt
index e21ee16..e11a964 100644
--- a/lib/TH/CMakeLists.txt
+++ b/lib/TH/CMakeLists.txt
@@ -124,15 +124,18 @@ ENDIF(C_SSE3_FOUND)
IF(C_AVX_FOUND)
SET(CMAKE_C_FLAGS "${C_AVX_FLAGS} -DUSE_AVX ${CMAKE_C_FLAGS}")
ENDIF(C_AVX_FOUND)
+IF(C_AVX2_FOUND)
+ SET(CMAKE_C_FLAGS "${C_AVX2_FLAGS} -DUSE_AVX2 ${CMAKE_C_FLAGS}")
+ENDIF(C_AVX2_FOUND)
-IF(C_AVX_FOUND OR C_SSE4_2_FOUND OR C_SSE4_1_FOUND)
+IF(C_AVX2_FOUND OR C_AVX_FOUND OR C_SSE4_2_FOUND OR C_SSE4_1_FOUND)
SET(simd generic/simd/convolve.c)
IF(MSVC)
SET_SOURCE_FILES_PROPERTIES(generic/simd/convolve.c PROPERTIES COMPILE_FLAGS "/std:c99")
ELSE(MSVC)
SET_SOURCE_FILES_PROPERTIES(generic/simd/convolve.c PROPERTIES COMPILE_FLAGS "-std=c99")
ENDIF(MSVC)
-ENDIF(C_AVX_FOUND OR C_SSE4_2_FOUND OR C_SSE4_1_FOUND)
+ENDIF(C_AVX2_FOUND OR C_AVX_FOUND OR C_SSE4_2_FOUND OR C_SSE4_1_FOUND)
IF(C_SSE4_1_FOUND)
SET(CMAKE_C_FLAGS "${C_SSE4_1_FLAGS} -DUSE_SSE4_1 ${CMAKE_C_FLAGS}")
@@ -150,7 +153,7 @@ IF(C_SSE4_1_FOUND OR C_SSE4_2_FOUND)
SET(simd ${simd} generic/simd/convolve5x5_sse.c)
ENDIF(C_SSE4_1_FOUND OR C_SSE4_2_FOUND)
-IF(C_AVX_FOUND)
+IF(C_AVX_FOUND OR C_AVX2_FOUND)
SET(CMAKE_C_FLAGS "-DUSE_AVX ${CMAKE_C_FLAGS}")
IF(MSVC)
SET_SOURCE_FILES_PROPERTIES(generic/simd/convolve5x5_avx.c PROPERTIES COMPILE_FLAGS "/Ox /fp:fast /arch:AVX /std:c99")
@@ -158,7 +161,7 @@ IF(C_AVX_FOUND)
SET_SOURCE_FILES_PROPERTIES(generic/simd/convolve5x5_avx.c PROPERTIES COMPILE_FLAGS "-O3 -ffast-math -mavx -std=c99")
ENDIF(MSVC)
SET(simd ${simd} generic/simd/convolve5x5_avx.c)
-ENDIF(C_AVX_FOUND)
+ENDIF(C_AVX_FOUND OR C_AVX2_FOUND)
SET(hdr
THGeneral.h THHalf.h THAllocator.h THStorage.h THTensor.h THTensorApply.h THBlas.h THMath.h
diff --git a/lib/TH/THVector.c b/lib/TH/THVector.c
index af086fd..abf1e16 100644
--- a/lib/TH/THVector.c
+++ b/lib/TH/THVector.c
@@ -15,7 +15,7 @@
#include "vector/SSE.c"
#endif
-#if defined(USE_AVX)
+#if defined(USE_AVX) || defined(USE_AVX2)
#include "vector/AVX.c"
#endif
diff --git a/lib/TH/cmake/FindSSE.cmake b/lib/TH/cmake/FindSSE.cmake
index d03cc19..f84ce89 100644
--- a/lib/TH/cmake/FindSSE.cmake
+++ b/lib/TH/cmake/FindSSE.cmake
@@ -68,6 +68,17 @@ SET(AVX_CODE "
}
")
+SET(AVX2_CODE "
+ #include <immintrin.h>
+
+ int main()
+ {
+ __m256i a;
+ a = _mm256_abs_epi16(a);
+ return 0;
+ }
+")
+
MACRO(CHECK_SSE lang type flags)
SET(__FLAG_I 1)
SET(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
@@ -103,6 +114,7 @@ CHECK_SSE(C "SSE3" " ;-msse3;/arch:SSE3")
CHECK_SSE(C "SSE4_1" " ;-msse4.1;-msse4;/arch:SSE4")
CHECK_SSE(C "SSE4_2" " ;-msse4.2;-msse4;/arch:SSE4")
CHECK_SSE(C "AVX" " ;-mavx;/arch:AVX")
+CHECK_SSE(C "AVX2" " ;-mavx2 -mfma;/arch:AVX2")
CHECK_SSE(CXX "SSE1" " ;-msse;/arch:SSE")
CHECK_SSE(CXX "SSE2" " ;-msse2;/arch:SSE2")
@@ -110,3 +122,4 @@ CHECK_SSE(CXX "SSE3" " ;-msse3;/arch:SSE3")
CHECK_SSE(CXX "SSE4_1" " ;-msse4.1;-msse4;/arch:SSE4")
CHECK_SSE(CXX "SSE4_2" " ;-msse4.2;-msse4;/arch:SSE4")
CHECK_SSE(CXX "AVX" " ;-mavx;/arch:AVX")
+CHECK_SSE(CXX "AVX2" " ;-mavx2 -mfma;/arch:AVX2")
diff --git a/lib/TH/generic/THVectorDispatch.c b/lib/TH/generic/THVectorDispatch.c
index 9563f48..8da42c4 100644
--- a/lib/TH/generic/THVectorDispatch.c
+++ b/lib/TH/generic/THVectorDispatch.c
@@ -52,6 +52,12 @@ static FunctionDescription THVector_(cadd_DISPATCHTABLE)[] = {
#endif
#endif
+ #if defined(USE_AVX2)
+ #if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)
+ FUNCTION_IMPL(THVector_(cadd_AVX2), SIMDExtension_AVX2),
+ #endif
+ #endif
+
#if defined(USE_AVX)
#if defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)
FUNCTION_IMPL(THVector_(cadd_AVX), SIMDExtension_AVX),
diff --git a/lib/TH/vector/AVX.c b/lib/TH/vector/AVX.c
index 1abfccf..74d684a 100644
--- a/lib/TH/vector/AVX.c
+++ b/lib/TH/vector/AVX.c
@@ -93,7 +93,7 @@ static void THDoubleVector_mul_AVX(double *y, const double *x, const double c, c
static void THDoubleVector_cadd_AVX(double *z, const double *x, const double *y, const double c, const ptrdiff_t n) {
ptrdiff_t i;
__m256d YMM15 = _mm256_set_pd(c, c, c, c);
- __m256d YMM0, YMM1, YMM2, YMM3, YMM4, YMM5;
+ __m256d YMM0, YMM1, YMM2, YMM3;
for (i=0; i<=((n)-4); i+=4) {
YMM0 = _mm256_loadu_pd(y+i);
YMM1 = _mm256_loadu_pd(x+i);
@@ -106,6 +106,25 @@ static void THDoubleVector_cadd_AVX(double *z, const double *x, const double *y,
}
}
+static void THDoubleVector_cadd_AVX2(double *z, const double *x, const double *y, const double c, const ptrdiff_t n) {
+ ptrdiff_t i;
+ __m256d YMM15 = _mm256_set_pd(c, c, c, c);
+ __m256d YMM0, YMM1, YMM2, YMM3;
+ for (i=0; i<=((n)-8); i+=8) {
+ YMM0 = _mm256_loadu_pd(y+i);
+ YMM1 = _mm256_loadu_pd(y+i+4);
+ YMM2 = _mm256_loadu_pd(x+i);
+ YMM3 = _mm256_loadu_pd(x+i+4);
+ YMM2 = _mm256_fmadd_pd(YMM0, YMM15, YMM2);
+ YMM3 = _mm256_fmadd_pd(YMM1, YMM15, YMM3);
+ _mm256_storeu_pd(z+i, YMM2);
+ _mm256_storeu_pd(z+i+4, YMM3);
+ }
+ for (; i<(n); i++) {
+ z[i] = x[i] + y[i] * c;
+ }
+}
+
static void THDoubleVector_add_AVX(double *y, const double *x, const double c, const ptrdiff_t n) {
ptrdiff_t i;
__m256d YMM15 = _mm256_set_pd(c, c, c, c);
@@ -225,6 +244,25 @@ static void THFloatVector_cadd_AVX(float *z, const float *x, const float *y, con
}
}
+static void THFloatVector_cadd_AVX2(float *z, const float *x, const float *y, const float c, const ptrdiff_t n) {
+ ptrdiff_t i;
+ __m256 YMM15 = _mm256_set_ps(c, c, c, c, c, c, c, c);
+ __m256 YMM0, YMM1, YMM2, YMM3;
+ for (i=0; i<=((n)-16); i+=16) {
+ YMM0 = _mm256_loadu_ps(y+i);
+ YMM1 = _mm256_loadu_ps(y+i+8);
+ YMM2 = _mm256_loadu_ps(x+i);
+ YMM3 = _mm256_loadu_ps(x+i+8);
+ YMM2 = _mm256_fmadd_ps(YMM0, YMM15, YMM2);
+ YMM3 = _mm256_fmadd_ps(YMM1, YMM15, YMM3);
+ _mm256_storeu_ps(z+i, YMM2);
+ _mm256_storeu_ps(z+i+8, YMM3);
+ }
+ for (; i<(n); i++) {
+ z[i] = x[i] + y[i] * c;
+ }
+}
+
static void THFloatVector_add_AVX(float *y, const float *x, const float c, const ptrdiff_t n) {
ptrdiff_t i;
__m256 YMM15 = _mm256_set_ps(c, c, c, c, c, c, c, c);