diff options
Diffstat (limited to 'intgemm/intgemm.h')
-rw-r--r-- | intgemm/intgemm.h | 26 |
1 files changed, 13 insertions, 13 deletions
diff --git a/intgemm/intgemm.h b/intgemm/intgemm.h index a354b60..029a8ec 100644 --- a/intgemm/intgemm.h +++ b/intgemm/intgemm.h @@ -127,21 +127,21 @@ struct Unsupported_8bit { #ifndef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI // These won't ever be called in this capacity, but it does let the code below compile. -namespace avx512vnni { +namespace AVX512VNNI { typedef Unsupported_8bit Kernels8; -} // namespace avx512vnni +} // namespace AVX512VNNI #endif #ifndef INTGEMM_COMPILER_SUPPORTS_AVX512BW -namespace avx512bw { +namespace AVX512BW { typedef Unsupported_8bit Kernels8; typedef Unsupported_16bit Kernels16; -} // namespace avx512bw +} // namespace AVX512BW #endif #ifndef INTGEMM_COMPILER_SUPPORTS_AVX2 -namespace avx2 { +namespace AVX2 { typedef Unsupported_8bit Kernels8; typedef Unsupported_16bit Kernels16; -} // namespace avx2 +} // namespace AVX2 #endif @@ -309,7 +309,7 @@ private: }; template <typename Callback> -void (*Int8::MultiplyImpl<Callback>::run)(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(OMPParallelWrap<Callback, avx512vnni::Kernels8>, OMPParallelWrap<Callback, avx512bw::Kernels8>, OMPParallelWrap<Callback, avx2::Kernels8>, OMPParallelWrap<Callback, ssse3::Kernels8>, Unsupported_8bit::Multiply<Callback>, Unsupported_8bit::Multiply<Callback>); +void (*Int8::MultiplyImpl<Callback>::run)(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(OMPParallelWrap<Callback, AVX512VNNI::Kernels8>, OMPParallelWrap<Callback, AVX512BW::Kernels8>, OMPParallelWrap<Callback, AVX2::Kernels8>, OMPParallelWrap<Callback, SSSE3::Kernels8>, Unsupported_8bit::Multiply<Callback>, Unsupported_8bit::Multiply<Callback>); /* * 8-bit matrix multiplication with shifting A by 127 @@ -373,14 +373,14 @@ private: template <class Callback> void (*Int8Shift::MultiplyImpl<Callback>::run)(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU( - OMPParallelWrap8Shift<Callback, avx512vnni::Kernels8>, - OMPParallelWrap8Shift<Callback, avx512bw::Kernels8>, - OMPParallelWrap8Shift<Callback, avx2::Kernels8>, - OMPParallelWrap8Shift<Callback, ssse3::Kernels8>, + OMPParallelWrap8Shift<Callback, AVX512VNNI::Kernels8>, + OMPParallelWrap8Shift<Callback, AVX512BW::Kernels8>, + OMPParallelWrap8Shift<Callback, AVX2::Kernels8>, + OMPParallelWrap8Shift<Callback, SSSE3::Kernels8>, Unsupported_8bit::Multiply8Shift<Callback>, Unsupported_8bit::Multiply8Shift<Callback>); template <class Callback> -void (*Int8Shift::PrepareBiasImpl<Callback>::run)(const int8_t *B, Index width, Index B_cols, Callback callback) = ChooseCPU(avx512vnni::Kernels8::PrepareBias<Callback>, avx512bw::Kernels8::PrepareBias<Callback>, avx2::Kernels8::PrepareBias<Callback>, ssse3::Kernels8::PrepareBias<Callback>, ssse3::Kernels8::PrepareBias<Callback>, Unsupported_8bit::PrepareBias); +void (*Int8Shift::PrepareBiasImpl<Callback>::run)(const int8_t *B, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512VNNI::Kernels8::PrepareBias<Callback>, AVX512BW::Kernels8::PrepareBias<Callback>, AVX2::Kernels8::PrepareBias<Callback>, SSSE3::Kernels8::PrepareBias<Callback>, SSSE3::Kernels8::PrepareBias<Callback>, Unsupported_8bit::PrepareBias); /* * 16-bit matrix multiplication @@ -436,7 +436,7 @@ private: }; template <typename Callback> -void (*Int16::MultiplyImpl<Callback>::run)(const int16_t *A, const int16_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(OMPParallelWrap<Callback, avx512bw::Kernels16> /*TODO VNNI 16-bit. */, OMPParallelWrap<Callback, avx512bw::Kernels16>, OMPParallelWrap<Callback, avx2::Kernels16>, OMPParallelWrap<Callback, sse2::Kernels16>, OMPParallelWrap<Callback, sse2::Kernels16>, Unsupported_16bit::Multiply<Callback>); +void (*Int16::MultiplyImpl<Callback>::run)(const int16_t *A, const int16_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(OMPParallelWrap<Callback, AVX512BW::Kernels16> /*TODO VNNI 16-bit. */, OMPParallelWrap<Callback, AVX512BW::Kernels16>, OMPParallelWrap<Callback, AVX2::Kernels16>, OMPParallelWrap<Callback, SSE2::Kernels16>, OMPParallelWrap<Callback, SSE2::Kernels16>, Unsupported_16bit::Multiply<Callback>); extern const CPUType kCPU; |