Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/google/ruy.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRuy Contributors <ruy-eng@google.com>2022-09-14 22:56:04 +0300
committerCopybara-Service <copybara-worker@google.com>2022-09-14 22:56:29 +0300
commit3286a34cc8de6149ac6844107dfdffac91531e72 (patch)
treedcc7f0f6e9f9bc60bd369944481d3b6a673c53d6
parent97ebb72aa0655c0af98896b317476a5d0dacad9c (diff)
Create API to determine how many threads to useHEADmaster
PiperOrigin-RevId: 474367386
-rw-r--r--ruy/BUILD12
-rw-r--r--ruy/context.cc7
-rw-r--r--ruy/context.h5
-rw-r--r--ruy/context_test.cc5
-rw-r--r--ruy/ctx.cc7
-rw-r--r--ruy/ctx.h3
-rw-r--r--ruy/ctx_impl.h2
-rw-r--r--ruy/ctx_test.cc9
-rw-r--r--ruy/strategy_controls.h34
-rw-r--r--ruy/trmul.cc5
10 files changed, 89 insertions, 0 deletions
diff --git a/ruy/BUILD b/ruy/BUILD
index 4c52c79..2a6bf68 100644
--- a/ruy/BUILD
+++ b/ruy/BUILD
@@ -436,6 +436,13 @@ cc_library(
)
cc_library(
+ name = "strategy_controls",
+ hdrs = ["strategy_controls.h"],
+ copts = ruy_copts(),
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
name = "matrix",
hdrs = ["matrix.h"],
copts = ruy_copts(),
@@ -859,6 +866,7 @@ cc_library(
":performance_advisory",
":platform",
":prepacked_cache",
+ ":strategy_controls",
":thread_pool",
":tune",
],
@@ -874,6 +882,7 @@ cc_test(
":path",
":platform",
":prepacked_cache",
+ ":strategy_controls",
":tune",
],
)
@@ -907,6 +916,7 @@ cc_library(
":performance_advisory",
":platform",
":prepacked_cache",
+ ":strategy_controls",
":thread_pool",
":trace",
":tune",
@@ -937,6 +947,7 @@ cc_test(
":gtest_wrapper",
":path",
":platform",
+ ":strategy_controls",
],
)
@@ -972,6 +983,7 @@ cc_library(
":opt_set",
":side_pair",
":size_util",
+ ":strategy_controls",
":thread_pool",
":trace",
":trmul_params",
diff --git a/ruy/context.cc b/ruy/context.cc
index 342ce52..ec651f9 100644
--- a/ruy/context.cc
+++ b/ruy/context.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "ruy/ctx.h"
#include "ruy/ctx_impl.h"
+#include "ruy/strategy_controls.h"
#include "ruy/path.h"
#include "ruy/performance_advisory.h"
#include "ruy/prepacked_cache.h"
@@ -44,6 +45,12 @@ int Context::max_num_threads() const { return ctx().max_num_threads(); }
void Context::set_max_num_threads(int value) {
mutable_ctx()->set_max_num_threads(value);
}
+NumThreadsStrategy Context::num_threads_strategy() const {
+ return ctx().num_threads_strategy();
+}
+void Context::set_num_threads_strategy(NumThreadsStrategy strategy) {
+ mutable_ctx()->set_num_threads_strategy(strategy);
+}
void Context::ClearPrepackedCache() { mutable_ctx()->ClearPrepackedCache(); }
diff --git a/ruy/context.h b/ruy/context.h
index f148f0f..16f40e7 100644
--- a/ruy/context.h
+++ b/ruy/context.h
@@ -28,6 +28,7 @@ class ThreadPool;
enum class Path : std::uint8_t;
enum class Tuning;
enum class PerformanceAdvisory;
+enum class NumThreadsStrategy : std::uint8_t;
// A Context holds runtime information used by Ruy. It holds runtime resources
// such as the workers thread pool and the allocator (which holds buffers for
@@ -71,6 +72,10 @@ class Context final {
int max_num_threads() const;
void set_max_num_threads(int value);
+ // Controls the logic to determine how many threads to use.
+ NumThreadsStrategy num_threads_strategy() const;
+ void set_num_threads_strategy(NumThreadsStrategy strategy);
+
// Returns true of the last ruy::Mul using this Context flagged the specified
// `advisory`. This is reset by each ruy::Mul call.
bool performance_advisory(PerformanceAdvisory advisory) const;
diff --git a/ruy/context_test.cc b/ruy/context_test.cc
index 4e69e65..6497c77 100644
--- a/ruy/context_test.cc
+++ b/ruy/context_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "ruy/context.h"
#include "ruy/gtest_wrapper.h"
+#include "ruy/strategy_controls.h"
#include "ruy/path.h"
#include "ruy/prepacked_cache.h"
#include "ruy/tune.h"
@@ -30,10 +31,14 @@ TEST(ContextTest, ContextClassSanity) {
EXPECT_EQ(&context.thread_pool(), context.mutable_thread_pool());
EXPECT_NE(context.mutable_thread_pool(), nullptr);
EXPECT_EQ(context.max_num_threads(), 1);
+ EXPECT_EQ(context.num_threads_strategy(), NumThreadsStrategy::kDefault);
context.set_explicit_tuning(Tuning::kGeneric);
context.set_max_num_threads(2);
+ context.set_num_threads_strategy(NumThreadsStrategy::kForceMaxNumThreads);
EXPECT_EQ(context.explicit_tuning(), Tuning::kGeneric);
EXPECT_EQ(context.max_num_threads(), 2);
+ EXPECT_EQ(context.num_threads_strategy(),
+ NumThreadsStrategy::kForceMaxNumThreads);
}
} // namespace
diff --git a/ruy/ctx.cc b/ruy/ctx.cc
index 0ef098d..5d6afd4 100644
--- a/ruy/ctx.cc
+++ b/ruy/ctx.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "ruy/path.h"
#include "ruy/performance_advisory.h"
#include "ruy/platform.h"
+#include "ruy/strategy_controls.h"
#include "ruy/prepacked_cache.h"
#include "ruy/trace.h"
@@ -56,6 +57,12 @@ bool Ctx::performance_advisory(PerformanceAdvisory advisory) const {
return (impl().performance_advisory_ & advisory) !=
PerformanceAdvisory::kNone;
}
+void Ctx::set_num_threads_strategy(NumThreadsStrategy strategy) {
+ mutable_impl()->num_threads_strategy_ = strategy;
+}
+NumThreadsStrategy Ctx::num_threads_strategy() const {
+ return impl().num_threads_strategy_;
+}
void Ctx::SetRuntimeEnabledPaths(Path paths) {
if (paths == Path::kNone) {
diff --git a/ruy/ctx.h b/ruy/ctx.h
index df9dee2..f576a90 100644
--- a/ruy/ctx.h
+++ b/ruy/ctx.h
@@ -32,6 +32,7 @@ class CpuInfo;
enum class Path : std::uint8_t;
enum class Tuning;
enum class PerformanceAdvisory;
+enum class NumThreadsStrategy : std::uint8_t;
// Ctx is the internal context class used throughout ruy code. Whereas Context
// is exposed to users, Ctx is internal to ruy. As many of ruy's internal
@@ -53,6 +54,8 @@ class Ctx /* not final, subclassed by CtxImpl */ {
void clear_performance_advisories();
void set_performance_advisory(PerformanceAdvisory advisory);
bool performance_advisory(PerformanceAdvisory advisory) const;
+ void set_num_threads_strategy(NumThreadsStrategy strategy);
+ NumThreadsStrategy num_threads_strategy() const;
// Returns the set of Path's that are available. By default, this is based on
// runtime detection of CPU features, as well as on which code paths were
diff --git a/ruy/ctx_impl.h b/ruy/ctx_impl.h
index 0a07ef6..be64553 100644
--- a/ruy/ctx_impl.h
+++ b/ruy/ctx_impl.h
@@ -29,6 +29,7 @@ limitations under the License.
#include "ruy/path.h"
#include "ruy/performance_advisory.h"
#include "ruy/prepacked_cache.h"
+#include "ruy/strategy_controls.h"
#include "ruy/thread_pool.h"
#include "ruy/tune.h"
@@ -63,6 +64,7 @@ class CtxImpl final : public Ctx {
Tuning explicit_tuning_ = Tuning::kAuto;
ThreadPool thread_pool_;
int max_num_threads_ = 1;
+ NumThreadsStrategy num_threads_strategy_ = NumThreadsStrategy::kDefault;
// Allocator for main thread work before invoking the threadpool.
// Our simple Allocator does not allow reserving/allocating more blocks
// while it's already in committed state, so the main thread needs both
diff --git a/ruy/ctx_test.cc b/ruy/ctx_test.cc
index e55dcfc..c40f2d6 100644
--- a/ruy/ctx_test.cc
+++ b/ruy/ctx_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "ruy/ctx_impl.h"
#include "ruy/gtest_wrapper.h"
+#include "ruy/strategy_controls.h"
#include "ruy/path.h"
#include "ruy/platform.h"
@@ -67,6 +68,14 @@ TEST(ContextInternalTest, ThreadSpecificResources) {
}
}
+TEST(ContextInternalTest, SetNumThreadsStrategy) {
+ CtxImpl ctx;
+ EXPECT_EQ(ctx.num_threads_strategy(), NumThreadsStrategy::kDefault);
+ ctx.set_num_threads_strategy(NumThreadsStrategy::kForceMaxNumThreads);
+ EXPECT_EQ(ctx.num_threads_strategy(),
+ NumThreadsStrategy::kForceMaxNumThreads);
+}
+
} // namespace
} // namespace ruy
diff --git a/ruy/strategy_controls.h b/ruy/strategy_controls.h
new file mode 100644
index 0000000..629c2b8
--- /dev/null
+++ b/ruy/strategy_controls.h
@@ -0,0 +1,34 @@
+/* Copyright 2022 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_STRATEGY_CONTROLS_H_
+#define RUY_RUY_STRATEGY_CONTROLS_H_
+
+#include <cstdint>
+
+namespace ruy {
+
+enum class NumThreadsStrategy : std::uint8_t {
+ // kDefault means using smart heuristic logic that has been optimized
+ // for cubic ColxRowxDepth matrix multiplication.
+ kDefault,
+ // kForceMaxNumThreads means using ctx->max_num_thread()
+ // for multi-thread computing.
+ kForceMaxNumThreads
+};
+
+} // namespace ruy
+
+#endif // RUY_RUY_STRATEGY_CONTROLS_H_
diff --git a/ruy/trmul.cc b/ruy/trmul.cc
index 7af4bf0..2ff519f 100644
--- a/ruy/trmul.cc
+++ b/ruy/trmul.cc
@@ -35,6 +35,7 @@ limitations under the License.
#include "ruy/mat.h"
#include "ruy/matrix.h"
#include "ruy/mul_params.h"
+#include "ruy/strategy_controls.h"
#include "ruy/opt_set.h"
#include "ruy/profiler/instrumentation.h"
#include "ruy/side_pair.h"
@@ -259,6 +260,10 @@ int GetTentativeThreadCount(Ctx* ctx, int rows, int cols, int depth) {
// in this Mul (product of the 3 sizes).
// Be defensive here by explicitly promoting operands to int64 to avoid the
// pitfall of `int64 result = x * y;` overflowing as x and y are still narrow.
+ if (ctx->num_threads_strategy() == NumThreadsStrategy::kForceMaxNumThreads) {
+ return ctx->max_num_threads();
+ }
+ RUY_CHECK_EQ(ctx->num_threads_strategy(), NumThreadsStrategy::kDefault);
const std::int64_t rows_i64 = rows;
const std::int64_t cols_i64 = cols;
const std::int64_t depth_i64 = depth;