Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/google/ruy.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBenoit Jacob <benoitjacob@google.com>2022-03-25 05:09:40 +0300
committerCopybara-Service <copybara-worker@google.com>2022-03-25 05:10:03 +0300
commit7ef39c5745a61f43071e699c6a96da41701ae59f (patch)
treecbe581e535ccf5f8737ca599ac59468ac7beb8c8
parentcf14b2b0ea27045c8c323a4ec11d771be3d2926a (diff)
Fix an integer overflow, and take some extra defensive steps.
PiperOrigin-RevId: 437140449
-rw-r--r--ruy/trmul.cc25
1 files changed, 19 insertions, 6 deletions
diff --git a/ruy/trmul.cc b/ruy/trmul.cc
index f8651b7..7af4bf0 100644
--- a/ruy/trmul.cc
+++ b/ruy/trmul.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <atomic>
#include <cstdint>
#include <cstring>
+#include <limits>
#include <memory>
#include <vector>
@@ -256,12 +257,24 @@ int GetTentativeThreadCount(Ctx* ctx, int rows, int cols, int depth) {
// Empirically determined rule for reasonable number of
// threads to use. This is proportional to the number of arithmetic ops
// in this Mul (product of the 3 sizes).
- const int64_t total_number_of_elements = rows * cols * depth;
- static constexpr int64_t kDivisor = 32768;
- int tentative_thread_count = std::max(
- 1, std::min(static_cast<int>(total_number_of_elements / kDivisor),
- ctx->max_num_threads()));
- RUY_TRACE_INFO(GET_TENTATIVE_THREAD_COUNT);
+ // Be defensive here by explicitly promoting operands to int64 to avoid the
+ // pitfall of `int64 result = x * y;` overflowing as x and y are still narrow.
+ const std::int64_t rows_i64 = rows;
+ const std::int64_t cols_i64 = cols;
+ const std::int64_t depth_i64 = depth;
+ const std::int64_t problem_size = rows_i64 * cols_i64 * depth_i64;
+ // Division is cheap when the denominator is constant
+ static constexpr std::int64_t kSizePerAdditionalThread = 32768;
+ std::int64_t tentative_thread_count = problem_size / kSizePerAdditionalThread;
+ // tentative_thread_count is still an int64, still not necessarily in the
+ // range of type int. It probably is as long as kSizePerAdditionalThread is
+ // large, but imagine that that constant might change in the future.
+ tentative_thread_count = std::max<std::int64_t>(tentative_thread_count, 1);
+ tentative_thread_count =
+ std::min<std::int64_t>(tentative_thread_count, ctx->max_num_threads());
+ // now tentative_thread_count must be in the range of type int, because
+ // ctx->max_num_threads() is.
+ RUY_DCHECK_LE(tentative_thread_count, std::numeric_limits<int>::max());
return tentative_thread_count;
}