Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/google/ruy.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKarim Nosir <karimnosseir@google.com>2022-03-24 04:27:45 +0300
committerCopybara-Service <copybara-worker@google.com>2022-03-24 04:28:14 +0300
commitcf14b2b0ea27045c8c323a4ec11d771be3d2926a (patch)
tree4580e8acf8dc9c48c1852b847dd96db4c26ba58c
parent2d950b3bfa7ebfbe7a97ecb44b1cc4da5ac1d6f0 (diff)
Update GetTentativeThreadCount to use int64 types
PiperOrigin-RevId: 436879056
-rw-r--r--ruy/trmul.cc10
1 files changed, 5 insertions, 5 deletions
diff --git a/ruy/trmul.cc b/ruy/trmul.cc
index dbf5feb..f8651b7 100644
--- a/ruy/trmul.cc
+++ b/ruy/trmul.cc
@@ -256,11 +256,11 @@ int GetTentativeThreadCount(Ctx* ctx, int rows, int cols, int depth) {
// Empirically determined rule for reasonable number of
// threads to use. This is proportional to the number of arithmetic ops
// in this Mul (product of the 3 sizes).
- static constexpr int kDivisorLog2 = 15;
- const int guess_log2 = std::max(
- 0, ceil_log2(rows) + ceil_log2(cols) + ceil_log2(depth) - kDivisorLog2);
- int tentative_thread_count =
- std::min(1 << guess_log2, ctx->max_num_threads());
+ const int64_t total_number_of_elements = rows * cols * depth;
+ static constexpr int64_t kDivisor = 32768;
+ int tentative_thread_count = std::max(
+ 1, std::min(static_cast<int>(total_number_of_elements / kDivisor),
+ ctx->max_num_threads()));
RUY_TRACE_INFO(GET_TENTATIVE_THREAD_COUNT);
return tentative_thread_count;
}