diff options
-rw-r--r-- | ruy/BUILD | 26 | ||||
-rw-r--r-- | ruy/build_defs.bzl | 2 | ||||
-rw-r--r-- | ruy/ctx.cc | 4 | ||||
-rw-r--r-- | ruy/ctx_test.cc | 4 | ||||
-rw-r--r-- | ruy/have_built_path_for.h | 2 | ||||
-rw-r--r-- | ruy/have_built_path_for_avx2_fma.cc (renamed from ruy/have_built_path_for_avx2.cc) | 10 | ||||
-rw-r--r-- | ruy/kernel_avx2_fma.cc (renamed from ruy/kernel_avx2.cc) | 16 | ||||
-rw-r--r-- | ruy/kernel_x86.h | 13 | ||||
-rw-r--r-- | ruy/pack_avx2_fma.cc (renamed from ruy/pack_avx2.cc) | 16 | ||||
-rw-r--r-- | ruy/pack_x86.h | 14 | ||||
-rw-r--r-- | ruy/path.h | 8 | ||||
-rw-r--r-- | ruy/platform.h | 33 | ||||
-rw-r--r-- | ruy/test.h | 2 |
13 files changed, 64 insertions, 86 deletions
@@ -1,7 +1,7 @@ # Ruy is not BLAS load("@bazel_skylib//lib:selects.bzl", "selects") -load(":build_defs.bzl", "ruy_copts", "ruy_copts_avx2", "ruy_copts_avx512") +load(":build_defs.bzl", "ruy_copts", "ruy_copts_avx2_fma", "ruy_copts_avx512") load(":build_defs.oss.bzl", "ruy_linkopts_thread_standard_library") load(":ruy_test_ext.oss.bzl", "ruy_test_ext_defines", "ruy_test_ext_deps") load(":ruy_test.bzl", "ruy_benchmark", "ruy_test") @@ -624,14 +624,14 @@ cc_library( ) cc_library( - name = "kernel_avx2", + name = "kernel_avx2_fma", srcs = [ - "kernel_avx2.cc", + "kernel_avx2_fma.cc", ], hdrs = [ "kernel_x86.h", ], - copts = ruy_copts() + ruy_copts_avx2(), + copts = ruy_copts() + ruy_copts_avx2_fma(), deps = [ ":check_macros", ":kernel_common", @@ -646,14 +646,14 @@ cc_library( ) cc_library( - name = "pack_avx2", + name = "pack_avx2_fma", srcs = [ - "pack_avx2.cc", + "pack_avx2_fma.cc", ], hdrs = [ "pack_x86.h", ], - copts = ruy_copts() + ruy_copts_avx2(), + copts = ruy_copts() + ruy_copts_avx2_fma(), deps = [ ":check_macros", ":mat", @@ -667,14 +667,14 @@ cc_library( ) cc_library( - name = "have_built_path_for_avx2", + name = "have_built_path_for_avx2_fma", srcs = [ - "have_built_path_for_avx2.cc", + "have_built_path_for_avx2_fma.cc", ], hdrs = [ "have_built_path_for.h", ], - copts = ruy_copts() + ruy_copts_avx2(), + copts = ruy_copts() + ruy_copts_avx2_fma(), deps = [ ":opt_set", ":platform", @@ -691,7 +691,7 @@ cc_library( ":apply_multiplier", ":check_macros", ":kernel_arm", # fixdeps: keep - ":kernel_avx2", # fixdeps: keep + ":kernel_avx2_fma", # fixdeps: keep ":kernel_avx512", # fixdeps: keep ":kernel_common", ":mat", @@ -719,7 +719,7 @@ cc_library( ":matrix", ":opt_set", ":pack_arm", # fixdeps: keep - ":pack_avx2", # fixdeps: keep + ":pack_avx2_fma", # fixdeps: keep ":pack_avx512", # fixdeps: keep ":pack_common", ":path", @@ -735,7 +735,7 @@ cc_library( "have_built_path_for.h", ], deps = [ - ":have_built_path_for_avx2", + ":have_built_path_for_avx2_fma", ":have_built_path_for_avx512", ":platform", ], diff --git a/ruy/build_defs.bzl b/ruy/build_defs.bzl index 594c4d9..a36942b 100644 --- a/ruy/build_defs.bzl +++ b/ruy/build_defs.bzl @@ -63,7 +63,7 @@ def ruy_copts_avx512(): "//conditions:default": [], }) -def ruy_copts_avx2(): +def ruy_copts_avx2_fma(): return select({ "//ruy:x86_64": ["-mavx2", "-mfma"], "//conditions:default": [], @@ -111,8 +111,8 @@ Path DetectRuntimeSupportedPaths(Path paths_to_detect, CpuInfo* cpuinfo) { #elif RUY_PLATFORM_X86 // x86 SIMD paths currently require both runtime detection, and detection of // whether we're building the path at all. - maybe_add(Path::kAvx2, - [=]() { return HaveBuiltPathForAvx2() && cpuinfo->Avx2(); }); + maybe_add(Path::kAvx2Fma, + [=]() { return HaveBuiltPathForAvx2Fma() && cpuinfo->Avx2(); }); maybe_add(Path::kAvx512, [=]() { return HaveBuiltPathForAvx512() && cpuinfo->Avx512(); }); #else diff --git a/ruy/ctx_test.cc b/ruy/ctx_test.cc index f57ec64..e55dcfc 100644 --- a/ruy/ctx_test.cc +++ b/ruy/ctx_test.cc @@ -33,9 +33,9 @@ TEST(ContextInternalTest, EnabledPathsGeneral) { #if RUY_PLATFORM_X86 TEST(ContextInternalTest, EnabledPathsX86Explicit) { CtxImpl ctx; - ctx.SetRuntimeEnabledPaths(Path::kAvx2); + ctx.SetRuntimeEnabledPaths(Path::kAvx2Fma); const auto ruy_paths = ctx.GetRuntimeEnabledPaths(); - EXPECT_EQ(ruy_paths, Path::kStandardCpp | Path::kAvx2); + EXPECT_EQ(ruy_paths, Path::kStandardCpp | Path::kAvx2Fma); } #endif // RUY_PLATFORM_X86 diff --git a/ruy/have_built_path_for.h b/ruy/have_built_path_for.h index 94761a7..60e98e1 100644 --- a/ruy/have_built_path_for.h +++ b/ruy/have_built_path_for.h @@ -21,7 +21,7 @@ limitations under the License. namespace ruy { #if RUY_PLATFORM_X86 -bool HaveBuiltPathForAvx2(); +bool HaveBuiltPathForAvx2Fma(); bool HaveBuiltPathForAvx512(); #endif // RUY_PLATFORM_X86 diff --git a/ruy/have_built_path_for_avx2.cc b/ruy/have_built_path_for_avx2_fma.cc index 4a6bbe5..03e8f8d 100644 --- a/ruy/have_built_path_for_avx2.cc +++ b/ruy/have_built_path_for_avx2_fma.cc @@ -21,15 +21,15 @@ namespace ruy { #if RUY_PLATFORM_X86 // IMPORTANT: // These patterns must match those in the pack and kernel cc files. -#if !(RUY_PLATFORM_AVX2 && RUY_OPT(ASM)) +#if !(RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)) -bool HaveBuiltPathForAvx2() { return false; } +bool HaveBuiltPathForAvx2Fma() { return false; } -#else // RUY_PLATFORM_AVX2 && RUY_OPT(ASM) +#else // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM) -bool HaveBuiltPathForAvx2() { return true; } +bool HaveBuiltPathForAvx2Fma() { return true; } -#endif // RUY_PLATFORM_AVX2 && RUY_OPT(ASM) +#endif // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM) #endif // RUY_PLATFORM_X86 } // namespace ruy diff --git a/ruy/kernel_avx2.cc b/ruy/kernel_avx2_fma.cc index 60c6e6b..3c261a5 100644 --- a/ruy/kernel_avx2.cc +++ b/ruy/kernel_avx2_fma.cc @@ -24,13 +24,13 @@ limitations under the License. #include "ruy/platform.h" #include "ruy/profiler/instrumentation.h" -#if RUY_PLATFORM_AVX2 && RUY_OPT(ASM) +#if RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM) #include <immintrin.h> // IWYU pragma: keep #endif namespace ruy { -#if !(RUY_PLATFORM_AVX2 && RUY_OPT(ASM)) +#if !(RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)) void Kernel8bitAvx2(const KernelParams8bit<8, 8>&) { // CPU-ID-based checks should disable the path that would reach this point. @@ -52,7 +52,7 @@ void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>&) { RUY_DCHECK(false); } -#else // RUY_PLATFORM_AVX2 && RUY_OPT(ASM) +#else // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM) static constexpr int kAvx8bitBlockSize = 8; static constexpr int kAvx8bitInnerSize = 4; @@ -393,7 +393,7 @@ inline void mm256_n_storeu_ps(float* dst, int residual_rows, const __m256 v) { } // namespace void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) { - profiler::ScopeLabel label("Kernel kAvx2 8-bit"); + profiler::ScopeLabel label("Kernel kAvx2Fma 8-bit"); const std::int8_t splitter_idx_data[32] = { 0, 1, 4, 5, 8, 9, 12, 13, // 2, 3, 6, 7, 10, 11, 14, 15, // @@ -1185,7 +1185,7 @@ void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) { } // NOLINT(readability/fn_size) void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) { - profiler::ScopeLabel label("Kernel kAvx2 8-bit GEMV"); + profiler::ScopeLabel label("Kernel kAvx2Fma 8-bit GEMV"); RUY_DCHECK_EQ(params.dst_cols, 1); RUY_DCHECK_EQ(params.last_col, 0); @@ -1450,7 +1450,7 @@ void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) { } // NOLINT(readability/fn_size) void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) { - profiler::ScopeLabel label("Kernel kAvx2 float"); + profiler::ScopeLabel label("Kernel kAvx2Fma float"); // As parameters are defined, we need to scale by sizeof(float). const std::int64_t lhs_stride = params.lhs_stride >> 2; @@ -1603,7 +1603,7 @@ void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) { } void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params) { - profiler::ScopeLabel label("Kernel kAvx2 float GEMV"); + profiler::ScopeLabel label("Kernel kAvx2Fma float GEMV"); RUY_DCHECK_EQ(params.dst_cols, 1); RUY_DCHECK_EQ(params.last_col, 0); @@ -1707,6 +1707,6 @@ void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params) { } // End handling of residual rows. } -#endif // RUY_PLATFORM_AVX2 && RUY_OPT(ASM) +#endif // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM) } // namespace ruy diff --git a/ruy/kernel_x86.h b/ruy/kernel_x86.h index 2def91f..641fbb6 100644 --- a/ruy/kernel_x86.h +++ b/ruy/kernel_x86.h @@ -30,8 +30,8 @@ namespace ruy { #if RUY_PLATFORM_X86 -RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kAvx2) -RUY_INHERIT_KERNEL(Path::kAvx2, Path::kAvx512) +RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kAvx2Fma) +RUY_INHERIT_KERNEL(Path::kAvx2Fma, Path::kAvx512) void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params); void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params); @@ -87,8 +87,9 @@ void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params); void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params); template <typename DstScalar> -struct Kernel<Path::kAvx2, std::int8_t, std::int8_t, std::int32_t, DstScalar> { - static constexpr Path kPath = Path::kAvx2; +struct Kernel<Path::kAvx2Fma, std::int8_t, std::int8_t, std::int32_t, + DstScalar> { + static constexpr Path kPath = Path::kAvx2Fma; Tuning tuning = Tuning::kAuto; using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>; using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>; @@ -112,8 +113,8 @@ void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params); void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params); template <> -struct Kernel<Path::kAvx2, float, float, float, float> { - static constexpr Path kPath = Path::kAvx2; +struct Kernel<Path::kAvx2Fma, float, float, float, float> { + static constexpr Path kPath = Path::kAvx2Fma; Tuning tuning = Tuning::kAuto; using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>; diff --git a/ruy/pack_avx2.cc b/ruy/pack_avx2_fma.cc index 01c477e..ebb0b33 100644 --- a/ruy/pack_avx2.cc +++ b/ruy/pack_avx2_fma.cc @@ -23,13 +23,13 @@ limitations under the License. #include "ruy/platform.h" #include "ruy/profiler/instrumentation.h" -#if RUY_PLATFORM_AVX2 && RUY_OPT(INTRINSICS) +#if RUY_PLATFORM_AVX2_FMA && RUY_OPT(INTRINSICS) #include <immintrin.h> // IWYU pragma: keep #endif namespace ruy { -#if !(RUY_PLATFORM_AVX2 && RUY_OPT(ASM)) +#if !(RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)) void Pack8bitAvx2(const std::int8_t*, std::int8_t, const std::int8_t*, int, int, int, std::int8_t*, std::int32_t*) { @@ -42,16 +42,16 @@ void PackFloatAvx2(const float*, const float*, int, int, int, float*) { RUY_DCHECK(false); } -#else // RUY_PLATFORM_AVX2 && RUY_OPT(ASM) +#else // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM) // The first int8_t template parameter is arbitrary: this routine is common to // all 8-bit source matrix types. using PackImpl8bitAvx2 = - PackImpl<Path::kAvx2, FixedKernelLayout<Order::kColMajor, 4, 8>, + PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kColMajor, 4, 8>, std::int8_t, std::int8_t, std::int32_t>; using PackImplFloatAvx2 = - PackImpl<Path::kAvx2, FixedKernelLayout<Order::kRowMajor, 1, 8>, float, + PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kRowMajor, 1, 8>, float, float, float>; namespace { @@ -752,7 +752,7 @@ void Pack8bitAvx2(const std::int8_t* src_ptr, std::int8_t input_xor, const std::int8_t* zerobuf, int src_stride, int remaining_src_cols, int src_rows, std::int8_t* packed_ptr, std::int32_t* sums_ptr) { - profiler::ScopeLabel label("Pack kAvx2 8bit"); + profiler::ScopeLabel label("Pack kAvx2Fma 8bit"); using Layout = PackImpl8bitAvx2::Layout; RUY_DCHECK_EQ(Layout::kCols, 8); @@ -789,7 +789,7 @@ void Pack8bitAvx2(const std::int8_t* src_ptr, std::int8_t input_xor, void PackFloatAvx2(const float* src_ptr, const float* zerobuf, int src_stride, int remaining_src_cols, int src_rows, float* packed_ptr) { - profiler::ScopeLabel label("Pack kAvx2 float"); + profiler::ScopeLabel label("Pack kAvx2Fma float"); static constexpr int kPackCols = 8; // Source cols packed together. static constexpr int kPackRows = 8; // Short input is padded. float trailing_buf[(kPackRows - 1) * kPackCols]; @@ -807,6 +807,6 @@ void PackFloatAvx2(const float* src_ptr, const float* zerobuf, int src_stride, } } -#endif // RUY_PLATFORM_AVX2 && RUY_OPT(INTRINSICS) +#endif // RUY_PLATFORM_AVX2_FMA && RUY_OPT(INTRINSICS) } // namespace ruy diff --git a/ruy/pack_x86.h b/ruy/pack_x86.h index 10aea82..93e8904 100644 --- a/ruy/pack_x86.h +++ b/ruy/pack_x86.h @@ -33,11 +33,11 @@ namespace ruy { #if RUY_PLATFORM_X86 -RUY_INHERIT_PACK(Path::kStandardCpp, Path::kAvx2) -RUY_INHERIT_PACK(Path::kAvx2, Path::kAvx512) +RUY_INHERIT_PACK(Path::kStandardCpp, Path::kAvx2Fma) +RUY_INHERIT_PACK(Path::kAvx2Fma, Path::kAvx512) template <> -struct PackedTypeImpl<Path::kAvx2, std::uint8_t> { +struct PackedTypeImpl<Path::kAvx2Fma, std::uint8_t> { using Type = std::int8_t; }; template <> @@ -53,8 +53,8 @@ void Pack8bitAvx2(const std::int8_t* src_ptr, std::int8_t input_xor, std::int32_t* sums_ptr); template <typename Scalar> -struct PackImpl<Path::kAvx2, FixedKernelLayout<Order::kColMajor, 4, 8>, Scalar, - std::int8_t, std::int32_t> { +struct PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kColMajor, 4, 8>, + Scalar, std::int8_t, std::int32_t> { static_assert(std::is_same<Scalar, std::int8_t>::value || std::is_same<Scalar, std::uint8_t>::value, ""); @@ -98,8 +98,8 @@ void PackFloatAvx2(const float* src_ptr, const float* zerobuf, int src_stride, int remaining_src_cols, int src_rows, float* packed_ptr); template <> -struct PackImpl<Path::kAvx2, FixedKernelLayout<Order::kRowMajor, 1, 8>, float, - float, float> { +struct PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kRowMajor, 1, 8>, + float, float, float> { using Layout = FixedKernelLayout<Order::kRowMajor, 1, 8>; static void Run(Tuning, const Mat<float>& src_matrix, PMat<float>* packed_matrix, int start_col, int end_col) { @@ -77,9 +77,11 @@ enum class Path : std::uint8_t { #endif // RUY_PLATFORM_ARM #if RUY_PLATFORM_X86 - // Optimized for AVX2. - kAvx2 = 0x10, + // Optimized for AVX2+FMA. + // Compiled with -mavx2 -mfma. + kAvx2Fma = 0x10, // Optimized for AVX-512. + // Compiled with -mavx512f -mavx512vl -mavx512cd -mavx512bw -mavx512dq. kAvx512 = 0x20, #endif // RUY_PLATFORM_X86 }; @@ -143,7 +145,7 @@ constexpr Path kExtraArchPaths = Path::kNone; constexpr Path kDefaultArchPaths = Path::kNeon; constexpr Path kExtraArchPaths = Path::kNone; #elif RUY_PLATFORM_X86 -constexpr Path kDefaultArchPaths = Path::kAvx2 | Path::kAvx512; +constexpr Path kDefaultArchPaths = Path::kAvx2Fma | Path::kAvx512; constexpr Path kExtraArchPaths = Path::kNone; #else constexpr Path kDefaultArchPaths = Path::kNone; diff --git a/ruy/platform.h b/ruy/platform.h index 2f9cbb3..7421613 100644 --- a/ruy/platform.h +++ b/ruy/platform.h @@ -132,36 +132,11 @@ limitations under the License. #define RUY_PLATFORM_AVX512 0 #endif -#if RUY_PLATFORM_X86_ENHANCEMENTS && RUY_PLATFORM_X86 && defined(__AVX2__) -#define RUY_PLATFORM_AVX2 1 +#if RUY_PLATFORM_X86_ENHANCEMENTS && RUY_PLATFORM_X86 && defined(__AVX2__) && \ + defined(__FMA__) +#define RUY_PLATFORM_AVX2_FMA 1 #else -#define RUY_PLATFORM_AVX2 0 -#endif - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// Note does not check for LZCNT or POPCNT. -#if defined(RUY_ENABLE_SSE_ENHANCEMENTS) && RUY_PLATFORM_X86_ENHANCEMENTS && \ - RUY_PLATFORM_X86 && defined(__SSE4_2__) && defined(__FMA__) -#define RUY_PLATFORM_SSE42 1 -#else -#define RUY_PLATFORM_SSE42 0 -#endif - -// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder. -// Optimization is not finished. In particular the dimensions of the kernel -// blocks can be changed as desired. -// -// Note that defined(__AVX512VBMI2__) can be false for compilation with -// -march=cascadelake. -// TODO(b/146646451) Check if we should also gate on defined(__AVX512VBMI2__). -#if defined(RUY_ENABLE_VNNI_ENHANCEMENTS) && RUY_PLATFORM_AVX512 && \ - defined(__AVX512VNNI__) -#define RUY_PLATFORM_AVX_VNNI 1 -#else -#define RUY_PLATFORM_AVX_VNNI 0 +#define RUY_PLATFORM_AVX2_FMA 0 #endif // Detect Emscripten, typically Wasm. @@ -105,7 +105,7 @@ inline const char* PathName(Path path) { RUY_PATHNAME_CASE(kNeon) RUY_PATHNAME_CASE(kNeonDotprod) #elif RUY_PLATFORM_X86 - RUY_PATHNAME_CASE(kAvx2) + RUY_PATHNAME_CASE(kAvx2Fma) RUY_PATHNAME_CASE(kAvx512) #endif default: |