Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/google/ruy.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'ruy/profiler')
-rw-r--r--ruy/profiler/BUILD52
-rw-r--r--ruy/profiler/README.md149
-rw-r--r--ruy/profiler/instrumentation.cc130
-rw-r--r--ruy/profiler/instrumentation.h203
-rw-r--r--ruy/profiler/profiler.cc109
-rw-r--r--ruy/profiler/profiler.h106
-rw-r--r--ruy/profiler/test.cc167
-rw-r--r--ruy/profiler/test_instrumented_library.cc59
-rw-r--r--ruy/profiler/test_instrumented_library.h23
-rw-r--r--ruy/profiler/treeview.cc248
-rw-r--r--ruy/profiler/treeview.h130
11 files changed, 1376 insertions, 0 deletions
diff --git a/ruy/profiler/BUILD b/ruy/profiler/BUILD
new file mode 100644
index 0000000..b0af802
--- /dev/null
+++ b/ruy/profiler/BUILD
@@ -0,0 +1,52 @@
+# A minimalistic profiler sampling pseudo-stacks
+
+package(
+ default_visibility = ["//visibility:public"],
+ licenses = ["notice"], # Apache 2.0
+)
+
+config_setting(
+ name = "ruy_profiler",
+ define_values = {"ruy_profiler": "true"},
+)
+
+cc_library(
+ name = "instrumentation",
+ srcs = ["instrumentation.cc"],
+ hdrs = ["instrumentation.h"],
+ defines = select({
+ ":ruy_profiler": ["RUY_PROFILER"],
+ "//conditions:default": [],
+ }),
+)
+
+cc_library(
+ name = "profiler",
+ srcs = [
+ "profiler.cc",
+ "treeview.cc",
+ ],
+ hdrs = [
+ "profiler.h",
+ "treeview.h",
+ ],
+ deps = [":instrumentation"],
+)
+
+cc_library(
+ name = "test_instrumented_library",
+ testonly = True,
+ srcs = ["test_instrumented_library.cc"],
+ hdrs = ["test_instrumented_library.h"],
+ deps = [":instrumentation"],
+)
+
+cc_test(
+ name = "test",
+ srcs = ["test.cc"],
+ deps = [
+ ":profiler",
+ ":test_instrumented_library",
+ "@com_google_googletest//:gtest",
+ ],
+)
diff --git a/ruy/profiler/README.md b/ruy/profiler/README.md
new file mode 100644
index 0000000..8d79025
--- /dev/null
+++ b/ruy/profiler/README.md
@@ -0,0 +1,149 @@
+# A minimalistic profiler sampling pseudo-stacks
+
+## Overview
+
+The present directory is the "ruy profiler". As a time profiler, it allows to
+measure where code is spending time.
+
+Contrary to most typical profilers, what it samples is not real call stacks, but
+"pseudo-stacks" which are just simple data structures constructed from within
+the program being profiled. Using this profiler requires manually instrumenting
+code to construct such pseudo-stack information.
+
+Another unusual characteristic of this profiler is that it uses only the C++11
+standard library. It does not use any non-portable feature, in particular it
+does not rely on signal handlers. The sampling is performed by a thread, the
+"profiler thread".
+
+A discussion of pros/cons of this approach is appended below.
+
+## How to use this profiler
+
+### How to instrument code
+
+An example of instrumented code is given in `test_instrumented_library.cc`.
+
+Code is instrumented by constructing `ScopeLabel` objects. These are RAII
+helpers, ensuring that the thread pseudo-stack contains the label during their
+lifetime. In the most common use case, one would construct such an object at the
+start of a function, so that its scope is the function scope and it allows to
+measure how much time is spent in this function.
+
+```c++
+#include "ruy/profiler/instrumentation.h"
+
+...
+
+void SomeFunction() {
+ ruy::profiling::ScopeLabel function_label("SomeFunction");
+ ... do something ...
+}
+```
+
+A `ScopeLabel` may however have any scope, for instance:
+
+```c++
+if (some_case) {
+ ruy::profiling::ScopeLabel extra_work_label("Some more work");
+ ... do some more work ...
+}
+```
+
+The string passed to the `ScopeLabel` constructor must be just a pointer to a
+literal string (a `char*` pointer). The profiler will assume that these pointers
+stay valid until the profile is finalized.
+
+However, that literal string may be a `printf` format string, and labels may
+have up to 4 parameters, of type `int`. For example:
+
+```c++
+void SomeFunction(int size) {
+ ruy::profiling::ScopeLabel function_label("SomeFunction (size=%d)", size);
+
+```
+
+### How to run the profiler
+
+Profiling instrumentation is a no-op unless the preprocessor token
+`RUY_PROFILER` is defined, so defining it is the first step when actually
+profiling. When building with Bazel, the preferred way to enable that is to pass
+this flag on the Bazel command line:
+
+```
+--define=ruy_profiler=true
+```
+
+To actually profile a code scope, it is enough to construct a `ScopeProfile`
+object, also a RAII helper. It will start the profiler on construction, and on
+destruction it will terminate the profiler and report the profile treeview on
+standard output by default. Example:
+
+```c++
+void SomeProfiledBenchmark() {
+ ruy::profiling::ScopeProfile profile;
+
+ CallSomeInstrumentedCode();
+}
+```
+
+An example is provided by the `:test` target in the present directory. Run it
+with `--define=ruy_profiler=true` as explained above:
+
+```
+bazel run -c opt \
+ --define=ruy_profiler=true \
+ //tensorflow/lite/experimental/ruy/profiler:test
+```
+
+The default behavior dumping the treeview on standard output may be overridden
+by passing a pointer to a `TreeView` object to the `ScopeProfile` constructor.
+This causes the tree-view to be stored in that `TreeView` object, where it may
+be accessed an manipulated using the functions declared in `treeview.h`. The
+aforementioned `:test` provides examples for doing so.
+
+## Advantages and inconvenients
+
+Compared to a traditional profiler, e.g. Linux's "perf", the present kind of
+profiler has the following inconvenients:
+
+* Requires manual instrumentation of code being profiled.
+* Substantial overhead, modifying the performance characteristics of the code
+ being measured.
+* Questionable accuracy.
+
+But also the following advantages:
+
+* Profiling can be driven from within a benchmark program, allowing the entire
+ profiling procedure to be a single command line.
+* Not relying on symbol information removes removes exposure to toolchain
+ details and means less hassle in some build environments, especially
+ embedded/mobile (single command line to run and profile, no symbols files
+ required).
+* Fully portable (all of this is standard C++11).
+* Fully testable (see `:test`). Profiling becomes just another feature of the
+ code like any other.
+* Customized instrumentation can result in easier to read treeviews (only
+ relevant functions, and custom labels may be more readable than function
+ names).
+* Parametrized/formatted labels allow to do things that aren't possible with
+ call-stack-sampling profilers. For example, break down a profile where much
+ time is being spent in matrix multiplications, by the various matrix
+ multiplication shapes involved.
+
+The philosophy underlying this profiler is that software performance depends on
+software engineers profiling often, and a key factor limiting that in practice
+is the difficulty or cumbersome aspects of profiling with more serious profilers
+such as Linux's "perf", especially in embedded/mobile development: multiple
+command lines are involved to copy symbol files to devices, retrieve profile
+data from the device, etc. In that context, it is useful to make profiling as
+easy as benchmarking, even on embedded targets, even if the price to pay for
+that is lower accuracy, higher overhead, and some intrusive instrumentation
+requirement.
+
+Another key aspect determining what profiling approach is suitable for a given
+context, is whether one already has a-priori knowledge of where much of the time
+is likely being spent. When one has such a-priori knowledge, it is feasible to
+instrument the known possibly-critical code as per the present approach. On the
+other hand, in situations where one doesn't have such a-priori knowledge, a real
+profiler such as Linux's "perf" allows to right away get a profile of real
+stacks, from just symbol information generated by the toolchain.
diff --git a/ruy/profiler/instrumentation.cc b/ruy/profiler/instrumentation.cc
new file mode 100644
index 0000000..f03f667
--- /dev/null
+++ b/ruy/profiler/instrumentation.cc
@@ -0,0 +1,130 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/profiler/instrumentation.h"
+
+#ifdef RUY_PROFILER
+
+namespace ruy {
+namespace profiler {
+
+void Label::operator=(const Label& other) {
+ format_ = other.format_;
+ args_count_ = other.args_count_;
+ for (int i = 0; i < args_count_; i++) {
+ args_[i] = other.args_[i];
+ }
+}
+
+bool Label::operator==(const Label& other) const {
+ if (std::string(format_) != std::string(other.format_)) {
+ return false;
+ }
+ if (args_count_ != other.args_count_) {
+ return false;
+ }
+ for (int i = 0; i < args_count_; i++) {
+ if (args_[i] != other.args_[i]) {
+ return false;
+ }
+ }
+ return true;
+}
+
+std::string Label::Formatted() const {
+ static constexpr int kBufSize = 256;
+ char buf[kBufSize];
+ if (args_count_ == 0) {
+ return format_;
+ }
+ if (args_count_ == 1) {
+ snprintf(buf, kBufSize, format_, args_[0]);
+ } else if (args_count_ == 2) {
+ snprintf(buf, kBufSize, format_, args_[0], args_[1]);
+ } else if (args_count_ == 3) {
+ snprintf(buf, kBufSize, format_, args_[0], args_[1], args_[2]);
+ } else if (args_count_ == 4) {
+ snprintf(buf, kBufSize, format_, args_[0], args_[1], args_[2], args_[3]);
+ } else {
+ abort();
+ }
+ return buf;
+}
+
+namespace detail {
+
+std::mutex* GlobalsMutex() {
+ static std::mutex mutex;
+ return &mutex;
+}
+
+bool& GlobalIsProfilerRunning() {
+ static bool b;
+ return b;
+}
+
+std::vector<ThreadStack*>* GlobalAllThreadStacks() {
+ static std::vector<ThreadStack*> all_stacks;
+ return &all_stacks;
+}
+
+ThreadStack* ThreadLocalThreadStack() {
+ thread_local static ThreadStack thread_stack;
+ return &thread_stack;
+}
+
+ThreadStack::ThreadStack() {
+ std::lock_guard<std::mutex> lock(*GlobalsMutex());
+ static std::uint32_t global_next_thread_stack_id = 0;
+ stack_.id = global_next_thread_stack_id++;
+ GlobalAllThreadStacks()->push_back(this);
+}
+
+ThreadStack::~ThreadStack() {
+ std::lock_guard<std::mutex> lock(*GlobalsMutex());
+ std::vector<ThreadStack*>* all_stacks = GlobalAllThreadStacks();
+ for (auto it = all_stacks->begin(); it != all_stacks->end(); ++it) {
+ if (*it == this) {
+ all_stacks->erase(it);
+ return;
+ }
+ }
+}
+int GetBufferSize(const Stack& stack) {
+ return sizeof(stack.id) + sizeof(stack.size) +
+ stack.size * sizeof(stack.labels[0]);
+}
+
+void CopyToBuffer(const Stack& stack, char* dst) {
+ memcpy(dst, &stack.id, sizeof(stack.id));
+ dst += sizeof(stack.id);
+ memcpy(dst, &stack.size, sizeof(stack.size));
+ dst += sizeof(stack.size);
+ memcpy(dst, stack.labels, stack.size * sizeof(stack.labels[0]));
+}
+
+void ReadFromBuffer(const char* src, Stack* stack) {
+ memcpy(&stack->id, src, sizeof(stack->id));
+ src += sizeof(stack->id);
+ memcpy(&stack->size, src, sizeof(stack->size));
+ src += sizeof(stack->size);
+ memcpy(stack->labels, src, stack->size * sizeof(stack->labels[0]));
+}
+
+} // namespace detail
+} // namespace profiler
+} // namespace ruy
+
+#endif
diff --git a/ruy/profiler/instrumentation.h b/ruy/profiler/instrumentation.h
new file mode 100644
index 0000000..a9046d4
--- /dev/null
+++ b/ruy/profiler/instrumentation.h
@@ -0,0 +1,203 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_INSTRUMENTATION_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_INSTRUMENTATION_H_
+
+#ifdef RUY_PROFILER
+#include <cstdio>
+#include <mutex>
+#include <vector>
+#endif
+
+namespace ruy {
+namespace profiler {
+
+#ifdef RUY_PROFILER
+
+// A label is how a code scope is annotated to appear in profiles.
+// The stacks that are sampled by the profiler are stacks of such labels.
+// A label consists of a literal string, plus optional integer arguments.
+class Label {
+ public:
+ Label() {}
+ template <typename... Args>
+ explicit Label(Args... args) {
+ Set(args...);
+ }
+ void Set(const char* format) {
+ format_ = format;
+ args_count_ = 0;
+ }
+ template <typename... Args>
+ void Set(const char* format, Args... args) {
+ format_ = format;
+ args_count_ = sizeof...(args);
+ SetArgs(0, args...);
+ }
+
+ void operator=(const Label& other);
+
+ bool operator==(const Label& other) const;
+
+ std::string Formatted() const;
+ const char* format() const { return format_; }
+
+ private:
+ void SetArgs(int position, int arg0) { args_[position] = arg0; }
+
+ template <typename... Args>
+ void SetArgs(int position, int arg0, Args... args) {
+ SetArgs(position, arg0);
+ SetArgs(position + 1, args...);
+ }
+
+ static constexpr int kMaxArgs = 4;
+ const char* format_ = nullptr;
+ int args_count_ = 0;
+ int args_[kMaxArgs];
+};
+
+namespace detail {
+
+// Forward-declaration, see class ThreadStack below.
+class ThreadStack;
+
+bool& GlobalIsProfilerRunning();
+
+// Returns the global vector of pointers to all stacks, there being one stack
+// per thread executing instrumented code.
+std::vector<ThreadStack*>* GlobalAllThreadStacks();
+
+// Returns the mutex to be locked around any access to GlobalAllThreadStacks().
+std::mutex* GlobalsMutex();
+
+// Returns the thread-local stack, specific to the current thread.
+ThreadStack* ThreadLocalThreadStack();
+
+// This 'stack' is what may be more appropriately called a 'pseudostack':
+// It contains Label entries that are 'manually' entered by instrumentation
+// code. It's unrelated to real call stacks.
+struct Stack {
+ std::uint32_t id = 0;
+ static constexpr int kMaxSize = 64;
+ int size = 0;
+ Label labels[kMaxSize];
+};
+
+// Returns the buffer byte size required by CopyToSample.
+int GetBufferSize(const Stack& stack);
+
+// Copies this Stack into a byte buffer, called a 'sample'.
+void CopyToBuffer(const Stack& stack, char* dst);
+
+// Populates this Stack from an existing sample buffer, typically
+// produced by CopyToSample.
+void ReadFromBuffer(const char* src, Stack* stack);
+
+// ThreadStack is meant to be used as a thread-local singleton, assigning to
+// each thread a Stack object holding its pseudo-stack of profile labels,
+// plus a mutex allowing to synchronize accesses to this pseudo-stack between
+// this thread and a possible profiler thread sampling it.
+class ThreadStack {
+ public:
+ ThreadStack();
+ ~ThreadStack();
+
+ const Stack& stack() const { return stack_; }
+
+ // Returns the mutex to lock around any access to this stack. Each stack is
+ // accessed by potentially two threads: the thread that it belongs to
+ // (which calls Push and Pop) and the profiler thread during profiling
+ // (which calls CopyToSample).
+ std::mutex& Mutex() const { return mutex_; }
+
+ // Pushes a new label on the top of this Stack.
+ template <typename... Args>
+ void Push(Args... args) {
+ // This mutex locking is needed to guard against race conditions as both
+ // the current thread and the profiler thread may be concurrently accessing
+ // this stack. In addition to that, this mutex locking also serves the other
+ // purpose of acting as a barrier (of compiler code reordering, of runtime
+ // CPU instruction reordering, and of memory access reordering), which
+ // gives a measure of correctness to this profiler. The downside is some
+ // latency. As this lock will be uncontended most of the times, the cost
+ // should be roughly that of an sequentially-consistent atomic access,
+ // comparable to an access to the level of CPU data cache that is shared
+ // among all cores, typically 60 cycles on current ARM CPUs, plus side
+ // effects from barrier instructions.
+ std::lock_guard<std::mutex> lock(mutex_);
+ // Avoid overrunning the stack, even in 'release' builds. This profiling
+ // instrumentation code should not ship in release builds anyway, the
+ // overhead of this check is negligible, and overrunning a stack array would
+ // be bad.
+ if (stack_.size >= Stack::kMaxSize) {
+ abort();
+ }
+ stack_.labels[stack_.size++].Set(args...);
+ }
+
+ // Pops the top-most label from this Stack.
+ void Pop() {
+ // See the comment in Push about this lock. While it would be tempting to
+ // try to remove this lock and just atomically decrement size_ with a
+ // store-release, that would not necessarily be a substitute for all of the
+ // purposes that this lock serves, or if it was done carefully to serve all
+ // of the same purposes, then that wouldn't be faster than this (mostly
+ // uncontended) lock.
+ std::lock_guard<std::mutex> lock(mutex_);
+ stack_.size--;
+ }
+
+ private:
+ mutable std::mutex mutex_;
+ Stack stack_;
+};
+
+} // namespace detail
+
+// RAII user-facing way to construct Labels associated with their life scope
+// and get them pushed to / popped from the current thread stack.
+class ScopeLabel {
+ public:
+ template <typename... Args>
+ ScopeLabel(Args... args) : thread_stack_(detail::ThreadLocalThreadStack()) {
+ thread_stack_->Push(args...);
+ }
+
+ ~ScopeLabel() { thread_stack_->Pop(); }
+
+ private:
+ detail::ThreadStack* thread_stack_;
+};
+
+#else // no RUY_PROFILER
+
+class ScopeLabel {
+ public:
+ template <typename... Args>
+ explicit ScopeLabel(Args...) {}
+
+ // This destructor is needed to consistently silence clang's -Wunused-variable
+ // which seems to trigger semi-randomly.
+ ~ScopeLabel() {}
+};
+
+#endif
+
+} // namespace profiler
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_INSTRUMENTATION_H_
diff --git a/ruy/profiler/profiler.cc b/ruy/profiler/profiler.cc
new file mode 100644
index 0000000..ae3a2e2
--- /dev/null
+++ b/ruy/profiler/profiler.cc
@@ -0,0 +1,109 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "ruy/profiler/profiler.h"
+
+#ifdef RUY_PROFILER
+#include <atomic>
+#include <chrono> // NOLINT
+#include <cstdio>
+#include <cstdlib>
+#include <thread> // NOLINT
+#include <vector>
+#endif
+
+#include "ruy/profiler/instrumentation.h"
+#include "ruy/profiler/treeview.h"
+
+namespace ruy {
+namespace profiler {
+
+#ifdef RUY_PROFILER
+
+ScopeProfile::ScopeProfile() { Start(); }
+ScopeProfile::ScopeProfile(bool enable) {
+ if (enable) {
+ Start();
+ }
+}
+ScopeProfile::~ScopeProfile() {
+ if (!thread_) {
+ return;
+ }
+ finishing_.store(true);
+ thread_->join();
+ Finish();
+}
+
+void ScopeProfile::Start() {
+ {
+ std::lock_guard<std::mutex> lock(*detail::GlobalsMutex());
+ if (detail::GlobalIsProfilerRunning()) {
+ fprintf(stderr, "FATAL: profiler already running!\n");
+ abort();
+ }
+ detail::GlobalIsProfilerRunning() = true;
+ }
+ finishing_ = false;
+ thread_.reset(new std::thread(&ScopeProfile::ThreadFunc, this));
+}
+
+void ScopeProfile::ThreadFunc() {
+ while (!finishing_.load()) {
+ std::this_thread::sleep_for(std::chrono::milliseconds(1));
+ std::lock_guard<std::mutex> lock(*detail::GlobalsMutex());
+ auto* thread_stacks = detail::GlobalAllThreadStacks();
+ for (detail::ThreadStack* thread_stack : *thread_stacks) {
+ Sample(*thread_stack);
+ }
+ }
+}
+
+void ScopeProfile::Sample(const detail::ThreadStack& thread_stack) {
+ std::lock_guard<std::mutex> lock(thread_stack.Mutex());
+ // Drop empty stacks.
+ // This ensures that profiles aren't polluted by uninteresting threads.
+ if (thread_stack.stack().size == 0) {
+ return;
+ }
+ int sample_size = detail::GetBufferSize(thread_stack.stack());
+ int old_buf_size = samples_buf_.size();
+ samples_buf_.resize(old_buf_size + sample_size);
+ detail::CopyToBuffer(thread_stack.stack(),
+ samples_buf_.data() + old_buf_size);
+}
+
+void ScopeProfile::Finish() {
+ {
+ std::lock_guard<std::mutex> lock(*detail::GlobalsMutex());
+ if (!detail::GlobalIsProfilerRunning()) {
+ fprintf(stderr, "FATAL: profiler is not running!\n");
+ abort();
+ }
+ detail::GlobalIsProfilerRunning() = false;
+ }
+ if (user_treeview_) {
+ user_treeview_->Populate(samples_buf_);
+ } else {
+ TreeView treeview;
+ treeview.Populate(samples_buf_);
+ Print(treeview);
+ }
+}
+
+#endif // RUY_PROFILER
+
+} // namespace profiler
+} // namespace ruy
diff --git a/ruy/profiler/profiler.h b/ruy/profiler/profiler.h
new file mode 100644
index 0000000..b68ca90
--- /dev/null
+++ b/ruy/profiler/profiler.h
@@ -0,0 +1,106 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_PROFILER_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_PROFILER_H_
+
+#include <cstdio>
+
+#ifdef RUY_PROFILER
+#include <atomic>
+#include <chrono>
+#include <thread>
+#include <vector>
+#endif
+
+#include "ruy/profiler/instrumentation.h"
+#include "ruy/profiler/treeview.h"
+
+namespace ruy {
+namespace profiler {
+
+#ifdef RUY_PROFILER
+
+// RAII user-facing way to create a profiler and let it profile a code scope,
+// and print out an ASCII/MarkDown treeview upon leaving the scope.
+class ScopeProfile {
+ public:
+ // Default constructor, unconditionally profiling.
+ ScopeProfile();
+
+ // Constructor allowing to choose at runtime whether to profile.
+ explicit ScopeProfile(bool enable);
+
+ // Destructor. It's where the profile is reported.
+ ~ScopeProfile();
+
+ // See treeview_ member.
+ void SetUserTreeView(TreeView* treeview) { user_treeview_ = treeview; }
+
+ private:
+ void Start();
+
+ // Thread entry point function for the profiler thread. This thread is
+ // created on construction.
+ void ThreadFunc();
+
+ // Record a stack as a sample.
+ void Sample(const detail::ThreadStack& stack);
+
+ // Finalize the profile. Called on destruction.
+ // If user_treeview_ is non-null, it will receive the treeview.
+ // Otherwise the treeview will just be printed.
+ void Finish();
+
+ // Buffer where samples are recorded during profiling.
+ std::vector<char> samples_buf_;
+
+ // Used to synchronize thread termination.
+ std::atomic<bool> finishing_;
+
+ // Underlying profiler thread, which will perform the sampling.
+ // This profiler approach relies on a thread rather than on signals.
+ std::unique_ptr<std::thread> thread_;
+
+ // TreeView to populate upon destruction. If left null (the default),
+ // a temporary treeview will be used and dumped on stdout. The user
+ // may override that by passing their own TreeView object for other
+ // output options or to directly inspect the TreeView.
+ TreeView* user_treeview_ = nullptr;
+};
+
+#else // no RUY_PROFILER
+
+struct ScopeProfile {
+ ScopeProfile() {
+#ifdef GEMMLOWP_PROFILING
+ fprintf(
+ stderr,
+ "\n\n\n**********\n\nWARNING:\n\nLooks like you defined "
+ "GEMMLOWP_PROFILING, but this code has been ported to the new ruy "
+ "profiler replacing the old gemmlowp profiler. You should now be "
+ "defining RUY_PROFILER and not GEMMLOWP_PROFILING. When building using "
+ "Bazel, just pass --define=ruy_profiler=true.\n\n**********\n\n\n");
+#endif
+ }
+ explicit ScopeProfile(bool) {}
+};
+
+#endif
+
+} // namespace profiler
+} // namespace ruy
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_PROFILER_H_
diff --git a/ruy/profiler/test.cc b/ruy/profiler/test.cc
new file mode 100644
index 0000000..e94840b
--- /dev/null
+++ b/ruy/profiler/test.cc
@@ -0,0 +1,167 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <chrono>
+#include <random>
+#include <thread>
+
+#include "testing/base/public/gunit.h"
+#include "ruy/profiler/profiler.h"
+#include "ruy/profiler/test_instrumented_library.h"
+#include "ruy/profiler/treeview.h"
+
+namespace ruy {
+namespace profiler {
+namespace {
+
+void DoSomeMergeSort(int size) {
+ std::vector<int> data(size);
+
+ std::default_random_engine engine;
+ for (auto& val : data) {
+ val = engine();
+ }
+
+ MergeSort(size, data.data());
+}
+
+// The purpose of this basic test is to cover the basic path that will be taken
+// by a majority of users, not inspecting treeviews but just implicitly printing
+// them on stdout, and to have this test enabled even when RUY_PROFILER is not
+// defined, so that we have coverage for the non-RUY_PROFILER case.
+TEST(ProfilerTest, MergeSortSingleThreadBasicTestEvenWithoutProfiler) {
+ {
+ ScopeProfile profile;
+ DoSomeMergeSort(1 << 20);
+ }
+}
+
+#ifdef RUY_PROFILER
+
+TEST(ProfilerTest, MergeSortSingleThread) {
+ TreeView treeview;
+ {
+ ScopeProfile profile;
+ profile.SetUserTreeView(&treeview);
+ DoSomeMergeSort(1 << 20);
+ }
+ Print(treeview);
+ EXPECT_EQ(treeview.thread_roots().size(), 1);
+ const auto& thread_root = *treeview.thread_roots().begin()->second;
+ EXPECT_EQ(DepthOfTreeBelow(thread_root), 22);
+ EXPECT_GE(
+ WeightBelowNodeMatchingUnformatted(thread_root, "Merging sorted halves"),
+ 0.1 * thread_root.weight);
+ EXPECT_GE(WeightBelowNodeMatchingFormatted(
+ thread_root, "MergeSortRecurse (level=20, size=1)"),
+ 0.01 * thread_root.weight);
+
+ TreeView treeview_collapsed;
+ CollapseNodesMatchingUnformatted(treeview, 5, "MergeSort (size=%d)",
+ &treeview_collapsed);
+ Print(treeview_collapsed);
+ const auto& collapsed_thread_root =
+ *treeview_collapsed.thread_roots().begin()->second;
+ EXPECT_EQ(DepthOfTreeBelow(collapsed_thread_root), 6);
+ EXPECT_EQ(
+ WeightBelowNodeMatchingUnformatted(thread_root, "MergeSort (size=%d)"),
+ WeightBelowNodeMatchingUnformatted(collapsed_thread_root,
+ "MergeSort (size=%d)"));
+}
+
+TEST(ProfilerTest, MemcpyFourThreads) {
+ TreeView treeview;
+ {
+ ScopeProfile profile;
+ profile.SetUserTreeView(&treeview);
+ std::vector<std::unique_ptr<std::thread>> threads;
+ for (int i = 0; i < 4; i++) {
+ threads.emplace_back(new std::thread([i]() {
+ ScopeLabel thread_label("worker thread #%d", i);
+ std::this_thread::sleep_for(std::chrono::milliseconds(1000));
+ ScopeLabel some_more_work_label("some more work");
+ std::this_thread::sleep_for(std::chrono::milliseconds(1000));
+ }));
+ }
+ for (int i = 0; i < 4; i++) {
+ threads[i]->join();
+ }
+ }
+ Print(treeview);
+ // Since we cleared GlobalAllThreadStacks and the current thread hasn't
+ // created any ScopeLabel, only the 4 worker threads should be recorded.
+ EXPECT_EQ(treeview.thread_roots().size(), 4);
+ for (const auto& thread_root : treeview.thread_roots()) {
+ const TreeView::Node& root_node = *thread_root.second;
+ // The root node may have 1 or 2 children depending on whether there is
+ // an "[other]" child.
+ EXPECT_GE(root_node.children.size(), 1);
+ EXPECT_LE(root_node.children.size(), 2);
+ const TreeView::Node& child_node = *root_node.children[0];
+ EXPECT_EQ(child_node.label.format(), "worker thread #%d");
+ // There must be 2 children, since roughly half the time will be in
+ // "some more work" leaving the other half in "[other]".
+ EXPECT_EQ(child_node.children.size(), 2);
+ const TreeView::Node& child_child_node = *child_node.children[0];
+ // Since we sample every millisecond and the threads run for >= 2000
+ // milliseconds, the "thread func" label should get roughly 2000 samples.
+ // Not very rigorous, as we're depending on the profiler thread getting
+ // scheduled, so to avoid this test being flaky, we use a much more
+ // conservative value of 500, one quarter of that normal value 2000.
+ EXPECT_GE(child_node.weight, 500);
+ // Likewise, allow up to four times more than the normal value 2000.
+ EXPECT_LE(child_node.weight, 8000);
+ // Roughly half of time should be spent under the "some more work" label.
+ float some_more_work_percentage =
+ 100.f * child_child_node.weight / child_node.weight;
+ EXPECT_GE(some_more_work_percentage, 40.0f);
+ EXPECT_LE(some_more_work_percentage, 60.0f);
+ }
+}
+
+TEST(ProfilerTest, OneThreadAfterAnother) {
+ TreeView treeview;
+ {
+ ScopeProfile profile;
+ profile.SetUserTreeView(&treeview);
+ {
+ std::thread thread([]() {
+ ScopeLabel thread_label("thread 0");
+ std::this_thread::sleep_for(std::chrono::milliseconds(1000));
+ });
+ thread.join();
+ }
+ {
+ std::thread thread([]() {
+ ScopeLabel thread_label("thread 1");
+ std::this_thread::sleep_for(std::chrono::milliseconds(1000));
+ });
+ thread.join();
+ }
+ }
+ Print(treeview);
+ EXPECT_EQ(treeview.thread_roots().size(), 2);
+}
+
+#endif // RUY_PROFILER
+
+} // namespace
+} // namespace profiler
+} // namespace ruy
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/ruy/profiler/test_instrumented_library.cc b/ruy/profiler/test_instrumented_library.cc
new file mode 100644
index 0000000..b017ea9
--- /dev/null
+++ b/ruy/profiler/test_instrumented_library.cc
@@ -0,0 +1,59 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include "ruy/profiler/instrumentation.h"
+
+namespace {
+
+void MergeSortRecurse(int level, int size, int* data, int* workspace) {
+ ruy::profiler::ScopeLabel function_label(
+ "MergeSortRecurse (level=%d, size=%d)", level, size);
+ if (size <= 1) {
+ return;
+ }
+ int half_size = size / 2;
+ MergeSortRecurse(level + 1, half_size, data, workspace);
+ MergeSortRecurse(level + 1, size - half_size, data + half_size,
+ workspace + half_size);
+
+ ruy::profiler::ScopeLabel merging_sorted_halves_label(
+ "Merging sorted halves");
+ int dst_index = 0;
+ int left_index = 0;
+ int right_index = half_size;
+ while (dst_index < size) {
+ int val;
+ if (left_index < half_size &&
+ ((right_index >= size) || data[left_index] < data[right_index])) {
+ val = data[left_index++];
+ } else {
+ val = data[right_index++];
+ }
+ workspace[dst_index++] = val;
+ }
+ for (int i = 0; i < size; i++) {
+ data[i] = workspace[i];
+ }
+}
+
+} // namespace
+
+void MergeSort(int size, int* data) {
+ ruy::profiler::ScopeLabel function_label("MergeSort (size=%d)", size);
+ std::vector<int> workspace(size);
+ MergeSortRecurse(0, size, data, workspace.data());
+}
diff --git a/ruy/profiler/test_instrumented_library.h b/ruy/profiler/test_instrumented_library.h
new file mode 100644
index 0000000..53d204e
--- /dev/null
+++ b/ruy/profiler/test_instrumented_library.h
@@ -0,0 +1,23 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TEST_INSTRUMENTED_LIBRARY_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TEST_INSTRUMENTED_LIBRARY_H_
+
+#include "ruy/profiler/instrumentation.h"
+
+void MergeSort(int size, int* data);
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TEST_INSTRUMENTED_LIBRARY_H_
diff --git a/ruy/profiler/treeview.cc b/ruy/profiler/treeview.cc
new file mode 100644
index 0000000..48d922a
--- /dev/null
+++ b/ruy/profiler/treeview.cc
@@ -0,0 +1,248 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef RUY_PROFILER
+
+#include "ruy/profiler/treeview.h"
+
+#include <algorithm>
+#include <cstdio>
+#include <functional>
+#include <memory>
+#include <vector>
+
+namespace ruy {
+namespace profiler {
+
+namespace {
+
+void SortNode(TreeView::Node* node) {
+ using NodePtr = std::unique_ptr<TreeView::Node>;
+ std::sort(node->children.begin(), node->children.end(),
+ [](const NodePtr& n1, const NodePtr& n2) {
+ return n1->weight > n2->weight;
+ });
+ for (const auto& child : node->children) {
+ SortNode(child.get());
+ }
+}
+
+// Records a stack i.e. a sample in a treeview, by incrementing the weights
+// of matching existing nodes and/or by creating new nodes as needed,
+// recursively, below the given node.
+void AddStack(const detail::Stack& stack, TreeView::Node* node, int level) {
+ node->weight++;
+ if (stack.size == level) {
+ return;
+ }
+ TreeView::Node* child_to_add_to = nullptr;
+ for (const auto& child : node->children) {
+ if (child->label == stack.labels[level]) {
+ child_to_add_to = child.get();
+ break;
+ }
+ }
+ if (!child_to_add_to) {
+ child_to_add_to = node->children.emplace_back(new TreeView::Node).get();
+ child_to_add_to->label = stack.labels[level];
+ }
+ AddStack(stack, child_to_add_to, level + 1);
+}
+
+// Recursively populates the treeview below the given node with 'other'
+// entries documenting for each node the difference between its weight and the
+// sum of its children's weight.
+void AddOther(TreeView::Node* node) {
+ int top_level_children_weight = 0;
+ for (const auto& child : node->children) {
+ AddOther(child.get());
+ top_level_children_weight += child->weight;
+ }
+ if (top_level_children_weight != 0 &&
+ top_level_children_weight != node->weight) {
+ const auto& new_child = node->children.emplace_back(new TreeView::Node);
+ new_child->label = Label("[other]");
+ new_child->weight = node->weight - top_level_children_weight;
+ }
+}
+
+} // namespace
+
+void TreeView::Populate(const std::vector<char>& samples_buf_) {
+ thread_roots_.clear();
+ // Populate the treeview with regular nodes coming from samples.
+ const char* buf_ptr = samples_buf_.data();
+ const char* const buf_ptr_end = buf_ptr + samples_buf_.size();
+ while (buf_ptr < buf_ptr_end) {
+ detail::Stack stack;
+ detail::ReadFromBuffer(buf_ptr, &stack);
+ // Empty stacks should have been dropped during sampling.
+ assert(stack.size > 0);
+ buf_ptr += GetBufferSize(stack);
+ const int id = stack.id;
+ if (!thread_roots_[id]) {
+ thread_roots_[id].reset(new Node);
+ }
+ AddStack(stack, thread_roots_[id].get(), 0);
+ }
+ // Populate the treeview with additional 'other' nodes, sort, and set
+ // root labels.
+ for (const auto& thread_root : thread_roots_) {
+ std::uint32_t id = thread_root.first;
+ Node* root = thread_root.second.get();
+ AddOther(root);
+ SortNode(root);
+ root->label.Set("Thread %x (%d samples)", id, root->weight);
+ }
+}
+
+// Recursively prints the treeview below the given node. The 'root' node
+// argument is only needed to compute weights ratios, with the root ratio
+// as denominator.
+void PrintTreeBelow(const TreeView::Node& node, const TreeView::Node& root,
+ int level) {
+ if (&node == &root) {
+ printf("%s\n\n", node.label.Formatted().c_str());
+ } else {
+ for (int i = 1; i < level; i++) {
+ printf(" ");
+ }
+ printf("* %.2f%% %s\n", 100.0f * node.weight / root.weight,
+ node.label.Formatted().c_str());
+ }
+ for (const auto& child : node.children) {
+ PrintTreeBelow(*child, root, level + 1);
+ }
+}
+
+void Print(const TreeView& treeview) {
+ printf("\n");
+ printf("Profile (%d threads):\n\n",
+ static_cast<int>(treeview.thread_roots().size()));
+ for (const auto& thread_root : treeview.thread_roots()) {
+ const TreeView::Node& root = *thread_root.second;
+ PrintTreeBelow(root, root, 0);
+ printf("\n");
+ }
+}
+
+int DepthOfTreeBelow(const TreeView::Node& node) {
+ if (node.children.empty()) {
+ return 0;
+ } else {
+ int max_child_depth = 0;
+ for (const auto& child : node.children) {
+ max_child_depth = std::max(max_child_depth, DepthOfTreeBelow(*child));
+ }
+ return 1 + max_child_depth;
+ }
+}
+
+int WeightBelowNodeMatchingFunction(
+ const TreeView::Node& node,
+ const std::function<bool(const Label&)>& match) {
+ int weight = 0;
+ if (match(node.label)) {
+ weight += node.weight;
+ }
+ for (const auto& child : node.children) {
+ weight += WeightBelowNodeMatchingFunction(*child, match);
+ }
+ return weight;
+}
+
+int WeightBelowNodeMatchingUnformatted(const TreeView::Node& node,
+ const std::string& format) {
+ return WeightBelowNodeMatchingFunction(
+ node, [&format](const Label& label) { return label.format() == format; });
+}
+
+int WeightBelowNodeMatchingFormatted(const TreeView::Node& node,
+ const std::string& formatted) {
+ return WeightBelowNodeMatchingFunction(
+ node, [&formatted](const Label& label) {
+ return label.Formatted() == formatted;
+ });
+}
+
+void CollapseNode(const TreeView::Node& node_in, int depth,
+ TreeView::Node* node_out) {
+ node_out->label = node_in.label;
+ node_out->weight = node_in.weight;
+ node_out->children.clear();
+ if (depth > 0) {
+ for (const auto& child_in : node_in.children) {
+ auto* child_out = new TreeView::Node;
+ node_out->children.emplace_back(child_out);
+ CollapseNode(*child_in, depth - 1, child_out);
+ }
+ }
+}
+
+void CollapseSubnodesMatchingFunction(
+ const TreeView::Node& node_in, int depth,
+ const std::function<bool(const Label&)>& match, TreeView::Node* node_out) {
+ if (match(node_in.label)) {
+ CollapseNode(node_in, depth, node_out);
+ } else {
+ node_out->label = node_in.label;
+ node_out->weight = node_in.weight;
+ node_out->children.clear();
+
+ for (const auto& child_in : node_in.children) {
+ auto* child_out = new TreeView::Node;
+ node_out->children.emplace_back(child_out);
+ CollapseSubnodesMatchingFunction(*child_in, depth, match, child_out);
+ }
+ }
+}
+
+void CollapseNodesMatchingFunction(
+ const TreeView& treeview_in, int depth,
+ const std::function<bool(const Label&)>& match, TreeView* treeview_out) {
+ treeview_out->mutable_thread_roots()->clear();
+ for (const auto& thread_root_in : treeview_in.thread_roots()) {
+ std::uint32_t id = thread_root_in.first;
+ const auto& root_in = *thread_root_in.second;
+ auto* root_out = new TreeView::Node;
+ treeview_out->mutable_thread_roots()->emplace(id, root_out);
+ CollapseSubnodesMatchingFunction(root_in, depth, match, root_out);
+ }
+}
+
+void CollapseNodesMatchingUnformatted(const TreeView& treeview_in, int depth,
+ const std::string& format,
+ TreeView* treeview_out) {
+ CollapseNodesMatchingFunction(
+ treeview_in, depth,
+ [&format](const Label& label) { return label.format() == format; },
+ treeview_out);
+}
+
+void CollapseNodesMatchingFormatted(const TreeView& treeview_in, int depth,
+ const std::string& formatted,
+ TreeView* treeview_out) {
+ CollapseNodesMatchingFunction(
+ treeview_in, depth,
+ [&formatted](const Label& label) {
+ return label.Formatted() == formatted;
+ },
+ treeview_out);
+}
+
+} // namespace profiler
+} // namespace ruy
+
+#endif // RUY_PROFILER
diff --git a/ruy/profiler/treeview.h b/ruy/profiler/treeview.h
new file mode 100644
index 0000000..e34b4f9
--- /dev/null
+++ b/ruy/profiler/treeview.h
@@ -0,0 +1,130 @@
+/* Copyright 2020 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TREEVIEW_H_
+#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TREEVIEW_H_
+
+#ifdef RUY_PROFILER
+
+#include <functional>
+#include <map>
+#include <memory>
+#include <vector>
+
+#include "ruy/profiler/instrumentation.h"
+
+namespace ruy {
+namespace profiler {
+
+// A tree view of a profile.
+class TreeView {
+ public:
+ struct Node {
+ std::vector<std::unique_ptr<Node>> children;
+ Label label;
+ int weight = 0;
+ };
+
+ void Populate(const std::vector<char>& samples_buf_);
+
+ // Intentionally an *ordered* map so that threads are enumerated
+ // in an order that's consistent and typically putting the 'main thread'
+ // first.
+ using ThreadRootsMap = std::map<std::uint32_t, std::unique_ptr<Node>>;
+
+ const ThreadRootsMap& thread_roots() const { return thread_roots_; }
+ ThreadRootsMap* mutable_thread_roots() { return &thread_roots_; }
+
+ private:
+ ThreadRootsMap thread_roots_;
+};
+
+/* Below are API functions for manipulating and printing treeviews. */
+
+// Prints the treeview to stdout.
+void Print(const TreeView& treeview);
+
+// Prints the treeview below the given node on stdout.
+void PrintTreeBelow(const TreeView::Node& node);
+
+// Returns the tree depth below the given node.
+int DepthOfTreeBelow(const TreeView::Node& node);
+
+// Returns the sum of weights of nodes below the given node and filtered by
+// the `match` predicate.
+int WeightBelowNodeMatchingFunction(
+ const TreeView::Node& node, const std::function<bool(const Label&)>& match);
+
+// Returns the sum of weights of nodes below the given node and whose
+// unformatted label (i.e. raw format string) matches the given `format` string.
+//
+// This allows to aggregate nodes whose labels differ only by parameter values.
+int WeightBelowNodeMatchingUnformatted(const TreeView::Node& node,
+ const std::string& format);
+
+// Returns the sum of weights of nodes below the given node and whose formatted
+// label matches the `formatted` string.
+//
+// In the case of nodes with parametrized labels, this allows to count only
+// nodes with specific parameter values. For that purpose, one may also instead
+// use WeightBelowNodeMatchingFunction directly, with a `match` predicate
+// comparing raw integer parameter values directly, instead of going through
+// formatted strings.
+int WeightBelowNodeMatchingFormatted(const TreeView::Node& node,
+ const std::string& formatted);
+
+// Produces a `node_out` that is a copy of `node_in` but with tree depth below
+// it clamped at `depth`, with further subtrees aggregated into single leaf
+// nodes.
+void CollapseNode(const TreeView::Node& node_in, int depth,
+ TreeView::Node* node_out);
+
+// Calls CollapseNode with the given `depth` on every subnode filtered by the
+// `match` predicate. Note that this does NOT limit the tree depth below
+// `node_out` to `depth`, since each collapsed node below `node_out` may be
+// arbitrarily far below it and `depth` is only used as the collapsing depth
+// at that point.
+void CollapseSubnodesMatchingFunction(
+ const TreeView::Node& node_in, int depth,
+ const std::function<bool(const Label&)>& match, TreeView::Node* node_out);
+
+// Calls CollapseNode with the given `depth` on every node filtered by the
+// `match` predicate. Note that this does NOT limit the tree depth below
+// `node_out` to `depth`, since each collapsed node below `node_out` may be
+// arbitrarily far below it and `depth` is only used as the collapsing depth
+// at that point.
+void CollapseNodesMatchingFunction(
+ const TreeView& treeview_in, int depth,
+ const std::function<bool(const Label&)>& match, TreeView* treeview_out);
+
+// Special case of CollapseNodesMatchingFunction matching unformatted labels,
+// i.e. raw format strings.
+// See the comment on WeightBelowNodeMatchingUnformatted.
+void CollapseNodesMatchingUnformatted(const TreeView& treeview_in, int depth,
+ const std::string& format,
+ TreeView* treeview_out);
+
+// Special case of CollapseNodesMatchingFunction matching formatted labels.
+// See the comment on WeightBelowNodeMatchingFormatted.
+void CollapseNodesMatchingFormatted(const TreeView& treeview_in, int depth,
+ const std::string& formatted,
+ TreeView* treeview_out);
+
+} // namespace profiler
+} // namespace ruy
+
+#endif // RUY_PROFILER
+
+#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PROFILER_TREEVIEW_H_