Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/google/ruy.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBenoit Jacob <benoitjacob@google.com>2021-03-10 06:37:44 +0300
committerCopybara-Service <copybara-worker@google.com>2021-03-10 06:38:06 +0300
commit939449243eb36e5b668cc00a1c936f2b1ad4dc27 (patch)
treee40f3aa5849a31e132c064b6773bde317899507b
parentb0e97e627db281e2f2d79ac84333a93452173ccd (diff)
Simplify some code and add release assertions to help debug a crash in an application.
PiperOrigin-RevId: 361953871
-rw-r--r--ruy/kernel_common.h7
-rw-r--r--ruy/mul_params.h24
-rw-r--r--ruy/trmul_params.h4
3 files changed, 26 insertions, 9 deletions
diff --git a/ruy/kernel_common.h b/ruy/kernel_common.h
index 9509b8f..cff243b 100644
--- a/ruy/kernel_common.h
+++ b/ruy/kernel_common.h
@@ -177,6 +177,8 @@ void MakeKernelParams8bit(const PMat<std::int8_t>& lhs,
params->prod_zp_depth = lhs.zero_point * rhs.zero_point * depth;
params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT;
if (mul_params.multiplier_fixedpoint_perchannel()) {
+ // Temporary release-assert to debug some crashes in an application.
+ RUY_CHECK(mul_params.multiplier_exponent_perchannel());
params->flags |= RUY_ASM_FLAG_HAS_PERCHANNEL;
params->multiplier_fixedpoint =
mul_params.multiplier_fixedpoint_perchannel();
@@ -200,6 +202,11 @@ void MakeKernelParams8bit(const PMat<std::int8_t>& lhs,
params->dst_type_id = DstTypeId<DstScalar>::kValue;
params->dst_base_ptr =
dst->data.get() + start_col * dst->layout.stride + start_row;
+
+ // Temporary release-asserts to debug some crashes in an application.
+ RUY_CHECK(params->multiplier_fixedpoint);
+ RUY_CHECK(params->multiplier_exponent);
+ RUY_CHECK(params->bias);
}
template <int LhsCols, int RhsCols>
diff --git a/ruy/mul_params.h b/ruy/mul_params.h
index 3535a75..42a5700 100644
--- a/ruy/mul_params.h
+++ b/ruy/mul_params.h
@@ -247,14 +247,22 @@ struct MulParamsStorage<std::int32_t, DstScalar> final {
static_assert(sizeof(DstScalar) < sizeof(AccumScalar), "");
const AccumScalar* bias = nullptr;
- union {
- const AccumScalar* multiplier_fixedpoint_perchannel = nullptr;
- AccumScalar multiplier_fixedpoint;
- };
- union {
- const int* multiplier_exponent_perchannel = nullptr;
- int multiplier_exponent;
- };
+ // union { // This used to be a union, temporarily flattened to debug a crash
+ const AccumScalar* multiplier_fixedpoint_perchannel = nullptr;
+ // Let the default multiplier be effecively a multiplication by 1, so that
+ // the matmul behaves as a (saturating) plain integer matmul. Unfortunately
+ // 1 is not exactly representable in fixedpoint with 0 integer bits, but
+ // using the highest representable value is a sufficiently good
+ // approximation: since this specialization of MulParams is for the case
+ // where DstScalar is at least 2x narrower than MulScalar, the values
+ // for which there would be a difference will get saturated anyway.
+ AccumScalar multiplier_fixedpoint = 0;
+ //};
+ // union { // This used to be a union, temporarily flattened to debug a crash
+ const int* multiplier_exponent_perchannel = nullptr;
+ // See the above comment about the default value of multiplier_fixedpoint.
+ int multiplier_exponent = 0;
+ // };
DstScalar clamp_min = std::numeric_limits<DstScalar>::lowest();
DstScalar clamp_max = std::numeric_limits<DstScalar>::max();
ChannelDimension channel_dimension = ChannelDimension::kRow;
diff --git a/ruy/trmul_params.h b/ruy/trmul_params.h
index e68d909..486a6c6 100644
--- a/ruy/trmul_params.h
+++ b/ruy/trmul_params.h
@@ -53,7 +53,9 @@ constexpr int kMaxMulParamsSize =
kMaxMulParamsSizeQuantizedIntegerCase));
// OK to adjust as needed, but we want to avoid unnecessarily inflating that.
-static_assert(kMaxMulParamsSize <= 32, "");
+// Temporarily bumped from 32 to 48 as part of temporarily not using unions
+// in MulParams.
+static_assert(kMaxMulParamsSize <= 48, "");
// Type-erased data needed for implementing TrMul.
struct TrMulParams {