Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/FBGEMM.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/OptimizedKernelsAvx2.cc')
-rw-r--r--src/OptimizedKernelsAvx2.cc51
1 files changed, 26 insertions, 25 deletions
diff --git a/src/OptimizedKernelsAvx2.cc b/src/OptimizedKernelsAvx2.cc
index e8c65c3..326bd72 100644
--- a/src/OptimizedKernelsAvx2.cc
+++ b/src/OptimizedKernelsAvx2.cc
@@ -7,6 +7,7 @@
#include "OptimizedKernelsAvx2.h"
#include <immintrin.h>
+#include "fbgemm/Utils.h"
using namespace std;
@@ -14,37 +15,37 @@ namespace fbgemm {
int32_t reduceAvx2(const uint8_t* A, int len) {
int32_t row_sum = 0;
-#if defined(__AVX2__)
- __m256i sum_v = _mm256_setzero_si256();
- __m256i one_epi16_v = _mm256_set1_epi16(1);
- __m256i one_epi8_v = _mm256_set1_epi8(1);
+ if (fbgemm::fbgemmHasAvx2Support()) {
+ __m256i sum_v = _mm256_setzero_si256();
+ __m256i one_epi16_v = _mm256_set1_epi16(1);
+ __m256i one_epi8_v = _mm256_set1_epi8(1);
- int i;
- // vectorized
- for (i = 0; i < len / 32 * 32; i += 32) {
- __m256i src_v = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i));
- sum_v = _mm256_add_epi32(
+ int i;
+ // vectorized
+ for (i = 0; i < len / 32 * 32; i += 32) {
+ __m256i src_v = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(A + i));
+ sum_v = _mm256_add_epi32(
sum_v,
_mm256_madd_epi16(
- _mm256_maddubs_epi16(src_v, one_epi8_v), one_epi16_v));
- }
-
- alignas(64) int32_t temp[8];
- _mm256_store_si256(reinterpret_cast<__m256i*>(temp), sum_v);
- for (int k = 0; k < 8; ++k) {
- row_sum += temp[k];
- }
+ _mm256_maddubs_epi16(src_v, one_epi8_v), one_epi16_v));
+ }
- // scalar
- for (; i < len; ++i) {
- row_sum += A[i];
- }
+ alignas(64) int32_t temp[8];
+ _mm256_store_si256(reinterpret_cast<__m256i*>(temp), sum_v);
+ for (int k = 0; k < 8; ++k) {
+ row_sum += temp[k];
+ }
-#else
- for (int i = 0; i < len; ++i) {
- row_sum += A[i];
+ // scalar
+ for (; i < len; ++i) {
+ row_sum += A[i];
+ }
+ } else {
+ for (int i = 0; i < len; ++i) {
+ row_sum += A[i];
+ }
}
-#endif
+
return row_sum;
}