Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'intgemm/intgemm.h')
-rw-r--r--intgemm/intgemm.h106
1 files changed, 27 insertions, 79 deletions
diff --git a/intgemm/intgemm.h b/intgemm/intgemm.h
index 8e2da02..977210d 100644
--- a/intgemm/intgemm.h
+++ b/intgemm/intgemm.h
@@ -49,11 +49,14 @@
#include "avx512_gemm.h"
#include "avx512vnni_gemm.h"
-#if defined(__INTEL_COMPILER)
+#if defined(WASM)
+// No header for CPUID since it's hard-coded.
+#elif defined(__INTEL_COMPILER)
#include <immintrin.h>
#elif defined(_MSC_VER)
#include <intrin.h>
-#elif defined(__GNUC__) || defined(__clang__)
+#else
+// Assume GCC and clang style.
#include <cpuid.h>
#endif
@@ -124,17 +127,25 @@ struct Unsupported_8bit {
#ifndef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
// These won't ever be called in this capacity, but it does let the code below compile.
-namespace avx512vnni {
+namespace AVX512VNNI {
typedef Unsupported_8bit Kernels8;
-} // namespace avx512vnni
+} // namespace AVX512VNNI
#endif
#ifndef INTGEMM_COMPILER_SUPPORTS_AVX512BW
-namespace avx512bw {
+namespace AVX512BW {
+typedef Unsupported_8bit Kernels8;
+typedef Unsupported_16bit Kernels16;
+} // namespace AVX512BW
+#endif
+#ifndef INTGEMM_COMPILER_SUPPORTS_AVX2
+namespace AVX2 {
typedef Unsupported_8bit Kernels8;
typedef Unsupported_16bit Kernels16;
-} // namespace avx512bw
+} // namespace AVX2
#endif
+CPUType GetCPUID();
+
/* Returns:
* axx512vnni if the CPU supports AVX512VNNI
*
@@ -148,72 +159,9 @@ typedef Unsupported_16bit Kernels16;
*
* unsupported otherwise
*/
-template <class T> T ChooseCPU(T
-#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
- avx512vnni
-#endif
- , T
-#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
- avx512bw
-#endif
- , T avx2, T ssse3, T sse2, T unsupported) {
-#if defined(__INTEL_COMPILER)
-# ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
- if (_may_i_use_cpu_feature(_FEATURE_AVX512_VNNI)) return avx512vnni;
-# endif
-# ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
- if (_may_i_use_cpu_feature(_FEATURE_AVX512BW)) return avx512bw;
-# endif
- if (_may_i_use_cpu_feature(_FEATURE_AVX2)) return avx2;
- if (_may_i_use_cpu_feature(_FEATURE_SSSE3)) return ssse3;
- if (_may_i_use_cpu_feature(_FEATURE_SSE2)) return sse2;
- return unsupported;
-#else
-// Everybody except Intel compiler.
-# if defined(_MSC_VER)
- int regs[4];
- int &eax = regs[0], &ebx = regs[1], &ecx = regs[2], &edx = regs[3];
- __cpuid(regs, 0);
- int m = eax;
-# else
- /* gcc and clang.
- * If intgemm is compiled by gcc 6.4.1 then dlopened into an executable
- * compiled by gcc 7.3.0, there will be a undefined symbol __cpu_info.
- * Work around this by calling the intrinsics more directly instead of
- * __builtin_cpu_supports.
- *
- * clang 6.0.0-1ubuntu2 supports vnni but doesn't have
- * __builtin_cpu_supports("avx512vnni")
- * so use the hand-coded CPUID for clang.
- */
- unsigned int m = __get_cpuid_max(0, 0);
- unsigned int eax, ebx, ecx, edx;
-# endif
- if (m >= 7) {
-# if defined(_MSC_VER)
- __cpuid(regs, 7);
-# else
- __cpuid_count(7, 0, eax, ebx, ecx, edx);
-# endif
-# ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
- if (ecx & (1 << 11)) return avx512vnni;
-# endif
-# ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
- if (ebx & (1 << 30)) return avx512bw;
-# endif
- if (ebx & (1 << 5)) return avx2;
- }
- if (m >= 1) {
-# if defined(_MSC_VER)
- __cpuid(regs, 1);
-# else
- __cpuid_count(1, 0, eax, ebx, ecx, edx);
-# endif
- if (ecx & (1 << 9)) return ssse3;
- if (edx & (1 << 26)) return sse2;
- }
- return unsupported;
-#endif
+template <class T> T ChooseCPU(T avx512vnni, T avx512bw, T avx2, T ssse3, T sse2, T unsupported) {
+ const T ret[] = {unsupported, sse2, ssse3, avx2, avx512bw, avx512vnni};
+ return ret[(int)GetCPUID()];
}
struct TileInfo {
@@ -280,7 +228,7 @@ private:
};
template <typename Callback>
-void (*Int8::MultiplyImpl<Callback>::run)(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(OMPParallelWrap<Callback, avx512vnni::Kernels8>, OMPParallelWrap<Callback, avx512bw::Kernels8>, OMPParallelWrap<Callback, avx2::Kernels8>, OMPParallelWrap<Callback, ssse3::Kernels8>, Unsupported_8bit::Multiply<Callback>, Unsupported_8bit::Multiply<Callback>);
+void (*Int8::MultiplyImpl<Callback>::run)(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(OMPParallelWrap<Callback, AVX512VNNI::Kernels8>, OMPParallelWrap<Callback, AVX512BW::Kernels8>, OMPParallelWrap<Callback, AVX2::Kernels8>, OMPParallelWrap<Callback, SSSE3::Kernels8>, Unsupported_8bit::Multiply<Callback>, Unsupported_8bit::Multiply<Callback>);
/*
* 8-bit matrix multiplication with shifting A by 127
@@ -344,14 +292,14 @@ private:
template <class Callback>
void (*Int8Shift::MultiplyImpl<Callback>::run)(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(
- OMPParallelWrap8Shift<Callback, avx512vnni::Kernels8>,
- OMPParallelWrap8Shift<Callback, avx512bw::Kernels8>,
- OMPParallelWrap8Shift<Callback, avx2::Kernels8>,
- OMPParallelWrap8Shift<Callback, ssse3::Kernels8>,
+ OMPParallelWrap8Shift<Callback, AVX512VNNI::Kernels8>,
+ OMPParallelWrap8Shift<Callback, AVX512BW::Kernels8>,
+ OMPParallelWrap8Shift<Callback, AVX2::Kernels8>,
+ OMPParallelWrap8Shift<Callback, SSSE3::Kernels8>,
Unsupported_8bit::Multiply8Shift<Callback>, Unsupported_8bit::Multiply8Shift<Callback>);
template <class Callback>
-void (*Int8Shift::PrepareBiasImpl<Callback>::run)(const int8_t *B, Index width, Index B_cols, Callback callback) = ChooseCPU(avx512vnni::Kernels8::PrepareBias<Callback>, avx512bw::Kernels8::PrepareBias<Callback>, avx2::Kernels8::PrepareBias<Callback>, ssse3::Kernels8::PrepareBias<Callback>, ssse3::Kernels8::PrepareBias<Callback>, Unsupported_8bit::PrepareBias);
+void (*Int8Shift::PrepareBiasImpl<Callback>::run)(const int8_t *B, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512VNNI::Kernels8::PrepareBias<Callback>, AVX512BW::Kernels8::PrepareBias<Callback>, AVX2::Kernels8::PrepareBias<Callback>, SSSE3::Kernels8::PrepareBias<Callback>, SSSE3::Kernels8::PrepareBias<Callback>, Unsupported_8bit::PrepareBias);
/*
* 16-bit matrix multiplication
@@ -407,7 +355,7 @@ private:
};
template <typename Callback>
-void (*Int16::MultiplyImpl<Callback>::run)(const int16_t *A, const int16_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(OMPParallelWrap<Callback, avx512bw::Kernels16> /*TODO VNNI 16-bit. */, OMPParallelWrap<Callback, avx512bw::Kernels16>, OMPParallelWrap<Callback, avx2::Kernels16>, OMPParallelWrap<Callback, sse2::Kernels16>, OMPParallelWrap<Callback, sse2::Kernels16>, Unsupported_16bit::Multiply<Callback>);
+void (*Int16::MultiplyImpl<Callback>::run)(const int16_t *A, const int16_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(OMPParallelWrap<Callback, AVX512BW::Kernels16> /*TODO VNNI 16-bit. */, OMPParallelWrap<Callback, AVX512BW::Kernels16>, OMPParallelWrap<Callback, AVX2::Kernels16>, OMPParallelWrap<Callback, SSE2::Kernels16>, OMPParallelWrap<Callback, SSE2::Kernels16>, Unsupported_16bit::Multiply<Callback>);
extern const CPUType kCPU;