Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/intgemm/intgemm.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNikolay Bogoychev <nheart@gmail.com>2021-06-24 00:49:54 +0300
committerGitHub <noreply@github.com>2021-06-24 00:49:54 +0300
commite4b82c15a368f21903657a2d3fb3259cd0f502c8 (patch)
tree831a3e08ab24efa58ba608f62f0525e24ae2458b
parent8abde25b13c3ab210c0dec8e23f4944e3953812d (diff)
parent6228d016ecc63470d2dbb76bd4ab7b0abe097993 (diff)
Merge branch 'kpu:master' into master
-rw-r--r--.github/workflows/intel-19.yml35
-rw-r--r--CMakeLists.txt49
-rw-r--r--README.md1
-rw-r--r--benchmarks/benchmark.cc31
-rw-r--r--benchmarks/benchmark_quantizer.cc8
-rw-r--r--benchmarks/biasmultiply.cc170
-rw-r--r--compile_test/avx2.cc17
-rw-r--r--compile_test/avx512bw.cc (renamed from compile_test_avx512bw.cc)0
-rw-r--r--compile_test/avx512vnni.cc (renamed from compile_test_avx512vnni.cc)0
-rw-r--r--intgemm/aligned.h23
-rw-r--r--intgemm/avx2_gemm.h40
-rw-r--r--intgemm/avx512_gemm.h6
-rw-r--r--intgemm/avx512vnni_gemm.h12
-rw-r--r--intgemm/callbacks.h2
-rw-r--r--intgemm/callbacks/configs.h15
-rw-r--r--intgemm/callbacks/implementations.inl101
-rw-r--r--intgemm/interleave.h7
-rw-r--r--intgemm/intgemm.cc168
-rw-r--r--intgemm/intgemm.h106
-rw-r--r--intgemm/intgemm_config.h.in1
-rw-r--r--intgemm/intrinsics.h20
-rw-r--r--intgemm/kernels.h2
-rw-r--r--intgemm/multiply.h17
-rw-r--r--intgemm/sse2_gemm.h4
-rw-r--r--intgemm/ssse3_gemm.h8
-rw-r--r--intgemm/stats.h4
-rw-r--r--intgemm/stats.inl6
-rw-r--r--intgemm/types.h69
-rw-r--r--intgemm/vec_traits.h2
-rw-r--r--test/add127_test.cc231
-rw-r--r--test/kernels/add_bias_test.cc2
-rw-r--r--test/kernels/bitwise_not_test.cc2
-rw-r--r--test/kernels/downcast_test.cc6
-rw-r--r--test/kernels/exp_test.cc2
-rw-r--r--test/kernels/floor_test.cc2
-rw-r--r--test/kernels/multiply_test.cc2
-rw-r--r--test/kernels/quantize_test.cc2
-rw-r--r--test/kernels/relu_test.cc2
-rw-r--r--test/kernels/rescale_test.cc2
-rw-r--r--test/kernels/sigmoid_test.cc2
-rw-r--r--test/kernels/tanh_test.cc2
-rw-r--r--test/kernels/unquantize_test.cc2
-rw-r--r--test/kernels/upcast_test.cc6
-rw-r--r--test/kernels/write_test.cc2
-rw-r--r--test/multiply_test.cc488
-rw-r--r--test/prepare_b_quantized_transposed.cc14
-rw-r--r--test/prepare_b_transposed.cc22
-rw-r--r--test/quantize_test.cc94
48 files changed, 1207 insertions, 602 deletions
diff --git a/.github/workflows/intel-19.yml b/.github/workflows/intel-19.yml
new file mode 100644
index 0000000..95371fb
--- /dev/null
+++ b/.github/workflows/intel-19.yml
@@ -0,0 +1,35 @@
+name: Intel Compiler
+
+on:
+ push:
+ branches: [master, static]
+ pull_request:
+ branches: [master, static]
+jobs:
+ build_linux_apt_cpp:
+ runs-on: ubuntu-20.04
+ defaults:
+ run:
+ shell: bash
+ steps:
+ - uses: actions/checkout@v2
+ - name: setup repo
+ run: |
+ wget https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS-2023.PUB
+ sudo apt-key add GPG-PUB-KEY-INTEL-SW-PRODUCTS-2023.PUB
+ sudo echo "deb https://apt.repos.intel.com/oneapi all main" | sudo tee /etc/apt/sources.list.d/oneAPI.list
+ sudo apt-get update
+ - name: install
+ run: sudo apt-get install -y intel-oneapi-compiler-dpcpp-cpp-and-cpp-classic
+ - name: cmake
+ run: |
+ source /opt/intel/oneapi/setvars.sh
+ mkdir -p build
+ cd build
+ cmake -DCMAKE_C_COMPILER=icc -DCMAKE_CXX_COMPILER=icpc ..
+ - name: Compile
+ working-directory: build
+ run: make -j2
+ - name: Test
+ working-directory: build
+ run: ctest -j
diff --git a/CMakeLists.txt b/CMakeLists.txt
index d1885f5..af27542 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -14,25 +14,44 @@ if(MSVC)
add_compile_options(/W4 /WX)
else()
add_compile_options(-Wall -Wextra -pedantic -Werror -Wno-unknown-pragmas)
+ if (COMPILE_WASM)
+ # Disabling Pthreads + memory growth warning to be an error for WASM
+ # Pthreads + memory growth causes JS accessing the wasm memory to be slow
+ # https://github.com/WebAssembly/design/issues/1271
+ add_compile_options(-Wno-error=pthreads-mem-growth)
+ endif()
endif()
+# Check if compiler supports AVX2 (this should only catch emscripten)
+try_compile(INTGEMM_COMPILER_SUPPORTS_AVX2
+ ${CMAKE_CURRENT_BINARY_DIR}/compile_tests
+ ${CMAKE_CURRENT_SOURCE_DIR}/compile_test/avx2.cc)
+
# Check if compiler supports AVX512BW
try_compile(INTGEMM_COMPILER_SUPPORTS_AVX512BW
${CMAKE_CURRENT_BINARY_DIR}/compile_tests
- ${CMAKE_CURRENT_SOURCE_DIR}/compile_test_avx512bw.cc)
-
-if(NOT INTGEMM_COMPILER_SUPPORTS_AVX512BW)
- message(WARNING "${Orange}Not building AVX512BW-based multiplication because your compiler is too old.\nFor details rerun cmake with --debug-trycompile then try to build in compile_tests/CMakeFiles/CMakeTmp.${ColourReset}")
-endif()
+ ${CMAKE_CURRENT_SOURCE_DIR}/compile_test/avx512bw.cc)
+# Check if the compiler supports AVX512VNNI
try_compile(INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
${CMAKE_CURRENT_BINARY_DIR}/compile_tests
- ${CMAKE_CURRENT_SOURCE_DIR}/compile_test_avx512vnni.cc)
-#No compiler flags for this test; that's part of the test!
-if(NOT INTGEMM_COMPILER_SUPPORTS_AVX512VNNI)
- message(WARNING "${Orange}Not building AVX512VNNI-based multiplication because your compiler is too old.\nFor details rerun cmake with --debug-trycompile then try to build in compile_tests/CMakeFiles/CMakeTmp.${ColourReset}")
+ ${CMAKE_CURRENT_SOURCE_DIR}/compile_test/avx512vnni.cc)
+
+if (NOT INTGEMM_COMPILER_SUPPORTS_AVX2 OR NOT INTGEMM_COMPILER_SUPPORTS_AVX512BW OR NOT INTGEMM_COMPILER_SUPPORTS_AVX512VNNI)
+ set(UNSUPPORTED "Your compiler is too old to support")
+ if (NOT INTGEMM_COMPILER_SUPPORTS_AVX2)
+ set(UNSUPPORTED "${UNSUPPORTED} AVX2")
+ endif()
+ if (NOT INTGEMM_COMPILER_SUPPORTS_AVX512BW)
+ set(UNSUPPORTED "${UNSUPPORTED} AVX512BW")
+ endif()
+ if (NOT INTGEMM_COMPILER_SUPPORTS_AVX512VNNI)
+ set(UNSUPPORTED "${UNSUPPORTED} AVX512VNNI")
+ endif()
+ message(WARNING "${Orange}${UNSUPPORTED}. Multiplication will be slower on CPUs that support these instructions. For details rerun cmake with --debug-trycompile then try to build in compile_tests/CMakeFiles/CMakeTmp.${ColourReset}")
endif()
+
add_library(intgemm STATIC intgemm/intgemm.cc)
# Generate configure file
@@ -42,7 +61,7 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR})
target_include_directories(intgemm PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
# This isn't necessary since intgemm uses entirely relative paths but source code depending on it may want to #include <intgemm/intgemm.h>
-target_include_directories(intgemm INTERFACE ${CMAKE_CURRENT_SOURCE_DIR})
+target_include_directories(intgemm INTERFACE .)
option(USE_OPENMP "Use OpenMP" OFF)
if (USE_OPENMP)
@@ -55,6 +74,16 @@ if (USE_OPENMP)
target_link_libraries(intgemm PUBLIC OpenMP::OpenMP_CXX)
endif()
+if (COMPILE_WASM)
+ # A compile defintion to compile intgemm on WASM platform
+ target_compile_definitions(intgemm PUBLIC WASM)
+endif()
+
+option(WORMHOLE "Use WASM wormhole https://bugzilla.mozilla.org/show_bug.cgi?id=1672160" OFF)
+if (WORMHOLE)
+ target_compile_definitions(intgemm PUBLIC INTGEMM_WORMHOLE)
+endif()
+
if(INTGEMM_DONT_BUILD_TESTS)
return()
endif()
diff --git a/README.md b/README.md
index 30469fc..b8388dc 100644
--- a/README.md
+++ b/README.md
@@ -6,6 +6,7 @@
![Build Ubuntu OpenMP](https://github.com/kpu/intgemm/workflows/Ubuntu%20OpenMP/badge.svg)
![Build Windows](https://github.com/kpu/intgemm/workflows/Windows/badge.svg)
![Build Mac](https://github.com/kpu/intgemm/workflows/Mac/badge.svg)
+[![Intel Compiler](https://github.com/kpu/intgemm/actions/workflows/intel-19.yml/badge.svg)](https://github.com/kpu/intgemm/actions/workflows/intel-19.yml)
# Integer Matrix Multiplication
diff --git a/benchmarks/benchmark.cc b/benchmarks/benchmark.cc
index c6133bf..512d3ec 100644
--- a/benchmarks/benchmark.cc
+++ b/benchmarks/benchmark.cc
@@ -145,45 +145,46 @@ int main(int, char ** argv) {
std::cerr << "SSSE3 8bit, 100 samples..." << std::endl;
for (int samples = 0; samples < kSamples; ++samples) {
RandomMatrices *end = (samples < 4) ? matrices_end : full_sample;
- RunAll<ssse3::Kernels8>(matrices, end, stats.ssse3_8bit);
+ RunAll<SSSE3::Kernels8>(matrices, end, stats.ssse3_8bit);
}
std::cerr << "SSE2 16bit, 100 samples..." << std::endl;
for (int samples = 0; samples < kSamples; ++samples) {
RandomMatrices *end = (samples < 4) ? matrices_end : full_sample;
- RunAll<sse2::Kernels16>(matrices, end, stats.sse2_16bit);
+ RunAll<SSE2::Kernels16>(matrices, end, stats.sse2_16bit);
}
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
std::cerr << "AVX2 8bit, 100 samples..." << std::endl;
for (int samples = 0; samples < kSamples; ++samples) {
RandomMatrices *end = (samples < 4) ? matrices_end : full_sample;
- RunAll<avx2::Kernels8>(matrices, end, stats.avx2_8bit);
+ RunAll<AVX2::Kernels8>(matrices, end, stats.avx2_8bit);
}
std::cerr << "AVX2 16bit, 100 samples..." << std::endl;
for (int samples = 0; samples < kSamples; ++samples) {
RandomMatrices *end = (samples < 4) ? matrices_end : full_sample;
- RunAll<avx2::Kernels16>(matrices, end, stats.avx2_16bit);
+ RunAll<AVX2::Kernels16>(matrices, end, stats.avx2_16bit);
}
-
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
std::cerr << "AVX512 8bit, 100 samples..." << std::endl;
for (int samples = 0; samples < kSamples; ++samples) {
RandomMatrices *end = (samples < 4) ? matrices_end : full_sample;
- RunAll<avx512bw::Kernels8>(matrices, end, stats.avx512_8bit);
+ RunAll<AVX512BW::Kernels8>(matrices, end, stats.avx512_8bit);
}
std::cerr << "AVX512 16bit, 100 samples..." << std::endl;
for (int samples = 0; samples < kSamples; ++samples) {
RandomMatrices *end = (samples < 4) ? matrices_end : full_sample;
- RunAll<avx512bw::Kernels16>(matrices, end, stats.avx512_16bit);
+ RunAll<AVX512BW::Kernels16>(matrices, end, stats.avx512_16bit);
}
#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
std::cerr << "AVX512VNNI 8bit, 100 samples..." << std::endl;
for (int samples = 0; samples < kSamples; ++samples) {
RandomMatrices *end = (samples < 4) ? matrices_end : full_sample;
- RunAll<avx512vnni::Kernels8>(matrices, end, stats.avx512vnni_8bit);
+ RunAll<AVX512VNNI::Kernels8>(matrices, end, stats.avx512vnni_8bit);
}
#endif
@@ -193,18 +194,18 @@ int main(int, char ** argv) {
}
for (std::size_t i = 0; i < sizeof(matrices) / sizeof(RandomMatrices); ++i) {
std::cout << "Multiply\t" << matrices[i].A_rows << '\t' << matrices[i].width << '\t' << matrices[i].B_cols << '\t' << "Samples=" << (kOutlierThreshold * stats.sse2_16bit[i].size()) << '\n';
- Print<ssse3::Kernels8>(stats.ssse3_8bit, i);
- Print<avx2::Kernels8>(stats.avx2_8bit, i);
+ Print<SSSE3::Kernels8>(stats.ssse3_8bit, i);
+ Print<AVX2::Kernels8>(stats.avx2_8bit, i);
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
- Print<avx512bw::Kernels8>(stats.avx512_8bit, i);
+ Print<AVX512BW::Kernels8>(stats.avx512_8bit, i);
#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
- Print<avx512vnni::Kernels8>(stats.avx512vnni_8bit, i);
+ Print<AVX512VNNI::Kernels8>(stats.avx512vnni_8bit, i);
#endif
- Print<sse2::Kernels16>(stats.sse2_16bit, i);
- Print<avx2::Kernels16>(stats.avx2_16bit, i);
+ Print<SSE2::Kernels16>(stats.sse2_16bit, i);
+ Print<AVX2::Kernels16>(stats.avx2_16bit, i);
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
- Print<avx512bw::Kernels16>(stats.avx512_16bit, i);
+ Print<AVX512BW::Kernels16>(stats.avx512_16bit, i);
#endif
}
return 0;
diff --git a/benchmarks/benchmark_quantizer.cc b/benchmarks/benchmark_quantizer.cc
index 5f36bd7..5235b1e 100644
--- a/benchmarks/benchmark_quantizer.cc
+++ b/benchmarks/benchmark_quantizer.cc
@@ -63,10 +63,12 @@ int main() {
for (float &element : in) {
element = dist(gen);
}
- QuantizerBench<intgemm::ssse3::Kernels8>(in.begin(), out.begin(), static_cast<intgemm::Index>(count));
- QuantizerBench<intgemm::avx2::Kernels8>(in.begin(), out.begin(), static_cast<intgemm::Index>(count));
+ QuantizerBench<intgemm::SSSE3::Kernels8>(in.begin(), out.begin(), static_cast<intgemm::Index>(count));
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
+ QuantizerBench<intgemm::AVX2::Kernels8>(in.begin(), out.begin(), static_cast<intgemm::Index>(count));
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
- QuantizerBench<intgemm::avx512bw::Kernels8>(in.begin(), out.begin(), static_cast<intgemm::Index>(count));
+ QuantizerBench<intgemm::AVX512BW::Kernels8>(in.begin(), out.begin(), static_cast<intgemm::Index>(count));
#endif
}
}
diff --git a/benchmarks/biasmultiply.cc b/benchmarks/biasmultiply.cc
index 490bf3b..c835b61 100644
--- a/benchmarks/biasmultiply.cc
+++ b/benchmarks/biasmultiply.cc
@@ -125,149 +125,151 @@ int main(int argc, char ** argv) {
repeat = atoi(argv[1]);
}
- std::chrono::duration<double> oldSSSE3_nobias = testOld_nobias<ssse3::Kernels8>(1, 64, 8);
+ std::chrono::duration<double> oldSSSE3_nobias = testOld_nobias<SSSE3::Kernels8>(1, 64, 8);
for (int i = 0; i<repeat; i++) {
- oldSSSE3_nobias += testOld_nobias<ssse3::Kernels8>(8, 256, 256);
- oldSSSE3_nobias += testOld_nobias<ssse3::Kernels8>(8, 2048, 256);
- oldSSSE3_nobias += testOld_nobias<ssse3::Kernels8>(320, 256, 256);
- oldSSSE3_nobias += testOld_nobias<ssse3::Kernels8>(472, 256, 256);
- oldSSSE3_nobias += testOld_nobias<ssse3::Kernels8>(248, 256, 256);
- oldSSSE3_nobias += testOld_nobias<ssse3::Kernels8>(200, 256, 256);
+ oldSSSE3_nobias += testOld_nobias<SSSE3::Kernels8>(8, 256, 256);
+ oldSSSE3_nobias += testOld_nobias<SSSE3::Kernels8>(8, 2048, 256);
+ oldSSSE3_nobias += testOld_nobias<SSSE3::Kernels8>(320, 256, 256);
+ oldSSSE3_nobias += testOld_nobias<SSSE3::Kernels8>(472, 256, 256);
+ oldSSSE3_nobias += testOld_nobias<SSSE3::Kernels8>(248, 256, 256);
+ oldSSSE3_nobias += testOld_nobias<SSSE3::Kernels8>(200, 256, 256);
}
std::cout << repeat << " iterations of SSSE3 without bias took: " << oldSSSE3_nobias.count() << " seconds." << std::endl;
- std::chrono::duration<double> oldSSSE3 = testOld<ssse3::Kernels8>(1, 64, 8);
+ std::chrono::duration<double> oldSSSE3 = testOld<SSSE3::Kernels8>(1, 64, 8);
for (int i = 0; i<repeat; i++) {
- oldSSSE3 += testOld<ssse3::Kernels8>(8, 256, 256);
- oldSSSE3 += testOld<ssse3::Kernels8>(8, 2048, 256);
- oldSSSE3 += testOld<ssse3::Kernels8>(320, 256, 256);
- oldSSSE3 += testOld<ssse3::Kernels8>(472, 256, 256);
- oldSSSE3 += testOld<ssse3::Kernels8>(248, 256, 256);
- oldSSSE3 += testOld<ssse3::Kernels8>(200, 256, 256);
+ oldSSSE3 += testOld<SSSE3::Kernels8>(8, 256, 256);
+ oldSSSE3 += testOld<SSSE3::Kernels8>(8, 2048, 256);
+ oldSSSE3 += testOld<SSSE3::Kernels8>(320, 256, 256);
+ oldSSSE3 += testOld<SSSE3::Kernels8>(472, 256, 256);
+ oldSSSE3 += testOld<SSSE3::Kernels8>(248, 256, 256);
+ oldSSSE3 += testOld<SSSE3::Kernels8>(200, 256, 256);
}
std::cout << repeat << " iterations of SSSE3 took: " << oldSSSE3.count() << " seconds." << std::endl;
- std::chrono::duration<double> newTimeSSSE3 = testOld<ssse3::Kernels8>(1, 64, 8);
+ std::chrono::duration<double> newTimeSSSE3 = testOld<SSSE3::Kernels8>(1, 64, 8);
for (int i = 0; i<repeat; i++) {
- newTimeSSSE3 += testNew<ssse3::Kernels8>(8, 256, 256);
- newTimeSSSE3 += testNew<ssse3::Kernels8>(8, 2048, 256);
- newTimeSSSE3 += testNew<ssse3::Kernels8>(320, 256, 256);
- newTimeSSSE3 += testNew<ssse3::Kernels8>(472, 256, 256);
- newTimeSSSE3 += testNew<ssse3::Kernels8>(248, 256, 256);
- newTimeSSSE3 += testNew<ssse3::Kernels8>(200, 256, 256);
+ newTimeSSSE3 += testNew<SSSE3::Kernels8>(8, 256, 256);
+ newTimeSSSE3 += testNew<SSSE3::Kernels8>(8, 2048, 256);
+ newTimeSSSE3 += testNew<SSSE3::Kernels8>(320, 256, 256);
+ newTimeSSSE3 += testNew<SSSE3::Kernels8>(472, 256, 256);
+ newTimeSSSE3 += testNew<SSSE3::Kernels8>(248, 256, 256);
+ newTimeSSSE3 += testNew<SSSE3::Kernels8>(200, 256, 256);
}
std::cout << repeat << " iterations of Shifted SSSE3 took: " << newTimeSSSE3.count() << " seconds." << std::endl;
- std::chrono::duration<double> oldAVX2_nobias = testOld_nobias<avx2::Kernels8>(1, 64, 8);
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
+ std::chrono::duration<double> oldAVX2_nobias = testOld_nobias<AVX2::Kernels8>(1, 64, 8);
for (int i = 0; i<repeat; i++) {
- oldAVX2_nobias += testOld_nobias<avx2::Kernels8>(8, 256, 256);
- oldAVX2_nobias += testOld_nobias<avx2::Kernels8>(8, 2048, 256);
- oldAVX2_nobias += testOld_nobias<avx2::Kernels8>(320, 256, 256);
- oldAVX2_nobias += testOld_nobias<avx2::Kernels8>(472, 256, 256);
- oldAVX2_nobias += testOld_nobias<avx2::Kernels8>(248, 256, 256);
- oldAVX2_nobias += testOld_nobias<avx2::Kernels8>(200, 256, 256);
+ oldAVX2_nobias += testOld_nobias<AVX2::Kernels8>(8, 256, 256);
+ oldAVX2_nobias += testOld_nobias<AVX2::Kernels8>(8, 2048, 256);
+ oldAVX2_nobias += testOld_nobias<AVX2::Kernels8>(320, 256, 256);
+ oldAVX2_nobias += testOld_nobias<AVX2::Kernels8>(472, 256, 256);
+ oldAVX2_nobias += testOld_nobias<AVX2::Kernels8>(248, 256, 256);
+ oldAVX2_nobias += testOld_nobias<AVX2::Kernels8>(200, 256, 256);
}
std::cout << repeat << " iterations of AVX2 without bias took: " << oldAVX2_nobias.count() << " seconds." << std::endl;
- std::chrono::duration<double> oldAVX2 = testOld<avx2::Kernels8>(1, 64, 8);
+ std::chrono::duration<double> oldAVX2 = testOld<AVX2::Kernels8>(1, 64, 8);
for (int i = 0; i<repeat; i++) {
- oldAVX2 += testOld<avx2::Kernels8>(8, 256, 256);
- oldAVX2 += testOld<avx2::Kernels8>(8, 2048, 256);
- oldAVX2 += testOld<avx2::Kernels8>(320, 256, 256);
- oldAVX2 += testOld<avx2::Kernels8>(472, 256, 256);
- oldAVX2 += testOld<avx2::Kernels8>(248, 256, 256);
- oldAVX2 += testOld<avx2::Kernels8>(200, 256, 256);
+ oldAVX2 += testOld<AVX2::Kernels8>(8, 256, 256);
+ oldAVX2 += testOld<AVX2::Kernels8>(8, 2048, 256);
+ oldAVX2 += testOld<AVX2::Kernels8>(320, 256, 256);
+ oldAVX2 += testOld<AVX2::Kernels8>(472, 256, 256);
+ oldAVX2 += testOld<AVX2::Kernels8>(248, 256, 256);
+ oldAVX2 += testOld<AVX2::Kernels8>(200, 256, 256);
}
std::cout << repeat << " iterations of AVX2 took: " << oldAVX2.count() << " seconds." << std::endl;
- std::chrono::duration<double> newTimeAVX2 = testOld<avx2::Kernels8>(1, 64, 8);
+ std::chrono::duration<double> newTimeAVX2 = testOld<AVX2::Kernels8>(1, 64, 8);
for (int i = 0; i<repeat; i++) {
- newTimeAVX2 += testNew<avx2::Kernels8>(8, 256, 256);
- newTimeAVX2 += testNew<avx2::Kernels8>(8, 2048, 256);
- newTimeAVX2 += testNew<avx2::Kernels8>(320, 256, 256);
- newTimeAVX2 += testNew<avx2::Kernels8>(472, 256, 256);
- newTimeAVX2 += testNew<avx2::Kernels8>(248, 256, 256);
- newTimeAVX2 += testNew<avx2::Kernels8>(200, 256, 256);
+ newTimeAVX2 += testNew<AVX2::Kernels8>(8, 256, 256);
+ newTimeAVX2 += testNew<AVX2::Kernels8>(8, 2048, 256);
+ newTimeAVX2 += testNew<AVX2::Kernels8>(320, 256, 256);
+ newTimeAVX2 += testNew<AVX2::Kernels8>(472, 256, 256);
+ newTimeAVX2 += testNew<AVX2::Kernels8>(248, 256, 256);
+ newTimeAVX2 += testNew<AVX2::Kernels8>(200, 256, 256);
}
std::cout << repeat << " iterations of Shifted AVX2 took: " << newTimeAVX2.count() << " seconds." << std::endl;
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
if (kCPU < CPUType::AVX512BW) return 0;
- std::chrono::duration<double> oldAVX512_nobias = testOld_nobias<avx512bw::Kernels8>(1, 64, 8);
+ std::chrono::duration<double> oldAVX512_nobias = testOld_nobias<AVX512BW::Kernels8>(1, 64, 8);
for (int i = 0; i<repeat; i++) {
- oldAVX512_nobias += testOld_nobias<avx512bw::Kernels8>(8, 256, 256);
- oldAVX512_nobias += testOld_nobias<avx512bw::Kernels8>(8, 2048, 256);
- oldAVX512_nobias += testOld_nobias<avx512bw::Kernels8>(320, 256, 256);
- oldAVX512_nobias += testOld_nobias<avx512bw::Kernels8>(472, 256, 256);
- oldAVX512_nobias += testOld_nobias<avx512bw::Kernels8>(248, 256, 256);
- oldAVX512_nobias += testOld_nobias<avx512bw::Kernels8>(200, 256, 256);
+ oldAVX512_nobias += testOld_nobias<AVX512BW::Kernels8>(8, 256, 256);
+ oldAVX512_nobias += testOld_nobias<AVX512BW::Kernels8>(8, 2048, 256);
+ oldAVX512_nobias += testOld_nobias<AVX512BW::Kernels8>(320, 256, 256);
+ oldAVX512_nobias += testOld_nobias<AVX512BW::Kernels8>(472, 256, 256);
+ oldAVX512_nobias += testOld_nobias<AVX512BW::Kernels8>(248, 256, 256);
+ oldAVX512_nobias += testOld_nobias<AVX512BW::Kernels8>(200, 256, 256);
}
std::cout << repeat << " iterations of AVX512 without bias took: " << oldAVX512_nobias.count() << " seconds." << std::endl;
- std::chrono::duration<double> oldAVX512 = testOld<avx512bw::Kernels8>(1, 64, 8);
+ std::chrono::duration<double> oldAVX512 = testOld<AVX512BW::Kernels8>(1, 64, 8);
for (int i = 0; i<repeat; i++) {
- oldAVX512 += testOld<avx512bw::Kernels8>(8, 256, 256);
- oldAVX512 += testOld<avx512bw::Kernels8>(8, 2048, 256);
- oldAVX512 += testOld<avx512bw::Kernels8>(320, 256, 256);
- oldAVX512 += testOld<avx512bw::Kernels8>(472, 256, 256);
- oldAVX512 += testOld<avx512bw::Kernels8>(248, 256, 256);
- oldAVX512 += testOld<avx512bw::Kernels8>(200, 256, 256);
+ oldAVX512 += testOld<AVX512BW::Kernels8>(8, 256, 256);
+ oldAVX512 += testOld<AVX512BW::Kernels8>(8, 2048, 256);
+ oldAVX512 += testOld<AVX512BW::Kernels8>(320, 256, 256);
+ oldAVX512 += testOld<AVX512BW::Kernels8>(472, 256, 256);
+ oldAVX512 += testOld<AVX512BW::Kernels8>(248, 256, 256);
+ oldAVX512 += testOld<AVX512BW::Kernels8>(200, 256, 256);
}
std::cout << repeat << " iterations of AVX512 took: " << oldAVX512.count() << " seconds." << std::endl;
- std::chrono::duration<double> newTimeAVX512 = testOld<avx512bw::Kernels8>(1, 64, 8);
+ std::chrono::duration<double> newTimeAVX512 = testOld<AVX512BW::Kernels8>(1, 64, 8);
for (int i = 0; i<repeat; i++) {
- newTimeAVX512 += testNew<avx512bw::Kernels8>(8, 256, 256);
- newTimeAVX512 += testNew<avx512bw::Kernels8>(8, 2048, 256);
- newTimeAVX512 += testNew<avx512bw::Kernels8>(320, 256, 256);
- newTimeAVX512 += testNew<avx512bw::Kernels8>(472, 256, 256);
- newTimeAVX512 += testNew<avx512bw::Kernels8>(248, 256, 256);
- newTimeAVX512 += testNew<avx512bw::Kernels8>(200, 256, 256);
+ newTimeAVX512 += testNew<AVX512BW::Kernels8>(8, 256, 256);
+ newTimeAVX512 += testNew<AVX512BW::Kernels8>(8, 2048, 256);
+ newTimeAVX512 += testNew<AVX512BW::Kernels8>(320, 256, 256);
+ newTimeAVX512 += testNew<AVX512BW::Kernels8>(472, 256, 256);
+ newTimeAVX512 += testNew<AVX512BW::Kernels8>(248, 256, 256);
+ newTimeAVX512 += testNew<AVX512BW::Kernels8>(200, 256, 256);
}
std::cout << repeat << " iterations of Shifted AVX512 took: " << newTimeAVX512.count() << " seconds." << std::endl;
#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
if (kCPU < CPUType::AVX512VNNI) return 0;
- std::chrono::duration<double> oldAVX512VNNI_nobias = testOld_nobias<avx512bw::Kernels8>(1, 64, 8);
+ std::chrono::duration<double> oldAVX512VNNI_nobias = testOld_nobias<AVX512BW::Kernels8>(1, 64, 8);
for (int i = 0; i<repeat; i++) {
- oldAVX512VNNI_nobias += testOld_nobias<avx512vnni::Kernels8>(8, 256, 256);
- oldAVX512VNNI_nobias += testOld_nobias<avx512vnni::Kernels8>(8, 2048, 256);
- oldAVX512VNNI_nobias += testOld_nobias<avx512vnni::Kernels8>(320, 256, 256);
- oldAVX512VNNI_nobias += testOld_nobias<avx512vnni::Kernels8>(472, 256, 256);
- oldAVX512VNNI_nobias += testOld_nobias<avx512vnni::Kernels8>(248, 256, 256);
- oldAVX512VNNI_nobias += testOld_nobias<avx512vnni::Kernels8>(200, 256, 256);
+ oldAVX512VNNI_nobias += testOld_nobias<AVX512VNNI::Kernels8>(8, 256, 256);
+ oldAVX512VNNI_nobias += testOld_nobias<AVX512VNNI::Kernels8>(8, 2048, 256);
+ oldAVX512VNNI_nobias += testOld_nobias<AVX512VNNI::Kernels8>(320, 256, 256);
+ oldAVX512VNNI_nobias += testOld_nobias<AVX512VNNI::Kernels8>(472, 256, 256);
+ oldAVX512VNNI_nobias += testOld_nobias<AVX512VNNI::Kernels8>(248, 256, 256);
+ oldAVX512VNNI_nobias += testOld_nobias<AVX512VNNI::Kernels8>(200, 256, 256);
}
std::cout << repeat << " iterations of AVX512VNNI without bias took: " << oldAVX512VNNI_nobias.count() << " seconds." << std::endl;
- std::chrono::duration<double> oldAVX512VNNI = testOld<avx512bw::Kernels8>(1, 64, 8);
+ std::chrono::duration<double> oldAVX512VNNI = testOld<AVX512BW::Kernels8>(1, 64, 8);
for (int i = 0; i<repeat; i++) {
- oldAVX512VNNI += testOld<avx512vnni::Kernels8>(8, 256, 256);
- oldAVX512VNNI += testOld<avx512vnni::Kernels8>(8, 2048, 256);
- oldAVX512VNNI += testOld<avx512vnni::Kernels8>(320, 256, 256);
- oldAVX512VNNI += testOld<avx512vnni::Kernels8>(472, 256, 256);
- oldAVX512VNNI += testOld<avx512vnni::Kernels8>(248, 256, 256);
- oldAVX512VNNI += testOld<avx512vnni::Kernels8>(200, 256, 256);
+ oldAVX512VNNI += testOld<AVX512VNNI::Kernels8>(8, 256, 256);
+ oldAVX512VNNI += testOld<AVX512VNNI::Kernels8>(8, 2048, 256);
+ oldAVX512VNNI += testOld<AVX512VNNI::Kernels8>(320, 256, 256);
+ oldAVX512VNNI += testOld<AVX512VNNI::Kernels8>(472, 256, 256);
+ oldAVX512VNNI += testOld<AVX512VNNI::Kernels8>(248, 256, 256);
+ oldAVX512VNNI += testOld<AVX512VNNI::Kernels8>(200, 256, 256);
}
std::cout << repeat << " iterations of AVX512VNNI took: " << oldAVX512VNNI.count() << " seconds." << std::endl;
- std::chrono::duration<double> newTimeAVX512VNNI = testOld<avx512bw::Kernels8>(1, 64, 8);
+ std::chrono::duration<double> newTimeAVX512VNNI = testOld<AVX512BW::Kernels8>(1, 64, 8);
for (int i = 0; i<repeat; i++) {
- newTimeAVX512VNNI += testNew<avx512vnni::Kernels8>(8, 256, 256);
- newTimeAVX512VNNI += testNew<avx512vnni::Kernels8>(8, 2048, 256);
- newTimeAVX512VNNI += testNew<avx512vnni::Kernels8>(320, 256, 256);
- newTimeAVX512VNNI += testNew<avx512vnni::Kernels8>(472, 256, 256);
- newTimeAVX512VNNI += testNew<avx512vnni::Kernels8>(248, 256, 256);
- newTimeAVX512VNNI += testNew<avx512vnni::Kernels8>(200, 256, 256);
+ newTimeAVX512VNNI += testNew<AVX512VNNI::Kernels8>(8, 256, 256);
+ newTimeAVX512VNNI += testNew<AVX512VNNI::Kernels8>(8, 2048, 256);
+ newTimeAVX512VNNI += testNew<AVX512VNNI::Kernels8>(320, 256, 256);
+ newTimeAVX512VNNI += testNew<AVX512VNNI::Kernels8>(472, 256, 256);
+ newTimeAVX512VNNI += testNew<AVX512VNNI::Kernels8>(248, 256, 256);
+ newTimeAVX512VNNI += testNew<AVX512VNNI::Kernels8>(200, 256, 256);
}
std::cout << repeat << " iterations of Shifted AVX512VNNI took: " << newTimeAVX512VNNI.count() << " seconds." << std::endl;
diff --git a/compile_test/avx2.cc b/compile_test/avx2.cc
new file mode 100644
index 0000000..8460fc0
--- /dev/null
+++ b/compile_test/avx2.cc
@@ -0,0 +1,17 @@
+// Some compilers don't have AVX2 support. Test for them.
+#include <immintrin.h>
+
+#if defined(_MSC_VER)
+#define INTGEMM_AVX2
+#else
+#define INTGEMM_AVX2 __attribute__ ((target ("avx2")))
+#endif
+
+INTGEMM_AVX2 int Test() {
+ __m256i value = _mm256_set1_epi32(1);
+ value = _mm256_abs_epi8(value);
+ return *(int*)&value;
+}
+
+int main() {
+}
diff --git a/compile_test_avx512bw.cc b/compile_test/avx512bw.cc
index 2cd4c6a..2cd4c6a 100644
--- a/compile_test_avx512bw.cc
+++ b/compile_test/avx512bw.cc
diff --git a/compile_test_avx512vnni.cc b/compile_test/avx512vnni.cc
index 1485cde..1485cde 100644
--- a/compile_test_avx512vnni.cc
+++ b/compile_test/avx512vnni.cc
diff --git a/intgemm/aligned.h b/intgemm/aligned.h
index 7500a8c..6fda369 100644
--- a/intgemm/aligned.h
+++ b/intgemm/aligned.h
@@ -5,24 +5,39 @@
#include <malloc.h>
#endif
-// 64-byte aligned simple vector.
+// Aligned simple vector.
namespace intgemm {
template <class T> class AlignedVector {
public:
- explicit AlignedVector(std::size_t size)
+ AlignedVector() : mem_(nullptr), size_(0) {}
+
+ explicit AlignedVector(std::size_t size, std::size_t alignment = 64 /* CPU cares about this */)
: size_(size) {
#ifdef _MSC_VER
- mem_ = static_cast<T*>(_aligned_malloc(size * sizeof(T), 64));
+ mem_ = static_cast<T*>(_aligned_malloc(size * sizeof(T), alignment));
if (!mem_) throw std::bad_alloc();
#else
- if (posix_memalign(reinterpret_cast<void **>(&mem_), 64, size * sizeof(T))) {
+ if (posix_memalign(reinterpret_cast<void **>(&mem_), alignment, size * sizeof(T))) {
throw std::bad_alloc();
}
#endif
}
+ AlignedVector(AlignedVector &&from) : mem_(from.mem_), size_(from.size_) {
+ from.mem_ = nullptr;
+ from.size_ = 0;
+ }
+
+ AlignedVector &operator=(AlignedVector &&from) {
+ mem_ = from.mem_;
+ size_ = from.size_;
+ from.mem_ = nullptr;
+ from.size_ = 0;
+ return *this;
+ }
+
AlignedVector(const AlignedVector&) = delete;
AlignedVector& operator=(const AlignedVector&) = delete;
diff --git a/intgemm/avx2_gemm.h b/intgemm/avx2_gemm.h
index 5e81475..d93ac8e 100644
--- a/intgemm/avx2_gemm.h
+++ b/intgemm/avx2_gemm.h
@@ -1,5 +1,9 @@
#pragma once
+#include "intgemm/intgemm_config.h"
+
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
+
#include "interleave.h"
#include "kernels.h"
#include "multiply.h"
@@ -9,7 +13,7 @@
#include <cstring>
namespace intgemm {
-namespace avx2 {
+namespace AVX2 {
INTGEMM_AVX2 inline Register QuantizerGrab(const float *input, const __m256 quant_mult_reg) {
return kernels::quantize(loadu_ps<FRegister>(input), quant_mult_reg);
@@ -69,14 +73,14 @@ struct Kernels16 {
static const Index kBTileCol = 8;
/*
INTGEMM_AVX2 static void PrepareB(const float *input, int16_t *output, float quant_mult, Index rows, Index cols) {
- PrepareBFor16(input, output, avx2::QuantizeTile16(quant_mult), rows, cols);
+ PrepareBFor16(input, output, AVX2::QuantizeTile16(quant_mult), rows, cols);
}*/
- INTGEMM_PREPARE_B_16(INTGEMM_AVX2, avx2::QuantizeTile16)
+ INTGEMM_PREPARE_B_16(INTGEMM_AVX2, AVX2::QuantizeTile16)
INTGEMM_PREPARE_B_QUANTIZED_TRANSPOSED(INTGEMM_AVX2, int16_t)
- INTGEMM_PREPARE_B_TRANSPOSED(INTGEMM_AVX2, avx2::QuantizeTile16, int16_t)
+ INTGEMM_PREPARE_B_TRANSPOSED(INTGEMM_AVX2, AVX2::QuantizeTile16, int16_t)
INTGEMM_AVX2 static void SelectColumnsB(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end) {
- avx2::SelectColumnsOfB((const __m256i*)input, (__m256i*)output, rows * 2, cols_begin, cols_end);
+ AVX2::SelectColumnsOfB((const __m256i*)input, (__m256i*)output, rows * 2, cols_begin, cols_end);
}
INTGEMM_MULTIPLY16(__m256i, INTGEMM_AVX2, CPUType::AVX2)
@@ -125,10 +129,10 @@ class QuantizeTile8 {
const __m256i neg127 = _mm256_set1_epi8(-127);
const __m256i shuffle_param = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
// Grab 4 registers at a time in 32-bit format.
- __m256i g0 = avx2::QuantizerGrab(input0, quant_mult);
- __m256i g1 = avx2::QuantizerGrab(input1, quant_mult);
- __m256i g2 = avx2::QuantizerGrab(input2, quant_mult);
- __m256i g3 = avx2::QuantizerGrab(input3, quant_mult);
+ __m256i g0 = AVX2::QuantizerGrab(input0, quant_mult);
+ __m256i g1 = AVX2::QuantizerGrab(input1, quant_mult);
+ __m256i g2 = AVX2::QuantizerGrab(input2, quant_mult);
+ __m256i g3 = AVX2::QuantizerGrab(input3, quant_mult);
// Pack 32-bit to 16-bit.
__m256i packed0 = _mm256_packs_epi32(g0, g1);
__m256i packed1 = _mm256_packs_epi32(g2, g3);
@@ -151,10 +155,10 @@ class QuantizeTile8 {
const __m256i pos127 = _mm256_set1_epi8(127);
const __m256i shuffle_param = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
// Grab 4 registers at a time in 32-bit format.
- __m256i g0 = avx2::QuantizerGrab(input0, quant_mult);
- __m256i g1 = avx2::QuantizerGrab(input1, quant_mult);
- __m256i g2 = avx2::QuantizerGrab(input2, quant_mult);
- __m256i g3 = avx2::QuantizerGrab(input3, quant_mult);
+ __m256i g0 = AVX2::QuantizerGrab(input0, quant_mult);
+ __m256i g1 = AVX2::QuantizerGrab(input1, quant_mult);
+ __m256i g2 = AVX2::QuantizerGrab(input2, quant_mult);
+ __m256i g3 = AVX2::QuantizerGrab(input3, quant_mult);
// Pack 32-bit to 16-bit.
__m256i packed0 = _mm256_packs_epi32(g0, g1);
__m256i packed1 = _mm256_packs_epi32(g2, g3);
@@ -203,12 +207,12 @@ struct Kernels8 {
static const Index kBTileRow = 32;
static const Index kBTileCol = 8;
- INTGEMM_PREPARE_B_8(INTGEMM_AVX2, avx2::QuantizeTile8)
+ INTGEMM_PREPARE_B_8(INTGEMM_AVX2, AVX2::QuantizeTile8)
INTGEMM_PREPARE_B_QUANTIZED_TRANSPOSED(INTGEMM_AVX2, int8_t)
- INTGEMM_PREPARE_B_TRANSPOSED(INTGEMM_AVX2, avx2::QuantizeTile8, int8_t)
+ INTGEMM_PREPARE_B_TRANSPOSED(INTGEMM_AVX2, AVX2::QuantizeTile8, int8_t)
INTGEMM_AVX2 static void SelectColumnsB(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) {
- avx2::SelectColumnsOfB((const __m256i*)input, (__m256i*)output, rows, cols_begin, cols_end);
+ AVX2::SelectColumnsOfB((const __m256i*)input, (__m256i*)output, rows, cols_begin, cols_end);
}
INTGEMM_MULTIPLY8(__m256i, INTGEMM_AVX2, CPUType::AVX2)
@@ -222,5 +226,7 @@ struct Kernels8 {
static const CPUType kUses = CPUType::AVX2;
};
-} // namespace avx2
+} // namespace AVX2
} // namespace intgemm
+
+#endif
diff --git a/intgemm/avx512_gemm.h b/intgemm/avx512_gemm.h
index f9fb1eb..90f67ee 100644
--- a/intgemm/avx512_gemm.h
+++ b/intgemm/avx512_gemm.h
@@ -31,7 +31,7 @@ namespace intgemm {
// So conversion in memory uses these, but I also implement a wider version for
// rearranging B.
-namespace avx512bw {
+namespace AVX512BW {
// Load from memory, multiply, and convert to int32_t.
/* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
@@ -391,7 +391,7 @@ struct Kernels8 {
Register pack4567 = Pack0123(sum4, sum5, sum6, sum7);
auto total = PermuteSummer(pack0123, pack4567);
- callback_impl(total, callbacks::OutputBufferInfo(A_rowidx, B0_colidx, A_rows, B_cols));
+ callback_impl.Run(total, callbacks::OutputBufferInfo(A_rowidx, B0_colidx, A_rows, B_cols));
}
}
}
@@ -405,7 +405,7 @@ struct Kernels8 {
static const CPUType kUses = CPUType::AVX512BW;
};
-} // namespace avx512bw
+} // namespace AVX512BW
} // namespace intgemm
#endif
diff --git a/intgemm/avx512vnni_gemm.h b/intgemm/avx512vnni_gemm.h
index c660168..28e8c14 100644
--- a/intgemm/avx512vnni_gemm.h
+++ b/intgemm/avx512vnni_gemm.h
@@ -7,7 +7,7 @@
#include "types.h"
namespace intgemm {
-namespace avx512vnni {
+namespace AVX512VNNI {
// Workaround extra vmovdqa64 https://gcc.gnu.org/bugzilla/show_bug.cgi?id=94663
INTGEMM_AVX512VNNI static inline void VNNI8(__m512i &c, __m512i a, __m512i b) {
@@ -18,7 +18,7 @@ INTGEMM_AVX512VNNI static inline void VNNI8(__m512i &c, __m512i a, __m512i b) {
#endif
}
-struct Kernels8 : public avx512bw::Kernels8 {
+struct Kernels8 : public AVX512BW::Kernels8 {
template <typename Callback>
INTGEMM_AVX512VNNI static void Multiply(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) {
assert(width % sizeof(Register) == 0);
@@ -75,7 +75,7 @@ struct Kernels8 : public avx512bw::Kernels8 {
Register pack0123 = Pack0123(sum0, sum1, sum2, sum3);
Register pack4567 = Pack0123(sum4, sum5, sum6, sum7);
auto total = PermuteSummer(pack0123, pack4567);
- callback_impl(total, callbacks::OutputBufferInfo(A_rowidx, B0_colidx, A_rows, B_cols));
+ callback_impl.Run(total, callbacks::OutputBufferInfo(A_rowidx, B0_colidx, A_rows, B_cols));
}
}
}
@@ -116,7 +116,7 @@ struct Kernels8 : public avx512bw::Kernels8 {
Register pack0123 = Pack0123(sum0, sum1, sum2, sum3);
Register pack4567 = Pack0123(sum4, sum5, sum6, sum7);
auto total = PermuteSummer(pack0123, pack4567);
- callback_impl(total, callbacks::OutputBufferInfo(A_rowidx, B0_colidx, A_rows, B_cols));
+ callback_impl.Run(total, callbacks::OutputBufferInfo(A_rowidx, B0_colidx, A_rows, B_cols));
}
}
}
@@ -153,7 +153,7 @@ struct Kernels8 : public avx512bw::Kernels8 {
Register pack0123 = Pack0123(sum0, sum1, sum2, sum3);
Register pack4567 = Pack0123(sum4, sum5, sum6, sum7);
auto total = PermuteSummer(pack0123, pack4567);
- callback_impl(total, callbacks::OutputBufferInfo(0, B0_colidx, 1, B_cols));
+ callback_impl.Run(total, callbacks::OutputBufferInfo(0, B0_colidx, 1, B_cols));
}
}
@@ -162,7 +162,7 @@ struct Kernels8 : public avx512bw::Kernels8 {
static const CPUType kUses = CPUType::AVX512VNNI;
};
-} // namespace avx512vnni
+} // namespace AVX512VNNI
} // namespace intgemm
#endif
diff --git a/intgemm/callbacks.h b/intgemm/callbacks.h
index 23d3be1..c304466 100644
--- a/intgemm/callbacks.h
+++ b/intgemm/callbacks.h
@@ -14,9 +14,11 @@
#include "callbacks/implementations.inl"
#undef CALLBACKS_THIS_IS_SSE2
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
#define CALLBACKS_THIS_IS_AVX2
#include "callbacks/implementations.inl"
#undef CALLBACKS_THIS_IS_AVX2
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
#define CALLBACKS_THIS_IS_AVX512BW
diff --git a/intgemm/callbacks/configs.h b/intgemm/callbacks/configs.h
index 1222448..d2fbe98 100644
--- a/intgemm/callbacks/configs.h
+++ b/intgemm/callbacks/configs.h
@@ -39,6 +39,13 @@ struct UnquantizeAndWrite {
UnquantizeAndWrite(float unquant_mult, float* output_addr) : unquant_mult(unquant_mult), output_addr(output_addr) {}
};
+struct UnquantizeAndWriteRelu {
+ float unquant_mult;
+ float* output_addr;
+
+ UnquantizeAndWriteRelu(float unquant_mult, float* output_addr) : unquant_mult(unquant_mult), output_addr(output_addr) {}
+};
+
struct AddBiasAndWrite {
const int* bias_addr;
int* output_addr;
@@ -54,5 +61,13 @@ struct UnquantizeAndAddBiasAndWrite {
UnquantizeAndAddBiasAndWrite(float unquant_mult, const float* bias_addr, float* output_addr) : unquant_mult(unquant_mult), bias_addr(bias_addr), output_addr(output_addr) {}
};
+struct UnquantizeAndAddBiasAndWriteRelu {
+ float unquant_mult;
+ const float* bias_addr;
+ float* output_addr;
+
+ UnquantizeAndAddBiasAndWriteRelu(float unquant_mult, const float* bias_addr, float* output_addr) : unquant_mult(unquant_mult), bias_addr(bias_addr), output_addr(output_addr) {}
+};
+
}
}
diff --git a/intgemm/callbacks/implementations.inl b/intgemm/callbacks/implementations.inl
index 47d2aa4..126701d 100644
--- a/intgemm/callbacks/implementations.inl
+++ b/intgemm/callbacks/implementations.inl
@@ -1,13 +1,13 @@
/* This file is included multiple times, once per architecture. */
#if defined(CALLBACKS_THIS_IS_SSE2)
#define CPU_NAME SSE2
- #define CPU_ATTR INTGEMM_SSE2
+ #define INTGEMM_TARGET INTGEMM_SSE2
#elif defined(CALLBACKS_THIS_IS_AVX2)
#define CPU_NAME AVX2
- #define CPU_ATTR INTGEMM_AVX2
+ #define INTGEMM_TARGET INTGEMM_AVX2
#elif defined(CALLBACKS_THIS_IS_AVX512BW)
#define CPU_NAME AVX512BW
- #define CPU_ATTR INTGEMM_AVX512BW
+ #define INTGEMM_TARGET INTGEMM_AVX512BW
#else
#error "Only SSE2, AVX2 and AVX512BW are supported"
#endif
@@ -22,6 +22,13 @@
#define vd vector_t<CPUType::AVX2, double>
#endif
+/* Intel compiler 19.1.0.166 20191121 fails to link constructors with target attributes */
+#ifdef __INTEL_COMPILER
+#define INTGEMM_TARGET_CONSTRUCTOR
+#else
+#define INTGEMM_TARGET_CONSTRUCTOR INTGEMM_TARGET
+#endif
+
namespace intgemm {
namespace callbacks {
@@ -42,9 +49,9 @@ namespace callbacks {
template <typename... Configs>
class CallbackImpl<CPUType::CPU_NAME, std::tuple<Configs...>> {
public:
- CPU_ATTR CallbackImpl(const std::tuple<Configs...>& configs) : callbacks(init_callbacks(configs, make_sequence<sizeof...(Configs)>())) {}
+ explicit CallbackImpl(const std::tuple<Configs...>& configs) : callbacks(init_callbacks(configs, make_sequence<sizeof...(Configs)>())) {}
- CPU_ATTR void operator()(vi input, const OutputBufferInfo& info) {
+ INTGEMM_TARGET void Run(vi input, const OutputBufferInfo& info) {
run_callbacks(input, info, callbacks, make_sequence<sizeof...(Configs)>());
}
@@ -60,11 +67,11 @@ private:
#define RUN_CALLBACKS_PIPELINE_IMPL(vtype) \
template <unsigned FirstIndex> \
- CPU_ATTR static inline void run_callbacks(vtype input, const OutputBufferInfo& info, CallbacksTupleType& tuple, sequence<FirstIndex>) { \
+ INTGEMM_TARGET static inline void run_callbacks(vtype input, const OutputBufferInfo& info, CallbacksTupleType& tuple, sequence<FirstIndex>) { \
std::get<FirstIndex>(tuple)(input, info); \
} \
template <unsigned FirstIndex, unsigned SecondIndex, unsigned... RestIndices> \
- CPU_ATTR static inline void run_callbacks(vtype input, const OutputBufferInfo& info, CallbacksTupleType& tuple, sequence<FirstIndex, SecondIndex, RestIndices...>) { \
+ INTGEMM_TARGET static inline void run_callbacks(vtype input, const OutputBufferInfo& info, CallbacksTupleType& tuple, sequence<FirstIndex, SecondIndex, RestIndices...>) { \
auto output = std::get<FirstIndex>(tuple)(input, info); \
run_callbacks(output, info, tuple, sequence<SecondIndex, RestIndices...>()); \
}
@@ -81,8 +88,8 @@ private:
*/
template <> class CallbackImpl<CPUType::CPU_NAME, Dummy> {
public:
- CPU_ATTR CallbackImpl(const Dummy&) {}
- CPU_ATTR void operator()(vi, const OutputBufferInfo&) {}
+ explicit INTGEMM_TARGET_CONSTRUCTOR CallbackImpl(const Dummy&) {}
+ INTGEMM_TARGET void Run(vi, const OutputBufferInfo&) {}
};
/*
@@ -91,9 +98,9 @@ public:
template <typename Type>
class CallbackImpl<CPUType::CPU_NAME, Write<Type>> {
public:
- CPU_ATTR CallbackImpl(const Write<Type>& config) : config(config) {}
+ explicit INTGEMM_TARGET_CONSTRUCTOR CallbackImpl(const Write<Type>& config) : config(config) {}
- CPU_ATTR void operator()(vector_t<CPUType::CPU_NAME, Type> input, const OutputBufferInfo& info) {
+ INTGEMM_TARGET void Run(vector_t<CPUType::CPU_NAME, Type> input, const OutputBufferInfo& info) {
kernels::write(input, config.output_addr, info.row_idx * info.cols + info.col_idx);
}
@@ -106,11 +113,11 @@ private:
*/
template <> class CallbackImpl<CPUType::CPU_NAME, Unquantize> {
public:
- CPU_ATTR CallbackImpl(const Unquantize& config) : config(config) {
+ explicit INTGEMM_TARGET_CONSTRUCTOR CallbackImpl(const Unquantize& config) : config(config) {
unquant_mult = set1_ps<vf>(config.unquant_mult);
}
- CPU_ATTR vf operator()(vi input, const OutputBufferInfo&) {
+ INTGEMM_TARGET vf Run(vi input, const OutputBufferInfo&) {
return kernels::unquantize(input, unquant_mult);
}
@@ -124,11 +131,11 @@ private:
*/
template <> class CallbackImpl<CPUType::CPU_NAME, UnquantizeAndWrite> {
public:
- CPU_ATTR CallbackImpl(const UnquantizeAndWrite& config) : config(config) {
+ explicit INTGEMM_TARGET_CONSTRUCTOR CallbackImpl(const UnquantizeAndWrite& config) : config(config) {
unquant_mult = set1_ps<vf>(config.unquant_mult);
}
- CPU_ATTR void operator()(vi input, const OutputBufferInfo& info) {
+ INTGEMM_TARGET void Run(vi input, const OutputBufferInfo& info) {
// Workaround gcc 5 internal compiler error that can't read register members in debug.
vf mult_reg;
#if !defined(__OPTIMIZE__) && (__GNUC__ == 5) && !defined(__clang__) && !defined(__INTEL_COMPILER)
@@ -146,13 +153,40 @@ private:
};
/*
+ * UnquantizeAndWriteRelu
+ */
+template <> class CallbackImpl<CPUType::CPU_NAME, UnquantizeAndWriteRelu> {
+public:
+ explicit INTGEMM_TARGET_CONSTRUCTOR CallbackImpl(const UnquantizeAndWriteRelu& config) : config(config) {
+ unquant_mult = set1_ps<vf>(config.unquant_mult);
+ }
+
+ INTGEMM_TARGET void Run(vi input, const OutputBufferInfo& info) {
+ // Workaround gcc 5 internal compiler error that can't read register members in debug.
+ vf mult_reg;
+#if !defined(__OPTIMIZE__) && (__GNUC__ == 5) && !defined(__clang__) && !defined(__INTEL_COMPILER)
+ asm ("vmovdqa %1, %0" : "=x" (mult_reg) : "m" (unquant_mult));
+#else
+ mult_reg = unquant_mult;
+#endif
+ auto result = kernels::relu<float>(kernels::unquantize(input, mult_reg));
+ kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx);
+ }
+
+private:
+ vf unquant_mult;
+ UnquantizeAndWriteRelu config;
+};
+
+
+/*
* AddBiasAndWrite
*/
template <> class CallbackImpl<CPUType::CPU_NAME, AddBiasAndWrite> {
public:
- CPU_ATTR CallbackImpl(const AddBiasAndWrite& config) : config(config) {}
+ explicit INTGEMM_TARGET_CONSTRUCTOR CallbackImpl(const AddBiasAndWrite& config) : config(config) {}
- CPU_ATTR void operator()(vi input, const OutputBufferInfo& info) {
+ INTGEMM_TARGET void Run(vi input, const OutputBufferInfo& info) {
auto result = kernels::add_bias(input, config.bias_addr, info.col_idx);
kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx);
}
@@ -166,11 +200,11 @@ private:
*/
template <> class CallbackImpl<CPUType::CPU_NAME, UnquantizeAndAddBiasAndWrite> {
public:
- CPU_ATTR CallbackImpl(const UnquantizeAndAddBiasAndWrite& config) : config(config) {
+ explicit INTGEMM_TARGET_CONSTRUCTOR CallbackImpl(const UnquantizeAndAddBiasAndWrite& config) : config(config) {
unquant_mult = set1_ps<vf>(config.unquant_mult);
}
- CPU_ATTR void operator()(vi input, const OutputBufferInfo& info) {
+ INTGEMM_TARGET void Run(vi input, const OutputBufferInfo& info) {
// Workaround gcc 5 internal compiler error that can't read register members in debug.
vf mult_reg;
#if !defined(__OPTIMIZE__) && (__GNUC__ == 5) && !defined(__clang__) && !defined(__INTEL_COMPILER)
@@ -187,11 +221,38 @@ private:
UnquantizeAndAddBiasAndWrite config;
};
+/*
+ * UnquantizeAndAddBiasAndWrite
+ */
+template <> class CallbackImpl<CPUType::CPU_NAME, UnquantizeAndAddBiasAndWriteRelu> {
+public:
+ explicit INTGEMM_TARGET_CONSTRUCTOR CallbackImpl(const UnquantizeAndAddBiasAndWriteRelu& config) : config(config) {
+ unquant_mult = set1_ps<vf>(config.unquant_mult);
+ }
+
+ INTGEMM_TARGET void Run(vi input, const OutputBufferInfo& info) {
+ // Workaround gcc 5 internal compiler error that can't read register members in debug.
+ vf mult_reg;
+#if !defined(__OPTIMIZE__) && (__GNUC__ == 5) && !defined(__clang__) && !defined(__INTEL_COMPILER)
+ asm ("vmovdqa %1, %0" : "=x" (mult_reg) : "m" (unquant_mult));
+#else
+ mult_reg = unquant_mult;
+#endif
+ auto result = kernels::unquantize(input, mult_reg);
+ result = kernels::add_bias(result, config.bias_addr, info.col_idx);
+ result = kernels::relu<float>(result);
+ kernels::write(result, config.output_addr, info.row_idx * info.cols + info.col_idx);
+ }
+private:
+ vf unquant_mult;
+ UnquantizeAndAddBiasAndWriteRelu config;
+};
+
}
}
#undef CPU_NAME
-#undef CPU_ATTR
+#undef INTGEMM_TARGET
#undef vi
#undef vf
#undef vd
diff --git a/intgemm/interleave.h b/intgemm/interleave.h
index 1ec686b..95f05ce 100644
--- a/intgemm/interleave.h
+++ b/intgemm/interleave.h
@@ -26,7 +26,10 @@ INTGEMM_INTERLEAVE_N(target, type, 32) \
INTGEMM_INTERLEAVE_N(target, type, 64)
INTGEMM_INTERLEAVE(INTGEMM_SSE2, __m128i)
+
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
INTGEMM_INTERLEAVE(INTGEMM_AVX2, __m256i)
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
INTGEMM_INTERLEAVE(INTGEMM_AVX512BW, __m512i)
#endif
@@ -42,7 +45,9 @@ target static inline void Swap(Register &a, Register &b) { \
} \
INTGEMM_SWAP(INTGEMM_SSE2, __m128i)
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
INTGEMM_SWAP(INTGEMM_AVX2, __m256i)
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
/* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
INTGEMM_SWAP(INTGEMM_AVX512BW, __m512i)
@@ -95,7 +100,9 @@ target static inline void Transpose16InLane(Register &r0, Register &r1, Register
} \
INTGEMM_TRANSPOSE16(INTGEMM_SSE2, __m128i)
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
INTGEMM_TRANSPOSE16(INTGEMM_AVX2, __m256i)
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
/* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
INTGEMM_TRANSPOSE16(INTGEMM_AVX512BW, __m512i)
diff --git a/intgemm/intgemm.cc b/intgemm/intgemm.cc
index f859b9a..9b38e08 100644
--- a/intgemm/intgemm.cc
+++ b/intgemm/intgemm.cc
@@ -1,8 +1,110 @@
#include "intgemm.h"
#include "stats.h"
+#include <stdlib.h>
+
+#include <iostream>
+
namespace intgemm {
+namespace {
+
+// Return the maximum CPU model that's found and supported at compile time.
+CPUType RealCPUID() {
+#if defined(WASM)
+ // emscripten does SSE4.1 but we only use up to SSSE3.
+ return CPUType::SSSE3;
+#elif defined(__INTEL_COMPILER)
+# ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
+ if (_may_i_use_cpu_feature(_FEATURE_AVX512_VNNI)) return CPUType::AVX512VNNI;
+# endif
+# ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
+ if (_may_i_use_cpu_feature(_FEATURE_AVX512BW)) return CPUType::AVX512BW;
+# endif
+# ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
+ if (_may_i_use_cpu_feature(_FEATURE_AVX2)) return CPUType::AVX2;
+# endif
+ if (_may_i_use_cpu_feature(_FEATURE_SSSE3)) return CPUType::SSSE3;
+ if (_may_i_use_cpu_feature(_FEATURE_SSE2)) return CPUType::SSE2;
+ return CPUType::UNSUPPORTED;
+#else
+// Not emscripten, not Intel compiler
+# if defined(_MSC_VER)
+ int regs[4];
+ int &eax = regs[0], &ebx = regs[1], &ecx = regs[2], &edx = regs[3];
+ __cpuid(regs, 0);
+ int m = eax;
+# else
+ /* gcc and clang.
+ * If intgemm is compiled by gcc 6.4.1 then dlopened into an executable
+ * compiled by gcc 7.3.0, there will be a undefined symbol __cpu_info.
+ * Work around this by calling the intrinsics more directly instead of
+ * __builtin_cpu_supports.
+ *
+ * clang 6.0.0-1ubuntu2 supports vnni but doesn't have
+ * __builtin_cpu_supports("avx512vnni")
+ * so use the hand-coded CPUID for clang.
+ */
+ unsigned int m = __get_cpuid_max(0, 0);
+ unsigned int eax, ebx, ecx, edx;
+# endif
+ if (m >= 7) {
+# if defined(_MSC_VER)
+ __cpuid(regs, 7);
+# else
+ __cpuid_count(7, 0, eax, ebx, ecx, edx);
+# endif
+# ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
+ if (ecx & (1 << 11)) return CPUType::AVX512VNNI;
+# endif
+# ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
+ if (ebx & (1 << 30)) return CPUType::AVX512BW;
+# endif
+# ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
+ if (ebx & (1 << 5)) return CPUType::AVX2;
+# endif
+ }
+ if (m >= 1) {
+# if defined(_MSC_VER)
+ __cpuid(regs, 1);
+# else
+ __cpuid_count(1, 0, eax, ebx, ecx, edx);
+# endif
+ if (ecx & (1 << 9)) return CPUType::SSSE3;
+ if (edx & (1 << 26)) return CPUType::SSE2;
+ }
+ return CPUType::UNSUPPORTED;
+#endif
+}
+
+CPUType EnvironmentCPUID() {
+#if defined(_MSC_VER)
+ char env_override[11];
+ size_t len = 0;
+ if (getenv_s(&len, env_override, sizeof(env_override), "INTGEMM_CPUID")) return CPUType::AVX512VNNI;
+ if (!len) return CPUType::AVX512VNNI;
+#else
+ const char *env_override = getenv("INTGEMM_CPUID");
+ if (!env_override) return CPUType::AVX512VNNI; /* This will be capped to actual ID */
+#endif
+ if (!strcmp(env_override, "AVX512VNNI")) return CPUType::AVX512VNNI;
+ if (!strcmp(env_override, "AVX512BW")) return CPUType::AVX512BW;
+ if (!strcmp(env_override, "AVX2")) return CPUType::AVX2;
+ if (!strcmp(env_override, "SSSE3")) return CPUType::SSSE3;
+ if (!strcmp(env_override, "SSE2")) return CPUType::SSE2;
+ std::cerr << "Unrecognized INTGEMM_CPUID " << env_override << std::endl;
+ return CPUType::AVX512VNNI;
+}
+
+} // namespace
+
+CPUType GetCPUID() {
+ static const CPUType kLocalCPU = std::min(RealCPUID(), EnvironmentCPUID());
+ return kLocalCPU;
+}
+
+const CPUType kCPU = GetCPUID();
+
float Unsupported_MaxAbsolute(const float * /*begin*/, const float * /*end*/) {
throw UnsupportedCPU();
}
@@ -11,61 +113,67 @@ MeanStd Unsupported_VectorMeanStd(const float * /*begin*/, const float * /*end*/
throw UnsupportedCPU();
}
-void (*Int16::Quantize)(const float *input, int16_t *output, float quant_mult, Index size) = ChooseCPU(avx512bw::Kernels16::Quantize, avx512bw::Kernels16::Quantize, avx2::Kernels16::Quantize, sse2::Kernels16::Quantize, sse2::Kernels16::Quantize, Unsupported_16bit::Quantize);
+void (*Int16::Quantize)(const float *input, int16_t *output, float quant_mult, Index size) = ChooseCPU(AVX512BW::Kernels16::Quantize, AVX512BW::Kernels16::Quantize, AVX2::Kernels16::Quantize, SSE2::Kernels16::Quantize, SSE2::Kernels16::Quantize, Unsupported_16bit::Quantize);
-void (*Int16::PrepareB)(const float *input, int16_t *output, float quant_mult, Index rows, Index cols) = ChooseCPU(avx512bw::Kernels16::PrepareB, avx512bw::Kernels16::PrepareB, avx2::Kernels16::PrepareB, sse2::Kernels16::PrepareB, sse2::Kernels16::PrepareB, Unsupported_16bit::PrepareB);
+void (*Int16::PrepareB)(const float *input, int16_t *output, float quant_mult, Index rows, Index cols) = ChooseCPU(AVX512BW::Kernels16::PrepareB, AVX512BW::Kernels16::PrepareB, AVX2::Kernels16::PrepareB, SSE2::Kernels16::PrepareB, SSE2::Kernels16::PrepareB, Unsupported_16bit::PrepareB);
-void (*Int16::PrepareBQuantizedTransposed)(const int16_t *input, int16_t *output, Index inner, Index B_untransposed_cols) = ChooseCPU(avx512bw::Kernels16::PrepareBQuantizedTransposed, avx512bw::Kernels16::PrepareBQuantizedTransposed, avx2::Kernels16::PrepareBQuantizedTransposed, sse2::Kernels16::PrepareBQuantizedTransposed, sse2::Kernels16::PrepareBQuantizedTransposed, Unsupported_16bit::PrepareBQuantizedTransposed);
+void (*Int16::PrepareBQuantizedTransposed)(const int16_t *input, int16_t *output, Index inner, Index B_untransposed_cols) = ChooseCPU(AVX512BW::Kernels16::PrepareBQuantizedTransposed, AVX512BW::Kernels16::PrepareBQuantizedTransposed, AVX2::Kernels16::PrepareBQuantizedTransposed, SSE2::Kernels16::PrepareBQuantizedTransposed, SSE2::Kernels16::PrepareBQuantizedTransposed, Unsupported_16bit::PrepareBQuantizedTransposed);
-void (*Int16::PrepareBTransposed)(const float *input, int16_t *output, float quant_mult, Index inner, Index B_untransposed_cols) = ChooseCPU(avx512bw::Kernels16::PrepareBTransposed, avx512bw::Kernels16::PrepareBTransposed, avx2::Kernels16::PrepareBTransposed, sse2::Kernels16::PrepareBTransposed, sse2::Kernels16::PrepareBTransposed, Unsupported_16bit::PrepareBTransposed);
+void (*Int16::PrepareBTransposed)(const float *input, int16_t *output, float quant_mult, Index inner, Index B_untransposed_cols) = ChooseCPU(AVX512BW::Kernels16::PrepareBTransposed, AVX512BW::Kernels16::PrepareBTransposed, AVX2::Kernels16::PrepareBTransposed, SSE2::Kernels16::PrepareBTransposed, SSE2::Kernels16::PrepareBTransposed, Unsupported_16bit::PrepareBTransposed);
-void (*Int16::SelectColumnsB)(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end) = ChooseCPU(avx512bw::Kernels16::SelectColumnsB, avx512bw::Kernels16::SelectColumnsB, avx2::Kernels16::SelectColumnsB, sse2::Kernels16::SelectColumnsB, sse2::Kernels16::SelectColumnsB, Unsupported_16bit::SelectColumnsB);
+void (*Int16::SelectColumnsB)(const int16_t *input, int16_t *output, Index rows, const Index *cols_begin, const Index *cols_end) = ChooseCPU(AVX512BW::Kernels16::SelectColumnsB, AVX512BW::Kernels16::SelectColumnsB, AVX2::Kernels16::SelectColumnsB, SSE2::Kernels16::SelectColumnsB, SSE2::Kernels16::SelectColumnsB, Unsupported_16bit::SelectColumnsB);
-const char *const Int16::kName = ChooseCPU(avx512bw::Kernels16::kName, avx512bw::Kernels16::kName, avx2::Kernels16::kName, sse2::Kernels16::kName, sse2::Kernels16::kName, Unsupported_16bit::kName);
+const char *const Int16::kName = ChooseCPU(AVX512BW::Kernels16::kName, AVX512BW::Kernels16::kName, AVX2::Kernels16::kName, SSE2::Kernels16::kName, SSE2::Kernels16::kName, Unsupported_16bit::kName);
-void (*Int8::Quantize)(const float *input, int8_t *output, float quant_mult, Index size) = ChooseCPU(avx512vnni::Kernels8::Quantize, avx512bw::Kernels8::Quantize, avx2::Kernels8::Quantize, ssse3::Kernels8::Quantize, Unsupported_8bit::Quantize, Unsupported_8bit::Quantize);
+void (*Int8::Quantize)(const float *input, int8_t *output, float quant_mult, Index size) = ChooseCPU(AVX512VNNI::Kernels8::Quantize, AVX512BW::Kernels8::Quantize, AVX2::Kernels8::Quantize, SSSE3::Kernels8::Quantize, Unsupported_8bit::Quantize, Unsupported_8bit::Quantize);
-void (*Int8::QuantizeU)(const float *input, uint8_t *output, float quant_mult, Index size) = ChooseCPU(avx512vnni::Kernels8::QuantizeU, avx512bw::Kernels8::QuantizeU, avx2::Kernels8::QuantizeU, ssse3::Kernels8::QuantizeU, Unsupported_8bit::QuantizeU, Unsupported_8bit::QuantizeU);
+void (*Int8::QuantizeU)(const float *input, uint8_t *output, float quant_mult, Index size) = ChooseCPU(AVX512VNNI::Kernels8::QuantizeU, AVX512BW::Kernels8::QuantizeU, AVX2::Kernels8::QuantizeU, SSSE3::Kernels8::QuantizeU, Unsupported_8bit::QuantizeU, Unsupported_8bit::QuantizeU);
-void (*Int8::PrepareB)(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) = ChooseCPU(avx512vnni::Kernels8::PrepareB, avx512bw::Kernels8::PrepareB, avx2::Kernels8::PrepareB, ssse3::Kernels8::PrepareB, Unsupported_8bit::PrepareB, Unsupported_8bit::PrepareB);
+void (*Int8::PrepareB)(const float *input, int8_t *output, float quant_mult, Index rows, Index cols) = ChooseCPU(AVX512VNNI::Kernels8::PrepareB, AVX512BW::Kernels8::PrepareB, AVX2::Kernels8::PrepareB, SSSE3::Kernels8::PrepareB, Unsupported_8bit::PrepareB, Unsupported_8bit::PrepareB);
-void (*Int8::PrepareBQuantizedTransposed)(const int8_t *input, int8_t *output, Index inner, Index B_untransposed_cols) = ChooseCPU(avx512bw::Kernels8::PrepareBQuantizedTransposed, avx512bw::Kernels8::PrepareBQuantizedTransposed, avx2::Kernels8::PrepareBQuantizedTransposed, ssse3::Kernels8::PrepareBQuantizedTransposed, Unsupported_8bit::PrepareBQuantizedTransposed, Unsupported_8bit::PrepareBQuantizedTransposed);
+void (*Int8::PrepareBQuantizedTransposed)(const int8_t *input, int8_t *output, Index inner, Index B_untransposed_cols) = ChooseCPU(AVX512BW::Kernels8::PrepareBQuantizedTransposed, AVX512BW::Kernels8::PrepareBQuantizedTransposed, AVX2::Kernels8::PrepareBQuantizedTransposed, SSSE3::Kernels8::PrepareBQuantizedTransposed, Unsupported_8bit::PrepareBQuantizedTransposed, Unsupported_8bit::PrepareBQuantizedTransposed);
-void (*Int8::PrepareBTransposed)(const float *input, int8_t *output, float quant_mult, Index inner, Index B_untransposed_cols) = ChooseCPU(avx512bw::Kernels8::PrepareBTransposed, avx512bw::Kernels8::PrepareBTransposed, avx2::Kernels8::PrepareBTransposed, ssse3::Kernels8::PrepareBTransposed, Unsupported_8bit::PrepareBTransposed, Unsupported_8bit::PrepareBTransposed);
+void (*Int8::PrepareBTransposed)(const float *input, int8_t *output, float quant_mult, Index inner, Index B_untransposed_cols) = ChooseCPU(AVX512BW::Kernels8::PrepareBTransposed, AVX512BW::Kernels8::PrepareBTransposed, AVX2::Kernels8::PrepareBTransposed, SSSE3::Kernels8::PrepareBTransposed, Unsupported_8bit::PrepareBTransposed, Unsupported_8bit::PrepareBTransposed);
-void (*Int8::SelectColumnsB)(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) = ChooseCPU(avx512vnni::Kernels8::SelectColumnsB, avx512bw::Kernels8::SelectColumnsB, avx2::Kernels8::SelectColumnsB, ssse3::Kernels8::SelectColumnsB, Unsupported_8bit::SelectColumnsB, Unsupported_8bit::SelectColumnsB);
+void (*Int8::SelectColumnsB)(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) = ChooseCPU(AVX512VNNI::Kernels8::SelectColumnsB, AVX512BW::Kernels8::SelectColumnsB, AVX2::Kernels8::SelectColumnsB, SSSE3::Kernels8::SelectColumnsB, Unsupported_8bit::SelectColumnsB, Unsupported_8bit::SelectColumnsB);
-const char *const Int8::kName = ChooseCPU(avx512vnni::Kernels8::kName, avx512bw::Kernels8::kName, avx2::Kernels8::kName, ssse3::Kernels8::kName, Unsupported_8bit::kName, Unsupported_8bit::kName);
+const char *const Int8::kName = ChooseCPU(AVX512VNNI::Kernels8::kName, AVX512BW::Kernels8::kName, AVX2::Kernels8::kName, SSSE3::Kernels8::kName, Unsupported_8bit::kName, Unsupported_8bit::kName);
-void (*Int8Shift::QuantizeU)(const float *input, uint8_t *output, float quant_mult, Index size) = ChooseCPU(avx512vnni::Kernels8::QuantizeU, avx512bw::Kernels8::QuantizeU, avx2::Kernels8::QuantizeU, ssse3::Kernels8::QuantizeU, Unsupported_8bit::QuantizeU, Unsupported_8bit::QuantizeU);
+void (*Int8Shift::QuantizeU)(const float *input, uint8_t *output, float quant_mult, Index size) = ChooseCPU(AVX512VNNI::Kernels8::QuantizeU, AVX512BW::Kernels8::QuantizeU, AVX2::Kernels8::QuantizeU, SSSE3::Kernels8::QuantizeU, Unsupported_8bit::QuantizeU, Unsupported_8bit::QuantizeU);
-const char *const Int8Shift::kName = ChooseCPU(avx512vnni::Kernels8::kName, avx512bw::Kernels8::kName, avx2::Kernels8::kName, ssse3::Kernels8::kName, Unsupported_8bit::kName, Unsupported_8bit::kName);
-
-const CPUType kCPU = ChooseCPU(CPUType::AVX512VNNI, CPUType::AVX512BW, CPUType::AVX2, CPUType::SSSE3, CPUType::SSE2, CPUType::UNSUPPORTED);
+const char *const Int8Shift::kName = ChooseCPU(AVX512VNNI::Kernels8::kName, AVX512BW::Kernels8::kName, AVX2::Kernels8::kName, SSSE3::Kernels8::kName, Unsupported_8bit::kName, Unsupported_8bit::kName);
+#if !defined(INTGEMM_COMPILER_SUPPORTS_AVX2)
+namespace AVX2{
+using SSE2::MaxAbsolute;
+using SSE2::VectorMeanStd;
+} // namespace AVX2
+#endif
#if !defined(INTGEMM_COMPILER_SUPPORTS_AVX512BW)
-namespace avx512bw {
-using avx2::MaxAbsolute;
-using avx2::VectorMeanStd;
-} // namespace avx512bw
+namespace AVX512BW {
+using AVX2::MaxAbsolute;
+using AVX2::VectorMeanStd;
+} // namespace AVX512BW
#endif
-float (*MaxAbsolute)(const float *begin, const float *end) = ChooseCPU(avx512bw::MaxAbsolute, avx512bw::MaxAbsolute, avx2::MaxAbsolute, sse2::MaxAbsolute, sse2::MaxAbsolute, Unsupported_MaxAbsolute);
+float (*MaxAbsolute)(const float *begin, const float *end) = ChooseCPU(AVX512BW::MaxAbsolute, AVX512BW::MaxAbsolute, AVX2::MaxAbsolute, SSE2::MaxAbsolute, SSE2::MaxAbsolute, Unsupported_MaxAbsolute);
-MeanStd (*VectorMeanStd)(const float *begin, const float *end, bool absolute) = ChooseCPU(avx512bw::VectorMeanStd, avx512bw::VectorMeanStd, avx2::VectorMeanStd, sse2::VectorMeanStd, sse2::VectorMeanStd, Unsupported_VectorMeanStd);
+MeanStd (*VectorMeanStd)(const float *begin, const float *end, bool absolute) = ChooseCPU(AVX512BW::VectorMeanStd, AVX512BW::VectorMeanStd, AVX2::VectorMeanStd, SSE2::VectorMeanStd, SSE2::VectorMeanStd, Unsupported_VectorMeanStd);
constexpr const char *const Unsupported_16bit::kName;
constexpr const char *const Unsupported_8bit::kName;
-constexpr const char *const sse2::Kernels16::kName;
-constexpr const char *const ssse3::Kernels8::kName;
-constexpr const char *const avx2::Kernels8::kName;
-constexpr const char *const avx2::Kernels16::kName;
+constexpr const char *const SSE2::Kernels16::kName;
+constexpr const char *const SSSE3::Kernels8::kName;
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
+constexpr const char *const AVX2::Kernels8::kName;
+constexpr const char *const AVX2::Kernels16::kName;
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
-constexpr const char *const avx512bw::Kernels8::kName;
-constexpr const char *const avx512bw::Kernels16::kName;
+constexpr const char *const AVX512BW::Kernels8::kName;
+constexpr const char *const AVX512BW::Kernels16::kName;
#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
-constexpr const char *const avx512vnni::Kernels8::kName;
+constexpr const char *const AVX512VNNI::Kernels8::kName;
#endif
}
diff --git a/intgemm/intgemm.h b/intgemm/intgemm.h
index 8e2da02..977210d 100644
--- a/intgemm/intgemm.h
+++ b/intgemm/intgemm.h
@@ -49,11 +49,14 @@
#include "avx512_gemm.h"
#include "avx512vnni_gemm.h"
-#if defined(__INTEL_COMPILER)
+#if defined(WASM)
+// No header for CPUID since it's hard-coded.
+#elif defined(__INTEL_COMPILER)
#include <immintrin.h>
#elif defined(_MSC_VER)
#include <intrin.h>
-#elif defined(__GNUC__) || defined(__clang__)
+#else
+// Assume GCC and clang style.
#include <cpuid.h>
#endif
@@ -124,17 +127,25 @@ struct Unsupported_8bit {
#ifndef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
// These won't ever be called in this capacity, but it does let the code below compile.
-namespace avx512vnni {
+namespace AVX512VNNI {
typedef Unsupported_8bit Kernels8;
-} // namespace avx512vnni
+} // namespace AVX512VNNI
#endif
#ifndef INTGEMM_COMPILER_SUPPORTS_AVX512BW
-namespace avx512bw {
+namespace AVX512BW {
+typedef Unsupported_8bit Kernels8;
+typedef Unsupported_16bit Kernels16;
+} // namespace AVX512BW
+#endif
+#ifndef INTGEMM_COMPILER_SUPPORTS_AVX2
+namespace AVX2 {
typedef Unsupported_8bit Kernels8;
typedef Unsupported_16bit Kernels16;
-} // namespace avx512bw
+} // namespace AVX2
#endif
+CPUType GetCPUID();
+
/* Returns:
* axx512vnni if the CPU supports AVX512VNNI
*
@@ -148,72 +159,9 @@ typedef Unsupported_16bit Kernels16;
*
* unsupported otherwise
*/
-template <class T> T ChooseCPU(T
-#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
- avx512vnni
-#endif
- , T
-#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
- avx512bw
-#endif
- , T avx2, T ssse3, T sse2, T unsupported) {
-#if defined(__INTEL_COMPILER)
-# ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
- if (_may_i_use_cpu_feature(_FEATURE_AVX512_VNNI)) return avx512vnni;
-# endif
-# ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
- if (_may_i_use_cpu_feature(_FEATURE_AVX512BW)) return avx512bw;
-# endif
- if (_may_i_use_cpu_feature(_FEATURE_AVX2)) return avx2;
- if (_may_i_use_cpu_feature(_FEATURE_SSSE3)) return ssse3;
- if (_may_i_use_cpu_feature(_FEATURE_SSE2)) return sse2;
- return unsupported;
-#else
-// Everybody except Intel compiler.
-# if defined(_MSC_VER)
- int regs[4];
- int &eax = regs[0], &ebx = regs[1], &ecx = regs[2], &edx = regs[3];
- __cpuid(regs, 0);
- int m = eax;
-# else
- /* gcc and clang.
- * If intgemm is compiled by gcc 6.4.1 then dlopened into an executable
- * compiled by gcc 7.3.0, there will be a undefined symbol __cpu_info.
- * Work around this by calling the intrinsics more directly instead of
- * __builtin_cpu_supports.
- *
- * clang 6.0.0-1ubuntu2 supports vnni but doesn't have
- * __builtin_cpu_supports("avx512vnni")
- * so use the hand-coded CPUID for clang.
- */
- unsigned int m = __get_cpuid_max(0, 0);
- unsigned int eax, ebx, ecx, edx;
-# endif
- if (m >= 7) {
-# if defined(_MSC_VER)
- __cpuid(regs, 7);
-# else
- __cpuid_count(7, 0, eax, ebx, ecx, edx);
-# endif
-# ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
- if (ecx & (1 << 11)) return avx512vnni;
-# endif
-# ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
- if (ebx & (1 << 30)) return avx512bw;
-# endif
- if (ebx & (1 << 5)) return avx2;
- }
- if (m >= 1) {
-# if defined(_MSC_VER)
- __cpuid(regs, 1);
-# else
- __cpuid_count(1, 0, eax, ebx, ecx, edx);
-# endif
- if (ecx & (1 << 9)) return ssse3;
- if (edx & (1 << 26)) return sse2;
- }
- return unsupported;
-#endif
+template <class T> T ChooseCPU(T avx512vnni, T avx512bw, T avx2, T ssse3, T sse2, T unsupported) {
+ const T ret[] = {unsupported, sse2, ssse3, avx2, avx512bw, avx512vnni};
+ return ret[(int)GetCPUID()];
}
struct TileInfo {
@@ -280,7 +228,7 @@ private:
};
template <typename Callback>
-void (*Int8::MultiplyImpl<Callback>::run)(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(OMPParallelWrap<Callback, avx512vnni::Kernels8>, OMPParallelWrap<Callback, avx512bw::Kernels8>, OMPParallelWrap<Callback, avx2::Kernels8>, OMPParallelWrap<Callback, ssse3::Kernels8>, Unsupported_8bit::Multiply<Callback>, Unsupported_8bit::Multiply<Callback>);
+void (*Int8::MultiplyImpl<Callback>::run)(const int8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(OMPParallelWrap<Callback, AVX512VNNI::Kernels8>, OMPParallelWrap<Callback, AVX512BW::Kernels8>, OMPParallelWrap<Callback, AVX2::Kernels8>, OMPParallelWrap<Callback, SSSE3::Kernels8>, Unsupported_8bit::Multiply<Callback>, Unsupported_8bit::Multiply<Callback>);
/*
* 8-bit matrix multiplication with shifting A by 127
@@ -344,14 +292,14 @@ private:
template <class Callback>
void (*Int8Shift::MultiplyImpl<Callback>::run)(const uint8_t *A, const int8_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(
- OMPParallelWrap8Shift<Callback, avx512vnni::Kernels8>,
- OMPParallelWrap8Shift<Callback, avx512bw::Kernels8>,
- OMPParallelWrap8Shift<Callback, avx2::Kernels8>,
- OMPParallelWrap8Shift<Callback, ssse3::Kernels8>,
+ OMPParallelWrap8Shift<Callback, AVX512VNNI::Kernels8>,
+ OMPParallelWrap8Shift<Callback, AVX512BW::Kernels8>,
+ OMPParallelWrap8Shift<Callback, AVX2::Kernels8>,
+ OMPParallelWrap8Shift<Callback, SSSE3::Kernels8>,
Unsupported_8bit::Multiply8Shift<Callback>, Unsupported_8bit::Multiply8Shift<Callback>);
template <class Callback>
-void (*Int8Shift::PrepareBiasImpl<Callback>::run)(const int8_t *B, Index width, Index B_cols, Callback callback) = ChooseCPU(avx512vnni::Kernels8::PrepareBias<Callback>, avx512bw::Kernels8::PrepareBias<Callback>, avx2::Kernels8::PrepareBias<Callback>, ssse3::Kernels8::PrepareBias<Callback>, ssse3::Kernels8::PrepareBias<Callback>, Unsupported_8bit::PrepareBias);
+void (*Int8Shift::PrepareBiasImpl<Callback>::run)(const int8_t *B, Index width, Index B_cols, Callback callback) = ChooseCPU(AVX512VNNI::Kernels8::PrepareBias<Callback>, AVX512BW::Kernels8::PrepareBias<Callback>, AVX2::Kernels8::PrepareBias<Callback>, SSSE3::Kernels8::PrepareBias<Callback>, SSSE3::Kernels8::PrepareBias<Callback>, Unsupported_8bit::PrepareBias);
/*
* 16-bit matrix multiplication
@@ -407,7 +355,7 @@ private:
};
template <typename Callback>
-void (*Int16::MultiplyImpl<Callback>::run)(const int16_t *A, const int16_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(OMPParallelWrap<Callback, avx512bw::Kernels16> /*TODO VNNI 16-bit. */, OMPParallelWrap<Callback, avx512bw::Kernels16>, OMPParallelWrap<Callback, avx2::Kernels16>, OMPParallelWrap<Callback, sse2::Kernels16>, OMPParallelWrap<Callback, sse2::Kernels16>, Unsupported_16bit::Multiply<Callback>);
+void (*Int16::MultiplyImpl<Callback>::run)(const int16_t *A, const int16_t *B, Index A_rows, Index width, Index B_cols, Callback callback) = ChooseCPU(OMPParallelWrap<Callback, AVX512BW::Kernels16> /*TODO VNNI 16-bit. */, OMPParallelWrap<Callback, AVX512BW::Kernels16>, OMPParallelWrap<Callback, AVX2::Kernels16>, OMPParallelWrap<Callback, SSE2::Kernels16>, OMPParallelWrap<Callback, SSE2::Kernels16>, Unsupported_16bit::Multiply<Callback>);
extern const CPUType kCPU;
diff --git a/intgemm/intgemm_config.h.in b/intgemm/intgemm_config.h.in
index 920e9ae..a2c8cbd 100644
--- a/intgemm/intgemm_config.h.in
+++ b/intgemm/intgemm_config.h.in
@@ -1,4 +1,5 @@
#pragma once
+#cmakedefine INTGEMM_COMPILER_SUPPORTS_AVX2
#cmakedefine INTGEMM_COMPILER_SUPPORTS_AVX512BW
#cmakedefine INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
diff --git a/intgemm/intrinsics.h b/intgemm/intrinsics.h
index 480f421..9f370cd 100644
--- a/intgemm/intrinsics.h
+++ b/intgemm/intrinsics.h
@@ -5,8 +5,13 @@
#include <tmmintrin.h>
#include <emmintrin.h>
-#include <immintrin.h>
#include <xmmintrin.h>
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
+#include <immintrin.h>
+#endif
+#ifdef INTGEMM_WORMHOLE
+#include <wasm_simd128.h>
+#endif
#include <cstdint>
@@ -90,10 +95,20 @@ template <> INTGEMM_SSE2 inline __m128 loadu_ps(const float* mem_addr) {
return _mm_loadu_ps(mem_addr);
}
INTGEMM_SSE2 static inline __m128i madd_epi16(__m128i first, __m128i second) {
+// https://bugzilla.mozilla.org/show_bug.cgi?id=1672160
+#ifdef INTGEMM_WORMHOLE
+ return wasm_v8x16_shuffle(first, second, 31, 0, 30, 2, 29, 4, 28, 6, 27, 8, 26, 10, 25, 12, 24, 2 /* PMADDWD */);
+#else
return _mm_madd_epi16(first, second);
+#endif
}
INTGEMM_SSSE3 static inline __m128i maddubs_epi16(__m128i first, __m128i second) {
+// https://bugzilla.mozilla.org/show_bug.cgi?id=1672160
+#ifdef INTGEMM_WORMHOLE
+ return wasm_v8x16_shuffle(first, second, 31, 0, 30, 2, 29, 4, 28, 6, 27, 8, 26, 10, 25, 12, 24, 1 /* PMADDUBSW */);
+#else
return _mm_maddubs_epi16(first, second);
+#endif
}
/*
* Missing max_epi8 for SSE2
@@ -215,6 +230,8 @@ INTGEMM_SSE2 static inline __m128i xor_si(__m128i a, __m128i b) {
* AVX2
*
*/
+
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
INTGEMM_AVX2 static inline __m256i abs_epi8(__m256i arg) {
return _mm256_abs_epi8(arg);
}
@@ -390,6 +407,7 @@ INTGEMM_AVX2 static inline __m256i unpackhi_epi64(__m256i a, __m256i b) {
INTGEMM_AVX2 static inline __m256i xor_si(__m256i a, __m256i b) {
return _mm256_xor_si256(a, b);
}
+#endif
/*
*
diff --git a/intgemm/kernels.h b/intgemm/kernels.h
index ee35966..57036f4 100644
--- a/intgemm/kernels.h
+++ b/intgemm/kernels.h
@@ -12,9 +12,11 @@
#include "kernels/implementations.inl"
#undef KERNELS_THIS_IS_SSE2
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
#define KERNELS_THIS_IS_AVX2
#include "kernels/implementations.inl"
#undef KERNELS_THIS_IS_AVX2
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
#define KERNELS_THIS_IS_AVX512BW
diff --git a/intgemm/multiply.h b/intgemm/multiply.h
index e201e09..8d411f3 100644
--- a/intgemm/multiply.h
+++ b/intgemm/multiply.h
@@ -13,6 +13,7 @@ INTGEMM_SSE2 static inline dvector_t<CPUType::SSE2, int> PermuteSummer(__m128i p
return { pack0123, pack4567 };
}
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
INTGEMM_AVX2 static inline __m256i PermuteSummer(__m256i pack0123, __m256i pack4567) {
// This instruction generates 1s 2s 3s 4s 5f 6f 7f 8f
__m256i rev = _mm256_permute2f128_si256(pack0123, pack4567, 0x21);
@@ -20,7 +21,7 @@ INTGEMM_AVX2 static inline __m256i PermuteSummer(__m256i pack0123, __m256i pack4
__m256i blended = _mm256_blend_epi32(pack0123, pack4567, 0xf0);
return _mm256_add_epi32(rev, blended);
}
-
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
/* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
@@ -99,7 +100,9 @@ target inline Register Pack0123(Register sum0, Register sum1, Register sum2, Reg
} \
INTGEMM_PACK0123(INTGEMM_SSE2, __m128i)
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
INTGEMM_PACK0123(INTGEMM_AVX2, __m256i)
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
/* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */
INTGEMM_PACK0123(INTGEMM_AVX512BW, __m512i)
@@ -107,14 +110,16 @@ INTGEMM_PACK0123(INTGEMM_AVX512BW, __m512i)
template <typename Callback>
INTGEMM_SSE2 static inline void RunCallback(Callback& callback_impl, dvector_t<CPUType::SSE2, int> total, Index row_idx, Index col_idx, Index rows, Index cols) {
- callback_impl(total.first, callbacks::OutputBufferInfo(row_idx, col_idx, rows, cols));
- callback_impl(total.second, callbacks::OutputBufferInfo(row_idx, col_idx + 4, rows, cols));
+ callback_impl.Run(total.first, callbacks::OutputBufferInfo(row_idx, col_idx, rows, cols));
+ callback_impl.Run(total.second, callbacks::OutputBufferInfo(row_idx, col_idx + 4, rows, cols));
}
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
template <typename Callback>
INTGEMM_AVX2 static inline void RunCallback(Callback& callback_impl, vector_t<CPUType::AVX2, int> total, Index row_idx, Index col_idx, Index rows, Index cols) {
- callback_impl(total, callbacks::OutputBufferInfo(row_idx, col_idx, rows, cols));
+ callback_impl.Run(total, callbacks::OutputBufferInfo(row_idx, col_idx, rows, cols));
}
+#endif
// 16-bit multiplier for INTGEMM_SSE2, INTGEMM_AVX2, and AVX512.
// C = A * B * unquant_mult
@@ -374,7 +379,7 @@ template <typename Callback> target static void Multiply(const int16_t *A, const
* 256-bit. We had to wait for INTGEMM_AVX2 to get 256-bit versions of vpsignb and
* vpmaddubsw. That's why this code is generic over 128-bit or 256-bit.
*/
-
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
INTGEMM_AVX2 inline static void InnerINTGEMM_AVX2(
__m256i a, const __m256i *b,
__m256i &sum0, __m256i &sum1, __m256i &sum2, __m256i &sum3,
@@ -514,7 +519,7 @@ INTGEMM_AVX2 inline static void InnerINTGEMM_AVX2(
sum7 = adds_epi16(sum7, maddubs_epi16(a_positive, sign_epi8(b[7], a)));
#endif
}
-
+#endif
// For INTGEMM_SSSE3 without AVX
INTGEMM_SSSE3 inline static void InnerINTGEMM_SSSE3(
diff --git a/intgemm/sse2_gemm.h b/intgemm/sse2_gemm.h
index cd49efe..cd855a6 100644
--- a/intgemm/sse2_gemm.h
+++ b/intgemm/sse2_gemm.h
@@ -9,7 +9,7 @@
// 8 bit is in ssse3_gemm.h
namespace intgemm {
-namespace sse2 {
+namespace SSE2 {
INTGEMM_SSE2 inline __m128i QuantizerGrab(const float *input, const __m128 quant_mult_reg) {
return kernels::quantize(loadu_ps<__m128>(input), quant_mult_reg);
@@ -80,5 +80,5 @@ struct Kernels16 {
static const CPUType kUses = CPUType::SSE2;
};
-} // namespace sse2
+} // namespace SSE2
} // namespace intgemm
diff --git a/intgemm/ssse3_gemm.h b/intgemm/ssse3_gemm.h
index 865fe12..db403bd 100644
--- a/intgemm/ssse3_gemm.h
+++ b/intgemm/ssse3_gemm.h
@@ -11,7 +11,7 @@
// 16-bit is in sse2_gemm.h
namespace intgemm {
-namespace ssse3 {
+namespace SSSE3 {
INTGEMM_SSSE3 inline __m128i QuantizerGrab(const float *input, const __m128 quant_mult_reg) {
return kernels::quantize(loadu_ps<__m128>(input), quant_mult_reg);
@@ -131,12 +131,12 @@ struct Kernels8 {
static const Index kBTileRow = 16;
static const Index kBTileCol = 8;
- INTGEMM_PREPARE_B_8(INTGEMM_SSSE3, ssse3::QuantizeTile8)
+ INTGEMM_PREPARE_B_8(INTGEMM_SSSE3, SSSE3::QuantizeTile8)
INTGEMM_PREPARE_B_QUANTIZED_TRANSPOSED(INTGEMM_SSSE3, int8_t)
INTGEMM_PREPARE_B_TRANSPOSED(INTGEMM_SSSE3, QuantizeTile8, int8_t)
INTGEMM_SSSE3 static void SelectColumnsB(const int8_t *input, int8_t *output, Index rows, const Index *cols_begin, const Index *cols_end) {
- ssse3::SelectColumnsOfB((const __m128i*)input, (__m128i*)output, rows, cols_begin, cols_end);
+ SSSE3::SelectColumnsOfB((const __m128i*)input, (__m128i*)output, rows, cols_begin, cols_end);
}
INTGEMM_MULTIPLY8(__m128i, INTGEMM_SSSE3, CPUType::SSE2)
@@ -150,5 +150,5 @@ struct Kernels8 {
static const CPUType kUses = CPUType::SSSE3;
};
-} // namespace ssse3
+} // namespace SSSE3
} // namespace intgemm
diff --git a/intgemm/stats.h b/intgemm/stats.h
index 6f9eda2..9573c4b 100644
--- a/intgemm/stats.h
+++ b/intgemm/stats.h
@@ -32,12 +32,14 @@ INTGEMM_SSE2 static inline float AddFloat32(__m128 a) {
return *reinterpret_cast<float*>(&a);
}
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
INTGEMM_AVX2 static inline float MaxFloat32(__m256 a) {
return MaxFloat32(max_ps(_mm256_castps256_ps128(a), _mm256_extractf128_ps(a, 1)));
}
INTGEMM_AVX2 static inline float AddFloat32(__m256 a) {
return AddFloat32(add_ps(_mm256_castps256_ps128(a), _mm256_extractf128_ps(a, 1)));
}
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
// Find the maximum float.
@@ -61,9 +63,11 @@ constexpr int32_t kFloatAbsoluteMask = 0x7fffffff;
#include "stats.inl"
#undef INTGEMM_THIS_IS_SSE2
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
#define INTGEMM_THIS_IS_AVX2
#include "stats.inl"
#undef INTGEMM_THIS_IS_AVX2
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
#define INTGEMM_THIS_IS_AVX512DQ
diff --git a/intgemm/stats.inl b/intgemm/stats.inl
index d6a850e..68a5b8e 100644
--- a/intgemm/stats.inl
+++ b/intgemm/stats.inl
@@ -1,12 +1,12 @@
/* This file is included multiple times, once per architecture. */
#if defined(INTGEMM_THIS_IS_AVX512DQ)
-#define INTGEMM_ARCH avx512bw
+#define INTGEMM_ARCH AVX512BW
#define INTGEMM_TARGET INTGEMM_AVX512DQ
#elif defined(INTGEMM_THIS_IS_AVX2)
-#define INTGEMM_ARCH avx2
+#define INTGEMM_ARCH AVX2
#define INTGEMM_TARGET INTGEMM_AVX2
#elif defined(INTGEMM_THIS_IS_SSE2)
-#define INTGEMM_ARCH sse2
+#define INTGEMM_ARCH SSE2
#define INTGEMM_TARGET INTGEMM_SSE2
#else
#error Included with unexpected architecture
diff --git a/intgemm/types.h b/intgemm/types.h
index da0429f..81b38af 100644
--- a/intgemm/types.h
+++ b/intgemm/types.h
@@ -1,10 +1,26 @@
#pragma once
#include <exception>
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
#include <immintrin.h>
+#endif
+#include <emmintrin.h>
-#if defined(_MSC_VER)
+#if defined(_MSC_VER) || defined(__INTEL_COMPILER)
/* MSVC does not appear to have target attributes but is also fine with just
* using intrinsics anywhere.
+ *
+ * The Intel compiler has a bug whereby constructors with target attributes do
+ * not link. Like this program doesn't compile with icpc:
+ * class Foo {
+ * public:
+ * __attribute__ ((target ("avx2"))) Foo() {}
+ * };
+ * int main() { Foo a; }
+ *
+ * It appears to be erroneously activating function multiversioning when only
+ * one version of a constructor with target attributes is defined. Normal
+ * methods with one target attribute work fine. The Intel compiler also allows
+ * intrinsics without any target attributes so we just leave them blank.
*/
#define INTGEMM_SSE2
#define INTGEMM_SSSE3
@@ -14,23 +30,14 @@
#define INTGEMM_AVX512DQ
#define INTGEMM_AVX512VNNI
#else
- /* gcc, clang, and Intel compiler */
+ /* gcc and clang take lists of all the flavors */
#define INTGEMM_SSE2 __attribute__ ((target ("sse2")))
#define INTGEMM_SSSE3 __attribute__ ((target ("ssse3")))
#define INTGEMM_AVX2 __attribute__ ((target ("avx2")))
- #if defined(__INTEL_COMPILER)
- /* Intel compiler might not have AVX512 flavors but lets you use them anyway */
- #define INTGEMM_AVX512F __attribute__ ((target ("avx512f")))
- #define INTGEMM_AVX512BW __attribute__ ((target ("avx512f")))
- #define INTGEMM_AVX512DQ __attribute__ ((target ("avx512f")))
- #define INTGEMM_AVX512VNNI __attribute__ ((target ("avx512f")))
- #else
- /* gcc and clang take lists of all the flavors */
- #define INTGEMM_AVX512F __attribute__ ((target ("avx512f")))
- #define INTGEMM_AVX512BW __attribute__ ((target ("avx512f,avx512bw,avx512dq")))
- #define INTGEMM_AVX512DQ __attribute__ ((target ("avx512f,avx512bw,avx512dq")))
- #define INTGEMM_AVX512VNNI __attribute__ ((target ("avx512f,avx512bw,avx512dq,avx512vnni")))
- #endif
+ #define INTGEMM_AVX512F __attribute__ ((target ("avx512f")))
+ #define INTGEMM_AVX512BW __attribute__ ((target ("avx512f,avx512bw,avx512dq")))
+ #define INTGEMM_AVX512DQ __attribute__ ((target ("avx512f,avx512bw,avx512dq")))
+ #define INTGEMM_AVX512VNNI __attribute__ ((target ("avx512f,avx512bw,avx512dq,avx512vnni")))
#endif
namespace intgemm {
@@ -51,11 +58,11 @@ typedef unsigned int Index;
// If you want to detect the CPU and dispatch yourself, here's what to use:
enum class CPUType {
UNSUPPORTED = 0,
- SSE2,
- SSSE3,
- AVX2,
- AVX512BW,
- AVX512VNNI
+ SSE2 = 1,
+ SSSE3 = 2,
+ AVX2 = 3,
+ AVX512BW = 4,
+ AVX512VNNI = 5
};
// Running CPU type. This is defined in intgemm.cc (as the dispatcher).
@@ -67,28 +74,30 @@ struct MeanStd {
};
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
-namespace avx512vnni {
+namespace AVX512VNNI {
typedef __m512i Register;
typedef __m512 FRegister;
-} // namespace avx512vnni
+} // namespace AVX512VNNI
#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
-namespace avx512bw {
+namespace AVX512BW {
typedef __m512i Register;
typedef __m512 FRegister;
-} // namespace avx512bw
+} // namespace AVX512BW
#endif
-namespace avx2 {
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
+namespace AVX2 {
typedef __m256i Register;
typedef __m256 FRegister;
-} // namespace avx2
-namespace ssse3 {
+} // namespace AVX2
+#endif
+namespace SSSE3 {
typedef __m128i Register;
typedef __m128 FRegister;
-} // namespace ssse3
-namespace sse2 {
+} // namespace SSSE3
+namespace SSE2 {
typedef __m128i Register;
typedef __m128 FRegister;
-} // namespace sse2
+} // namespace SSE2
} // namespace intgemm
diff --git a/intgemm/vec_traits.h b/intgemm/vec_traits.h
index 86265b2..948dae1 100644
--- a/intgemm/vec_traits.h
+++ b/intgemm/vec_traits.h
@@ -18,11 +18,13 @@ template <> struct vector_s<CPUType::SSSE3, int16_t> { using type = __m128i; };
template <> struct vector_s<CPUType::SSSE3, int> { using type = __m128i; };
template <> struct vector_s<CPUType::SSSE3, float> { using type = __m128; };
template <> struct vector_s<CPUType::SSSE3, double> { using type = __m128d; };
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
template <> struct vector_s<CPUType::AVX2, int8_t> { using type = __m256i; };
template <> struct vector_s<CPUType::AVX2, int16_t> { using type = __m256i; };
template <> struct vector_s<CPUType::AVX2, int> { using type = __m256i; };
template <> struct vector_s<CPUType::AVX2, float> { using type = __m256; };
template <> struct vector_s<CPUType::AVX2, double> { using type = __m256d; };
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
template <> struct vector_s<CPUType::AVX512BW, int8_t> { using type = __m512i; };
template <> struct vector_s<CPUType::AVX512BW, int16_t> { using type = __m512i; };
diff --git a/test/add127_test.cc b/test/add127_test.cc
index b7ce49b..c31732c 100644
--- a/test/add127_test.cc
+++ b/test/add127_test.cc
@@ -282,196 +282,209 @@ template <class Routine> void TestMultiplyShiftInt(Index A_rows, Index width, In
// Bias
TEST_CASE("PrepareBias SSSE3", "[Add127]") {
if (kCPU < CPUType::SSSE3) return;
- TestPrepareBias<ssse3::Kernels8>(256,256);
- TestPrepareBias<ssse3::Kernels8>(2048,256);
- TestPrepareBias<ssse3::Kernels8>(512,512);
+ TestPrepareBias<SSSE3::Kernels8>(256,256);
+ TestPrepareBias<SSSE3::Kernels8>(2048,256);
+ TestPrepareBias<SSSE3::Kernels8>(512,512);
}
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
TEST_CASE("PrepareBias AVX2", "[Add127]") {
if (kCPU < CPUType::AVX2) return;
- TestPrepareBias<avx2::Kernels8>(256,256);
- TestPrepareBias<avx2::Kernels8>(2048,256);
- TestPrepareBias<avx2::Kernels8>(512,512);
+ TestPrepareBias<AVX2::Kernels8>(256,256);
+ TestPrepareBias<AVX2::Kernels8>(2048,256);
+ TestPrepareBias<AVX2::Kernels8>(512,512);
}
+#endif
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
TEST_CASE("PrepareBias AVX512F", "[Add127]") {
if (kCPU < CPUType::AVX512BW) return;
- #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
- TestPrepareBias<avx512bw::Kernels8>(256,256);
- TestPrepareBias<avx512bw::Kernels8>(2048,256);
- TestPrepareBias<avx512bw::Kernels8>(512,512);
- #endif
+ TestPrepareBias<AVX512BW::Kernels8>(256,256);
+ TestPrepareBias<AVX512BW::Kernels8>(2048,256);
+ TestPrepareBias<AVX512BW::Kernels8>(512,512);
}
+#endif
//A
TEST_CASE("PrepareA SSSE3", "[Add127]") {
if (kCPU < CPUType::SSSE3) return;
- TestPrepareA<ssse3::Kernels8>(64,64);
- TestPrepareA<ssse3::Kernels8>(256,256);
- TestPrepareA<ssse3::Kernels8>(512,512);
- TestPrepareA<ssse3::Kernels8>(2048,256);
+ TestPrepareA<SSSE3::Kernels8>(64,64);
+ TestPrepareA<SSSE3::Kernels8>(256,256);
+ TestPrepareA<SSSE3::Kernels8>(512,512);
+ TestPrepareA<SSSE3::Kernels8>(2048,256);
}
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
TEST_CASE("PrepareA AVX2", "[Add127]") {
if (kCPU < CPUType::AVX2) return;
- TestPrepareA<avx2::Kernels8>(64,64);
- TestPrepareA<avx2::Kernels8>(256,256);
- TestPrepareA<avx2::Kernels8>(512,512);
- TestPrepareA<avx2::Kernels8>(2048,256);
+ TestPrepareA<AVX2::Kernels8>(64,64);
+ TestPrepareA<AVX2::Kernels8>(256,256);
+ TestPrepareA<AVX2::Kernels8>(512,512);
+ TestPrepareA<AVX2::Kernels8>(2048,256);
}
+#endif
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
TEST_CASE("PrepareA AVX512F", "[Add127]") {
if (kCPU < CPUType::AVX512BW) return;
- #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
- TestPrepareA<avx512bw::Kernels8>(64,64);
- TestPrepareA<avx512bw::Kernels8>(256,256);
- TestPrepareA<avx512bw::Kernels8>(512,512);
- TestPrepareA<avx512bw::Kernels8>(2048,256);
- #endif
+ TestPrepareA<AVX512BW::Kernels8>(64,64);
+ TestPrepareA<AVX512BW::Kernels8>(256,256);
+ TestPrepareA<AVX512BW::Kernels8>(512,512);
+ TestPrepareA<AVX512BW::Kernels8>(2048,256);
}
+#endif
// Multiply
TEST_CASE ("Multiply SSSE3 8bit Shift with bias", "[Add127]") {
if (kCPU < CPUType::SSSE3) return;
- TestMultiplyBiasNew<ssse3::Kernels8>(1, 64, 8, 0.11f, 0.1f, 0.06f, 0.05f);
- TestMultiplyBiasNew<ssse3::Kernels8>(8, 256, 256, 0.45f, 0.54f, 0.17f, 0.16f);
- TestMultiplyBiasNew<ssse3::Kernels8>(8, 2048, 256, 1.7f, 1.7f, 0.46f, 0.43f);
- TestMultiplyBiasNew<ssse3::Kernels8>(320, 256, 256, 0.56f, 0.64f, 0.16f, 0.15f);
- TestMultiplyBiasNew<ssse3::Kernels8>(472, 256, 256, 0.46f, 0.62f, 0.17f, 0.16f);
- TestMultiplyBiasNew<ssse3::Kernels8>(248, 256, 256, 0.48f, 0.64f, 0.16f, 0.15f);
- TestMultiplyBiasNew<ssse3::Kernels8>(200, 256, 256, 0.55f, 0.74f, 0.17f, 0.16f);
+ TestMultiplyBiasNew<SSSE3::Kernels8>(1, 64, 8, 0.11f, 0.1f, 0.06f, 0.05f);
+ TestMultiplyBiasNew<SSSE3::Kernels8>(8, 256, 256, 0.45f, 0.54f, 0.17f, 0.16f);
+ TestMultiplyBiasNew<SSSE3::Kernels8>(8, 2048, 256, 1.7f, 1.7f, 0.46f, 0.43f);
+ TestMultiplyBiasNew<SSSE3::Kernels8>(320, 256, 256, 0.56f, 0.64f, 0.16f, 0.15f);
+ TestMultiplyBiasNew<SSSE3::Kernels8>(472, 256, 256, 0.46f, 0.62f, 0.17f, 0.16f);
+ TestMultiplyBiasNew<SSSE3::Kernels8>(248, 256, 256, 0.48f, 0.64f, 0.16f, 0.15f);
+ TestMultiplyBiasNew<SSSE3::Kernels8>(200, 256, 256, 0.55f, 0.74f, 0.17f, 0.16f);
}
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
TEST_CASE ("Multiply AVX2 8bit Shift with bias", "[Add127]") {
if (kCPU < CPUType::AVX2) return;
- TestMultiplyBiasNew<avx2::Kernels8>(1, 64, 8, 0.11f, 0.11f, 0.06f, 0.05f);
- TestMultiplyBiasNew<avx2::Kernels8>(8, 256, 256, 0.49f, 0.54f, 0.17f, 0.16f);
- TestMultiplyBiasNew<avx2::Kernels8>(8, 2048, 256, 1.57f, 1.66f, 0.46f, 0.46f);
- TestMultiplyBiasNew<avx2::Kernels8>(320, 256, 256, 0.49f, 0.64f, 0.16f, 0.15f);
- TestMultiplyBiasNew<avx2::Kernels8>(472, 256, 256, 0.46f, 0.62f, 0.17f, 0.16f);
- TestMultiplyBiasNew<avx2::Kernels8>(248, 256, 256, 0.48f, 0.64f, 0.16f, 0.15f);
- TestMultiplyBiasNew<avx2::Kernels8>(200, 256, 256, 0.55f, 0.74f, 0.17f, 0.16f);
+ TestMultiplyBiasNew<AVX2::Kernels8>(1, 64, 8, 0.11f, 0.11f, 0.06f, 0.05f);
+ TestMultiplyBiasNew<AVX2::Kernels8>(8, 256, 256, 0.49f, 0.54f, 0.17f, 0.16f);
+ TestMultiplyBiasNew<AVX2::Kernels8>(8, 2048, 256, 1.57f, 1.66f, 0.46f, 0.46f);
+ TestMultiplyBiasNew<AVX2::Kernels8>(320, 256, 256, 0.49f, 0.64f, 0.16f, 0.15f);
+ TestMultiplyBiasNew<AVX2::Kernels8>(472, 256, 256, 0.46f, 0.62f, 0.17f, 0.16f);
+ TestMultiplyBiasNew<AVX2::Kernels8>(248, 256, 256, 0.48f, 0.64f, 0.16f, 0.15f);
+ TestMultiplyBiasNew<AVX2::Kernels8>(200, 256, 256, 0.55f, 0.74f, 0.17f, 0.16f);
}
+#endif
+
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
TEST_CASE ("Multiply AVX512F 8bit Shift with bias", "[Add127]") {
if (kCPU < CPUType::AVX512BW) return;
- TestMultiplyBiasNew<avx512bw::Kernels8>(1, 64, 8, 0.0001f, 0.05f, 0.03f, 0.001f);
- TestMultiplyBiasNew<avx512bw::Kernels8>(8, 256, 256, 0.0001f, 0.22f, 0.06f, 0.001f);
- TestMultiplyBiasNew<avx512bw::Kernels8>(8, 2048, 256, 0.0001f, 0.61f, 0.17f, 0.001f);
- TestMultiplyBiasNew<avx512bw::Kernels8>(320, 256, 256, 0.0001f, 0.27f, 0.06f, 0.001f);
- TestMultiplyBiasNew<avx512bw::Kernels8>(472, 256, 256, 0.0001f, 0.33f, 0.06f, 0.001f);
- TestMultiplyBiasNew<avx512bw::Kernels8>(248, 256, 256, 0.0001f, 0.27f, 0.06f, 0.001f);
- TestMultiplyBiasNew<avx512bw::Kernels8>(200, 256, 256, 0.0001f, 0.28f, 0.06f, 0.001f);
+ TestMultiplyBiasNew<AVX512BW::Kernels8>(1, 64, 8, 0.0001f, 0.05f, 0.03f, 0.001f);
+ TestMultiplyBiasNew<AVX512BW::Kernels8>(8, 256, 256, 0.0001f, 0.22f, 0.06f, 0.001f);
+ TestMultiplyBiasNew<AVX512BW::Kernels8>(8, 2048, 256, 0.0001f, 0.61f, 0.17f, 0.001f);
+ TestMultiplyBiasNew<AVX512BW::Kernels8>(320, 256, 256, 0.0001f, 0.27f, 0.06f, 0.001f);
+ TestMultiplyBiasNew<AVX512BW::Kernels8>(472, 256, 256, 0.0001f, 0.33f, 0.06f, 0.001f);
+ TestMultiplyBiasNew<AVX512BW::Kernels8>(248, 256, 256, 0.0001f, 0.27f, 0.06f, 0.001f);
+ TestMultiplyBiasNew<AVX512BW::Kernels8>(200, 256, 256, 0.0001f, 0.28f, 0.06f, 0.001f);
}
#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
TEST_CASE ("Multiply AVX512VNNI 8bit Shift with bias", "[Add127]") {
if (kCPU < CPUType::AVX512VNNI) return;
- TestMultiplyBiasNew<avx512vnni::Kernels8>(1, 64, 8, 0.0001f, 0.05f, 0.03f, 0.001f);
- TestMultiplyBiasNew<avx512vnni::Kernels8>(8, 256, 256, 0.0001f, 0.22f, 0.06f, 0.001f);
- TestMultiplyBiasNew<avx512vnni::Kernels8>(8, 2048, 256, 0.0001f, 0.61f, 0.17f, 0.001f);
- TestMultiplyBiasNew<avx512vnni::Kernels8>(320, 256, 256, 0.0001f, 0.27f, 0.06f, 0.001f);
- TestMultiplyBiasNew<avx512vnni::Kernels8>(472, 256, 256, 0.0001f, 0.33f, 0.06f, 0.001f);
- TestMultiplyBiasNew<avx512vnni::Kernels8>(248, 256, 256, 0.0001f, 0.27f, 0.06f, 0.001f);
- TestMultiplyBiasNew<avx512vnni::Kernels8>(200, 256, 256, 0.0001f, 0.28f, 0.06f, 0.001f);
+ TestMultiplyBiasNew<AVX512VNNI::Kernels8>(1, 64, 8, 0.0001f, 0.05f, 0.03f, 0.001f);
+ TestMultiplyBiasNew<AVX512VNNI::Kernels8>(8, 256, 256, 0.0001f, 0.22f, 0.06f, 0.001f);
+ TestMultiplyBiasNew<AVX512VNNI::Kernels8>(8, 2048, 256, 0.0001f, 0.61f, 0.17f, 0.001f);
+ TestMultiplyBiasNew<AVX512VNNI::Kernels8>(320, 256, 256, 0.0001f, 0.27f, 0.06f, 0.001f);
+ TestMultiplyBiasNew<AVX512VNNI::Kernels8>(472, 256, 256, 0.0001f, 0.33f, 0.06f, 0.001f);
+ TestMultiplyBiasNew<AVX512VNNI::Kernels8>(248, 256, 256, 0.0001f, 0.27f, 0.06f, 0.001f);
+ TestMultiplyBiasNew<AVX512VNNI::Kernels8>(200, 256, 256, 0.0001f, 0.28f, 0.06f, 0.001f);
}
#endif
//Multiply old vs new
TEST_CASE ("Multiply SSSE3 8bit Shift vs nonshift", "[Add127]") {
if (kCPU < CPUType::SSSE3) return;
- TestMultiplyShiftNonShift<ssse3::Kernels8>(1, 64, 8, 0.00001f, 0.1f, 0.06f, 0.00001f);
- TestMultiplyShiftNonShift<ssse3::Kernels8>(8, 256, 256, 0.00001f, 0.54f, 0.17f, 0.00001f);
- TestMultiplyShiftNonShift<ssse3::Kernels8>(8, 2048, 256, 17.9f, 1.7f, 0.46f, 4.2f); //Big difference here because the non-shift version is very bad
- TestMultiplyShiftNonShift<ssse3::Kernels8>(320, 256, 256, 1.2f, 0.64f, 0.16f, 0.006f);
- TestMultiplyShiftNonShift<ssse3::Kernels8>(472, 256, 256, 1.1f, 0.62f, 0.17f, 0.006f);
- TestMultiplyShiftNonShift<ssse3::Kernels8>(248, 256, 256, 0.9f, 0.64f, 0.16f, 0.007f);
- TestMultiplyShiftNonShift<ssse3::Kernels8>(200, 256, 256, 1, 0.74f, 0.17f, 0.006f);
+ TestMultiplyShiftNonShift<SSSE3::Kernels8>(1, 64, 8, 0.00001f, 0.1f, 0.06f, 0.00001f);
+ TestMultiplyShiftNonShift<SSSE3::Kernels8>(8, 256, 256, 0.00001f, 0.54f, 0.17f, 0.00001f);
+ TestMultiplyShiftNonShift<SSSE3::Kernels8>(8, 2048, 256, 17.9f, 1.7f, 0.46f, 4.2f); //Big difference here because the non-shift version is very bad
+ TestMultiplyShiftNonShift<SSSE3::Kernels8>(320, 256, 256, 1.2f, 0.64f, 0.16f, 0.006f);
+ TestMultiplyShiftNonShift<SSSE3::Kernels8>(472, 256, 256, 1.1f, 0.62f, 0.17f, 0.006f);
+ TestMultiplyShiftNonShift<SSSE3::Kernels8>(248, 256, 256, 0.9f, 0.64f, 0.16f, 0.007f);
+ TestMultiplyShiftNonShift<SSSE3::Kernels8>(200, 256, 256, 1, 0.74f, 0.17f, 0.006f);
}
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
TEST_CASE ("Multiply AVX2 8bit Shift vs nonshift", "[Add127]") {
if (kCPU < CPUType::AVX2) return;
- TestMultiplyShiftNonShift<avx2::Kernels8>(1, 64, 8, 0.00001f, 0.11f, 0.06f, 0.00001f);
- TestMultiplyShiftNonShift<avx2::Kernels8>(8, 256, 256, 0.00001f, 0.54f, 0.17f, 0.00001f);
- TestMultiplyShiftNonShift<avx2::Kernels8>(8, 2048, 256, 9.4f, 1.66f, 0.46f, 1.67f); //Big difference here because the non-shift version is very bad
- TestMultiplyShiftNonShift<avx2::Kernels8>(320, 256, 256, 0.0001f, 0.64f, 0.16f, 0.0001f);
- TestMultiplyShiftNonShift<avx2::Kernels8>(472, 256, 256, 0.0001f, 0.62f, 0.17f, 0.0001f);
- TestMultiplyShiftNonShift<avx2::Kernels8>(248, 256, 256, 0.0001f, 0.64f, 0.16f, 0.0001f);
- TestMultiplyShiftNonShift<avx2::Kernels8>(200, 256, 256, 0.0001f, 0.74f, 0.17f, 0.0001f);
+ TestMultiplyShiftNonShift<AVX2::Kernels8>(1, 64, 8, 0.00001f, 0.11f, 0.06f, 0.00001f);
+ TestMultiplyShiftNonShift<AVX2::Kernels8>(8, 256, 256, 0.00001f, 0.54f, 0.17f, 0.00001f);
+ TestMultiplyShiftNonShift<AVX2::Kernels8>(8, 2048, 256, 9.4f, 1.66f, 0.46f, 1.67f); //Big difference here because the non-shift version is very bad
+ TestMultiplyShiftNonShift<AVX2::Kernels8>(320, 256, 256, 0.0001f, 0.64f, 0.16f, 0.0001f);
+ TestMultiplyShiftNonShift<AVX2::Kernels8>(472, 256, 256, 0.0001f, 0.62f, 0.17f, 0.0001f);
+ TestMultiplyShiftNonShift<AVX2::Kernels8>(248, 256, 256, 0.0001f, 0.64f, 0.16f, 0.0001f);
+ TestMultiplyShiftNonShift<AVX2::Kernels8>(200, 256, 256, 0.0001f, 0.74f, 0.17f, 0.0001f);
}
+#endif
+
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
TEST_CASE ("Multiply AVX512F 8bit Shift vs nonshift", "[Add127]") {
if (kCPU < CPUType::AVX512BW) return;
- TestMultiplyShiftNonShift<avx512bw::Kernels8>(1, 64, 8, 0.0001f, 0.05f, 0.03f, 0.001f);
- TestMultiplyShiftNonShift<avx512bw::Kernels8>(8, 256, 256, 0.0001f, 0.22f, 0.06f, 0.001f);
- TestMultiplyShiftNonShift<avx512bw::Kernels8>(8, 2048, 256, 3.51f, 0.61f, 0.17f, 0.3f);
- TestMultiplyShiftNonShift<avx512bw::Kernels8>(320, 256, 256, 0.0001f, 0.27f, 0.06f, 0.001f);
- TestMultiplyShiftNonShift<avx512bw::Kernels8>(472, 256, 256, 0.0001f, 0.33f, 0.06f, 0.001f);
- TestMultiplyShiftNonShift<avx512bw::Kernels8>(248, 256, 256, 0.0001f, 0.27f, 0.06f, 0.001f);
- TestMultiplyShiftNonShift<avx512bw::Kernels8>(200, 256, 256, 0.0001f, 0.28f, 0.06f, 0.001f);
+ TestMultiplyShiftNonShift<AVX512BW::Kernels8>(1, 64, 8, 0.0001f, 0.05f, 0.03f, 0.001f);
+ TestMultiplyShiftNonShift<AVX512BW::Kernels8>(8, 256, 256, 0.0001f, 0.22f, 0.06f, 0.001f);
+ TestMultiplyShiftNonShift<AVX512BW::Kernels8>(8, 2048, 256, 3.51f, 0.61f, 0.17f, 0.3f);
+ TestMultiplyShiftNonShift<AVX512BW::Kernels8>(320, 256, 256, 0.0001f, 0.27f, 0.06f, 0.001f);
+ TestMultiplyShiftNonShift<AVX512BW::Kernels8>(472, 256, 256, 0.0001f, 0.33f, 0.06f, 0.001f);
+ TestMultiplyShiftNonShift<AVX512BW::Kernels8>(248, 256, 256, 0.0001f, 0.27f, 0.06f, 0.001f);
+ TestMultiplyShiftNonShift<AVX512BW::Kernels8>(200, 256, 256, 0.0001f, 0.28f, 0.06f, 0.001f);
}
#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
TEST_CASE ("Multiply AVX512VNNI 8bit Shift vs nonshift", "[Add127]") {
if (kCPU < CPUType::AVX512VNNI) return;
- TestMultiplyShiftNonShift<avx512vnni::Kernels8>(1, 64, 8, 0.00001f, 0.05f, 0.03f, 0.00001f);
- TestMultiplyShiftNonShift<avx512vnni::Kernels8>(8, 256, 256, 0.00001f, 0.22f, 0.06f, 0.00001f);
- TestMultiplyShiftNonShift<avx512vnni::Kernels8>(8, 2048, 256, 0.0001f, 0.61f, 0.17f, 0.0001f);
- TestMultiplyShiftNonShift<avx512vnni::Kernels8>(320, 256, 256, 0.00001f, 0.27f, 0.06f, 0.00001f);
- TestMultiplyShiftNonShift<avx512vnni::Kernels8>(472, 256, 256, 0.00001f, 0.33f, 0.06f, 0.00001f);
- TestMultiplyShiftNonShift<avx512vnni::Kernels8>(248, 256, 256, 0.00001f, 0.27f, 0.06f, 0.00001f);
- TestMultiplyShiftNonShift<avx512vnni::Kernels8>(200, 256, 256, 0.00001f, 0.28f, 0.06f, 0.00001f);
+ TestMultiplyShiftNonShift<AVX512VNNI::Kernels8>(1, 64, 8, 0.00001f, 0.05f, 0.03f, 0.00001f);
+ TestMultiplyShiftNonShift<AVX512VNNI::Kernels8>(8, 256, 256, 0.00001f, 0.22f, 0.06f, 0.00001f);
+ TestMultiplyShiftNonShift<AVX512VNNI::Kernels8>(8, 2048, 256, 0.0001f, 0.61f, 0.17f, 0.0001f);
+ TestMultiplyShiftNonShift<AVX512VNNI::Kernels8>(320, 256, 256, 0.00001f, 0.27f, 0.06f, 0.00001f);
+ TestMultiplyShiftNonShift<AVX512VNNI::Kernels8>(472, 256, 256, 0.00001f, 0.33f, 0.06f, 0.00001f);
+ TestMultiplyShiftNonShift<AVX512VNNI::Kernels8>(248, 256, 256, 0.00001f, 0.27f, 0.06f, 0.00001f);
+ TestMultiplyShiftNonShift<AVX512VNNI::Kernels8>(200, 256, 256, 0.00001f, 0.28f, 0.06f, 0.00001f);
}
#endif
//Multiply Shift vs int shift implementation
TEST_CASE ("Multiply SSSE3 8bit Shift vs Int", "[Add127]") {
if (kCPU < CPUType::SSSE3) return;
- TestMultiplyShiftInt<ssse3::Kernels8>(1, 64, 8, 0.0001f, 0.1f, 0.06f, 0.0001f);
- TestMultiplyShiftInt<ssse3::Kernels8>(8, 256, 256, 0.0001f, 0.54f, 0.17f, 0.0001f);
- TestMultiplyShiftInt<ssse3::Kernels8>(8, 2048, 256, 0.0001f, 1.7f, 0.46f, 0.0001f);
- TestMultiplyShiftInt<ssse3::Kernels8>(320, 256, 256, 0.0001f, 0.64f, 0.16f, 0.0001f);
- TestMultiplyShiftInt<ssse3::Kernels8>(472, 256, 256, 0.0001f, 0.62f, 0.17f, 0.0001f);
- TestMultiplyShiftInt<ssse3::Kernels8>(248, 256, 256, 0.0001f, 0.64f, 0.16f, 0.0001f);
- TestMultiplyShiftInt<ssse3::Kernels8>(200, 256, 256, 0.0001f, 0.74f, 0.17f, 0.0001f);
+ TestMultiplyShiftInt<SSSE3::Kernels8>(1, 64, 8, 0.0001f, 0.1f, 0.06f, 0.0001f);
+ TestMultiplyShiftInt<SSSE3::Kernels8>(8, 256, 256, 0.0001f, 0.54f, 0.17f, 0.0001f);
+ TestMultiplyShiftInt<SSSE3::Kernels8>(8, 2048, 256, 0.0001f, 1.7f, 0.46f, 0.0001f);
+ TestMultiplyShiftInt<SSSE3::Kernels8>(320, 256, 256, 0.0001f, 0.64f, 0.16f, 0.0001f);
+ TestMultiplyShiftInt<SSSE3::Kernels8>(472, 256, 256, 0.0001f, 0.62f, 0.17f, 0.0001f);
+ TestMultiplyShiftInt<SSSE3::Kernels8>(248, 256, 256, 0.0001f, 0.64f, 0.16f, 0.0001f);
+ TestMultiplyShiftInt<SSSE3::Kernels8>(200, 256, 256, 0.0001f, 0.74f, 0.17f, 0.0001f);
}
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
TEST_CASE ("Multiply AVX2 8bit Shift vs Int", "[Add127]") {
if (kCPU < CPUType::AVX2) return;
- TestMultiplyShiftInt<avx2::Kernels8>(1, 64, 8, 0.0001f, 0.11f, 0.06f, 0.0001f);
- TestMultiplyShiftInt<avx2::Kernels8>(8, 256, 256, 0.0001f, 0.54f, 0.17f, 0.0001f);
- TestMultiplyShiftInt<avx2::Kernels8>(8, 2048, 256, 0.0001f, 1.66f, 0.46f, 0.0001f);
- TestMultiplyShiftInt<avx2::Kernels8>(320, 256, 256, 0.0001f, 0.64f, 0.16f, 0.0001f);
- TestMultiplyShiftInt<avx2::Kernels8>(472, 256, 256, 0.0001f, 0.62f, 0.17f, 0.0001f);
- TestMultiplyShiftInt<avx2::Kernels8>(248, 256, 256, 0.0001f, 0.64f, 0.16f, 0.0001f);
- TestMultiplyShiftInt<avx2::Kernels8>(200, 256, 256, 0.0001f, 0.74f, 0.17f, 0.0001f);
+ TestMultiplyShiftInt<AVX2::Kernels8>(1, 64, 8, 0.0001f, 0.11f, 0.06f, 0.0001f);
+ TestMultiplyShiftInt<AVX2::Kernels8>(8, 256, 256, 0.0001f, 0.54f, 0.17f, 0.0001f);
+ TestMultiplyShiftInt<AVX2::Kernels8>(8, 2048, 256, 0.0001f, 1.66f, 0.46f, 0.0001f);
+ TestMultiplyShiftInt<AVX2::Kernels8>(320, 256, 256, 0.0001f, 0.64f, 0.16f, 0.0001f);
+ TestMultiplyShiftInt<AVX2::Kernels8>(472, 256, 256, 0.0001f, 0.62f, 0.17f, 0.0001f);
+ TestMultiplyShiftInt<AVX2::Kernels8>(248, 256, 256, 0.0001f, 0.64f, 0.16f, 0.0001f);
+ TestMultiplyShiftInt<AVX2::Kernels8>(200, 256, 256, 0.0001f, 0.74f, 0.17f, 0.0001f);
}
+#endif
+
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
TEST_CASE ("Multiply AVX512F 8bit Shift vs Int", "[Add127]") {
if (kCPU < CPUType::AVX512BW) return;
- TestMultiplyShiftInt<avx512bw::Kernels8>(1, 64, 8, 0.0001f, 0.05f, 0.03f, 0.0001f);
- TestMultiplyShiftInt<avx512bw::Kernels8>(8, 256, 256, 0.0001f, 0.22f, 0.06f, 0.0001f);
- TestMultiplyShiftInt<avx512bw::Kernels8>(8, 2048, 256, 0.0001f, 0.61f, 0.17f, 0.0001f);
- TestMultiplyShiftInt<avx512bw::Kernels8>(320, 256, 256, 0.0001f, 0.27f, 0.06f, 0.0001f);
- TestMultiplyShiftInt<avx512bw::Kernels8>(472, 256, 256, 0.0001f, 0.33f, 0.06f, 0.0001f);
- TestMultiplyShiftInt<avx512bw::Kernels8>(248, 256, 256, 0.0001f, 0.27f, 0.06f, 0.0001f);
- TestMultiplyShiftInt<avx512bw::Kernels8>(200, 256, 256, 0.0001f, 0.28f, 0.06f, 0.0001f);
+ TestMultiplyShiftInt<AVX512BW::Kernels8>(1, 64, 8, 0.0001f, 0.05f, 0.03f, 0.0001f);
+ TestMultiplyShiftInt<AVX512BW::Kernels8>(8, 256, 256, 0.0001f, 0.22f, 0.06f, 0.0001f);
+ TestMultiplyShiftInt<AVX512BW::Kernels8>(8, 2048, 256, 0.0001f, 0.61f, 0.17f, 0.0001f);
+ TestMultiplyShiftInt<AVX512BW::Kernels8>(320, 256, 256, 0.0001f, 0.27f, 0.06f, 0.0001f);
+ TestMultiplyShiftInt<AVX512BW::Kernels8>(472, 256, 256, 0.0001f, 0.33f, 0.06f, 0.0001f);
+ TestMultiplyShiftInt<AVX512BW::Kernels8>(248, 256, 256, 0.0001f, 0.27f, 0.06f, 0.0001f);
+ TestMultiplyShiftInt<AVX512BW::Kernels8>(200, 256, 256, 0.0001f, 0.28f, 0.06f, 0.0001f);
}
#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
TEST_CASE ("Multiply AVX512VNNI 8bit Shift vs Int", "[Add127]") {
if (kCPU < CPUType::AVX512VNNI) return;
- TestMultiplyShiftInt<avx512vnni::Kernels8>(1, 64, 8, 0.0001f, 0.05f, 0.03f, 0.0001f);
- TestMultiplyShiftInt<avx512vnni::Kernels8>(8, 256, 256, 0.0001f, 0.22f, 0.06f, 0.0001f);
- TestMultiplyShiftInt<avx512vnni::Kernels8>(8, 2048, 256, 0.0001f, 0.61f, 0.17f, 0.0001f);
- TestMultiplyShiftInt<avx512vnni::Kernels8>(320, 256, 256, 0.0001f, 0.27f, 0.06f, 0.0001f);
- TestMultiplyShiftInt<avx512vnni::Kernels8>(472, 256, 256, 0.0001f, 0.33f, 0.06f, 0.0001f);
- TestMultiplyShiftInt<avx512vnni::Kernels8>(248, 256, 256, 0.0001f, 0.27f, 0.06f, 0.0001f);
- TestMultiplyShiftInt<avx512vnni::Kernels8>(200, 256, 256, 0.0001f, 0.28f, 0.06f, 0.0001f);
+ TestMultiplyShiftInt<AVX512VNNI::Kernels8>(1, 64, 8, 0.0001f, 0.05f, 0.03f, 0.0001f);
+ TestMultiplyShiftInt<AVX512VNNI::Kernels8>(8, 256, 256, 0.0001f, 0.22f, 0.06f, 0.0001f);
+ TestMultiplyShiftInt<AVX512VNNI::Kernels8>(8, 2048, 256, 0.0001f, 0.61f, 0.17f, 0.0001f);
+ TestMultiplyShiftInt<AVX512VNNI::Kernels8>(320, 256, 256, 0.0001f, 0.27f, 0.06f, 0.0001f);
+ TestMultiplyShiftInt<AVX512VNNI::Kernels8>(472, 256, 256, 0.0001f, 0.33f, 0.06f, 0.0001f);
+ TestMultiplyShiftInt<AVX512VNNI::Kernels8>(248, 256, 256, 0.0001f, 0.27f, 0.06f, 0.0001f);
+ TestMultiplyShiftInt<AVX512VNNI::Kernels8>(200, 256, 256, 0.0001f, 0.28f, 0.06f, 0.0001f);
}
#endif
diff --git a/test/kernels/add_bias_test.cc b/test/kernels/add_bias_test.cc
index 492c669..b9e5fd9 100644
--- a/test/kernels/add_bias_test.cc
+++ b/test/kernels/add_bias_test.cc
@@ -37,6 +37,7 @@ KERNEL_TEST_CASE("add_bias/int SSE2") { return kernel_add_bias_test<CPUType::SSE
KERNEL_TEST_CASE("add_bias/float SSE2") { return kernel_add_bias_test<CPUType::SSE2, float>(); }
KERNEL_TEST_CASE("add_bias/double SSE2") { return kernel_add_bias_test<CPUType::SSE2, double>(); }
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
template INTGEMM_AVX2 void kernel_add_bias_test<CPUType::AVX2, int8_t>();
template INTGEMM_AVX2 void kernel_add_bias_test<CPUType::AVX2, int16_t>();
template INTGEMM_AVX2 void kernel_add_bias_test<CPUType::AVX2, int>();
@@ -47,6 +48,7 @@ KERNEL_TEST_CASE("add_bias/int16 AVX2") { return kernel_add_bias_test<CPUType::A
KERNEL_TEST_CASE("add_bias/int AVX2") { return kernel_add_bias_test<CPUType::AVX2, int>(); }
KERNEL_TEST_CASE("add_bias/float AVX2") { return kernel_add_bias_test<CPUType::AVX2, float>(); }
KERNEL_TEST_CASE("add_bias/double AVX2") { return kernel_add_bias_test<CPUType::AVX2, double>(); }
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
template INTGEMM_AVX512BW void kernel_add_bias_test<CPUType::AVX512BW, int8_t>();
diff --git a/test/kernels/bitwise_not_test.cc b/test/kernels/bitwise_not_test.cc
index e908c43..6c28c95 100644
--- a/test/kernels/bitwise_not_test.cc
+++ b/test/kernels/bitwise_not_test.cc
@@ -28,8 +28,10 @@ void kernel_bitwise_not_test() {
template INTGEMM_SSE2 void kernel_bitwise_not_test<CPUType::SSE2>();
KERNEL_TEST_CASE("bitwise_not SSE2") { return kernel_bitwise_not_test<CPUType::SSE2>(); }
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
template INTGEMM_AVX2 void kernel_bitwise_not_test<CPUType::AVX2>();
KERNEL_TEST_CASE("bitwise_not AVX2") { return kernel_bitwise_not_test<CPUType::AVX2>(); }
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
template INTGEMM_AVX512BW void kernel_bitwise_not_test<CPUType::AVX512BW>();
diff --git a/test/kernels/downcast_test.cc b/test/kernels/downcast_test.cc
index 5f9db66..0f2ccd0 100644
--- a/test/kernels/downcast_test.cc
+++ b/test/kernels/downcast_test.cc
@@ -30,8 +30,10 @@ void kernel_downcast32to8_test() {
template INTGEMM_SSE2 void kernel_downcast32to8_test<CPUType::SSE2>();
KERNEL_TEST_CASE("downcast32to8 SSE2") { return kernel_downcast32to8_test<CPUType::SSE2>(); }
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
template INTGEMM_AVX2 void kernel_downcast32to8_test<CPUType::AVX2>();
KERNEL_TEST_CASE("downcast32to8 AVX2") { return kernel_downcast32to8_test<CPUType::AVX2>(); }
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
template INTGEMM_AVX512BW void kernel_downcast32to8_test<CPUType::AVX512BW>();
@@ -60,8 +62,10 @@ void kernel_downcast32to16_test() {
template INTGEMM_SSE2 void kernel_downcast32to16_test<CPUType::SSE2>();
KERNEL_TEST_CASE("downcast32to16 SSE2") { return kernel_downcast32to16_test<CPUType::SSE2>(); }
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
template INTGEMM_AVX2 void kernel_downcast32to16_test<CPUType::AVX2>();
KERNEL_TEST_CASE("downcast32to16 AVX2") { return kernel_downcast32to16_test<CPUType::AVX2>(); }
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
template INTGEMM_AVX512BW void kernel_downcast32to16_test<CPUType::AVX512BW>();
@@ -90,8 +94,10 @@ void kernel_downcast16to8_test() {
template INTGEMM_SSE2 void kernel_downcast16to8_test<CPUType::SSE2>();
KERNEL_TEST_CASE("downcast16to8 SSE2") { return kernel_downcast16to8_test<CPUType::SSE2>(); }
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
template INTGEMM_AVX2 void kernel_downcast16to8_test<CPUType::AVX2>();
KERNEL_TEST_CASE("downcast16to8 AVX2") { return kernel_downcast16to8_test<CPUType::AVX2>(); }
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
template INTGEMM_AVX512BW void kernel_downcast16to8_test<CPUType::AVX512BW>();
diff --git a/test/kernels/exp_test.cc b/test/kernels/exp_test.cc
index 838e228..9f535f2 100644
--- a/test/kernels/exp_test.cc
+++ b/test/kernels/exp_test.cc
@@ -25,8 +25,10 @@ void kernel_exp_approx_taylor_test() {
CHECK_EPS(output[i], exp(input[i]), 0.001f);
}
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
template INTGEMM_AVX2 void kernel_exp_approx_taylor_test<CPUType::AVX2>();
KERNEL_TEST_CASE("exp_approx_taylor AVX2") { return kernel_exp_approx_taylor_test<CPUType::AVX2>(); }
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
template INTGEMM_AVX512BW void kernel_exp_approx_taylor_test<CPUType::AVX512BW>();
diff --git a/test/kernels/floor_test.cc b/test/kernels/floor_test.cc
index 2659c3f..9b7a214 100644
--- a/test/kernels/floor_test.cc
+++ b/test/kernels/floor_test.cc
@@ -28,8 +28,10 @@ void kernel_floor_test() {
template INTGEMM_SSE2 void kernel_floor_test<CPUType::SSE2>();
KERNEL_TEST_CASE("floor SSE2") { return kernel_floor_test<CPUType::SSE2>(); }
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
template INTGEMM_AVX2 void kernel_floor_test<CPUType::AVX2>();
KERNEL_TEST_CASE("floor AVX2") { return kernel_floor_test<CPUType::AVX2>(); }
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
template INTGEMM_AVX512BW void kernel_floor_test<CPUType::AVX512BW>();
diff --git a/test/kernels/multiply_test.cc b/test/kernels/multiply_test.cc
index 029e3ac..fc1a51e 100644
--- a/test/kernels/multiply_test.cc
+++ b/test/kernels/multiply_test.cc
@@ -38,6 +38,7 @@ KERNEL_TEST_CASE("multiply/int SSE2") { return kernel_multiply_test<CPUType::SSE
KERNEL_TEST_CASE("multiply/float SSE2") { return kernel_multiply_test<CPUType::SSE2, float>(); }
KERNEL_TEST_CASE("multiply/double SSE2") { return kernel_multiply_test<CPUType::SSE2, double>(); }
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
template INTGEMM_AVX2 void kernel_multiply_test<CPUType::AVX2, int8_t>();
template INTGEMM_AVX2 void kernel_multiply_test<CPUType::AVX2, int16_t>();
template INTGEMM_AVX2 void kernel_multiply_test<CPUType::AVX2, int>();
@@ -48,6 +49,7 @@ KERNEL_TEST_CASE("multiply/int16 AVX2") { return kernel_multiply_test<CPUType::A
KERNEL_TEST_CASE("multiply/int AVX2") { return kernel_multiply_test<CPUType::AVX2, int>(); }
KERNEL_TEST_CASE("multiply/float AVX2") { return kernel_multiply_test<CPUType::AVX2, float>(); }
KERNEL_TEST_CASE("multiply/double AVX2") { return kernel_multiply_test<CPUType::AVX2, double>(); }
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
template INTGEMM_AVX512BW void kernel_multiply_test<CPUType::AVX512BW, int8_t>();
diff --git a/test/kernels/quantize_test.cc b/test/kernels/quantize_test.cc
index ae3c068..93280f7 100644
--- a/test/kernels/quantize_test.cc
+++ b/test/kernels/quantize_test.cc
@@ -28,8 +28,10 @@ void kernel_quantize_test() {
template INTGEMM_SSE2 void kernel_quantize_test<CPUType::SSE2>();
KERNEL_TEST_CASE("quantize SSE2") { return kernel_quantize_test<CPUType::SSE2>(); }
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
template INTGEMM_AVX2 void kernel_quantize_test<CPUType::AVX2>();
KERNEL_TEST_CASE("quantize AVX2") { return kernel_quantize_test<CPUType::AVX2>(); }
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
template INTGEMM_AVX512BW void kernel_quantize_test<CPUType::AVX512BW>();
diff --git a/test/kernels/relu_test.cc b/test/kernels/relu_test.cc
index 6fcef98..8fd30ae 100644
--- a/test/kernels/relu_test.cc
+++ b/test/kernels/relu_test.cc
@@ -36,6 +36,7 @@ KERNEL_TEST_CASE("relu/int SSE2") { return kernel_relu_test<CPUType::SSE2, int>(
KERNEL_TEST_CASE("relu/float SSE2") { return kernel_relu_test<CPUType::SSE2, float>(); }
KERNEL_TEST_CASE("relu/double SSE2") { return kernel_relu_test<CPUType::SSE2, double>(); }
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
template INTGEMM_AVX2 void kernel_relu_test<CPUType::AVX2, int8_t>();
template INTGEMM_AVX2 void kernel_relu_test<CPUType::AVX2, int16_t>();
template INTGEMM_AVX2 void kernel_relu_test<CPUType::AVX2, int>();
@@ -46,6 +47,7 @@ KERNEL_TEST_CASE("relu/int16 AVX2") { return kernel_relu_test<CPUType::AVX2, int
KERNEL_TEST_CASE("relu/int AVX2") { return kernel_relu_test<CPUType::AVX2, int>(); }
KERNEL_TEST_CASE("relu/float AVX2") { return kernel_relu_test<CPUType::AVX2, float>(); }
KERNEL_TEST_CASE("relu/double AVX2") { return kernel_relu_test<CPUType::AVX2, double>(); }
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
template INTGEMM_AVX512BW void kernel_relu_test<CPUType::AVX512BW, int8_t>();
diff --git a/test/kernels/rescale_test.cc b/test/kernels/rescale_test.cc
index 280b513..13937ed 100644
--- a/test/kernels/rescale_test.cc
+++ b/test/kernels/rescale_test.cc
@@ -30,8 +30,10 @@ void kernel_rescale_test() {
template INTGEMM_SSE2 void kernel_rescale_test<CPUType::SSE2>();
KERNEL_TEST_CASE("rescale SSE2") { return kernel_rescale_test<CPUType::SSE2>(); }
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
template INTGEMM_AVX2 void kernel_rescale_test<CPUType::AVX2>();
KERNEL_TEST_CASE("rescale AVX2") { return kernel_rescale_test<CPUType::AVX2>(); }
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
template INTGEMM_AVX512BW void kernel_rescale_test<CPUType::AVX512BW>();
diff --git a/test/kernels/sigmoid_test.cc b/test/kernels/sigmoid_test.cc
index af9dad1..7827593 100644
--- a/test/kernels/sigmoid_test.cc
+++ b/test/kernels/sigmoid_test.cc
@@ -32,8 +32,10 @@ void kernel_sigmoid_test() {
CHECK_EPS(output[i], sigmoid_ref(input[i]), 0.001f);
}
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
template INTGEMM_AVX2 void kernel_sigmoid_test<CPUType::AVX2>();
KERNEL_TEST_CASE("sigmoid AVX2") { return kernel_sigmoid_test<CPUType::AVX2>(); }
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
template INTGEMM_AVX512BW void kernel_sigmoid_test<CPUType::AVX512BW>();
diff --git a/test/kernels/tanh_test.cc b/test/kernels/tanh_test.cc
index e2c36f5..1d00042 100644
--- a/test/kernels/tanh_test.cc
+++ b/test/kernels/tanh_test.cc
@@ -25,8 +25,10 @@ void kernel_tanh_test() {
CHECK_EPS(output[i], tanh(input[i]), 0.001f);
}
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
template INTGEMM_AVX2 void kernel_tanh_test<CPUType::AVX2>();
KERNEL_TEST_CASE("tanh AVX2") { return kernel_tanh_test<CPUType::AVX2>(); }
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
template INTGEMM_AVX512BW void kernel_tanh_test<CPUType::AVX512BW>();
diff --git a/test/kernels/unquantize_test.cc b/test/kernels/unquantize_test.cc
index ee4bc80..edfafa5 100644
--- a/test/kernels/unquantize_test.cc
+++ b/test/kernels/unquantize_test.cc
@@ -28,8 +28,10 @@ void kernel_unquantize_test() {
template INTGEMM_SSE2 void kernel_unquantize_test<CPUType::SSE2>();
KERNEL_TEST_CASE("unquantize SSE2") { return kernel_unquantize_test<CPUType::SSE2>(); }
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
template INTGEMM_AVX2 void kernel_unquantize_test<CPUType::AVX2>();
KERNEL_TEST_CASE("unquantize AVX2") { return kernel_unquantize_test<CPUType::AVX2>(); }
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
template INTGEMM_AVX512BW void kernel_unquantize_test<CPUType::AVX512BW>();
diff --git a/test/kernels/upcast_test.cc b/test/kernels/upcast_test.cc
index 92be1bd..0733922 100644
--- a/test/kernels/upcast_test.cc
+++ b/test/kernels/upcast_test.cc
@@ -33,8 +33,10 @@ void kernel_upcast8to16_test() {
template INTGEMM_SSE2 void kernel_upcast8to16_test<CPUType::SSE2>();
KERNEL_TEST_CASE("upcast8to16 SSE2") { return kernel_upcast8to16_test<CPUType::SSE2>(); }
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
template INTGEMM_AVX2 void kernel_upcast8to16_test<CPUType::AVX2>();
KERNEL_TEST_CASE("upcast8to16 AVX2") { return kernel_upcast8to16_test<CPUType::AVX2>(); }
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
template INTGEMM_AVX512BW void kernel_upcast8to16_test<CPUType::AVX512BW>();
@@ -65,8 +67,10 @@ void kernel_upcast16to32_test() {
template INTGEMM_SSE2 void kernel_upcast16to32_test<CPUType::SSE2>();
KERNEL_TEST_CASE("upcast16to32 SSE2") { return kernel_upcast16to32_test<CPUType::SSE2>(); }
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
template INTGEMM_AVX2 void kernel_upcast16to32_test<CPUType::AVX2>();
KERNEL_TEST_CASE("upcast16to32 AVX2") { return kernel_upcast16to32_test<CPUType::AVX2>(); }
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
template INTGEMM_AVX512BW void kernel_upcast16to32_test<CPUType::AVX512BW>();
@@ -100,8 +104,10 @@ void kernel_upcast8to32_test() {
template INTGEMM_SSE2 void kernel_upcast8to32_test<CPUType::SSE2>();
KERNEL_TEST_CASE("upcast8to32 SSE2") { return kernel_upcast8to32_test<CPUType::SSE2>(); }
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
template INTGEMM_AVX2 void kernel_upcast8to32_test<CPUType::AVX2>();
KERNEL_TEST_CASE("upcast8to32 AVX2") { return kernel_upcast8to32_test<CPUType::AVX2>(); }
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
template INTGEMM_AVX512BW void kernel_upcast8to32_test<CPUType::AVX512BW>();
diff --git a/test/kernels/write_test.cc b/test/kernels/write_test.cc
index c263fca..a136a86 100644
--- a/test/kernels/write_test.cc
+++ b/test/kernels/write_test.cc
@@ -36,6 +36,7 @@ KERNEL_TEST_CASE("write/int SSE2") { return kernel_write_test<CPUType::SSE2, int
KERNEL_TEST_CASE("write/float SSE2") { return kernel_write_test<CPUType::SSE2, float>(); }
KERNEL_TEST_CASE("write/double SSE2") { return kernel_write_test<CPUType::SSE2, double>(); }
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
template INTGEMM_AVX2 void kernel_write_test<CPUType::AVX2, int8_t>();
template INTGEMM_AVX2 void kernel_write_test<CPUType::AVX2, int16_t>();
template INTGEMM_AVX2 void kernel_write_test<CPUType::AVX2, int>();
@@ -46,6 +47,7 @@ KERNEL_TEST_CASE("write/int16 AVX2") { return kernel_write_test<CPUType::AVX2, i
KERNEL_TEST_CASE("write/int AVX2") { return kernel_write_test<CPUType::AVX2, int>(); }
KERNEL_TEST_CASE("write/float AVX2") { return kernel_write_test<CPUType::AVX2, float>(); }
KERNEL_TEST_CASE("write/double AVX2") { return kernel_write_test<CPUType::AVX2, double>(); }
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
template INTGEMM_AVX512BW void kernel_write_test<CPUType::AVX512BW, int8_t>();
diff --git a/test/multiply_test.cc b/test/multiply_test.cc
index 6c16edd..f72758f 100644
--- a/test/multiply_test.cc
+++ b/test/multiply_test.cc
@@ -20,7 +20,10 @@
namespace intgemm {
-INTGEMM_SSE2 TEST_CASE("Transpose 16", "[transpose]") {
+#ifndef __INTEL_COMPILER
+INTGEMM_SSE2
+#endif
+TEST_CASE("Transpose 16", "[transpose]") {
if (kCPU < CPUType::SSE2) return;
const unsigned N = 8;
AlignedVector<int16_t> input(N * N);
@@ -38,7 +41,10 @@ INTGEMM_SSE2 TEST_CASE("Transpose 16", "[transpose]") {
}
}
-INTGEMM_SSSE3 TEST_CASE("Transpose 8", "[transpose]") {
+#ifndef __INTEL_COMPILER
+INTGEMM_SSSE3
+#endif
+TEST_CASE("Transpose 8", "[transpose]") {
if (kCPU < CPUType::SSSE3) return;
const unsigned N = 16;
AlignedVector<int8_t> input(N * N);
@@ -82,33 +88,35 @@ template <class Routine> void TestPrepare(Index rows = 32, Index cols = 16) {
PrintMatrix(reference.begin(), rows, cols) << "Routine" << '\n' << PrintMatrix(test.begin(), rows, cols));
}
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
TEST_CASE("Prepare AVX512", "[prepare]") {
if (kCPU < CPUType::AVX512BW) return;
-#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
- TestPrepare<avx512bw::Kernels8>(64, 8);
- TestPrepare<avx512bw::Kernels8>(256, 32);
- TestPrepare<avx512bw::Kernels16>(64, 8);
- TestPrepare<avx512bw::Kernels16>(256, 32);
-#endif
+ TestPrepare<AVX512BW::Kernels8>(64, 8);
+ TestPrepare<AVX512BW::Kernels8>(256, 32);
+ TestPrepare<AVX512BW::Kernels16>(64, 8);
+ TestPrepare<AVX512BW::Kernels16>(256, 32);
}
+#endif
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
TEST_CASE("Prepare AVX2", "[prepare]") {
if (kCPU < CPUType::AVX2) return;
- TestPrepare<avx2::Kernels8>(64, 32);
- TestPrepare<avx2::Kernels16>(64, 32);
+ TestPrepare<AVX2::Kernels8>(64, 32);
+ TestPrepare<AVX2::Kernels16>(64, 32);
}
+#endif
TEST_CASE("Prepare SSSE3", "[prepare]") {
if (kCPU < CPUType::SSSE3) return;
- TestPrepare<ssse3::Kernels8>(16, 8);
- TestPrepare<ssse3::Kernels8>(32, 16);
- TestPrepare<ssse3::Kernels8>(32, 32);
+ TestPrepare<SSSE3::Kernels8>(16, 8);
+ TestPrepare<SSSE3::Kernels8>(32, 16);
+ TestPrepare<SSSE3::Kernels8>(32, 32);
}
TEST_CASE("Prepare SSE2", "[prepare]") {
if (kCPU < CPUType::SSE2) return;
- TestPrepare<sse2::Kernels16>(8, 8);
- TestPrepare<sse2::Kernels16>(32, 32);
+ TestPrepare<SSE2::Kernels16>(8, 8);
+ TestPrepare<SSE2::Kernels16>(32, 32);
}
template <class Routine> void TestSelectColumnsB(Index rows = 64, Index cols = 16) {
@@ -147,30 +155,32 @@ template <class Routine> void TestSelectColumnsB(Index rows = 64, Index cols = 1
PrintMatrix(ref.begin(), rows, kSelectCols) << PrintMatrix(test.begin(), rows, kSelectCols));
}
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
TEST_CASE("SelectColumnsB AVX512", "[select]") {
if (kCPU < CPUType::AVX512BW) return;
-#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
- TestSelectColumnsB<avx512bw::Kernels8>();
- TestSelectColumnsB<avx512bw::Kernels16>(256, 256);
-#endif
+ TestSelectColumnsB<AVX512BW::Kernels8>();
+ TestSelectColumnsB<AVX512BW::Kernels16>(256, 256);
}
+#endif
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
TEST_CASE("SelectColumnsB AVX2", "[select]") {
if (kCPU < CPUType::AVX2) return;
- TestSelectColumnsB<avx2::Kernels8>(256, 256);
- TestSelectColumnsB<avx2::Kernels16>(256, 256);
+ TestSelectColumnsB<AVX2::Kernels8>(256, 256);
+ TestSelectColumnsB<AVX2::Kernels16>(256, 256);
}
+#endif
TEST_CASE("SelectColumnsB SSSE3", "[select]") {
if (kCPU < CPUType::SSSE3) return;
- TestSelectColumnsB<ssse3::Kernels8>();
- TestSelectColumnsB<ssse3::Kernels8>(256, 256);
+ TestSelectColumnsB<SSSE3::Kernels8>();
+ TestSelectColumnsB<SSSE3::Kernels8>(256, 256);
}
TEST_CASE("SelectColumnsB SSE2", "[select]") {
if (kCPU < CPUType::SSE2) return;
- TestSelectColumnsB<sse2::Kernels16>();
- TestSelectColumnsB<sse2::Kernels16>(256, 256);
+ TestSelectColumnsB<SSE2::Kernels16>();
+ TestSelectColumnsB<SSE2::Kernels16>(256, 256);
}
template <class Register> void TestMax() {
@@ -215,20 +225,22 @@ template <float (*Backend) (const float *, const float *)> void TestMaxAbsolute(
TEST_CASE("MaxAbsolute SSE2", "[max]") {
if (kCPU < CPUType::SSE2) return;
- TestMaxAbsolute<sse2::MaxAbsolute>();
+ TestMaxAbsolute<SSE2::MaxAbsolute>();
}
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
TEST_CASE("MaxAbsolute AVX2", "[max]") {
if (kCPU < CPUType::AVX2) return;
- TestMaxAbsolute<avx2::MaxAbsolute>();
+ TestMaxAbsolute<AVX2::MaxAbsolute>();
}
+#endif
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
TEST_CASE("MaxAbsolute AVX512BW", "[max]") {
if (kCPU < CPUType::AVX512BW) return;
- #ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
- TestMaxAbsolute<avx512bw::MaxAbsolute>();
- #endif
+ TestMaxAbsolute<AVX512BW::MaxAbsolute>();
}
+#endif
// Based on https://arxiv.org/abs/1705.01991
@@ -303,6 +315,57 @@ template <class Routine> void TestMultiply(Index A_rows, Index width, Index B_co
int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance);
}
+template <class Routine> void TestMultiplyRelu(Index A_rows, Index width, Index B_cols,
+ float int_tolerance=.1, float float_tolerance=1, float MSE_float_tolerance=0, float MSE_int_tolerance=0) {
+ using Integer = typename Routine::Integer;
+ std::ostringstream info;
+ info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n';
+
+ // Initialize A and B.
+ AlignedVector<float> A(A_rows * width);
+ AlignedVector<float> B(width * B_cols);
+ std::mt19937 gen;
+ std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
+ for (auto& it : A) {
+ it = dist(gen);
+ }
+ for (auto& it : B) {
+ it = dist(gen);
+ }
+
+ float quant_mult = (sizeof(Integer) == 2) ? 1024 : 64;
+ float unquant_mult = 1.0f / (quant_mult*quant_mult);
+
+ AlignedVector<Integer> A_prep(A.size());
+ AlignedVector<Integer> B_prep(B.size());
+ Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width);
+ Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
+
+ AlignedVector<float> test_C(A_rows * B_cols);
+ OMPParallelWrap<callbacks::UnquantizeAndWriteRelu, Routine>(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndWriteRelu(unquant_mult, test_C.begin()));
+ // Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::Sequence(
+ // callbacks::Unquantize(unquant_mult),
+ // callbacks::Write<float>(test_C.begin())
+ // ));
+
+ AlignedVector<Integer> B_quant(B.size());
+ Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, static_cast<Index>(B.size()));
+ AlignedVector<float> slowint_C(test_C.size());
+ // Assuming A is just quantization here.
+ references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo&) {
+ float ret = std::max(0.0f, sum * unquant_mult);
+ return ret;
+ });
+
+ AlignedVector<float> float_C(test_C.size());
+ references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](double sum, const callbacks::OutputBufferInfo&) {
+ return static_cast<float>(std::max(0.0,sum));
+ });
+
+ CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(),
+ int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance);
+}
+
//Code duplication may be avoided through some use of variadic templates, as the different WriteC symbols
//Require different number of arguments. I don't think the refactoring is worth it.
template <class Routine> void TestMultiplyBias(Index A_rows, Index width, Index B_cols,
@@ -326,7 +389,7 @@ template <class Routine> void TestMultiplyBias(Index A_rows, Index width, Index
for (auto& it : bias) {
it = dist(gen);
}
-
+
float quant_mult = (sizeof(Integer) == 2) ? 1024 : 64;
float unquant_mult = 1.0f / (quant_mult*quant_mult);
@@ -356,147 +419,342 @@ template <class Routine> void TestMultiplyBias(Index A_rows, Index width, Index
int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance);
}
+template <class Routine> void TestMultiplyBiasRelu(Index A_rows, Index width, Index B_cols,
+ float int_tolerance = 0.1f, float float_tolerance = 1.0f, float MSE_float_tolerance = 0.0f, float MSE_int_tolerance = 0.0f) {
+ using Integer = typename Routine::Integer;
+ std::ostringstream info;
+ info << Routine::kName << "\t" << A_rows << '\t' << width << '\t' << B_cols << '\n';
+
+ // Initialize A and B.
+ AlignedVector<float> A(A_rows * width);
+ AlignedVector<float> B(width * B_cols);
+ AlignedVector<float> bias(B_cols);
+ std::mt19937 gen;
+ std::uniform_real_distribution<float> dist(-1.0f, 1.0f);
+ for (auto& it : A) {
+ it = dist(gen);
+ }
+ for (auto& it : B) {
+ it = dist(gen);
+ }
+ for (auto& it : bias) {
+ it = dist(gen);
+ }
+
+ float quant_mult = (sizeof(Integer) == 2) ? 1024 : 64;
+ float unquant_mult = 1.0f / (quant_mult*quant_mult);
+
+ AlignedVector<Integer> A_prep(A.size());
+ AlignedVector<Integer> B_prep(B.size());
+ Routine::PrepareA(A.begin(), A_prep.begin(), quant_mult, A_rows, width);
+ Routine::PrepareB(B.begin(), B_prep.begin(), quant_mult, width, B_cols);
+
+ AlignedVector<float> test_C(A_rows * B_cols);
+
+ Routine::Multiply(A_prep.begin(), B_prep.begin(), A_rows, width, B_cols, callbacks::UnquantizeAndAddBiasAndWriteRelu(unquant_mult, bias.begin(), test_C.begin()));
+
+ AlignedVector<Integer> B_quant(B.size());
+ Routine::Quantize(B.begin(), B_quant.begin(), quant_mult, static_cast<Index>(B.size()));
+ AlignedVector<float> slowint_C(test_C.size());
+ // Assuming A is just quantization here.
+ references::Multiply(A_prep.begin(), B_quant.begin(), slowint_C.begin(), A_rows, width, B_cols, [&](int32_t sum, const callbacks::OutputBufferInfo& info) {
+ return std::max(0.0f, sum * unquant_mult + bias[info.col_idx]);
+ });
+
+ AlignedVector<float> float_C(test_C.size());
+ references::Multiply(A.begin(), B.begin(), float_C.begin(), A_rows, width, B_cols, [&](double sum, const callbacks::OutputBufferInfo& info) {
+ return std::max(0.0f, static_cast<float>(sum) + bias[info.col_idx]);
+ });
+
+ CompareMSE(float_C.begin(), slowint_C.begin(), test_C.begin(), test_C.size(), info.str(),
+ int_tolerance, float_tolerance, MSE_float_tolerance, MSE_int_tolerance);
+}
+
TEST_CASE ("Multiply SSE2 16bit", "[multiply]") {
if (kCPU < CPUType::SSE2) return;
- TestMultiply<sse2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
- TestMultiply<sse2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
- TestMultiply<sse2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
- TestMultiply<sse2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
- TestMultiply<sse2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
- TestMultiply<sse2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
+ TestMultiply<SSE2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
+ TestMultiply<SSE2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
+ TestMultiply<SSE2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
+ TestMultiply<SSE2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
+ TestMultiply<SSE2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
+ TestMultiply<SSE2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
+}
+
+TEST_CASE ("Multiply SSE2 16bit with relu", "[multiply_relu]") {
+ if (kCPU < CPUType::SSE2) return;
+ TestMultiplyRelu<SSE2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<SSE2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
+ TestMultiplyRelu<SSE2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<SSE2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<SSE2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<SSE2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
}
TEST_CASE ("Multiply SSE2 16bit with bias", "[biased_multiply]") {
if (kCPU < CPUType::SSE2) return;
- TestMultiplyBias<sse2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
- TestMultiplyBias<sse2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
- TestMultiplyBias<sse2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
- TestMultiplyBias<sse2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
- TestMultiplyBias<sse2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
- TestMultiplyBias<sse2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBias<SSE2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBias<SSE2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
+ TestMultiplyBias<SSE2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBias<SSE2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBias<SSE2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBias<SSE2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
+}
+
+TEST_CASE ("Multiply SSE2 16bit with bias and relu", "[biased_multiply_relu]") {
+ if (kCPU < CPUType::SSE2) return;
+ TestMultiplyBiasRelu<SSE2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<SSE2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
+ TestMultiplyBiasRelu<SSE2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<SSE2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<SSE2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<SSE2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
}
TEST_CASE ("Multiply SSSE3 8bit", "[multiply]") {
if (kCPU < CPUType::SSSE3) return;
- TestMultiply<ssse3::Kernels8>(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f);
- TestMultiply<ssse3::Kernels8>(8, 2048, 256, 33, 33, 4.4f, 4.4f);
- TestMultiply<ssse3::Kernels8>(320, 256, 256, 1.9f, 1.9f, 0.1f, 0.01f);
- TestMultiply<ssse3::Kernels8>(472, 256, 256, 2.1f, 2.1f, 0.1f, 0.011f);
- TestMultiply<ssse3::Kernels8>(248, 256, 256, 1.7f, 1.7f, 0.1f, 0.012f);
- TestMultiply<ssse3::Kernels8>(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f);
+ TestMultiply<SSSE3::Kernels8>(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f);
+ TestMultiply<SSSE3::Kernels8>(8, 2048, 256, 33, 33, 4.4f, 4.4f);
+ TestMultiply<SSSE3::Kernels8>(320, 256, 256, 1.9f, 1.9f, 0.1f, 0.01f);
+ TestMultiply<SSSE3::Kernels8>(472, 256, 256, 2.1f, 2.1f, 0.1f, 0.011f);
+ TestMultiply<SSSE3::Kernels8>(248, 256, 256, 1.7f, 1.7f, 0.1f, 0.012f);
+ TestMultiply<SSSE3::Kernels8>(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f);
+}
+
+TEST_CASE ("Multiply SSSE3 8bit with relu", "[multiply_relu]") {
+ if (kCPU < CPUType::SSSE3) return;
+ TestMultiplyRelu<SSSE3::Kernels8>(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f);
+ TestMultiplyRelu<SSSE3::Kernels8>(8, 2048, 256, 33, 33, 4.4f, 4.4f);
+ TestMultiplyRelu<SSSE3::Kernels8>(320, 256, 256, 1.9f, 1.9f, 0.1f, 0.01f);
+ TestMultiplyRelu<SSSE3::Kernels8>(472, 256, 256, 2.1f, 2.1f, 0.1f, 0.011f);
+ TestMultiplyRelu<SSSE3::Kernels8>(248, 256, 256, 1.7f, 1.7f, 0.1f, 0.012f);
+ TestMultiplyRelu<SSSE3::Kernels8>(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f);
}
TEST_CASE ("Multiply SSSE3 8bit with bias", "[biased_multiply]") {
if (kCPU < CPUType::SSSE3) return;
- TestMultiplyBias<ssse3::Kernels8>(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f);
- TestMultiplyBias<ssse3::Kernels8>(8, 2048, 256, 33, 33, 4.4f, 4.4f);
- TestMultiplyBias<ssse3::Kernels8>(320, 256, 256, 1.9f, 1.9f, 0.1f, 0.01f);
- TestMultiplyBias<ssse3::Kernels8>(472, 256, 256, 2.1f, 2.1f, 0.1f, 0.011f);
- TestMultiplyBias<ssse3::Kernels8>(248, 256, 256, 1.7f, 1.7f, 0.1f, 0.012f);
- TestMultiplyBias<ssse3::Kernels8>(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f);
+ TestMultiplyBias<SSSE3::Kernels8>(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f);
+ TestMultiplyBias<SSSE3::Kernels8>(8, 2048, 256, 33, 33, 4.4f, 4.4f);
+ TestMultiplyBias<SSSE3::Kernels8>(320, 256, 256, 1.9f, 1.9f, 0.1f, 0.01f);
+ TestMultiplyBias<SSSE3::Kernels8>(472, 256, 256, 2.1f, 2.1f, 0.1f, 0.011f);
+ TestMultiplyBias<SSSE3::Kernels8>(248, 256, 256, 1.7f, 1.7f, 0.1f, 0.012f);
+ TestMultiplyBias<SSSE3::Kernels8>(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f);
+}
+
+TEST_CASE ("Multiply SSSE3 8bit with bias and relu", "[biased_multiply_relu]") {
+ if (kCPU < CPUType::SSSE3) return;
+ TestMultiplyBiasRelu<SSSE3::Kernels8>(8, 256, 256, 1.2f, 1.2f, 0.064f, 0.026f);
+ TestMultiplyBiasRelu<SSSE3::Kernels8>(8, 2048, 256, 33, 33, 4.4f, 4.4f);
+ TestMultiplyBiasRelu<SSSE3::Kernels8>(320, 256, 256, 1.9f, 1.9f, 0.1f, 0.01f);
+ TestMultiplyBiasRelu<SSSE3::Kernels8>(472, 256, 256, 2.1f, 2.1f, 0.1f, 0.011f);
+ TestMultiplyBiasRelu<SSSE3::Kernels8>(248, 256, 256, 1.7f, 1.7f, 0.1f, 0.012f);
+ TestMultiplyBiasRelu<SSSE3::Kernels8>(200, 256, 256, 1.8f, 1.9f, 0.1f, 0.011f);
}
+
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
TEST_CASE ("Multiply AVX2 8bit", "[multiply]") {
if (kCPU < CPUType::AVX2) return;
- TestMultiply<avx2::Kernels8>(8, 256, 256, .1f, 1, 0.1f);
- TestMultiply<avx2::Kernels8>(8, 2048, 256, 19, 19, 1.8f, 1.8f);
- TestMultiply<avx2::Kernels8>(320, 256, 256, .1f, 1, 0.1f);
- TestMultiply<avx2::Kernels8>(472, 256, 256, .1f, 1, 0.1f);
- TestMultiply<avx2::Kernels8>(248, 256, 256, .1f, 1, 0.1f);
- TestMultiply<avx2::Kernels8>(200, 256, 256, .1f, 1, 0.1f);
+ TestMultiply<AVX2::Kernels8>(8, 256, 256, .1f, 1, 0.1f);
+ TestMultiply<AVX2::Kernels8>(8, 2048, 256, 19, 19, 1.8f, 1.8f);
+ TestMultiply<AVX2::Kernels8>(320, 256, 256, .1f, 1, 0.1f);
+ TestMultiply<AVX2::Kernels8>(472, 256, 256, .1f, 1, 0.1f);
+ TestMultiply<AVX2::Kernels8>(248, 256, 256, .1f, 1, 0.1f);
+ TestMultiply<AVX2::Kernels8>(200, 256, 256, .1f, 1, 0.1f);
+}
+
+TEST_CASE ("Multiply AVX2 8bit with relu", "[multiply_relu]") {
+ if (kCPU < CPUType::AVX2) return;
+ TestMultiplyRelu<AVX2::Kernels8>(8, 256, 256, .1f, 1, 0.1f);
+ TestMultiplyRelu<AVX2::Kernels8>(8, 2048, 256, 19, 19, 1.8f, 1.8f);
+ TestMultiplyRelu<AVX2::Kernels8>(320, 256, 256, .1f, 1, 0.1f);
+ TestMultiplyRelu<AVX2::Kernels8>(472, 256, 256, .1f, 1, 0.1f);
+ TestMultiplyRelu<AVX2::Kernels8>(248, 256, 256, .1f, 1, 0.1f);
+ TestMultiplyRelu<AVX2::Kernels8>(200, 256, 256, .1f, 1, 0.1f);
}
TEST_CASE ("Multiply AVX2 8bit with bias", "[biased_multiply]") {
if (kCPU < CPUType::AVX2) return;
- TestMultiplyBias<avx2::Kernels8>(8, 256, 256, .1f, 1, 0.1f);
- TestMultiplyBias<avx2::Kernels8>(8, 2048, 256, 19, 19, 1.8f, 1.8f);
- TestMultiplyBias<avx2::Kernels8>(320, 256, 256, .1f, 1, 0.1f);
- TestMultiplyBias<avx2::Kernels8>(472, 256, 256, .1f, 1, 0.1f);
- TestMultiplyBias<avx2::Kernels8>(248, 256, 256, .1f, 1, 0.1f);
- TestMultiplyBias<avx2::Kernels8>(200, 256, 256, .1f, 1, 0.1f);
+ TestMultiplyBias<AVX2::Kernels8>(8, 256, 256, .1f, 1, 0.1f);
+ TestMultiplyBias<AVX2::Kernels8>(8, 2048, 256, 19, 19, 1.8f, 1.8f);
+ TestMultiplyBias<AVX2::Kernels8>(320, 256, 256, .1f, 1, 0.1f);
+ TestMultiplyBias<AVX2::Kernels8>(472, 256, 256, .1f, 1, 0.1f);
+ TestMultiplyBias<AVX2::Kernels8>(248, 256, 256, .1f, 1, 0.1f);
+ TestMultiplyBias<AVX2::Kernels8>(200, 256, 256, .1f, 1, 0.1f);
+}
+
+TEST_CASE ("Multiply AVX2 8bit with bias and relu", "[biased_multiply_relu]") {
+ if (kCPU < CPUType::AVX2) return;
+ TestMultiplyBiasRelu<AVX2::Kernels8>(8, 256, 256, .1f, 1, 0.1f);
+ TestMultiplyBiasRelu<AVX2::Kernels8>(8, 2048, 256, 19, 19, 1.8f, 1.8f);
+ TestMultiplyBiasRelu<AVX2::Kernels8>(320, 256, 256, .1f, 1, 0.1f);
+ TestMultiplyBiasRelu<AVX2::Kernels8>(472, 256, 256, .1f, 1, 0.1f);
+ TestMultiplyBiasRelu<AVX2::Kernels8>(248, 256, 256, .1f, 1, 0.1f);
+ TestMultiplyBiasRelu<AVX2::Kernels8>(200, 256, 256, .1f, 1, 0.1f);
}
TEST_CASE ("Multiply AVX2 16bit", "[multiply]") {
if (kCPU < CPUType::AVX2) return;
- TestMultiply<avx2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
- TestMultiply<avx2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
- TestMultiply<avx2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
- TestMultiply<avx2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
- TestMultiply<avx2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
- TestMultiply<avx2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
+ TestMultiply<AVX2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
+ TestMultiply<AVX2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
+ TestMultiply<AVX2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
+ TestMultiply<AVX2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
+ TestMultiply<AVX2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
+ TestMultiply<AVX2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
+}
+
+TEST_CASE ("Multiply AVX2 16bit with relu", "[multiply_relu]") {
+ if (kCPU < CPUType::AVX2) return;
+ TestMultiplyRelu<AVX2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<AVX2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
+ TestMultiplyRelu<AVX2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<AVX2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<AVX2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<AVX2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
}
TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") {
if (kCPU < CPUType::AVX2) return;
- TestMultiplyBias<avx2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
- TestMultiplyBias<avx2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
- TestMultiplyBias<avx2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
- TestMultiplyBias<avx2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
- TestMultiplyBias<avx2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
- TestMultiplyBias<avx2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBias<AVX2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBias<AVX2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
+ TestMultiplyBias<AVX2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBias<AVX2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBias<AVX2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBias<AVX2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
}
+TEST_CASE ("Multiply AVX2 16bit with bias and relu", "[biased_multiply_relu]") {
+ if (kCPU < CPUType::AVX2) return;
+ TestMultiplyBiasRelu<AVX2::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<AVX2::Kernels16>(8, 2048, 256, .1f, 1, 0.02f);
+ TestMultiplyBiasRelu<AVX2::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<AVX2::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<AVX2::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<AVX2::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
+}
+#endif
+
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
TEST_CASE ("Multiply AVX512 8bit", "[multiply]") {
if (kCPU < CPUType::AVX512BW) return;
- TestMultiply<avx512bw::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
- TestMultiply<avx512bw::Kernels8>(8, 2048, 256, 3.7f, 4, 0.37f, 0.33f);
- TestMultiply<avx512bw::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
- TestMultiply<avx512bw::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
- TestMultiply<avx512bw::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
- TestMultiply<avx512bw::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
+ TestMultiply<AVX512BW::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
+ TestMultiply<AVX512BW::Kernels8>(8, 2048, 256, 3.7f, 4, 0.37f, 0.33f);
+ TestMultiply<AVX512BW::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
+ TestMultiply<AVX512BW::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
+ TestMultiply<AVX512BW::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
+ TestMultiply<AVX512BW::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
+ }
+
+ TEST_CASE ("Multiply AVX512 8bit with relu", "[multiply_relu]") {
+ if (kCPU < CPUType::AVX512BW) return;
+ TestMultiplyRelu<AVX512BW::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
+ TestMultiplyRelu<AVX512BW::Kernels8>(8, 2048, 256, 3.7f, 4, 0.37f, 0.33f);
+ TestMultiplyRelu<AVX512BW::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
+ TestMultiplyRelu<AVX512BW::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
+ TestMultiplyRelu<AVX512BW::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
+ TestMultiplyRelu<AVX512BW::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
}
TEST_CASE ("Multiply AVX512 8bit with bias", "[biased_multiply]") {
if (kCPU < CPUType::AVX512BW) return;
- TestMultiplyBias<avx512bw::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
- TestMultiplyBias<avx512bw::Kernels8>(8, 2048, 256, 3.7f, 4, 0.37f, 0.33f);
- TestMultiplyBias<avx512bw::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
- TestMultiplyBias<avx512bw::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
- TestMultiplyBias<avx512bw::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
- TestMultiplyBias<avx512bw::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
+ TestMultiplyBias<AVX512BW::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
+ TestMultiplyBias<AVX512BW::Kernels8>(8, 2048, 256, 3.7f, 4, 0.37f, 0.33f);
+ TestMultiplyBias<AVX512BW::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
+ TestMultiplyBias<AVX512BW::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
+ TestMultiplyBias<AVX512BW::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
+ TestMultiplyBias<AVX512BW::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
+ }
+
+ TEST_CASE ("Multiply AVX512 8bit with bias and relu", "[biased_multiply_relu]") {
+ if (kCPU < CPUType::AVX512BW) return;
+ TestMultiplyBiasRelu<AVX512BW::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
+ TestMultiplyBiasRelu<AVX512BW::Kernels8>(8, 2048, 256, 3.7f, 4, 0.37f, 0.33f);
+ TestMultiplyBiasRelu<AVX512BW::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
+ TestMultiplyBiasRelu<AVX512BW::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
+ TestMultiplyBiasRelu<AVX512BW::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
+ TestMultiplyBiasRelu<AVX512BW::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
}
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512VNNI
TEST_CASE ("Multiply AVX512VNNI 8bit", "[multiply]") {
if (kCPU < CPUType::AVX512VNNI) return;
- TestMultiply<avx512vnni::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
- TestMultiply<avx512vnni::Kernels8>(8, 2048, 256, 0, 0.55f, 0.25f);
- TestMultiply<avx512vnni::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
- TestMultiply<avx512vnni::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
- TestMultiply<avx512vnni::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
- TestMultiply<avx512vnni::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
+ TestMultiply<AVX512VNNI::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
+ TestMultiply<AVX512VNNI::Kernels8>(8, 2048, 256, 0, 0.55f, 0.25f);
+ TestMultiply<AVX512VNNI::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
+ TestMultiply<AVX512VNNI::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
+ TestMultiply<AVX512VNNI::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
+ TestMultiply<AVX512VNNI::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
+ }
+
+ TEST_CASE ("Multiply AVX512VNNI 8bit with relu", "[multiply_relu]") {
+ if (kCPU < CPUType::AVX512VNNI) return;
+ TestMultiplyRelu<AVX512VNNI::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
+ TestMultiplyRelu<AVX512VNNI::Kernels8>(8, 2048, 256, 0, 0.55f, 0.25f);
+ TestMultiplyRelu<AVX512VNNI::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
+ TestMultiplyRelu<AVX512VNNI::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
+ TestMultiplyRelu<AVX512VNNI::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
+ TestMultiplyRelu<AVX512VNNI::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
}
TEST_CASE ("Multiply AVX512VNNI 8bit with bias", "[biased_multiply]") {
if (kCPU < CPUType::AVX512VNNI) return;
- TestMultiplyBias<avx512vnni::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
- TestMultiplyBias<avx512vnni::Kernels8>(8, 2048, 256, 0, 0.55f, 0.25f);
- TestMultiplyBias<avx512vnni::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
- TestMultiplyBias<avx512vnni::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
- TestMultiplyBias<avx512vnni::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
- TestMultiplyBias<avx512vnni::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
+ TestMultiplyBias<AVX512VNNI::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
+ TestMultiplyBias<AVX512VNNI::Kernels8>(8, 2048, 256, 0, 0.55f, 0.25f);
+ TestMultiplyBias<AVX512VNNI::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
+ TestMultiplyBias<AVX512VNNI::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
+ TestMultiplyBias<AVX512VNNI::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
+ TestMultiplyBias<AVX512VNNI::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
+ }
+
+ TEST_CASE ("Multiply AVX512VNNI 8bit with bias and relu", "[biased_multiply_relu]") {
+ if (kCPU < CPUType::AVX512VNNI) return;
+ TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(8, 256, 256, 0, 0.25f, 0.062f);
+ TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(8, 2048, 256, 0, 0.55f, 0.25f);
+ TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(320, 256, 256, 0, 0.26f, 0.059f);
+ TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(472, 256, 256, 0, 0.29f, 0.059f);
+ TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(248, 256, 256, 0, 0.29f, 0.059f);
+ TestMultiplyBiasRelu<AVX512VNNI::Kernels8>(200, 256, 256, 0, 0.28f, 0.06f);
}
#endif
TEST_CASE ("Multiply AVX512 16bit", "[multiply]") {
if (kCPU < CPUType::AVX512BW) return;
- TestMultiply<avx512bw::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
- TestMultiply<avx512bw::Kernels16>(8, 2048, 256, .1f, 1, 0.011f);
- TestMultiply<avx512bw::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
- TestMultiply<avx512bw::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
- TestMultiply<avx512bw::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
- TestMultiply<avx512bw::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
+ TestMultiply<AVX512BW::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
+ TestMultiply<AVX512BW::Kernels16>(8, 2048, 256, .1f, 1, 0.011f);
+ TestMultiply<AVX512BW::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
+ TestMultiply<AVX512BW::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
+ TestMultiply<AVX512BW::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
+ TestMultiply<AVX512BW::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
+ }
+
+ TEST_CASE ("Multiply AVX512 16bit with relu", "[multiply_relu]") {
+ if (kCPU < CPUType::AVX512BW) return;
+ TestMultiplyRelu<AVX512BW::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<AVX512BW::Kernels16>(8, 2048, 256, .1f, 1, 0.011f);
+ TestMultiplyRelu<AVX512BW::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<AVX512BW::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<AVX512BW::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyRelu<AVX512BW::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
}
+
TEST_CASE ("Multiply AVX512 16bit with bias", "[biased_multiply]") {
if (kCPU < CPUType::AVX512BW) return;
- TestMultiplyBias<avx512bw::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
- TestMultiplyBias<avx512bw::Kernels16>(8, 2048, 256, .1f, 1, 0.011f);
- TestMultiplyBias<avx512bw::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
- TestMultiplyBias<avx512bw::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
- TestMultiplyBias<avx512bw::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
- TestMultiplyBias<avx512bw::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBias<AVX512BW::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBias<AVX512BW::Kernels16>(8, 2048, 256, .1f, 1, 0.011f);
+ TestMultiplyBias<AVX512BW::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBias<AVX512BW::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBias<AVX512BW::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBias<AVX512BW::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
+ }
+
+ TEST_CASE ("Multiply AVX512 16bit with bias and relu", "[biased_multiply_relu]") {
+ if (kCPU < CPUType::AVX512BW) return;
+ TestMultiplyBiasRelu<AVX512BW::Kernels16>(8, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<AVX512BW::Kernels16>(8, 2048, 256, .1f, 1, 0.011f);
+ TestMultiplyBiasRelu<AVX512BW::Kernels16>(320, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<AVX512BW::Kernels16>(472, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<AVX512BW::Kernels16>(248, 256, 256, .1f, 1, 0.01f);
+ TestMultiplyBiasRelu<AVX512BW::Kernels16>(200, 256, 256, .1f, 1, 0.01f);
}
#endif
diff --git a/test/prepare_b_quantized_transposed.cc b/test/prepare_b_quantized_transposed.cc
index 1437e0a..defe9a0 100644
--- a/test/prepare_b_quantized_transposed.cc
+++ b/test/prepare_b_quantized_transposed.cc
@@ -62,31 +62,33 @@ TEST_CASE("PrepareBQuantizedTransposed SSE2", "") {
if (kCPU < CPUType::SSE2)
return;
- CHECK(TestMany<sse2::Kernels16>(32, 128));
+ CHECK(TestMany<SSE2::Kernels16>(32, 128));
}
TEST_CASE("PrepareBQuantizedTransposed SSSE3", "") {
if (kCPU < CPUType::SSSE3)
return;
- CHECK(TestMany<ssse3::Kernels8>(32, 128));
+ CHECK(TestMany<SSSE3::Kernels8>(32, 128));
}
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
TEST_CASE("PrepareBQuantizedTransposed AVX2", "") {
if (kCPU < CPUType::AVX2)
return;
- CHECK(TestMany<avx2::Kernels8>(32, 128));
- CHECK(TestMany<avx2::Kernels16>(32, 128));
+ CHECK(TestMany<AVX2::Kernels8>(32, 128));
+ CHECK(TestMany<AVX2::Kernels16>(32, 128));
}
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
TEST_CASE("PrepareBQuantizedTransposed AVX512", "") {
if (kCPU < CPUType::AVX512BW)
return;
- CHECK(TestMany<avx512bw::Kernels8>(64, 128));
- CHECK(TestMany<avx512bw::Kernels16>(64, 128));
+ CHECK(TestMany<AVX512BW::Kernels8>(64, 128));
+ CHECK(TestMany<AVX512BW::Kernels16>(64, 128));
}
#endif
diff --git a/test/prepare_b_transposed.cc b/test/prepare_b_transposed.cc
index bc35138..1c11fbe 100644
--- a/test/prepare_b_transposed.cc
+++ b/test/prepare_b_transposed.cc
@@ -63,32 +63,34 @@ TEST_CASE("PrepareBTransposed SSE2", "") {
if (kCPU < CPUType::SSE2)
return;
- CHECK(TestMany<sse2::Kernels16>(4, 128, 2.0f));
+ CHECK(TestMany<SSE2::Kernels16>(4, 128, 2.0f));
}
TEST_CASE("PrepareBTransposed SSSE3", "") {
if (kCPU < CPUType::SSSE3)
return;
- CHECK(TestMany<ssse3::Kernels8>(4, 128, 2.0f));
+ CHECK(TestMany<SSSE3::Kernels8>(4, 128, 2.0f));
}
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
TEST_CASE("PrepareBTransposed AVX2", "") {
if (kCPU < CPUType::AVX2)
return;
- CHECK(TestMany<avx2::Kernels8>(8, 128, 2.0f));
- CHECK(TestMany<avx2::Kernels16>(8, 128, 2.0f));
+ CHECK(TestMany<AVX2::Kernels8>(8, 128, 2.0f));
+ CHECK(TestMany<AVX2::Kernels16>(8, 128, 2.0f));
}
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
- TEST_CASE("PrepareBTransposed AVX512", "") {
- if (kCPU < CPUType::AVX512BW)
- return;
+TEST_CASE("PrepareBTransposed AVX512", "") {
+ if (kCPU < CPUType::AVX512BW)
+ return;
- CHECK(TestMany<avx512bw::Kernels8>(16, 128, 2.0f));
- CHECK(TestMany<avx512bw::Kernels16>(16, 128, 2.0f));
- }
+ CHECK(TestMany<AVX512BW::Kernels8>(16, 128, 2.0f));
+ CHECK(TestMany<AVX512BW::Kernels16>(16, 128, 2.0f));
+}
#endif
}
diff --git a/test/quantize_test.cc b/test/quantize_test.cc
index 550ec66..622ff71 100644
--- a/test/quantize_test.cc
+++ b/test/quantize_test.cc
@@ -120,74 +120,78 @@ template <class Backend> void TestMany(std::size_t grow) {
TEST_CASE ("Quantize SSE2", "[quantize]") {
if (kCPU < CPUType::SSE2) return;
- TestMany<sse2::Kernels16>(8);
+ TestMany<SSE2::Kernels16>(8);
}
TEST_CASE ("Quantize SSSE3", "[quantize]") {
if (kCPU < CPUType::SSSE3) return;
- TestMany<ssse3::Kernels8>(1);
+ TestMany<SSSE3::Kernels8>(1);
}
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
TEST_CASE ("Quantize AVX2", "[quantize]") {
if (kCPU < CPUType::AVX2) return;
- TestMany<avx2::Kernels8>(1);
- TestMany<avx2::Kernels16>(16);
+ TestMany<AVX2::Kernels8>(1);
+ TestMany<AVX2::Kernels16>(16);
}
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
- TEST_CASE ("Quantize AVX512", "[quantize]") {
- if (kCPU < CPUType::AVX512BW) return;
- TestMany<avx512bw::Kernels8>(1);
- TestMany<avx512bw::Kernels16>(16);
- }
+TEST_CASE ("Quantize AVX512", "[quantize]") {
+ if (kCPU < CPUType::AVX512BW) return;
+ TestMany<AVX512BW::Kernels8>(1);
+ TestMany<AVX512BW::Kernels16>(16);
+}
#endif
TEST_CASE("QuantizeStd SSSE3", "[VectorMeanStd]") {
if (kCPU < CPUType::SSSE3) return;
- testVectorMeanStd<sse2::VectorMeanStd>(64);
- testVectorMeanStd<sse2::VectorMeanStd>(64, true);
- testVectorMeanStd<sse2::VectorMeanStd>(256);
- testVectorMeanStd<sse2::VectorMeanStd>(256, true);
- testVectorMeanStd<sse2::VectorMeanStd>(2048);
- testVectorMeanStd<sse2::VectorMeanStd>(2048, true);
- testVectorMeanStd<sse2::VectorMeanStd>(65536);
- testVectorMeanStd<sse2::VectorMeanStd>(65536, true);
- testVectorMeanStd<sse2::VectorMeanStd>(81920);
- testVectorMeanStd<sse2::VectorMeanStd>(81920, true);
- testVectorMeanStd<sse2::VectorMeanStd>(120832);
- testVectorMeanStd<sse2::VectorMeanStd>(120832, true);
+ testVectorMeanStd<SSE2::VectorMeanStd>(64);
+ testVectorMeanStd<SSE2::VectorMeanStd>(64, true);
+ testVectorMeanStd<SSE2::VectorMeanStd>(256);
+ testVectorMeanStd<SSE2::VectorMeanStd>(256, true);
+ testVectorMeanStd<SSE2::VectorMeanStd>(2048);
+ testVectorMeanStd<SSE2::VectorMeanStd>(2048, true);
+ testVectorMeanStd<SSE2::VectorMeanStd>(65536);
+ testVectorMeanStd<SSE2::VectorMeanStd>(65536, true);
+ testVectorMeanStd<SSE2::VectorMeanStd>(81920);
+ testVectorMeanStd<SSE2::VectorMeanStd>(81920, true);
+ testVectorMeanStd<SSE2::VectorMeanStd>(120832);
+ testVectorMeanStd<SSE2::VectorMeanStd>(120832, true);
}
+#ifdef INTGEMM_COMPILER_SUPPORTS_AVX2
TEST_CASE("QuantizeStd AVX2", "[VectorMeanStd]") {
if (kCPU < CPUType::AVX2) return;
- testVectorMeanStd<avx2::VectorMeanStd>(64);
- testVectorMeanStd<avx2::VectorMeanStd>(64, true);
- testVectorMeanStd<avx2::VectorMeanStd>(256);
- testVectorMeanStd<avx2::VectorMeanStd>(256, true);
- testVectorMeanStd<avx2::VectorMeanStd>(2048);
- testVectorMeanStd<avx2::VectorMeanStd>(2048, true);
- testVectorMeanStd<avx2::VectorMeanStd>(65536);
- testVectorMeanStd<avx2::VectorMeanStd>(65536, true);
- testVectorMeanStd<avx2::VectorMeanStd>(81920);
- testVectorMeanStd<avx2::VectorMeanStd>(81920, true);
- testVectorMeanStd<avx2::VectorMeanStd>(120832);
- testVectorMeanStd<avx2::VectorMeanStd>(120832, true);
+ testVectorMeanStd<AVX2::VectorMeanStd>(64);
+ testVectorMeanStd<AVX2::VectorMeanStd>(64, true);
+ testVectorMeanStd<AVX2::VectorMeanStd>(256);
+ testVectorMeanStd<AVX2::VectorMeanStd>(256, true);
+ testVectorMeanStd<AVX2::VectorMeanStd>(2048);
+ testVectorMeanStd<AVX2::VectorMeanStd>(2048, true);
+ testVectorMeanStd<AVX2::VectorMeanStd>(65536);
+ testVectorMeanStd<AVX2::VectorMeanStd>(65536, true);
+ testVectorMeanStd<AVX2::VectorMeanStd>(81920);
+ testVectorMeanStd<AVX2::VectorMeanStd>(81920, true);
+ testVectorMeanStd<AVX2::VectorMeanStd>(120832);
+ testVectorMeanStd<AVX2::VectorMeanStd>(120832, true);
}
+#endif
#ifdef INTGEMM_COMPILER_SUPPORTS_AVX512BW
TEST_CASE("QuantizeStd AVX512BW", "[VectorMeanStd]") {
if (kCPU < CPUType::AVX512BW) return;
- testVectorMeanStd<avx512bw::VectorMeanStd>(64);
- testVectorMeanStd<avx512bw::VectorMeanStd>(64, true);
- testVectorMeanStd<avx512bw::VectorMeanStd>(256);
- testVectorMeanStd<avx512bw::VectorMeanStd>(256, true);
- testVectorMeanStd<avx512bw::VectorMeanStd>(2048);
- testVectorMeanStd<avx512bw::VectorMeanStd>(2048, true);
- testVectorMeanStd<avx512bw::VectorMeanStd>(65536);
- testVectorMeanStd<avx512bw::VectorMeanStd>(65536, true);
- testVectorMeanStd<avx512bw::VectorMeanStd>(81920);
- testVectorMeanStd<avx512bw::VectorMeanStd>(81920, true);
- testVectorMeanStd<avx512bw::VectorMeanStd>(120832);
- testVectorMeanStd<avx512bw::VectorMeanStd>(120832, true);
+ testVectorMeanStd<AVX512BW::VectorMeanStd>(64);
+ testVectorMeanStd<AVX512BW::VectorMeanStd>(64, true);
+ testVectorMeanStd<AVX512BW::VectorMeanStd>(256);
+ testVectorMeanStd<AVX512BW::VectorMeanStd>(256, true);
+ testVectorMeanStd<AVX512BW::VectorMeanStd>(2048);
+ testVectorMeanStd<AVX512BW::VectorMeanStd>(2048, true);
+ testVectorMeanStd<AVX512BW::VectorMeanStd>(65536);
+ testVectorMeanStd<AVX512BW::VectorMeanStd>(65536, true);
+ testVectorMeanStd<AVX512BW::VectorMeanStd>(81920);
+ testVectorMeanStd<AVX512BW::VectorMeanStd>(81920, true);
+ testVectorMeanStd<AVX512BW::VectorMeanStd>(120832);
+ testVectorMeanStd<AVX512BW::VectorMeanStd>(120832, true);
}
#endif