diff options
author | bjacob <jacob.benoit.1@gmail.com> | 2021-01-21 23:33:11 +0300 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2021-01-21 23:34:32 +0300 |
commit | 58e3051707b3a92dab0a72297183114a1d35f483 (patch) | |
tree | c4de371de040b400e57fdb0a3ba0eab2a7e5838c | |
parent | fad5a101143f2c27e0888ef0afe0e5230907a782 (diff) |
Change the default MulParams multiplier values to multiply by 1, not 0.
Multiplying by 0 by default is unfriendly to people getting familiar
with ruy having to debug why their output values are all 0.
With a default of 1, tiny toy examples might output sane values,
anything beyond that will saturate, and seeing all saturated values will
be a hint that something needs to be set to rescale values.
Closes https://github.com/google/ruy/pull/248
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/ruy/pull/248 from bjacob:multiplier-default 3fb1152e899fffc1f9fa9103b533348599ca494f
PiperOrigin-RevId: 353077204
-rw-r--r-- | ruy/apply_multiplier.cc | 12 | ||||
-rw-r--r-- | ruy/apply_multiplier.h | 4 | ||||
-rw-r--r-- | ruy/apply_multiplier_test.cc | 5 | ||||
-rw-r--r-- | ruy/mul_params.h | 96 | ||||
-rw-r--r-- | ruy/mul_params_test.cc | 2 | ||||
-rw-r--r-- | ruy/ruy.h | 8 |
6 files changed, 105 insertions, 22 deletions
diff --git a/ruy/apply_multiplier.cc b/ruy/apply_multiplier.cc index 200d1c0..1b7df5c 100644 --- a/ruy/apply_multiplier.cc +++ b/ruy/apply_multiplier.cc @@ -24,6 +24,10 @@ namespace ruy { namespace detail { namespace { +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +// Warning: this code is not meant to be bit-exact-normative. +// Please refer to the class comment of ruy::MulParams, in mul_params.h. +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // Copied from gemmlowp/fixedpoint. // Similar to the ARM64 instruction, SQRDMULH. The name of this function // is copied from the name of that instruction. @@ -43,6 +47,10 @@ std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a, std::int32_t b) { return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32; } +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +// Warning: this code is not meant to be bit-exact-normative. +// Please refer to the class comment of ruy::MulParams, in mul_params.h. +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // Returns numerator/2^exponent, rounding to nearest, breaking ties // upwards. That particular tie-breaking behavior is not important in practice. // It happens to be cheap to implement in hardware and therefore, commonplace. @@ -71,6 +79,10 @@ std::int32_t RoundingRightShift(std::int32_t numerator, int exponent) { } // namespace +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +// Warning: this code is not meant to be bit-exact-normative. +// Please refer to the class comment of ruy::MulParams, in mul_params.h. +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // Copied from TF Lite code. std::int32_t MultiplyByQuantizedMultiplier(std::int32_t x, std::int32_t quantized_multiplier, diff --git a/ruy/apply_multiplier.h b/ruy/apply_multiplier.h index 5a210fc..120b990 100644 --- a/ruy/apply_multiplier.h +++ b/ruy/apply_multiplier.h @@ -14,6 +14,10 @@ limitations under the License. ==============================================================================*/ // Provides a reference (portable, non-optimized) ApplyMultiplier function. +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +// Warning: this code is not meant to be bit-exact-normative. +// Please refer to the class comment of ruy::MulParams, in mul_params.h. +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! #ifndef RUY_RUY_APPLY_MULTIPLIER_H_ #define RUY_RUY_APPLY_MULTIPLIER_H_ diff --git a/ruy/apply_multiplier_test.cc b/ruy/apply_multiplier_test.cc index 900092a..cae6888 100644 --- a/ruy/apply_multiplier_test.cc +++ b/ruy/apply_multiplier_test.cc @@ -107,9 +107,14 @@ void TestApplyMultiplier(const MulParams<AccumScalar, DstScalar>& mul_params, TEST(ApplyMultiplierTest, ApplyMultiplierUniform) { MulParams<std::int32_t, std::int8_t> mul_params; + // Test that default values give a multiplication by 1. + TestApplyMultiplier(mul_params, 0, 1000, 1000); mul_params.set_multiplier_fixedpoint(1 << 30); mul_params.set_multiplier_exponent(-1); TestApplyMultiplier(mul_params, 0, 1000, 250); + mul_params.set_multiplier_fixedpoint(1 << 25); + mul_params.set_multiplier_exponent(3); + TestApplyMultiplier(mul_params, 0, 1000, 125); } TEST(ApplyMultiplierTest, ApplyMultiplierPerChannel) { diff --git a/ruy/mul_params.h b/ruy/mul_params.h index c28237e..9bdbfa4 100644 --- a/ruy/mul_params.h +++ b/ruy/mul_params.h @@ -50,6 +50,56 @@ struct MulParamsStorage; // AccumScalar: Accumulator type. The type of accumulators used to compute the // dot-products before being ultimately casted to the destination type. // DstScalar: The destination scalar type. +// +// Constraints on these template parameters (see also the ruy::Mul comment): +// * If DstScalar is floating-point then AccumScalar must also be. +// * If DstScalar is integral then AccumScalar must be std::int32_t. Moreover +// in that integral case, there is a mode switch: +// - If DstScalar is std::int32_t then the multiplier_* fields are all +// disabled, and ruy::Mul will just return raw (unscaled) accumulators. +// - If DstScalar is not std::int32_t then the multiplier_* fields are +// enabled, and ruy::Mul will use them to scale internal std::int32_t +// accumulators before casting them to the DstScalar type. The default +// values are such that the effective multiplier is 1 (no scaling). +// +// In the latter case (DstScalar integral and narrower than std::int32_t), +// the multiplier effect on accumulator values is as follows: +// +// 1. Left shift by max(0, multiplier_exponent). +// 2. Fixed-point multiplication by multiplier_fixedpoint in Q0.31 format. +// 3. Rounding right shift by max(0, -multiplier_exponent). +// +// Reference code for this can be found in the implementation of +// ruy::ApplyMultiplier. If you look there, you'll find warnings like this: +// +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +// Warning: this code is not meant to be bit-exact-normative. +// Please refer to the class comment of ruy::MulParams, in mul_params.h. +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +// +// The explanation of this warning is that as of early 2021, we still don't know +// whether it is advisable to let this code as-is have normative value, or +// whether that would become advisable after some specific final change. +// +// Ruy's CPU backends (x86 and ARM) as of early 2021 happen to conform +// bit-exactly to this reference, but we also know that x86 could be faster if +// it didn't, and so could NEON-less ARM (such as Cortex-M) (see [2]). We don't +// know that this particular reference code is inherently better than other +// forms that could perform better on these architectures --- in fact, the +// alternative that was proposed in [2] as better performing on ARM Cortex-M +// is also inherently more accurate thanks to rounding only once, but it would +// perform worse on both ARM NEON, and x86. +// +// In fact, if we look at other hardware architectures beyond current Ruy +// targets, namely "hardware accelerators", it becomes clear that there is no +// hope for any form of this to be efficiently implementable simultaneously on +// all current relevant hardware. Indeed, some accelerators prefer to perform +// the multiplication in IEEE float32, others in IEEE float16, others in +// bfloat16, others in 16-bit fixed-point... +// +// See: +// [1] https://github.com/google/ruy/pull/227 +// [2] https://github.com/tensorflow/tensorflow/issues/25087 template <typename tAccumScalar, typename tDstScalar> class MulParams final { public: @@ -59,9 +109,14 @@ class MulParams final { // The bias vector data, if not null. const AccumScalar* bias() const { return storage_.bias; } void set_bias(const AccumScalar* ptr) { storage_.bias = ptr; } - // Only for non-floating-point cases. The fixed-point part (i.e. the mantissa) - // of the multiplier by which accumulators are multiplied before being casted - // to the destination type. + // Only for non-floating-point cases. The fixed-point part of the multiplier + // by which accumulators are multiplied before being casted to the destination + // type. This is a fixed-point quantity with 0 integer bits. Since + // (as explained in the class comment) AccumScalar must be std::int32_t, + // that means that the fixed-point format is Q0.31. For example, + // a multiplier_fixedpoint value of 2^30 has the effect of multiplying + // by one half (1/2). More generally, the effect is to multiply by + // (multiplier_fixedpoint / (2^31)). AccumScalar multiplier_fixedpoint() const { return storage_.perchannel ? 0 : storage_.multiplier_fixedpoint; } @@ -83,9 +138,10 @@ class MulParams final { // `multiplier_exponent` are disabled and `multiplier_fixedpoint_perchannel` // and `multiplier_exponent_perchannel` are used instead. // - // This must point to a buffer of as many values as there are rows in the - // destination matrix. Each row of the destination matrix will use the - // corresponding buffer element instead of multiplier_fixedpoint. + // This must point to a buffer of as many values as there are rows or columns + // in the destination matrix, whichever is the channels dimension. Each + // channel of the destination matrix will use the corresponding buffer element + // instead of multiplier_fixedpoint. const AccumScalar* multiplier_fixedpoint_perchannel() const { return storage_.perchannel ? storage_.multiplier_fixedpoint_perchannel : nullptr; @@ -155,16 +211,6 @@ class MulParams final { detail::MulParamsStorage<AccumScalar, DstScalar> storage_; void set_perchannel(bool perchannel) { - if (storage_.perchannel == perchannel) { - return; - } - if (perchannel) { - RUY_DCHECK_EQ(storage_.multiplier_fixedpoint, 0); - RUY_DCHECK_EQ(storage_.multiplier_exponent, 0); - } else { - RUY_DCHECK_EQ(storage_.multiplier_fixedpoint_perchannel, nullptr); - RUY_DCHECK_EQ(storage_.multiplier_exponent_perchannel, nullptr); - } storage_.perchannel = perchannel; } }; @@ -200,16 +246,24 @@ template <typename DstScalar> struct MulParamsStorage<std::int32_t, DstScalar> final { using AccumScalar = std::int32_t; static_assert(std::is_integral<DstScalar>::value, ""); - static_assert(sizeof(DstScalar) < sizeof(AccumScalar), ""); + static_assert(sizeof(DstScalar) <= sizeof(AccumScalar) / 2, ""); const AccumScalar* bias = nullptr; union { - const AccumScalar* multiplier_fixedpoint_perchannel = nullptr; - AccumScalar multiplier_fixedpoint; + const AccumScalar* multiplier_fixedpoint_perchannel; + // Let the default multiplier be effecively a multiplication by 1, so that + // the matmul behaves as a (saturating) plain integer matmul. Unfortunately + // 1 is not exactly representable in fixedpoint with 0 integer bits, but + // using the highest representable value is a sufficiently good + // approximation: since this specialization of MulParams is for the case + // where DstScalar is at least 2x narrower than MulScalar, the values + // for which there would be a difference will get saturated anyway. + AccumScalar multiplier_fixedpoint = std::numeric_limits<AccumScalar>::max(); }; union { - const int* multiplier_exponent_perchannel = nullptr; - int multiplier_exponent; + const int* multiplier_exponent_perchannel; + // See the above comment about the default value of multiplier_fixedpoint. + int multiplier_exponent = 0; }; DstScalar clamp_min = std::numeric_limits<DstScalar>::lowest(); DstScalar clamp_max = std::numeric_limits<DstScalar>::max(); diff --git a/ruy/mul_params_test.cc b/ruy/mul_params_test.cc index feb7dbb..4bc9f87 100644 --- a/ruy/mul_params_test.cc +++ b/ruy/mul_params_test.cc @@ -31,7 +31,7 @@ TEST(MulParamsTest, SpecClassSanity) { MulParamsType mul_params; EXPECT_EQ(mul_params.bias(), nullptr); - EXPECT_EQ(mul_params.multiplier_fixedpoint(), 0); + EXPECT_EQ(mul_params.multiplier_fixedpoint(), std::numeric_limits<std::int32_t>::max()); EXPECT_EQ(mul_params.multiplier_exponent(), 0); EXPECT_EQ(mul_params.multiplier_fixedpoint_perchannel(), nullptr); EXPECT_EQ(mul_params.multiplier_exponent_perchannel(), nullptr); @@ -93,6 +93,14 @@ void Mul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs, // (e.g. the number of CPU cores in typical scenarios). At least ruy forces // each invocation to make an explicit decision here, there is no automatic // detection of the best number of threads to use in ruy. +// +// Constraints on the template parameters: +// * If DstScalar is floating-point then AccumScalar must also be. +// * If DstScalar is integral then AccumScalar must be std::int32_t. +// Please refer to MulParams' class comment for more information. When +// DstScalar is integral and is narrower than AccumScalar, additional +// MulParams fields must be set to control the scaling of internal accumulators +// before the final saturating cast to the DstScalar type. template <typename LhsScalar, typename RhsScalar, typename AccumScalar, typename DstScalar> void Mul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs, |