diff options
author | Karim Nosir <karimnosseir@google.com> | 2022-03-24 04:27:45 +0300 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2022-03-24 04:28:14 +0300 |
commit | cf14b2b0ea27045c8c323a4ec11d771be3d2926a (patch) | |
tree | 4580e8acf8dc9c48c1852b847dd96db4c26ba58c | |
parent | 2d950b3bfa7ebfbe7a97ecb44b1cc4da5ac1d6f0 (diff) |
Update GetTentativeThreadCount to use int64 types
PiperOrigin-RevId: 436879056
-rw-r--r-- | ruy/trmul.cc | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/ruy/trmul.cc b/ruy/trmul.cc index dbf5feb..f8651b7 100644 --- a/ruy/trmul.cc +++ b/ruy/trmul.cc @@ -256,11 +256,11 @@ int GetTentativeThreadCount(Ctx* ctx, int rows, int cols, int depth) { // Empirically determined rule for reasonable number of // threads to use. This is proportional to the number of arithmetic ops // in this Mul (product of the 3 sizes). - static constexpr int kDivisorLog2 = 15; - const int guess_log2 = std::max( - 0, ceil_log2(rows) + ceil_log2(cols) + ceil_log2(depth) - kDivisorLog2); - int tentative_thread_count = - std::min(1 << guess_log2, ctx->max_num_threads()); + const int64_t total_number_of_elements = rows * cols * depth; + static constexpr int64_t kDivisor = 32768; + int tentative_thread_count = std::max( + 1, std::min(static_cast<int>(total_number_of_elements / kDivisor), + ctx->max_num_threads())); RUY_TRACE_INFO(GET_TENTATIVE_THREAD_COUNT); return tentative_thread_count; } |