diff options
author | Benoit Jacob <benoitjacob@google.com> | 2022-03-25 05:09:40 +0300 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2022-03-25 05:10:03 +0300 |
commit | 7ef39c5745a61f43071e699c6a96da41701ae59f (patch) | |
tree | cbe581e535ccf5f8737ca599ac59468ac7beb8c8 | |
parent | cf14b2b0ea27045c8c323a4ec11d771be3d2926a (diff) |
Fix an integer overflow, and take some extra defensive steps.
PiperOrigin-RevId: 437140449
-rw-r--r-- | ruy/trmul.cc | 25 |
1 files changed, 19 insertions, 6 deletions
diff --git a/ruy/trmul.cc b/ruy/trmul.cc index f8651b7..7af4bf0 100644 --- a/ruy/trmul.cc +++ b/ruy/trmul.cc @@ -21,6 +21,7 @@ limitations under the License. #include <atomic> #include <cstdint> #include <cstring> +#include <limits> #include <memory> #include <vector> @@ -256,12 +257,24 @@ int GetTentativeThreadCount(Ctx* ctx, int rows, int cols, int depth) { // Empirically determined rule for reasonable number of // threads to use. This is proportional to the number of arithmetic ops // in this Mul (product of the 3 sizes). - const int64_t total_number_of_elements = rows * cols * depth; - static constexpr int64_t kDivisor = 32768; - int tentative_thread_count = std::max( - 1, std::min(static_cast<int>(total_number_of_elements / kDivisor), - ctx->max_num_threads())); - RUY_TRACE_INFO(GET_TENTATIVE_THREAD_COUNT); + // Be defensive here by explicitly promoting operands to int64 to avoid the + // pitfall of `int64 result = x * y;` overflowing as x and y are still narrow. + const std::int64_t rows_i64 = rows; + const std::int64_t cols_i64 = cols; + const std::int64_t depth_i64 = depth; + const std::int64_t problem_size = rows_i64 * cols_i64 * depth_i64; + // Division is cheap when the denominator is constant + static constexpr std::int64_t kSizePerAdditionalThread = 32768; + std::int64_t tentative_thread_count = problem_size / kSizePerAdditionalThread; + // tentative_thread_count is still an int64, still not necessarily in the + // range of type int. It probably is as long as kSizePerAdditionalThread is + // large, but imagine that that constant might change in the future. + tentative_thread_count = std::max<std::int64_t>(tentative_thread_count, 1); + tentative_thread_count = + std::min<std::int64_t>(tentative_thread_count, ctx->max_num_threads()); + // now tentative_thread_count must be in the range of type int, because + // ctx->max_num_threads() is. + RUY_DCHECK_LE(tentative_thread_count, std::numeric_limits<int>::max()); return tentative_thread_count; } |