Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/google/ruy.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2021-02-09 21:49:17 +0300
committerCopybara-Service <copybara-worker@google.com>2021-02-09 21:49:37 +0300
commitbe760b63149d8205dfb3ca66d78a049dc1ab7772 (patch)
tree45b033783360ed59b9efbfd8bd2dfdb4fc8bdb72
parent287015c8ea2b2bbc7780f85650263a92518dcd37 (diff)
Simplify quantized multiplier
Alter sequence to a single rounded scaling with normal rounded shift. Double rounding and symmetric rounding are removed compared to reference. Double rounding seems unnecessary and can complicate implementations. Moreover, symmetric rounding also adds implementation complexity. For NEON the new sequence can be translated to VQDMULH + VRSHR. Closes https://github.com/google/ruy/pull/227 COPYBARA_INTEGRATE_REVIEW=https://github.com/google/ruy/pull/227 from GeorgeARM:mul_pr dec00bd87a8815fdad79d302494430aa63522752 PiperOrigin-RevId: 356539687
-rw-r--r--ruy/apply_multiplier.cc95
-rw-r--r--ruy/apply_multiplier_test.cc5
-rw-r--r--ruy/kernel_arm32.cc45
-rw-r--r--ruy/kernel_arm64.cc426
-rw-r--r--ruy/kernel_common.h5
-rw-r--r--ruy/mul_params.h12
6 files changed, 267 insertions, 321 deletions
diff --git a/ruy/apply_multiplier.cc b/ruy/apply_multiplier.cc
index 1b7df5c..19bfd88 100644
--- a/ruy/apply_multiplier.cc
+++ b/ruy/apply_multiplier.cc
@@ -22,76 +22,47 @@ limitations under the License.
namespace ruy {
namespace detail {
-namespace {
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
// Warning: this code is not meant to be bit-exact-normative.
// Please refer to the class comment of ruy::MulParams, in mul_params.h.
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-// Copied from gemmlowp/fixedpoint.
-// Similar to the ARM64 instruction, SQRDMULH. The name of this function
-// is copied from the name of that instruction.
-// Implements a fixed-point multiplication on values in Q0.31 format, i.e.
-// the int32 values represent real numbers in [-1, 1), the int32 value -2^31
-// represents the real number -1. The 'doubling' part of the name refers to
-// the fact that this returns (up to correct rounding) a*b/2^31, not a*b/2^32
-// as just 'high mul' would suggest.
-std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a, std::int32_t b) {
- bool overflow = a == b && a == std::numeric_limits<std::int32_t>::min();
- std::int64_t a_64(a);
- std::int64_t b_64(b);
- std::int64_t ab_64 = a_64 * b_64;
- std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30));
- std::int32_t ab_x2_high32 =
- static_cast<std::int32_t>((ab_64 + nudge) / (1ll << 31));
- return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32;
-}
-
-// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-// Warning: this code is not meant to be bit-exact-normative.
-// Please refer to the class comment of ruy::MulParams, in mul_params.h.
-// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-// Returns numerator/2^exponent, rounding to nearest, breaking ties
-// upwards. That particular tie-breaking behavior is not important in practice.
-// It happens to be cheap to implement in hardware and therefore, commonplace.
-// In particular, it matches the behavior of ARM NEON rounding right shifts
-// (RSHL with negative shift amount). By contrast, breaking ties away-from-zero
-// or to-nearest-even is a little more expensive and less commonplace in SIMD
-// hardware.
-std::int32_t RoundingRightShift(std::int32_t numerator, int exponent) {
- // According to
- // https://en.cppreference.com/w/cpp/language/operator_arithmetic ,
- // since C++20, "The value of a >> b is a/2^b rounded down (in other words,
- // right shift on signed a is arithmetic right shift)". While we currently
- // target C++14/17, this makes it reasonable to assume that the
- // implementation-defined behavior of a>>b with a<0 has converged to this
- // behavior on current compilers even in C++14/17 modes.
- RUY_DCHECK_GE(exponent, 0);
- RUY_DCHECK_LE(exponent, 31);
- const std::int32_t nudge = (exponent > 0) ? (1 << (exponent - 1)) : 0;
- // if numerator + nudge would overflow, do the computation as if it were 2^31.
- if (numerator > std::numeric_limits<std::int32_t>::max() - nudge) {
- RUY_DCHECK_GE(exponent, 1); // This can't happen with exponent==0.
- return 1 << (31 - exponent);
- }
- return (numerator + nudge) >> exponent;
-}
-
-} // namespace
-
-// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-// Warning: this code is not meant to be bit-exact-normative.
-// Please refer to the class comment of ruy::MulParams, in mul_params.h.
-// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-// Copied from TF Lite code.
+// Simplified multiplier application function
+//
+// Double rounding and symmetric rounding are removed compared to reference.
+// Double rounding seems unnecessary and can complicate implementations.
+// Symmetric rounding also adds implementation complexity.
+//
+// Composed of a single rounding shift right and can lead to more HW
+// friendly implementations.
+//
+// On NEON this can be translated to a SQDMULH + rounding shift right sequence.
+// The use of SQDMULH rather than SQRDMULH gives a result that is
+// equivalent to a single rounded shift since the truncating shift of SQDMULH
+// can be combined with the rounding right shift via the formula (for k>=1):
+// ((x>>31)+(1<<(k-1)))>>k = (x + (1<<(30+k))>>(31+k)
+//
+// Preconditions:
+// - quantized_multiplier >= 0
+// - shift is -31 to +7 (negative for right shift)
std::int32_t MultiplyByQuantizedMultiplier(std::int32_t x,
std::int32_t quantized_multiplier,
int shift) {
- int left_shift = shift > 0 ? shift : 0;
- int right_shift = shift > 0 ? 0 : -shift;
- return RoundingRightShift(SaturatingRoundingDoublingHighMul(
- x * (1 << left_shift), quantized_multiplier),
- right_shift);
+ RUY_CHECK_GE(shift, -31);
+ RUY_CHECK_LE(shift, 7);
+
+ int total_shift = 31 - shift;
+
+ std::int64_t x_64(x);
+ std::int64_t quantized_multiplier_64(quantized_multiplier);
+ std::int64_t round = (int64_t)1 << (total_shift - 1);
+ int64_t result = x_64 * quantized_multiplier_64 + round;
+ result = result >> total_shift;
+
+ RUY_DCHECK_GE(result, std::numeric_limits<std::int32_t>::lowest());
+ RUY_DCHECK_LE(result, std::numeric_limits<std::int32_t>::max());
+
+ return static_cast<std::int32_t>(result);
}
} // namespace detail
diff --git a/ruy/apply_multiplier_test.cc b/ruy/apply_multiplier_test.cc
index cae6888..2df80d7 100644
--- a/ruy/apply_multiplier_test.cc
+++ b/ruy/apply_multiplier_test.cc
@@ -67,10 +67,7 @@ TEST(ApplyMultiplierTest, RoundingRightShift) {
TestMultiplyByQuantizedMultiplier(1000, max_int32, -1, 500);
TestMultiplyByQuantizedMultiplier(1000, max_int32, -2, 250);
TestMultiplyByQuantizedMultiplier(1000, max_int32, -3, 125);
- // This 63 value comes from rounding 62.5, which is a tie.
- // As a positive value, it does not distinguish between 'upward'
- // and 'away from zero' tie-breaking behavior.
- TestMultiplyByQuantizedMultiplier(1000, max_int32, -4, 63);
+ TestMultiplyByQuantizedMultiplier(1000, max_int32, -4, 62);
TestMultiplyByQuantizedMultiplier(1000, max_int32, -5, 31);
TestMultiplyByQuantizedMultiplier(1000, max_int32, -6, 16);
TestMultiplyByQuantizedMultiplier(-1000, max_int32, -1, -500);
diff --git a/ruy/kernel_arm32.cc b/ruy/kernel_arm32.cc
index 4ab58b2..b20f668 100644
--- a/ruy/kernel_arm32.cc
+++ b/ruy/kernel_arm32.cc
@@ -1023,47 +1023,47 @@ void Kernel8bitNeon(const KernelParams8bit<4, 2>& params) {
"vld1.32 {d12}, [r2]!\n" // 2 values of multiplier_fixedpoint
"tst r6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
- RUY_MAKE_ZERO(q8)
+ "vmvn.i32 q8, #0\n"
"bne 8f\n"
// Case where channels are rows.
// Load the remaining 2 bias values, since we're on the width-4 side
// of this 4x2 kernel.
"vld1.32 {d21}, [r1]\n" // 2 more values of multiplier_exponent
"vld1.32 {d13}, [r2]\n" // 2 more values of multiplier_fixedpoint
- "vmax.s32 q11, q10, q8\n"
- "vmin.s32 q10, q10, q8\n"
+ "vmin.s32 q11, q10, q8\n"
+ "vsub.s32 q10, q10, q11\n"
// Apply the positive exponent part of the multiplier.
- "vshl.s32 q14, q14, q11\n"
- "vshl.s32 q15, q15, q11\n"
+ "vshl.s32 q14, q14, q10\n"
+ "vshl.s32 q15, q15, q10\n"
// Apply the fixed-point part of the multiplier.
- "vqrdmulh.s32 q14, q14, q6\n"
- "vqrdmulh.s32 q15, q15, q6\n"
+ "vqdmulh.s32 q14, q14, q6\n"
+ "vqdmulh.s32 q15, q15, q6\n"
// Apply the negative exponent part of the multiplier.
- "vrshl.s32 q14, q14, q10\n"
- "vrshl.s32 q15, q15, q10\n"
+ "vrshl.s32 q14, q14, q11\n"
+ "vrshl.s32 q15, q15, q11\n"
"b 9f\n"
"8:\n"
// Case where channels are columns.
- "vmax.s32 d22, d20, d16\n"
- "vmin.s32 d20, d20, d16\n"
+ "vmin.s32 d22, d20, d16\n"
+ "vsub.s32 d20, d20, d22\n"
// Apply the positive exponent part of the multiplier.
- "vdup.32 q12, d22[0]\n"
- "vdup.32 q13, d22[1]\n"
+ "vdup.32 q12, d20[0]\n"
+ "vdup.32 q13, d20[1]\n"
"vshl.s32 q14, q14, q12\n"
"vshl.s32 q15, q15, q13\n"
// Apply the fixed-point part of the multiplier.
- "vqrdmulh.s32 q14, q14, d12[0]\n"
- "vqrdmulh.s32 q15, q15, d12[1]\n"
+ "vqdmulh.s32 q14, q14, d12[0]\n"
+ "vqdmulh.s32 q15, q15, d12[1]\n"
// Apply the negative exponent part of the multiplier.
- "vdup.32 q12, d20[0]\n"
- "vdup.32 q13, d20[1]\n"
+ "vdup.32 q12, d22[0]\n"
+ "vdup.32 q13, d22[1]\n"
"vrshl.s32 q14, q14, q12\n"
"vrshl.s32 q15, q15, q13\n"
@@ -1961,14 +1961,13 @@ void Kernel8bitNeon1Col(const KernelParams8bit<4, 2>& params) {
"vld1.32 {q10}, [r1]\n"
- RUY_MAKE_ZERO(q8)
- "vmax.s32 q12, q10, q8\n"
+ "vmvn.i32 q8, #0\n"
+ "vmin.s32 q13, q10, q8\n"
+ "vsub.s32 q12, q10, q13\n"
// Apply the positive exponent part of the multiplier.
"vshl.s32 q14, q14, q12\n"
- "vmin.s32 q12, q10, q8\n"
-
// Load fixed point part of the multiplier
"ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
// r6 has flags, r4 has row
@@ -1978,10 +1977,10 @@ void Kernel8bitNeon1Col(const KernelParams8bit<4, 2>& params) {
"vld1.32 {q10}, [r1]\n" // multiplier_fixedpoint
// Apply the fixed-point part of the multiplier.
- "vqrdmulh.s32 q14, q14, q10\n"
+ "vqdmulh.s32 q14, q14, q10\n"
// Apply the negative exponent part of the multiplier.
- "vrshl.s32 q14, q14, q12\n"
+ "vrshl.s32 q14, q14, q13\n"
"ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n"
"cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
diff --git a/ruy/kernel_arm64.cc b/ruy/kernel_arm64.cc
index ea6bcb2..fe65d9c 100644
--- a/ruy/kernel_arm64.cc
+++ b/ruy/kernel_arm64.cc
@@ -400,7 +400,7 @@ void Kernel8bitNeon(const KernelParams8bit<4, 4>& params) {
"mov %[rhs_ptr], %[rhs_col_ptr]\n"
// Load some parameters needed for the end work on current block.
- RUY_MAKE_ZERO(v8)
+ "mvni v8.4s, #0\n"
"ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
"ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
"ins v13.h[4], w4\n" // dst_zero_point
@@ -533,8 +533,8 @@ void Kernel8bitNeon(const KernelParams8bit<4, 4>& params) {
"ld1 {v14.4s}, [x1]\n"
- "smax v12.4s, v14.4s, v8.4s\n"
- "smin v11.4s, v14.4s, v8.4s\n"
+ "smin v11.4s, v8.4s, v14.4s\n"
+ "sub v12.4s, v14.4s, v11.4s\n"
// Jump based on channel dimension.
"tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
@@ -548,10 +548,10 @@ void Kernel8bitNeon(const KernelParams8bit<4, 4>& params) {
"sshl v19.4s, v19.4s, v12.4s\n"
// Apply the fixed-point part of the multiplier.
- "sqrdmulh v16.4s, v16.4s, v15.4s\n"
- "sqrdmulh v17.4s, v17.4s, v15.4s\n"
- "sqrdmulh v18.4s, v18.4s, v15.4s\n"
- "sqrdmulh v19.4s, v19.4s, v15.4s\n"
+ "sqdmulh v16.4s, v16.4s, v15.4s\n"
+ "sqdmulh v17.4s, v17.4s, v15.4s\n"
+ "sqdmulh v18.4s, v18.4s, v15.4s\n"
+ "sqdmulh v19.4s, v19.4s, v15.4s\n"
// Apply the negative exponent part of the multiplier.
"srshl v16.4s, v16.4s, v11.4s\n"
@@ -574,10 +574,10 @@ void Kernel8bitNeon(const KernelParams8bit<4, 4>& params) {
"sshl v19.4s, v19.4s, v23.4s\n"
// Apply the fixed-point part of the multiplier.
- "sqrdmulh v16.4s, v16.4s, v15.s[0]\n"
- "sqrdmulh v17.4s, v17.4s, v15.s[1]\n"
- "sqrdmulh v18.4s, v18.4s, v15.s[2]\n"
- "sqrdmulh v19.4s, v19.4s, v15.s[3]\n"
+ "sqdmulh v16.4s, v16.4s, v15.s[0]\n"
+ "sqdmulh v17.4s, v17.4s, v15.s[1]\n"
+ "sqdmulh v18.4s, v18.4s, v15.s[2]\n"
+ "sqdmulh v19.4s, v19.4s, v15.s[3]\n"
// Apply the negative exponent part of the multiplier.
"dup v20.4s, v11.s[0]\n"
@@ -1348,7 +1348,7 @@ void Kernel8bitNeon1Col(const KernelParams8bit<4, 4>& params) {
"mov %[rhs_ptr], %[rhs_col_ptr]\n"
// Load some parameters needed for the end work on current block.
- RUY_MAKE_ZERO(v8)
+ "mvni v8.4s, #0\n"
"ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
"ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
"ins v13.h[4], w4\n" // dst_zero_point
@@ -1441,18 +1441,17 @@ void Kernel8bitNeon1Col(const KernelParams8bit<4, 4>& params) {
"ld1 {v14.4s}, [x1]\n"
- "smax v12.4s, v14.4s, v8.4s\n"
+ "smin v11.4s, v8.4s, v14.4s\n"
+ "sub v12.4s, v14.4s, v11.4s\n"
// Apply the positive exponent part of the multiplier.
"sshl v16.4s, v16.4s, v12.4s\n"
- "smin v12.4s, v14.4s, v8.4s\n"
-
// Apply the fixed-point part of the multiplier.
- "sqrdmulh v16.4s, v16.4s, v15.4s\n"
+ "sqdmulh v16.4s, v16.4s, v15.4s\n"
// Apply the negative exponent part of the multiplier.
- "srshl v16.4s, v16.4s, v12.4s\n"
+ "srshl v16.4s, v16.4s, v11.4s\n"
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
@@ -2146,7 +2145,7 @@ void Kernel8bitNeonA55ish(const KernelParams8bit<4, 4>& params) {
"mov %[rhs_ptr], %[rhs_col_ptr]\n"
// Load some parameters needed for the end work on current block.
- RUY_MAKE_ZERO(v8)
+ "mvni v8.4s, #0\n"
"ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
"ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
"ins v13.h[4], w4\n" // dst_zero_point
@@ -2284,9 +2283,9 @@ void Kernel8bitNeonA55ish(const KernelParams8bit<4, 4>& params) {
"ld1 {v14.4s}, [x1]\n"
- "smax v12.4s, v14.4s, v8.4s\n"
+ "smin v11.4s, v8.4s, v14.4s\n"
"ldr x1, [%[lhs_ptr], #8]\n"
- "smin v11.4s, v14.4s, v8.4s\n"
+ "sub v12.4s, v14.4s, v11.4s\n"
// Jump based on channel dimension.
"tst w6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
@@ -2307,16 +2306,16 @@ void Kernel8bitNeonA55ish(const KernelParams8bit<4, 4>& params) {
// Apply the fixed-point part of the multiplier.
"ins v0.d[1], x1\n"
"ldr x1, [%[rhs_ptr], #8]\n"
- "sqrdmulh v16.4s, v16.4s, v15.4s\n"
+ "sqdmulh v16.4s, v16.4s, v15.4s\n"
"ins v1.d[1], x2\n"
"ldr x2, [%[rhs_ptr], #24]\n"
- "sqrdmulh v17.4s, v17.4s, v15.4s\n"
+ "sqdmulh v17.4s, v17.4s, v15.4s\n"
"ins v2.d[1], x3\n"
"ldr x3, [%[rhs_ptr], #40]\n"
- "sqrdmulh v18.4s, v18.4s, v15.4s\n"
+ "sqdmulh v18.4s, v18.4s, v15.4s\n"
"ins v3.d[1], x4\n"
"ldr x4, [%[rhs_ptr], #56]\n"
- "sqrdmulh v19.4s, v19.4s, v15.4s\n"
+ "sqdmulh v19.4s, v19.4s, v15.4s\n"
// Apply the negative exponent part of the multiplier.
"srshl v16.4s, v16.4s, v11.4s\n"
@@ -2349,13 +2348,13 @@ void Kernel8bitNeonA55ish(const KernelParams8bit<4, 4>& params) {
"ldr x3, [%[rhs_ptr], #40]\n"
// Apply the fixed-point part of the multiplier.
- "sqrdmulh v16.4s, v16.4s, v15.s[0]\n"
+ "sqdmulh v16.4s, v16.4s, v15.s[0]\n"
"ins v3.d[1], x4\n"
- "sqrdmulh v17.4s, v17.4s, v15.s[1]\n"
+ "sqdmulh v17.4s, v17.4s, v15.s[1]\n"
"ldr x4, [%[rhs_ptr], #56]\n"
- "sqrdmulh v18.4s, v18.4s, v15.s[2]\n"
+ "sqdmulh v18.4s, v18.4s, v15.s[2]\n"
"dup v20.4s, v11.s[0]\n"
- "sqrdmulh v19.4s, v19.4s, v15.s[3]\n"
+ "sqdmulh v19.4s, v19.4s, v15.s[3]\n"
// Apply the negative exponent part of the multiplier.
"dup v21.4s, v11.s[1]\n"
@@ -3351,7 +3350,7 @@ void Kernel8bitNeonDotprod(const KernelParams8bit<8, 8>& params) {
"mov %[rhs_ptr], %[rhs_col_ptr]\n"
// Load some parameters needed for the end work on current block.
- RUY_MAKE_ZERO(v8)
+ "mvni v8.4s, #0\n"
"ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
"ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
"dup v9.4s, w3\n" // create prod_zp_depth_vec
@@ -3523,10 +3522,10 @@ void Kernel8bitNeonDotprod(const KernelParams8bit<8, 8>& params) {
"ldr q9, [x1]\n"
"ldr q10, [x1, #16]\n"
// Separate positive and negative exponents
- "smax v11.4s, v9.4s, v8.4s\n"
- "smax v12.4s, v10.4s, v8.4s\n"
- "smin v9.4s, v9.4s, v8.4s\n"
- "smin v10.4s, v10.4s, v8.4s\n"
+ "smin v11.4s, v8.4s, v9.4s\n"
+ "smin v12.4s, v8.4s, v10.4s\n"
+ "sub v9.4s, v9.4s, v11.4s\n"
+ "sub v10.4s, v10.4s, v12.4s\n"
// Compute the multiplier_fixedpoint pointer
"ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
@@ -3541,74 +3540,70 @@ void Kernel8bitNeonDotprod(const KernelParams8bit<8, 8>& params) {
"bne 8f\n"
// Case where channels are rows
- "tst w6, #" RUY_STR(RUY_ASM_FLAG_NEEDS_LEFT_SHIFT) "\n"
- "beq 10f\n"
// Apply the positive exponent part of the multiplier.
- "sshl v16.4s, v16.4s, v11.4s\n"
- "sshl v17.4s, v17.4s, v12.4s\n"
- "sshl v18.4s, v18.4s, v11.4s\n"
- "sshl v19.4s, v19.4s, v12.4s\n"
- "sshl v20.4s, v20.4s, v11.4s\n"
- "sshl v21.4s, v21.4s, v12.4s\n"
- "sshl v22.4s, v22.4s, v11.4s\n"
- "sshl v23.4s, v23.4s, v12.4s\n"
- "sshl v24.4s, v24.4s, v11.4s\n"
- "sshl v25.4s, v25.4s, v12.4s\n"
- "sshl v26.4s, v26.4s, v11.4s\n"
- "sshl v27.4s, v27.4s, v12.4s\n"
- "sshl v28.4s, v28.4s, v11.4s\n"
- "sshl v29.4s, v29.4s, v12.4s\n"
- "sshl v30.4s, v30.4s, v11.4s\n"
- "sshl v31.4s, v31.4s, v12.4s\n"
+ "sshl v16.4s, v16.4s, v9.4s\n"
+ "sshl v17.4s, v17.4s, v10.4s\n"
+ "sshl v18.4s, v18.4s, v9.4s\n"
+ "sshl v19.4s, v19.4s, v10.4s\n"
+ "sshl v20.4s, v20.4s, v9.4s\n"
+ "sshl v21.4s, v21.4s, v10.4s\n"
+ "sshl v22.4s, v22.4s, v9.4s\n"
+ "sshl v23.4s, v23.4s, v10.4s\n"
+ "sshl v24.4s, v24.4s, v9.4s\n"
+ "sshl v25.4s, v25.4s, v10.4s\n"
+ "sshl v26.4s, v26.4s, v9.4s\n"
+ "sshl v27.4s, v27.4s, v10.4s\n"
+ "sshl v28.4s, v28.4s, v9.4s\n"
+ "sshl v29.4s, v29.4s, v10.4s\n"
+ "sshl v30.4s, v30.4s, v9.4s\n"
+ "sshl v31.4s, v31.4s, v10.4s\n"
"10:\n"
// Apply the fixed-point part of the multiplier.
- "sqrdmulh v16.4s, v16.4s, v14.4s\n"
- "sqrdmulh v17.4s, v17.4s, v15.4s\n"
- "sqrdmulh v18.4s, v18.4s, v14.4s\n"
- "sqrdmulh v19.4s, v19.4s, v15.4s\n"
- "sqrdmulh v20.4s, v20.4s, v14.4s\n"
- "sqrdmulh v21.4s, v21.4s, v15.4s\n"
- "sqrdmulh v22.4s, v22.4s, v14.4s\n"
- "sqrdmulh v23.4s, v23.4s, v15.4s\n"
- "sqrdmulh v24.4s, v24.4s, v14.4s\n"
- "sqrdmulh v25.4s, v25.4s, v15.4s\n"
- "sqrdmulh v26.4s, v26.4s, v14.4s\n"
- "sqrdmulh v27.4s, v27.4s, v15.4s\n"
- "sqrdmulh v28.4s, v28.4s, v14.4s\n"
- "sqrdmulh v29.4s, v29.4s, v15.4s\n"
- "sqrdmulh v30.4s, v30.4s, v14.4s\n"
- "sqrdmulh v31.4s, v31.4s, v15.4s\n"
+ "sqdmulh v16.4s, v16.4s, v14.4s\n"
+ "sqdmulh v17.4s, v17.4s, v15.4s\n"
+ "sqdmulh v18.4s, v18.4s, v14.4s\n"
+ "sqdmulh v19.4s, v19.4s, v15.4s\n"
+ "sqdmulh v20.4s, v20.4s, v14.4s\n"
+ "sqdmulh v21.4s, v21.4s, v15.4s\n"
+ "sqdmulh v22.4s, v22.4s, v14.4s\n"
+ "sqdmulh v23.4s, v23.4s, v15.4s\n"
+ "sqdmulh v24.4s, v24.4s, v14.4s\n"
+ "sqdmulh v25.4s, v25.4s, v15.4s\n"
+ "sqdmulh v26.4s, v26.4s, v14.4s\n"
+ "sqdmulh v27.4s, v27.4s, v15.4s\n"
+ "sqdmulh v28.4s, v28.4s, v14.4s\n"
+ "sqdmulh v29.4s, v29.4s, v15.4s\n"
+ "sqdmulh v30.4s, v30.4s, v14.4s\n"
+ "sqdmulh v31.4s, v31.4s, v15.4s\n"
// Apply the negative exponent part of the multiplier.
- "srshl v16.4s, v16.4s, v9.4s\n"
- "srshl v17.4s, v17.4s, v10.4s\n"
- "srshl v18.4s, v18.4s, v9.4s\n"
- "srshl v19.4s, v19.4s, v10.4s\n"
- "srshl v20.4s, v20.4s, v9.4s\n"
- "srshl v21.4s, v21.4s, v10.4s\n"
- "srshl v22.4s, v22.4s, v9.4s\n"
- "srshl v23.4s, v23.4s, v10.4s\n"
- "srshl v24.4s, v24.4s, v9.4s\n"
- "srshl v25.4s, v25.4s, v10.4s\n"
- "srshl v26.4s, v26.4s, v9.4s\n"
- "srshl v27.4s, v27.4s, v10.4s\n"
- "srshl v28.4s, v28.4s, v9.4s\n"
- "srshl v29.4s, v29.4s, v10.4s\n"
- "srshl v30.4s, v30.4s, v9.4s\n"
- "srshl v31.4s, v31.4s, v10.4s\n"
+ "srshl v16.4s, v16.4s, v11.4s\n"
+ "srshl v17.4s, v17.4s, v12.4s\n"
+ "srshl v18.4s, v18.4s, v11.4s\n"
+ "srshl v19.4s, v19.4s, v12.4s\n"
+ "srshl v20.4s, v20.4s, v11.4s\n"
+ "srshl v21.4s, v21.4s, v12.4s\n"
+ "srshl v22.4s, v22.4s, v11.4s\n"
+ "srshl v23.4s, v23.4s, v12.4s\n"
+ "srshl v24.4s, v24.4s, v11.4s\n"
+ "srshl v25.4s, v25.4s, v12.4s\n"
+ "srshl v26.4s, v26.4s, v11.4s\n"
+ "srshl v27.4s, v27.4s, v12.4s\n"
+ "srshl v28.4s, v28.4s, v11.4s\n"
+ "srshl v29.4s, v29.4s, v12.4s\n"
+ "srshl v30.4s, v30.4s, v11.4s\n"
+ "srshl v31.4s, v31.4s, v12.4s\n"
"b 9f\n"
"8:\n"
// Case where channels are columns
- "tst w6, #" RUY_STR(RUY_ASM_FLAG_NEEDS_LEFT_SHIFT) "\n"
- "beq 11f\n"
// Apply the positive exponent part of the multiplier.
- "dup v4.4s, v11.s[0]\n"
- "dup v5.4s, v11.s[1]\n"
- "dup v6.4s, v11.s[2]\n"
- "dup v7.4s, v11.s[3]\n"
+ "dup v4.4s, v9.s[0]\n"
+ "dup v5.4s, v9.s[1]\n"
+ "dup v6.4s, v9.s[2]\n"
+ "dup v7.4s, v9.s[3]\n"
"sshl v16.4s, v16.4s, v4.4s\n"
"sshl v17.4s, v17.4s, v4.4s\n"
"sshl v18.4s, v18.4s, v5.4s\n"
@@ -3617,10 +3612,10 @@ void Kernel8bitNeonDotprod(const KernelParams8bit<8, 8>& params) {
"sshl v21.4s, v21.4s, v6.4s\n"
"sshl v22.4s, v22.4s, v7.4s\n"
"sshl v23.4s, v23.4s, v7.4s\n"
- "dup v4.4s, v12.s[0]\n"
- "dup v5.4s, v12.s[1]\n"
- "dup v6.4s, v12.s[2]\n"
- "dup v7.4s, v12.s[3]\n"
+ "dup v4.4s, v10.s[0]\n"
+ "dup v5.4s, v10.s[1]\n"
+ "dup v6.4s, v10.s[2]\n"
+ "dup v7.4s, v10.s[3]\n"
"sshl v24.4s, v24.4s, v4.4s\n"
"sshl v25.4s, v25.4s, v4.4s\n"
"sshl v26.4s, v26.4s, v5.4s\n"
@@ -3632,28 +3627,28 @@ void Kernel8bitNeonDotprod(const KernelParams8bit<8, 8>& params) {
"11:\n"
// Apply the fixed-point part of the multiplier.
- "sqrdmulh v16.4s, v16.4s, v14.s[0]\n"
- "sqrdmulh v17.4s, v17.4s, v14.s[0]\n"
- "sqrdmulh v18.4s, v18.4s, v14.s[1]\n"
- "sqrdmulh v19.4s, v19.4s, v14.s[1]\n"
- "sqrdmulh v20.4s, v20.4s, v14.s[2]\n"
- "sqrdmulh v21.4s, v21.4s, v14.s[2]\n"
- "sqrdmulh v22.4s, v22.4s, v14.s[3]\n"
- "sqrdmulh v23.4s, v23.4s, v14.s[3]\n"
- "sqrdmulh v24.4s, v24.4s, v15.s[0]\n"
- "sqrdmulh v25.4s, v25.4s, v15.s[0]\n"
- "sqrdmulh v26.4s, v26.4s, v15.s[1]\n"
- "sqrdmulh v27.4s, v27.4s, v15.s[1]\n"
- "sqrdmulh v28.4s, v28.4s, v15.s[2]\n"
- "sqrdmulh v29.4s, v29.4s, v15.s[2]\n"
- "sqrdmulh v30.4s, v30.4s, v15.s[3]\n"
- "sqrdmulh v31.4s, v31.4s, v15.s[3]\n"
+ "sqdmulh v16.4s, v16.4s, v14.s[0]\n"
+ "sqdmulh v17.4s, v17.4s, v14.s[0]\n"
+ "sqdmulh v18.4s, v18.4s, v14.s[1]\n"
+ "sqdmulh v19.4s, v19.4s, v14.s[1]\n"
+ "sqdmulh v20.4s, v20.4s, v14.s[2]\n"
+ "sqdmulh v21.4s, v21.4s, v14.s[2]\n"
+ "sqdmulh v22.4s, v22.4s, v14.s[3]\n"
+ "sqdmulh v23.4s, v23.4s, v14.s[3]\n"
+ "sqdmulh v24.4s, v24.4s, v15.s[0]\n"
+ "sqdmulh v25.4s, v25.4s, v15.s[0]\n"
+ "sqdmulh v26.4s, v26.4s, v15.s[1]\n"
+ "sqdmulh v27.4s, v27.4s, v15.s[1]\n"
+ "sqdmulh v28.4s, v28.4s, v15.s[2]\n"
+ "sqdmulh v29.4s, v29.4s, v15.s[2]\n"
+ "sqdmulh v30.4s, v30.4s, v15.s[3]\n"
+ "sqdmulh v31.4s, v31.4s, v15.s[3]\n"
// Apply the negative exponent part of the multiplier.
- "dup v4.4s, v9.s[0]\n"
- "dup v5.4s, v9.s[1]\n"
- "dup v6.4s, v9.s[2]\n"
- "dup v7.4s, v9.s[3]\n"
+ "dup v4.4s, v11.s[0]\n"
+ "dup v5.4s, v11.s[1]\n"
+ "dup v6.4s, v11.s[2]\n"
+ "dup v7.4s, v11.s[3]\n"
"srshl v16.4s, v16.4s, v4.4s\n"
"srshl v17.4s, v17.4s, v4.4s\n"
"srshl v18.4s, v18.4s, v5.4s\n"
@@ -3662,10 +3657,10 @@ void Kernel8bitNeonDotprod(const KernelParams8bit<8, 8>& params) {
"srshl v21.4s, v21.4s, v6.4s\n"
"srshl v22.4s, v22.4s, v7.4s\n"
"srshl v23.4s, v23.4s, v7.4s\n"
- "dup v4.4s, v10.s[0]\n"
- "dup v5.4s, v10.s[1]\n"
- "dup v6.4s, v10.s[2]\n"
- "dup v7.4s, v10.s[3]\n"
+ "dup v4.4s, v12.s[0]\n"
+ "dup v5.4s, v12.s[1]\n"
+ "dup v6.4s, v12.s[2]\n"
+ "dup v7.4s, v12.s[3]\n"
"srshl v24.4s, v24.4s, v4.4s\n"
"srshl v25.4s, v25.4s, v4.4s\n"
"srshl v26.4s, v26.4s, v5.4s\n"
@@ -4570,7 +4565,7 @@ void Kernel8bitNeonDotprod1Col(const KernelParams8bit<8, 8>& params) {
"mov %[rhs_ptr], %[rhs_col_ptr]\n"
// Load some parameters needed for the end work on current block.
- RUY_MAKE_ZERO(v8)
+ "mvni v8.4s, #0\n"
"ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
"ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
"ins v13.h[4], w4\n" // dst_zero_point
@@ -4662,24 +4657,22 @@ void Kernel8bitNeonDotprod1Col(const KernelParams8bit<8, 8>& params) {
"ldr q9, [x1]\n"
"ldr q10, [x1, #16]\n"
- "tst w6, #" RUY_STR(RUY_ASM_FLAG_NEEDS_LEFT_SHIFT) "\n"
- "beq 403f\n"
- "smax v11.4s, v9.4s, v8.4s\n"
- "smax v12.4s, v10.4s, v8.4s\n"
+ "smin v11.4s, v8.4s, v9.4s\n"
+ "smin v12.4s, v8.4s, v10.4s\n"
+ "sub v9.4s, v9.4s, v11.4s\n"
+ "sub v10.4s, v10.4s, v12.4s\n"
+
// Apply the positive exponent part of the multiplier.
- "sshl v16.4s, v16.4s, v11.4s\n"
- "sshl v17.4s, v17.4s, v12.4s\n"
+ "sshl v16.4s, v16.4s, v9.4s\n"
+ "sshl v17.4s, v17.4s, v10.4s\n"
"403:\n"
"ldr q14, [x4]\n" // multiplier_fixedpoint
"ldr q15, [x4, #16]\n" // multiplier_fixedpoint
- "smin v11.4s, v9.4s, v8.4s\n"
- "smin v12.4s, v10.4s, v8.4s\n"
-
// Apply the fixed-point part of the multiplier.
- "sqrdmulh v16.4s, v16.4s, v14.4s\n"
- "sqrdmulh v17.4s, v17.4s, v15.4s\n"
+ "sqdmulh v16.4s, v16.4s, v14.4s\n"
+ "sqdmulh v17.4s, v17.4s, v15.4s\n"
// Apply the negative exponent part of the multiplier.
"srshl v16.4s, v16.4s, v11.4s\n"
@@ -5295,7 +5288,7 @@ void Kernel8bitNeonDotprodA55ish(const KernelParams8bit<8, 8>& params) {
// computed.
"mov %[lhs_ptr], %[lhs_col_ptr]\n"
// Load some parameters needed for the end work on current block.
- RUY_MAKE_ZERO(v8)
+ "mvni v8.4s, #0\n"
"mov %[rhs_ptr], %[rhs_col_ptr]\n"
"ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
"ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
@@ -5469,10 +5462,10 @@ void Kernel8bitNeonDotprodA55ish(const KernelParams8bit<8, 8>& params) {
"ldr q9, [x1]\n"
"ldr q10, [x1, #16]\n"
// Separate positive and negative exponents
- "smax v11.4s, v9.4s, v8.4s\n"
- "smax v12.4s, v10.4s, v8.4s\n"
- "smin v9.4s, v9.4s, v8.4s\n"
- "smin v10.4s, v10.4s, v8.4s\n"
+ "smin v11.4s, v8.4s, v9.4s\n"
+ "smin v12.4s, v8.4s, v10.4s\n"
+ "sub v9.4s, v9.4s, v11.4s\n"
+ "sub v10.4s, v10.4s, v12.4s\n"
// Compute the multiplier_fixedpoint pointer
"ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
@@ -5487,26 +5480,23 @@ void Kernel8bitNeonDotprodA55ish(const KernelParams8bit<8, 8>& params) {
"bne 8f\n"
// Case where channels are rows
- "tst w6, #" RUY_STR(RUY_ASM_FLAG_NEEDS_LEFT_SHIFT) "\n"
- "beq 10f\n"
-
// Apply the positive exponent part of the multiplier.
- "sshl v16.4s, v16.4s, v11.4s\n"
- "sshl v17.4s, v17.4s, v12.4s\n"
- "sshl v18.4s, v18.4s, v11.4s\n"
- "sshl v19.4s, v19.4s, v12.4s\n"
- "sshl v20.4s, v20.4s, v11.4s\n"
- "sshl v21.4s, v21.4s, v12.4s\n"
- "sshl v22.4s, v22.4s, v11.4s\n"
- "sshl v23.4s, v23.4s, v12.4s\n"
- "sshl v24.4s, v24.4s, v11.4s\n"
- "sshl v25.4s, v25.4s, v12.4s\n"
- "sshl v26.4s, v26.4s, v11.4s\n"
- "sshl v27.4s, v27.4s, v12.4s\n"
- "sshl v28.4s, v28.4s, v11.4s\n"
- "sshl v29.4s, v29.4s, v12.4s\n"
- "sshl v30.4s, v30.4s, v11.4s\n"
- "sshl v31.4s, v31.4s, v12.4s\n"
+ "sshl v16.4s, v16.4s, v9.4s\n"
+ "sshl v17.4s, v17.4s, v10.4s\n"
+ "sshl v18.4s, v18.4s, v9.4s\n"
+ "sshl v19.4s, v19.4s, v10.4s\n"
+ "sshl v20.4s, v20.4s, v9.4s\n"
+ "sshl v21.4s, v21.4s, v10.4s\n"
+ "sshl v22.4s, v22.4s, v9.4s\n"
+ "sshl v23.4s, v23.4s, v10.4s\n"
+ "sshl v24.4s, v24.4s, v9.4s\n"
+ "sshl v25.4s, v25.4s, v10.4s\n"
+ "sshl v26.4s, v26.4s, v9.4s\n"
+ "sshl v27.4s, v27.4s, v10.4s\n"
+ "sshl v28.4s, v28.4s, v9.4s\n"
+ "sshl v29.4s, v29.4s, v10.4s\n"
+ "sshl v30.4s, v30.4s, v9.4s\n"
+ "sshl v31.4s, v31.4s, v10.4s\n"
"10:\n"
// Apply the fixed-point part of the multiplier.
@@ -5517,77 +5507,75 @@ void Kernel8bitNeonDotprodA55ish(const KernelParams8bit<8, 8>& params) {
// each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
// in the rest of the work on the current block.
"ld1 {v0.8b}, [%[lhs_ptr]], #8\n"
- "sqrdmulh v16.4s, v16.4s, v14.4s\n"
+ "sqdmulh v16.4s, v16.4s, v14.4s\n"
"ldr x1, [%[lhs_ptr]], #8\n"
- "sqrdmulh v17.4s, v17.4s, v15.4s\n"
+ "sqdmulh v17.4s, v17.4s, v15.4s\n"
"ld1 {v1.8b}, [%[lhs_ptr]], #8\n"
- "sqrdmulh v18.4s, v18.4s, v14.4s\n"
+ "sqdmulh v18.4s, v18.4s, v14.4s\n"
"ldr x2, [%[lhs_ptr]], #8\n"
- "sqrdmulh v19.4s, v19.4s, v15.4s\n"
+ "sqdmulh v19.4s, v19.4s, v15.4s\n"
"ld1 {v2.8b}, [%[rhs_ptr]], #8\n"
- "sqrdmulh v20.4s, v20.4s, v14.4s\n"
+ "sqdmulh v20.4s, v20.4s, v14.4s\n"
"ldr x5, [%[rhs_ptr]], #8\n"
- "sqrdmulh v21.4s, v21.4s, v15.4s\n"
+ "sqdmulh v21.4s, v21.4s, v15.4s\n"
"ld1 {v3.8b}, [%[rhs_ptr]], #8\n"
- "sqrdmulh v22.4s, v22.4s, v14.4s\n"
+ "sqdmulh v22.4s, v22.4s, v14.4s\n"
"ldr x6, [%[rhs_ptr]], #8\n"
- "sqrdmulh v23.4s, v23.4s, v15.4s\n"
- "sqrdmulh v24.4s, v24.4s, v14.4s\n"
- "sqrdmulh v25.4s, v25.4s, v15.4s\n"
- "sqrdmulh v26.4s, v26.4s, v14.4s\n"
- "sqrdmulh v27.4s, v27.4s, v15.4s\n"
- "sqrdmulh v28.4s, v28.4s, v14.4s\n"
- "sqrdmulh v29.4s, v29.4s, v15.4s\n"
- "sqrdmulh v30.4s, v30.4s, v14.4s\n"
- "sqrdmulh v31.4s, v31.4s, v15.4s\n"
+ "sqdmulh v23.4s, v23.4s, v15.4s\n"
+ "sqdmulh v24.4s, v24.4s, v14.4s\n"
+ "sqdmulh v25.4s, v25.4s, v15.4s\n"
+ "sqdmulh v26.4s, v26.4s, v14.4s\n"
+ "sqdmulh v27.4s, v27.4s, v15.4s\n"
+ "sqdmulh v28.4s, v28.4s, v14.4s\n"
+ "sqdmulh v29.4s, v29.4s, v15.4s\n"
+ "sqdmulh v30.4s, v30.4s, v14.4s\n"
+ "sqdmulh v31.4s, v31.4s, v15.4s\n"
// Apply the negative exponent part of the multiplier.
- "srshl v16.4s, v16.4s, v9.4s\n"
- "srshl v17.4s, v17.4s, v10.4s\n"
- "srshl v18.4s, v18.4s, v9.4s\n"
- "srshl v19.4s, v19.4s, v10.4s\n"
- "srshl v20.4s, v20.4s, v9.4s\n"
- "srshl v21.4s, v21.4s, v10.4s\n"
- "srshl v22.4s, v22.4s, v9.4s\n"
- "srshl v23.4s, v23.4s, v10.4s\n"
- "srshl v24.4s, v24.4s, v9.4s\n"
- "srshl v25.4s, v25.4s, v10.4s\n"
+ "srshl v16.4s, v16.4s, v11.4s\n"
+ "srshl v17.4s, v17.4s, v12.4s\n"
+ "srshl v18.4s, v18.4s, v11.4s\n"
+ "srshl v19.4s, v19.4s, v12.4s\n"
+ "srshl v20.4s, v20.4s, v11.4s\n"
+ "srshl v21.4s, v21.4s, v12.4s\n"
+ "srshl v22.4s, v22.4s, v11.4s\n"
+ "srshl v23.4s, v23.4s, v12.4s\n"
+ "srshl v24.4s, v24.4s, v11.4s\n"
+ "srshl v25.4s, v25.4s, v12.4s\n"
"ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
- "srshl v26.4s, v26.4s, v9.4s\n"
+ "srshl v26.4s, v26.4s, v11.4s\n"
"ins v13.h[4], w4\n" // dst_zero_point
- "srshl v27.4s, v27.4s, v10.4s\n"
+ "srshl v27.4s, v27.4s, v12.4s\n"
"ins v0.d[1], x1\n"
- "srshl v28.4s, v28.4s, v9.4s\n"
+ "srshl v28.4s, v28.4s, v11.4s\n"
"ins v1.d[1], x2\n"
- "srshl v29.4s, v29.4s, v10.4s\n"
+ "srshl v29.4s, v29.4s, v12.4s\n"
"ins v2.d[1], x5\n"
- "srshl v30.4s, v30.4s, v9.4s\n"
+ "srshl v30.4s, v30.4s, v11.4s\n"
"ins v3.d[1], x6\n"
- "srshl v31.4s, v31.4s, v10.4s\n"
+ "srshl v31.4s, v31.4s, v12.4s\n"
"b 9f\n"
"8:\n"
// Case where channels are columns
- "tst w6, #" RUY_STR(RUY_ASM_FLAG_NEEDS_LEFT_SHIFT) "\n"
- "beq 11f\n"
// Apply the positive exponent part of the multiplier.
- "dup v4.4s, v11.s[0]\n"
- "dup v5.4s, v11.s[1]\n"
+ "dup v4.4s, v9.s[0]\n"
+ "dup v5.4s, v9.s[1]\n"
"sshl v16.4s, v16.4s, v4.4s\n"
- "dup v6.4s, v11.s[2]\n"
+ "dup v6.4s, v9.s[2]\n"
"sshl v17.4s, v17.4s, v4.4s\n"
- "dup v7.4s, v11.s[3]\n"
+ "dup v7.4s, v9.s[3]\n"
"sshl v18.4s, v18.4s, v5.4s\n"
- "dup v4.4s, v12.s[0]\n"
+ "dup v4.4s, v10.s[0]\n"
"sshl v19.4s, v19.4s, v5.4s\n"
- "dup v5.4s, v12.s[1]\n"
+ "dup v5.4s, v10.s[1]\n"
"sshl v20.4s, v20.4s, v6.4s\n"
"sshl v21.4s, v21.4s, v6.4s\n"
- "dup v6.4s, v12.s[2]\n"
+ "dup v6.4s, v10.s[2]\n"
"sshl v22.4s, v22.4s, v7.4s\n"
"sshl v23.4s, v23.4s, v7.4s\n"
- "dup v7.4s, v12.s[3]\n"
+ "dup v7.4s, v10.s[3]\n"
"sshl v24.4s, v24.4s, v4.4s\n"
"sshl v25.4s, v25.4s, v4.4s\n"
"sshl v26.4s, v26.4s, v5.4s\n"
@@ -5606,47 +5594,47 @@ void Kernel8bitNeonDotprodA55ish(const KernelParams8bit<8, 8>& params) {
// each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
// in the rest of the work on the current block.
"ld1 {v0.8b}, [%[lhs_ptr]], #8\n"
- "sqrdmulh v16.4s, v16.4s, v14.s[0]\n"
+ "sqdmulh v16.4s, v16.4s, v14.s[0]\n"
"ldr x1, [%[lhs_ptr]], #8\n"
- "sqrdmulh v17.4s, v17.4s, v14.s[0]\n"
+ "sqdmulh v17.4s, v17.4s, v14.s[0]\n"
"ld1 {v1.8b}, [%[lhs_ptr]], #8\n"
- "sqrdmulh v18.4s, v18.4s, v14.s[1]\n"
+ "sqdmulh v18.4s, v18.4s, v14.s[1]\n"
"ldr x2, [%[lhs_ptr]], #8\n"
- "sqrdmulh v19.4s, v19.4s, v14.s[1]\n"
+ "sqdmulh v19.4s, v19.4s, v14.s[1]\n"
"ld1 {v2.8b}, [%[rhs_ptr]], #8\n"
- "sqrdmulh v20.4s, v20.4s, v14.s[2]\n"
+ "sqdmulh v20.4s, v20.4s, v14.s[2]\n"
"ldr x5, [%[rhs_ptr]], #8\n"
- "sqrdmulh v21.4s, v21.4s, v14.s[2]\n"
+ "sqdmulh v21.4s, v21.4s, v14.s[2]\n"
"ld1 {v3.8b}, [%[rhs_ptr]], #8\n"
- "sqrdmulh v22.4s, v22.4s, v14.s[3]\n"
+ "sqdmulh v22.4s, v22.4s, v14.s[3]\n"
"ldr x6, [%[rhs_ptr]], #8\n"
- "sqrdmulh v23.4s, v23.4s, v14.s[3]\n"
- "dup v4.4s, v9.s[0]\n"
- "sqrdmulh v24.4s, v24.4s, v15.s[0]\n"
- "dup v5.4s, v9.s[1]\n"
- "sqrdmulh v25.4s, v25.4s, v15.s[0]\n"
- "dup v6.4s, v9.s[2]\n"
- "sqrdmulh v26.4s, v26.4s, v15.s[1]\n"
- "dup v7.4s, v9.s[3]\n"
- "sqrdmulh v27.4s, v27.4s, v15.s[1]\n"
- "sqrdmulh v28.4s, v28.4s, v15.s[2]\n"
- "sqrdmulh v29.4s, v29.4s, v15.s[2]\n"
- "sqrdmulh v30.4s, v30.4s, v15.s[3]\n"
- "sqrdmulh v31.4s, v31.4s, v15.s[3]\n"
+ "sqdmulh v23.4s, v23.4s, v14.s[3]\n"
+ "dup v4.4s, v11.s[0]\n"
+ "sqdmulh v24.4s, v24.4s, v15.s[0]\n"
+ "dup v5.4s, v11.s[1]\n"
+ "sqdmulh v25.4s, v25.4s, v15.s[0]\n"
+ "dup v6.4s, v11.s[2]\n"
+ "sqdmulh v26.4s, v26.4s, v15.s[1]\n"
+ "dup v7.4s, v11.s[3]\n"
+ "sqdmulh v27.4s, v27.4s, v15.s[1]\n"
+ "sqdmulh v28.4s, v28.4s, v15.s[2]\n"
+ "sqdmulh v29.4s, v29.4s, v15.s[2]\n"
+ "sqdmulh v30.4s, v30.4s, v15.s[3]\n"
+ "sqdmulh v31.4s, v31.4s, v15.s[3]\n"
// Apply the negative exponent part of the multiplier.
"srshl v16.4s, v16.4s, v4.4s\n"
"srshl v17.4s, v17.4s, v4.4s\n"
- "dup v4.4s, v10.s[0]\n"
+ "dup v4.4s, v12.s[0]\n"
"srshl v18.4s, v18.4s, v5.4s\n"
"srshl v19.4s, v19.4s, v5.4s\n"
- "dup v5.4s, v10.s[1]\n"
+ "dup v5.4s, v12.s[1]\n"
"srshl v20.4s, v20.4s, v6.4s\n"
"srshl v21.4s, v21.4s, v6.4s\n"
- "dup v6.4s, v10.s[2]\n"
+ "dup v6.4s, v12.s[2]\n"
"srshl v22.4s, v22.4s, v7.4s\n"
"srshl v23.4s, v23.4s, v7.4s\n"
- "dup v7.4s, v10.s[3]\n"
+ "dup v7.4s, v12.s[3]\n"
"srshl v24.4s, v24.4s, v4.4s\n"
"ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
"srshl v25.4s, v25.4s, v4.4s\n"
diff --git a/ruy/kernel_common.h b/ruy/kernel_common.h
index 4d92a00..9509b8f 100644
--- a/ruy/kernel_common.h
+++ b/ruy/kernel_common.h
@@ -175,16 +175,13 @@ void MakeKernelParams8bit(const PMat<std::int8_t>& lhs,
params->dst_zero_point = dst->zero_point;
params->depth = depth;
params->prod_zp_depth = lhs.zero_point * rhs.zero_point * depth;
+ params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT;
if (mul_params.multiplier_fixedpoint_perchannel()) {
- params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT;
params->flags |= RUY_ASM_FLAG_HAS_PERCHANNEL;
params->multiplier_fixedpoint =
mul_params.multiplier_fixedpoint_perchannel();
params->multiplier_exponent = mul_params.multiplier_exponent_perchannel();
} else {
- if (mul_params.multiplier_exponent() > 0) {
- params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT;
- }
params->multiplier_fixedpoint = params->multiplier_fixedpoint_buf;
params->multiplier_exponent = params->multiplier_exponent_buf;
for (int i = 0; i < LhsCols; i++) {
diff --git a/ruy/mul_params.h b/ruy/mul_params.h
index 9bdbfa4..d5aa27b 100644
--- a/ruy/mul_params.h
+++ b/ruy/mul_params.h
@@ -62,15 +62,9 @@ struct MulParamsStorage;
// accumulators before casting them to the DstScalar type. The default
// values are such that the effective multiplier is 1 (no scaling).
//
-// In the latter case (DstScalar integral and narrower than std::int32_t),
-// the multiplier effect on accumulator values is as follows:
-//
-// 1. Left shift by max(0, multiplier_exponent).
-// 2. Fixed-point multiplication by multiplier_fixedpoint in Q0.31 format.
-// 3. Rounding right shift by max(0, -multiplier_exponent).
-//
-// Reference code for this can be found in the implementation of
-// ruy::ApplyMultiplier. If you look there, you'll find warnings like this:
+// For the latter case (DstScalar integral and narrower than std::int32_t),
+// reference code can be found in the implementation of ruy::ApplyMultiplier.
+// If you look there, you'll find warnings like this:
//
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
// Warning: this code is not meant to be bit-exact-normative.