Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/google/ruy.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBenoit Jacob <benoitjacob@google.com>2021-03-10 06:18:04 +0300
committerCopybara-Service <copybara-worker@google.com>2021-03-10 06:18:22 +0300
commitb0e97e627db281e2f2d79ac84333a93452173ccd (patch)
treed45d61489913cd94fc24196cadedf3189d72f8cc
parent54774a7a2cf85963777289193629d4bd42de4a59 (diff)
rollback hopefully fixing some application crash
PiperOrigin-RevId: 361951187
-rw-r--r--ruy/apply_multiplier_test.cc5
-rw-r--r--ruy/mul_params.h46
-rw-r--r--ruy/mul_params_test.cc2
-rw-r--r--ruy/ruy.h8
4 files changed, 22 insertions, 39 deletions
diff --git a/ruy/apply_multiplier_test.cc b/ruy/apply_multiplier_test.cc
index 2df80d7..ff4cb2c 100644
--- a/ruy/apply_multiplier_test.cc
+++ b/ruy/apply_multiplier_test.cc
@@ -104,14 +104,9 @@ void TestApplyMultiplier(const MulParams<AccumScalar, DstScalar>& mul_params,
TEST(ApplyMultiplierTest, ApplyMultiplierUniform) {
MulParams<std::int32_t, std::int8_t> mul_params;
- // Test that default values give a multiplication by 1.
- TestApplyMultiplier(mul_params, 0, 1000, 1000);
mul_params.set_multiplier_fixedpoint(1 << 30);
mul_params.set_multiplier_exponent(-1);
TestApplyMultiplier(mul_params, 0, 1000, 250);
- mul_params.set_multiplier_fixedpoint(1 << 25);
- mul_params.set_multiplier_exponent(3);
- TestApplyMultiplier(mul_params, 0, 1000, 125);
}
TEST(ApplyMultiplierTest, ApplyMultiplierPerChannel) {
diff --git a/ruy/mul_params.h b/ruy/mul_params.h
index d5aa27b..3535a75 100644
--- a/ruy/mul_params.h
+++ b/ruy/mul_params.h
@@ -103,14 +103,9 @@ class MulParams final {
// The bias vector data, if not null.
const AccumScalar* bias() const { return storage_.bias; }
void set_bias(const AccumScalar* ptr) { storage_.bias = ptr; }
- // Only for non-floating-point cases. The fixed-point part of the multiplier
- // by which accumulators are multiplied before being casted to the destination
- // type. This is a fixed-point quantity with 0 integer bits. Since
- // (as explained in the class comment) AccumScalar must be std::int32_t,
- // that means that the fixed-point format is Q0.31. For example,
- // a multiplier_fixedpoint value of 2^30 has the effect of multiplying
- // by one half (1/2). More generally, the effect is to multiply by
- // (multiplier_fixedpoint / (2^31)).
+ // Only for non-floating-point cases. The fixed-point part (i.e. the mantissa)
+ // of the multiplier by which accumulators are multiplied before being casted
+ // to the destination type.
AccumScalar multiplier_fixedpoint() const {
return storage_.perchannel ? 0 : storage_.multiplier_fixedpoint;
}
@@ -132,10 +127,9 @@ class MulParams final {
// `multiplier_exponent` are disabled and `multiplier_fixedpoint_perchannel`
// and `multiplier_exponent_perchannel` are used instead.
//
- // This must point to a buffer of as many values as there are rows or columns
- // in the destination matrix, whichever is the channels dimension. Each
- // channel of the destination matrix will use the corresponding buffer element
- // instead of multiplier_fixedpoint.
+ // This must point to a buffer of as many values as there are rows in the
+ // destination matrix. Each row of the destination matrix will use the
+ // corresponding buffer element instead of multiplier_fixedpoint.
const AccumScalar* multiplier_fixedpoint_perchannel() const {
return storage_.perchannel ? storage_.multiplier_fixedpoint_perchannel
: nullptr;
@@ -205,6 +199,16 @@ class MulParams final {
detail::MulParamsStorage<AccumScalar, DstScalar> storage_;
void set_perchannel(bool perchannel) {
+ if (storage_.perchannel == perchannel) {
+ return;
+ }
+ if (perchannel) {
+ RUY_DCHECK_EQ(storage_.multiplier_fixedpoint, 0);
+ RUY_DCHECK_EQ(storage_.multiplier_exponent, 0);
+ } else {
+ RUY_DCHECK_EQ(storage_.multiplier_fixedpoint_perchannel, nullptr);
+ RUY_DCHECK_EQ(storage_.multiplier_exponent_perchannel, nullptr);
+ }
storage_.perchannel = perchannel;
}
};
@@ -240,24 +244,16 @@ template <typename DstScalar>
struct MulParamsStorage<std::int32_t, DstScalar> final {
using AccumScalar = std::int32_t;
static_assert(std::is_integral<DstScalar>::value, "");
- static_assert(sizeof(DstScalar) <= sizeof(AccumScalar) / 2, "");
+ static_assert(sizeof(DstScalar) < sizeof(AccumScalar), "");
const AccumScalar* bias = nullptr;
union {
- const AccumScalar* multiplier_fixedpoint_perchannel;
- // Let the default multiplier be effecively a multiplication by 1, so that
- // the matmul behaves as a (saturating) plain integer matmul. Unfortunately
- // 1 is not exactly representable in fixedpoint with 0 integer bits, but
- // using the highest representable value is a sufficiently good
- // approximation: since this specialization of MulParams is for the case
- // where DstScalar is at least 2x narrower than MulScalar, the values
- // for which there would be a difference will get saturated anyway.
- AccumScalar multiplier_fixedpoint = std::numeric_limits<AccumScalar>::max();
+ const AccumScalar* multiplier_fixedpoint_perchannel = nullptr;
+ AccumScalar multiplier_fixedpoint;
};
union {
- const int* multiplier_exponent_perchannel;
- // See the above comment about the default value of multiplier_fixedpoint.
- int multiplier_exponent = 0;
+ const int* multiplier_exponent_perchannel = nullptr;
+ int multiplier_exponent;
};
DstScalar clamp_min = std::numeric_limits<DstScalar>::lowest();
DstScalar clamp_max = std::numeric_limits<DstScalar>::max();
diff --git a/ruy/mul_params_test.cc b/ruy/mul_params_test.cc
index 4bc9f87..feb7dbb 100644
--- a/ruy/mul_params_test.cc
+++ b/ruy/mul_params_test.cc
@@ -31,7 +31,7 @@ TEST(MulParamsTest, SpecClassSanity) {
MulParamsType mul_params;
EXPECT_EQ(mul_params.bias(), nullptr);
- EXPECT_EQ(mul_params.multiplier_fixedpoint(), std::numeric_limits<std::int32_t>::max());
+ EXPECT_EQ(mul_params.multiplier_fixedpoint(), 0);
EXPECT_EQ(mul_params.multiplier_exponent(), 0);
EXPECT_EQ(mul_params.multiplier_fixedpoint_perchannel(), nullptr);
EXPECT_EQ(mul_params.multiplier_exponent_perchannel(), nullptr);
diff --git a/ruy/ruy.h b/ruy/ruy.h
index ddbe192..3cf7bdd 100644
--- a/ruy/ruy.h
+++ b/ruy/ruy.h
@@ -93,14 +93,6 @@ void Mul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
// (e.g. the number of CPU cores in typical scenarios). At least ruy forces
// each invocation to make an explicit decision here, there is no automatic
// detection of the best number of threads to use in ruy.
-//
-// Constraints on the template parameters:
-// * If DstScalar is floating-point then AccumScalar must also be.
-// * If DstScalar is integral then AccumScalar must be std::int32_t.
-// Please refer to MulParams' class comment for more information. When
-// DstScalar is integral and is narrower than AccumScalar, additional
-// MulParams fields must be set to control the scaling of internal accumulators
-// before the final saturating cast to the DstScalar type.
template <typename LhsScalar, typename RhsScalar, typename AccumScalar,
typename DstScalar>
void Mul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,