diff options
author | T.J. Alumbaugh <talumbau@google.com> | 2020-10-19 21:06:47 +0300 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2020-10-19 21:07:09 +0300 |
commit | dd1102a6ce6ce501f92f6abd72c89ac59a95afeb (patch) | |
tree | 0431d351a878c79aa477c82278f6020885e60f3c | |
parent | a28320aaf5fe2bd8a8aa9c777fc1264a9b49a14f (diff) |
Update AVX, AVX2, AVX512 Rescale operations with Rounding Right Shift
PiperOrigin-RevId: 337892847
-rw-r--r-- | ruy/kernel_avx.cc | 254 | ||||
-rw-r--r-- | ruy/kernel_avx2_fma.cc | 145 | ||||
-rw-r--r-- | ruy/kernel_avx512.cc | 152 |
3 files changed, 351 insertions, 200 deletions
diff --git a/ruy/kernel_avx.cc b/ruy/kernel_avx.cc index 21e8826..057b2d0 100644 --- a/ruy/kernel_avx.cc +++ b/ruy/kernel_avx.cc @@ -198,6 +198,10 @@ inline __m256i mm256_mullo_epi32(const __m256i& a, const __m256i& b) { _mm_castps_si128(_mm_blend_ps(_mm_castsi128_ps(a), _mm_castsi128_ps(b), imm)) // Defined as a macro since `imm` must be an immediate. +#define BlendM128_epi64(a, b, imm) \ + _mm_castpd_si128(_mm_blend_pd(_mm_castsi128_pd(a), _mm_castsi128_pd(b), imm)) + +// Defined as a macro since `imm` must be an immediate. #define mm256_blend_epi32(ans, a, b, imm) \ __m128i a_lo = _mm256_extractf128_si256(a, 0); \ __m128i a_hi = _mm256_extractf128_si256(a, 1); \ @@ -278,56 +282,6 @@ inline __m256i mm256_sllv_epi64(const __m256i& a, const __m256i& b) { #define PermuteM128_epi32(a, imm) \ _mm_castps_si128(_mm_permute_ps(_mm_castsi128_ps(a), imm)); -inline __m128i mm_srlv_epi32(const __m128i& a, const __m128i& b) { - // shift all elements of a by first 32bits of b. - - __m128i res0 = _mm_srl_epi32(a, BlendM128_epi32(_mm_setzero_si128(), b, 1)); - // put bits 32-63 of b in the first slot. - __m128i tmp1 = PermuteM128_epi32(b, 1); - // put bits 32-63 of a in the first slot. - __m128i a1 = PermuteM128_epi32(a, 1); - // shift all elements of a by second 32bits of b. - __m128i res1 = - _mm_srl_epi32(a1, BlendM128_epi32(_mm_setzero_si128(), tmp1, 1)); - // put bits 64-95 of b in the first slot. - __m128i tmp2 = PermuteM128_epi32(b, 2); - // shift all elements of a by third 32bits of b. - __m128i res2 = - _mm_srl_epi32(a, BlendM128_epi32(_mm_setzero_si128(), tmp2, 1)); - // put bits 96-127 of b in the first slot. - __m128i tmp3 = PermuteM128_epi32(b, 3); - // put bits 96-127 of a in the third slot. - __m128i a3 = PermuteM128_epi32(a, 48); - // shift all elements of a3 by fourth 32bits of b. - __m128i res3 = - _mm_srl_epi32(a3, BlendM128_epi32(_mm_setzero_si128(), tmp3, 1)); - // Take bits 0-31 of res0, bits 0-31 of res1, - // bits 64-95 of res2, and bits 64-95 of res3. - // res0 _ _ _ 0 - // res1 _ _ _ 1 - // res2 _ 2 _ _ - // res3 _ 3 _ _ - // f_01 _ _ 1 0 - // f_23 _ _ 3 2 - - __m128i f_01 = _mm_unpacklo_epi32(res0, res1); - __m128i f_23 = _mm_unpacklo_epi32(res2, res3); - // The lower 64 bits of res_lo and the lower 64 bits of tmp_hi. - return _mm_unpacklo_epi64(f_01, f_23); -} - -template <Path path> -inline __m256i mm256_srlv_epi32(const __m256i& a, const __m256i& b) { - __m128i a_lo = _mm256_extractf128_si256(a, 0); - __m128i a_hi = _mm256_extractf128_si256(a, 1); - __m128i b_lo = _mm256_extractf128_si256(b, 0); - __m128i b_hi = _mm256_extractf128_si256(b, 1); - __m128i lo = mm_srlv_epi32(a_lo, b_lo); - __m128i hi = mm_srlv_epi32(a_hi, b_hi); - __m256i ans = _mm256_set_m128i(hi, lo); - return ans; -} - inline __m128i mm_sllv_epi32(const __m128i& a, const __m128i& b) { // shift all elements of a by first 32bits of b. __m128i res0 = _mm_sll_epi32(a, BlendM128_epi32(_mm_setzero_si128(), b, 1)); @@ -426,6 +380,54 @@ inline __m256i AddBiasEpi32(const __m256i& a, const int32_t* bias, int offset) { return mm256_add_epi32<path>(a, bias0); } +__m256i mm256_blendv_epi32(const __m256i& a, const __m256i& b, + const __m256i& mask) { + __m256 result = + _mm256_blendv_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), + _mm256_castsi256_ps(mask)); + return _mm256_castps_si256(result); +} + +template <Path path> +inline __m256i mm256_cmpgt_epi32(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + __m128i b_lo = _mm256_extractf128_si256(b, 0); + __m128i b_hi = _mm256_extractf128_si256(b, 1); + __m128i lo = _mm_cmpgt_epi32(a_lo, b_lo); + __m128i hi = _mm_cmpgt_epi32(a_hi, b_hi); + return _mm256_set_m128i(hi, lo); +} + +template <Path path> +inline __m256i mm256_srav_epi32(const __m256i& a, const __m256i& b) { + __m128i a_lo = _mm256_extractf128_si256(a, 0); + __m128i a_hi = _mm256_extractf128_si256(a, 1); + + __m128i r0 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 0)); + __m128i r1 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 1)); + __m128i r2 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 2)); + __m128i r3 = _mm_srai_epi32(a_lo, _mm256_extract_epi32(b, 3)); + __m128i r4 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 4)); + __m128i r5 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 5)); + __m128i r6 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 6)); + __m128i r7 = _mm_srai_epi32(a_hi, _mm256_extract_epi32(b, 7)); + + // get element 0 from r0, element 1 from r1 + __m128i r01 = BlendM128_epi32(r0, r1, 2); + // get element 2 from r2, element 3 from r3 + __m128i r23 = BlendM128_epi32(r2, r3, 8); + // get element 0 from r4, element 1 from r5 + __m128i r45 = BlendM128_epi32(r4, r5, 2); + // get element 2 from r6, element 3 from r7 + __m128i r67 = BlendM128_epi32(r6, r7, 8); + // get lower 64 bits of r01, upper 64 bits of r23 + __m128i r0123 = BlendM128_epi64(r01, r23, 2); + // get lower 64 bits of r45, upper 64 bits of r67 + __m128i r4567 = BlendM128_epi64(r45, r67, 2); + return _mm256_set_m128i(r4567, r0123); +} + // AVX doesn't have fused multiply-add so we define an inline function to be // used in the common code following. template <> @@ -735,47 +737,25 @@ void Kernel8bitAvxImpl(const KernelParams8bit<8, 8>& params) { intrin_utils::mm256_sub_epi32<path>(zero_vector, e_vector); const __m256i right_shift = intrin_utils::mm256_max_epi32<path>(neg_e_vector, zero_vector); - const __m256i final_right_shift = intrin_utils::mm256_add_epi32<path>( - right_shift, _mm256_set1_epi32(31)); + const __m256i final_right_shift = _mm256_set1_epi32(31); const __m256i final_right_shift_low = intrin_utils::mm256_cvtepi32_epi64<path>( _mm256_extractf128_si256(final_right_shift, 0)); const __m256i final_right_shift_high = intrin_utils::mm256_cvtepi32_epi64<path>( _mm256_extractf128_si256(final_right_shift, 1)); - // Really we want 0x100000000, but use half to avoid overflowing. - - const __m256i convert_to_signed_halved = - intrin_utils::mm256_srlv_epi32<path>(_mm256_set1_epi32(0x80000000), - right_shift); const __m256i convert_to_unsigned_64 = _mm256_set1_epi64x(0x8000000000000000); - __m256i post_scaling_offset = intrin_utils::mm256_add_epi32<path>( - convert_to_signed_halved, convert_to_signed_halved); - - const __m256i offset_vector = - intrin_utils::mm256_slli_epi64<path>(_mm256_set1_epi64x(1), 30); - // Really these should be shifted by neg_e_vector, but tests pass when - // using right_shift. - const __m256i offset_vector_low = intrin_utils::mm256_add_epi64<path>( - intrin_utils::mm256_sllv_epi64<path>( - offset_vector, intrin_utils::mm256_cvtepi32_epi64<path>( - _mm256_extractf128_si256(right_shift, 0))), - convert_to_unsigned_64); - const __m256i offset_vector_high = intrin_utils::mm256_add_epi64<path>( - intrin_utils::mm256_sllv_epi64<path>( - offset_vector, intrin_utils::mm256_cvtepi32_epi64<path>( - _mm256_extractf128_si256(right_shift, 1))), + __m256i post_scaling_offset = _mm256_setzero_si256(); + + // A "half" added for rounding prior to truncation of 64-bit value. + const __m256i offset_vector = intrin_utils::mm256_add_epi64<path>( + intrin_utils::mm256_slli_epi64<path>(_mm256_set1_epi64x(1), 30), convert_to_unsigned_64); if (params.dst_zero_point) { - const __m256i dst_zero_point = - _mm256_set1_epi32(params.dst_zero_point); - // The post-scaling offset is subtracted later, so this has the effect - // of adding the zero point. - post_scaling_offset = intrin_utils::mm256_sub_epi32<path>( - post_scaling_offset, dst_zero_point); + post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point); } // We cannot do @@ -831,7 +811,7 @@ void Kernel8bitAvxImpl(const KernelParams8bit<8, 8>& params) { // __m256i results = // _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); // results = _mm256_permutevar8x32_epi32(results, repack_perm); - // accum_data_v[j] = intrin_utils::mm256_sub_epi32<path>(results, + // accum_data_v[j] = intrin_utils::mm256_add_epi32<path>(results, // post_scaling_offset); // This multiplier code is complex and expensive enough on x86, that @@ -850,6 +830,41 @@ void Kernel8bitAvxImpl(const KernelParams8bit<8, 8>& params) { &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3, &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7); } + + auto rounding_right_shift = [=](__m256i& results, + const __m256i& exponent) { + // Construct the "nudge" value for each lane if the exponent is + // greater than 0. Otherwise, the nudge is 0. + const __m256i zeros = _mm256_setzero_si256(); + const __m256i mask_rightshift_gtz = + intrin_utils::mm256_cmpgt_epi32<path>(exponent, zeros); + const __m256i one_shift_exp_minus1 = + intrin_utils::mm256_sllv_epi32<path>( + _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>( + exponent, _mm256_set1_epi32(1))); + __m256i nudge = intrin_utils::mm256_blendv_epi32( + zeros, one_shift_exp_minus1, mask_rightshift_gtz); + // Calculate the shifted sum (results + nudge) >> exp. + const __m256i r_plus_nudge = + intrin_utils::mm256_add_epi32<path>(results, nudge); + const __m256i shifted_sum = + intrin_utils::mm256_srav_epi32<path>(r_plus_nudge, exponent); + + // Identify overflow in each lane and create mask. + const __m256i one_shift_31minus_exp = + intrin_utils::mm256_sllv_epi32<path>( + _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>( + _mm256_set1_epi32(31), exponent)); + const __m256i mask_num_plus_nudge_overflow = + intrin_utils::mm256_cmpgt_epi32<path>( + results, intrin_utils::mm256_sub_epi32<path>( + _mm256_set1_epi32(0x7fffffff), nudge)); + // Fill results with either (results + nudge) >> exponent or + // 1 << (31 - exp) in the case of overflow. + results = intrin_utils::mm256_blendv_epi32( + shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow); + }; + auto apply_multiplier = [=](__m256i& accum) { __m256i shifted_accum = intrin_utils::mm256_sllv_epi32<path>(accum, left_shift); @@ -863,9 +878,9 @@ void Kernel8bitAvxImpl(const KernelParams8bit<8, 8>& params) { _mm256_extractf128_si256(shifted_accum, 1)), m_64bit_high); scaled_v_low = intrin_utils::mm256_add_epi64<path>(scaled_v_low, - offset_vector_low); + offset_vector); scaled_v_high = intrin_utils::mm256_add_epi64<path>( - scaled_v_high, offset_vector_high); + scaled_v_high, offset_vector); scaled_v_low = intrin_utils::mm256_srlv_epi64<path>( scaled_v_low, final_right_shift_low); @@ -879,8 +894,9 @@ void Kernel8bitAvxImpl(const KernelParams8bit<8, 8>& params) { // lo->hi (0, 2, 4, 6, 1, 3, 5, 7); results = intrin_utils::PermuteEpi32EvenOdds<path>(results); + rounding_right_shift(results, right_shift); accum = - intrin_utils::mm256_sub_epi32<path>(results, post_scaling_offset); + intrin_utils::mm256_add_epi32<path>(results, post_scaling_offset); }; apply_multiplier(accum_data_v0); apply_multiplier(accum_data_v1); @@ -1306,45 +1322,25 @@ void Kernel8bitAvxSingleColImpl(const KernelParams8bit<8, 8>& params) { intrin_utils::mm256_sub_epi32<path>(zero_vector, e_vector); const __m256i right_shift = intrin_utils::mm256_max_epi32<path>(neg_e_vector, zero_vector); - const __m256i final_right_shift = intrin_utils::mm256_add_epi32<path>( - right_shift, _mm256_set1_epi32(31)); + const __m256i final_right_shift = _mm256_set1_epi32(31); const __m256i final_right_shift_low = intrin_utils::mm256_cvtepi32_epi64<path>( _mm256_extractf128_si256(final_right_shift, 0)); const __m256i final_right_shift_high = intrin_utils::mm256_cvtepi32_epi64<path>( _mm256_extractf128_si256(final_right_shift, 1)); - // Really we want 0x100000000, but use half to avoid overflowing. - const __m256i convert_to_signed_halved = - intrin_utils::mm256_srlv_epi32<path>(_mm256_set1_epi32(0x80000000), - right_shift); const __m256i convert_to_unsigned_64 = _mm256_set1_epi64x(0x8000000000000000); - __m256i post_scaling_offset = intrin_utils::mm256_add_epi32<path>( - convert_to_signed_halved, convert_to_signed_halved); - - const __m256i offset_vector = - intrin_utils::mm256_slli_epi64<path>(_mm256_set1_epi64x(1), 30); - // Really these should be shifted by neg_e_vector, but tests pass when - // using right_shift. - const __m256i offset_vector_low = intrin_utils::mm256_add_epi64<path>( - intrin_utils::mm256_sllv_epi64<path>( - offset_vector, intrin_utils::mm256_cvtepi32_epi64<path>( - _mm256_extractf128_si256(right_shift, 0))), - convert_to_unsigned_64); - const __m256i offset_vector_high = intrin_utils::mm256_add_epi64<path>( - intrin_utils::mm256_sllv_epi64<path>( - offset_vector, intrin_utils::mm256_cvtepi32_epi64<path>( - _mm256_extractf128_si256(right_shift, 1))), + __m256i post_scaling_offset = _mm256_setzero_si256(); + + // A "half" added for rounding prior to truncation of 64-bit value. + const __m256i offset_vector = intrin_utils::mm256_add_epi64<path>( + intrin_utils::mm256_slli_epi64<path>(_mm256_set1_epi64x(1), 30), convert_to_unsigned_64); if (params.dst_zero_point) { - const __m256i dst_zero_point = _mm256_set1_epi32(params.dst_zero_point); - // The post-scaling offset is subtracted later, so this has the effect - // of adding the zero point. - post_scaling_offset = intrin_utils::mm256_sub_epi32<path>( - post_scaling_offset, dst_zero_point); + post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point); } // See GEMM version for details of this process. @@ -1362,9 +1358,9 @@ void Kernel8bitAvxSingleColImpl(const KernelParams8bit<8, 8>& params) { m_64bit_high); scaled_v_low = intrin_utils::mm256_add_epi64<path>(scaled_v_low, - offset_vector_low); + offset_vector); scaled_v_high = intrin_utils::mm256_add_epi64<path>(scaled_v_high, - offset_vector_high); + offset_vector); scaled_v_low = intrin_utils::mm256_srlv_epi64<path>( scaled_v_low, final_right_shift_low); @@ -1377,8 +1373,40 @@ void Kernel8bitAvxSingleColImpl(const KernelParams8bit<8, 8>& params) { // Permute results to this ordering of int32 elements // lo->hi (0, 2, 4, 6, 1, 3, 5, 7); results = intrin_utils::PermuteEpi32EvenOdds<path>(results); + + // Now perform the Rounding Right Shift. + // First, construct the "nudge" value for each lane if the exponent is + // greater than 0. Otherwise, the nudge is 0. + const __m256i zeros = _mm256_setzero_si256(); + const __m256i mask_rightshift_gtz = + intrin_utils::mm256_cmpgt_epi32<path>(right_shift, zeros); + const __m256i one_shift_exp_minus1 = + intrin_utils::mm256_sllv_epi32<path>( + _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>( + right_shift, _mm256_set1_epi32(1))); + __m256i nudge = intrin_utils::mm256_blendv_epi32( + zeros, one_shift_exp_minus1, mask_rightshift_gtz); + // Calculate the shifted sum (results + nudge) >> exp. + const __m256i r_plus_nudge = + intrin_utils::mm256_add_epi32<path>(results, nudge); + const __m256i shifted_sum = + intrin_utils::mm256_srav_epi32<path>(r_plus_nudge, right_shift); + + // Identify overflow in each lane and create mask. + const __m256i one_shift_31minus_exp = + intrin_utils::mm256_sllv_epi32<path>( + _mm256_set1_epi32(1), intrin_utils::mm256_sub_epi32<path>( + _mm256_set1_epi32(31), right_shift)); + const __m256i mask_num_plus_nudge_overflow = + intrin_utils::mm256_cmpgt_epi32<path>( + results, intrin_utils::mm256_sub_epi32<path>( + _mm256_set1_epi32(0x7fffffff), nudge)); + // Fill results with either (results + nudge) >> exponent or + // 1 << (31 - exp) in the case of overflow. + results = intrin_utils::mm256_blendv_epi32( + shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow); accum_data_v0 = - intrin_utils::mm256_sub_epi32<path>(results, post_scaling_offset); + intrin_utils::mm256_add_epi32<path>(results, post_scaling_offset); } } const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max); diff --git a/ruy/kernel_avx2_fma.cc b/ruy/kernel_avx2_fma.cc index 2d8f016..2aefbe3 100644 --- a/ruy/kernel_avx2_fma.cc +++ b/ruy/kernel_avx2_fma.cc @@ -88,6 +88,14 @@ inline __m128i mm256_extracti128_si256<Path::kAvx2Fma>(const __m256i& a, } } +__m256i mm256_blendv_epi32(const __m256i& a, const __m256i& b, + const __m256i& mask) { + __m256 result = + _mm256_blendv_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), + _mm256_castsi256_ps(mask)); + return _mm256_castps_si256(result); +} + } // namespace intrin_utils } // namespace @@ -352,43 +360,22 @@ void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) { const __m256i left_shift = _mm256_max_epi32(e_vector, zero_vector); const __m256i neg_e_vector = _mm256_sub_epi32(zero_vector, e_vector); const __m256i right_shift = _mm256_max_epi32(neg_e_vector, zero_vector); - const __m256i final_right_shift = - _mm256_add_epi32(right_shift, _mm256_set1_epi32(31)); + const __m256i final_right_shift = _mm256_set1_epi32(31); const __m256i final_right_shift_low = _mm256_cvtepi32_epi64( _mm256_extracti128_si256(final_right_shift, 0)); const __m256i final_right_shift_high = _mm256_cvtepi32_epi64( _mm256_extracti128_si256(final_right_shift, 1)); - // Really we want 0x100000000, but use half to avoid overflowing. - const __m256i convert_to_signed_halved = - _mm256_srlv_epi32(_mm256_set1_epi32(0x80000000), right_shift); const __m256i convert_to_unsigned_64 = _mm256_set1_epi64x(0x8000000000000000); - __m256i post_scaling_offset = _mm256_add_epi32( - convert_to_signed_halved, convert_to_signed_halved); - - const __m256i offset_vector = - _mm256_slli_epi64(_mm256_set1_epi64x(1), 30); - // Really these should be shifted by neg_e_vector, but tests pass when - // using right_shift. - const __m256i offset_vector_low = _mm256_add_epi64( - _mm256_sllv_epi64(offset_vector, - _mm256_cvtepi32_epi64( - _mm256_extracti128_si256(right_shift, 0))), - convert_to_unsigned_64); - const __m256i offset_vector_high = _mm256_add_epi64( - _mm256_sllv_epi64(offset_vector, - _mm256_cvtepi32_epi64( - _mm256_extracti128_si256(right_shift, 1))), + __m256i post_scaling_offset = _mm256_setzero_si256(); + // A "half" added for rounding prior to truncation of 64-bit value. + const __m256i offset_vector = _mm256_add_epi64( + _mm256_slli_epi64(_mm256_set1_epi64x(1), 30), convert_to_unsigned_64); if (params.dst_zero_point) { - const __m256i dst_zero_point = - _mm256_set1_epi32(params.dst_zero_point); - // The post-scaling offset is subtracted later, so this has the effect - // of adding the zero point. - post_scaling_offset = - _mm256_sub_epi32(post_scaling_offset, dst_zero_point); + post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point); } const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7); @@ -446,7 +433,7 @@ void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) { // __m256i results = // _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); // results = _mm256_permutevar8x32_epi32(results, repack_perm); - // accum_data_v[j] = _mm256_sub_epi32(results, post_scaling_offset); + // accum_data_v[j] = _mm256_add_epi32(results, post_scaling_offset); // This multiplier code is complex and expensive enough on x86, that // we prefer to implement the channels-are-columns case by transposing @@ -464,6 +451,35 @@ void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) { &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3, &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7); } + + auto rounding_right_shift = [=](__m256i& results, + const __m256i& exponent) { + // Construct the "nudge" value for each lane if the exponent is + // greater than 0. Otherwise, the nudge is 0. + const __m256i zeros = _mm256_setzero_si256(); + const __m256i mask_rightshift_gtz = + _mm256_cmpgt_epi32(exponent, zeros); + const __m256i one_shift_exp_minus1 = _mm256_sllv_epi32( + _mm256_set1_epi32(1), + _mm256_sub_epi32(exponent, _mm256_set1_epi32(1))); + __m256i nudge = intrin_utils::mm256_blendv_epi32( + zeros, one_shift_exp_minus1, mask_rightshift_gtz); + // Calculate the shifted sum (results + nudge) >> exp. + const __m256i r_plus_nudge = _mm256_add_epi32(results, nudge); + const __m256i shifted_sum = _mm256_srav_epi32(r_plus_nudge, exponent); + + // Identify overflow in each lane and create mask. + const __m256i one_shift_31minus_exp = _mm256_sllv_epi32( + _mm256_set1_epi32(1), + _mm256_sub_epi32(_mm256_set1_epi32(31), exponent)); + const __m256i mask_num_plus_nudge_overflow = _mm256_cmpgt_epi32( + results, _mm256_sub_epi32(_mm256_set1_epi32(0x7fffffff), nudge)); + // Fill results with either (results + nudge) >> exponent or + // 1 << (31 - exp) in the case of overflow. + results = intrin_utils::mm256_blendv_epi32( + shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow); + }; + auto apply_multiplier = [=](__m256i& accum) { __m256i shifted_accum = _mm256_sllv_epi32(accum, left_shift); // Apply the fixed-point part of the multiplier. @@ -474,8 +490,8 @@ void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) { _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), m_64bit_high); - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); + scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector); + scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector); scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); scaled_v_high = @@ -485,8 +501,9 @@ void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) { __m256i results = _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum = _mm256_sub_epi32(results, post_scaling_offset); + // Now do a Rounding Right Shift. + rounding_right_shift(results, right_shift); + accum = _mm256_add_epi32(results, post_scaling_offset); }; apply_multiplier(accum_data_v0); apply_multiplier(accum_data_v1); @@ -856,42 +873,22 @@ void Kernel8bitAvx2SingleColImpl(const KernelParams8bit<8, 8>& params) { const __m256i left_shift = _mm256_max_epi32(e_vector, zero_vector); const __m256i neg_e_vector = _mm256_sub_epi32(zero_vector, e_vector); const __m256i right_shift = _mm256_max_epi32(neg_e_vector, zero_vector); - const __m256i final_right_shift = - _mm256_add_epi32(right_shift, _mm256_set1_epi32(31)); + const __m256i final_right_shift = _mm256_set1_epi32(31); const __m256i final_right_shift_low = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(final_right_shift, 0)); const __m256i final_right_shift_high = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(final_right_shift, 1)); - // Really we want 0x100000000, but use half to avoid overflowing. - const __m256i convert_to_signed_halved = - _mm256_srlv_epi32(_mm256_set1_epi32(0x80000000), right_shift); const __m256i convert_to_unsigned_64 = _mm256_set1_epi64x(0x8000000000000000); - __m256i post_scaling_offset = - _mm256_add_epi32(convert_to_signed_halved, convert_to_signed_halved); - - const __m256i offset_vector = - _mm256_slli_epi64(_mm256_set1_epi64x(1), 30); - // Really these should be shifted by neg_e_vector, but tests pass when - // using right_shift. - const __m256i offset_vector_low = _mm256_add_epi64( - _mm256_sllv_epi64( - offset_vector, - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(right_shift, 0))), - convert_to_unsigned_64); - const __m256i offset_vector_high = _mm256_add_epi64( - _mm256_sllv_epi64( - offset_vector, - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(right_shift, 1))), + __m256i post_scaling_offset = _mm256_setzero_si256(); + // A "half" added for rounding prior to truncation of 64-bit value. + const __m256i offset_vector = _mm256_add_epi64( + _mm256_slli_epi64(_mm256_set1_epi64x(1), 30), convert_to_unsigned_64); if (params.dst_zero_point) { - const __m256i dst_zero_point = _mm256_set1_epi32(params.dst_zero_point); - // The post-scaling offset is subtracted later, so this has the effect - // of adding the zero point. - post_scaling_offset = - _mm256_sub_epi32(post_scaling_offset, dst_zero_point); + post_scaling_offset = _mm256_set1_epi32(params.dst_zero_point); } const __m256i repack_perm = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7); @@ -907,8 +904,8 @@ void Kernel8bitAvx2SingleColImpl(const KernelParams8bit<8, 8>& params) { _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), m_64bit_high); - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); + scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector); + scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector); scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); scaled_v_high = @@ -918,7 +915,33 @@ void Kernel8bitAvx2SingleColImpl(const KernelParams8bit<8, 8>& params) { __m256i results = _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); results = _mm256_permutevar8x32_epi32(results, repack_perm); - accum_data_v0 = _mm256_sub_epi32(results, post_scaling_offset); + // Now do a Rounding Right Shift. + // First, construct the nudge value for each lane. + const __m256i zeros = _mm256_setzero_si256(); + const __m256i mask_rightshift_gtz = + _mm256_cmpgt_epi32(right_shift, zeros); + const __m256i one_shift_exp_minus1 = _mm256_sllv_epi32( + _mm256_set1_epi32(1), + _mm256_sub_epi32(right_shift, _mm256_set1_epi32(1))); + __m256i nudge = intrin_utils::mm256_blendv_epi32( + zeros, one_shift_exp_minus1, mask_rightshift_gtz); + // Calculate the shifted sum (results + nudge) >> exp. + const __m256i r_plus_nudge = _mm256_add_epi32(results, nudge); + const __m256i shifted_sum = + _mm256_srav_epi32(r_plus_nudge, right_shift); + + // Identify overflow in each lane and create mask. + const __m256i one_shift_31minus_exp = _mm256_sllv_epi32( + _mm256_set1_epi32(1), + _mm256_sub_epi32(_mm256_set1_epi32(31), right_shift)); + const __m256i mask_num_plus_nudge_overflow = _mm256_cmpgt_epi32( + results, _mm256_sub_epi32(_mm256_set1_epi32(0x7fffffff), nudge)); + // Fill results with either (results + nudge) >> exponent or + // 1 << (31 - exp) in the case of overflow. + results = intrin_utils::mm256_blendv_epi32( + shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow); + + accum_data_v0 = _mm256_add_epi32(results, post_scaling_offset); } } const __m256i clamp_max_v = _mm256_set1_epi32(params.clamp_max); diff --git a/ruy/kernel_avx512.cc b/ruy/kernel_avx512.cc index 6a65ca7..fddb482 100644 --- a/ruy/kernel_avx512.cc +++ b/ruy/kernel_avx512.cc @@ -52,6 +52,45 @@ void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>&) { #else // RUY_PLATFORM_AVX512 && RUY_OPT(ASM) +namespace { +namespace intrin_utils { + +__m256i mm256_blendv_epi64(const __m256i& a, const __m256i& b, + const __m256i& mask) { + __m256d result = + _mm256_blendv_pd(_mm256_castsi256_pd(a), _mm256_castsi256_pd(b), + _mm256_castsi256_pd(mask)); + return _mm256_castpd_si256(result); +} + +__m512i mm512_blendv_epi64(const __m512i& a, const __m512i& b, + const __m512i& mask) { + __m256i a_lo = _mm512_extracti64x4_epi64(a, 0); + __m256i a_hi = _mm512_extracti64x4_epi64(a, 1); + __m256i b_lo = _mm512_extracti64x4_epi64(b, 0); + __m256i b_hi = _mm512_extracti64x4_epi64(b, 1); + __m256i mask_lo = _mm512_extracti64x4_epi64(mask, 0); + __m256i mask_hi = _mm512_extracti64x4_epi64(mask, 1); + __m256i lo = mm256_blendv_epi64(a_lo, b_lo, mask_lo); + __m256i hi = mm256_blendv_epi64(a_hi, b_hi, mask_hi); + __m512i result = _mm512_inserti64x4(_mm512_setzero_si512(), lo, 0); + return _mm512_inserti64x4(result, hi, 1); +} + +__m512i mm512_cmpgt_epi64(const __m512i& a, const __m512i& b) { + __m256i a_lo = _mm512_extracti64x4_epi64(a, 0); + __m256i a_hi = _mm512_extracti64x4_epi64(a, 1); + __m256i b_lo = _mm512_extracti64x4_epi64(b, 0); + __m256i b_hi = _mm512_extracti64x4_epi64(b, 1); + __m256i lo = _mm256_cmpgt_epi64(a_lo, b_lo); + __m256i hi = _mm256_cmpgt_epi64(a_hi, b_hi); + __m512i result = _mm512_inserti64x4(_mm512_setzero_si512(), lo, 0); + return _mm512_inserti64x4(result, hi, 1); +} + +} // namespace intrin_utils +} // namespace + void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) { profiler::ScopeLabel label("Kernel kAvx512 8-bit"); @@ -333,23 +372,49 @@ void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) { const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector); const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector); const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector); - const __m512i final_right_shift = - _mm512_add_epi32(right_shift, _mm512_set1_epi32(31)); + const __m512i final_right_shift = _mm512_set1_epi32(31); + const __m512i right_shift_low = + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0)); + const __m512i right_shift_high = + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1)); const __m512i final_right_shift_low = _mm512_cvtepi32_epi64( _mm512_extracti32x8_epi32(final_right_shift, 0)); const __m512i final_right_shift_high = _mm512_cvtepi32_epi64( _mm512_extracti32x8_epi32(final_right_shift, 1)); + // A "half" added for rounding prior to truncation of 64-bit value. const __m512i offset_vector = _mm512_slli_epi64(_mm512_set1_epi64(1), 30); - // Really these should be shifted by neg_e_vector, but tests pass when - // using right_shift. - const __m512i offset_vector_low = _mm512_sllv_epi64( - offset_vector, - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0))); - const __m512i offset_vector_high = _mm512_sllv_epi64( - offset_vector, - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1))); + + auto rounding_right_shift = [=](__m512i& results, + const __m512i& exponent) { + // Construct the "nudge" value for each lane if the exponent is + // greater than 0. Otherwise, the nudge is 0. + const __m512i zeros = _mm512_setzero_si512(); + const __m512i mask_rightshift_gtz = + intrin_utils::mm512_cmpgt_epi64(exponent, zeros); + const __m512i one_shift_exp_minus1 = _mm512_sllv_epi64( + _mm512_set1_epi64(1), + _mm512_sub_epi64(exponent, _mm512_set1_epi64(1))); + __m512i nudge = intrin_utils::mm512_blendv_epi64( + zeros, one_shift_exp_minus1, mask_rightshift_gtz); + // Calculate the shifted sum (results + nudge) >> exp. + const __m512i r_plus_nudge = _mm512_add_epi64(results, nudge); + const __m512i shifted_sum = _mm512_srav_epi64(r_plus_nudge, exponent); + + // Identify overflow in each lane and create mask. + const __m512i one_shift_31minus_exp = _mm512_sllv_epi64( + _mm512_set1_epi64(1), + _mm512_sub_epi64(_mm512_set1_epi64(31), exponent)); + const __m512i mask_num_plus_nudge_overflow = + intrin_utils::mm512_cmpgt_epi64( + results, + _mm512_sub_epi64(_mm512_set1_epi64(0x7fffffff), nudge)); + // Fill results with either (results + nudge) >> exponent or + // 1 << (31 - exp) in the case of overflow. + results = intrin_utils::mm512_blendv_epi64( + shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow); + }; if (per_column_multiplier) { auto apply_multiplier = [=](__m512i& accum, int col) { @@ -360,11 +425,12 @@ void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) { __m512i m_64bit_val = _mm512_permutexvar_epi64( perm_64bit_vals, col < 8 ? m_64bit_low : m_64bit_high); __m512i offset_vector_val = _mm512_permutexvar_epi64( - perm_64bit_vals, - col < 8 ? offset_vector_low : offset_vector_high); + perm_64bit_vals, offset_vector); __m512i final_right_shift_val = _mm512_permutexvar_epi64( perm_64bit_vals, col < 8 ? final_right_shift_low : final_right_shift_high); + __m512i right_shift_val = _mm512_permutexvar_epi64( + perm_64bit_vals, col < 8 ? right_shift_low : right_shift_high); accum = _mm512_sllv_epi32(accum, left_shift_val); __m512i scaled_v_low = _mm512_mul_epi32( @@ -382,6 +448,9 @@ void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) { scaled_v_high = _mm512_srav_epi64(scaled_v_high, final_right_shift_val); + rounding_right_shift(scaled_v_low, right_shift_val); + rounding_right_shift(scaled_v_high, right_shift_val); + accum = _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); accum = _mm512_inserti32x8(accum, _mm512_cvtepi64_epi32(scaled_v_high), 1); @@ -413,14 +482,16 @@ void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) { _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 1)), m_64bit_high); - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector); scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); scaled_v_high = _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + rounding_right_shift(scaled_v_low, right_shift_low); + rounding_right_shift(scaled_v_high, right_shift_high); accum = _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); accum = _mm512_inserti32x8(accum, _mm512_cvtepi64_epi32(scaled_v_high), 1); @@ -713,22 +784,48 @@ void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) { const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector); const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector); const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector); - const __m512i final_right_shift = - _mm512_add_epi32(right_shift, _mm512_set1_epi32(31)); + const __m512i final_right_shift = _mm512_set1_epi32(31); + const __m512i right_shift_low = + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0)); + const __m512i right_shift_high = + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1)); const __m512i final_right_shift_low = _mm512_cvtepi32_epi64( _mm512_extracti32x8_epi32(final_right_shift, 0)); const __m512i final_right_shift_high = _mm512_cvtepi32_epi64( _mm512_extracti32x8_epi32(final_right_shift, 1)); + // A "half" added for rounding prior to truncation of 64-bit value. const __m512i offset_vector = _mm512_slli_epi64(_mm512_set1_epi64(1), 30); - // Really these should be shifted by neg_e_vector, but tests pass when - // using right_shift. - const __m512i offset_vector_low = _mm512_sllv_epi64( - offset_vector, - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0))); - const __m512i offset_vector_high = _mm512_sllv_epi64( - offset_vector, - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1))); + + auto rounding_right_shift = [=](__m512i& results, + const __m512i& exponent) { + // Construct the "nudge" value for each lane if the exponent is + // greater than 0. Otherwise, the nudge is 0. + const __m512i zeros = _mm512_setzero_si512(); + const __m512i mask_rightshift_gtz = + intrin_utils::mm512_cmpgt_epi64(exponent, zeros); + const __m512i one_shift_exp_minus1 = + _mm512_sllv_epi64(_mm512_set1_epi64(1), + _mm512_sub_epi64(exponent, _mm512_set1_epi64(1))); + __m512i nudge = intrin_utils::mm512_blendv_epi64( + zeros, one_shift_exp_minus1, mask_rightshift_gtz); + // Calculate the shifted sum (results + nudge) >> exp. + const __m512i r_plus_nudge = _mm512_add_epi64(results, nudge); + const __m512i shifted_sum = _mm512_srav_epi64(r_plus_nudge, exponent); + + // Identify overflow in each lane and create mask. + const __m512i one_shift_31minus_exp = _mm512_sllv_epi64( + _mm512_set1_epi64(1), + _mm512_sub_epi64(_mm512_set1_epi64(31), exponent)); + const __m512i mask_num_plus_nudge_overflow = + intrin_utils::mm512_cmpgt_epi64( + results, + _mm512_sub_epi64(_mm512_set1_epi64(0x7fffffff), nudge)); + // Fill results with either (results + nudge) >> exponent or + // 1 << (31 - exp) in the case of overflow. + results = intrin_utils::mm512_blendv_epi64( + shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow); + }; // Shift and round column 0. accum_data_v0 = _mm512_sllv_epi32(accum_data_v0, left_shift); @@ -740,12 +837,15 @@ void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) { _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum_data_v0, 1)), m_64bit_high); - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector); scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); scaled_v_high = _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + rounding_right_shift(scaled_v_low, right_shift_low); + rounding_right_shift(scaled_v_high, right_shift_high); + accum_data_v0 = _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); accum_data_v0 = _mm512_inserti32x8( |