Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/google/ruy.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorT.J. Alumbaugh <talumbau@google.com>2020-09-18 17:10:35 +0300
committerCopybara-Service <copybara-worker@google.com>2020-09-18 17:10:55 +0300
commitbe065e42dd898f565c3e439b70957debb28dfa34 (patch)
treedd7e35c5fafe86d41f247de2d98a0b1596e198e4
parentd7b739eb7573e23125d18d44d5c7bed936244911 (diff)
Optimize AVX/AVX2+FMA float path
PiperOrigin-RevId: 332444643
-rw-r--r--ruy/kernel_x86.h91
1 files changed, 77 insertions, 14 deletions
diff --git a/ruy/kernel_x86.h b/ruy/kernel_x86.h
index c530a1f..5681a43 100644
--- a/ruy/kernel_x86.h
+++ b/ruy/kernel_x86.h
@@ -603,17 +603,47 @@ inline void KernelFloatAvxCommon(const KernelParamsFloat<8, 8>& params) {
const float* rhs_ptr = rhs_col_ptr;
for (int d = 0; d < params.depth; ++d) {
const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
- // In this version RHS values are loaded individually rather than first
- // loading together and then extract with broadcasting. This is because
- // AVX flavours and instrinsics and compilers in combination do not
- // handle this pattern of extraction very well.
const float* rhs_data = rhs_ptr;
+ // Load 8 RHS values, then use permute instructions to
+ // broadcast each value to a register.
+ __m256 rhs1 = _mm256_loadu_ps(rhs_data); // Load [0 1 2 3 4 5 6 7]
+ __m256 rhs0_3 =
+ _mm256_permute2f128_ps(rhs1, rhs1, 0); // [0 1 2 3 0 1 2 3]
+ __m256 rhs4_7 =
+ _mm256_permute2f128_ps(rhs1, rhs1, 17); // [4 5 6 7 4 5 6 7]
+
+ const __m256 dup_rhs_element_0 = _mm256_permute_ps(rhs0_3, 0);
+ accum_data_v[0] = intrin_utils::MulAdd<path>(
+ lhs_data, dup_rhs_element_0, accum_data_v[0]);
+
+ const __m256 dup_rhs_element_1 = _mm256_permute_ps(rhs0_3, 85);
+ accum_data_v[1] = intrin_utils::MulAdd<path>(
+ lhs_data, dup_rhs_element_1, accum_data_v[1]);
+
+ const __m256 dup_rhs_element_2 = _mm256_permute_ps(rhs0_3, 170);
+ accum_data_v[2] = intrin_utils::MulAdd<path>(
+ lhs_data, dup_rhs_element_2, accum_data_v[2]);
+
+ const __m256 dup_rhs_element_3 = _mm256_permute_ps(rhs0_3, 255);
+ accum_data_v[3] = intrin_utils::MulAdd<path>(
+ lhs_data, dup_rhs_element_3, accum_data_v[3]);
+
+ const __m256 dup_rhs_element_4 = _mm256_permute_ps(rhs4_7, 0);
+ accum_data_v[4] = intrin_utils::MulAdd<path>(
+ lhs_data, dup_rhs_element_4, accum_data_v[4]);
+
+ const __m256 dup_rhs_element_5 = _mm256_permute_ps(rhs4_7, 85);
+ accum_data_v[5] = intrin_utils::MulAdd<path>(
+ lhs_data, dup_rhs_element_5, accum_data_v[5]);
+
+ const __m256 dup_rhs_element_6 = _mm256_permute_ps(rhs4_7, 170);
+ accum_data_v[6] = intrin_utils::MulAdd<path>(
+ lhs_data, dup_rhs_element_6, accum_data_v[6]);
+
+ const __m256 dup_rhs_element_7 = _mm256_permute_ps(rhs4_7, 255);
+ accum_data_v[7] = intrin_utils::MulAdd<path>(
+ lhs_data, dup_rhs_element_7, accum_data_v[7]);
- for (int j = 0; j < 8; ++j) {
- const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[j]);
- accum_data_v[j] = intrin_utils::MulAdd<path>(
- lhs_data, dup_rhs_element_j, accum_data_v[j]);
- }
lhs_ptr += 8;
rhs_ptr += 8;
}
@@ -675,11 +705,44 @@ inline void KernelFloatAvxCommon(const KernelParamsFloat<8, 8>& params) {
const __m256 lhs_data = _mm256_loadu_ps(lhs_ptr);
const float* rhs_data = rhs_ptr;
- for (int j = 0; j < 8; ++j) {
- const __m256 dup_rhs_element_j = _mm256_set1_ps(rhs_data[j]);
- accum_data_v[j] = intrin_utils::MulAdd<path>(
- lhs_data, dup_rhs_element_j, accum_data_v[j]);
- }
+ __m256 rhs1 = _mm256_loadu_ps(rhs_data); // Load [0 1 2 3 4 5 6 7]
+ __m256 rhs0_3 =
+ _mm256_permute2f128_ps(rhs1, rhs1, 0); // [0 1 2 3 0 1 2 3]
+ __m256 rhs4_7 =
+ _mm256_permute2f128_ps(rhs1, rhs1, 17); // [4 5 6 7 4 5 6 7]
+
+ const __m256 dup_rhs_element_0 = _mm256_permute_ps(rhs0_3, 0);
+ accum_data_v[0] = intrin_utils::MulAdd<path>(
+ lhs_data, dup_rhs_element_0, accum_data_v[0]);
+
+ const __m256 dup_rhs_element_1 = _mm256_permute_ps(rhs0_3, 85);
+ accum_data_v[1] = intrin_utils::MulAdd<path>(
+ lhs_data, dup_rhs_element_1, accum_data_v[1]);
+
+ const __m256 dup_rhs_element_2 = _mm256_permute_ps(rhs0_3, 170);
+ accum_data_v[2] = intrin_utils::MulAdd<path>(
+ lhs_data, dup_rhs_element_2, accum_data_v[2]);
+
+ const __m256 dup_rhs_element_3 = _mm256_permute_ps(rhs0_3, 255);
+ accum_data_v[3] = intrin_utils::MulAdd<path>(
+ lhs_data, dup_rhs_element_3, accum_data_v[3]);
+
+ const __m256 dup_rhs_element_4 = _mm256_permute_ps(rhs4_7, 0);
+ accum_data_v[4] = intrin_utils::MulAdd<path>(
+ lhs_data, dup_rhs_element_4, accum_data_v[4]);
+
+ const __m256 dup_rhs_element_5 = _mm256_permute_ps(rhs4_7, 85);
+ accum_data_v[5] = intrin_utils::MulAdd<path>(
+ lhs_data, dup_rhs_element_5, accum_data_v[5]);
+
+ const __m256 dup_rhs_element_6 = _mm256_permute_ps(rhs4_7, 170);
+ accum_data_v[6] = intrin_utils::MulAdd<path>(
+ lhs_data, dup_rhs_element_6, accum_data_v[6]);
+
+ const __m256 dup_rhs_element_7 = _mm256_permute_ps(rhs4_7, 255);
+ accum_data_v[7] = intrin_utils::MulAdd<path>(
+ lhs_data, dup_rhs_element_7, accum_data_v[7]);
+
lhs_ptr += 8;
rhs_ptr += 8;
}