Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/google/ruy.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRuy Contributors <ruy-eng@google.com>2021-04-22 21:00:46 +0300
committerCopybara-Service <copybara-worker@google.com>2021-04-22 21:01:04 +0300
commit516761f8278df5dd4a3a0a4d382a31b26767e7a3 (patch)
treec16574ebd477186995ad4c89a434cc2fe7b8c7a3
parent1c518e24f7a043ae26c5c6b8f6b7f0946bc013df (diff)
1.02x speedup of Ruy AVX2 f32 and AVX-512 f32/i8
AVX-512: - broadcast without extra instruction (code size) - use native mask ops - re-roll mmm loop AVX2: avoid slow permute, especially for AMD PiperOrigin-RevId: 369907385
-rw-r--r--ruy/kernel_avx512.cc428
-rw-r--r--ruy/kernel_x86.h16
2 files changed, 98 insertions, 346 deletions
diff --git a/ruy/kernel_avx512.cc b/ruy/kernel_avx512.cc
index fddb482..84b9380 100644
--- a/ruy/kernel_avx512.cc
+++ b/ruy/kernel_avx512.cc
@@ -52,45 +52,6 @@ void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>&) {
#else // RUY_PLATFORM_AVX512 && RUY_OPT(ASM)
-namespace {
-namespace intrin_utils {
-
-__m256i mm256_blendv_epi64(const __m256i& a, const __m256i& b,
- const __m256i& mask) {
- __m256d result =
- _mm256_blendv_pd(_mm256_castsi256_pd(a), _mm256_castsi256_pd(b),
- _mm256_castsi256_pd(mask));
- return _mm256_castpd_si256(result);
-}
-
-__m512i mm512_blendv_epi64(const __m512i& a, const __m512i& b,
- const __m512i& mask) {
- __m256i a_lo = _mm512_extracti64x4_epi64(a, 0);
- __m256i a_hi = _mm512_extracti64x4_epi64(a, 1);
- __m256i b_lo = _mm512_extracti64x4_epi64(b, 0);
- __m256i b_hi = _mm512_extracti64x4_epi64(b, 1);
- __m256i mask_lo = _mm512_extracti64x4_epi64(mask, 0);
- __m256i mask_hi = _mm512_extracti64x4_epi64(mask, 1);
- __m256i lo = mm256_blendv_epi64(a_lo, b_lo, mask_lo);
- __m256i hi = mm256_blendv_epi64(a_hi, b_hi, mask_hi);
- __m512i result = _mm512_inserti64x4(_mm512_setzero_si512(), lo, 0);
- return _mm512_inserti64x4(result, hi, 1);
-}
-
-__m512i mm512_cmpgt_epi64(const __m512i& a, const __m512i& b) {
- __m256i a_lo = _mm512_extracti64x4_epi64(a, 0);
- __m256i a_hi = _mm512_extracti64x4_epi64(a, 1);
- __m256i b_lo = _mm512_extracti64x4_epi64(b, 0);
- __m256i b_hi = _mm512_extracti64x4_epi64(b, 1);
- __m256i lo = _mm256_cmpgt_epi64(a_lo, b_lo);
- __m256i hi = _mm256_cmpgt_epi64(a_hi, b_hi);
- __m512i result = _mm512_inserti64x4(_mm512_setzero_si512(), lo, 0);
- return _mm512_inserti64x4(result, hi, 1);
-}
-
-} // namespace intrin_utils
-} // namespace
-
void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) {
profiler::ScopeLabel label("Kernel kAvx512 8-bit");
@@ -391,13 +352,13 @@ void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) {
// Construct the "nudge" value for each lane if the exponent is
// greater than 0. Otherwise, the nudge is 0.
const __m512i zeros = _mm512_setzero_si512();
- const __m512i mask_rightshift_gtz =
- intrin_utils::mm512_cmpgt_epi64(exponent, zeros);
+ const auto mask_rightshift_gtz =
+ _mm512_cmpgt_epi64_mask(exponent, zeros);
const __m512i one_shift_exp_minus1 = _mm512_sllv_epi64(
_mm512_set1_epi64(1),
_mm512_sub_epi64(exponent, _mm512_set1_epi64(1)));
- __m512i nudge = intrin_utils::mm512_blendv_epi64(
- zeros, one_shift_exp_minus1, mask_rightshift_gtz);
+ __m512i nudge = _mm512_mask_mov_epi64(zeros, mask_rightshift_gtz,
+ one_shift_exp_minus1);
// Calculate the shifted sum (results + nudge) >> exp.
const __m512i r_plus_nudge = _mm512_add_epi64(results, nudge);
const __m512i shifted_sum = _mm512_srav_epi64(r_plus_nudge, exponent);
@@ -406,14 +367,12 @@ void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) {
const __m512i one_shift_31minus_exp = _mm512_sllv_epi64(
_mm512_set1_epi64(1),
_mm512_sub_epi64(_mm512_set1_epi64(31), exponent));
- const __m512i mask_num_plus_nudge_overflow =
- intrin_utils::mm512_cmpgt_epi64(
- results,
- _mm512_sub_epi64(_mm512_set1_epi64(0x7fffffff), nudge));
+ const auto mask_num_plus_nudge_overflow = _mm512_cmpgt_epi64_mask(
+ results, _mm512_sub_epi64(_mm512_set1_epi64(0x7fffffff), nudge));
// Fill results with either (results + nudge) >> exponent or
// 1 << (31 - exp) in the case of overflow.
- results = intrin_utils::mm512_blendv_epi64(
- shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow);
+ results = _mm512_mask_mov_epi64(
+ shifted_sum, mask_num_plus_nudge_overflow, one_shift_31minus_exp);
};
if (per_column_multiplier) {
@@ -424,8 +383,8 @@ void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) {
_mm512_permutexvar_epi32(_mm512_set1_epi32(col), left_shift);
__m512i m_64bit_val = _mm512_permutexvar_epi64(
perm_64bit_vals, col < 8 ? m_64bit_low : m_64bit_high);
- __m512i offset_vector_val = _mm512_permutexvar_epi64(
- perm_64bit_vals, offset_vector);
+ __m512i offset_vector_val =
+ _mm512_permutexvar_epi64(perm_64bit_vals, offset_vector);
__m512i final_right_shift_val = _mm512_permutexvar_epi64(
perm_64bit_vals,
col < 8 ? final_right_shift_low : final_right_shift_high);
@@ -802,13 +761,13 @@ void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) {
// Construct the "nudge" value for each lane if the exponent is
// greater than 0. Otherwise, the nudge is 0.
const __m512i zeros = _mm512_setzero_si512();
- const __m512i mask_rightshift_gtz =
- intrin_utils::mm512_cmpgt_epi64(exponent, zeros);
+ const auto mask_rightshift_gtz =
+ _mm512_cmpgt_epi64_mask(exponent, zeros);
const __m512i one_shift_exp_minus1 =
_mm512_sllv_epi64(_mm512_set1_epi64(1),
_mm512_sub_epi64(exponent, _mm512_set1_epi64(1)));
- __m512i nudge = intrin_utils::mm512_blendv_epi64(
- zeros, one_shift_exp_minus1, mask_rightshift_gtz);
+ __m512i nudge = _mm512_mask_mov_epi64(zeros, mask_rightshift_gtz,
+ one_shift_exp_minus1);
// Calculate the shifted sum (results + nudge) >> exp.
const __m512i r_plus_nudge = _mm512_add_epi64(results, nudge);
const __m512i shifted_sum = _mm512_srav_epi64(r_plus_nudge, exponent);
@@ -817,14 +776,12 @@ void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) {
const __m512i one_shift_31minus_exp = _mm512_sllv_epi64(
_mm512_set1_epi64(1),
_mm512_sub_epi64(_mm512_set1_epi64(31), exponent));
- const __m512i mask_num_plus_nudge_overflow =
- intrin_utils::mm512_cmpgt_epi64(
- results,
- _mm512_sub_epi64(_mm512_set1_epi64(0x7fffffff), nudge));
+ const auto mask_num_plus_nudge_overflow = _mm512_cmpgt_epi64_mask(
+ results, _mm512_sub_epi64(_mm512_set1_epi64(0x7fffffff), nudge));
// Fill results with either (results + nudge) >> exponent or
// 1 << (31 - exp) in the case of overflow.
- results = intrin_utils::mm512_blendv_epi64(
- shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow);
+ results = _mm512_mask_mov_epi64(
+ shifted_sum, mask_num_plus_nudge_overflow, one_shift_31minus_exp);
};
// Shift and round column 0.
@@ -930,9 +887,8 @@ void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) {
float* dst_ptr = dst_col_ptr + row;
// Process block in two halves, split by columns.
- {
- constexpr int mmm = 0;
-
+#pragma unroll(1)
+ for (int mmm = 0; mmm < 2; ++mmm) {
__m512 accum_data_v0;
__m512 accum_data_v1;
__m512 accum_data_v2;
@@ -972,81 +928,49 @@ void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) {
const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
for (int d = 0; d < (params.depth - 1); ++d) {
const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
- // In this version RHS values are loaded individually rather than
- // first loading together and then extract with broadcasting. This is
- // because AVX flavours and instrinsics and compilers in combination
- // do not handle this pattern of extraction very well.
const float* rhs_data = rhs_ptr;
lhs_ptr += 16;
rhs_ptr += 16;
- {
- // Load 8 float32 values.
- __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data));
- __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0); // [0 1 2 3] X 4
- __m512 rhs4_7 =
- _mm512_shuffle_f32x4(rhs, rhs, 0x55); // [4 5 6 7] X 4
-
- const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0);
- accum_data_v0 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
- const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55);
- accum_data_v1 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
- const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa);
- accum_data_v2 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
- const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff);
- accum_data_v3 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
- const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0);
- accum_data_v4 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
- const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55);
- accum_data_v5 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
- const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa);
- accum_data_v6 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
- const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff);
- accum_data_v7 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
- }
+ // GCC and clang can fuse set1+FMA into an FMA with EVEX broadcast:
+ // https://gcc.godbolt.org/z/xbfqWYfn1. Clang is more likely to do
+ // so if given an rvalue.
+ accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]),
+ accum_data_v0);
+ accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]),
+ accum_data_v1);
+ accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]),
+ accum_data_v2);
+ accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]),
+ accum_data_v3);
+ accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]),
+ accum_data_v4);
+ accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]),
+ accum_data_v5);
+ accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]),
+ accum_data_v6);
+ accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]),
+ accum_data_v7);
}
- {
+ { // nested extra blocks lead to measurable speed gains
const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
const float* rhs_data = rhs_ptr;
- {
- // Load 8 float32 values.
- __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data));
- __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0); // [0 1 2 3] X 4
- __m512 rhs4_7 =
- _mm512_shuffle_f32x4(rhs, rhs, 0x55); // [4 5 6 7] X 4
- const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0);
- accum_data_v0 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
- const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55);
- accum_data_v1 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
- const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa);
- accum_data_v2 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
- const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff);
- accum_data_v3 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
- const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0);
- accum_data_v4 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
- const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55);
- accum_data_v5 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
- const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa);
- accum_data_v6 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
- const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff);
- accum_data_v7 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
- }
+ accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]),
+ accum_data_v0);
+ accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]),
+ accum_data_v1);
+ accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]),
+ accum_data_v2);
+ accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]),
+ accum_data_v3);
+ accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]),
+ accum_data_v4);
+ accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]),
+ accum_data_v5);
+ accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]),
+ accum_data_v6);
+ accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]),
+ accum_data_v7);
{
float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
@@ -1075,147 +999,7 @@ void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) {
_mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7);
}
}
- } // Inner half-block loop, unrolled, first iteration.
- {
- constexpr int mmm = 1;
-
- __m512 accum_data_v0;
- __m512 accum_data_v1;
- __m512 accum_data_v2;
- __m512 accum_data_v3;
- __m512 accum_data_v4;
- __m512 accum_data_v5;
- __m512 accum_data_v6;
- __m512 accum_data_v7;
-
- // Initialize with bias.
- if (channel_dimension_is_col) {
- const float* bias_elem_ptr =
- bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment;
- accum_data_v0 = _mm512_set1_ps(bias_elem_ptr[0]);
- accum_data_v1 = _mm512_set1_ps(bias_elem_ptr[1]);
- accum_data_v2 = _mm512_set1_ps(bias_elem_ptr[2]);
- accum_data_v3 = _mm512_set1_ps(bias_elem_ptr[3]);
- accum_data_v4 = _mm512_set1_ps(bias_elem_ptr[4]);
- accum_data_v5 = _mm512_set1_ps(bias_elem_ptr[5]);
- accum_data_v6 = _mm512_set1_ps(bias_elem_ptr[6]);
- accum_data_v7 = _mm512_set1_ps(bias_elem_ptr[7]);
- } else {
- const __m512 initial_accum_data =
- _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment);
-
- accum_data_v0 = initial_accum_data;
- accum_data_v1 = initial_accum_data;
- accum_data_v2 = initial_accum_data;
- accum_data_v3 = initial_accum_data;
- accum_data_v4 = initial_accum_data;
- accum_data_v5 = initial_accum_data;
- accum_data_v6 = initial_accum_data;
- accum_data_v7 = initial_accum_data;
- }
- const float* lhs_ptr = lhs_col_ptr;
- const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
- for (int d = 0; d < (params.depth - 1); ++d) {
- const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
- const float* rhs_data = rhs_ptr;
- lhs_ptr += 16;
- rhs_ptr += 16;
- {
- // Load 8 float32 values.
- __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data));
- __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0); // [0 1 2 3] X 4
- __m512 rhs4_7 =
- _mm512_shuffle_f32x4(rhs, rhs, 0x55); // [4 5 6 7] X 4
-
- const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0);
- accum_data_v0 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
- const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55);
- accum_data_v1 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
- const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa);
- accum_data_v2 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
- const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff);
- accum_data_v3 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
- const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0);
- accum_data_v4 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
- const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55);
- accum_data_v5 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
- const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa);
- accum_data_v6 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
- const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff);
- accum_data_v7 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
- }
- }
- {
- const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
- const float* rhs_data = rhs_ptr;
- {
- // Load 8 float32 values.
- __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data));
- __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0); // [0 1 2 3] X 4
- __m512 rhs4_7 =
- _mm512_shuffle_f32x4(rhs, rhs, 0x55); // [4 5 6 7] X 4
- const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0);
- accum_data_v0 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
- const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55);
- accum_data_v1 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
- const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa);
- accum_data_v2 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
- const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff);
- accum_data_v3 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
- const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0);
- accum_data_v4 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
- const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55);
- accum_data_v5 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
- const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa);
- accum_data_v6 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
- const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff);
- accum_data_v7 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
- }
- {
- float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
- accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
- accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v);
- _mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0);
- accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v);
- accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v);
- _mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1);
- accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v);
- accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v);
- _mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2);
- accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v);
- accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v);
- _mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3);
- accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v);
- accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v);
- _mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4);
- accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v);
- accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v);
- _mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5);
- accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v);
- accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v);
- _mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6);
- accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v);
- accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v);
- _mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7);
- }
- }
- } // Inner half-block loop, unrolled, second iteration.
+ }
} // End row-block loop.
// The unrolling within this conditional may be somewhat pointless. It
@@ -1273,73 +1057,45 @@ void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) {
const float* rhs_data = rhs_ptr;
lhs_ptr += 16;
rhs_ptr += 16;
- {
- // Load 8 float32 values.
- __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data));
- __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0); // [0 1 2 3] X 4
- __m512 rhs4_7 =
- _mm512_shuffle_f32x4(rhs, rhs, 0x55); // [4 5 6 7] X 4
-
- const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0);
- accum_data_v0 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
- const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55);
- accum_data_v1 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
- const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa);
- accum_data_v2 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
- const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff);
- accum_data_v3 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
- const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0);
- accum_data_v4 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
- const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55);
- accum_data_v5 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
- const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa);
- accum_data_v6 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
- const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff);
- accum_data_v7 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
- }
+ // GCC and clang can fuse set1+FMA into an FMA with EVEX broadcast:
+ // https://gcc.godbolt.org/z/xbfqWYfn1. Clang is more likely to do
+ // so if given an rvalue.
+ accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]),
+ accum_data_v0);
+ accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]),
+ accum_data_v1);
+ accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]),
+ accum_data_v2);
+ accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]),
+ accum_data_v3);
+ accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]),
+ accum_data_v4);
+ accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]),
+ accum_data_v5);
+ accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]),
+ accum_data_v6);
+ accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]),
+ accum_data_v7);
}
{
const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr);
const float* rhs_data = rhs_ptr;
- {
- // Load 8 float32 values.
- __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data));
- __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0); // [0 1 2 3] X 4
- __m512 rhs4_7 =
- _mm512_shuffle_f32x4(rhs, rhs, 0x55); // [4 5 6 7] X 4
- const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0);
- accum_data_v0 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0);
- const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55);
- accum_data_v1 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1);
- const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa);
- accum_data_v2 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2);
- const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff);
- accum_data_v3 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3);
- const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0);
- accum_data_v4 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4);
- const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55);
- accum_data_v5 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5);
- const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa);
- accum_data_v6 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6);
- const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff);
- accum_data_v7 =
- _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7);
- }
+ accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]),
+ accum_data_v0);
+ accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]),
+ accum_data_v1);
+ accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]),
+ accum_data_v2);
+ accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]),
+ accum_data_v3);
+ accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]),
+ accum_data_v4);
+ accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]),
+ accum_data_v5);
+ accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]),
+ accum_data_v6);
+ accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]),
+ accum_data_v7);
{
float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride;
accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v);
diff --git a/ruy/kernel_x86.h b/ruy/kernel_x86.h
index 2f8fe19..b716502 100644
--- a/ruy/kernel_x86.h
+++ b/ruy/kernel_x86.h
@@ -607,14 +607,12 @@ inline void KernelFloatAvxCommon(const KernelParamsFloat<8, 8>& params) {
const float* rhs_ptr = rhs_col_ptr;
for (int d = 0; d < params.depth; ++d) {
const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
- const float* rhs_data = rhs_ptr;
- // Load 8 RHS values, then use permute instructions to
- // broadcast each value to a register.
- __m256 rhs1 = _mm256_loadu_ps(rhs_data); // Load [0 1 2 3 4 5 6 7]
+ // Load 8 RHS values, then use permute instructions to broadcast each
+ // value to a register. _mm256_permute2f128_ps is slow on AMD.
__m256 rhs0_3 =
- _mm256_permute2f128_ps(rhs1, rhs1, 0); // [0 1 2 3 0 1 2 3]
+ _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr));
__m256 rhs4_7 =
- _mm256_permute2f128_ps(rhs1, rhs1, 17); // [4 5 6 7 4 5 6 7]
+ _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr + 4));
const __m256 dup_rhs_element_0 = _mm256_permute_ps(rhs0_3, 0);
accum_data_v[0] = intrin_utils::MulAdd<path>(
@@ -707,13 +705,11 @@ inline void KernelFloatAvxCommon(const KernelParamsFloat<8, 8>& params) {
const float* rhs_ptr = rhs_col_ptr;
for (int d = 0; d < params.depth; ++d) {
const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
- const float* rhs_data = rhs_ptr;
- __m256 rhs1 = _mm256_loadu_ps(rhs_data); // Load [0 1 2 3 4 5 6 7]
__m256 rhs0_3 =
- _mm256_permute2f128_ps(rhs1, rhs1, 0); // [0 1 2 3 0 1 2 3]
+ _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr));
__m256 rhs4_7 =
- _mm256_permute2f128_ps(rhs1, rhs1, 17); // [4 5 6 7 4 5 6 7]
+ _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr + 4));
const __m256 dup_rhs_element_0 = _mm256_permute_ps(rhs0_3, 0);
accum_data_v[0] = intrin_utils::MulAdd<path>(