diff options
author | Ruy Contributors <ruy-eng@google.com> | 2022-09-14 22:56:04 +0300 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2022-09-14 22:56:29 +0300 |
commit | 3286a34cc8de6149ac6844107dfdffac91531e72 (patch) | |
tree | dcc7f0f6e9f9bc60bd369944481d3b6a673c53d6 | |
parent | 97ebb72aa0655c0af98896b317476a5d0dacad9c (diff) |
PiperOrigin-RevId: 474367386
-rw-r--r-- | ruy/BUILD | 12 | ||||
-rw-r--r-- | ruy/context.cc | 7 | ||||
-rw-r--r-- | ruy/context.h | 5 | ||||
-rw-r--r-- | ruy/context_test.cc | 5 | ||||
-rw-r--r-- | ruy/ctx.cc | 7 | ||||
-rw-r--r-- | ruy/ctx.h | 3 | ||||
-rw-r--r-- | ruy/ctx_impl.h | 2 | ||||
-rw-r--r-- | ruy/ctx_test.cc | 9 | ||||
-rw-r--r-- | ruy/strategy_controls.h | 34 | ||||
-rw-r--r-- | ruy/trmul.cc | 5 |
10 files changed, 89 insertions, 0 deletions
@@ -436,6 +436,13 @@ cc_library( ) cc_library( + name = "strategy_controls", + hdrs = ["strategy_controls.h"], + copts = ruy_copts(), + visibility = ["//visibility:public"], +) + +cc_library( name = "matrix", hdrs = ["matrix.h"], copts = ruy_copts(), @@ -859,6 +866,7 @@ cc_library( ":performance_advisory", ":platform", ":prepacked_cache", + ":strategy_controls", ":thread_pool", ":tune", ], @@ -874,6 +882,7 @@ cc_test( ":path", ":platform", ":prepacked_cache", + ":strategy_controls", ":tune", ], ) @@ -907,6 +916,7 @@ cc_library( ":performance_advisory", ":platform", ":prepacked_cache", + ":strategy_controls", ":thread_pool", ":trace", ":tune", @@ -937,6 +947,7 @@ cc_test( ":gtest_wrapper", ":path", ":platform", + ":strategy_controls", ], ) @@ -972,6 +983,7 @@ cc_library( ":opt_set", ":side_pair", ":size_util", + ":strategy_controls", ":thread_pool", ":trace", ":trmul_params", diff --git a/ruy/context.cc b/ruy/context.cc index 342ce52..ec651f9 100644 --- a/ruy/context.cc +++ b/ruy/context.cc @@ -17,6 +17,7 @@ limitations under the License. #include "ruy/ctx.h" #include "ruy/ctx_impl.h" +#include "ruy/strategy_controls.h" #include "ruy/path.h" #include "ruy/performance_advisory.h" #include "ruy/prepacked_cache.h" @@ -44,6 +45,12 @@ int Context::max_num_threads() const { return ctx().max_num_threads(); } void Context::set_max_num_threads(int value) { mutable_ctx()->set_max_num_threads(value); } +NumThreadsStrategy Context::num_threads_strategy() const { + return ctx().num_threads_strategy(); +} +void Context::set_num_threads_strategy(NumThreadsStrategy strategy) { + mutable_ctx()->set_num_threads_strategy(strategy); +} void Context::ClearPrepackedCache() { mutable_ctx()->ClearPrepackedCache(); } diff --git a/ruy/context.h b/ruy/context.h index f148f0f..16f40e7 100644 --- a/ruy/context.h +++ b/ruy/context.h @@ -28,6 +28,7 @@ class ThreadPool; enum class Path : std::uint8_t; enum class Tuning; enum class PerformanceAdvisory; +enum class NumThreadsStrategy : std::uint8_t; // A Context holds runtime information used by Ruy. It holds runtime resources // such as the workers thread pool and the allocator (which holds buffers for @@ -71,6 +72,10 @@ class Context final { int max_num_threads() const; void set_max_num_threads(int value); + // Controls the logic to determine how many threads to use. + NumThreadsStrategy num_threads_strategy() const; + void set_num_threads_strategy(NumThreadsStrategy strategy); + // Returns true of the last ruy::Mul using this Context flagged the specified // `advisory`. This is reset by each ruy::Mul call. bool performance_advisory(PerformanceAdvisory advisory) const; diff --git a/ruy/context_test.cc b/ruy/context_test.cc index 4e69e65..6497c77 100644 --- a/ruy/context_test.cc +++ b/ruy/context_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "ruy/context.h" #include "ruy/gtest_wrapper.h" +#include "ruy/strategy_controls.h" #include "ruy/path.h" #include "ruy/prepacked_cache.h" #include "ruy/tune.h" @@ -30,10 +31,14 @@ TEST(ContextTest, ContextClassSanity) { EXPECT_EQ(&context.thread_pool(), context.mutable_thread_pool()); EXPECT_NE(context.mutable_thread_pool(), nullptr); EXPECT_EQ(context.max_num_threads(), 1); + EXPECT_EQ(context.num_threads_strategy(), NumThreadsStrategy::kDefault); context.set_explicit_tuning(Tuning::kGeneric); context.set_max_num_threads(2); + context.set_num_threads_strategy(NumThreadsStrategy::kForceMaxNumThreads); EXPECT_EQ(context.explicit_tuning(), Tuning::kGeneric); EXPECT_EQ(context.max_num_threads(), 2); + EXPECT_EQ(context.num_threads_strategy(), + NumThreadsStrategy::kForceMaxNumThreads); } } // namespace @@ -26,6 +26,7 @@ limitations under the License. #include "ruy/path.h" #include "ruy/performance_advisory.h" #include "ruy/platform.h" +#include "ruy/strategy_controls.h" #include "ruy/prepacked_cache.h" #include "ruy/trace.h" @@ -56,6 +57,12 @@ bool Ctx::performance_advisory(PerformanceAdvisory advisory) const { return (impl().performance_advisory_ & advisory) != PerformanceAdvisory::kNone; } +void Ctx::set_num_threads_strategy(NumThreadsStrategy strategy) { + mutable_impl()->num_threads_strategy_ = strategy; +} +NumThreadsStrategy Ctx::num_threads_strategy() const { + return impl().num_threads_strategy_; +} void Ctx::SetRuntimeEnabledPaths(Path paths) { if (paths == Path::kNone) { @@ -32,6 +32,7 @@ class CpuInfo; enum class Path : std::uint8_t; enum class Tuning; enum class PerformanceAdvisory; +enum class NumThreadsStrategy : std::uint8_t; // Ctx is the internal context class used throughout ruy code. Whereas Context // is exposed to users, Ctx is internal to ruy. As many of ruy's internal @@ -53,6 +54,8 @@ class Ctx /* not final, subclassed by CtxImpl */ { void clear_performance_advisories(); void set_performance_advisory(PerformanceAdvisory advisory); bool performance_advisory(PerformanceAdvisory advisory) const; + void set_num_threads_strategy(NumThreadsStrategy strategy); + NumThreadsStrategy num_threads_strategy() const; // Returns the set of Path's that are available. By default, this is based on // runtime detection of CPU features, as well as on which code paths were diff --git a/ruy/ctx_impl.h b/ruy/ctx_impl.h index 0a07ef6..be64553 100644 --- a/ruy/ctx_impl.h +++ b/ruy/ctx_impl.h @@ -29,6 +29,7 @@ limitations under the License. #include "ruy/path.h" #include "ruy/performance_advisory.h" #include "ruy/prepacked_cache.h" +#include "ruy/strategy_controls.h" #include "ruy/thread_pool.h" #include "ruy/tune.h" @@ -63,6 +64,7 @@ class CtxImpl final : public Ctx { Tuning explicit_tuning_ = Tuning::kAuto; ThreadPool thread_pool_; int max_num_threads_ = 1; + NumThreadsStrategy num_threads_strategy_ = NumThreadsStrategy::kDefault; // Allocator for main thread work before invoking the threadpool. // Our simple Allocator does not allow reserving/allocating more blocks // while it's already in committed state, so the main thread needs both diff --git a/ruy/ctx_test.cc b/ruy/ctx_test.cc index e55dcfc..c40f2d6 100644 --- a/ruy/ctx_test.cc +++ b/ruy/ctx_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "ruy/ctx_impl.h" #include "ruy/gtest_wrapper.h" +#include "ruy/strategy_controls.h" #include "ruy/path.h" #include "ruy/platform.h" @@ -67,6 +68,14 @@ TEST(ContextInternalTest, ThreadSpecificResources) { } } +TEST(ContextInternalTest, SetNumThreadsStrategy) { + CtxImpl ctx; + EXPECT_EQ(ctx.num_threads_strategy(), NumThreadsStrategy::kDefault); + ctx.set_num_threads_strategy(NumThreadsStrategy::kForceMaxNumThreads); + EXPECT_EQ(ctx.num_threads_strategy(), + NumThreadsStrategy::kForceMaxNumThreads); +} + } // namespace } // namespace ruy diff --git a/ruy/strategy_controls.h b/ruy/strategy_controls.h new file mode 100644 index 0000000..629c2b8 --- /dev/null +++ b/ruy/strategy_controls.h @@ -0,0 +1,34 @@ +/* Copyright 2022 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef RUY_RUY_STRATEGY_CONTROLS_H_ +#define RUY_RUY_STRATEGY_CONTROLS_H_ + +#include <cstdint> + +namespace ruy { + +enum class NumThreadsStrategy : std::uint8_t { + // kDefault means using smart heuristic logic that has been optimized + // for cubic ColxRowxDepth matrix multiplication. + kDefault, + // kForceMaxNumThreads means using ctx->max_num_thread() + // for multi-thread computing. + kForceMaxNumThreads +}; + +} // namespace ruy + +#endif // RUY_RUY_STRATEGY_CONTROLS_H_ diff --git a/ruy/trmul.cc b/ruy/trmul.cc index 7af4bf0..2ff519f 100644 --- a/ruy/trmul.cc +++ b/ruy/trmul.cc @@ -35,6 +35,7 @@ limitations under the License. #include "ruy/mat.h" #include "ruy/matrix.h" #include "ruy/mul_params.h" +#include "ruy/strategy_controls.h" #include "ruy/opt_set.h" #include "ruy/profiler/instrumentation.h" #include "ruy/side_pair.h" @@ -259,6 +260,10 @@ int GetTentativeThreadCount(Ctx* ctx, int rows, int cols, int depth) { // in this Mul (product of the 3 sizes). // Be defensive here by explicitly promoting operands to int64 to avoid the // pitfall of `int64 result = x * y;` overflowing as x and y are still narrow. + if (ctx->num_threads_strategy() == NumThreadsStrategy::kForceMaxNumThreads) { + return ctx->max_num_threads(); + } + RUY_CHECK_EQ(ctx->num_threads_strategy(), NumThreadsStrategy::kDefault); const std::int64_t rows_i64 = rows; const std::int64_t cols_i64 = cols; const std::int64_t depth_i64 = depth; |