Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNikolay Bogoychev <nheart@gmail.com>2020-01-17 18:27:12 +0300
committerNikolay Bogoychev <nheart@gmail.com>2020-01-17 18:27:12 +0300
commit7701239ac364b93dfa932d3042be8d30be2e6590 (patch)
treeb35e36239617771a9360e4d830731acd2e3c886d
parent7176bae962f4d1347caa49e81434cae654fa6217 (diff)
parent86feaac3c5049b27f4ef571965242d4a8fb1943c (diff)
Merge branch 'master' into debug_add127
-rw-r--r--CMakeLists.txt1
-rw-r--r--avx2_gemm.h4
-rw-r--r--benchmarks/benchmark.cc31
-rw-r--r--benchmarks/biasmultiply.cc4
-rw-r--r--callbacks.h24
-rw-r--r--callbacks/avx2.h13
-rw-r--r--callbacks/avx512.h19
-rw-r--r--callbacks/implementations.inl10
-rw-r--r--callbacks/output_buffer_info.h2
-rw-r--r--callbacks/sse2.h13
-rw-r--r--compile_test_avx512vnni.cc15
-rw-r--r--interleave.h52
-rw-r--r--intgemm.h44
-rw-r--r--intrinsics.h27
-rw-r--r--kernels.h25
-rw-r--r--kernels/avx2.h13
-rw-r--r--kernels/avx512.h19
-rw-r--r--kernels/implementations.inl7
-rw-r--r--kernels/sse2.h13
-rw-r--r--multiply.h7
-rw-r--r--sse2_gemm.h2
-rw-r--r--ssse3_gemm.h2
-rw-r--r--test/add127_test.cc2
-rw-r--r--test/kernels/add_bias_test.cc6
-rw-r--r--test/kernels/bitwise_not_test.cc6
-rw-r--r--test/kernels/downcast_test.cc6
-rw-r--r--test/kernels/exp_test.cc6
-rw-r--r--test/kernels/floor_test.cc6
-rw-r--r--test/kernels/multiply_sat_test.cc6
-rw-r--r--test/kernels/multiply_test.cc6
-rw-r--r--test/kernels/quantize_test.cc6
-rw-r--r--test/kernels/relu_test.cc6
-rw-r--r--test/kernels/rescale_test.cc6
-rw-r--r--test/kernels/sigmoid_test.cc6
-rw-r--r--test/kernels/tanh_test.cc6
-rw-r--r--test/kernels/unquantize_test.cc6
-rw-r--r--test/kernels/upcast_test.cc6
-rw-r--r--test/kernels/write_test.cc6
-rw-r--r--test/multiply_test.cc12
-rw-r--r--test/quantize_test.cc12
-rw-r--r--test/test.cc2
-rw-r--r--test/test.h6
-rw-r--r--test/utils_test.cc4
-rw-r--r--types.h1
44 files changed, 237 insertions, 239 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 9036c14..891cf11 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -39,7 +39,6 @@ endif()
# Generate configure file
configure_file(${CMAKE_CURRENT_SOURCE_DIR}/intgemm_config.h.in ${CMAKE_CURRENT_BINARY_DIR}/intgemm_config.h)
-include_directories(${CMAKE_CURRENT_SOURCE_DIR})
include_directories(${CMAKE_CURRENT_BINARY_DIR})
foreach(exe benchmark biasmultiply)
diff --git a/avx2_gemm.h b/avx2_gemm.h
index e2bbca0..7d240b9 100644
--- a/avx2_gemm.h
+++ b/avx2_gemm.h
@@ -83,7 +83,7 @@ struct AVX2_16bit {
INTGEMM_MULTIPLY16(__m256i, INTGEMM_AVX2, CPUType::AVX2)
- constexpr static const char *const kName = "16-bit INTGEMM_AVX2";
+ constexpr static const char *const kName = "16-bit AVX2";
static const CPUType kUses = CPUType::AVX2;
};
@@ -222,7 +222,7 @@ struct AVX2_8bit {
INTGEMM_PREPAREBIASFOR8(__m256i, INTGEMM_AVX2, CPUType::AVX2)
- constexpr static const char *const kName = "8-bit INTGEMM_AVX2";
+ constexpr static const char *const kName = "8-bit AVX2";
static const CPUType kUses = CPUType::AVX2;
};
diff --git a/benchmarks/benchmark.cc b/benchmarks/benchmark.cc
index 2810936..c36afbd 100644
--- a/benchmarks/benchmark.cc
+++ b/benchmarks/benchmark.cc
@@ -1,12 +1,12 @@
-#include "aligned.h"
+#include "../aligned.h"
#include "intgemm_config.h"
-#include "avx512_gemm.h"
-#include "sse2_gemm.h"
-#include "avx2_gemm.h"
-#include "ssse3_gemm.h"
-#include "intgemm.h"
-#include "stop_watch.h"
-#include "callbacks.h"
+#include "../avx512_gemm.h"
+#include "../sse2_gemm.h"
+#include "../avx2_gemm.h"
+#include "../ssse3_gemm.h"
+#include "../intgemm.h"
+#include "../stop_watch.h"
+#include "../callbacks.h"
#include <algorithm>
#include <cassert>
@@ -101,6 +101,7 @@ struct BackendStats {
std::vector<std::vector<uint64_t>> ssse3_8bit;
std::vector<std::vector<uint64_t>> avx2_8bit;
std::vector<std::vector<uint64_t>> avx512_8bit;
+ std::vector<std::vector<uint64_t>> avx512vnni_8bit;
std::vector<std::vector<uint64_t>> sse2_16bit;
std::vector<std::vector<uint64_t>> avx2_16bit;
std::vector<std::vector<uint64_t>> avx512_16bit;
@@ -122,12 +123,12 @@ void Summarize(std::vector<uint64_t> &stats) {
stddev += off * off;
}
stddev = sqrt(stddev / (keep - stats.begin() - 1));
- std::cout << std::setw(8) << *std::min_element(stats.begin(), stats.end()) << '\t' << std::setw(8) << avg << '\t' << std::setw(8) << stddev;
+ std::cout << std::setw(10) << *std::min_element(stats.begin(), stats.end()) << '\t' << std::setw(8) << avg << '\t' << std::setw(8) << stddev;
}
template <class Backend> void Print(std::vector<std::vector<uint64_t>> &stats, int index) {
if (stats.empty()) return;
- std::cout << Backend::kName << '\t';
+ std::cout << std::setw(16) << Backend::kName << '\t';
Summarize(stats[index]);
std::cout << '\n';
}
@@ -208,6 +209,13 @@ int main(int argc, char ** argv) {
RunAll<AVX512_16bit>(matrices, end, stats.avx512_16bit);
}
#endif
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
+ std::cerr << "AVX512VNNI 8bit, 100 samples..." << std::endl;
+ for (int samples = 0; samples < kSamples; ++samples) {
+ RandomMatrices *end = (samples < 4) ? matrices_end : full_sample;
+ RunAll<AVX512VNNI_8bit>(matrices, end, stats.avx512vnni_8bit);
+ }
+#endif
if (stats.sse2_16bit.empty()) {
std::cerr << "No CPU support." << std::endl;
@@ -220,6 +228,9 @@ int main(int argc, char ** argv) {
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512
Print<AVX512_8bit>(stats.avx512_8bit, i);
#endif
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
+ Print<AVX512VNNI_8bit>(stats.avx512vnni_8bit, i);
+#endif
Print<SSE2_16bit>(stats.sse2_16bit, i);
Print<AVX2_16bit>(stats.avx2_16bit, i);
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512
diff --git a/benchmarks/biasmultiply.cc b/benchmarks/biasmultiply.cc
index ec147ce..ec8ca95 100644
--- a/benchmarks/biasmultiply.cc
+++ b/benchmarks/biasmultiply.cc
@@ -1,5 +1,5 @@
-#include "intgemm.h"
-#include "aligned.h"
+#include "../intgemm.h"
+#include "../aligned.h"
#include <chrono>
#include <random>
#include <iostream>
diff --git a/callbacks.h b/callbacks.h
index da3e88f..c8a29df 100644
--- a/callbacks.h
+++ b/callbacks.h
@@ -3,6 +3,24 @@
#include "callbacks/configs.h"
#include "callbacks/output_buffer_info.h"
-#include "callbacks/sse2.h"
-#include "callbacks/avx2.h"
-#include "callbacks/avx512.h"
+#include "intgemm_config.h"
+#include "intrinsics.h"
+#include "kernels.h"
+#include "types.h"
+#include "utils.h"
+#include "vec_traits.h"
+
+#define CALLBACKS_THIS_IS_SSE2
+#include "callbacks/implementations.inl"
+#undef CALLBACKS_THIS_IS_SSE2
+
+#define CALLBACKS_THIS_IS_AVX2
+#include "callbacks/implementations.inl"
+#undef CALLBACKS_THIS_IS_AVX2
+
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512
+#define CALLBACKS_THIS_IS_AVX512BW
+#include "callbacks/implementations.inl"
+#undef CALLBACKS_THIS_IS_AVX512BW
+#endif
+
diff --git a/callbacks/avx2.h b/callbacks/avx2.h
deleted file mode 100644
index 76b2605..0000000
--- a/callbacks/avx2.h
+++ /dev/null
@@ -1,13 +0,0 @@
-#pragma once
-
-#define CALLBACKS_THIS_IS_AVX2
-#include "callbacks/implementations.inl"
-#undef CALLBACKS_THIS_IS_AVX2
-
-namespace intgemm {
-namespace callbacks {
-
-// Put here callbacks supported only by AVX2...
-
-}
-}
diff --git a/callbacks/avx512.h b/callbacks/avx512.h
deleted file mode 100644
index 3e101dd..0000000
--- a/callbacks/avx512.h
+++ /dev/null
@@ -1,19 +0,0 @@
-#pragma once
-
-#include "intgemm_config.h"
-
-#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512
-
-#define CALLBACKS_THIS_IS_AVX512BW
-#include "callbacks/implementations.inl"
-#undef CALLBACKS_THIS_IS_AVX512BW
-
-namespace intgemm {
-namespace callbacks {
-
-// Put here callbacks supported only by AVX512BW...
-
-}
-}
-
-#endif
diff --git a/callbacks/implementations.inl b/callbacks/implementations.inl
index 4541664..dce89b2 100644
--- a/callbacks/implementations.inl
+++ b/callbacks/implementations.inl
@@ -1,12 +1,4 @@
-#include "callbacks/configs.h"
-#include "callbacks/output_buffer_info.h"
-
-#include "intrinsics.h"
-#include "kernels.h"
-#include "types.h"
-#include "utils.h"
-#include "vec_traits.h"
-
+/* This file is included multiple times, once per architecture. */
#if defined(CALLBACKS_THIS_IS_SSE2)
#define CPU_NAME SSE2
#define CPU_ATTR INTGEMM_SSE2
diff --git a/callbacks/output_buffer_info.h b/callbacks/output_buffer_info.h
index fa86587..213aef4 100644
--- a/callbacks/output_buffer_info.h
+++ b/callbacks/output_buffer_info.h
@@ -1,6 +1,6 @@
#pragma once
-#include "types.h"
+#include "../types.h"
namespace intgemm {
namespace callbacks {
diff --git a/callbacks/sse2.h b/callbacks/sse2.h
deleted file mode 100644
index a53b8ef..0000000
--- a/callbacks/sse2.h
+++ /dev/null
@@ -1,13 +0,0 @@
-#pragma once
-
-#define CALLBACKS_THIS_IS_SSE2
-#include "callbacks/implementations.inl"
-#undef CALLBACKS_THIS_IS_SSE2
-
-namespace intgemm {
-namespace callbacks {
-
-// Put here callbacks supported only by SSE2...
-
-}
-}
diff --git a/compile_test_avx512vnni.cc b/compile_test_avx512vnni.cc
index fc1c3dd..611cc53 100644
--- a/compile_test_avx512vnni.cc
+++ b/compile_test_avx512vnni.cc
@@ -1,6 +1,11 @@
#include <immintrin.h>
-__attribute__ ((target ("avx512f,avx512bw,avx512dq,avx512vnni"))) bool Foo() {
+#ifdef __INTEL_COMPILER
+__attribute__ ((target ("avx512f")))
+#else
+__attribute__ ((target ("avx512f,avx512bw,avx512dq,avx512vnni")))
+#endif
+bool Foo() {
// AVX512F
__m512i value = _mm512_set1_epi32(1);
// AVX512BW
@@ -14,5 +19,11 @@ __attribute__ ((target ("avx512f,avx512bw,avx512dq,avx512vnni"))) bool Foo() {
}
int main() {
- return Foo() && __builtin_cpu_supports("avx512vnni");
+ return Foo() &&
+#ifdef __INTEL_COMPILER
+ _may_i_use_cpu_feature(_FEATURE_AVX512_VNNI)
+#else
+ __builtin_cpu_supports("avx512vnni")
+#endif
+ ;
}
diff --git a/interleave.h b/interleave.h
index 6b8edbd..5f9d525 100644
--- a/interleave.h
+++ b/interleave.h
@@ -10,45 +10,31 @@
namespace intgemm {
-/* This macro defines functions that interleave their arguments like
- * inline void Interleave8(__m256i &first, __m256i &second) {
- * __m256i temp = _mm256_unpacklo_epi8(first, second);
- * second = _mm256_unpackhi_epi8(first, second);
- * first = temp;
- * }
- *
- * Example usage:
- * INTGEMM_INTERLEAVE(__m128i, )
- * INTGEMM_INTERLEAVE(__m256i, 256)
- * INTGEMM_INTERLEAVE(__m512i, 512)
+/*
+ * Interleave vectors.
*/
-#define INTGEMM_INTERLEAVE(target, type, prefix) \
-target static inline void Interleave8(type &first, type &second) { \
- type temp = _mm##prefix##_unpacklo_epi8(first, second); \
- second = _mm##prefix##_unpackhi_epi8(first, second); \
- first = temp; \
-} \
-target static inline void Interleave16(type &first, type &second) { \
- type temp = _mm##prefix##_unpacklo_epi16(first, second); \
- second = _mm##prefix##_unpackhi_epi16(first, second); \
- first = temp; \
-} \
-target static inline void Interleave32(type &first, type &second) { \
- type temp = _mm##prefix##_unpacklo_epi32(first, second); \
- second = _mm##prefix##_unpackhi_epi32(first, second); \
- first = temp; \
-} \
-target static inline void Interleave64(type &first, type &second) { \
- type temp = _mm##prefix##_unpacklo_epi64(first, second); \
- second = _mm##prefix##_unpackhi_epi64(first, second); \
+#define INTGEMM_INTERLEAVE_N(target, type, N) \
+target static inline void Interleave##N(type &first, type &second) { \
+ type temp = unpacklo_epi##N(first, second); \
+ second = unpackhi_epi##N(first, second); \
first = temp; \
}
-INTGEMM_INTERLEAVE(INTGEMM_SSE2, __m128i, )
-INTGEMM_INTERLEAVE(INTGEMM_AVX2, __m256i, 256)
+#define INTGEMM_INTERLEAVE(target, type) \
+INTGEMM_INTERLEAVE_N(target, type, 8) \
+INTGEMM_INTERLEAVE_N(target, type, 16) \
+INTGEMM_INTERLEAVE_N(target, type, 32) \
+INTGEMM_INTERLEAVE_N(target, type, 64)
+
+INTGEMM_INTERLEAVE(INTGEMM_SSE2, __m128i)
+INTGEMM_INTERLEAVE(INTGEMM_AVX2, __m256i)
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512
-INTGEMM_INTERLEAVE(INTGEMM_AVX512BW, __m512i, 512)
+INTGEMM_INTERLEAVE(INTGEMM_AVX512BW, __m512i)
#endif
+
+/*
+ * Swap vectors.
+ */
#define INTGEMM_SWAP(target, Register) \
target static inline void Swap(Register &a, Register &b) { \
Register tmp = a; \
diff --git a/intgemm.h b/intgemm.h
index 6f6839c..8940085 100644
--- a/intgemm.h
+++ b/intgemm.h
@@ -1,5 +1,4 @@
#pragma once
-
/* Main interface for integer matrix multiplication.
*
* We are computing C = A * B with an optional scaling factor.
@@ -52,6 +51,10 @@
#include "avx512_gemm.h"
#include "avx512vnni_gemm.h"
+#if defined(__GNUC__) && defined(INTGEMM_COMPILER_SUPPORTS_AVX512)
+#include "cpuid.h"
+#endif
+
/* Dispatch to functions based on runtime CPUID. This adds one call-by-variable to each call. */
namespace intgemm {
@@ -117,9 +120,29 @@ static inline float MaxAbsolute(const float *begin, const float *end) {
typedef Unsupported_8bit AVX512VNNI_8bit;
#endif
+
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512
+// gcc 5.4.0 bizarrely supports avx512bw targets but not __builtin_cpu_supports("avx512bw"). So implement it manually.
+inline bool CheckAVX512BW() {
+#ifdef __INTEL_COMPILER
+ return _may_i_use_cpu_feature(_FEATURE_AVX512BW)
+#elif __GNUC__
+ unsigned int m = __get_cpuid_max(0, NULL);
+ if (m < 7) return false;
+ unsigned int eax, ebx, ecx, edx;
+ __cpuid_count(7, 0, eax, ebx, ecx, edx);
+ const unsigned int avx512bw_bit = (1 << 30);
+ return ebx & avx512bw_bit;
+#else
+ return __builtin_cpu_supports("avx512bw");
+#endif
+}
+#endif
+
/* Returns:
- * avx512 if the CPU supports AVX512F (though really it should be AVX512BW, but
- * cloud providers lie). TODO: don't catch Knights processors with this.
+ * axx512vnni if the CPU supports AVX512VNNI
+ *
+ * avx512bw if the CPU supports AVX512BW
*
* avx2 if the CPU supports AVX2
*
@@ -129,16 +152,21 @@ typedef Unsupported_8bit AVX512VNNI_8bit;
*
* unsupported otherwise
*/
-template <class T> T ChooseCPU(T avx512vnni, T avx512, T avx2, T ssse3, T sse2, T unsupported) {
- // TODO: don't catch Knights processors here!
+template <class T> T ChooseCPU(T avx512vnni, T avx512bw, T avx2, T ssse3, T sse2, T unsupported) {
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
- if (__builtin_cpu_supports("avx512vnni")) {
+ if (
+#ifdef __INTEL_COMPILER
+ _may_i_use_cpu_feature(_FEATURE_AVX512_VNNI)
+#else
+ __builtin_cpu_supports("avx512vnni")
+#endif
+ ) {
return avx512vnni;
}
#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512
- if (__builtin_cpu_supports("avx512f")) {
- return avx512;
+ if (CheckAVX512BW()) {
+ return avx512bw;
}
#endif
if (__builtin_cpu_supports("avx2")) {
diff --git a/intrinsics.h b/intrinsics.h
index 0b08493..5fe3159 100644
--- a/intrinsics.h
+++ b/intrinsics.h
@@ -188,6 +188,15 @@ INTGEMM_SSE2 static inline __m128i unpackhi_epi16(__m128i a, __m128i b) {
INTGEMM_SSE2 static inline __m128i unpacklo_epi32(__m128i a, __m128i b) {
return _mm_unpacklo_epi32(a, b);
}
+INTGEMM_SSE2 static inline __m128i unpackhi_epi32(__m128i a, __m128i b) {
+ return _mm_unpackhi_epi32(a, b);
+}
+INTGEMM_SSE2 static inline __m128i unpacklo_epi64(__m128i a, __m128i b) {
+ return _mm_unpacklo_epi64(a, b);
+}
+INTGEMM_SSE2 static inline __m128i unpackhi_epi64(__m128i a, __m128i b) {
+ return _mm_unpackhi_epi64(a, b);
+}
INTGEMM_SSE2 static inline __m128i xor_si(__m128i a, __m128i b) {
return _mm_xor_si128(a, b);
}
@@ -354,6 +363,15 @@ INTGEMM_AVX2 static inline __m256i unpackhi_epi16(__m256i a, __m256i b) {
INTGEMM_AVX2 static inline __m256i unpacklo_epi32(__m256i a, __m256i b) {
return _mm256_unpacklo_epi32(a, b);
}
+INTGEMM_AVX2 static inline __m256i unpackhi_epi32(__m256i a, __m256i b) {
+ return _mm256_unpackhi_epi32(a, b);
+}
+INTGEMM_AVX2 static inline __m256i unpacklo_epi64(__m256i a, __m256i b) {
+ return _mm256_unpacklo_epi64(a, b);
+}
+INTGEMM_AVX2 static inline __m256i unpackhi_epi64(__m256i a, __m256i b) {
+ return _mm256_unpackhi_epi64(a, b);
+}
INTGEMM_AVX2 static inline __m256i xor_si(__m256i a, __m256i b) {
return _mm256_xor_si256(a, b);
}
@@ -522,6 +540,15 @@ INTGEMM_AVX512BW static inline __m512i unpackhi_epi16(__m512i a, __m512i b) {
INTGEMM_AVX512BW static inline __m512i unpacklo_epi32(__m512i a, __m512i b) {
return _mm512_unpacklo_epi32(a, b);
}
+INTGEMM_AVX512BW static inline __m512i unpackhi_epi32(__m512i a, __m512i b) {
+ return _mm512_unpackhi_epi32(a, b);
+}
+INTGEMM_AVX512BW static inline __m512i unpacklo_epi64(__m512i a, __m512i b) {
+ return _mm512_unpacklo_epi64(a, b);
+}
+INTGEMM_AVX512BW static inline __m512i unpackhi_epi64(__m512i a, __m512i b) {
+ return _mm512_unpackhi_epi64(a, b);
+}
INTGEMM_AVX512BW static inline __m512i xor_si(__m512i a, __m512i b) {
return _mm512_xor_si512(a, b);
}
diff --git a/kernels.h b/kernels.h
index 4ab937c..ef63fec 100644
--- a/kernels.h
+++ b/kernels.h
@@ -1,5 +1,24 @@
#pragma once
-#include "kernels/sse2.h"
-#include "kernels/avx2.h"
-#include "kernels/avx512.h"
+#include "intgemm_config.h"
+#include "intrinsics.h"
+#include "types.h"
+#include "utils.h"
+#include "vec_traits.h"
+
+#include <cstdlib>
+
+#define KERNELS_THIS_IS_SSE2
+#include "kernels/implementations.inl"
+#undef KERNELS_THIS_IS_SSE2
+
+#define KERNELS_THIS_IS_AVX2
+#include "kernels/implementations.inl"
+#undef KERNELS_THIS_IS_AVX2
+
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512
+#define KERNELS_THIS_IS_AVX512BW
+#include "kernels/implementations.inl"
+#undef KERNELS_THIS_IS_AVX512BW
+#endif
+
diff --git a/kernels/avx2.h b/kernels/avx2.h
deleted file mode 100644
index c7f29ca..0000000
--- a/kernels/avx2.h
+++ /dev/null
@@ -1,13 +0,0 @@
-#pragma once
-
-#define KERNELS_THIS_IS_AVX2
-#include "kernels/implementations.inl"
-#undef KERNELS_THIS_IS_AVX2
-
-namespace intgemm {
-namespace kernels {
-
-// Put here kernels supported only by AVX2...
-
-}
-}
diff --git a/kernels/avx512.h b/kernels/avx512.h
deleted file mode 100644
index e472422..0000000
--- a/kernels/avx512.h
+++ /dev/null
@@ -1,19 +0,0 @@
-#pragma once
-
-#include "intgemm_config.h"
-
-#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512
-
-#define KERNELS_THIS_IS_AVX512BW
-#include "kernels/implementations.inl"
-#undef KERNELS_THIS_IS_AVX512BW
-
-namespace intgemm {
-namespace kernels {
-
-// Put here kernels supported only by AVX512BW...
-
-}
-}
-
-#endif
diff --git a/kernels/implementations.inl b/kernels/implementations.inl
index fda4a04..80347fc 100644
--- a/kernels/implementations.inl
+++ b/kernels/implementations.inl
@@ -1,9 +1,4 @@
-#include "intrinsics.h"
-#include "types.h"
-#include "utils.h"
-#include "vec_traits.h"
-
-#include <cstdlib>
+/* This file is included multiple times, once for each backend instruction set. */
#if defined(KERNELS_THIS_IS_SSE2)
#define CPU_NAME SSE2
diff --git a/kernels/sse2.h b/kernels/sse2.h
deleted file mode 100644
index 322fd37..0000000
--- a/kernels/sse2.h
+++ /dev/null
@@ -1,13 +0,0 @@
-#pragma once
-
-#define KERNELS_THIS_IS_SSE2
-#include "kernels/implementations.inl"
-#undef KERNELS_THIS_IS_SSE2
-
-namespace intgemm {
-namespace kernels {
-
-// Put here kernels supported only by SSE2...
-
-}
-}
diff --git a/multiply.h b/multiply.h
index 1d4113c..823fa6d 100644
--- a/multiply.h
+++ b/multiply.h
@@ -50,8 +50,11 @@ INTGEMM_AVX512BW static inline __m256i PermuteSummer(__m512i pack0123, __m512i p
}
// Find the maximum float.
-static inline INTGEMM_AVX512DQ float MaxFloat32(__m512 a) {
- return MaxFloat32(max_ps(_mm512_castps512_ps256(a), _mm512_extractf32x8_ps(a, 1)));
+static inline INTGEMM_AVX512F float MaxFloat32(__m512 a) {
+ // _mm512_extractf32x8_ps is AVX512DQ but we don't care about masking.
+ // So cast to pd, do AVX512F _mm512_extractf64x4_pd, then cast to ps.
+ __m256 upper = _mm256_castpd_ps(_mm512_extractf64x4_pd(_mm512_castps_pd(a), 1));
+ return MaxFloat32(max_ps(_mm512_castps512_ps256(a), upper));
}
#endif
diff --git a/sse2_gemm.h b/sse2_gemm.h
index 4e8f885..ef81daf 100644
--- a/sse2_gemm.h
+++ b/sse2_gemm.h
@@ -77,7 +77,7 @@ struct SSE2_16bit {
}
INTGEMM_MULTIPLY16(__m128i, INTGEMM_SSE2, CPUType::SSE2)
- constexpr static const char *const kName = "16-bit INTGEMM_SSE2";
+ constexpr static const char *const kName = "16-bit SSE2";
static const CPUType kUses = CPUType::SSE2;
};
diff --git a/ssse3_gemm.h b/ssse3_gemm.h
index 9dd290f..a2d74dd 100644
--- a/ssse3_gemm.h
+++ b/ssse3_gemm.h
@@ -147,7 +147,7 @@ struct SSSE3_8bit {
INTGEMM_PREPAREBIASFOR8(__m128i, INTGEMM_SSSE3, CPUType::SSE2)
- constexpr static const char *const kName = "8-bit INTGEMM_SSSE3";
+ constexpr static const char *const kName = "8-bit SSSE3";
static const CPUType kUses = CPUType::SSSE3;
};
diff --git a/test/add127_test.cc b/test/add127_test.cc
index 8a25a7b..5b8d426 100644
--- a/test/add127_test.cc
+++ b/test/add127_test.cc
@@ -1,4 +1,4 @@
-#include "test/test.h"
+#include "test.h"
namespace intgemm {
diff --git a/test/kernels/add_bias_test.cc b/test/kernels/add_bias_test.cc
index 3c4a593..4a2060e 100644
--- a/test/kernels/add_bias_test.cc
+++ b/test/kernels/add_bias_test.cc
@@ -1,6 +1,6 @@
-#include "test/test.h"
-#include "aligned.h"
-#include "kernels.h"
+#include "../test.h"
+#include "../../aligned.h"
+#include "../../kernels.h"
#include <numeric>
diff --git a/test/kernels/bitwise_not_test.cc b/test/kernels/bitwise_not_test.cc
index 3b78aa8..889e1bb 100644
--- a/test/kernels/bitwise_not_test.cc
+++ b/test/kernels/bitwise_not_test.cc
@@ -1,6 +1,6 @@
-#include "test/test.h"
-#include "aligned.h"
-#include "kernels.h"
+#include "../test.h"
+#include "../../aligned.h"
+#include "../../kernels.h"
#include <numeric>
diff --git a/test/kernels/downcast_test.cc b/test/kernels/downcast_test.cc
index 056c1e7..b25889f 100644
--- a/test/kernels/downcast_test.cc
+++ b/test/kernels/downcast_test.cc
@@ -1,6 +1,6 @@
-#include "test/test.h"
-#include "aligned.h"
-#include "kernels.h"
+#include "../test.h"
+#include "../../aligned.h"
+#include "../../kernels.h"
#include <numeric>
diff --git a/test/kernels/exp_test.cc b/test/kernels/exp_test.cc
index 2e4fecc..d4e100e 100644
--- a/test/kernels/exp_test.cc
+++ b/test/kernels/exp_test.cc
@@ -1,6 +1,6 @@
-#include "test/test.h"
-#include "aligned.h"
-#include "kernels.h"
+#include "../test.h"
+#include "../../aligned.h"
+#include "../../kernels.h"
#include <numeric>
diff --git a/test/kernels/floor_test.cc b/test/kernels/floor_test.cc
index 8f21af3..3f4fdf3 100644
--- a/test/kernels/floor_test.cc
+++ b/test/kernels/floor_test.cc
@@ -1,6 +1,6 @@
-#include "test/test.h"
-#include "aligned.h"
-#include "kernels.h"
+#include "../test.h"
+#include "../../aligned.h"
+#include "../../kernels.h"
#include <numeric>
diff --git a/test/kernels/multiply_sat_test.cc b/test/kernels/multiply_sat_test.cc
index 86bf581..83ce5ac 100644
--- a/test/kernels/multiply_sat_test.cc
+++ b/test/kernels/multiply_sat_test.cc
@@ -1,6 +1,6 @@
-#include "test/test.h"
-#include "aligned.h"
-#include "kernels.h"
+#include "../test.h"
+#include "../../aligned.h"
+#include "../../kernels.h"
#include <numeric>
diff --git a/test/kernels/multiply_test.cc b/test/kernels/multiply_test.cc
index 9673e89..90607f5 100644
--- a/test/kernels/multiply_test.cc
+++ b/test/kernels/multiply_test.cc
@@ -1,6 +1,6 @@
-#include "test/test.h"
-#include "aligned.h"
-#include "kernels.h"
+#include "../test.h"
+#include "../../aligned.h"
+#include "../../kernels.h"
#include <numeric>
diff --git a/test/kernels/quantize_test.cc b/test/kernels/quantize_test.cc
index 29a5ecc..e666654 100644
--- a/test/kernels/quantize_test.cc
+++ b/test/kernels/quantize_test.cc
@@ -1,6 +1,6 @@
-#include "test/test.h"
-#include "aligned.h"
-#include "kernels.h"
+#include "../test.h"
+#include "../../aligned.h"
+#include "../../kernels.h"
#include <numeric>
diff --git a/test/kernels/relu_test.cc b/test/kernels/relu_test.cc
index 7631623..fdf7c0e 100644
--- a/test/kernels/relu_test.cc
+++ b/test/kernels/relu_test.cc
@@ -1,6 +1,6 @@
-#include "test/test.h"
-#include "aligned.h"
-#include "kernels.h"
+#include "../test.h"
+#include "../../aligned.h"
+#include "../../kernels.h"
#include <numeric>
diff --git a/test/kernels/rescale_test.cc b/test/kernels/rescale_test.cc
index 9c0d581..1d7f556 100644
--- a/test/kernels/rescale_test.cc
+++ b/test/kernels/rescale_test.cc
@@ -1,6 +1,6 @@
-#include "test/test.h"
-#include "aligned.h"
-#include "kernels.h"
+#include "../test.h"
+#include "../../aligned.h"
+#include "../../kernels.h"
#include <numeric>
diff --git a/test/kernels/sigmoid_test.cc b/test/kernels/sigmoid_test.cc
index f38e890..e0e008e 100644
--- a/test/kernels/sigmoid_test.cc
+++ b/test/kernels/sigmoid_test.cc
@@ -1,6 +1,6 @@
-#include "test/test.h"
-#include "aligned.h"
-#include "kernels.h"
+#include "../test.h"
+#include "../../aligned.h"
+#include "../../kernels.h"
#include <numeric>
diff --git a/test/kernels/tanh_test.cc b/test/kernels/tanh_test.cc
index 3a4294a..7407a11 100644
--- a/test/kernels/tanh_test.cc
+++ b/test/kernels/tanh_test.cc
@@ -1,6 +1,6 @@
-#include "test/test.h"
-#include "aligned.h"
-#include "kernels.h"
+#include "../test.h"
+#include "../../aligned.h"
+#include "../../kernels.h"
#include <numeric>
diff --git a/test/kernels/unquantize_test.cc b/test/kernels/unquantize_test.cc
index b23e7bd..439970e 100644
--- a/test/kernels/unquantize_test.cc
+++ b/test/kernels/unquantize_test.cc
@@ -1,6 +1,6 @@
-#include "test/test.h"
-#include "aligned.h"
-#include "kernels.h"
+#include "../test.h"
+#include "../../aligned.h"
+#include "../../kernels.h"
#include <numeric>
diff --git a/test/kernels/upcast_test.cc b/test/kernels/upcast_test.cc
index bef4e41..5c13dfd 100644
--- a/test/kernels/upcast_test.cc
+++ b/test/kernels/upcast_test.cc
@@ -1,6 +1,6 @@
-#include "test/test.h"
-#include "aligned.h"
-#include "kernels.h"
+#include "../test.h"
+#include "../../aligned.h"
+#include "../../kernels.h"
#include <numeric>
diff --git a/test/kernels/write_test.cc b/test/kernels/write_test.cc
index 8d85600..53a0ea6 100644
--- a/test/kernels/write_test.cc
+++ b/test/kernels/write_test.cc
@@ -1,6 +1,6 @@
-#include "test/test.h"
-#include "aligned.h"
-#include "kernels.h"
+#include "../test.h"
+#include "../../aligned.h"
+#include "../../kernels.h"
#include <numeric>
diff --git a/test/multiply_test.cc b/test/multiply_test.cc
index c2dea32..c972489 100644
--- a/test/multiply_test.cc
+++ b/test/multiply_test.cc
@@ -1,9 +1,9 @@
-#include "test/test.h"
-#include "aligned.h"
-#include "interleave.h"
-#include "intgemm.h"
-#include "multiply.h"
-#include "callbacks.h"
+#include "test.h"
+#include "../aligned.h"
+#include "../interleave.h"
+#include "../intgemm.h"
+#include "../multiply.h"
+#include "../callbacks.h"
#include <algorithm>
#include <cassert>
diff --git a/test/quantize_test.cc b/test/quantize_test.cc
index 0d47ee6..83b1d20 100644
--- a/test/quantize_test.cc
+++ b/test/quantize_test.cc
@@ -1,9 +1,9 @@
-#include "test/test.h"
-#include "aligned.h"
-#include "avx2_gemm.h"
-#include "avx512_gemm.h"
-#include "sse2_gemm.h"
-#include "ssse3_gemm.h"
+#include "test.h"
+#include "../aligned.h"
+#include "../avx2_gemm.h"
+#include "../avx512_gemm.h"
+#include "../sse2_gemm.h"
+#include "../ssse3_gemm.h"
#include <cstring>
#include <iostream>
diff --git a/test/test.cc b/test/test.cc
index e1656ac..2986d82 100644
--- a/test/test.cc
+++ b/test/test.cc
@@ -1,5 +1,5 @@
#define CATCH_CONFIG_RUNNER
-#include "test/test.h"
+#include "test.h"
int main(int argc, char ** argv) {
return Catch::Session().run(argc, argv);
diff --git a/test/test.h b/test/test.h
index e9baa77..291ff45 100644
--- a/test/test.h
+++ b/test/test.h
@@ -1,9 +1,9 @@
#pragma once
-#include "3rd_party/catch.hpp"
+#include "../3rd_party/catch.hpp"
#include <sstream>
-#include "intgemm.h"
-#include "aligned.h"
+#include "../intgemm.h"
+#include "../aligned.h"
#include "intgemm_config.h"
diff --git a/test/utils_test.cc b/test/utils_test.cc
index 580a872..782027e 100644
--- a/test/utils_test.cc
+++ b/test/utils_test.cc
@@ -1,5 +1,5 @@
-#include "test/test.h"
-#include "utils.h"
+#include "test.h"
+#include "../utils.h"
namespace intgemm {
namespace {
diff --git a/types.h b/types.h
index fb010b0..d901e18 100644
--- a/types.h
+++ b/types.h
@@ -10,7 +10,6 @@
#define INTGEMM_AVX512F __attribute__ ((target ("avx512f")))
#define INTGEMM_AVX512BW __attribute__ ((target ("avx512f")))
#define INTGEMM_AVX512DQ __attribute__ ((target ("avx512f")))
-// TODO is this right?
#define INTGEMM_AVX512VNNI __attribute__ ((target ("avx512f")))
#else
#define INTGEMM_AVX512F __attribute__ ((target ("avx512f")))