diff options
author | Ruy Contributors <ruy-eng@google.com> | 2022-09-13 23:15:40 +0300 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2022-09-13 23:35:04 +0300 |
commit | cdad646878ecb8139fe40f4884e5cbd709311728 (patch) | |
tree | a7aa805c49c64358c2074caf5849128ef68a250b | |
parent | 97ebb72aa0655c0af98896b317476a5d0dacad9c (diff) |
Create API to determine how many threads to usetest_474105012
PiperOrigin-RevId: 474105012
-rw-r--r-- | ruy/BUILD | 9 | ||||
-rw-r--r-- | ruy/ctx.cc | 7 | ||||
-rw-r--r-- | ruy/ctx.h | 3 | ||||
-rw-r--r-- | ruy/ctx_impl.h | 3 | ||||
-rw-r--r-- | ruy/multithread_context.h | 33 | ||||
-rw-r--r-- | ruy/trmul.cc | 6 |
6 files changed, 61 insertions, 0 deletions
@@ -436,6 +436,13 @@ cc_library( ) cc_library( + name = "multithread_context", + hdrs = ["multithread_context.h"], + copts = ruy_copts(), + visibility = ["//visibility:public"], +) + +cc_library( name = "matrix", hdrs = ["matrix.h"], copts = ruy_copts(), @@ -903,6 +910,7 @@ cc_library( ":check_macros", ":cpuinfo", ":have_built_path_for", + ":multithread_context", ":path", ":performance_advisory", ":platform", @@ -969,6 +977,7 @@ cc_library( ":mat", ":matrix", ":mul_params", + ":multithread_context", ":opt_set", ":side_pair", ":size_util", @@ -26,6 +26,7 @@ limitations under the License. #include "ruy/path.h" #include "ruy/performance_advisory.h" #include "ruy/platform.h" +#include "ruy/multithread_context.h" #include "ruy/prepacked_cache.h" #include "ruy/trace.h" @@ -56,6 +57,12 @@ bool Ctx::performance_advisory(PerformanceAdvisory advisory) const { return (impl().performance_advisory_ & advisory) != PerformanceAdvisory::kNone; } +void Ctx::set_num_threads_strategy(NumThreadsStrategy strategy) { + mutable_impl()->num_threads_strategy_ = strategy; +} +NumThreadsStrategy Ctx::num_threads_strategy() const { + return impl().num_threads_strategy_; +} void Ctx::SetRuntimeEnabledPaths(Path paths) { if (paths == Path::kNone) { @@ -32,6 +32,7 @@ class CpuInfo; enum class Path : std::uint8_t; enum class Tuning; enum class PerformanceAdvisory; +enum class NumThreadsStrategy : std::uint8_t; // Ctx is the internal context class used throughout ruy code. Whereas Context // is exposed to users, Ctx is internal to ruy. As many of ruy's internal @@ -53,6 +54,8 @@ class Ctx /* not final, subclassed by CtxImpl */ { void clear_performance_advisories(); void set_performance_advisory(PerformanceAdvisory advisory); bool performance_advisory(PerformanceAdvisory advisory) const; + void set_num_threads_strategy(NumThreadsStrategy strategy); + NumThreadsStrategy num_threads_strategy() const; // Returns the set of Path's that are available. By default, this is based on // runtime detection of CPU features, as well as on which code paths were diff --git a/ruy/ctx_impl.h b/ruy/ctx_impl.h index 0a07ef6..4cc0dea 100644 --- a/ruy/ctx_impl.h +++ b/ruy/ctx_impl.h @@ -29,6 +29,7 @@ limitations under the License. #include "ruy/path.h" #include "ruy/performance_advisory.h" #include "ruy/prepacked_cache.h" +#include "ruy/multithread_context.h" #include "ruy/thread_pool.h" #include "ruy/tune.h" @@ -63,6 +64,8 @@ class CtxImpl final : public Ctx { Tuning explicit_tuning_ = Tuning::kAuto; ThreadPool thread_pool_; int max_num_threads_ = 1; + NumThreadsStrategy num_threads_strategy_ = + NumThreadsStrategy::kSmartHeuristic; // Allocator for main thread work before invoking the threadpool. // Our simple Allocator does not allow reserving/allocating more blocks // while it's already in committed state, so the main thread needs both diff --git a/ruy/multithread_context.h b/ruy/multithread_context.h new file mode 100644 index 0000000..a38a72c --- /dev/null +++ b/ruy/multithread_context.h @@ -0,0 +1,33 @@ +/* Copyright 2022 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_MULTITHREAD_CONTEXT_H_ +#define RUY_RUY_MULTITHREAD_CONTEXT_H_ + +#include <cstdint> + +namespace ruy { + +enum class NumThreadsStrategy : std::uint8_t { + // kSmartHeuristic means using smart heuristic logic that has been optimized + // for cubic ColxRowxDepth matrix multiplication. + kSmartHeuristic, + // kEnforce means using ctx->max_num_thread() for multi-thread computing. + kEnforce +}; + +} // namespace ruy + +#endif // RUY_RUY_MULTITHREAD_CONTEXT_H_ diff --git a/ruy/trmul.cc b/ruy/trmul.cc index 7af4bf0..d406fe8 100644 --- a/ruy/trmul.cc +++ b/ruy/trmul.cc @@ -35,6 +35,7 @@ limitations under the License. #include "ruy/mat.h" #include "ruy/matrix.h" #include "ruy/mul_params.h" +#include "ruy/multithread_context.h" #include "ruy/opt_set.h" #include "ruy/profiler/instrumentation.h" #include "ruy/side_pair.h" @@ -259,6 +260,11 @@ int GetTentativeThreadCount(Ctx* ctx, int rows, int cols, int depth) { // in this Mul (product of the 3 sizes). // Be defensive here by explicitly promoting operands to int64 to avoid the // pitfall of `int64 result = x * y;` overflowing as x and y are still narrow. + if (ctx->num_threads_strategy() == NumThreadsStrategy::kEnforce) { + return ctx->max_num_threads(); + } + RUY_CHECK_EQ(ctx->num_threads_strategy(), + NumThreadsStrategy::kSmartHeuristic); const std::int64_t rows_i64 = rows; const std::int64_t cols_i64 = cols; const std::int64_t depth_i64 = depth; |