diff options
author | T.J. Alumbaugh <talumbau@google.com> | 2020-09-18 17:10:35 +0300 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2020-09-18 17:10:55 +0300 |
commit | be065e42dd898f565c3e439b70957debb28dfa34 (patch) | |
tree | dd7e35c5fafe86d41f247de2d98a0b1596e198e4 | |
parent | d7b739eb7573e23125d18d44d5c7bed936244911 (diff) |
Optimize AVX/AVX2+FMA float path
PiperOrigin-RevId: 332444643
-rw-r--r-- | ruy/kernel_x86.h | 91 |
1 files changed, 77 insertions, 14 deletions
diff --git a/ruy/kernel_x86.h b/ruy/kernel_x86.h index c530a1f..5681a43 100644 --- a/ruy/kernel_x86.h +++ b/ruy/kernel_x86.h @@ -603,17 +603,47 @@ inline void KernelFloatAvxCommon(const KernelParamsFloat<8, 8>& params) { const float* rhs_ptr = rhs_col_ptr; for (int d = 0; d < params.depth; ++d) { const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); - // In this version RHS values are loaded individually rather than first - // loading together and then extract with broadcasting. This is because - // AVX flavours and instrinsics and compilers in combination do not - // handle this pattern of extraction very well. const float* rhs_data = rhs_ptr; + // Load 8 RHS values, then use permute instructions to + // broadcast each value to a register. + __m256 rhs1 = _mm256_loadu_ps(rhs_data); // Load [0 1 2 3 4 5 6 7] + __m256 rhs0_3 = + _mm256_permute2f128_ps(rhs1, rhs1, 0); // [0 1 2 3 0 1 2 3] + __m256 rhs4_7 = + _mm256_permute2f128_ps(rhs1, rhs1, 17); // [4 5 6 7 4 5 6 7] + + const __m256 dup_rhs_element_0 = _mm256_permute_ps(rhs0_3, 0); + accum_data_v[0] = intrin_utils::MulAdd<path>( + lhs_data, dup_rhs_element_0, accum_data_v[0]); + + const __m256 dup_rhs_element_1 = _mm256_permute_ps(rhs0_3, 85); + accum_data_v[1] = intrin_utils::MulAdd<path>( + lhs_data, dup_rhs_element_1, accum_data_v[1]); + + const __m256 dup_rhs_element_2 = _mm256_permute_ps(rhs0_3, 170); + accum_data_v[2] = intrin_utils::MulAdd<path>( + lhs_data, dup_rhs_element_2, accum_data_v[2]); + + const __m256 dup_rhs_element_3 = _mm256_permute_ps(rhs0_3, 255); + accum_data_v[3] = intrin_utils::MulAdd<path>( + lhs_data, dup_rhs_element_3, accum_data_v[3]); + + const __m256 dup_rhs_element_4 = _mm256_permute_ps(rhs4_7, 0); + accum_data_v[4] = intrin_utils::MulAdd<path>( + lhs_data, dup_rhs_element_4, accum_data_v[4]); + + const __m256 dup_rhs_element_5 = _mm256_permute_ps(rhs4_7, 85); + accum_data_v[5] = intrin_utils::MulAdd<path>( + lhs_data, dup_rhs_element_5, accum_data_v[5]); + + const __m256 dup_rhs_element_6 = _mm256_permute_ps(rhs4_7, 170); + accum_data_v[6] = intrin_utils::MulAdd<path>( + lhs_data, dup_rhs_element_6, accum_data_v[6]); + + const __m256 dup_rhs_element_7 = _mm256_permute_ps(rhs4_7, 255); + accum_data_v[7] = intrin_utils::MulAdd<path>( + lhs_data, dup_rhs_element_7, accum_data_v[7]); - for (int j = 0; j < 8; ++j) { - const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[j]); - accum_data_v[j] = intrin_utils::MulAdd<path>( - lhs_data, dup_rhs_element_j, accum_data_v[j]); - } lhs_ptr += 8; rhs_ptr += 8; } @@ -675,11 +705,44 @@ inline void KernelFloatAvxCommon(const KernelParamsFloat<8, 8>& params) { const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr); const float* rhs_data = rhs_ptr; - for (int j = 0; j < 8; ++j) { - const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[j]); - accum_data_v[j] = intrin_utils::MulAdd<path>( - lhs_data, dup_rhs_element_j, accum_data_v[j]); - } + __m256 rhs1 = _mm256_loadu_ps(rhs_data); // Load [0 1 2 3 4 5 6 7] + __m256 rhs0_3 = + _mm256_permute2f128_ps(rhs1, rhs1, 0); // [0 1 2 3 0 1 2 3] + __m256 rhs4_7 = + _mm256_permute2f128_ps(rhs1, rhs1, 17); // [4 5 6 7 4 5 6 7] + + const __m256 dup_rhs_element_0 = _mm256_permute_ps(rhs0_3, 0); + accum_data_v[0] = intrin_utils::MulAdd<path>( + lhs_data, dup_rhs_element_0, accum_data_v[0]); + + const __m256 dup_rhs_element_1 = _mm256_permute_ps(rhs0_3, 85); + accum_data_v[1] = intrin_utils::MulAdd<path>( + lhs_data, dup_rhs_element_1, accum_data_v[1]); + + const __m256 dup_rhs_element_2 = _mm256_permute_ps(rhs0_3, 170); + accum_data_v[2] = intrin_utils::MulAdd<path>( + lhs_data, dup_rhs_element_2, accum_data_v[2]); + + const __m256 dup_rhs_element_3 = _mm256_permute_ps(rhs0_3, 255); + accum_data_v[3] = intrin_utils::MulAdd<path>( + lhs_data, dup_rhs_element_3, accum_data_v[3]); + + const __m256 dup_rhs_element_4 = _mm256_permute_ps(rhs4_7, 0); + accum_data_v[4] = intrin_utils::MulAdd<path>( + lhs_data, dup_rhs_element_4, accum_data_v[4]); + + const __m256 dup_rhs_element_5 = _mm256_permute_ps(rhs4_7, 85); + accum_data_v[5] = intrin_utils::MulAdd<path>( + lhs_data, dup_rhs_element_5, accum_data_v[5]); + + const __m256 dup_rhs_element_6 = _mm256_permute_ps(rhs4_7, 170); + accum_data_v[6] = intrin_utils::MulAdd<path>( + lhs_data, dup_rhs_element_6, accum_data_v[6]); + + const __m256 dup_rhs_element_7 = _mm256_permute_ps(rhs4_7, 255); + accum_data_v[7] = intrin_utils::MulAdd<path>( + lhs_data, dup_rhs_element_7, accum_data_v[7]); + lhs_ptr += 8; rhs_ptr += 8; } |