diff options
author | benoitjacob <benoitjacob@google.com> | 2020-07-24 19:47:32 +0300 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2020-07-24 21:52:36 +0300 |
commit | 78b870988e5ed9587dc23838d800352731a0f58c (patch) | |
tree | 7c07a8651ed4965c9738a958f14df3c65fa5f160 | |
parent | 1efd97066bbc59c3bcce267f99d46ec5387e876e (diff) |
Let cpu_backend_gemm support all storage order combinations, unconditionally using ruy as the backend in combinations other than RowMajor*ColMajor->ColMajor, which were so far not supported. Ruy is different from other back-ends in that it supports all combinations as runtime parameters without a code size increase.test_323013778
PiperOrigin-RevId: 323013778
-rw-r--r-- | ruy/kernel_avx2_fma.cc | 357 | ||||
-rw-r--r-- | ruy/kernel_avx512.cc | 245 |
2 files changed, 128 insertions, 474 deletions
diff --git a/ruy/kernel_avx2_fma.cc b/ruy/kernel_avx2_fma.cc index 463bb4b..4ba73f5 100644 --- a/ruy/kernel_avx2_fma.cc +++ b/ruy/kernel_avx2_fma.cc @@ -319,39 +319,6 @@ inline float mm256_get1_ps(const __m256 a, int i) { return float_val; } -inline __m256 mm256_n_loadu_ps(int i, const float* src) { - switch (i) { - case 0: - return _mm256_setzero_ps(); - case 1: - return _mm256_setr_m128(_mm_setr_ps(src[0], .0f, .0f, .0f), - _mm_setzero_ps()); - case 2: - return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], .0f, .0f), - _mm_setzero_ps()); - case 3: - return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], src[2], .0f), - _mm_setzero_ps()); - case 4: - return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], src[2], src[3]), - _mm_setzero_ps()); - case 5: - return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], .0f, .0f, - .0f); - case 6: - return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], src[5], .0f, - .0f); - case 7: - return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], src[5], - src[6], .0f); - case 8: - return _mm256_loadu_ps(src); - default: - RUY_DCHECK_LT(i, 9); - return _mm256_setzero_ps(); - } -} - inline void mm256_n_storeu_ps(float* dst, int residual_rows, const __m256 v) { for (int i = 0; i < residual_rows; ++i) { dst[i] = intrin_utils::mm256_get1_ps(v, i); @@ -589,126 +556,26 @@ void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) { // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit. const __m256i lhs_16_bit_high = _mm256_permute2x128_si256( lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31); - // Accumulate for column 0. - { - const std::int32_t low_rhs_value = rhs_data[0]; - const std::int32_t high_rhs_value = rhs_data[1]; - - const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); - const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - - accum_data_v0 = _mm256_add_epi32( - accum_data_v0, - _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v0 = _mm256_add_epi32( - accum_data_v0, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - } - // Accumulate for column 1. - { - const std::int32_t low_rhs_value = rhs_data[2]; - const std::int32_t high_rhs_value = rhs_data[3]; - - const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); - const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - - accum_data_v1 = _mm256_add_epi32( - accum_data_v1, - _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v1 = _mm256_add_epi32( - accum_data_v1, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - } - // Accumulate for column 2. - { - const std::int32_t low_rhs_value = rhs_data[4]; - const std::int32_t high_rhs_value = rhs_data[5]; - - const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); - const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - - accum_data_v2 = _mm256_add_epi32( - accum_data_v2, - _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v2 = _mm256_add_epi32( - accum_data_v2, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - } - // Accumulate for column 3. - { - const std::int32_t low_rhs_value = rhs_data[6]; - const std::int32_t high_rhs_value = rhs_data[7]; + auto process_column = [=](int col, __m256i& accum) { + const std::int32_t low_rhs_value = rhs_data[col * 2]; + const std::int32_t high_rhs_value = rhs_data[col * 2 + 1]; const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - accum_data_v3 = _mm256_add_epi32( - accum_data_v3, - _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v3 = _mm256_add_epi32( - accum_data_v3, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - } - // Accumulate for column 4. - { - const std::int32_t low_rhs_value = rhs_data[8]; - const std::int32_t high_rhs_value = rhs_data[9]; - - const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); - const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - - accum_data_v4 = _mm256_add_epi32( - accum_data_v4, - _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v4 = _mm256_add_epi32( - accum_data_v4, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - } - // Accumulate for column 5. - { - const std::int32_t low_rhs_value = rhs_data[10]; - const std::int32_t high_rhs_value = rhs_data[11]; - - const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); - const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - - accum_data_v5 = _mm256_add_epi32( - accum_data_v5, - _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v5 = _mm256_add_epi32( - accum_data_v5, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - } - // Accumulate for column 6. - { - const std::int32_t low_rhs_value = rhs_data[12]; - const std::int32_t high_rhs_value = rhs_data[13]; - - const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); - const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - - accum_data_v6 = _mm256_add_epi32( - accum_data_v6, - _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v6 = _mm256_add_epi32( - accum_data_v6, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - } - // Accumulate for column 7. - { - const std::int32_t low_rhs_value = rhs_data[14]; - const std::int32_t high_rhs_value = rhs_data[15]; - - const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value); - const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value); - - accum_data_v7 = _mm256_add_epi32( - accum_data_v7, - _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); - accum_data_v7 = _mm256_add_epi32( - accum_data_v7, - _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); - } + accum = _mm256_add_epi32( + accum, _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); + accum = _mm256_add_epi32( + accum, _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); + }; + process_column(0, accum_data_v0); + process_column(1, accum_data_v1); + process_column(2, accum_data_v2); + process_column(3, accum_data_v3); + process_column(4, accum_data_v4); + process_column(5, accum_data_v5); + process_column(6, accum_data_v6); + process_column(7, accum_data_v7); lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; @@ -844,8 +711,8 @@ void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) { &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3, &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7); } - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v0, left_shift); + auto apply_multiplier = [=](__m256i& accum) { + __m256i shifted_accum = _mm256_sllv_epi32(accum, left_shift); // Apply the fixed-point part of the multiplier. __m256i scaled_v_low = _mm256_mul_epi32( _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), @@ -866,176 +733,16 @@ void Kernel8bitAvx2(const KernelParams8bit<8, 8>& params) { _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); results = _mm256_permutevar8x32_epi32(results, repack_perm); - accum_data_v0 = _mm256_sub_epi32(results, post_scaling_offset); - } - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v1, left_shift); - // Apply the fixed-point part of the multiplier. - __m256i scaled_v_low = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), - m_64bit_low); - __m256i scaled_v_high = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), - m_64bit_high); - - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); - - scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - __m256i results = - _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum_data_v1 = _mm256_sub_epi32(results, post_scaling_offset); - } - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v2, left_shift); - // Apply the fixed-point part of the multiplier. - __m256i scaled_v_low = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), - m_64bit_low); - __m256i scaled_v_high = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), - m_64bit_high); - - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); - - scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - __m256i results = - _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum_data_v2 = _mm256_sub_epi32(results, post_scaling_offset); - } - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v3, left_shift); - // Apply the fixed-point part of the multiplier. - __m256i scaled_v_low = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), - m_64bit_low); - __m256i scaled_v_high = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), - m_64bit_high); - - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); - - scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - __m256i results = - _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum_data_v3 = _mm256_sub_epi32(results, post_scaling_offset); - } - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v4, left_shift); - // Apply the fixed-point part of the multiplier. - __m256i scaled_v_low = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), - m_64bit_low); - __m256i scaled_v_high = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), - m_64bit_high); - - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); - - scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - __m256i results = - _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum_data_v4 = _mm256_sub_epi32(results, post_scaling_offset); - } - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v5, left_shift); - // Apply the fixed-point part of the multiplier. - __m256i scaled_v_low = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), - m_64bit_low); - __m256i scaled_v_high = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), - m_64bit_high); - - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); - - scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - __m256i results = - _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum_data_v5 = _mm256_sub_epi32(results, post_scaling_offset); - } - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v6, left_shift); - // Apply the fixed-point part of the multiplier. - __m256i scaled_v_low = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), - m_64bit_low); - __m256i scaled_v_high = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), - m_64bit_high); - - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); - - scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - __m256i results = - _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum_data_v6 = _mm256_sub_epi32(results, post_scaling_offset); - } - { - __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v7, left_shift); - // Apply the fixed-point part of the multiplier. - __m256i scaled_v_low = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)), - m_64bit_low); - __m256i scaled_v_high = _mm256_mul_epi32( - _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)), - m_64bit_high); - - scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm256_srlv_epi64(scaled_v_high, final_right_shift_high); - - scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32); - __m256i results = - _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa); - results = _mm256_permutevar8x32_epi32(results, repack_perm); - - accum_data_v7 = _mm256_sub_epi32(results, post_scaling_offset); - } + accum = _mm256_sub_epi32(results, post_scaling_offset); + }; + apply_multiplier(accum_data_v0); + apply_multiplier(accum_data_v1); + apply_multiplier(accum_data_v2); + apply_multiplier(accum_data_v3); + apply_multiplier(accum_data_v4); + apply_multiplier(accum_data_v5); + apply_multiplier(accum_data_v6); + apply_multiplier(accum_data_v7); // See above comment: here we transpose again to undo the transposition // of the 8x8 block of accumulators used to implement the // channels-are-columns case. @@ -1549,8 +1256,7 @@ void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) { } } else { const float* bias_elem_ptr = bias_ptr + row * bias_ptr_block_increment; - const __m256 initial_accum_data = - intrin_utils::mm256_n_loadu_ps(residual_rows, bias_elem_ptr); + const __m256 initial_accum_data = _mm256_loadu_ps(bias_elem_ptr); for (int j = 0; j < 8; ++j) { accum_data_v[j] = initial_accum_data; @@ -1620,8 +1326,7 @@ void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params) { } } else { const float* bias_elem_ptr = bias_ptr + row * bias_ptr_block_increment; - const __m256 initial_accum_data = - intrin_utils::mm256_n_loadu_ps(residual_rows, bias_elem_ptr); + const __m256 initial_accum_data = _mm256_loadu_ps(bias_elem_ptr); for (int j = 0; j < 8; ++j) { accum_data_v[j] = initial_accum_data; @@ -1739,7 +1444,7 @@ void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params) { const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; // Initialize with bias. - accum_data_v = intrin_utils::mm256_n_loadu_ps(residual_rows, bias_ptr); + accum_data_v = _mm256_loadu_ps(bias_ptr); const float* lhs_ptr = lhs_col_ptr; const float* rhs_ptr = rhs_col_ptr; diff --git a/ruy/kernel_avx512.cc b/ruy/kernel_avx512.cc index 3d36516..d2d1f4c 100644 --- a/ruy/kernel_avx512.cc +++ b/ruy/kernel_avx512.cc @@ -52,88 +52,6 @@ void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>&) { #else // RUY_PLATFORM_AVX512 && RUY_OPT(ASM) -namespace { -namespace intrin_utils { - -// Transpose a 8x8 matrix of int32's. -void mm512_transpose16x16_epi32(__m512i* v0, __m512i* v1, __m512i* v2, - __m512i* v3, __m512i* v4, __m512i* v5, - __m512i* v6, __m512i* v7, __m512i* v8, - __m512i* v9, __m512i* va, __m512i* vb, - __m512i* vc, __m512i* vd, __m512i* ve, - __m512i* vf) { - __m512i t2x2_0 = _mm512_unpacklo_epi32(*v0, *v1); - __m512i t2x2_1 = _mm512_unpackhi_epi32(*v0, *v1); - __m512i t2x2_2 = _mm512_unpacklo_epi32(*v2, *v3); - __m512i t2x2_3 = _mm512_unpackhi_epi32(*v2, *v3); - __m512i t2x2_4 = _mm512_unpacklo_epi32(*v4, *v5); - __m512i t2x2_5 = _mm512_unpackhi_epi32(*v4, *v5); - __m512i t2x2_6 = _mm512_unpacklo_epi32(*v6, *v7); - __m512i t2x2_7 = _mm512_unpackhi_epi32(*v6, *v7); - __m512i t2x2_8 = _mm512_unpacklo_epi32(*v8, *v9); - __m512i t2x2_9 = _mm512_unpackhi_epi32(*v8, *v9); - __m512i t2x2_a = _mm512_unpacklo_epi32(*va, *vb); - __m512i t2x2_b = _mm512_unpackhi_epi32(*va, *vb); - __m512i t2x2_c = _mm512_unpacklo_epi32(*vc, *vd); - __m512i t2x2_d = _mm512_unpackhi_epi32(*vc, *vd); - __m512i t2x2_e = _mm512_unpacklo_epi32(*ve, *vf); - __m512i t2x2_f = _mm512_unpackhi_epi32(*ve, *vf); - - __m512i t4x4_0 = _mm512_unpacklo_epi64(t2x2_0, t2x2_2); - __m512i t4x4_1 = _mm512_unpackhi_epi64(t2x2_0, t2x2_2); - __m512i t4x4_2 = _mm512_unpacklo_epi64(t2x2_1, t2x2_3); - __m512i t4x4_3 = _mm512_unpackhi_epi64(t2x2_1, t2x2_3); - __m512i t4x4_4 = _mm512_unpacklo_epi64(t2x2_4, t2x2_6); - __m512i t4x4_5 = _mm512_unpackhi_epi64(t2x2_4, t2x2_6); - __m512i t4x4_6 = _mm512_unpacklo_epi64(t2x2_5, t2x2_7); - __m512i t4x4_7 = _mm512_unpackhi_epi64(t2x2_5, t2x2_7); - __m512i t4x4_8 = _mm512_unpacklo_epi64(t2x2_8, t2x2_a); - __m512i t4x4_9 = _mm512_unpackhi_epi64(t2x2_8, t2x2_a); - __m512i t4x4_a = _mm512_unpacklo_epi64(t2x2_9, t2x2_b); - __m512i t4x4_b = _mm512_unpackhi_epi64(t2x2_9, t2x2_b); - __m512i t4x4_c = _mm512_unpacklo_epi64(t2x2_c, t2x2_e); - __m512i t4x4_d = _mm512_unpackhi_epi64(t2x2_c, t2x2_e); - __m512i t4x4_e = _mm512_unpacklo_epi64(t2x2_d, t2x2_f); - __m512i t4x4_f = _mm512_unpackhi_epi64(t2x2_d, t2x2_f); - - __m512i t8x8_0 = _mm512_shuffle_i32x4(t4x4_0, t4x4_4, 0x88); - __m512i t8x8_1 = _mm512_shuffle_i32x4(t4x4_1, t4x4_5, 0x88); - __m512i t8x8_2 = _mm512_shuffle_i32x4(t4x4_2, t4x4_6, 0x88); - __m512i t8x8_3 = _mm512_shuffle_i32x4(t4x4_3, t4x4_7, 0x88); - __m512i t8x8_4 = _mm512_shuffle_i32x4(t4x4_0, t4x4_4, 0xdd); - __m512i t8x8_5 = _mm512_shuffle_i32x4(t4x4_1, t4x4_5, 0xdd); - __m512i t8x8_6 = _mm512_shuffle_i32x4(t4x4_2, t4x4_6, 0xdd); - __m512i t8x8_7 = _mm512_shuffle_i32x4(t4x4_3, t4x4_7, 0xdd); - __m512i t8x8_8 = _mm512_shuffle_i32x4(t4x4_8, t4x4_c, 0x88); - __m512i t8x8_9 = _mm512_shuffle_i32x4(t4x4_9, t4x4_d, 0x88); - __m512i t8x8_a = _mm512_shuffle_i32x4(t4x4_a, t4x4_e, 0x88); - __m512i t8x8_b = _mm512_shuffle_i32x4(t4x4_b, t4x4_f, 0x88); - __m512i t8x8_c = _mm512_shuffle_i32x4(t4x4_8, t4x4_c, 0xdd); - __m512i t8x8_d = _mm512_shuffle_i32x4(t4x4_9, t4x4_d, 0xdd); - __m512i t8x8_e = _mm512_shuffle_i32x4(t4x4_a, t4x4_e, 0xdd); - __m512i t8x8_f = _mm512_shuffle_i32x4(t4x4_b, t4x4_f, 0xdd); - - *v0 = _mm512_shuffle_i32x4(t8x8_0, t8x8_8, 0x88); - *v1 = _mm512_shuffle_i32x4(t8x8_1, t8x8_9, 0x88); - *v2 = _mm512_shuffle_i32x4(t8x8_2, t8x8_a, 0x88); - *v3 = _mm512_shuffle_i32x4(t8x8_3, t8x8_b, 0x88); - *v4 = _mm512_shuffle_i32x4(t8x8_4, t8x8_c, 0x88); - *v5 = _mm512_shuffle_i32x4(t8x8_5, t8x8_d, 0x88); - *v6 = _mm512_shuffle_i32x4(t8x8_6, t8x8_e, 0x88); - *v7 = _mm512_shuffle_i32x4(t8x8_7, t8x8_f, 0x88); - *v8 = _mm512_shuffle_i32x4(t8x8_0, t8x8_8, 0xdd); - *v9 = _mm512_shuffle_i32x4(t8x8_1, t8x8_9, 0xdd); - *va = _mm512_shuffle_i32x4(t8x8_2, t8x8_a, 0xdd); - *vb = _mm512_shuffle_i32x4(t8x8_3, t8x8_b, 0xdd); - *vc = _mm512_shuffle_i32x4(t8x8_4, t8x8_c, 0xdd); - *vd = _mm512_shuffle_i32x4(t8x8_5, t8x8_d, 0xdd); - *ve = _mm512_shuffle_i32x4(t8x8_6, t8x8_e, 0xdd); - *vf = _mm512_shuffle_i32x4(t8x8_7, t8x8_f, 0xdd); -} - -} // namespace intrin_utils -} // namespace - void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) { profiler::ScopeLabel label("Kernel kAvx512 8-bit"); @@ -391,6 +309,13 @@ void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) { } if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) { + // The non-per-channel case could equivalently be handled in the per-row + // or per-column code path. The per-row code path is slightly more + // efficient so we handle it there. + const bool per_column_multiplier = + (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) && + (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL); + __m512i m_vector; __m512i e_vector; // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. @@ -426,72 +351,96 @@ void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) { offset_vector, _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1))); - // This multiplier code is complex and expensive enough on x86, that - // we prefer to implement the channels-are-columns case by transposing - // around it, rather than duplicate it (which would also require - // duplicating the above code computing the multiplier constants). - // This is one instance where channels-are-columns has lower performance - // than channels-are-rows. - const bool transpose_around_multiplier = - (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) && - (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL); - if (transpose_around_multiplier) { - // Transpose the 16x16 accumulators block. Will be un-transposed below - // after the multplier implementation. - intrin_utils::mm512_transpose16x16_epi32( - &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3, - &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7, - &accum_data_v8, &accum_data_v9, &accum_data_va, &accum_data_vb, - &accum_data_vc, &accum_data_vd, &accum_data_ve, &accum_data_vf); - } - - auto apply_multiplier = [=](__m512i& accum) { - accum = _mm512_sllv_epi32(accum, left_shift); - // Apply the fixed-point part of the multiplier. - __m512i scaled_v_low = _mm512_mul_epi32( - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 0)), - m_64bit_low); - __m512i scaled_v_high = _mm512_mul_epi32( - _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 1)), - m_64bit_high); - - scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); - scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); - - scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); - scaled_v_high = - _mm512_srav_epi64(scaled_v_high, final_right_shift_high); - - accum = _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); - accum = _mm512_inserti32x8(accum, - _mm512_cvtepi64_epi32(scaled_v_high), 1); - }; - apply_multiplier(accum_data_v0); - apply_multiplier(accum_data_v1); - apply_multiplier(accum_data_v2); - apply_multiplier(accum_data_v3); - apply_multiplier(accum_data_v4); - apply_multiplier(accum_data_v5); - apply_multiplier(accum_data_v6); - apply_multiplier(accum_data_v7); - apply_multiplier(accum_data_v8); - apply_multiplier(accum_data_v9); - apply_multiplier(accum_data_va); - apply_multiplier(accum_data_vb); - apply_multiplier(accum_data_vc); - apply_multiplier(accum_data_vd); - apply_multiplier(accum_data_ve); - apply_multiplier(accum_data_vf); - - if (transpose_around_multiplier) { - // See above comment: here we transpose again to undo the - // transposition of the 16x16 block of accumulators used to implement - // the channels-are-columns case. - intrin_utils::mm512_transpose16x16_epi32( - &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3, - &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7, - &accum_data_v8, &accum_data_v9, &accum_data_va, &accum_data_vb, - &accum_data_vc, &accum_data_vd, &accum_data_ve, &accum_data_vf); + if (per_column_multiplier) { + auto apply_multiplier = [=](__m512i& accum, int col) { + __m512i perm_64bit_vals = _mm512_set1_epi64(col % 8); + // Apply the fixed-point part of the multiplier. + __m512i left_shift_val = + _mm512_permutexvar_epi32(_mm512_set1_epi32(col), left_shift); + __m512i m_64bit_val = _mm512_permutexvar_epi64( + perm_64bit_vals, col < 8 ? m_64bit_low : m_64bit_high); + __m512i offset_vector_val = _mm512_permutexvar_epi64( + perm_64bit_vals, + col < 8 ? offset_vector_low : offset_vector_high); + __m512i final_right_shift_val = _mm512_permutexvar_epi64( + perm_64bit_vals, + col < 8 ? final_right_shift_low : final_right_shift_high); + + accum = _mm512_sllv_epi32(accum, left_shift_val); + __m512i scaled_v_low = _mm512_mul_epi32( + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 0)), + m_64bit_val); + __m512i scaled_v_high = _mm512_mul_epi32( + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 1)), + m_64bit_val); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_val); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_val); + + scaled_v_low = + _mm512_srav_epi64(scaled_v_low, final_right_shift_val); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_val); + + accum = _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum = _mm512_inserti32x8(accum, + _mm512_cvtepi64_epi32(scaled_v_high), 1); + }; + apply_multiplier(accum_data_v0, 0); + apply_multiplier(accum_data_v1, 1); + apply_multiplier(accum_data_v2, 2); + apply_multiplier(accum_data_v3, 3); + apply_multiplier(accum_data_v4, 4); + apply_multiplier(accum_data_v5, 5); + apply_multiplier(accum_data_v6, 6); + apply_multiplier(accum_data_v7, 7); + apply_multiplier(accum_data_v8, 8); + apply_multiplier(accum_data_v9, 9); + apply_multiplier(accum_data_va, 10); + apply_multiplier(accum_data_vb, 11); + apply_multiplier(accum_data_vc, 12); + apply_multiplier(accum_data_vd, 13); + apply_multiplier(accum_data_ve, 14); + apply_multiplier(accum_data_vf, 15); + } else { // not per-column, so per-row + auto apply_multiplier = [=](__m512i& accum) { + accum = _mm512_sllv_epi32(accum, left_shift); + // Apply the fixed-point part of the multiplier. + __m512i scaled_v_low = _mm512_mul_epi32( + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 0)), + m_64bit_low); + __m512i scaled_v_high = _mm512_mul_epi32( + _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 1)), + m_64bit_high); + + scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); + scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); + + scaled_v_low = + _mm512_srav_epi64(scaled_v_low, final_right_shift_low); + scaled_v_high = + _mm512_srav_epi64(scaled_v_high, final_right_shift_high); + + accum = _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); + accum = _mm512_inserti32x8(accum, + _mm512_cvtepi64_epi32(scaled_v_high), 1); + }; + apply_multiplier(accum_data_v0); + apply_multiplier(accum_data_v1); + apply_multiplier(accum_data_v2); + apply_multiplier(accum_data_v3); + apply_multiplier(accum_data_v4); + apply_multiplier(accum_data_v5); + apply_multiplier(accum_data_v6); + apply_multiplier(accum_data_v7); + apply_multiplier(accum_data_v8); + apply_multiplier(accum_data_v9); + apply_multiplier(accum_data_va); + apply_multiplier(accum_data_vb); + apply_multiplier(accum_data_vc); + apply_multiplier(accum_data_vd); + apply_multiplier(accum_data_ve); + apply_multiplier(accum_data_vf); } if (params.dst_zero_point != 0) { |