Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/google/ruy.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRuy Contributors <ruy-eng@google.com>2022-09-13 23:15:40 +0300
committerCopybara-Service <copybara-worker@google.com>2022-09-13 23:35:04 +0300
commitcdad646878ecb8139fe40f4884e5cbd709311728 (patch)
treea7aa805c49c64358c2074caf5849128ef68a250b
parent97ebb72aa0655c0af98896b317476a5d0dacad9c (diff)
Create API to determine how many threads to usetest_474105012
PiperOrigin-RevId: 474105012
-rw-r--r--ruy/BUILD9
-rw-r--r--ruy/ctx.cc7
-rw-r--r--ruy/ctx.h3
-rw-r--r--ruy/ctx_impl.h3
-rw-r--r--ruy/multithread_context.h33
-rw-r--r--ruy/trmul.cc6
6 files changed, 61 insertions, 0 deletions
diff --git a/ruy/BUILD b/ruy/BUILD
index 4c52c79..8d89c5b 100644
--- a/ruy/BUILD
+++ b/ruy/BUILD
@@ -436,6 +436,13 @@ cc_library(
)
cc_library(
+ name = "multithread_context",
+ hdrs = ["multithread_context.h"],
+ copts = ruy_copts(),
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
name = "matrix",
hdrs = ["matrix.h"],
copts = ruy_copts(),
@@ -903,6 +910,7 @@ cc_library(
":check_macros",
":cpuinfo",
":have_built_path_for",
+ ":multithread_context",
":path",
":performance_advisory",
":platform",
@@ -969,6 +977,7 @@ cc_library(
":mat",
":matrix",
":mul_params",
+ ":multithread_context",
":opt_set",
":side_pair",
":size_util",
diff --git a/ruy/ctx.cc b/ruy/ctx.cc
index 0ef098d..2abc0f1 100644
--- a/ruy/ctx.cc
+++ b/ruy/ctx.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "ruy/path.h"
#include "ruy/performance_advisory.h"
#include "ruy/platform.h"
+#include "ruy/multithread_context.h"
#include "ruy/prepacked_cache.h"
#include "ruy/trace.h"
@@ -56,6 +57,12 @@ bool Ctx::performance_advisory(PerformanceAdvisory advisory) const {
return (impl().performance_advisory_ & advisory) !=
PerformanceAdvisory::kNone;
}
+void Ctx::set_num_threads_strategy(NumThreadsStrategy strategy) {
+ mutable_impl()->num_threads_strategy_ = strategy;
+}
+NumThreadsStrategy Ctx::num_threads_strategy() const {
+ return impl().num_threads_strategy_;
+}
void Ctx::SetRuntimeEnabledPaths(Path paths) {
if (paths == Path::kNone) {
diff --git a/ruy/ctx.h b/ruy/ctx.h
index df9dee2..f576a90 100644
--- a/ruy/ctx.h
+++ b/ruy/ctx.h
@@ -32,6 +32,7 @@ class CpuInfo;
enum class Path : std::uint8_t;
enum class Tuning;
enum class PerformanceAdvisory;
+enum class NumThreadsStrategy : std::uint8_t;
// Ctx is the internal context class used throughout ruy code. Whereas Context
// is exposed to users, Ctx is internal to ruy. As many of ruy's internal
@@ -53,6 +54,8 @@ class Ctx /* not final, subclassed by CtxImpl */ {
void clear_performance_advisories();
void set_performance_advisory(PerformanceAdvisory advisory);
bool performance_advisory(PerformanceAdvisory advisory) const;
+ void set_num_threads_strategy(NumThreadsStrategy strategy);
+ NumThreadsStrategy num_threads_strategy() const;
// Returns the set of Path's that are available. By default, this is based on
// runtime detection of CPU features, as well as on which code paths were
diff --git a/ruy/ctx_impl.h b/ruy/ctx_impl.h
index 0a07ef6..4cc0dea 100644
--- a/ruy/ctx_impl.h
+++ b/ruy/ctx_impl.h
@@ -29,6 +29,7 @@ limitations under the License.
#include "ruy/path.h"
#include "ruy/performance_advisory.h"
#include "ruy/prepacked_cache.h"
+#include "ruy/multithread_context.h"
#include "ruy/thread_pool.h"
#include "ruy/tune.h"
@@ -63,6 +64,8 @@ class CtxImpl final : public Ctx {
Tuning explicit_tuning_ = Tuning::kAuto;
ThreadPool thread_pool_;
int max_num_threads_ = 1;
+ NumThreadsStrategy num_threads_strategy_ =
+ NumThreadsStrategy::kSmartHeuristic;
// Allocator for main thread work before invoking the threadpool.
// Our simple Allocator does not allow reserving/allocating more blocks
// while it's already in committed state, so the main thread needs both
diff --git a/ruy/multithread_context.h b/ruy/multithread_context.h
new file mode 100644
index 0000000..a38a72c
--- /dev/null
+++ b/ruy/multithread_context.h
@@ -0,0 +1,33 @@
+/* Copyright 2022 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef RUY_RUY_MULTITHREAD_CONTEXT_H_
+#define RUY_RUY_MULTITHREAD_CONTEXT_H_
+
+#include <cstdint>
+
+namespace ruy {
+
+enum class NumThreadsStrategy : std::uint8_t {
+ // kSmartHeuristic means using smart heuristic logic that has been optimized
+ // for cubic ColxRowxDepth matrix multiplication.
+ kSmartHeuristic,
+ // kEnforce means using ctx->max_num_thread() for multi-thread computing.
+ kEnforce
+};
+
+} // namespace ruy
+
+#endif // RUY_RUY_MULTITHREAD_CONTEXT_H_
diff --git a/ruy/trmul.cc b/ruy/trmul.cc
index 7af4bf0..d406fe8 100644
--- a/ruy/trmul.cc
+++ b/ruy/trmul.cc
@@ -35,6 +35,7 @@ limitations under the License.
#include "ruy/mat.h"
#include "ruy/matrix.h"
#include "ruy/mul_params.h"
+#include "ruy/multithread_context.h"
#include "ruy/opt_set.h"
#include "ruy/profiler/instrumentation.h"
#include "ruy/side_pair.h"
@@ -259,6 +260,11 @@ int GetTentativeThreadCount(Ctx* ctx, int rows, int cols, int depth) {
// in this Mul (product of the 3 sizes).
// Be defensive here by explicitly promoting operands to int64 to avoid the
// pitfall of `int64 result = x * y;` overflowing as x and y are still narrow.
+ if (ctx->num_threads_strategy() == NumThreadsStrategy::kEnforce) {
+ return ctx->max_num_threads();
+ }
+ RUY_CHECK_EQ(ctx->num_threads_strategy(),
+ NumThreadsStrategy::kSmartHeuristic);
const std::int64_t rows_i64 = rows;
const std::int64_t cols_i64 = cols;
const std::int64_t depth_i64 = depth;