diff options
author | jokeren <robinho364@gmail.com> | 2016-12-17 12:17:40 +0300 |
---|---|---|
committer | Soumith Chintala <soumith@gmail.com> | 2017-02-23 13:50:34 +0300 |
commit | 90681d4004b05e11b424d0c411b2b68660ee597b (patch) | |
tree | bfb9a878e8e50f680312656de1d4ff53f6afe0cd | |
parent | f223bb5e3978aa3bd034bc7cf7803d63c6ce043c (diff) |
Fix AVX2 bugs
-rw-r--r-- | lib/TH/THVector.c | 4 | ||||
-rw-r--r-- | lib/TH/vector/AVX.c | 38 | ||||
-rw-r--r-- | lib/TH/vector/AVX2.c | 44 |
3 files changed, 48 insertions, 38 deletions
diff --git a/lib/TH/THVector.c b/lib/TH/THVector.c index abf1e16..1c9ea24 100644 --- a/lib/TH/THVector.c +++ b/lib/TH/THVector.c @@ -19,6 +19,10 @@ #include "vector/AVX.c" #endif +#if defined(USE_AVX2) +#include "vector/AVX2.c" +#endif + #include "generic/THVectorDefault.c" #include "THGenerateAllTypes.h" diff --git a/lib/TH/vector/AVX.c b/lib/TH/vector/AVX.c index 74d684a..101134f 100644 --- a/lib/TH/vector/AVX.c +++ b/lib/TH/vector/AVX.c @@ -106,25 +106,6 @@ static void THDoubleVector_cadd_AVX(double *z, const double *x, const double *y, } } -static void THDoubleVector_cadd_AVX2(double *z, const double *x, const double *y, const double c, const ptrdiff_t n) { - ptrdiff_t i; - __m256d YMM15 = _mm256_set_pd(c, c, c, c); - __m256d YMM0, YMM1, YMM2, YMM3; - for (i=0; i<=((n)-8); i+=8) { - YMM0 = _mm256_loadu_pd(y+i); - YMM1 = _mm256_loadu_pd(y+i+4); - YMM2 = _mm256_loadu_pd(x+i); - YMM3 = _mm256_loadu_pd(x+i+4); - YMM2 = _mm256_fmadd_pd(YMM0, YMM15, YMM2); - YMM3 = _mm256_fmadd_pd(YMM1, YMM15, YMM3); - _mm256_storeu_pd(z+i, YMM2); - _mm256_storeu_pd(z+i+4, YMM3); - } - for (; i<(n); i++) { - z[i] = x[i] + y[i] * c; - } -} - static void THDoubleVector_add_AVX(double *y, const double *x, const double c, const ptrdiff_t n) { ptrdiff_t i; __m256d YMM15 = _mm256_set_pd(c, c, c, c); @@ -244,25 +225,6 @@ static void THFloatVector_cadd_AVX(float *z, const float *x, const float *y, con } } -static void THFloatVector_cadd_AVX2(float *z, const float *x, const float *y, const float c, const ptrdiff_t n) { - ptrdiff_t i; - __m256 YMM15 = _mm256_set_ps(c, c, c, c, c, c, c, c); - __m256 YMM0, YMM1, YMM2, YMM3; - for (i=0; i<=((n)-16); i+=16) { - YMM0 = _mm256_loadu_ps(y+i); - YMM1 = _mm256_loadu_ps(y+i+8); - YMM2 = _mm256_loadu_ps(x+i); - YMM3 = _mm256_loadu_ps(x+i+8); - YMM2 = _mm256_fmadd_ps(YMM0, YMM15, YMM2); - YMM3 = _mm256_fmadd_ps(YMM1, YMM15, YMM3); - _mm256_storeu_ps(z+i, YMM2); - _mm256_storeu_ps(z+i+8, YMM3); - } - for (; i<(n); i++) { - z[i] = x[i] + y[i] * c; - } -} - static void THFloatVector_add_AVX(float *y, const float *x, const float c, const ptrdiff_t n) { ptrdiff_t i; __m256 YMM15 = _mm256_set_ps(c, c, c, c, c, c, c, c); diff --git a/lib/TH/vector/AVX2.c b/lib/TH/vector/AVX2.c new file mode 100644 index 0000000..3ccfc82 --- /dev/null +++ b/lib/TH/vector/AVX2.c @@ -0,0 +1,44 @@ +#ifndef _MSC_VER +#include <x86intrin.h> +#else +#include <intrin.h> +#endif + +static void THDoubleVector_cadd_AVX2(double *z, const double *x, const double *y, const double c, const ptrdiff_t n) { + ptrdiff_t i; + __m256d YMM15 = _mm256_set_pd(c, c, c, c); + __m256d YMM0, YMM1, YMM2, YMM3; + for (i=0; i<=((n)-8); i+=8) { + YMM0 = _mm256_loadu_pd(y+i); + YMM1 = _mm256_loadu_pd(y+i+4); + YMM2 = _mm256_loadu_pd(x+i); + YMM3 = _mm256_loadu_pd(x+i+4); + YMM2 = _mm256_fmadd_pd(YMM0, YMM15, YMM2); + YMM3 = _mm256_fmadd_pd(YMM1, YMM15, YMM3); + _mm256_storeu_pd(z+i, YMM2); + _mm256_storeu_pd(z+i+4, YMM3); + } + for (; i<(n); i++) { + z[i] = x[i] + y[i] * c; + } +} + +static void THFloatVector_cadd_AVX2(float *z, const float *x, const float *y, const float c, const ptrdiff_t n) { + ptrdiff_t i; + __m256 YMM15 = _mm256_set_ps(c, c, c, c, c, c, c, c); + __m256 YMM0, YMM1, YMM2, YMM3; + for (i=0; i<=((n)-16); i+=16) { + YMM0 = _mm256_loadu_ps(y+i); + YMM1 = _mm256_loadu_ps(y+i+8); + YMM2 = _mm256_loadu_ps(x+i); + YMM3 = _mm256_loadu_ps(x+i+8); + YMM2 = _mm256_fmadd_ps(YMM0, YMM15, YMM2); + YMM3 = _mm256_fmadd_ps(YMM1, YMM15, YMM3); + _mm256_storeu_ps(z+i, YMM2); + _mm256_storeu_ps(z+i+8, YMM3); + } + for (; i<(n); i++) { + z[i] = x[i] + y[i] * c; + } +} + |