From b0e97e627db281e2f2d79ac84333a93452173ccd Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Tue, 9 Mar 2021 19:18:04 -0800 Subject: rollback hopefully fixing some application crash PiperOrigin-RevId: 361951187 --- ruy/apply_multiplier_test.cc | 5 ----- ruy/mul_params.h | 46 ++++++++++++++++++++------------------------ ruy/mul_params_test.cc | 2 +- ruy/ruy.h | 8 -------- 4 files changed, 22 insertions(+), 39 deletions(-) diff --git a/ruy/apply_multiplier_test.cc b/ruy/apply_multiplier_test.cc index 2df80d7..ff4cb2c 100644 --- a/ruy/apply_multiplier_test.cc +++ b/ruy/apply_multiplier_test.cc @@ -104,14 +104,9 @@ void TestApplyMultiplier(const MulParams& mul_params, TEST(ApplyMultiplierTest, ApplyMultiplierUniform) { MulParams mul_params; - // Test that default values give a multiplication by 1. - TestApplyMultiplier(mul_params, 0, 1000, 1000); mul_params.set_multiplier_fixedpoint(1 << 30); mul_params.set_multiplier_exponent(-1); TestApplyMultiplier(mul_params, 0, 1000, 250); - mul_params.set_multiplier_fixedpoint(1 << 25); - mul_params.set_multiplier_exponent(3); - TestApplyMultiplier(mul_params, 0, 1000, 125); } TEST(ApplyMultiplierTest, ApplyMultiplierPerChannel) { diff --git a/ruy/mul_params.h b/ruy/mul_params.h index d5aa27b..3535a75 100644 --- a/ruy/mul_params.h +++ b/ruy/mul_params.h @@ -103,14 +103,9 @@ class MulParams final { // The bias vector data, if not null. const AccumScalar* bias() const { return storage_.bias; } void set_bias(const AccumScalar* ptr) { storage_.bias = ptr; } - // Only for non-floating-point cases. The fixed-point part of the multiplier - // by which accumulators are multiplied before being casted to the destination - // type. This is a fixed-point quantity with 0 integer bits. Since - // (as explained in the class comment) AccumScalar must be std::int32_t, - // that means that the fixed-point format is Q0.31. For example, - // a multiplier_fixedpoint value of 2^30 has the effect of multiplying - // by one half (1/2). More generally, the effect is to multiply by - // (multiplier_fixedpoint / (2^31)). + // Only for non-floating-point cases. The fixed-point part (i.e. the mantissa) + // of the multiplier by which accumulators are multiplied before being casted + // to the destination type. AccumScalar multiplier_fixedpoint() const { return storage_.perchannel ? 0 : storage_.multiplier_fixedpoint; } @@ -132,10 +127,9 @@ class MulParams final { // `multiplier_exponent` are disabled and `multiplier_fixedpoint_perchannel` // and `multiplier_exponent_perchannel` are used instead. // - // This must point to a buffer of as many values as there are rows or columns - // in the destination matrix, whichever is the channels dimension. Each - // channel of the destination matrix will use the corresponding buffer element - // instead of multiplier_fixedpoint. + // This must point to a buffer of as many values as there are rows in the + // destination matrix. Each row of the destination matrix will use the + // corresponding buffer element instead of multiplier_fixedpoint. const AccumScalar* multiplier_fixedpoint_perchannel() const { return storage_.perchannel ? storage_.multiplier_fixedpoint_perchannel : nullptr; @@ -205,6 +199,16 @@ class MulParams final { detail::MulParamsStorage storage_; void set_perchannel(bool perchannel) { + if (storage_.perchannel == perchannel) { + return; + } + if (perchannel) { + RUY_DCHECK_EQ(storage_.multiplier_fixedpoint, 0); + RUY_DCHECK_EQ(storage_.multiplier_exponent, 0); + } else { + RUY_DCHECK_EQ(storage_.multiplier_fixedpoint_perchannel, nullptr); + RUY_DCHECK_EQ(storage_.multiplier_exponent_perchannel, nullptr); + } storage_.perchannel = perchannel; } }; @@ -240,24 +244,16 @@ template struct MulParamsStorage final { using AccumScalar = std::int32_t; static_assert(std::is_integral::value, ""); - static_assert(sizeof(DstScalar) <= sizeof(AccumScalar) / 2, ""); + static_assert(sizeof(DstScalar) < sizeof(AccumScalar), ""); const AccumScalar* bias = nullptr; union { - const AccumScalar* multiplier_fixedpoint_perchannel; - // Let the default multiplier be effecively a multiplication by 1, so that - // the matmul behaves as a (saturating) plain integer matmul. Unfortunately - // 1 is not exactly representable in fixedpoint with 0 integer bits, but - // using the highest representable value is a sufficiently good - // approximation: since this specialization of MulParams is for the case - // where DstScalar is at least 2x narrower than MulScalar, the values - // for which there would be a difference will get saturated anyway. - AccumScalar multiplier_fixedpoint = std::numeric_limits::max(); + const AccumScalar* multiplier_fixedpoint_perchannel = nullptr; + AccumScalar multiplier_fixedpoint; }; union { - const int* multiplier_exponent_perchannel; - // See the above comment about the default value of multiplier_fixedpoint. - int multiplier_exponent = 0; + const int* multiplier_exponent_perchannel = nullptr; + int multiplier_exponent; }; DstScalar clamp_min = std::numeric_limits::lowest(); DstScalar clamp_max = std::numeric_limits::max(); diff --git a/ruy/mul_params_test.cc b/ruy/mul_params_test.cc index 4bc9f87..feb7dbb 100644 --- a/ruy/mul_params_test.cc +++ b/ruy/mul_params_test.cc @@ -31,7 +31,7 @@ TEST(MulParamsTest, SpecClassSanity) { MulParamsType mul_params; EXPECT_EQ(mul_params.bias(), nullptr); - EXPECT_EQ(mul_params.multiplier_fixedpoint(), std::numeric_limits::max()); + EXPECT_EQ(mul_params.multiplier_fixedpoint(), 0); EXPECT_EQ(mul_params.multiplier_exponent(), 0); EXPECT_EQ(mul_params.multiplier_fixedpoint_perchannel(), nullptr); EXPECT_EQ(mul_params.multiplier_exponent_perchannel(), nullptr); diff --git a/ruy/ruy.h b/ruy/ruy.h index ddbe192..3cf7bdd 100644 --- a/ruy/ruy.h +++ b/ruy/ruy.h @@ -93,14 +93,6 @@ void Mul(const Matrix& lhs, const Matrix& rhs, // (e.g. the number of CPU cores in typical scenarios). At least ruy forces // each invocation to make an explicit decision here, there is no automatic // detection of the best number of threads to use in ruy. -// -// Constraints on the template parameters: -// * If DstScalar is floating-point then AccumScalar must also be. -// * If DstScalar is integral then AccumScalar must be std::int32_t. -// Please refer to MulParams' class comment for more information. When -// DstScalar is integral and is narrower than AccumScalar, additional -// MulParams fields must be set to control the scaling of internal accumulators -// before the final saturating cast to the DstScalar type. template void Mul(const Matrix& lhs, const Matrix& rhs, -- cgit v1.2.3