diff options
author | Benoit Jacob <benoitjacob@google.com> | 2021-03-10 06:37:44 +0300 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2021-03-10 06:38:06 +0300 |
commit | 939449243eb36e5b668cc00a1c936f2b1ad4dc27 (patch) | |
tree | e40f3aa5849a31e132c064b6773bde317899507b | |
parent | b0e97e627db281e2f2d79ac84333a93452173ccd (diff) |
Simplify some code and add release assertions to help debug a crash in an application.
PiperOrigin-RevId: 361953871
-rw-r--r-- | ruy/kernel_common.h | 7 | ||||
-rw-r--r-- | ruy/mul_params.h | 24 | ||||
-rw-r--r-- | ruy/trmul_params.h | 4 |
3 files changed, 26 insertions, 9 deletions
diff --git a/ruy/kernel_common.h b/ruy/kernel_common.h index 9509b8f..cff243b 100644 --- a/ruy/kernel_common.h +++ b/ruy/kernel_common.h @@ -177,6 +177,8 @@ void MakeKernelParams8bit(const PMat<std::int8_t>& lhs, params->prod_zp_depth = lhs.zero_point * rhs.zero_point * depth; params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT; if (mul_params.multiplier_fixedpoint_perchannel()) { + // Temporary release-assert to debug some crashes in an application. + RUY_CHECK(mul_params.multiplier_exponent_perchannel()); params->flags |= RUY_ASM_FLAG_HAS_PERCHANNEL; params->multiplier_fixedpoint = mul_params.multiplier_fixedpoint_perchannel(); @@ -200,6 +202,11 @@ void MakeKernelParams8bit(const PMat<std::int8_t>& lhs, params->dst_type_id = DstTypeId<DstScalar>::kValue; params->dst_base_ptr = dst->data.get() + start_col * dst->layout.stride + start_row; + + // Temporary release-asserts to debug some crashes in an application. + RUY_CHECK(params->multiplier_fixedpoint); + RUY_CHECK(params->multiplier_exponent); + RUY_CHECK(params->bias); } template <int LhsCols, int RhsCols> diff --git a/ruy/mul_params.h b/ruy/mul_params.h index 3535a75..42a5700 100644 --- a/ruy/mul_params.h +++ b/ruy/mul_params.h @@ -247,14 +247,22 @@ struct MulParamsStorage<std::int32_t, DstScalar> final { static_assert(sizeof(DstScalar) < sizeof(AccumScalar), ""); const AccumScalar* bias = nullptr; - union { - const AccumScalar* multiplier_fixedpoint_perchannel = nullptr; - AccumScalar multiplier_fixedpoint; - }; - union { - const int* multiplier_exponent_perchannel = nullptr; - int multiplier_exponent; - }; + // union { // This used to be a union, temporarily flattened to debug a crash + const AccumScalar* multiplier_fixedpoint_perchannel = nullptr; + // Let the default multiplier be effecively a multiplication by 1, so that + // the matmul behaves as a (saturating) plain integer matmul. Unfortunately + // 1 is not exactly representable in fixedpoint with 0 integer bits, but + // using the highest representable value is a sufficiently good + // approximation: since this specialization of MulParams is for the case + // where DstScalar is at least 2x narrower than MulScalar, the values + // for which there would be a difference will get saturated anyway. + AccumScalar multiplier_fixedpoint = 0; + //}; + // union { // This used to be a union, temporarily flattened to debug a crash + const int* multiplier_exponent_perchannel = nullptr; + // See the above comment about the default value of multiplier_fixedpoint. + int multiplier_exponent = 0; + // }; DstScalar clamp_min = std::numeric_limits<DstScalar>::lowest(); DstScalar clamp_max = std::numeric_limits<DstScalar>::max(); ChannelDimension channel_dimension = ChannelDimension::kRow; diff --git a/ruy/trmul_params.h b/ruy/trmul_params.h index e68d909..486a6c6 100644 --- a/ruy/trmul_params.h +++ b/ruy/trmul_params.h @@ -53,7 +53,9 @@ constexpr int kMaxMulParamsSize = kMaxMulParamsSizeQuantizedIntegerCase)); // OK to adjust as needed, but we want to avoid unnecessarily inflating that. -static_assert(kMaxMulParamsSize <= 32, ""); +// Temporarily bumped from 32 to 48 as part of temporarily not using unions +// in MulParams. +static_assert(kMaxMulParamsSize <= 48, ""); // Type-erased data needed for implementing TrMul. struct TrMulParams { |