Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'intgemm/stats.h')
-rw-r--r--intgemm/stats.h72
1 files changed, 72 insertions, 0 deletions
diff --git a/intgemm/stats.h b/intgemm/stats.h
new file mode 100644
index 0000000..6f9eda2
--- /dev/null
+++ b/intgemm/stats.h
@@ -0,0 +1,72 @@
+#pragma once
+
+#include <cmath>
+#include "intrinsics.h"
+
+#ifdef _OPENMP
+#include <omp.h>
+#endif
+
+namespace intgemm {
+
+/* Horizontal max and sums. TODO make a template argument? */
+
+INTGEMM_SSE2 static inline float MaxFloat32(__m128 a) {
+ // Fold to just using the first 64 bits.
+ __m128 second_half = _mm_shuffle_ps(a, a, 3 * 4 + 2);
+ a = _mm_max_ps(a, second_half);
+ // Fold to just using the first 32 bits.
+ second_half = _mm_shuffle_ps(a, a, 1);
+ a = _mm_max_ps(a, second_half);
+ // This casting compiles to nothing.
+ return *reinterpret_cast<float*>(&a);
+}
+INTGEMM_SSE2 static inline float AddFloat32(__m128 a) {
+ // Fold to just using the first 64 bits.
+ __m128 second_half = _mm_shuffle_ps(a, a, 3 * 4 + 2);
+ a = _mm_add_ps(a, second_half);
+ // Fold to just using the first 32 bits.
+ second_half = _mm_shuffle_ps(a, a, 1);
+ a = _mm_add_ps(a, second_half);
+ // This casting compiles to nothing.
+ return *reinterpret_cast<float*>(&a);
+}
+
+INTGEMM_AVX2 static inline float MaxFloat32(__m256 a) {
+ return MaxFloat32(max_ps(_mm256_castps256_ps128(a), _mm256_extractf128_ps(a, 1)));
+}
+INTGEMM_AVX2 static inline float AddFloat32(__m256 a) {
+ return AddFloat32(add_ps(_mm256_castps256_ps128(a), _mm256_extractf128_ps(a, 1)));
+}
+
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
+// Find the maximum float.
+INTGEMM_AVX512F static inline float MaxFloat32(__m512 a) {
+ // _mm512_extractf32x8_ps is AVX512DQ but we don't care about masking.
+ // So cast to pd, do AVX512F _mm512_extractf64x4_pd, then cast to ps.
+ __m256 upper = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(a), 1));
+ return MaxFloat32(max_ps(_mm512_castps512_ps256(a), upper));
+}
+INTGEMM_AVX512F static inline float AddFloat32(__m512 a) {
+ __m256 upper = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(a), 1));
+ return AddFloat32(add_ps(_mm512_castps512_ps256(a), upper));
+}
+#endif
+
+constexpr int32_t kFloatAbsoluteMask = 0x7fffffff;
+
+} // namespace intgemm
+
+#define INTGEMM_THIS_IS_SSE2
+#include "stats.inl"
+#undef INTGEMM_THIS_IS_SSE2
+
+#define INTGEMM_THIS_IS_AVX2
+#include "stats.inl"
+#undef INTGEMM_THIS_IS_AVX2
+
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
+#define INTGEMM_THIS_IS_AVX512DQ
+#include "stats.inl"
+#undef INTGEMM_THIS_IS_AVX512DQ
+#endif