Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/google/ruy.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ruy/BUILD26
-rw-r--r--ruy/build_defs.bzl2
-rw-r--r--ruy/ctx.cc4
-rw-r--r--ruy/ctx_test.cc4
-rw-r--r--ruy/have_built_path_for.h2
-rw-r--r--ruy/have_built_path_for_avx2_fma.cc (renamed from ruy/have_built_path_for_avx2.cc)10
-rw-r--r--ruy/kernel_avx2_fma.cc (renamed from ruy/kernel_avx2.cc)16
-rw-r--r--ruy/kernel_x86.h13
-rw-r--r--ruy/pack_avx2_fma.cc (renamed from ruy/pack_avx2.cc)16
-rw-r--r--ruy/pack_x86.h14
-rw-r--r--ruy/path.h8
-rw-r--r--ruy/platform.h33
-rw-r--r--ruy/test.h2
13 files changed, 64 insertions, 86 deletions
diff --git a/ruy/BUILD b/ruy/BUILD
index f595139..a8e6717 100644
--- a/ruy/BUILD
+++ b/ruy/BUILD
@@ -1,7 +1,7 @@
# Ruy is not BLAS
load("@bazel_skylib//lib:selects.bzl", "selects")
-load(":build_defs.bzl", "ruy_copts", "ruy_copts_avx2", "ruy_copts_avx512")
+load(":build_defs.bzl", "ruy_copts", "ruy_copts_avx2_fma", "ruy_copts_avx512")
load(":build_defs.oss.bzl", "ruy_linkopts_thread_standard_library")
load(":ruy_test_ext.oss.bzl", "ruy_test_ext_defines", "ruy_test_ext_deps")
load(":ruy_test.bzl", "ruy_benchmark", "ruy_test")
@@ -624,14 +624,14 @@ cc_library(
)
cc_library(
- name = "kernel_avx2",
+ name = "kernel_avx2_fma",
srcs = [
- "kernel_avx2.cc",
+ "kernel_avx2_fma.cc",
],
hdrs = [
"kernel_x86.h",
],
- copts = ruy_copts() + ruy_copts_avx2(),
+ copts = ruy_copts() + ruy_copts_avx2_fma(),
deps = [
":check_macros",
":kernel_common",
@@ -646,14 +646,14 @@ cc_library(
)
cc_library(
- name = "pack_avx2",
+ name = "pack_avx2_fma",
srcs = [
- "pack_avx2.cc",
+ "pack_avx2_fma.cc",
],
hdrs = [
"pack_x86.h",
],
- copts = ruy_copts() + ruy_copts_avx2(),
+ copts = ruy_copts() + ruy_copts_avx2_fma(),
deps = [
":check_macros",
":mat",
@@ -667,14 +667,14 @@ cc_library(
)
cc_library(
- name = "have_built_path_for_avx2",
+ name = "have_built_path_for_avx2_fma",
srcs = [
- "have_built_path_for_avx2.cc",
+ "have_built_path_for_avx2_fma.cc",
],
hdrs = [
"have_built_path_for.h",
],
- copts = ruy_copts() + ruy_copts_avx2(),
+ copts = ruy_copts() + ruy_copts_avx2_fma(),
deps = [
":opt_set",
":platform",
@@ -691,7 +691,7 @@ cc_library(
":apply_multiplier",
":check_macros",
":kernel_arm", # fixdeps: keep
- ":kernel_avx2", # fixdeps: keep
+ ":kernel_avx2_fma", # fixdeps: keep
":kernel_avx512", # fixdeps: keep
":kernel_common",
":mat",
@@ -719,7 +719,7 @@ cc_library(
":matrix",
":opt_set",
":pack_arm", # fixdeps: keep
- ":pack_avx2", # fixdeps: keep
+ ":pack_avx2_fma", # fixdeps: keep
":pack_avx512", # fixdeps: keep
":pack_common",
":path",
@@ -735,7 +735,7 @@ cc_library(
"have_built_path_for.h",
],
deps = [
- ":have_built_path_for_avx2",
+ ":have_built_path_for_avx2_fma",
":have_built_path_for_avx512",
":platform",
],
diff --git a/ruy/build_defs.bzl b/ruy/build_defs.bzl
index 594c4d9..a36942b 100644
--- a/ruy/build_defs.bzl
+++ b/ruy/build_defs.bzl
@@ -63,7 +63,7 @@ def ruy_copts_avx512():
"//conditions:default": [],
})
-def ruy_copts_avx2():
+def ruy_copts_avx2_fma():
return select({
"//ruy:x86_64": ["-mavx2", "-mfma"],
"//conditions:default": [],
diff --git a/ruy/ctx.cc b/ruy/ctx.cc
index 6f17f4f..0411ea7 100644
--- a/ruy/ctx.cc
+++ b/ruy/ctx.cc
@@ -111,8 +111,8 @@ Path DetectRuntimeSupportedPaths(Path paths_to_detect, CpuInfo* cpuinfo) {
#elif RUY_PLATFORM_X86
// x86 SIMD paths currently require both runtime detection, and detection of
// whether we're building the path at all.
- maybe_add(Path::kAvx2,
- [=]() { return HaveBuiltPathForAvx2() && cpuinfo->Avx2(); });
+ maybe_add(Path::kAvx2Fma,
+ [=]() { return HaveBuiltPathForAvx2Fma() && cpuinfo->Avx2(); });
maybe_add(Path::kAvx512,
[=]() { return HaveBuiltPathForAvx512() && cpuinfo->Avx512(); });
#else
diff --git a/ruy/ctx_test.cc b/ruy/ctx_test.cc
index f57ec64..e55dcfc 100644
--- a/ruy/ctx_test.cc
+++ b/ruy/ctx_test.cc
@@ -33,9 +33,9 @@ TEST(ContextInternalTest, EnabledPathsGeneral) {
#if RUY_PLATFORM_X86
TEST(ContextInternalTest, EnabledPathsX86Explicit) {
CtxImpl ctx;
- ctx.SetRuntimeEnabledPaths(Path::kAvx2);
+ ctx.SetRuntimeEnabledPaths(Path::kAvx2Fma);
const auto ruy_paths = ctx.GetRuntimeEnabledPaths();
- EXPECT_EQ(ruy_paths, Path::kStandardCpp | Path::kAvx2);
+ EXPECT_EQ(ruy_paths, Path::kStandardCpp | Path::kAvx2Fma);
}
#endif // RUY_PLATFORM_X86
diff --git a/ruy/have_built_path_for.h b/ruy/have_built_path_for.h
index 94761a7..60e98e1 100644
--- a/ruy/have_built_path_for.h
+++ b/ruy/have_built_path_for.h
@@ -21,7 +21,7 @@ limitations under the License.
namespace ruy {
#if RUY_PLATFORM_X86
-bool HaveBuiltPathForAvx2();
+bool HaveBuiltPathForAvx2Fma();
bool HaveBuiltPathForAvx512();
#endif // RUY_PLATFORM_X86
diff --git a/ruy/have_built_path_for_avx2.cc b/ruy/have_built_path_for_avx2_fma.cc
index 4a6bbe5..03e8f8d 100644
--- a/ruy/have_built_path_for_avx2.cc
+++ b/ruy/have_built_path_for_avx2_fma.cc
@@ -21,15 +21,15 @@ namespace ruy {
#if RUY_PLATFORM_X86
// IMPORTANT:
// These patterns must match those in the pack and kernel cc files.
-#if !(RUY_PLATFORM_AVX2 && RUY_OPT(ASM))
+#if !(RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM))
-bool HaveBuiltPathForAvx2() { return false; }
+bool HaveBuiltPathForAvx2Fma() { return false; }
-#else // RUY_PLATFORM_AVX2 && RUY_OPT(ASM)
+#else // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)
-bool HaveBuiltPathForAvx2() { return true; }
+bool HaveBuiltPathForAvx2Fma() { return true; }
-#endif // RUY_PLATFORM_AVX2 && RUY_OPT(ASM)
+#endif // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)
#endif // RUY_PLATFORM_X86
} // namespace ruy
diff --git a/ruy/kernel_avx2.cc b/ruy/kernel_avx2_fma.cc
index 60c6e6b..3c261a5 100644
--- a/ruy/kernel_avx2.cc
+++ b/ruy/kernel_avx2_fma.cc
@@ -24,13 +24,13 @@ limitations under the License.
#include "ruy/platform.h"
#include "ruy/profiler/instrumentation.h"
-#if RUY_PLATFORM_AVX2 && RUY_OPT(ASM)
+#if RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)
#include <immintrin.h> // IWYU pragma: keep
#endif
namespace ruy {
-#if !(RUY_PLATFORM_AVX2 && RUY_OPT(ASM))
+#if !(RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM))
void Kernel8bitAvx2(const KernelParams8bit<8, 8>&) {
// CPU-ID-based checks should disable the path that would reach this point.
@@ -52,7 +52,7 @@ void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>&) {
RUY_DCHECK(false);
}
-#else // RUY_PLATFORM_AVX2 && RUY_OPT(ASM)
+#else // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)
static constexpr int kAvx8bitBlockSize = 8;
static constexpr int kAvx8bitInnerSize = 4;
@@ -393,7 +393,7 @@ inline void mm256_n_storeu_ps(float* dst, int residual_rows, const __m256 v) {
} // namespace
void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) {
- profiler::ScopeLabel label("Kernel kAvx2 8-bit");
+ profiler::ScopeLabel label("Kernel kAvx2Fma 8-bit");
const std::int8_t splitter_idx_data[32] = {
0, 1, 4, 5, 8, 9, 12, 13, //
2, 3, 6, 7, 10, 11, 14, 15, //
@@ -1185,7 +1185,7 @@ void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) {
} // NOLINT(readability/fn_size)
void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) {
- profiler::ScopeLabel label("Kernel kAvx2 8-bit GEMV");
+ profiler::ScopeLabel label("Kernel kAvx2Fma 8-bit GEMV");
RUY_DCHECK_EQ(params.dst_cols, 1);
RUY_DCHECK_EQ(params.last_col, 0);
@@ -1450,7 +1450,7 @@ void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) {
} // NOLINT(readability/fn_size)
void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) {
- profiler::ScopeLabel label("Kernel kAvx2 float");
+ profiler::ScopeLabel label("Kernel kAvx2Fma float");
// As parameters are defined, we need to scale by sizeof(float).
const std::int64_t lhs_stride = params.lhs_stride >> 2;
@@ -1603,7 +1603,7 @@ void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) {
}
void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params) {
- profiler::ScopeLabel label("Kernel kAvx2 float GEMV");
+ profiler::ScopeLabel label("Kernel kAvx2Fma float GEMV");
RUY_DCHECK_EQ(params.dst_cols, 1);
RUY_DCHECK_EQ(params.last_col, 0);
@@ -1707,6 +1707,6 @@ void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params) {
} // End handling of residual rows.
}
-#endif // RUY_PLATFORM_AVX2 && RUY_OPT(ASM)
+#endif // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)
} // namespace ruy
diff --git a/ruy/kernel_x86.h b/ruy/kernel_x86.h
index 2def91f..641fbb6 100644
--- a/ruy/kernel_x86.h
+++ b/ruy/kernel_x86.h
@@ -30,8 +30,8 @@ namespace ruy {
#if RUY_PLATFORM_X86
-RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kAvx2)
-RUY_INHERIT_KERNEL(Path::kAvx2, Path::kAvx512)
+RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kAvx2Fma)
+RUY_INHERIT_KERNEL(Path::kAvx2Fma, Path::kAvx512)
void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params);
void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params);
@@ -87,8 +87,9 @@ void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params);
void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params);
template <typename DstScalar>
-struct Kernel<Path::kAvx2, std::int8_t, std::int8_t, std::int32_t, DstScalar> {
- static constexpr Path kPath = Path::kAvx2;
+struct Kernel<Path::kAvx2Fma, std::int8_t, std::int8_t, std::int32_t,
+ DstScalar> {
+ static constexpr Path kPath = Path::kAvx2Fma;
Tuning tuning = Tuning::kAuto;
using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
@@ -112,8 +113,8 @@ void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params);
void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params);
template <>
-struct Kernel<Path::kAvx2, float, float, float, float> {
- static constexpr Path kPath = Path::kAvx2;
+struct Kernel<Path::kAvx2Fma, float, float, float, float> {
+ static constexpr Path kPath = Path::kAvx2Fma;
Tuning tuning = Tuning::kAuto;
using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
diff --git a/ruy/pack_avx2.cc b/ruy/pack_avx2_fma.cc
index 01c477e..ebb0b33 100644
--- a/ruy/pack_avx2.cc
+++ b/ruy/pack_avx2_fma.cc
@@ -23,13 +23,13 @@ limitations under the License.
#include "ruy/platform.h"
#include "ruy/profiler/instrumentation.h"
-#if RUY_PLATFORM_AVX2 && RUY_OPT(INTRINSICS)
+#if RUY_PLATFORM_AVX2_FMA && RUY_OPT(INTRINSICS)
#include <immintrin.h> // IWYU pragma: keep
#endif
namespace ruy {
-#if !(RUY_PLATFORM_AVX2 && RUY_OPT(ASM))
+#if !(RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM))
void Pack8bitAvx2(const std::int8_t*, std::int8_t, const std::int8_t*, int, int,
int, std::int8_t*, std::int32_t*) {
@@ -42,16 +42,16 @@ void PackFloatAvx2(const float*, const float*, int, int, int, float*) {
RUY_DCHECK(false);
}
-#else // RUY_PLATFORM_AVX2 && RUY_OPT(ASM)
+#else // RUY_PLATFORM_AVX2_FMA && RUY_OPT(ASM)
// The first int8_t template parameter is arbitrary: this routine is common to
// all 8-bit source matrix types.
using PackImpl8bitAvx2 =
- PackImpl<Path::kAvx2, FixedKernelLayout<Order::kColMajor, 4, 8>,
+ PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kColMajor, 4, 8>,
std::int8_t, std::int8_t, std::int32_t>;
using PackImplFloatAvx2 =
- PackImpl<Path::kAvx2, FixedKernelLayout<Order::kRowMajor, 1, 8>, float,
+ PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kRowMajor, 1, 8>, float,
float, float>;
namespace {
@@ -752,7 +752,7 @@ void Pack8bitAvx2(const std::int8_t* src_ptr, std::int8_t input_xor,
const std::int8_t* zerobuf, int src_stride,
int remaining_src_cols, int src_rows, std::int8_t* packed_ptr,
std::int32_t* sums_ptr) {
- profiler::ScopeLabel label("Pack kAvx2 8bit");
+ profiler::ScopeLabel label("Pack kAvx2Fma 8bit");
using Layout = PackImpl8bitAvx2::Layout;
RUY_DCHECK_EQ(Layout::kCols, 8);
@@ -789,7 +789,7 @@ void Pack8bitAvx2(const std::int8_t* src_ptr, std::int8_t input_xor,
void PackFloatAvx2(const float* src_ptr, const float* zerobuf, int src_stride,
int remaining_src_cols, int src_rows, float* packed_ptr) {
- profiler::ScopeLabel label("Pack kAvx2 float");
+ profiler::ScopeLabel label("Pack kAvx2Fma float");
static constexpr int kPackCols = 8; // Source cols packed together.
static constexpr int kPackRows = 8; // Short input is padded.
float trailing_buf[(kPackRows - 1) * kPackCols];
@@ -807,6 +807,6 @@ void PackFloatAvx2(const float* src_ptr, const float* zerobuf, int src_stride,
}
}
-#endif // RUY_PLATFORM_AVX2 && RUY_OPT(INTRINSICS)
+#endif // RUY_PLATFORM_AVX2_FMA && RUY_OPT(INTRINSICS)
} // namespace ruy
diff --git a/ruy/pack_x86.h b/ruy/pack_x86.h
index 10aea82..93e8904 100644
--- a/ruy/pack_x86.h
+++ b/ruy/pack_x86.h
@@ -33,11 +33,11 @@ namespace ruy {
#if RUY_PLATFORM_X86
-RUY_INHERIT_PACK(Path::kStandardCpp, Path::kAvx2)
-RUY_INHERIT_PACK(Path::kAvx2, Path::kAvx512)
+RUY_INHERIT_PACK(Path::kStandardCpp, Path::kAvx2Fma)
+RUY_INHERIT_PACK(Path::kAvx2Fma, Path::kAvx512)
template <>
-struct PackedTypeImpl<Path::kAvx2, std::uint8_t> {
+struct PackedTypeImpl<Path::kAvx2Fma, std::uint8_t> {
using Type = std::int8_t;
};
template <>
@@ -53,8 +53,8 @@ void Pack8bitAvx2(const std::int8_t* src_ptr, std::int8_t input_xor,
std::int32_t* sums_ptr);
template <typename Scalar>
-struct PackImpl<Path::kAvx2, FixedKernelLayout<Order::kColMajor, 4, 8>, Scalar,
- std::int8_t, std::int32_t> {
+struct PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kColMajor, 4, 8>,
+ Scalar, std::int8_t, std::int32_t> {
static_assert(std::is_same<Scalar, std::int8_t>::value ||
std::is_same<Scalar, std::uint8_t>::value,
"");
@@ -98,8 +98,8 @@ void PackFloatAvx2(const float* src_ptr, const float* zerobuf, int src_stride,
int remaining_src_cols, int src_rows, float* packed_ptr);
template <>
-struct PackImpl<Path::kAvx2, FixedKernelLayout<Order::kRowMajor, 1, 8>, float,
- float, float> {
+struct PackImpl<Path::kAvx2Fma, FixedKernelLayout<Order::kRowMajor, 1, 8>,
+ float, float, float> {
using Layout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
static void Run(Tuning, const Mat<float>& src_matrix,
PMat<float>* packed_matrix, int start_col, int end_col) {
diff --git a/ruy/path.h b/ruy/path.h
index 94d3089..a3cd939 100644
--- a/ruy/path.h
+++ b/ruy/path.h
@@ -77,9 +77,11 @@ enum class Path : std::uint8_t {
#endif // RUY_PLATFORM_ARM
#if RUY_PLATFORM_X86
- // Optimized for AVX2.
- kAvx2 = 0x10,
+ // Optimized for AVX2+FMA.
+ // Compiled with -mavx2 -mfma.
+ kAvx2Fma = 0x10,
// Optimized for AVX-512.
+ // Compiled with -mavx512f -mavx512vl -mavx512cd -mavx512bw -mavx512dq.
kAvx512 = 0x20,
#endif // RUY_PLATFORM_X86
};
@@ -143,7 +145,7 @@ constexpr Path kExtraArchPaths = Path::kNone;
constexpr Path kDefaultArchPaths = Path::kNeon;
constexpr Path kExtraArchPaths = Path::kNone;
#elif RUY_PLATFORM_X86
-constexpr Path kDefaultArchPaths = Path::kAvx2 | Path::kAvx512;
+constexpr Path kDefaultArchPaths = Path::kAvx2Fma | Path::kAvx512;
constexpr Path kExtraArchPaths = Path::kNone;
#else
constexpr Path kDefaultArchPaths = Path::kNone;
diff --git a/ruy/platform.h b/ruy/platform.h
index 2f9cbb3..7421613 100644
--- a/ruy/platform.h
+++ b/ruy/platform.h
@@ -132,36 +132,11 @@ limitations under the License.
#define RUY_PLATFORM_AVX512 0
#endif
-#if RUY_PLATFORM_X86_ENHANCEMENTS && RUY_PLATFORM_X86 && defined(__AVX2__)
-#define RUY_PLATFORM_AVX2 1
+#if RUY_PLATFORM_X86_ENHANCEMENTS && RUY_PLATFORM_X86 && defined(__AVX2__) && \
+ defined(__FMA__)
+#define RUY_PLATFORM_AVX2_FMA 1
#else
-#define RUY_PLATFORM_AVX2 0
-#endif
-
-// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder.
-// Optimization is not finished. In particular the dimensions of the kernel
-// blocks can be changed as desired.
-//
-// Note does not check for LZCNT or POPCNT.
-#if defined(RUY_ENABLE_SSE_ENHANCEMENTS) && RUY_PLATFORM_X86_ENHANCEMENTS && \
- RUY_PLATFORM_X86 && defined(__SSE4_2__) && defined(__FMA__)
-#define RUY_PLATFORM_SSE42 1
-#else
-#define RUY_PLATFORM_SSE42 0
-#endif
-
-// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder.
-// Optimization is not finished. In particular the dimensions of the kernel
-// blocks can be changed as desired.
-//
-// Note that defined(__AVX512VBMI2__) can be false for compilation with
-// -march=cascadelake.
-// TODO(b/146646451) Check if we should also gate on defined(__AVX512VBMI2__).
-#if defined(RUY_ENABLE_VNNI_ENHANCEMENTS) && RUY_PLATFORM_AVX512 && \
- defined(__AVX512VNNI__)
-#define RUY_PLATFORM_AVX_VNNI 1
-#else
-#define RUY_PLATFORM_AVX_VNNI 0
+#define RUY_PLATFORM_AVX2_FMA 0
#endif
// Detect Emscripten, typically Wasm.
diff --git a/ruy/test.h b/ruy/test.h
index 2f4e1a8..9e34d36 100644
--- a/ruy/test.h
+++ b/ruy/test.h
@@ -105,7 +105,7 @@ inline const char* PathName(Path path) {
RUY_PATHNAME_CASE(kNeon)
RUY_PATHNAME_CASE(kNeonDotprod)
#elif RUY_PLATFORM_X86
- RUY_PATHNAME_CASE(kAvx2)
+ RUY_PATHNAME_CASE(kAvx2Fma)
RUY_PATHNAME_CASE(kAvx512)
#endif
default: