Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/google/ruy.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorT.J. Alumbaugh <talumbau@google.com>2020-10-19 21:06:47 +0300
committerCopybara-Service <copybara-worker@google.com>2020-10-19 21:07:09 +0300
commitdd1102a6ce6ce501f92f6abd72c89ac59a95afeb (patch)
tree0431d351a878c79aa477c82278f6020885e60f3c
parenta28320aaf5fe2bd8a8aa9c777fc1264a9b49a14f (diff)
Update AVX, AVX2, AVX512 Rescale operations with Rounding Right Shift
PiperOrigin-RevId: 337892847
-rw-r--r--ruy/kernel_avx.cc254
-rw-r--r--ruy/kernel_avx2_fma.cc145
-rw-r--r--ruy/kernel_avx512.cc152
3 files changed, 351 insertions, 200 deletions
diff --git a/ruy/kernel_avx.cc b/ruy/kernel_avx.cc
index 21e8826..057b2d0 100644
--- a/ruy/kernel_avx.cc
+++ b/ruy/kernel_avx.cc
@@ -198,6 +198,10 @@ inline __m256i mm256_mullo_epi32(const __m256i& a, const __m256i& b) {
_mm_castps_si128(_mm_blend_ps(_mm_castsi128_ps(a), _mm_castsi128_ps(b), imm))
// Defined as a macro since `imm` must be an immediate.
+#define BlendM128_epi64(a, b, imm) \
+ _mm_castpd_si128(_mm_blend_pd(_mm_castsi128_pd(a), _mm_castsi128_pd(b), imm))
+
+// Defined as a macro since `imm` must be an immediate.
#define mm256_blend_epi32(ans, a, b, imm) \
__m128i a_lo = _mm256_extractf128_si256(a, 0); \
__m128i a_hi = _mm256_extractf128_si256(a, 1); \
@@ -278,56 +282,6 @@ inline __m256i mm256_sllv_epi64(const __m256i& a, const __m256i& b) {
#define PermuteM128_epi32(a, imm) \
_mm_castps_si128(_mm_permute_ps(_mm_castsi128_ps(a), imm));
-inline __m128i mm_srlv_epi32(const __m128i& a, const __m128i& b) {
- // shift all elements of a by first 32bits of b.
-
- __m128i res0 = _mm_srl_epi32(a, BlendM128_epi32(_mm_setzero_si128(), b, 1));
- // put bits 32-63 of b in the first slot.
- __m128i tmp1 = PermuteM128_epi32(b, 1);
- // put bits 32-63 of a in the first slot.
- __m128i a1 = PermuteM128_epi32(a, 1);
- // shift all elements of a by second 32bits of b.
- __m128i res1 =
- _mm_srl_epi32(a1, BlendM128_epi32(_mm_setzero_si128(), tmp1, 1));
- // put bits 64-95 of b in the first slot.
- __m128i tmp2 = PermuteM128_epi32(b, 2);
- // shift all elements of a by third 32bits of b.
- __m128i res2 =
- _mm_srl_epi32(a, BlendM128_epi32(_mm_setzero_si128(), tmp2, 1));
- // put bits 96-127 of b in the first slot.
- __m128i tmp3 = PermuteM128_epi32(b, 3);
- // put bits 96-127 of a in the third slot.
- __m128i a3 = PermuteM128_epi32(a, 48);
- // shift all elements of a3 by fourth 32bits of b.
- __m128i res3 =
- _mm_srl_epi32(a3, BlendM128_epi32(_mm_setzero_si128(), tmp3, 1));
- // Take bits 0-31 of res0, bits 0-31 of res1,
- // bits 64-95 of res2, and bits 64-95 of res3.
- // res0 _ _ _ 0
- // res1 _ _ _ 1
- // res2 _ 2 _ _
- // res3 _ 3 _ _
- // f_01 _ _ 1 0
- // f_23 _ _ 3 2
-
- __m128i f_01 = _mm_unpacklo_epi32(res0, res1);
- __m128i f_23 = _mm_unpacklo_epi32(res2, res3);
- // The lower 64 bits of res_lo and the lower 64 bits of tmp_hi.
- return _mm_unpacklo_epi64(f_01, f_23);
-}
-
-template <Path path>
-inline __m256i mm256_srlv_epi32(const __m256i& a, const __m256i& b) {
- __m128i a_lo = _mm256_extractf128_si256(a, 0);
- __m128i a_hi = _mm256_extractf128_si256(a, 1);
- __m128i b_lo = _mm256_extractf128_si256(b, 0);
- __m128i b_hi = _mm256_extractf128_si256(b, 1);
- __m128i lo = mm_srlv_epi32(a_lo, b_lo);
- __m128i hi = mm_srlv_epi32(a_hi, b_hi);
- __m256i ans = _mm256_set_m128i(hi, lo);
- return ans;
-}
-
inline __m128i mm_sllv_epi32(const __m128i& a, const __m128i& b) {
// shift all elements of a by first 32bits of b.
__m128i res0 = _mm_sll_epi32(a, BlendM128_epi32(_mm_setzero_si128(), b, 1));
@@ -426,6 +380,54 @@ inline __m256i AddBiasEpi32(const __m256i& a, const int32_t* bias, int offset) {
return mm256_add_epi32<path>(a, bias0);
}
+__m256i mm256_blendv_epi32(const __m256i& a, const __m256i& b,
+ const __m256i& mask) {
+ __m256 result =
+ _mm256_blendv_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b),
+ _mm256_castsi256_ps(mask));
+ return _mm256_castps_si256(result);
+}
+
+template <Path path>
+inline __m256i mm256_cmpgt_epi32(const __m256i& a, const __m256i& b) {
+ __m128i a_lo = _mm256_extractf128_si256(a, 0);
+ __m128i a_hi = _mm256_extractf128_si256(a, 1);
+ __m128i b_lo = _mm256_extractf128_si256(b, 0);
+ __m128i b_hi = _mm256_extractf128_si256(b, 1);
+ __m128i lo = _mm_cmpgt_epi32(a_lo, b_lo);
+ __m128i hi = _mm_cmpgt_epi32(a_hi, b_hi);
+ return _mm256_set_m128i(hi, lo);
+}
+
+template <Path path>
+inline __m256i mm256_srav_epi32(const __m256i& a, const __m256i& b) {
+ __m128i a_lo = _mm256_extractf128_si256(a, 0);
+ __m128i a_hi = _mm256_extractf128_si256(a, 1);
+
+ __m128i r0 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 0));
+ __m128i r1 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 1));
+ __m128i r2 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 2));
+ __m128i r3 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 3));
+ __m128i r4 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 4));
+ __m128i r5 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 5));
+ __m128i r6 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 6));
+ __m128i r7 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 7));
+
+ // get element 0 from r0, element 1 from r1
+ __m128i r01 = BlendM128_epi32(r0, r1, 2);
+ // get element 2 from r2, element 3 from r3
+ __m128i r23 = BlendM128_epi32(r2, r3, 8);
+ // get element 0 from r4, element 1 from r5
+ __m128i r45 = BlendM128_epi32(r4, r5, 2);
+ // get element 2 from r6, element 3 from r7
+ __m128i r67 = BlendM128_epi32(r6, r7, 8);
+ // get lower 64 bits of r01, upper 64 bits of r23
+ __m128i r0123 = BlendM128_epi64(r01, r23, 2);
+ // get lower 64 bits of r45, upper 64 bits of r67
+ __m128i r4567 = BlendM128_epi64(r45, r67, 2);
+ return _mm256_set_m128i(r4567, r0123);
+}
+
// AVX doesn't have fused multiply-add so we define an inline function to be
// used in the common code following.
template <>
@@ -735,47 +737,25 @@ void Kernel8bitAvxImpl(const KernelParams8bit<8, 8>& params) {
intrin_utils::mm256_sub_epi32<path>(zero_vector, e_vector);
const __m256i right_shift =
intrin_utils::mm256_max_epi32<path>(neg_e_vector, zero_vector);
- const __m256i final_right_shift = intrin_utils::mm256_add_epi32<path>(
- right_shift, _mm256_set1_epi32(31));
+ const __m256i final_right_shift = _mm256_set1_epi32(31);
const __m256i final_right_shift_low =
intrin_utils::mm256_cvtepi32_epi64<path>(
_mm256_extractf128_si256(final_right_shift, 0));
const __m256i final_right_shift_high =
intrin_utils::mm256_cvtepi32_epi64<path>(
_mm256_extractf128_si256(final_right_shift, 1));
- // Really we want 0x100000000, but use half to avoid overflowing.
-
- const __m256i convert_to_signed_halved =
- intrin_utils::mm256_srlv_epi32<path>(_mm256_set1_epi32(0x80000000),
- right_shift);
const __m256i convert_to_unsigned_64 =
_mm256_set1_epi64x(0x8000000000000000);
- __m256i post_scaling_offset = intrin_utils::mm256_add_epi32<path>(
- convert_to_signed_halved, convert_to_signed_halved);
-
- const __m256i offset_vector =
- intrin_utils::mm256_slli_epi64<path>(_mm256_set1_epi64x(1), 30);
- // Really these should be shifted by neg_e_vector, but tests pass when
- // using right_shift.
- const __m256i offset_vector_low = intrin_utils::mm256_add_epi64<path>(
- intrin_utils::mm256_sllv_epi64<path>(
- offset_vector, intrin_utils::mm256_cvtepi32_epi64<path>(
- _mm256_extractf128_si256(right_shift, 0))),
- convert_to_unsigned_64);
- const __m256i offset_vector_high = intrin_utils::mm256_add_epi64<path>(
- intrin_utils::mm256_sllv_epi64<path>(
- offset_vector, intrin_utils::mm256_cvtepi32_epi64<path>(
- _mm256_extractf128_si256(right_shift, 1))),
+ __m256i post_scaling_offset = _mm256_setzero_si256();
+
+ // A "half" added for rounding prior to truncation of 64-bit value.
+ const __m256i offset_vector = intrin_utils::mm256_add_epi64<path>(
+ intrin_utils::mm256_slli_epi64<path>(_mm256_set1_epi64x(1), 30),
convert_to_unsigned_64);
if (params.dst_zero_point) {
- const __m256i dst_zero_point =
- _mm256_set1_epi32(params.dst_zero_point);
- // The post-scaling offset is subtracted later, so this has the effect
- // of adding the zero point.
- post_scaling_offset = intrin_utils::mm256_sub_epi32<path>(
- post_scaling_offset, dst_zero_point);
+ post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point);
}
// We cannot do
@@ -831,7 +811,7 @@ void Kernel8bitAvxImpl(const KernelParams8bit<8, 8>& params) {
// __m256i results =
// _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
// results = _mm256_permutevar8x32_epi32(results, repack_perm);
- // accum_data_v[j] = intrin_utils::mm256_sub_epi32<path>(results,
+ // accum_data_v[j] = intrin_utils::mm256_add_epi32<path>(results,
// post_scaling_offset);
// This multiplier code is complex and expensive enough on x86, that
@@ -850,6 +830,41 @@ void Kernel8bitAvxImpl(const KernelParams8bit<8, 8>& params) {
&accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3,
&accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7);
}
+
+ auto rounding_right_shift = [=](__m256i& results,
+ const __m256i& exponent) {
+ // Construct the "nudge" value for each lane if the exponent is
+ // greater than 0. Otherwise, the nudge is 0.
+ const __m256i zeros = _mm256_setzero_si256();
+ const __m256i mask_rightshift_gtz =
+ intrin_utils::mm256_cmpgt_epi32<path>(exponent, zeros);
+ const __m256i one_shift_exp_minus1 =
+ intrin_utils::mm256_sllv_epi32<path>(
+ _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>(
+ exponent, _mm256_set1_epi32(1)));
+ __m256i nudge = intrin_utils::mm256_blendv_epi32(
+ zeros, one_shift_exp_minus1, mask_rightshift_gtz);
+ // Calculate the shifted sum (results + nudge) >> exp.
+ const __m256i r_plus_nudge =
+ intrin_utils::mm256_add_epi32<path>(results, nudge);
+ const __m256i shifted_sum =
+ intrin_utils::mm256_srav_epi32<path>(r_plus_nudge, exponent);
+
+ // Identify overflow in each lane and create mask.
+ const __m256i one_shift_31minus_exp =
+ intrin_utils::mm256_sllv_epi32<path>(
+ _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>(
+ _mm256_set1_epi32(31), exponent));
+ const __m256i mask_num_plus_nudge_overflow =
+ intrin_utils::mm256_cmpgt_epi32<path>(
+ results, intrin_utils::mm256_sub_epi32<path>(
+ _mm256_set1_epi32(0x7fffffff), nudge));
+ // Fill results with either (results + nudge) >> exponent or
+ // 1 << (31 - exp) in the case of overflow.
+ results = intrin_utils::mm256_blendv_epi32(
+ shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow);
+ };
+
auto apply_multiplier = [=](__m256i& accum) {
__m256i shifted_accum =
intrin_utils::mm256_sllv_epi32<path>(accum, left_shift);
@@ -863,9 +878,9 @@ void Kernel8bitAvxImpl(const KernelParams8bit<8, 8>& params) {
_mm256_extractf128_si256(shifted_accum, 1)),
m_64bit_high);
scaled_v_low = intrin_utils::mm256_add_epi64<path>(scaled_v_low,
- offset_vector_low);
+ offset_vector);
scaled_v_high = intrin_utils::mm256_add_epi64<path>(
- scaled_v_high, offset_vector_high);
+ scaled_v_high, offset_vector);
scaled_v_low = intrin_utils::mm256_srlv_epi64<path>(
scaled_v_low, final_right_shift_low);
@@ -879,8 +894,9 @@ void Kernel8bitAvxImpl(const KernelParams8bit<8, 8>& params) {
// lo->hi (0, 2, 4, 6, 1, 3, 5, 7);
results = intrin_utils::PermuteEpi32EvenOdds<path>(results);
+ rounding_right_shift(results, right_shift);
accum =
- intrin_utils::mm256_sub_epi32<path>(results, post_scaling_offset);
+ intrin_utils::mm256_add_epi32<path>(results, post_scaling_offset);
};
apply_multiplier(accum_data_v0);
apply_multiplier(accum_data_v1);
@@ -1306,45 +1322,25 @@ void Kernel8bitAvxSingleColImpl(const KernelParams8bit<8, 8>& params) {
intrin_utils::mm256_sub_epi32<path>(zero_vector, e_vector);
const __m256i right_shift =
intrin_utils::mm256_max_epi32<path>(neg_e_vector, zero_vector);
- const __m256i final_right_shift = intrin_utils::mm256_add_epi32<path>(
- right_shift, _mm256_set1_epi32(31));
+ const __m256i final_right_shift = _mm256_set1_epi32(31);
const __m256i final_right_shift_low =
intrin_utils::mm256_cvtepi32_epi64<path>(
_mm256_extractf128_si256(final_right_shift, 0));
const __m256i final_right_shift_high =
intrin_utils::mm256_cvtepi32_epi64<path>(
_mm256_extractf128_si256(final_right_shift, 1));
- // Really we want 0x100000000, but use half to avoid overflowing.
- const __m256i convert_to_signed_halved =
- intrin_utils::mm256_srlv_epi32<path>(_mm256_set1_epi32(0x80000000),
- right_shift);
const __m256i convert_to_unsigned_64 =
_mm256_set1_epi64x(0x8000000000000000);
- __m256i post_scaling_offset = intrin_utils::mm256_add_epi32<path>(
- convert_to_signed_halved, convert_to_signed_halved);
-
- const __m256i offset_vector =
- intrin_utils::mm256_slli_epi64<path>(_mm256_set1_epi64x(1), 30);
- // Really these should be shifted by neg_e_vector, but tests pass when
- // using right_shift.
- const __m256i offset_vector_low = intrin_utils::mm256_add_epi64<path>(
- intrin_utils::mm256_sllv_epi64<path>(
- offset_vector, intrin_utils::mm256_cvtepi32_epi64<path>(
- _mm256_extractf128_si256(right_shift, 0))),
- convert_to_unsigned_64);
- const __m256i offset_vector_high = intrin_utils::mm256_add_epi64<path>(
- intrin_utils::mm256_sllv_epi64<path>(
- offset_vector, intrin_utils::mm256_cvtepi32_epi64<path>(
- _mm256_extractf128_si256(right_shift, 1))),
+ __m256i post_scaling_offset = _mm256_setzero_si256();
+
+ // A "half" added for rounding prior to truncation of 64-bit value.
+ const __m256i offset_vector = intrin_utils::mm256_add_epi64<path>(
+ intrin_utils::mm256_slli_epi64<path>(_mm256_set1_epi64x(1), 30),
convert_to_unsigned_64);
if (params.dst_zero_point) {
- const __m256i dst_zero_point = _mm256_set1_epi32(params.dst_zero_point);
- // The post-scaling offset is subtracted later, so this has the effect
- // of adding the zero point.
- post_scaling_offset = intrin_utils::mm256_sub_epi32<path>(
- post_scaling_offset, dst_zero_point);
+ post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point);
}
// See GEMM version for details of this process.
@@ -1362,9 +1358,9 @@ void Kernel8bitAvxSingleColImpl(const KernelParams8bit<8, 8>& params) {
m_64bit_high);
scaled_v_low = intrin_utils::mm256_add_epi64<path>(scaled_v_low,
- offset_vector_low);
+ offset_vector);
scaled_v_high = intrin_utils::mm256_add_epi64<path>(scaled_v_high,
- offset_vector_high);
+ offset_vector);
scaled_v_low = intrin_utils::mm256_srlv_epi64<path>(
scaled_v_low, final_right_shift_low);
@@ -1377,8 +1373,40 @@ void Kernel8bitAvxSingleColImpl(const KernelParams8bit<8, 8>& params) {
// Permute results to this ordering of int32 elements
// lo->hi (0, 2, 4, 6, 1, 3, 5, 7);
results = intrin_utils::PermuteEpi32EvenOdds<path>(results);
+
+ // Now perform the Rounding Right Shift.
+ // First, construct the "nudge" value for each lane if the exponent is
+ // greater than 0. Otherwise, the nudge is 0.
+ const __m256i zeros = _mm256_setzero_si256();
+ const __m256i mask_rightshift_gtz =
+ intrin_utils::mm256_cmpgt_epi32<path>(right_shift, zeros);
+ const __m256i one_shift_exp_minus1 =
+ intrin_utils::mm256_sllv_epi32<path>(
+ _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>(
+ right_shift, _mm256_set1_epi32(1)));
+ __m256i nudge = intrin_utils::mm256_blendv_epi32(
+ zeros, one_shift_exp_minus1, mask_rightshift_gtz);
+ // Calculate the shifted sum (results + nudge) >> exp.
+ const __m256i r_plus_nudge =
+ intrin_utils::mm256_add_epi32<path>(results, nudge);
+ const __m256i shifted_sum =
+ intrin_utils::mm256_srav_epi32<path>(r_plus_nudge, right_shift);
+
+ // Identify overflow in each lane and create mask.
+ const __m256i one_shift_31minus_exp =
+ intrin_utils::mm256_sllv_epi32<path>(
+ _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>(
+ _mm256_set1_epi32(31), right_shift));
+ const __m256i mask_num_plus_nudge_overflow =
+ intrin_utils::mm256_cmpgt_epi32<path>(
+ results, intrin_utils::mm256_sub_epi32<path>(
+ _mm256_set1_epi32(0x7fffffff), nudge));
+ // Fill results with either (results + nudge) >> exponent or
+ // 1 << (31 - exp) in the case of overflow.
+ results = intrin_utils::mm256_blendv_epi32(
+ shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow);
accum_data_v0 =
- intrin_utils::mm256_sub_epi32<path>(results, post_scaling_offset);
+ intrin_utils::mm256_add_epi32<path>(results, post_scaling_offset);
}
}
const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max);
diff --git a/ruy/kernel_avx2_fma.cc b/ruy/kernel_avx2_fma.cc
index 2d8f016..2aefbe3 100644
--- a/ruy/kernel_avx2_fma.cc
+++ b/ruy/kernel_avx2_fma.cc
@@ -88,6 +88,14 @@ inline __m128i mm256_extracti128_si256<Path::kAvx2Fma>(const __m256i& a,
}
}
+__m256i mm256_blendv_epi32(const __m256i& a, const __m256i& b,
+ const __m256i& mask) {
+ __m256 result =
+ _mm256_blendv_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b),
+ _mm256_castsi256_ps(mask));
+ return _mm256_castps_si256(result);
+}
+
} // namespace intrin_utils
} // namespace
@@ -352,43 +360,22 @@ void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) {
const __m256i left_shift = _mm256_max_epi32(e_vector, zero_vector);
const __m256i neg_e_vector = _mm256_sub_epi32(zero_vector, e_vector);
const __m256i right_shift = _mm256_max_epi32(neg_e_vector, zero_vector);
- const __m256i final_right_shift =
- _mm256_add_epi32(right_shift, _mm256_set1_epi32(31));
+ const __m256i final_right_shift = _mm256_set1_epi32(31);
const __m256i final_right_shift_low = _mm256_cvtepi32_epi64(
_mm256_extracti128_si256(final_right_shift, 0));
const __m256i final_right_shift_high = _mm256_cvtepi32_epi64(
_mm256_extracti128_si256(final_right_shift, 1));
- // Really we want 0x100000000, but use half to avoid overflowing.
- const __m256i convert_to_signed_halved =
- _mm256_srlv_epi32(_mm256_set1_epi32(0x80000000), right_shift);
const __m256i convert_to_unsigned_64 =
_mm256_set1_epi64x(0x8000000000000000);
- __m256i post_scaling_offset = _mm256_add_epi32(
- convert_to_signed_halved, convert_to_signed_halved);
-
- const __m256i offset_vector =
- _mm256_slli_epi64(_mm256_set1_epi64x(1), 30);
- // Really these should be shifted by neg_e_vector, but tests pass when
- // using right_shift.
- const __m256i offset_vector_low = _mm256_add_epi64(
- _mm256_sllv_epi64(offset_vector,
- _mm256_cvtepi32_epi64(
- _mm256_extracti128_si256(right_shift, 0))),
- convert_to_unsigned_64);
- const __m256i offset_vector_high = _mm256_add_epi64(
- _mm256_sllv_epi64(offset_vector,
- _mm256_cvtepi32_epi64(
- _mm256_extracti128_si256(right_shift, 1))),
+ __m256i post_scaling_offset = _mm256_setzero_si256();
+ // A "half" added for rounding prior to truncation of 64-bit value.
+ const __m256i offset_vector = _mm256_add_epi64(
+ _mm256_slli_epi64(_mm256_set1_epi64x(1), 30),
convert_to_unsigned_64);
if (params.dst_zero_point) {
- const __m256i dst_zero_point =
- _mm256_set1_epi32(params.dst_zero_point);
- // The post-scaling offset is subtracted later, so this has the effect
- // of adding the zero point.
- post_scaling_offset =
- _mm256_sub_epi32(post_scaling_offset, dst_zero_point);
+ post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point);
}
const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
@@ -446,7 +433,7 @@ void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) {
// __m256i results =
// _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
// results = _mm256_permutevar8x32_epi32(results, repack_perm);
- // accum_data_v[j] = _mm256_sub_epi32(results, post_scaling_offset);
+ // accum_data_v[j] = _mm256_add_epi32(results, post_scaling_offset);
// This multiplier code is complex and expensive enough on x86, that
// we prefer to implement the channels-are-columns case by transposing
@@ -464,6 +451,35 @@ void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) {
&accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3,
&accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7);
}
+
+ auto rounding_right_shift = [=](__m256i& results,
+ const __m256i& exponent) {
+ // Construct the "nudge" value for each lane if the exponent is
+ // greater than 0. Otherwise, the nudge is 0.
+ const __m256i zeros = _mm256_setzero_si256();
+ const __m256i mask_rightshift_gtz =
+ _mm256_cmpgt_epi32(exponent, zeros);
+ const __m256i one_shift_exp_minus1 = _mm256_sllv_epi32(
+ _mm256_set1_epi32(1),
+ _mm256_sub_epi32(exponent, _mm256_set1_epi32(1)));
+ __m256i nudge = intrin_utils::mm256_blendv_epi32(
+ zeros, one_shift_exp_minus1, mask_rightshift_gtz);
+ // Calculate the shifted sum (results + nudge) >> exp.
+ const __m256i r_plus_nudge = _mm256_add_epi32(results, nudge);
+ const __m256i shifted_sum = _mm256_srav_epi32(r_plus_nudge, exponent);
+
+ // Identify overflow in each lane and create mask.
+ const __m256i one_shift_31minus_exp = _mm256_sllv_epi32(
+ _mm256_set1_epi32(1),
+ _mm256_sub_epi32(_mm256_set1_epi32(31), exponent));
+ const __m256i mask_num_plus_nudge_overflow = _mm256_cmpgt_epi32(
+ results, _mm256_sub_epi32(_mm256_set1_epi32(0x7fffffff), nudge));
+ // Fill results with either (results + nudge) >> exponent or
+ // 1 << (31 - exp) in the case of overflow.
+ results = intrin_utils::mm256_blendv_epi32(
+ shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow);
+ };
+
auto apply_multiplier = [=](__m256i& accum) {
__m256i shifted_accum = _mm256_sllv_epi32(accum, left_shift);
// Apply the fixed-point part of the multiplier.
@@ -474,8 +490,8 @@ void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) {
_mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
m_64bit_high);
- scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low);
- scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high);
+ scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector);
+ scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector);
scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
scaled_v_high =
@@ -485,8 +501,9 @@ void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) {
__m256i results =
_mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
results = _mm256_permutevar8x32_epi32(results, repack_perm);
-
- accum = _mm256_sub_epi32(results, post_scaling_offset);
+ // Now do a Rounding Right Shift.
+ rounding_right_shift(results, right_shift);
+ accum = _mm256_add_epi32(results, post_scaling_offset);
};
apply_multiplier(accum_data_v0);
apply_multiplier(accum_data_v1);
@@ -856,42 +873,22 @@ void Kernel8bitAvx2SingleColImpl(const KernelParams8bit<8, 8>& params) {
const __m256i left_shift = _mm256_max_epi32(e_vector, zero_vector);
const __m256i neg_e_vector = _mm256_sub_epi32(zero_vector, e_vector);
const __m256i right_shift = _mm256_max_epi32(neg_e_vector, zero_vector);
- const __m256i final_right_shift =
- _mm256_add_epi32(right_shift, _mm256_set1_epi32(31));
+ const __m256i final_right_shift = _mm256_set1_epi32(31);
const __m256i final_right_shift_low =
_mm256_cvtepi32_epi64(_mm256_extracti128_si256(final_right_shift, 0));
const __m256i final_right_shift_high =
_mm256_cvtepi32_epi64(_mm256_extracti128_si256(final_right_shift, 1));
- // Really we want 0x100000000, but use half to avoid overflowing.
- const __m256i convert_to_signed_halved =
- _mm256_srlv_epi32(_mm256_set1_epi32(0x80000000), right_shift);
const __m256i convert_to_unsigned_64 =
_mm256_set1_epi64x(0x8000000000000000);
- __m256i post_scaling_offset =
- _mm256_add_epi32(convert_to_signed_halved, convert_to_signed_halved);
-
- const __m256i offset_vector =
- _mm256_slli_epi64(_mm256_set1_epi64x(1), 30);
- // Really these should be shifted by neg_e_vector, but tests pass when
- // using right_shift.
- const __m256i offset_vector_low = _mm256_add_epi64(
- _mm256_sllv_epi64(
- offset_vector,
- _mm256_cvtepi32_epi64(_mm256_extracti128_si256(right_shift, 0))),
- convert_to_unsigned_64);
- const __m256i offset_vector_high = _mm256_add_epi64(
- _mm256_sllv_epi64(
- offset_vector,
- _mm256_cvtepi32_epi64(_mm256_extracti128_si256(right_shift, 1))),
+ __m256i post_scaling_offset = _mm256_setzero_si256();
+ // A "half" added for rounding prior to truncation of 64-bit value.
+ const __m256i offset_vector = _mm256_add_epi64(
+ _mm256_slli_epi64(_mm256_set1_epi64x(1), 30),
convert_to_unsigned_64);
if (params.dst_zero_point) {
- const __m256i dst_zero_point = _mm256_set1_epi32(params.dst_zero_point);
- // The post-scaling offset is subtracted later, so this has the effect
- // of adding the zero point.
- post_scaling_offset =
- _mm256_sub_epi32(post_scaling_offset, dst_zero_point);
+ post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point);
}
const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
@@ -907,8 +904,8 @@ void Kernel8bitAvx2SingleColImpl(const KernelParams8bit<8, 8>& params) {
_mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
m_64bit_high);
- scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low);
- scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high);
+ scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector);
+ scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector);
scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
scaled_v_high =
@@ -918,7 +915,33 @@ void Kernel8bitAvx2SingleColImpl(const KernelParams8bit<8, 8>& params) {
__m256i results = _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
results = _mm256_permutevar8x32_epi32(results, repack_perm);
- accum_data_v0 = _mm256_sub_epi32(results, post_scaling_offset);
+ // Now do a Rounding Right Shift.
+ // First, construct the nudge value for each lane.
+ const __m256i zeros = _mm256_setzero_si256();
+ const __m256i mask_rightshift_gtz =
+ _mm256_cmpgt_epi32(right_shift, zeros);
+ const __m256i one_shift_exp_minus1 = _mm256_sllv_epi32(
+ _mm256_set1_epi32(1),
+ _mm256_sub_epi32(right_shift, _mm256_set1_epi32(1)));
+ __m256i nudge = intrin_utils::mm256_blendv_epi32(
+ zeros, one_shift_exp_minus1, mask_rightshift_gtz);
+ // Calculate the shifted sum (results + nudge) >> exp.
+ const __m256i r_plus_nudge = _mm256_add_epi32(results, nudge);
+ const __m256i shifted_sum =
+ _mm256_srav_epi32(r_plus_nudge, right_shift);
+
+ // Identify overflow in each lane and create mask.
+ const __m256i one_shift_31minus_exp = _mm256_sllv_epi32(
+ _mm256_set1_epi32(1),
+ _mm256_sub_epi32(_mm256_set1_epi32(31), right_shift));
+ const __m256i mask_num_plus_nudge_overflow = _mm256_cmpgt_epi32(
+ results, _mm256_sub_epi32(_mm256_set1_epi32(0x7fffffff), nudge));
+ // Fill results with either (results + nudge) >> exponent or
+ // 1 << (31 - exp) in the case of overflow.
+ results = intrin_utils::mm256_blendv_epi32(
+ shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow);
+
+ accum_data_v0 = _mm256_add_epi32(results, post_scaling_offset);
}
}
const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max);
diff --git a/ruy/kernel_avx512.cc b/ruy/kernel_avx512.cc
index 6a65ca7..fddb482 100644
--- a/ruy/kernel_avx512.cc
+++ b/ruy/kernel_avx512.cc
@@ -52,6 +52,45 @@ void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>&) {
#else // RUY_PLATFORM_AVX512 && RUY_OPT(ASM)
+namespace {
+namespace intrin_utils {
+
+__m256i mm256_blendv_epi64(const __m256i& a, const __m256i& b,
+ const __m256i& mask) {
+ __m256d result =
+ _mm256_blendv_pd(_mm256_castsi256_pd(a), _mm256_castsi256_pd(b),
+ _mm256_castsi256_pd(mask));
+ return _mm256_castpd_si256(result);
+}
+
+__m512i mm512_blendv_epi64(const __m512i& a, const __m512i& b,
+ const __m512i& mask) {
+ __m256i a_lo = _mm512_extracti64x4_epi64(a, 0);
+ __m256i a_hi = _mm512_extracti64x4_epi64(a, 1);
+ __m256i b_lo = _mm512_extracti64x4_epi64(b, 0);
+ __m256i b_hi = _mm512_extracti64x4_epi64(b, 1);
+ __m256i mask_lo = _mm512_extracti64x4_epi64(mask, 0);
+ __m256i mask_hi = _mm512_extracti64x4_epi64(mask, 1);
+ __m256i lo = mm256_blendv_epi64(a_lo, b_lo, mask_lo);
+ __m256i hi = mm256_blendv_epi64(a_hi, b_hi, mask_hi);
+ __m512i result = _mm512_inserti64x4(_mm512_setzero_si512(), lo, 0);
+ return _mm512_inserti64x4(result, hi, 1);
+}
+
+__m512i mm512_cmpgt_epi64(const __m512i& a, const __m512i& b) {
+ __m256i a_lo = _mm512_extracti64x4_epi64(a, 0);
+ __m256i a_hi = _mm512_extracti64x4_epi64(a, 1);
+ __m256i b_lo = _mm512_extracti64x4_epi64(b, 0);
+ __m256i b_hi = _mm512_extracti64x4_epi64(b, 1);
+ __m256i lo = _mm256_cmpgt_epi64(a_lo, b_lo);
+ __m256i hi = _mm256_cmpgt_epi64(a_hi, b_hi);
+ __m512i result = _mm512_inserti64x4(_mm512_setzero_si512(), lo, 0);
+ return _mm512_inserti64x4(result, hi, 1);
+}
+
+} // namespace intrin_utils
+} // namespace
+
void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) {
profiler::ScopeLabel label("Kernel kAvx512 8-bit");
@@ -333,23 +372,49 @@ void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) {
const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector);
const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector);
const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector);
- const __m512i final_right_shift =
- _mm512_add_epi32(right_shift, _mm512_set1_epi32(31));
+ const __m512i final_right_shift = _mm512_set1_epi32(31);
+ const __m512i right_shift_low =
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0));
+ const __m512i right_shift_high =
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1));
const __m512i final_right_shift_low = _mm512_cvtepi32_epi64(
_mm512_extracti32x8_epi32(final_right_shift, 0));
const __m512i final_right_shift_high = _mm512_cvtepi32_epi64(
_mm512_extracti32x8_epi32(final_right_shift, 1));
+ // A "half" added for rounding prior to truncation of 64-bit value.
const __m512i offset_vector =
_mm512_slli_epi64(_mm512_set1_epi64(1), 30);
- // Really these should be shifted by neg_e_vector, but tests pass when
- // using right_shift.
- const __m512i offset_vector_low = _mm512_sllv_epi64(
- offset_vector,
- _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0)));
- const __m512i offset_vector_high = _mm512_sllv_epi64(
- offset_vector,
- _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1)));
+
+ auto rounding_right_shift = [=](__m512i& results,
+ const __m512i& exponent) {
+ // Construct the "nudge" value for each lane if the exponent is
+ // greater than 0. Otherwise, the nudge is 0.
+ const __m512i zeros = _mm512_setzero_si512();
+ const __m512i mask_rightshift_gtz =
+ intrin_utils::mm512_cmpgt_epi64(exponent, zeros);
+ const __m512i one_shift_exp_minus1 = _mm512_sllv_epi64(
+ _mm512_set1_epi64(1),
+ _mm512_sub_epi64(exponent, _mm512_set1_epi64(1)));
+ __m512i nudge = intrin_utils::mm512_blendv_epi64(
+ zeros, one_shift_exp_minus1, mask_rightshift_gtz);
+ // Calculate the shifted sum (results + nudge) >> exp.
+ const __m512i r_plus_nudge = _mm512_add_epi64(results, nudge);
+ const __m512i shifted_sum = _mm512_srav_epi64(r_plus_nudge, exponent);
+
+ // Identify overflow in each lane and create mask.
+ const __m512i one_shift_31minus_exp = _mm512_sllv_epi64(
+ _mm512_set1_epi64(1),
+ _mm512_sub_epi64(_mm512_set1_epi64(31), exponent));
+ const __m512i mask_num_plus_nudge_overflow =
+ intrin_utils::mm512_cmpgt_epi64(
+ results,
+ _mm512_sub_epi64(_mm512_set1_epi64(0x7fffffff), nudge));
+ // Fill results with either (results + nudge) >> exponent or
+ // 1 << (31 - exp) in the case of overflow.
+ results = intrin_utils::mm512_blendv_epi64(
+ shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow);
+ };
if (per_column_multiplier) {
auto apply_multiplier = [=](__m512i& accum, int col) {
@@ -360,11 +425,12 @@ void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) {
__m512i m_64bit_val = _mm512_permutexvar_epi64(
perm_64bit_vals, col < 8 ? m_64bit_low : m_64bit_high);
__m512i offset_vector_val = _mm512_permutexvar_epi64(
- perm_64bit_vals,
- col < 8 ? offset_vector_low : offset_vector_high);
+ perm_64bit_vals, offset_vector);
__m512i final_right_shift_val = _mm512_permutexvar_epi64(
perm_64bit_vals,
col < 8 ? final_right_shift_low : final_right_shift_high);
+ __m512i right_shift_val = _mm512_permutexvar_epi64(
+ perm_64bit_vals, col < 8 ? right_shift_low : right_shift_high);
accum = _mm512_sllv_epi32(accum, left_shift_val);
__m512i scaled_v_low = _mm512_mul_epi32(
@@ -382,6 +448,9 @@ void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) {
scaled_v_high =
_mm512_srav_epi64(scaled_v_high, final_right_shift_val);
+ rounding_right_shift(scaled_v_low, right_shift_val);
+ rounding_right_shift(scaled_v_high, right_shift_val);
+
accum = _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
accum = _mm512_inserti32x8(accum,
_mm512_cvtepi64_epi32(scaled_v_high), 1);
@@ -413,14 +482,16 @@ void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) {
_mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 1)),
m_64bit_high);
- scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low);
- scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high);
+ scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector);
+ scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector);
scaled_v_low =
_mm512_srav_epi64(scaled_v_low, final_right_shift_low);
scaled_v_high =
_mm512_srav_epi64(scaled_v_high, final_right_shift_high);
+ rounding_right_shift(scaled_v_low, right_shift_low);
+ rounding_right_shift(scaled_v_high, right_shift_high);
accum = _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
accum = _mm512_inserti32x8(accum,
_mm512_cvtepi64_epi32(scaled_v_high), 1);
@@ -713,22 +784,48 @@ void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) {
const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector);
const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector);
const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector);
- const __m512i final_right_shift =
- _mm512_add_epi32(right_shift, _mm512_set1_epi32(31));
+ const __m512i final_right_shift = _mm512_set1_epi32(31);
+ const __m512i right_shift_low =
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0));
+ const __m512i right_shift_high =
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1));
const __m512i final_right_shift_low = _mm512_cvtepi32_epi64(
_mm512_extracti32x8_epi32(final_right_shift, 0));
const __m512i final_right_shift_high = _mm512_cvtepi32_epi64(
_mm512_extracti32x8_epi32(final_right_shift, 1));
+ // A "half" added for rounding prior to truncation of 64-bit value.
const __m512i offset_vector = _mm512_slli_epi64(_mm512_set1_epi64(1), 30);
- // Really these should be shifted by neg_e_vector, but tests pass when
- // using right_shift.
- const __m512i offset_vector_low = _mm512_sllv_epi64(
- offset_vector,
- _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0)));
- const __m512i offset_vector_high = _mm512_sllv_epi64(
- offset_vector,
- _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1)));
+
+ auto rounding_right_shift = [=](__m512i& results,
+ const __m512i& exponent) {
+ // Construct the "nudge" value for each lane if the exponent is
+ // greater than 0. Otherwise, the nudge is 0.
+ const __m512i zeros = _mm512_setzero_si512();
+ const __m512i mask_rightshift_gtz =
+ intrin_utils::mm512_cmpgt_epi64(exponent, zeros);
+ const __m512i one_shift_exp_minus1 =
+ _mm512_sllv_epi64(_mm512_set1_epi64(1),
+ _mm512_sub_epi64(exponent, _mm512_set1_epi64(1)));
+ __m512i nudge = intrin_utils::mm512_blendv_epi64(
+ zeros, one_shift_exp_minus1, mask_rightshift_gtz);
+ // Calculate the shifted sum (results + nudge) >> exp.
+ const __m512i r_plus_nudge = _mm512_add_epi64(results, nudge);
+ const __m512i shifted_sum = _mm512_srav_epi64(r_plus_nudge, exponent);
+
+ // Identify overflow in each lane and create mask.
+ const __m512i one_shift_31minus_exp = _mm512_sllv_epi64(
+ _mm512_set1_epi64(1),
+ _mm512_sub_epi64(_mm512_set1_epi64(31), exponent));
+ const __m512i mask_num_plus_nudge_overflow =
+ intrin_utils::mm512_cmpgt_epi64(
+ results,
+ _mm512_sub_epi64(_mm512_set1_epi64(0x7fffffff), nudge));
+ // Fill results with either (results + nudge) >> exponent or
+ // 1 << (31 - exp) in the case of overflow.
+ results = intrin_utils::mm512_blendv_epi64(
+ shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow);
+ };
// Shift and round column 0.
accum_data_v0 = _mm512_sllv_epi32(accum_data_v0, left_shift);
@@ -740,12 +837,15 @@ void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) {
_mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum_data_v0, 1)),
m_64bit_high);
- scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low);
- scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high);
+ scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector);
+ scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector);
scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
scaled_v_high = _mm512_srav_epi64(scaled_v_high, final_right_shift_high);
+ rounding_right_shift(scaled_v_low, right_shift_low);
+ rounding_right_shift(scaled_v_high, right_shift_high);
+
accum_data_v0 =
_mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
accum_data_v0 = _mm512_inserti32x8(