From 516761f8278df5dd4a3a0a4d382a31b26767e7a3 Mon Sep 17 00:00:00 2001 From: Ruy Contributors Date: Thu, 22 Apr 2021 11:00:46 -0700 Subject: 1.02x speedup of Ruy AVX2 f32 and AVX-512 f32/i8 AVX-512: - broadcast without extra instruction (code size) - use native mask ops - re-roll mmm loop AVX2: avoid slow permute, especially for AMD PiperOrigin-RevId: 369907385 --- ruy/kernel_avx512.cc | 428 +++++++++++---------------------------------------- ruy/kernel_x86.h | 16 +- 2 files changed, 98 insertions(+), 346 deletions(-) diff --git a/ruy/kernel_avx512.cc b/ruy/kernel_avx512.cc index fddb482..84b9380 100644 --- a/ruy/kernel_avx512.cc +++ b/ruy/kernel_avx512.cc @@ -52,45 +52,6 @@ void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>&) { #else // RUY_PLATFORM_AVX512 && RUY_OPT(ASM) -namespace { -namespace intrin_utils { - -__m256i mm256_blendv_epi64(const __m256i& a, const __m256i& b, - const __m256i& mask) { - __m256d result = - _mm256_blendv_pd(_mm256_castsi256_pd(a), _mm256_castsi256_pd(b), - _mm256_castsi256_pd(mask)); - return _mm256_castpd_si256(result); -} - -__m512i mm512_blendv_epi64(const __m512i& a, const __m512i& b, - const __m512i& mask) { - __m256i a_lo = _mm512_extracti64x4_epi64(a, 0); - __m256i a_hi = _mm512_extracti64x4_epi64(a, 1); - __m256i b_lo = _mm512_extracti64x4_epi64(b, 0); - __m256i b_hi = _mm512_extracti64x4_epi64(b, 1); - __m256i mask_lo = _mm512_extracti64x4_epi64(mask, 0); - __m256i mask_hi = _mm512_extracti64x4_epi64(mask, 1); - __m256i lo = mm256_blendv_epi64(a_lo, b_lo, mask_lo); - __m256i hi = mm256_blendv_epi64(a_hi, b_hi, mask_hi); - __m512i result = _mm512_inserti64x4(_mm512_setzero_si512(), lo, 0); - return _mm512_inserti64x4(result, hi, 1); -} - -__m512i mm512_cmpgt_epi64(const __m512i& a, const __m512i& b) { - __m256i a_lo = _mm512_extracti64x4_epi64(a, 0); - __m256i a_hi = _mm512_extracti64x4_epi64(a, 1); - __m256i b_lo = _mm512_extracti64x4_epi64(b, 0); - __m256i b_hi = _mm512_extracti64x4_epi64(b, 1); - __m256i lo = _mm256_cmpgt_epi64(a_lo, b_lo); - __m256i hi = _mm256_cmpgt_epi64(a_hi, b_hi); - __m512i result = _mm512_inserti64x4(_mm512_setzero_si512(), lo, 0); - return _mm512_inserti64x4(result, hi, 1); -} - -} // namespace intrin_utils -} // namespace - void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) { profiler::ScopeLabel label("Kernel kAvx512 8-bit"); @@ -391,13 +352,13 @@ void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) { // Construct the "nudge" value for each lane if the exponent is // greater than 0. Otherwise, the nudge is 0. const __m512i zeros = _mm512_setzero_si512(); - const __m512i mask_rightshift_gtz = - intrin_utils::mm512_cmpgt_epi64(exponent, zeros); + const auto mask_rightshift_gtz = + _mm512_cmpgt_epi64_mask(exponent, zeros); const __m512i one_shift_exp_minus1 = _mm512_sllv_epi64( _mm512_set1_epi64(1), _mm512_sub_epi64(exponent, _mm512_set1_epi64(1))); - __m512i nudge = intrin_utils::mm512_blendv_epi64( - zeros, one_shift_exp_minus1, mask_rightshift_gtz); + __m512i nudge = _mm512_mask_mov_epi64(zeros, mask_rightshift_gtz, + one_shift_exp_minus1); // Calculate the shifted sum (results + nudge) >> exp. const __m512i r_plus_nudge = _mm512_add_epi64(results, nudge); const __m512i shifted_sum = _mm512_srav_epi64(r_plus_nudge, exponent); @@ -406,14 +367,12 @@ void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) { const __m512i one_shift_31minus_exp = _mm512_sllv_epi64( _mm512_set1_epi64(1), _mm512_sub_epi64(_mm512_set1_epi64(31), exponent)); - const __m512i mask_num_plus_nudge_overflow = - intrin_utils::mm512_cmpgt_epi64( - results, - _mm512_sub_epi64(_mm512_set1_epi64(0x7fffffff), nudge)); + const auto mask_num_plus_nudge_overflow = _mm512_cmpgt_epi64_mask( + results, _mm512_sub_epi64(_mm512_set1_epi64(0x7fffffff), nudge)); // Fill results with either (results + nudge) >> exponent or // 1 << (31 - exp) in the case of overflow. - results = intrin_utils::mm512_blendv_epi64( - shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow); + results = _mm512_mask_mov_epi64( + shifted_sum, mask_num_plus_nudge_overflow, one_shift_31minus_exp); }; if (per_column_multiplier) { @@ -424,8 +383,8 @@ void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) { _mm512_permutexvar_epi32(_mm512_set1_epi32(col), left_shift); __m512i m_64bit_val = _mm512_permutexvar_epi64( perm_64bit_vals, col < 8 ? m_64bit_low : m_64bit_high); - __m512i offset_vector_val = _mm512_permutexvar_epi64( - perm_64bit_vals, offset_vector); + __m512i offset_vector_val = + _mm512_permutexvar_epi64(perm_64bit_vals, offset_vector); __m512i final_right_shift_val = _mm512_permutexvar_epi64( perm_64bit_vals, col < 8 ? final_right_shift_low : final_right_shift_high); @@ -802,13 +761,13 @@ void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) { // Construct the "nudge" value for each lane if the exponent is // greater than 0. Otherwise, the nudge is 0. const __m512i zeros = _mm512_setzero_si512(); - const __m512i mask_rightshift_gtz = - intrin_utils::mm512_cmpgt_epi64(exponent, zeros); + const auto mask_rightshift_gtz = + _mm512_cmpgt_epi64_mask(exponent, zeros); const __m512i one_shift_exp_minus1 = _mm512_sllv_epi64(_mm512_set1_epi64(1), _mm512_sub_epi64(exponent, _mm512_set1_epi64(1))); - __m512i nudge = intrin_utils::mm512_blendv_epi64( - zeros, one_shift_exp_minus1, mask_rightshift_gtz); + __m512i nudge = _mm512_mask_mov_epi64(zeros, mask_rightshift_gtz, + one_shift_exp_minus1); // Calculate the shifted sum (results + nudge) >> exp. const __m512i r_plus_nudge = _mm512_add_epi64(results, nudge); const __m512i shifted_sum = _mm512_srav_epi64(r_plus_nudge, exponent); @@ -817,14 +776,12 @@ void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) { const __m512i one_shift_31minus_exp = _mm512_sllv_epi64( _mm512_set1_epi64(1), _mm512_sub_epi64(_mm512_set1_epi64(31), exponent)); - const __m512i mask_num_plus_nudge_overflow = - intrin_utils::mm512_cmpgt_epi64( - results, - _mm512_sub_epi64(_mm512_set1_epi64(0x7fffffff), nudge)); + const auto mask_num_plus_nudge_overflow = _mm512_cmpgt_epi64_mask( + results, _mm512_sub_epi64(_mm512_set1_epi64(0x7fffffff), nudge)); // Fill results with either (results + nudge) >> exponent or // 1 << (31 - exp) in the case of overflow. - results = intrin_utils::mm512_blendv_epi64( - shifted_sum, one_shift_31minus_exp, mask_num_plus_nudge_overflow); + results = _mm512_mask_mov_epi64( + shifted_sum, mask_num_plus_nudge_overflow, one_shift_31minus_exp); }; // Shift and round column 0. @@ -930,9 +887,8 @@ void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) { float* dst_ptr = dst_col_ptr + row; // Process block in two halves, split by columns. - { - constexpr int mmm = 0; - +#pragma unroll(1) + for (int mmm = 0; mmm < 2; ++mmm) { __m512 accum_data_v0; __m512 accum_data_v1; __m512 accum_data_v2; @@ -972,81 +928,49 @@ void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) { const float* rhs_ptr = rhs_col_ptr + 8 * mmm; for (int d = 0; d < (params.depth - 1); ++d) { const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); - // In this version RHS values are loaded individually rather than - // first loading together and then extract with broadcasting. This is - // because AVX flavours and instrinsics and compilers in combination - // do not handle this pattern of extraction very well. const float* rhs_data = rhs_ptr; lhs_ptr += 16; rhs_ptr += 16; - { - // Load 8 float32 values. - __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data)); - __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0); // [0 1 2 3] X 4 - __m512 rhs4_7 = - _mm512_shuffle_f32x4(rhs, rhs, 0x55); // [4 5 6 7] X 4 - - const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0); - accum_data_v0 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); - const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55); - accum_data_v1 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); - const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa); - accum_data_v2 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); - const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff); - accum_data_v3 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); - const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0); - accum_data_v4 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); - const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55); - accum_data_v5 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); - const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa); - accum_data_v6 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); - const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff); - accum_data_v7 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); - } + // GCC and clang can fuse set1+FMA into an FMA with EVEX broadcast: + // https://gcc.godbolt.org/z/xbfqWYfn1. Clang is more likely to do + // so if given an rvalue. + accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]), + accum_data_v0); + accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]), + accum_data_v1); + accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]), + accum_data_v2); + accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]), + accum_data_v3); + accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]), + accum_data_v4); + accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]), + accum_data_v5); + accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]), + accum_data_v6); + accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]), + accum_data_v7); } - { + { // nested extra blocks lead to measurable speed gains const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); const float* rhs_data = rhs_ptr; - { - // Load 8 float32 values. - __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data)); - __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0); // [0 1 2 3] X 4 - __m512 rhs4_7 = - _mm512_shuffle_f32x4(rhs, rhs, 0x55); // [4 5 6 7] X 4 - const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0); - accum_data_v0 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); - const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55); - accum_data_v1 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); - const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa); - accum_data_v2 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); - const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff); - accum_data_v3 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); - const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0); - accum_data_v4 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); - const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55); - accum_data_v5 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); - const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa); - accum_data_v6 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); - const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff); - accum_data_v7 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); - } + accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]), + accum_data_v0); + accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]), + accum_data_v1); + accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]), + accum_data_v2); + accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]), + accum_data_v3); + accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]), + accum_data_v4); + accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]), + accum_data_v5); + accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]), + accum_data_v6); + accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]), + accum_data_v7); { float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride; accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v); @@ -1075,147 +999,7 @@ void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) { _mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7); } } - } // Inner half-block loop, unrolled, first iteration. - { - constexpr int mmm = 1; - - __m512 accum_data_v0; - __m512 accum_data_v1; - __m512 accum_data_v2; - __m512 accum_data_v3; - __m512 accum_data_v4; - __m512 accum_data_v5; - __m512 accum_data_v6; - __m512 accum_data_v7; - - // Initialize with bias. - if (channel_dimension_is_col) { - const float* bias_elem_ptr = - bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment; - accum_data_v0 = _mm512_set1_ps(bias_elem_ptr[0]); - accum_data_v1 = _mm512_set1_ps(bias_elem_ptr[1]); - accum_data_v2 = _mm512_set1_ps(bias_elem_ptr[2]); - accum_data_v3 = _mm512_set1_ps(bias_elem_ptr[3]); - accum_data_v4 = _mm512_set1_ps(bias_elem_ptr[4]); - accum_data_v5 = _mm512_set1_ps(bias_elem_ptr[5]); - accum_data_v6 = _mm512_set1_ps(bias_elem_ptr[6]); - accum_data_v7 = _mm512_set1_ps(bias_elem_ptr[7]); - } else { - const __m512 initial_accum_data = - _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment); - - accum_data_v0 = initial_accum_data; - accum_data_v1 = initial_accum_data; - accum_data_v2 = initial_accum_data; - accum_data_v3 = initial_accum_data; - accum_data_v4 = initial_accum_data; - accum_data_v5 = initial_accum_data; - accum_data_v6 = initial_accum_data; - accum_data_v7 = initial_accum_data; - } - const float* lhs_ptr = lhs_col_ptr; - const float* rhs_ptr = rhs_col_ptr + 8 * mmm; - for (int d = 0; d < (params.depth - 1); ++d) { - const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - lhs_ptr += 16; - rhs_ptr += 16; - { - // Load 8 float32 values. - __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data)); - __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0); // [0 1 2 3] X 4 - __m512 rhs4_7 = - _mm512_shuffle_f32x4(rhs, rhs, 0x55); // [4 5 6 7] X 4 - - const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0); - accum_data_v0 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); - const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55); - accum_data_v1 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); - const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa); - accum_data_v2 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); - const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff); - accum_data_v3 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); - const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0); - accum_data_v4 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); - const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55); - accum_data_v5 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); - const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa); - accum_data_v6 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); - const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff); - accum_data_v7 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); - } - } - { - const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - { - // Load 8 float32 values. - __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data)); - __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0); // [0 1 2 3] X 4 - __m512 rhs4_7 = - _mm512_shuffle_f32x4(rhs, rhs, 0x55); // [4 5 6 7] X 4 - const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0); - accum_data_v0 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); - const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55); - accum_data_v1 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); - const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa); - accum_data_v2 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); - const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff); - accum_data_v3 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); - const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0); - accum_data_v4 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); - const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55); - accum_data_v5 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); - const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa); - accum_data_v6 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); - const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff); - accum_data_v7 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); - } - { - float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride; - accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v); - accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v); - _mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0); - accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v); - accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v); - _mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1); - accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v); - accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v); - _mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2); - accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v); - accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v); - _mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3); - accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v); - accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v); - _mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4); - accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v); - accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v); - _mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5); - accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v); - accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v); - _mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6); - accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v); - accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v); - _mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7); - } - } - } // Inner half-block loop, unrolled, second iteration. + } } // End row-block loop. // The unrolling within this conditional may be somewhat pointless. It @@ -1273,73 +1057,45 @@ void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) { const float* rhs_data = rhs_ptr; lhs_ptr += 16; rhs_ptr += 16; - { - // Load 8 float32 values. - __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data)); - __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0); // [0 1 2 3] X 4 - __m512 rhs4_7 = - _mm512_shuffle_f32x4(rhs, rhs, 0x55); // [4 5 6 7] X 4 - - const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0); - accum_data_v0 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); - const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55); - accum_data_v1 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); - const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa); - accum_data_v2 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); - const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff); - accum_data_v3 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); - const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0); - accum_data_v4 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); - const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55); - accum_data_v5 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); - const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa); - accum_data_v6 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); - const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff); - accum_data_v7 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); - } + // GCC and clang can fuse set1+FMA into an FMA with EVEX broadcast: + // https://gcc.godbolt.org/z/xbfqWYfn1. Clang is more likely to do + // so if given an rvalue. + accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]), + accum_data_v0); + accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]), + accum_data_v1); + accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]), + accum_data_v2); + accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]), + accum_data_v3); + accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]), + accum_data_v4); + accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]), + accum_data_v5); + accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]), + accum_data_v6); + accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]), + accum_data_v7); } { const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); const float* rhs_data = rhs_ptr; - { - // Load 8 float32 values. - __m512 rhs = _mm512_castps256_ps512(_mm256_loadu_ps(rhs_data)); - __m512 rhs0_3 = _mm512_shuffle_f32x4(rhs, rhs, 0); // [0 1 2 3] X 4 - __m512 rhs4_7 = - _mm512_shuffle_f32x4(rhs, rhs, 0x55); // [4 5 6 7] X 4 - const __m512 dup_rhs_element_j0 = _mm512_permute_ps(rhs0_3, 0); - accum_data_v0 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); - const __m512 dup_rhs_element_j1 = _mm512_permute_ps(rhs0_3, 0x55); - accum_data_v1 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); - const __m512 dup_rhs_element_j2 = _mm512_permute_ps(rhs0_3, 0xaa); - accum_data_v2 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); - const __m512 dup_rhs_element_j3 = _mm512_permute_ps(rhs0_3, 0xff); - accum_data_v3 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); - const __m512 dup_rhs_element_j4 = _mm512_permute_ps(rhs4_7, 0); - accum_data_v4 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); - const __m512 dup_rhs_element_j5 = _mm512_permute_ps(rhs4_7, 0x55); - accum_data_v5 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); - const __m512 dup_rhs_element_j6 = _mm512_permute_ps(rhs4_7, 0xaa); - accum_data_v6 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); - const __m512 dup_rhs_element_j7 = _mm512_permute_ps(rhs4_7, 0xff); - accum_data_v7 = - _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); - } + accum_data_v0 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[0]), + accum_data_v0); + accum_data_v1 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[1]), + accum_data_v1); + accum_data_v2 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[2]), + accum_data_v2); + accum_data_v3 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[3]), + accum_data_v3); + accum_data_v4 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[4]), + accum_data_v4); + accum_data_v5 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[5]), + accum_data_v5); + accum_data_v6 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[6]), + accum_data_v6); + accum_data_v7 = _mm512_fmadd_ps(lhs_data, _mm512_set1_ps(rhs_data[7]), + accum_data_v7); { float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride; accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v); diff --git a/ruy/kernel_x86.h b/ruy/kernel_x86.h index 2f8fe19..b716502 100644 --- a/ruy/kernel_x86.h +++ b/ruy/kernel_x86.h @@ -607,14 +607,12 @@ inline void KernelFloatAvxCommon(const KernelParamsFloat<8, 8>& params) { const float* rhs_ptr = rhs_col_ptr; for (int d = 0; d < params.depth; ++d) { const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - // Load 8 RHS values, then use permute instructions to - // broadcast each value to a register. - __m256 rhs1 = _mm256_loadu_ps(rhs_data); // Load [0 1 2 3 4 5 6 7] + // Load 8 RHS values, then use permute instructions to broadcast each + // value to a register. _mm256_permute2f128_ps is slow on AMD. __m256 rhs0_3 = - _mm256_permute2f128_ps(rhs1, rhs1, 0); // [0 1 2 3 0 1 2 3] + _mm256_broadcast_ps(reinterpret_cast(rhs_ptr)); __m256 rhs4_7 = - _mm256_permute2f128_ps(rhs1, rhs1, 17); // [4 5 6 7 4 5 6 7] + _mm256_broadcast_ps(reinterpret_cast(rhs_ptr + 4)); const __m256 dup_rhs_element_0 = _mm256_permute_ps(rhs0_3, 0); accum_data_v[0] = intrin_utils::MulAdd( @@ -707,13 +705,11 @@ inline void KernelFloatAvxCommon(const KernelParamsFloat<8, 8>& params) { const float* rhs_ptr = rhs_col_ptr; for (int d = 0; d < params.depth; ++d) { const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); - const float* rhs_data = rhs_ptr; - __m256 rhs1 = _mm256_loadu_ps(rhs_data); // Load [0 1 2 3 4 5 6 7] __m256 rhs0_3 = - _mm256_permute2f128_ps(rhs1, rhs1, 0); // [0 1 2 3 0 1 2 3] + _mm256_broadcast_ps(reinterpret_cast(rhs_ptr)); __m256 rhs4_7 = - _mm256_permute2f128_ps(rhs1, rhs1, 17); // [4 5 6 7 4 5 6 7] + _mm256_broadcast_ps(reinterpret_cast(rhs_ptr + 4)); const __m256 dup_rhs_element_0 = _mm256_permute_ps(rhs0_3, 0); accum_data_v[0] = intrin_utils::MulAdd( -- cgit v1.2.3