Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/google/ruy.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorbjacob <jacob.benoit.1@gmail.com>2021-01-21 23:33:11 +0300
committerCopybara-Service <copybara-worker@google.com>2021-01-21 23:34:32 +0300
commit58e3051707b3a92dab0a72297183114a1d35f483 (patch)
treec4de371de040b400e57fdb0a3ba0eab2a7e5838c
parentfad5a101143f2c27e0888ef0afe0e5230907a782 (diff)
Change the default MulParams multiplier values to multiply by 1, not 0.
Multiplying by 0 by default is unfriendly to people getting familiar with ruy having to debug why their output values are all 0. With a default of 1, tiny toy examples might output sane values, anything beyond that will saturate, and seeing all saturated values will be a hint that something needs to be set to rescale values. Closes https://github.com/google/ruy/pull/248 COPYBARA_INTEGRATE_REVIEW=https://github.com/google/ruy/pull/248 from bjacob:multiplier-default 3fb1152e899fffc1f9fa9103b533348599ca494f PiperOrigin-RevId: 353077204
-rw-r--r--ruy/apply_multiplier.cc12
-rw-r--r--ruy/apply_multiplier.h4
-rw-r--r--ruy/apply_multiplier_test.cc5
-rw-r--r--ruy/mul_params.h96
-rw-r--r--ruy/mul_params_test.cc2
-rw-r--r--ruy/ruy.h8
6 files changed, 105 insertions, 22 deletions
diff --git a/ruy/apply_multiplier.cc b/ruy/apply_multiplier.cc
index 200d1c0..1b7df5c 100644
--- a/ruy/apply_multiplier.cc
+++ b/ruy/apply_multiplier.cc
@@ -24,6 +24,10 @@ namespace ruy {
namespace detail {
namespace {
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+// Warning: this code is not meant to be bit-exact-normative.
+// Please refer to the class comment of ruy::MulParams, in mul_params.h.
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
// Copied from gemmlowp/fixedpoint.
// Similar to the ARM64 instruction, SQRDMULH. The name of this function
// is copied from the name of that instruction.
@@ -43,6 +47,10 @@ std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a, std::int32_t b) {
return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32;
}
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+// Warning: this code is not meant to be bit-exact-normative.
+// Please refer to the class comment of ruy::MulParams, in mul_params.h.
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
// Returns numerator/2^exponent, rounding to nearest, breaking ties
// upwards. That particular tie-breaking behavior is not important in practice.
// It happens to be cheap to implement in hardware and therefore, commonplace.
@@ -71,6 +79,10 @@ std::int32_t RoundingRightShift(std::int32_t numerator, int exponent) {
} // namespace
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+// Warning: this code is not meant to be bit-exact-normative.
+// Please refer to the class comment of ruy::MulParams, in mul_params.h.
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
// Copied from TF Lite code.
std::int32_t MultiplyByQuantizedMultiplier(std::int32_t x,
std::int32_t quantized_multiplier,
diff --git a/ruy/apply_multiplier.h b/ruy/apply_multiplier.h
index 5a210fc..120b990 100644
--- a/ruy/apply_multiplier.h
+++ b/ruy/apply_multiplier.h
@@ -14,6 +14,10 @@ limitations under the License.
==============================================================================*/
// Provides a reference (portable, non-optimized) ApplyMultiplier function.
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+// Warning: this code is not meant to be bit-exact-normative.
+// Please refer to the class comment of ruy::MulParams, in mul_params.h.
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
#ifndef RUY_RUY_APPLY_MULTIPLIER_H_
#define RUY_RUY_APPLY_MULTIPLIER_H_
diff --git a/ruy/apply_multiplier_test.cc b/ruy/apply_multiplier_test.cc
index 900092a..cae6888 100644
--- a/ruy/apply_multiplier_test.cc
+++ b/ruy/apply_multiplier_test.cc
@@ -107,9 +107,14 @@ void TestApplyMultiplier(const MulParams<AccumScalar, DstScalar>& mul_params,
TEST(ApplyMultiplierTest, ApplyMultiplierUniform) {
MulParams<std::int32_t, std::int8_t> mul_params;
+ // Test that default values give a multiplication by 1.
+ TestApplyMultiplier(mul_params, 0, 1000, 1000);
mul_params.set_multiplier_fixedpoint(1 << 30);
mul_params.set_multiplier_exponent(-1);
TestApplyMultiplier(mul_params, 0, 1000, 250);
+ mul_params.set_multiplier_fixedpoint(1 << 25);
+ mul_params.set_multiplier_exponent(3);
+ TestApplyMultiplier(mul_params, 0, 1000, 125);
}
TEST(ApplyMultiplierTest, ApplyMultiplierPerChannel) {
diff --git a/ruy/mul_params.h b/ruy/mul_params.h
index c28237e..9bdbfa4 100644
--- a/ruy/mul_params.h
+++ b/ruy/mul_params.h
@@ -50,6 +50,56 @@ struct MulParamsStorage;
// AccumScalar: Accumulator type. The type of accumulators used to compute the
// dot-products before being ultimately casted to the destination type.
// DstScalar: The destination scalar type.
+//
+// Constraints on these template parameters (see also the ruy::Mul comment):
+// * If DstScalar is floating-point then AccumScalar must also be.
+// * If DstScalar is integral then AccumScalar must be std::int32_t. Moreover
+// in that integral case, there is a mode switch:
+// - If DstScalar is std::int32_t then the multiplier_* fields are all
+// disabled, and ruy::Mul will just return raw (unscaled) accumulators.
+// - If DstScalar is not std::int32_t then the multiplier_* fields are
+// enabled, and ruy::Mul will use them to scale internal std::int32_t
+// accumulators before casting them to the DstScalar type. The default
+// values are such that the effective multiplier is 1 (no scaling).
+//
+// In the latter case (DstScalar integral and narrower than std::int32_t),
+// the multiplier effect on accumulator values is as follows:
+//
+// 1. Left shift by max(0, multiplier_exponent).
+// 2. Fixed-point multiplication by multiplier_fixedpoint in Q0.31 format.
+// 3. Rounding right shift by max(0, -multiplier_exponent).
+//
+// Reference code for this can be found in the implementation of
+// ruy::ApplyMultiplier. If you look there, you'll find warnings like this:
+//
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+// Warning: this code is not meant to be bit-exact-normative.
+// Please refer to the class comment of ruy::MulParams, in mul_params.h.
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+//
+// The explanation of this warning is that as of early 2021, we still don't know
+// whether it is advisable to let this code as-is have normative value, or
+// whether that would become advisable after some specific final change.
+//
+// Ruy's CPU backends (x86 and ARM) as of early 2021 happen to conform
+// bit-exactly to this reference, but we also know that x86 could be faster if
+// it didn't, and so could NEON-less ARM (such as Cortex-M) (see [2]). We don't
+// know that this particular reference code is inherently better than other
+// forms that could perform better on these architectures --- in fact, the
+// alternative that was proposed in [2] as better performing on ARM Cortex-M
+// is also inherently more accurate thanks to rounding only once, but it would
+// perform worse on both ARM NEON, and x86.
+//
+// In fact, if we look at other hardware architectures beyond current Ruy
+// targets, namely "hardware accelerators", it becomes clear that there is no
+// hope for any form of this to be efficiently implementable simultaneously on
+// all current relevant hardware. Indeed, some accelerators prefer to perform
+// the multiplication in IEEE float32, others in IEEE float16, others in
+// bfloat16, others in 16-bit fixed-point...
+//
+// See:
+// [1] https://github.com/google/ruy/pull/227
+// [2] https://github.com/tensorflow/tensorflow/issues/25087
template <typename tAccumScalar, typename tDstScalar>
class MulParams final {
public:
@@ -59,9 +109,14 @@ class MulParams final {
// The bias vector data, if not null.
const AccumScalar* bias() const { return storage_.bias; }
void set_bias(const AccumScalar* ptr) { storage_.bias = ptr; }
- // Only for non-floating-point cases. The fixed-point part (i.e. the mantissa)
- // of the multiplier by which accumulators are multiplied before being casted
- // to the destination type.
+ // Only for non-floating-point cases. The fixed-point part of the multiplier
+ // by which accumulators are multiplied before being casted to the destination
+ // type. This is a fixed-point quantity with 0 integer bits. Since
+ // (as explained in the class comment) AccumScalar must be std::int32_t,
+ // that means that the fixed-point format is Q0.31. For example,
+ // a multiplier_fixedpoint value of 2^30 has the effect of multiplying
+ // by one half (1/2). More generally, the effect is to multiply by
+ // (multiplier_fixedpoint / (2^31)).
AccumScalar multiplier_fixedpoint() const {
return storage_.perchannel ? 0 : storage_.multiplier_fixedpoint;
}
@@ -83,9 +138,10 @@ class MulParams final {
// `multiplier_exponent` are disabled and `multiplier_fixedpoint_perchannel`
// and `multiplier_exponent_perchannel` are used instead.
//
- // This must point to a buffer of as many values as there are rows in the
- // destination matrix. Each row of the destination matrix will use the
- // corresponding buffer element instead of multiplier_fixedpoint.
+ // This must point to a buffer of as many values as there are rows or columns
+ // in the destination matrix, whichever is the channels dimension. Each
+ // channel of the destination matrix will use the corresponding buffer element
+ // instead of multiplier_fixedpoint.
const AccumScalar* multiplier_fixedpoint_perchannel() const {
return storage_.perchannel ? storage_.multiplier_fixedpoint_perchannel
: nullptr;
@@ -155,16 +211,6 @@ class MulParams final {
detail::MulParamsStorage<AccumScalar, DstScalar> storage_;
void set_perchannel(bool perchannel) {
- if (storage_.perchannel == perchannel) {
- return;
- }
- if (perchannel) {
- RUY_DCHECK_EQ(storage_.multiplier_fixedpoint, 0);
- RUY_DCHECK_EQ(storage_.multiplier_exponent, 0);
- } else {
- RUY_DCHECK_EQ(storage_.multiplier_fixedpoint_perchannel, nullptr);
- RUY_DCHECK_EQ(storage_.multiplier_exponent_perchannel, nullptr);
- }
storage_.perchannel = perchannel;
}
};
@@ -200,16 +246,24 @@ template <typename DstScalar>
struct MulParamsStorage<std::int32_t, DstScalar> final {
using AccumScalar = std::int32_t;
static_assert(std::is_integral<DstScalar>::value, "");
- static_assert(sizeof(DstScalar) < sizeof(AccumScalar), "");
+ static_assert(sizeof(DstScalar) <= sizeof(AccumScalar) / 2, "");
const AccumScalar* bias = nullptr;
union {
- const AccumScalar* multiplier_fixedpoint_perchannel = nullptr;
- AccumScalar multiplier_fixedpoint;
+ const AccumScalar* multiplier_fixedpoint_perchannel;
+ // Let the default multiplier be effecively a multiplication by 1, so that
+ // the matmul behaves as a (saturating) plain integer matmul. Unfortunately
+ // 1 is not exactly representable in fixedpoint with 0 integer bits, but
+ // using the highest representable value is a sufficiently good
+ // approximation: since this specialization of MulParams is for the case
+ // where DstScalar is at least 2x narrower than MulScalar, the values
+ // for which there would be a difference will get saturated anyway.
+ AccumScalar multiplier_fixedpoint = std::numeric_limits<AccumScalar>::max();
};
union {
- const int* multiplier_exponent_perchannel = nullptr;
- int multiplier_exponent;
+ const int* multiplier_exponent_perchannel;
+ // See the above comment about the default value of multiplier_fixedpoint.
+ int multiplier_exponent = 0;
};
DstScalar clamp_min = std::numeric_limits<DstScalar>::lowest();
DstScalar clamp_max = std::numeric_limits<DstScalar>::max();
diff --git a/ruy/mul_params_test.cc b/ruy/mul_params_test.cc
index feb7dbb..4bc9f87 100644
--- a/ruy/mul_params_test.cc
+++ b/ruy/mul_params_test.cc
@@ -31,7 +31,7 @@ TEST(MulParamsTest, SpecClassSanity) {
MulParamsType mul_params;
EXPECT_EQ(mul_params.bias(), nullptr);
- EXPECT_EQ(mul_params.multiplier_fixedpoint(), 0);
+ EXPECT_EQ(mul_params.multiplier_fixedpoint(), std::numeric_limits<std::int32_t>::max());
EXPECT_EQ(mul_params.multiplier_exponent(), 0);
EXPECT_EQ(mul_params.multiplier_fixedpoint_perchannel(), nullptr);
EXPECT_EQ(mul_params.multiplier_exponent_perchannel(), nullptr);
diff --git a/ruy/ruy.h b/ruy/ruy.h
index 3cf7bdd..ddbe192 100644
--- a/ruy/ruy.h
+++ b/ruy/ruy.h
@@ -93,6 +93,14 @@ void Mul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
// (e.g. the number of CPU cores in typical scenarios). At least ruy forces
// each invocation to make an explicit decision here, there is no automatic
// detection of the best number of threads to use in ruy.
+//
+// Constraints on the template parameters:
+// * If DstScalar is floating-point then AccumScalar must also be.
+// * If DstScalar is integral then AccumScalar must be std::int32_t.
+// Please refer to MulParams' class comment for more information. When
+// DstScalar is integral and is narrower than AccumScalar, additional
+// MulParams fields must be set to control the scaling of internal accumulators
+// before the final saturating cast to the DstScalar type.
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
typename DstScalar>
void Mul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,