diff options
-rw-r--r-- | avx2_gemm.h | 8 | ||||
-rw-r--r-- | avx512_gemm.h | 8 | ||||
-rw-r--r-- | intgemm.cc | 2 | ||||
-rw-r--r-- | postprocess.h | 24 | ||||
-rw-r--r-- | postprocess_pipeline.h | 16 | ||||
-rw-r--r-- | sse2_gemm.h | 4 | ||||
-rw-r--r-- | ssse3_gemm.h | 4 | ||||
-rw-r--r-- | test/multiply_test.cc | 50 | ||||
-rw-r--r-- | test/pipeline_test.cc | 8 | ||||
-rw-r--r-- | test/quantize_test.cc | 8 | ||||
-rw-r--r-- | test/relu_test.cc | 12 | ||||
-rw-r--r-- | types.h | 8 |
12 files changed, 79 insertions, 73 deletions
diff --git a/avx2_gemm.h b/avx2_gemm.h index a03ff09..1482090 100644 --- a/avx2_gemm.h +++ b/avx2_gemm.h @@ -80,11 +80,11 @@ struct AVX2_16bit { avx2::SelectColumnsOfB((const __m256i*)input, (__m256i*)output, rows * 2, cols_begin, cols_end); } - INTGEMM_MULTIPLY16(__m256i, INTGEMM_AVX2, CPUType::CPU_AVX2) + INTGEMM_MULTIPLY16(__m256i, INTGEMM_AVX2, CPUType::AVX2) constexpr static const char *const kName = "16-bit INTGEMM_AVX2"; - static const CPUType kUses = CPU_AVX2; + static const CPUType kUses = CPUType::AVX2; }; namespace avx2 { @@ -169,11 +169,11 @@ struct AVX2_8bit { avx2::SelectColumnsOfB((const __m256i*)input, (__m256i*)output, rows, cols_begin, cols_end); } - INTGEMM_MULTIPLY8(__m256i, INTGEMM_AVX2, CPUType::CPU_AVX2) + INTGEMM_MULTIPLY8(__m256i, INTGEMM_AVX2, CPUType::AVX2) constexpr static const char *const kName = "8-bit INTGEMM_AVX2"; - static const CPUType kUses = CPU_AVX2; + static const CPUType kUses = CPUType::AVX2; }; } // namespace intgemm diff --git a/avx512_gemm.h b/avx512_gemm.h index 4d6f0db..043dfae 100644 --- a/avx512_gemm.h +++ b/avx512_gemm.h @@ -166,11 +166,11 @@ struct AVX512_16bit { } /* Only INTGEMM_AVX512F is necessary but due to GCC 5.4 bug we have to set INTGEMM_AVX512BW */ - INTGEMM_MULTIPLY16(__m512i, INTGEMM_AVX512BW, CPUType::CPU_AVX2) + INTGEMM_MULTIPLY16(__m512i, INTGEMM_AVX512BW, CPUType::AVX2) constexpr static const char *const kName = "16-bit AVX512"; - static const CPUType kUses = CPU_AVX512BW; + static const CPUType kUses = CPUType::AVX512BW; }; struct AVX512_8bit { @@ -227,7 +227,7 @@ struct AVX512_8bit { assert(reinterpret_cast<uintptr_t>(A) % sizeof(Integer) == 0); assert(reinterpret_cast<uintptr_t>(B) % sizeof(Integer) == 0); // There's 8 results for INTGEMM_AVX2 to handle. - auto inited_pipeline = InitPostprocessPipeline<CPUType::CPU_AVX2>(pipeline); + auto inited_pipeline = InitPostprocessPipeline<CPUType::AVX2>(pipeline); const int simd_width = width / sizeof(Integer); const Integer *B0_col = reinterpret_cast<const Integer*>(B); // Added for AVX512. @@ -333,7 +333,7 @@ struct AVX512_8bit { constexpr static const char *const kName = "8-bit AVX512"; - static const CPUType kUses = CPU_AVX512BW; + static const CPUType kUses = CPUType::AVX512BW; }; } // namespace intgemm @@ -22,7 +22,7 @@ void (*Int8::SelectColumnsB)(const int8_t *input, int8_t *output, Index rows, co const char *const Int8::kName = ChooseCPU(AVX512_8bit::kName, AVX2_8bit::kName, SSSE3_8bit::kName, Unsupported_8bit::kName, Unsupported_8bit::kName); -const CPUType kCPU = ChooseCPU(CPU_AVX512BW, CPU_AVX2, CPU_SSSE3, CPU_SSE2, CPU_UNSUPPORTED); +const CPUType kCPU = ChooseCPU(CPUType::AVX512BW, CPUType::AVX2, CPUType::SSSE3, CPUType::SSE2, CPUType::UNSUPPORTED); float (*MaxAbsolute)(const float *begin, const float *end) = ChooseCPU(avx512f::MaxAbsolute, avx2::MaxAbsolute, sse2::MaxAbsolute, sse2::MaxAbsolute, Unsupported_MaxAbsolute); diff --git a/postprocess.h b/postprocess.h index 76755b3..0855548 100644 --- a/postprocess.h +++ b/postprocess.h @@ -18,7 +18,7 @@ public: }; template <> -class PostprocessImpl<Unquantize, CPUType::CPU_SSE2> { +class PostprocessImpl<Unquantize, CPUType::SSE2> { public: using InputRegister = RegisterPair128i; using OutputRegister = RegisterPair128; @@ -39,7 +39,7 @@ private: }; template <> -class PostprocessImpl<Unquantize, CPUType::CPU_AVX2> { +class PostprocessImpl<Unquantize, CPUType::AVX2> { public: using InputRegister = __m256i; using OutputRegister = __m256; @@ -57,7 +57,7 @@ private: }; template <> -class PostprocessImpl<Unquantize, CPUType::CPU_AVX512BW> { +class PostprocessImpl<Unquantize, CPUType::AVX512BW> { public: using InputRegister = __m512i; using OutputRegister = __m512; @@ -80,7 +80,7 @@ private: class Identity {}; template <> -class PostprocessImpl<Identity, CPUType::CPU_SSE2> { +class PostprocessImpl<Identity, CPUType::SSE2> { public: using InputRegister = RegisterPair128i; using OutputRegister = RegisterPair128i; @@ -93,7 +93,7 @@ public: }; template <> -class PostprocessImpl<Identity, CPUType::CPU_AVX2> { +class PostprocessImpl<Identity, CPUType::AVX2> { public: using InputRegister = __m256i; using OutputRegister = __m256i; @@ -106,7 +106,7 @@ public: }; template <> -class PostprocessImpl<Identity, CPUType::CPU_AVX512BW> { +class PostprocessImpl<Identity, CPUType::AVX512BW> { public: using InputRegister = __m512i; using OutputRegister = __m512i; @@ -130,7 +130,7 @@ public: }; template <> -class PostprocessImpl<AddBias, CPUType::CPU_SSE2> { +class PostprocessImpl<AddBias, CPUType::SSE2> { public: using InputRegister = RegisterPair128; using OutputRegister = RegisterPair128; @@ -151,7 +151,7 @@ private: }; template <> -class PostprocessImpl<AddBias, CPUType::CPU_AVX2> { +class PostprocessImpl<AddBias, CPUType::AVX2> { public: using InputRegister = __m256; using OutputRegister = __m256; @@ -173,7 +173,7 @@ private: class ReLU {}; template <> -class PostprocessImpl<ReLU, CPUType::CPU_SSE2> { +class PostprocessImpl<ReLU, CPUType::SSE2> { public: using InputRegister = RegisterPair128; using OutputRegister = RegisterPair128; @@ -190,10 +190,10 @@ public: }; template <> -class PostprocessImpl<ReLU, CPUType::CPU_SSSE3> : public PostprocessImpl<ReLU, CPUType::CPU_SSE2> {}; +class PostprocessImpl<ReLU, CPUType::SSSE3> : public PostprocessImpl<ReLU, CPUType::SSE2> {}; template <> -class PostprocessImpl<ReLU, CPUType::CPU_AVX2> { +class PostprocessImpl<ReLU, CPUType::AVX2> { public: using InputRegister = __m256; using OutputRegister = __m256; @@ -207,7 +207,7 @@ public: }; template <> -class PostprocessImpl<ReLU, CPUType::CPU_AVX512BW> { +class PostprocessImpl<ReLU, CPUType::AVX512BW> { public: using InputRegister = __m512; using OutputRegister = __m512; diff --git a/postprocess_pipeline.h b/postprocess_pipeline.h index 01f3f11..ad26ac5 100644 --- a/postprocess_pipeline.h +++ b/postprocess_pipeline.h @@ -71,10 +71,10 @@ struct RunPostprocessPipelineImpl; } \ }; -RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(INTGEMM_SSE2, CPUType::CPU_SSE2) -RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(INTGEMM_SSSE3, CPUType::CPU_SSSE3) -RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(INTGEMM_AVX2, CPUType::CPU_AVX2) -RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(INTGEMM_AVX512BW, CPUType::CPU_AVX512BW) +RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(INTGEMM_SSE2, CPUType::SSE2) +RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(INTGEMM_SSSE3, CPUType::SSSE3) +RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(INTGEMM_AVX2, CPUType::AVX2) +RUN_POSTPROCESS_PIPELINE_IMPL_INSERT_IMPL(INTGEMM_AVX512BW, CPUType::AVX512BW) } // anonymous namespace @@ -105,9 +105,9 @@ constexpr InitedPostprocessPipeline<CpuType, Stages...> InitPostprocessPipeline( const std::tuple<PostprocessImpl<Stages, cpu_type>...> inited_pipeline; \ }; -INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(INTGEMM_SSE2, CPUType::CPU_SSE2) -INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(INTGEMM_SSSE3, CPUType::CPU_SSSE3) -INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(INTGEMM_AVX2, CPUType::CPU_AVX2) -INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(INTGEMM_AVX512BW, CPUType::CPU_AVX512BW) +INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(INTGEMM_SSE2, CPUType::SSE2) +INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(INTGEMM_SSSE3, CPUType::SSSE3) +INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(INTGEMM_AVX2, CPUType::AVX2) +INITED_POSTPROCESS_PIPELINE_INSERT_IMPL(INTGEMM_AVX512BW, CPUType::AVX512BW) } diff --git a/sse2_gemm.h b/sse2_gemm.h index dfccc5c..6b7e698 100644 --- a/sse2_gemm.h +++ b/sse2_gemm.h @@ -72,11 +72,11 @@ struct SSE2_16bit { //TODO #DEFINE sse2::SelectColumnsOfB((const __m128i*)input, (__m128i*)output, rows * 2, cols_begin, cols_end); } - INTGEMM_MULTIPLY16(__m128i, INTGEMM_SSE2, CPUType::CPU_SSE2) + INTGEMM_MULTIPLY16(__m128i, INTGEMM_SSE2, CPUType::SSE2) constexpr static const char *const kName = "16-bit INTGEMM_SSE2"; - static const CPUType kUses = CPU_SSE2; + static const CPUType kUses = CPUType::SSE2; }; } // namespace intgemm diff --git a/ssse3_gemm.h b/ssse3_gemm.h index 4e12b90..6038889 100644 --- a/ssse3_gemm.h +++ b/ssse3_gemm.h @@ -95,11 +95,11 @@ struct SSSE3_8bit { ssse3::SelectColumnsOfB((const __m128i*)input, (__m128i*)output, rows, cols_begin, cols_end); } - INTGEMM_MULTIPLY8(__m128i, INTGEMM_SSSE3, CPUType::CPU_SSE2) + INTGEMM_MULTIPLY8(__m128i, INTGEMM_SSSE3, CPUType::SSE2) constexpr static const char *const kName = "8-bit INTGEMM_SSSE3"; - static const CPUType kUses = CPU_SSSE3; + static const CPUType kUses = CPUType::SSSE3; }; } // namespace intgemm diff --git a/test/multiply_test.cc b/test/multiply_test.cc index 00256a8..82062fe 100644 --- a/test/multiply_test.cc +++ b/test/multiply_test.cc @@ -52,7 +52,7 @@ template <class V> void SlowTranspose(const V *from, V *to, Index rows, Index co } INTGEMM_SSE2 TEST_CASE("Transpose 16", "[transpose]") { - if (kCPU < CPU_SSE2) return; + if (kCPU < CPUType::SSE2) return; const unsigned N = 8; AlignedVector<int16_t> input(N * N); std::iota(input.begin(), input.end(), 0); @@ -70,7 +70,7 @@ INTGEMM_SSE2 TEST_CASE("Transpose 16", "[transpose]") { } INTGEMM_SSSE3 TEST_CASE("Transpose 8", "[transpose]") { - if (kCPU < CPU_SSSE3) return; + if (kCPU < CPUType::SSSE3) return; const unsigned N = 16; AlignedVector<int8_t> input(N * N); std::iota(input.begin(), input.end(), 0); @@ -125,7 +125,7 @@ template <class Routine> void TestPrepare(Index rows = 32, Index cols = 16) { } TEST_CASE("Prepare AVX512", "[prepare]") { - if (kCPU < CPU_AVX512BW) return; + if (kCPU < CPUType::AVX512BW) return; #ifndef INTGEMM_NO_AVX512 TestPrepare<AVX512_8bit>(64, 8); TestPrepare<AVX512_8bit>(256, 32); @@ -135,20 +135,20 @@ TEST_CASE("Prepare AVX512", "[prepare]") { } TEST_CASE("Prepare AVX2", "[prepare]") { - if (kCPU < CPU_AVX2) return; + if (kCPU < CPUType::AVX2) return; TestPrepare<AVX2_8bit>(64, 32); TestPrepare<AVX2_16bit>(64, 32); } TEST_CASE("Prepare SSSE3", "[prepare]") { - if (kCPU < CPU_SSSE3) return; + if (kCPU < CPUType::SSSE3) return; TestPrepare<SSSE3_8bit>(16, 8); TestPrepare<SSSE3_8bit>(32, 16); TestPrepare<SSSE3_8bit>(32, 32); } TEST_CASE("Prepare SSE2", "[prepare]") { - if (kCPU < CPU_SSE2) return; + if (kCPU < CPUType::SSE2) return; TestPrepare<SSE2_16bit>(8, 8); TestPrepare<SSE2_16bit>(32, 32); } @@ -190,7 +190,7 @@ template <class Routine> void TestSelectColumnsB(Index rows = 64, Index cols = 1 } TEST_CASE("SelectColumnsB AVX512", "[select]") { - if (kCPU < CPU_AVX512BW) return; + if (kCPU < CPUType::AVX512BW) return; #ifndef INTGEMM_NO_AVX512 TestSelectColumnsB<AVX512_8bit>(); TestSelectColumnsB<AVX512_16bit>(256, 256); @@ -198,19 +198,19 @@ TEST_CASE("SelectColumnsB AVX512", "[select]") { } TEST_CASE("SelectColumnsB AVX2", "[select]") { - if (kCPU < CPU_AVX2) return; + if (kCPU < CPUType::AVX2) return; TestSelectColumnsB<AVX2_8bit>(256, 256); TestSelectColumnsB<AVX2_16bit>(256, 256); } TEST_CASE("SelectColumnsB SSSE3", "[select]") { - if (kCPU < CPU_SSSE3) return; + if (kCPU < CPUType::SSSE3) return; TestSelectColumnsB<SSSE3_8bit>(); TestSelectColumnsB<SSSE3_8bit>(256, 256); } TEST_CASE("SelectColumnsB SSE2", "[select]") { - if (kCPU < CPU_SSE2) return; + if (kCPU < CPUType::SSE2) return; TestSelectColumnsB<SSE2_16bit>(); TestSelectColumnsB<SSE2_16bit>(256, 256); } @@ -254,17 +254,17 @@ template <float (*Backend) (const float *, const float *)> void TestMaxAbsolute( } TEST_CASE("MaxAbsolute SSE2", "[max]") { - if (kCPU < CPU_SSE2) return; + if (kCPU < CPUType::SSE2) return; TestMaxAbsolute<sse2::MaxAbsolute>(); } TEST_CASE("MaxAbsolute AVX2", "[max]") { - if (kCPU < CPU_AVX2) return; + if (kCPU < CPUType::AVX2) return; TestMaxAbsolute<avx2::MaxAbsolute>(); } TEST_CASE("MaxAbsolute AVX512F", "[max]") { - if (kCPU < CPU_AVX512BW) return; + if (kCPU < CPUType::AVX512BW) return; #ifndef INTGEMM_NO_AVX512 TestMaxAbsolute<avx512f::MaxAbsolute>(); #endif @@ -432,7 +432,7 @@ template <class Routine> void TestMultiplyBias(Index A_rows, Index width, Index } TEST_CASE ("Multiply SSE2 16bit", "[multiply]") { - if (kCPU < CPU_SSE2) return; + if (kCPU < CPUType::SSE2) return; TestMultiply<SSE2_16bit>(8, 256, 256, .1, 1, 0.01); TestMultiply<SSE2_16bit>(8, 2048, 256, .1, 1, 0.02); TestMultiply<SSE2_16bit>(320, 256, 256, .1, 1, 0.01); @@ -442,7 +442,7 @@ TEST_CASE ("Multiply SSE2 16bit", "[multiply]") { } TEST_CASE ("Multiply SSE2 16bit with bias", "[biased_multiply]") { - if (kCPU < CPU_SSE2) return; + if (kCPU < CPUType::SSE2) return; TestMultiplyBias<SSE2_16bit>(8, 256, 256, .1, 1, 0.01); TestMultiplyBias<SSE2_16bit>(8, 2048, 256, .1, 1, 0.02); TestMultiplyBias<SSE2_16bit>(320, 256, 256, .1, 1, 0.01); @@ -452,7 +452,7 @@ TEST_CASE ("Multiply SSE2 16bit with bias", "[biased_multiply]") { } TEST_CASE ("Multiply SSSE3 8bit", "[multiply]") { - if (kCPU < CPU_SSSE3) return; + if (kCPU < CPUType::SSSE3) return; TestMultiply<SSSE3_8bit>(8, 256, 256, 1.2, 1.2, 0.064, 0.026); TestMultiply<SSSE3_8bit>(8, 2048, 256, 33, 33, 4.4, 4.4); TestMultiply<SSSE3_8bit>(320, 256, 256, 1.9, 1.9, 0.1, 0.01); @@ -462,7 +462,7 @@ TEST_CASE ("Multiply SSSE3 8bit", "[multiply]") { } TEST_CASE ("Multiply SSSE3 8bit with bias", "[biased_multiply]") { - if (kCPU < CPU_SSSE3) return; + if (kCPU < CPUType::SSSE3) return; TestMultiplyBias<SSSE3_8bit>(8, 256, 256, 1.2, 1.2, 0.064, 0.026); TestMultiplyBias<SSSE3_8bit>(8, 2048, 256, 33, 33, 4.4, 4.4); TestMultiplyBias<SSSE3_8bit>(320, 256, 256, 1.9, 1.9, 0.1, 0.01); @@ -472,7 +472,7 @@ TEST_CASE ("Multiply SSSE3 8bit with bias", "[biased_multiply]") { } TEST_CASE ("Multiply AVX2 8bit", "[multiply]") { - if (kCPU < CPU_AVX2) return; + if (kCPU < CPUType::AVX2) return; TestMultiply<AVX2_8bit>(8, 256, 256, .1, 1, 0.1); TestMultiply<AVX2_8bit>(8, 2048, 256, 19, 19, 1.8, 1.8); TestMultiply<AVX2_8bit>(320, 256, 256, .1, 1, 0.1); @@ -482,7 +482,7 @@ TEST_CASE ("Multiply AVX2 8bit", "[multiply]") { } TEST_CASE ("Multiply AVX2 8bit with bias", "[biased_multiply]") { - if (kCPU < CPU_AVX2) return; + if (kCPU < CPUType::AVX2) return; TestMultiplyBias<AVX2_8bit>(8, 256, 256, .1, 1, 0.1); TestMultiplyBias<AVX2_8bit>(8, 2048, 256, 19, 19, 1.8, 1.8); TestMultiplyBias<AVX2_8bit>(320, 256, 256, .1, 1, 0.1); @@ -492,7 +492,7 @@ TEST_CASE ("Multiply AVX2 8bit with bias", "[biased_multiply]") { } TEST_CASE ("Multiply AVX2 16bit", "[multiply]") { - if (kCPU < CPU_AVX2) return; + if (kCPU < CPUType::AVX2) return; TestMultiply<AVX2_16bit>(8, 256, 256, .1, 1, 0.01); TestMultiply<AVX2_16bit>(8, 2048, 256, .1, 1, 0.02); TestMultiply<AVX2_16bit>(320, 256, 256, .1, 1, 0.01); @@ -502,7 +502,7 @@ TEST_CASE ("Multiply AVX2 16bit", "[multiply]") { } TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") { - if (kCPU < CPU_AVX2) return; + if (kCPU < CPUType::AVX2) return; TestMultiplyBias<AVX2_16bit>(8, 256, 256, .1, 1, 0.01); TestMultiplyBias<AVX2_16bit>(8, 2048, 256, .1, 1, 0.02); TestMultiplyBias<AVX2_16bit>(320, 256, 256, .1, 1, 0.01); @@ -513,7 +513,7 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") { #ifndef INTGEMM_NO_AVX512 TEST_CASE ("Multiply AVX512 8bit", "[multiply]") { - if (kCPU < CPU_AVX512BW) return; + if (kCPU < CPUType::AVX512BW) return; TestMultiply<AVX512_8bit>(8, 256, 256, .1, 1, 0.062); TestMultiply<AVX512_8bit>(8, 2048, 256, 4.2, 4, 0.41, 0.37); TestMultiply<AVX512_8bit>(320, 256, 256, .1, 1, 0.06); @@ -523,7 +523,7 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") { } TEST_CASE ("Multiply AVX512 8bit with bias", "[biased_multiply]") { - if (kCPU < CPU_AVX512BW) return; + if (kCPU < CPUType::AVX512BW) return; TestMultiplyBias<AVX512_8bit>(8, 256, 256, .1, 1, 0.062); TestMultiplyBias<AVX512_8bit>(8, 2048, 256, 4.2, 4, 0.41, 0.37); TestMultiplyBias<AVX512_8bit>(320, 256, 256, .1, 1, 0.06); @@ -533,7 +533,7 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") { } TEST_CASE ("Multiply AVX512 16bit", "[multiply]") { - if (kCPU < CPU_AVX512BW) return; + if (kCPU < CPUType::AVX512BW) return; TestMultiply<AVX512_16bit>(8, 256, 256, .1, 1, 0.01); TestMultiply<AVX512_16bit>(8, 2048, 256, .1, 1, 0.011); TestMultiply<AVX512_16bit>(320, 256, 256, .1, 1, 0.01); @@ -543,7 +543,7 @@ TEST_CASE ("Multiply AVX2 16bit with bias", "[biased_multiply]") { } TEST_CASE ("Multiply AVX512 16bit with bias", "[biased_multiply]") { - if (kCPU < CPU_AVX512BW) return; + if (kCPU < CPUType::AVX512BW) return; TestMultiplyBias<AVX512_16bit>(8, 256, 256, .1, 1, 0.01); TestMultiplyBias<AVX512_16bit>(8, 2048, 256, .1, 1, 0.011); TestMultiplyBias<AVX512_16bit>(320, 256, 256, .1, 1, 0.01); diff --git a/test/pipeline_test.cc b/test/pipeline_test.cc index 389b9d7..1b8c21d 100644 --- a/test/pipeline_test.cc +++ b/test/pipeline_test.cc @@ -6,7 +6,7 @@ namespace intgemm { INTGEMM_AVX2 TEST_CASE("PostprocessPipeline AVX2", "Unquantize-ReLU") { - if (kCPU < CPU_AVX2) + if (kCPU < CPUType::AVX2) return; __m256i input; @@ -19,7 +19,7 @@ INTGEMM_AVX2 TEST_CASE("PostprocessPipeline AVX2", "Unquantize-ReLU") { std::fill(raw_output, raw_output + 8, 42); auto pipeline = CreatePostprocessPipeline(Unquantize(0.5f), ReLU()); - auto inited_pipeline = InitPostprocessPipeline<CPU_AVX2>(pipeline); + auto inited_pipeline = InitPostprocessPipeline<CPUType::AVX2>(pipeline); output = inited_pipeline.run(input, 0); CHECK(raw_output[0] == 0.0f); // input = -2 @@ -33,7 +33,7 @@ INTGEMM_AVX2 TEST_CASE("PostprocessPipeline AVX2", "Unquantize-ReLU") { } INTGEMM_AVX2 TEST_CASE("PostprocessPipeline AVX2 on whole buffer", "Unquantize-ReLU") { - if (kCPU < CPU_AVX2) + if (kCPU < CPUType::AVX2) return; __m256i input[2]; @@ -46,7 +46,7 @@ INTGEMM_AVX2 TEST_CASE("PostprocessPipeline AVX2 on whole buffer", "Unquantize-R std::fill(raw_output, raw_output + 16, 42); auto pipeline = CreatePostprocessPipeline(Unquantize(0.5f), ReLU()); - auto inited_pipeline = InitPostprocessPipeline<CPU_AVX2>(pipeline); + auto inited_pipeline = InitPostprocessPipeline<CPUType::AVX2>(pipeline); inited_pipeline.run(input, 2, output); CHECK(raw_output[0] == 0.f); // input = -8 diff --git a/test/quantize_test.cc b/test/quantize_test.cc index e68fc34..fb866f1 100644 --- a/test/quantize_test.cc +++ b/test/quantize_test.cc @@ -74,23 +74,23 @@ template <class Backend> bool TestMany() { } TEST_CASE ("Quantize SSE2", "[quantize]") { - if (kCPU < CPU_SSE2) return; + if (kCPU < CPUType::SSE2) return; CHECK(TestMany<SSE2_16bit>()); } TEST_CASE ("Quantize SSE3", "[quantize]") { - if (kCPU < CPU_SSSE3) return; + if (kCPU < CPUType::SSSE3) return; CHECK(TestMany<SSSE3_8bit>()); } TEST_CASE ("Quantize AVX2", "[quantize]") { - if (kCPU < CPU_AVX2) return; + if (kCPU < CPUType::AVX2) return; CHECK(TestMany<AVX2_8bit>()); CHECK(TestMany<AVX2_16bit>()); } #ifndef INTGEMM_NO_AVX512 TEST_CASE ("Quantize AVX512", "[quantize]") { - if (kCPU < CPU_AVX512BW) return; + if (kCPU < CPUType::AVX512BW) return; CHECK(TestMany<AVX512_8bit>()); CHECK(TestMany<AVX512_16bit>()); } diff --git a/test/relu_test.cc b/test/relu_test.cc index 0a72a29..183f415 100644 --- a/test/relu_test.cc +++ b/test/relu_test.cc @@ -6,7 +6,7 @@ namespace intgemm { INTGEMM_SSE2 TEST_CASE("ReLU SSE2",) { - if (kCPU < CPU_SSE2) + if (kCPU < CPUType::SSE2) return; float raw_input[8]; @@ -16,7 +16,7 @@ INTGEMM_SSE2 TEST_CASE("ReLU SSE2",) { input.pack0123 = *reinterpret_cast<__m128*>(raw_input); input.pack4567 = *reinterpret_cast<__m128*>(raw_input + 4); - auto postproc = PostprocessImpl<ReLU, CPUType::CPU_SSE2>(ReLU()); + auto postproc = PostprocessImpl<ReLU, CPUType::SSE2>(ReLU()); auto output = postproc.run(input, 0); auto raw_output = reinterpret_cast<float*>(&output); @@ -31,14 +31,14 @@ INTGEMM_SSE2 TEST_CASE("ReLU SSE2",) { } INTGEMM_AVX2 TEST_CASE("ReLU AVX2",) { - if (kCPU < CPU_AVX2) + if (kCPU < CPUType::AVX2) return; float raw_input[8]; std::iota(raw_input, raw_input + 8, -4); auto input = *reinterpret_cast<__m256*>(raw_input); - auto postproc = PostprocessImpl<ReLU, CPUType::CPU_AVX2>(ReLU()); + auto postproc = PostprocessImpl<ReLU, CPUType::AVX2>(ReLU()); auto output = postproc.run(input, 0); auto raw_output = reinterpret_cast<float*>(&output); @@ -55,14 +55,14 @@ INTGEMM_AVX2 TEST_CASE("ReLU AVX2",) { #ifndef INTGEMM_NO_AVX512 INTGEMM_AVX512BW TEST_CASE("ReLU AVX512",) { - if (kCPU < CPU_AVX512BW) + if (kCPU < CPUType::AVX512BW) return; float raw_input[16]; std::iota(raw_input, raw_input + 16, -8); auto input = *reinterpret_cast<__m512*>(raw_input); - auto postproc = PostprocessImpl<ReLU, CPUType::CPU_AVX512BW>(ReLU()); + auto postproc = PostprocessImpl<ReLU, CPUType::AVX512BW>(ReLU()); auto output = postproc.run(input, 0); auto raw_output = reinterpret_cast<float*>(&output); @@ -33,7 +33,13 @@ class UnsupportedCPU : public std::exception { typedef unsigned int Index; // If you want to detect the CPU and dispatch yourself, here's what to use: -typedef enum {CPU_AVX512BW = 4, CPU_AVX2 = 3, CPU_SSSE3 = 2, CPU_SSE2 = 1, CPU_UNSUPPORTED = 0} CPUType; +enum class CPUType { + UNSUPPORTED = 0, + SSE2, + SSSE3, + AVX2, + AVX512BW, +}; // Running CPU type. This is defined in intgemm.cc (as the dispatcher). extern const CPUType kCPU; |