Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/torch/torch7.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjokeren <robinho364@gmail.com>2016-12-17 12:17:40 +0300
committerSoumith Chintala <soumith@gmail.com>2017-02-23 13:50:34 +0300
commit90681d4004b05e11b424d0c411b2b68660ee597b (patch)
treebfb9a878e8e50f680312656de1d4ff53f6afe0cd
parentf223bb5e3978aa3bd034bc7cf7803d63c6ce043c (diff)
Fix AVX2 bugs
-rw-r--r--lib/TH/THVector.c4
-rw-r--r--lib/TH/vector/AVX.c38
-rw-r--r--lib/TH/vector/AVX2.c44
3 files changed, 48 insertions, 38 deletions
diff --git a/lib/TH/THVector.c b/lib/TH/THVector.c
index abf1e16..1c9ea24 100644
--- a/lib/TH/THVector.c
+++ b/lib/TH/THVector.c
@@ -19,6 +19,10 @@
#include "vector/AVX.c"
#endif
+#if defined(USE_AVX2)
+#include "vector/AVX2.c"
+#endif
+
#include "generic/THVectorDefault.c"
#include "THGenerateAllTypes.h"
diff --git a/lib/TH/vector/AVX.c b/lib/TH/vector/AVX.c
index 74d684a..101134f 100644
--- a/lib/TH/vector/AVX.c
+++ b/lib/TH/vector/AVX.c
@@ -106,25 +106,6 @@ static void THDoubleVector_cadd_AVX(double *z, const double *x, const double *y,
}
}
-static void THDoubleVector_cadd_AVX2(double *z, const double *x, const double *y, const double c, const ptrdiff_t n) {
- ptrdiff_t i;
- __m256d YMM15 = _mm256_set_pd(c, c, c, c);
- __m256d YMM0, YMM1, YMM2, YMM3;
- for (i=0; i<=((n)-8); i+=8) {
- YMM0 = _mm256_loadu_pd(y+i);
- YMM1 = _mm256_loadu_pd(y+i+4);
- YMM2 = _mm256_loadu_pd(x+i);
- YMM3 = _mm256_loadu_pd(x+i+4);
- YMM2 = _mm256_fmadd_pd(YMM0, YMM15, YMM2);
- YMM3 = _mm256_fmadd_pd(YMM1, YMM15, YMM3);
- _mm256_storeu_pd(z+i, YMM2);
- _mm256_storeu_pd(z+i+4, YMM3);
- }
- for (; i<(n); i++) {
- z[i] = x[i] + y[i] * c;
- }
-}
-
static void THDoubleVector_add_AVX(double *y, const double *x, const double c, const ptrdiff_t n) {
ptrdiff_t i;
__m256d YMM15 = _mm256_set_pd(c, c, c, c);
@@ -244,25 +225,6 @@ static void THFloatVector_cadd_AVX(float *z, const float *x, const float *y, con
}
}
-static void THFloatVector_cadd_AVX2(float *z, const float *x, const float *y, const float c, const ptrdiff_t n) {
- ptrdiff_t i;
- __m256 YMM15 = _mm256_set_ps(c, c, c, c, c, c, c, c);
- __m256 YMM0, YMM1, YMM2, YMM3;
- for (i=0; i<=((n)-16); i+=16) {
- YMM0 = _mm256_loadu_ps(y+i);
- YMM1 = _mm256_loadu_ps(y+i+8);
- YMM2 = _mm256_loadu_ps(x+i);
- YMM3 = _mm256_loadu_ps(x+i+8);
- YMM2 = _mm256_fmadd_ps(YMM0, YMM15, YMM2);
- YMM3 = _mm256_fmadd_ps(YMM1, YMM15, YMM3);
- _mm256_storeu_ps(z+i, YMM2);
- _mm256_storeu_ps(z+i+8, YMM3);
- }
- for (; i<(n); i++) {
- z[i] = x[i] + y[i] * c;
- }
-}
-
static void THFloatVector_add_AVX(float *y, const float *x, const float c, const ptrdiff_t n) {
ptrdiff_t i;
__m256 YMM15 = _mm256_set_ps(c, c, c, c, c, c, c, c);
diff --git a/lib/TH/vector/AVX2.c b/lib/TH/vector/AVX2.c
new file mode 100644
index 0000000..3ccfc82
--- /dev/null
+++ b/lib/TH/vector/AVX2.c
@@ -0,0 +1,44 @@
+#ifndef _MSC_VER
+#include <x86intrin.h>
+#else
+#include <intrin.h>
+#endif
+
+static void THDoubleVector_cadd_AVX2(double *z, const double *x, const double *y, const double c, const ptrdiff_t n) {
+ ptrdiff_t i;
+ __m256d YMM15 = _mm256_set_pd(c, c, c, c);
+ __m256d YMM0, YMM1, YMM2, YMM3;
+ for (i=0; i<=((n)-8); i+=8) {
+ YMM0 = _mm256_loadu_pd(y+i);
+ YMM1 = _mm256_loadu_pd(y+i+4);
+ YMM2 = _mm256_loadu_pd(x+i);
+ YMM3 = _mm256_loadu_pd(x+i+4);
+ YMM2 = _mm256_fmadd_pd(YMM0, YMM15, YMM2);
+ YMM3 = _mm256_fmadd_pd(YMM1, YMM15, YMM3);
+ _mm256_storeu_pd(z+i, YMM2);
+ _mm256_storeu_pd(z+i+4, YMM3);
+ }
+ for (; i<(n); i++) {
+ z[i] = x[i] + y[i] * c;
+ }
+}
+
+static void THFloatVector_cadd_AVX2(float *z, const float *x, const float *y, const float c, const ptrdiff_t n) {
+ ptrdiff_t i;
+ __m256 YMM15 = _mm256_set_ps(c, c, c, c, c, c, c, c);
+ __m256 YMM0, YMM1, YMM2, YMM3;
+ for (i=0; i<=((n)-16); i+=16) {
+ YMM0 = _mm256_loadu_ps(y+i);
+ YMM1 = _mm256_loadu_ps(y+i+8);
+ YMM2 = _mm256_loadu_ps(x+i);
+ YMM3 = _mm256_loadu_ps(x+i+8);
+ YMM2 = _mm256_fmadd_ps(YMM0, YMM15, YMM2);
+ YMM3 = _mm256_fmadd_ps(YMM1, YMM15, YMM3);
+ _mm256_storeu_ps(z+i, YMM2);
+ _mm256_storeu_ps(z+i+8, YMM3);
+ }
+ for (; i<(n); i++) {
+ z[i] = x[i] + y[i] * c;
+ }
+}
+