diff options
Diffstat (limited to 'intgemm/intgemm.h')
-rw-r--r-- | intgemm/intgemm.h | 106 |
1 files changed, 27 insertions, 79 deletions
diff --git a/intgemm/intgemm.h b/intgemm/intgemm.h index 8e2da02..977210d 100644 --- a/intgemm/intgemm.h +++ b/intgemm/intgemm.h @@ -49,11 +49,14 @@ #include "avx512_gemm.h" #include "avx512vnni_gemm.h" -#if defined(__INTEL_COMPILER) +#if defined(WASM) +// No header for CPUID since it's hard-coded. +#elif defined(__INTEL_COMPILER) #include <immintrin.h> #elif defined(_MSC_VER) #include <intrin.h> -#elif defined(__GNUC__) || defined(__clang__) +#else +// Assume GCC and clang style. #include <cpuid.h> #endif @@ -124,17 +127,25 @@ struct Unsupported_8bit { #ifndef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI // These won't ever be called in this capacity, but it does let the code below compile. -namespace avx512vnni { +namespace AVX512VNNI { typedef Unsupported_8bit Kernels8; -} // namespace avx512vnni +} // namespace AVX512VNNI #endif #ifndef INTGEMM_COMPILER_SUPPORTS_AVX512BW -namespace avx512bw { +namespace AVX512BW { +typedef Unsupported_8bit Kernels8; +typedef Unsupported_16bit Kernels16; +} // namespace AVX512BW +#endif +#ifndef INTGEMM_COMPILER_SUPPORTS_AVX2 +namespace AVX2 { typedef Unsupported_8bit Kernels8; typedef Unsupported_16bit Kernels16; -} // namespace avx512bw +} // namespace AVX2 #endif +CPUType GetCPUID(); + /* Returns: * axx512vnni if the CPU supports AVX512VNNI * @@ -148,72 +159,9 @@ typedef Unsupported_16bit Kernels16; * * unsupported otherwise */ -template <class T> T ChooseCPU(T -#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI - avx512vnni -#endif - , T -#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW - avx512bw -#endif - , T avx2, T ssse3, T sse2, T unsupported) { -#if defined(__INTEL_COMPILER) -# ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI - if (_may_i_use_cpu_feature(_FEATURE_AVX512_VNNI)) return avx512vnni; -# endif -# ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW - if (_may_i_use_cpu_feature(_FEATURE_AVX512BW)) return avx512bw; -# endif - if (_may_i_use_cpu_feature(_FEATURE_AVX2)) return avx2; - if (_may_i_use_cpu_feature(_FEATURE_SSSE3)) return ssse3; - if (_may_i_use_cpu_feature(_FEATURE_SSE2)) return sse2; - return unsupported; -#else -// Everybody except Intel compiler. -# if defined(_MSC_VER) - int regs[4]; - int &eax = regs[0], &ebx = regs[1], &ecx = regs[2], &edx = regs[3]; - __cpuid(regs, 0); - int m = eax; -# else - /* gcc and clang. - * If intgemm is compiled by gcc 6.4.1 then dlopened into an executable - * compiled by gcc 7.3.0, there will be a undefined symbol __cpu_info. - * Work around this by calling the intrinsics more directly instead of - * __builtin_cpu_supports. - * - * clang 6.0.0-1ubuntu2 supports vnni but doesn't have - * __builtin_cpu_supports("avx512vnni") - * so use the hand-coded CPUID for clang. - */ - unsigned int m = __get_cpuid_max(0, 0); - unsigned int eax, ebx, ecx, edx; -# endif - if (m >= 7) { -# if defined(_MSC_VER) - __cpuid(regs, 7); -# else - __cpuid_count(7, 0, eax, ebx, ecx, edx); -# endif -# ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI - if (ecx & (1 << 11)) return avx512vnni; -# endif -# ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW - if (ebx & (1 << 30)) return avx512bw; -# endif - if (ebx & (1 << 5)) return avx2; - } - if (m >= 1) { -# if defined(_MSC_VER) - __cpuid(regs, 1); -# else - __cpuid_count(1, 0, eax, ebx, ecx, edx); -# endif - if (ecx & (1 << 9)) return ssse3; - if (edx & (1 << 26)) return sse2; - } - return unsupported; -#endif +template <class T> T ChooseCPU(T avx512vnni, T avx512bw, T avx2, T ssse3, T sse2, T unsupported) { + const T ret[] = {unsupported, sse2, ssse3, avx2, avx512bw, avx512vnni}; + return ret[(int)GetCPUID()]; } struct TileInfo { @@ -280,7 +228,7 @@ private: }; template <typename Callback> -void (*Int8::MultiplyImpl<Callback>::run)(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(OMPParallelWrap<Callback, avx512vnni::Kernels8>, OMPParallelWrap<Callback, avx512bw::Kernels8>, OMPParallelWrap<Callback, avx2::Kernels8>, OMPParallelWrap<Callback, ssse3::Kernels8>, Unsupported_8bit::Multiply<Callback>, Unsupported_8bit::Multiply<Callback>); +void (*Int8::MultiplyImpl<Callback>::run)(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(OMPParallelWrap<Callback, AVX512VNNI::Kernels8>, OMPParallelWrap<Callback, AVX512BW::Kernels8>, OMPParallelWrap<Callback, AVX2::Kernels8>, OMPParallelWrap<Callback, SSSE3::Kernels8>, Unsupported_8bit::Multiply<Callback>, Unsupported_8bit::Multiply<Callback>); /* * 8-bit matrix multiplication with shifting A by 127 @@ -344,14 +292,14 @@ private: template <class Callback> void (*Int8Shift::MultiplyImpl<Callback>::run)(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU( - OMPParallelWrap8Shift<Callback, avx512vnni::Kernels8>, - OMPParallelWrap8Shift<Callback, avx512bw::Kernels8>, - OMPParallelWrap8Shift<Callback, avx2::Kernels8>, - OMPParallelWrap8Shift<Callback, ssse3::Kernels8>, + OMPParallelWrap8Shift<Callback, AVX512VNNI::Kernels8>, + OMPParallelWrap8Shift<Callback, AVX512BW::Kernels8>, + OMPParallelWrap8Shift<Callback, AVX2::Kernels8>, + OMPParallelWrap8Shift<Callback, SSSE3::Kernels8>, Unsupported_8bit::Multiply8Shift<Callback>, Unsupported_8bit::Multiply8Shift<Callback>); template <class Callback> -void (*Int8Shift::PrepareBiasImpl<Callback>::run)(const int8_t *B, Index width, Index B_cols, Callback callback) = ChooseCPU(avx512vnni::Kernels8::PrepareBias<Callback>, avx512bw::Kernels8::PrepareBias<Callback>, avx2::Kernels8::PrepareBias<Callback>, ssse3::Kernels8::PrepareBias<Callback>, ssse3::Kernels8::PrepareBias<Callback>, Unsupported_8bit::PrepareBias); +void (*Int8Shift::PrepareBiasImpl<Callback>::run)(const int8_t *B, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512VNNI::Kernels8::PrepareBias<Callback>, AVX512BW::Kernels8::PrepareBias<Callback>, AVX2::Kernels8::PrepareBias<Callback>, SSSE3::Kernels8::PrepareBias<Callback>, SSSE3::Kernels8::PrepareBias<Callback>, Unsupported_8bit::PrepareBias); /* * 16-bit matrix multiplication @@ -407,7 +355,7 @@ private: }; template <typename Callback> -void (*Int16::MultiplyImpl<Callback>::run)(const int16_t *A, const int16_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(OMPParallelWrap<Callback, avx512bw::Kernels16> /*TODO VNNI 16-bit. */, OMPParallelWrap<Callback, avx512bw::Kernels16>, OMPParallelWrap<Callback, avx2::Kernels16>, OMPParallelWrap<Callback, sse2::Kernels16>, OMPParallelWrap<Callback, sse2::Kernels16>, Unsupported_16bit::Multiply<Callback>); +void (*Int16::MultiplyImpl<Callback>::run)(const int16_t *A, const int16_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(OMPParallelWrap<Callback, AVX512BW::Kernels16> /*TODO VNNI 16-bit. */, OMPParallelWrap<Callback, AVX512BW::Kernels16>, OMPParallelWrap<Callback, AVX2::Kernels16>, OMPParallelWrap<Callback, SSE2::Kernels16>, OMPParallelWrap<Callback, SSE2::Kernels16>, Unsupported_16bit::Multiply<Callback>); extern const CPUType kCPU; |