Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/google/ruy.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorbenoitjacob <benoitjacob@google.com>2020-07-24 19:47:32 +0300
committerCopybara-Service <copybara-worker@google.com>2020-07-24 21:52:36 +0300
commit78b870988e5ed9587dc23838d800352731a0f58c (patch)
tree7c07a8651ed4965c9738a958f14df3c65fa5f160
parent1efd97066bbc59c3bcce267f99d46ec5387e876e (diff)
Let cpu_backend_gemm support all storage order combinations, unconditionally using ruy as the backend in combinations other than RowMajor*ColMajor->ColMajor, which were so far not supported. Ruy is different from other back-ends in that it supports all combinations as runtime parameters without a code size increase.test_323013778
PiperOrigin-RevId: 323013778
-rw-r--r--ruy/kernel_avx2_fma.cc357
-rw-r--r--ruy/kernel_avx512.cc245
2 files changed, 128 insertions, 474 deletions
diff --git a/ruy/kernel_avx2_fma.cc b/ruy/kernel_avx2_fma.cc
index 463bb4b..4ba73f5 100644
--- a/ruy/kernel_avx2_fma.cc
+++ b/ruy/kernel_avx2_fma.cc
@@ -319,39 +319,6 @@ inline float mm256_get1_ps(const __m256 a, int i) {
return float_val;
}
-inline __m256 mm256_n_loadu_ps(int i, const float* src) {
- switch (i) {
- case 0:
- return _mm256_setzero_ps();
- case 1:
- return _mm256_setr_m128(_mm_setr_ps(src[0], .0f, .0f, .0f),
- _mm_setzero_ps());
- case 2:
- return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], .0f, .0f),
- _mm_setzero_ps());
- case 3:
- return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], src[2], .0f),
- _mm_setzero_ps());
- case 4:
- return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], src[2], src[3]),
- _mm_setzero_ps());
- case 5:
- return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], .0f, .0f,
- .0f);
- case 6:
- return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], src[5], .0f,
- .0f);
- case 7:
- return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], src[5],
- src[6], .0f);
- case 8:
- return _mm256_loadu_ps(src);
- default:
- RUY_DCHECK_LT(i, 9);
- return _mm256_setzero_ps();
- }
-}
-
inline void mm256_n_storeu_ps(float* dst, int residual_rows, const __m256 v) {
for (int i = 0; i < residual_rows; ++i) {
dst[i] = intrin_utils::mm256_get1_ps(v, i);
@@ -589,126 +556,26 @@ void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) {
// Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit.
const __m256i lhs_16_bit_high = _mm256_permute2x128_si256(
lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31);
- // Accumulate for column 0.
- {
- const std::int32_t low_rhs_value = rhs_data[0];
- const std::int32_t high_rhs_value = rhs_data[1];
-
- const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
- const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
-
- accum_data_v0 = _mm256_add_epi32(
- accum_data_v0,
- _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
- accum_data_v0 = _mm256_add_epi32(
- accum_data_v0,
- _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
- }
- // Accumulate for column 1.
- {
- const std::int32_t low_rhs_value = rhs_data[2];
- const std::int32_t high_rhs_value = rhs_data[3];
-
- const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
- const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
-
- accum_data_v1 = _mm256_add_epi32(
- accum_data_v1,
- _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
- accum_data_v1 = _mm256_add_epi32(
- accum_data_v1,
- _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
- }
- // Accumulate for column 2.
- {
- const std::int32_t low_rhs_value = rhs_data[4];
- const std::int32_t high_rhs_value = rhs_data[5];
-
- const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
- const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
-
- accum_data_v2 = _mm256_add_epi32(
- accum_data_v2,
- _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
- accum_data_v2 = _mm256_add_epi32(
- accum_data_v2,
- _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
- }
- // Accumulate for column 3.
- {
- const std::int32_t low_rhs_value = rhs_data[6];
- const std::int32_t high_rhs_value = rhs_data[7];
+ auto process_column = [=](int col, __m256i& accum) {
+ const std::int32_t low_rhs_value = rhs_data[col * 2];
+ const std::int32_t high_rhs_value = rhs_data[col * 2 + 1];
const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
- accum_data_v3 = _mm256_add_epi32(
- accum_data_v3,
- _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
- accum_data_v3 = _mm256_add_epi32(
- accum_data_v3,
- _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
- }
- // Accumulate for column 4.
- {
- const std::int32_t low_rhs_value = rhs_data[8];
- const std::int32_t high_rhs_value = rhs_data[9];
-
- const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
- const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
-
- accum_data_v4 = _mm256_add_epi32(
- accum_data_v4,
- _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
- accum_data_v4 = _mm256_add_epi32(
- accum_data_v4,
- _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
- }
- // Accumulate for column 5.
- {
- const std::int32_t low_rhs_value = rhs_data[10];
- const std::int32_t high_rhs_value = rhs_data[11];
-
- const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
- const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
-
- accum_data_v5 = _mm256_add_epi32(
- accum_data_v5,
- _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
- accum_data_v5 = _mm256_add_epi32(
- accum_data_v5,
- _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
- }
- // Accumulate for column 6.
- {
- const std::int32_t low_rhs_value = rhs_data[12];
- const std::int32_t high_rhs_value = rhs_data[13];
-
- const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
- const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
-
- accum_data_v6 = _mm256_add_epi32(
- accum_data_v6,
- _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
- accum_data_v6 = _mm256_add_epi32(
- accum_data_v6,
- _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
- }
- // Accumulate for column 7.
- {
- const std::int32_t low_rhs_value = rhs_data[14];
- const std::int32_t high_rhs_value = rhs_data[15];
-
- const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
- const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
-
- accum_data_v7 = _mm256_add_epi32(
- accum_data_v7,
- _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
- accum_data_v7 = _mm256_add_epi32(
- accum_data_v7,
- _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
- }
+ accum = _mm256_add_epi32(
+ accum, _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
+ accum = _mm256_add_epi32(
+ accum, _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
+ };
+ process_column(0, accum_data_v0);
+ process_column(1, accum_data_v1);
+ process_column(2, accum_data_v2);
+ process_column(3, accum_data_v3);
+ process_column(4, accum_data_v4);
+ process_column(5, accum_data_v5);
+ process_column(6, accum_data_v6);
+ process_column(7, accum_data_v7);
lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
@@ -844,8 +711,8 @@ void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) {
&accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3,
&accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7);
}
- {
- __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v0, left_shift);
+ auto apply_multiplier = [=](__m256i& accum) {
+ __m256i shifted_accum = _mm256_sllv_epi32(accum, left_shift);
// Apply the fixed-point part of the multiplier.
__m256i scaled_v_low = _mm256_mul_epi32(
_mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
@@ -866,176 +733,16 @@ void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) {
_mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
results = _mm256_permutevar8x32_epi32(results, repack_perm);
- accum_data_v0 = _mm256_sub_epi32(results, post_scaling_offset);
- }
- {
- __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v1, left_shift);
- // Apply the fixed-point part of the multiplier.
- __m256i scaled_v_low = _mm256_mul_epi32(
- _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
- m_64bit_low);
- __m256i scaled_v_high = _mm256_mul_epi32(
- _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
- m_64bit_high);
-
- scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low);
- scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high);
-
- scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
- scaled_v_high =
- _mm256_srlv_epi64(scaled_v_high, final_right_shift_high);
-
- scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
- __m256i results =
- _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
- results = _mm256_permutevar8x32_epi32(results, repack_perm);
-
- accum_data_v1 = _mm256_sub_epi32(results, post_scaling_offset);
- }
- {
- __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v2, left_shift);
- // Apply the fixed-point part of the multiplier.
- __m256i scaled_v_low = _mm256_mul_epi32(
- _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
- m_64bit_low);
- __m256i scaled_v_high = _mm256_mul_epi32(
- _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
- m_64bit_high);
-
- scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low);
- scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high);
-
- scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
- scaled_v_high =
- _mm256_srlv_epi64(scaled_v_high, final_right_shift_high);
-
- scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
- __m256i results =
- _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
- results = _mm256_permutevar8x32_epi32(results, repack_perm);
-
- accum_data_v2 = _mm256_sub_epi32(results, post_scaling_offset);
- }
- {
- __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v3, left_shift);
- // Apply the fixed-point part of the multiplier.
- __m256i scaled_v_low = _mm256_mul_epi32(
- _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
- m_64bit_low);
- __m256i scaled_v_high = _mm256_mul_epi32(
- _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
- m_64bit_high);
-
- scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low);
- scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high);
-
- scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
- scaled_v_high =
- _mm256_srlv_epi64(scaled_v_high, final_right_shift_high);
-
- scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
- __m256i results =
- _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
- results = _mm256_permutevar8x32_epi32(results, repack_perm);
-
- accum_data_v3 = _mm256_sub_epi32(results, post_scaling_offset);
- }
- {
- __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v4, left_shift);
- // Apply the fixed-point part of the multiplier.
- __m256i scaled_v_low = _mm256_mul_epi32(
- _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
- m_64bit_low);
- __m256i scaled_v_high = _mm256_mul_epi32(
- _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
- m_64bit_high);
-
- scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low);
- scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high);
-
- scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
- scaled_v_high =
- _mm256_srlv_epi64(scaled_v_high, final_right_shift_high);
-
- scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
- __m256i results =
- _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
- results = _mm256_permutevar8x32_epi32(results, repack_perm);
-
- accum_data_v4 = _mm256_sub_epi32(results, post_scaling_offset);
- }
- {
- __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v5, left_shift);
- // Apply the fixed-point part of the multiplier.
- __m256i scaled_v_low = _mm256_mul_epi32(
- _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
- m_64bit_low);
- __m256i scaled_v_high = _mm256_mul_epi32(
- _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
- m_64bit_high);
-
- scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low);
- scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high);
-
- scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
- scaled_v_high =
- _mm256_srlv_epi64(scaled_v_high, final_right_shift_high);
-
- scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
- __m256i results =
- _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
- results = _mm256_permutevar8x32_epi32(results, repack_perm);
-
- accum_data_v5 = _mm256_sub_epi32(results, post_scaling_offset);
- }
- {
- __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v6, left_shift);
- // Apply the fixed-point part of the multiplier.
- __m256i scaled_v_low = _mm256_mul_epi32(
- _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
- m_64bit_low);
- __m256i scaled_v_high = _mm256_mul_epi32(
- _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
- m_64bit_high);
-
- scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low);
- scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high);
-
- scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
- scaled_v_high =
- _mm256_srlv_epi64(scaled_v_high, final_right_shift_high);
-
- scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
- __m256i results =
- _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
- results = _mm256_permutevar8x32_epi32(results, repack_perm);
-
- accum_data_v6 = _mm256_sub_epi32(results, post_scaling_offset);
- }
- {
- __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v7, left_shift);
- // Apply the fixed-point part of the multiplier.
- __m256i scaled_v_low = _mm256_mul_epi32(
- _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
- m_64bit_low);
- __m256i scaled_v_high = _mm256_mul_epi32(
- _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
- m_64bit_high);
-
- scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low);
- scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high);
-
- scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
- scaled_v_high =
- _mm256_srlv_epi64(scaled_v_high, final_right_shift_high);
-
- scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
- __m256i results =
- _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
- results = _mm256_permutevar8x32_epi32(results, repack_perm);
-
- accum_data_v7 = _mm256_sub_epi32(results, post_scaling_offset);
- }
+ accum = _mm256_sub_epi32(results, post_scaling_offset);
+ };
+ apply_multiplier(accum_data_v0);
+ apply_multiplier(accum_data_v1);
+ apply_multiplier(accum_data_v2);
+ apply_multiplier(accum_data_v3);
+ apply_multiplier(accum_data_v4);
+ apply_multiplier(accum_data_v5);
+ apply_multiplier(accum_data_v6);
+ apply_multiplier(accum_data_v7);
// See above comment: here we transpose again to undo the transposition
// of the 8x8 block of accumulators used to implement the
// channels-are-columns case.
@@ -1549,8 +1256,7 @@ void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) {
}
} else {
const float* bias_elem_ptr = bias_ptr + row * bias_ptr_block_increment;
- const __m256 initial_accum_data =
- intrin_utils::mm256_n_loadu_ps(residual_rows, bias_elem_ptr);
+ const __m256 initial_accum_data = _mm256_loadu_ps(bias_elem_ptr);
for (int j = 0; j < 8; ++j) {
accum_data_v[j] = initial_accum_data;
@@ -1620,8 +1326,7 @@ void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) {
}
} else {
const float* bias_elem_ptr = bias_ptr + row * bias_ptr_block_increment;
- const __m256 initial_accum_data =
- intrin_utils::mm256_n_loadu_ps(residual_rows, bias_elem_ptr);
+ const __m256 initial_accum_data = _mm256_loadu_ps(bias_elem_ptr);
for (int j = 0; j < 8; ++j) {
accum_data_v[j] = initial_accum_data;
@@ -1739,7 +1444,7 @@ void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params) {
const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
// Initialize with bias.
- accum_data_v = intrin_utils::mm256_n_loadu_ps(residual_rows, bias_ptr);
+ accum_data_v = _mm256_loadu_ps(bias_ptr);
const float* lhs_ptr = lhs_col_ptr;
const float* rhs_ptr = rhs_col_ptr;
diff --git a/ruy/kernel_avx512.cc b/ruy/kernel_avx512.cc
index 3d36516..d2d1f4c 100644
--- a/ruy/kernel_avx512.cc
+++ b/ruy/kernel_avx512.cc
@@ -52,88 +52,6 @@ void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>&) {
#else // RUY_PLATFORM_AVX512 && RUY_OPT(ASM)
-namespace {
-namespace intrin_utils {
-
-// Transpose a 8x8 matrix of int32's.
-void mm512_transpose16x16_epi32(__m512i* v0, __m512i* v1, __m512i* v2,
- __m512i* v3, __m512i* v4, __m512i* v5,
- __m512i* v6, __m512i* v7, __m512i* v8,
- __m512i* v9, __m512i* va, __m512i* vb,
- __m512i* vc, __m512i* vd, __m512i* ve,
- __m512i* vf) {
- __m512i t2x2_0 = _mm512_unpacklo_epi32(*v0, *v1);
- __m512i t2x2_1 = _mm512_unpackhi_epi32(*v0, *v1);
- __m512i t2x2_2 = _mm512_unpacklo_epi32(*v2, *v3);
- __m512i t2x2_3 = _mm512_unpackhi_epi32(*v2, *v3);
- __m512i t2x2_4 = _mm512_unpacklo_epi32(*v4, *v5);
- __m512i t2x2_5 = _mm512_unpackhi_epi32(*v4, *v5);
- __m512i t2x2_6 = _mm512_unpacklo_epi32(*v6, *v7);
- __m512i t2x2_7 = _mm512_unpackhi_epi32(*v6, *v7);
- __m512i t2x2_8 = _mm512_unpacklo_epi32(*v8, *v9);
- __m512i t2x2_9 = _mm512_unpackhi_epi32(*v8, *v9);
- __m512i t2x2_a = _mm512_unpacklo_epi32(*va, *vb);
- __m512i t2x2_b = _mm512_unpackhi_epi32(*va, *vb);
- __m512i t2x2_c = _mm512_unpacklo_epi32(*vc, *vd);
- __m512i t2x2_d = _mm512_unpackhi_epi32(*vc, *vd);
- __m512i t2x2_e = _mm512_unpacklo_epi32(*ve, *vf);
- __m512i t2x2_f = _mm512_unpackhi_epi32(*ve, *vf);
-
- __m512i t4x4_0 = _mm512_unpacklo_epi64(t2x2_0, t2x2_2);
- __m512i t4x4_1 = _mm512_unpackhi_epi64(t2x2_0, t2x2_2);
- __m512i t4x4_2 = _mm512_unpacklo_epi64(t2x2_1, t2x2_3);
- __m512i t4x4_3 = _mm512_unpackhi_epi64(t2x2_1, t2x2_3);
- __m512i t4x4_4 = _mm512_unpacklo_epi64(t2x2_4, t2x2_6);
- __m512i t4x4_5 = _mm512_unpackhi_epi64(t2x2_4, t2x2_6);
- __m512i t4x4_6 = _mm512_unpacklo_epi64(t2x2_5, t2x2_7);
- __m512i t4x4_7 = _mm512_unpackhi_epi64(t2x2_5, t2x2_7);
- __m512i t4x4_8 = _mm512_unpacklo_epi64(t2x2_8, t2x2_a);
- __m512i t4x4_9 = _mm512_unpackhi_epi64(t2x2_8, t2x2_a);
- __m512i t4x4_a = _mm512_unpacklo_epi64(t2x2_9, t2x2_b);
- __m512i t4x4_b = _mm512_unpackhi_epi64(t2x2_9, t2x2_b);
- __m512i t4x4_c = _mm512_unpacklo_epi64(t2x2_c, t2x2_e);
- __m512i t4x4_d = _mm512_unpackhi_epi64(t2x2_c, t2x2_e);
- __m512i t4x4_e = _mm512_unpacklo_epi64(t2x2_d, t2x2_f);
- __m512i t4x4_f = _mm512_unpackhi_epi64(t2x2_d, t2x2_f);
-
- __m512i t8x8_0 = _mm512_shuffle_i32x4(t4x4_0, t4x4_4, 0x88);
- __m512i t8x8_1 = _mm512_shuffle_i32x4(t4x4_1, t4x4_5, 0x88);
- __m512i t8x8_2 = _mm512_shuffle_i32x4(t4x4_2, t4x4_6, 0x88);
- __m512i t8x8_3 = _mm512_shuffle_i32x4(t4x4_3, t4x4_7, 0x88);
- __m512i t8x8_4 = _mm512_shuffle_i32x4(t4x4_0, t4x4_4, 0xdd);
- __m512i t8x8_5 = _mm512_shuffle_i32x4(t4x4_1, t4x4_5, 0xdd);
- __m512i t8x8_6 = _mm512_shuffle_i32x4(t4x4_2, t4x4_6, 0xdd);
- __m512i t8x8_7 = _mm512_shuffle_i32x4(t4x4_3, t4x4_7, 0xdd);
- __m512i t8x8_8 = _mm512_shuffle_i32x4(t4x4_8, t4x4_c, 0x88);
- __m512i t8x8_9 = _mm512_shuffle_i32x4(t4x4_9, t4x4_d, 0x88);
- __m512i t8x8_a = _mm512_shuffle_i32x4(t4x4_a, t4x4_e, 0x88);
- __m512i t8x8_b = _mm512_shuffle_i32x4(t4x4_b, t4x4_f, 0x88);
- __m512i t8x8_c = _mm512_shuffle_i32x4(t4x4_8, t4x4_c, 0xdd);
- __m512i t8x8_d = _mm512_shuffle_i32x4(t4x4_9, t4x4_d, 0xdd);
- __m512i t8x8_e = _mm512_shuffle_i32x4(t4x4_a, t4x4_e, 0xdd);
- __m512i t8x8_f = _mm512_shuffle_i32x4(t4x4_b, t4x4_f, 0xdd);
-
- *v0 = _mm512_shuffle_i32x4(t8x8_0, t8x8_8, 0x88);
- *v1 = _mm512_shuffle_i32x4(t8x8_1, t8x8_9, 0x88);
- *v2 = _mm512_shuffle_i32x4(t8x8_2, t8x8_a, 0x88);
- *v3 = _mm512_shuffle_i32x4(t8x8_3, t8x8_b, 0x88);
- *v4 = _mm512_shuffle_i32x4(t8x8_4, t8x8_c, 0x88);
- *v5 = _mm512_shuffle_i32x4(t8x8_5, t8x8_d, 0x88);
- *v6 = _mm512_shuffle_i32x4(t8x8_6, t8x8_e, 0x88);
- *v7 = _mm512_shuffle_i32x4(t8x8_7, t8x8_f, 0x88);
- *v8 = _mm512_shuffle_i32x4(t8x8_0, t8x8_8, 0xdd);
- *v9 = _mm512_shuffle_i32x4(t8x8_1, t8x8_9, 0xdd);
- *va = _mm512_shuffle_i32x4(t8x8_2, t8x8_a, 0xdd);
- *vb = _mm512_shuffle_i32x4(t8x8_3, t8x8_b, 0xdd);
- *vc = _mm512_shuffle_i32x4(t8x8_4, t8x8_c, 0xdd);
- *vd = _mm512_shuffle_i32x4(t8x8_5, t8x8_d, 0xdd);
- *ve = _mm512_shuffle_i32x4(t8x8_6, t8x8_e, 0xdd);
- *vf = _mm512_shuffle_i32x4(t8x8_7, t8x8_f, 0xdd);
-}
-
-} // namespace intrin_utils
-} // namespace
-
void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) {
profiler::ScopeLabel label("Kernel kAvx512 8-bit");
@@ -391,6 +309,13 @@ void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) {
}
if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
+ // The non-per-channel case could equivalently be handled in the per-row
+ // or per-column code path. The per-row code path is slightly more
+ // efficient so we handle it there.
+ const bool per_column_multiplier =
+ (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) &&
+ (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL);
+
__m512i m_vector;
__m512i e_vector;
// Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
@@ -426,72 +351,96 @@ void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) {
offset_vector,
_mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1)));
- // This multiplier code is complex and expensive enough on x86, that
- // we prefer to implement the channels-are-columns case by transposing
- // around it, rather than duplicate it (which would also require
- // duplicating the above code computing the multiplier constants).
- // This is one instance where channels-are-columns has lower performance
- // than channels-are-rows.
- const bool transpose_around_multiplier =
- (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) &&
- (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL);
- if (transpose_around_multiplier) {
- // Transpose the 16x16 accumulators block. Will be un-transposed below
- // after the multplier implementation.
- intrin_utils::mm512_transpose16x16_epi32(
- &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3,
- &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7,
- &accum_data_v8, &accum_data_v9, &accum_data_va, &accum_data_vb,
- &accum_data_vc, &accum_data_vd, &accum_data_ve, &accum_data_vf);
- }
-
- auto apply_multiplier = [=](__m512i& accum) {
- accum = _mm512_sllv_epi32(accum, left_shift);
- // Apply the fixed-point part of the multiplier.
- __m512i scaled_v_low = _mm512_mul_epi32(
- _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 0)),
- m_64bit_low);
- __m512i scaled_v_high = _mm512_mul_epi32(
- _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 1)),
- m_64bit_high);
-
- scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low);
- scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high);
-
- scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
- scaled_v_high =
- _mm512_srav_epi64(scaled_v_high, final_right_shift_high);
-
- accum = _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
- accum = _mm512_inserti32x8(accum,
- _mm512_cvtepi64_epi32(scaled_v_high), 1);
- };
- apply_multiplier(accum_data_v0);
- apply_multiplier(accum_data_v1);
- apply_multiplier(accum_data_v2);
- apply_multiplier(accum_data_v3);
- apply_multiplier(accum_data_v4);
- apply_multiplier(accum_data_v5);
- apply_multiplier(accum_data_v6);
- apply_multiplier(accum_data_v7);
- apply_multiplier(accum_data_v8);
- apply_multiplier(accum_data_v9);
- apply_multiplier(accum_data_va);
- apply_multiplier(accum_data_vb);
- apply_multiplier(accum_data_vc);
- apply_multiplier(accum_data_vd);
- apply_multiplier(accum_data_ve);
- apply_multiplier(accum_data_vf);
-
- if (transpose_around_multiplier) {
- // See above comment: here we transpose again to undo the
- // transposition of the 16x16 block of accumulators used to implement
- // the channels-are-columns case.
- intrin_utils::mm512_transpose16x16_epi32(
- &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3,
- &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7,
- &accum_data_v8, &accum_data_v9, &accum_data_va, &accum_data_vb,
- &accum_data_vc, &accum_data_vd, &accum_data_ve, &accum_data_vf);
+ if (per_column_multiplier) {
+ auto apply_multiplier = [=](__m512i& accum, int col) {
+ __m512i perm_64bit_vals = _mm512_set1_epi64(col % 8);
+ // Apply the fixed-point part of the multiplier.
+ __m512i left_shift_val =
+ _mm512_permutexvar_epi32(_mm512_set1_epi32(col), left_shift);
+ __m512i m_64bit_val = _mm512_permutexvar_epi64(
+ perm_64bit_vals, col < 8 ? m_64bit_low : m_64bit_high);
+ __m512i offset_vector_val = _mm512_permutexvar_epi64(
+ perm_64bit_vals,
+ col < 8 ? offset_vector_low : offset_vector_high);
+ __m512i final_right_shift_val = _mm512_permutexvar_epi64(
+ perm_64bit_vals,
+ col < 8 ? final_right_shift_low : final_right_shift_high);
+
+ accum = _mm512_sllv_epi32(accum, left_shift_val);
+ __m512i scaled_v_low = _mm512_mul_epi32(
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 0)),
+ m_64bit_val);
+ __m512i scaled_v_high = _mm512_mul_epi32(
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 1)),
+ m_64bit_val);
+
+ scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_val);
+ scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_val);
+
+ scaled_v_low =
+ _mm512_srav_epi64(scaled_v_low, final_right_shift_val);
+ scaled_v_high =
+ _mm512_srav_epi64(scaled_v_high, final_right_shift_val);
+
+ accum = _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
+ accum = _mm512_inserti32x8(accum,
+ _mm512_cvtepi64_epi32(scaled_v_high), 1);
+ };
+ apply_multiplier(accum_data_v0, 0);
+ apply_multiplier(accum_data_v1, 1);
+ apply_multiplier(accum_data_v2, 2);
+ apply_multiplier(accum_data_v3, 3);
+ apply_multiplier(accum_data_v4, 4);
+ apply_multiplier(accum_data_v5, 5);
+ apply_multiplier(accum_data_v6, 6);
+ apply_multiplier(accum_data_v7, 7);
+ apply_multiplier(accum_data_v8, 8);
+ apply_multiplier(accum_data_v9, 9);
+ apply_multiplier(accum_data_va, 10);
+ apply_multiplier(accum_data_vb, 11);
+ apply_multiplier(accum_data_vc, 12);
+ apply_multiplier(accum_data_vd, 13);
+ apply_multiplier(accum_data_ve, 14);
+ apply_multiplier(accum_data_vf, 15);
+ } else { // not per-column, so per-row
+ auto apply_multiplier = [=](__m512i& accum) {
+ accum = _mm512_sllv_epi32(accum, left_shift);
+ // Apply the fixed-point part of the multiplier.
+ __m512i scaled_v_low = _mm512_mul_epi32(
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 0)),
+ m_64bit_low);
+ __m512i scaled_v_high = _mm512_mul_epi32(
+ _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 1)),
+ m_64bit_high);
+
+ scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low);
+ scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high);
+
+ scaled_v_low =
+ _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
+ scaled_v_high =
+ _mm512_srav_epi64(scaled_v_high, final_right_shift_high);
+
+ accum = _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
+ accum = _mm512_inserti32x8(accum,
+ _mm512_cvtepi64_epi32(scaled_v_high), 1);
+ };
+ apply_multiplier(accum_data_v0);
+ apply_multiplier(accum_data_v1);
+ apply_multiplier(accum_data_v2);
+ apply_multiplier(accum_data_v3);
+ apply_multiplier(accum_data_v4);
+ apply_multiplier(accum_data_v5);
+ apply_multiplier(accum_data_v6);
+ apply_multiplier(accum_data_v7);
+ apply_multiplier(accum_data_v8);
+ apply_multiplier(accum_data_v9);
+ apply_multiplier(accum_data_va);
+ apply_multiplier(accum_data_vb);
+ apply_multiplier(accum_data_vc);
+ apply_multiplier(accum_data_vd);
+ apply_multiplier(accum_data_ve);
+ apply_multiplier(accum_data_vf);
}
if (params.dst_zero_point != 0) {