diff options
author | Benoit Jacob <benoitjacob@google.com> | 2020-07-13 20:23:39 +0300 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2020-07-13 20:24:03 +0300 |
commit | 27d16d0b47ad31a81aa1d7b044a4a2162159d928 (patch) | |
tree | 00ac5155f6e132c1c0ae5d82757b1ccd6403c6d9 /ruy/kernel_avx512.cc | |
parent | 592d30cc49aa5ee3410677cf15e7d3a43b59b257 (diff) |
Efficient support for any channel_dimension for float kernels on AVX-512.
PiperOrigin-RevId: 320981916
Diffstat (limited to 'ruy/kernel_avx512.cc')
-rw-r--r-- | ruy/kernel_avx512.cc | 158 |
1 files changed, 118 insertions, 40 deletions
diff --git a/ruy/kernel_avx512.cc b/ruy/kernel_avx512.cc index a502ead..34e1038 100644 --- a/ruy/kernel_avx512.cc +++ b/ruy/kernel_avx512.cc @@ -1279,10 +1279,12 @@ void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) { params.dst_base_ptr - params.start_col * dst_stride - params.start_row; const float* adj_lhs_col_ptr = params.lhs_base_ptr - params.start_row * lhs_stride; - const float* bias_col_ptr = params.bias; + const float* bias_ptr = params.bias; const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max); const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min); + const bool channel_dimension_is_col = + params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL; int col = params.start_col; for (; col <= end_col - 16; col += 16) { @@ -1293,23 +1295,45 @@ void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) { for (; row <= end_row - 16; row += 16) { const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; float* dst_ptr = dst_col_ptr + row; - const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; - - // Initialize with bias. - const __m512 initial_accum_data = _mm512_loadu_ps(bias_ptr); // Process block in two halves, split by columns. { constexpr int mmm = 0; - __m512 accum_data_v0 = initial_accum_data; - __m512 accum_data_v1 = initial_accum_data; - __m512 accum_data_v2 = initial_accum_data; - __m512 accum_data_v3 = initial_accum_data; - __m512 accum_data_v4 = initial_accum_data; - __m512 accum_data_v5 = initial_accum_data; - __m512 accum_data_v6 = initial_accum_data; - __m512 accum_data_v7 = initial_accum_data; + __m512 accum_data_v0; + __m512 accum_data_v1; + __m512 accum_data_v2; + __m512 accum_data_v3; + __m512 accum_data_v4; + __m512 accum_data_v5; + __m512 accum_data_v6; + __m512 accum_data_v7; + + // Initialize with bias. + if (channel_dimension_is_col) { + const float* bias_elem_ptr = + bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment; + accum_data_v0 = _mm512_set1_ps(bias_elem_ptr[0]); + accum_data_v1 = _mm512_set1_ps(bias_elem_ptr[1]); + accum_data_v2 = _mm512_set1_ps(bias_elem_ptr[2]); + accum_data_v3 = _mm512_set1_ps(bias_elem_ptr[3]); + accum_data_v4 = _mm512_set1_ps(bias_elem_ptr[4]); + accum_data_v5 = _mm512_set1_ps(bias_elem_ptr[5]); + accum_data_v6 = _mm512_set1_ps(bias_elem_ptr[6]); + accum_data_v7 = _mm512_set1_ps(bias_elem_ptr[7]); + } else { + const __m512 initial_accum_data = + _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment); + + accum_data_v0 = initial_accum_data; + accum_data_v1 = initial_accum_data; + accum_data_v2 = initial_accum_data; + accum_data_v3 = initial_accum_data; + accum_data_v4 = initial_accum_data; + accum_data_v5 = initial_accum_data; + accum_data_v6 = initial_accum_data; + accum_data_v7 = initial_accum_data; + } const float* lhs_ptr = lhs_col_ptr; const float* rhs_ptr = rhs_col_ptr + 8 * mmm; @@ -1411,15 +1435,40 @@ void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) { { constexpr int mmm = 1; - __m512 accum_data_v0 = initial_accum_data; - __m512 accum_data_v1 = initial_accum_data; - __m512 accum_data_v2 = initial_accum_data; - __m512 accum_data_v3 = initial_accum_data; - __m512 accum_data_v4 = initial_accum_data; - __m512 accum_data_v5 = initial_accum_data; - __m512 accum_data_v6 = initial_accum_data; - __m512 accum_data_v7 = initial_accum_data; - + __m512 accum_data_v0; + __m512 accum_data_v1; + __m512 accum_data_v2; + __m512 accum_data_v3; + __m512 accum_data_v4; + __m512 accum_data_v5; + __m512 accum_data_v6; + __m512 accum_data_v7; + + // Initialize with bias. + if (channel_dimension_is_col) { + const float* bias_elem_ptr = + bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment; + accum_data_v0 = _mm512_set1_ps(bias_elem_ptr[0]); + accum_data_v1 = _mm512_set1_ps(bias_elem_ptr[1]); + accum_data_v2 = _mm512_set1_ps(bias_elem_ptr[2]); + accum_data_v3 = _mm512_set1_ps(bias_elem_ptr[3]); + accum_data_v4 = _mm512_set1_ps(bias_elem_ptr[4]); + accum_data_v5 = _mm512_set1_ps(bias_elem_ptr[5]); + accum_data_v6 = _mm512_set1_ps(bias_elem_ptr[6]); + accum_data_v7 = _mm512_set1_ps(bias_elem_ptr[7]); + } else { + const __m512 initial_accum_data = + _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment); + + accum_data_v0 = initial_accum_data; + accum_data_v1 = initial_accum_data; + accum_data_v2 = initial_accum_data; + accum_data_v3 = initial_accum_data; + accum_data_v4 = initial_accum_data; + accum_data_v5 = initial_accum_data; + accum_data_v6 = initial_accum_data; + accum_data_v7 = initial_accum_data; + } const float* lhs_ptr = lhs_col_ptr; const float* rhs_ptr = rhs_col_ptr + 8 * mmm; for (int d = 0; d < (params.depth - 1); ++d) { @@ -1521,24 +1570,46 @@ void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) { const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; float* dst_ptr = dst_col_ptr + row; - const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; - // Initialize with bias. const __mmask16 row_mask = (static_cast<std::uint32_t>(1) << residual_rows) - 1; - const __m512 initial_accum_data = - _mm512_maskz_loadu_ps(row_mask, bias_ptr); // Process block in two halves, split by columns. for (int mmm = 0; mmm < 2; ++mmm) { - __m512 accum_data_v0 = initial_accum_data; - __m512 accum_data_v1 = initial_accum_data; - __m512 accum_data_v2 = initial_accum_data; - __m512 accum_data_v3 = initial_accum_data; - __m512 accum_data_v4 = initial_accum_data; - __m512 accum_data_v5 = initial_accum_data; - __m512 accum_data_v6 = initial_accum_data; - __m512 accum_data_v7 = initial_accum_data; + __m512 accum_data_v0; + __m512 accum_data_v1; + __m512 accum_data_v2; + __m512 accum_data_v3; + __m512 accum_data_v4; + __m512 accum_data_v5; + __m512 accum_data_v6; + __m512 accum_data_v7; + + // Initialize with bias. + if (channel_dimension_is_col) { + const float* bias_elem_ptr = + bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment; + accum_data_v0 = _mm512_set1_ps(bias_elem_ptr[0]); + accum_data_v1 = _mm512_set1_ps(bias_elem_ptr[1]); + accum_data_v2 = _mm512_set1_ps(bias_elem_ptr[2]); + accum_data_v3 = _mm512_set1_ps(bias_elem_ptr[3]); + accum_data_v4 = _mm512_set1_ps(bias_elem_ptr[4]); + accum_data_v5 = _mm512_set1_ps(bias_elem_ptr[5]); + accum_data_v6 = _mm512_set1_ps(bias_elem_ptr[6]); + accum_data_v7 = _mm512_set1_ps(bias_elem_ptr[7]); + } else { + const __m512 initial_accum_data = + _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment); + + accum_data_v0 = initial_accum_data; + accum_data_v1 = initial_accum_data; + accum_data_v2 = initial_accum_data; + accum_data_v3 = initial_accum_data; + accum_data_v4 = initial_accum_data; + accum_data_v5 = initial_accum_data; + accum_data_v6 = initial_accum_data; + accum_data_v7 = initial_accum_data; + } const float* lhs_ptr = lhs_col_ptr; const float* rhs_ptr = rhs_col_ptr + 8 * mmm; @@ -1657,18 +1728,25 @@ void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) { const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; float* dst_ptr = dst_col_ptr + row; - const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; - // Initialize with bias. const __mmask16 row_mask = (static_cast<std::uint32_t>(1) << residual_rows) - 1; - const __m512 initial_accum_data = - _mm512_maskz_loadu_ps(row_mask, bias_ptr); // Process block in two halves, split by columns. for (int mmm = 0; mmm < 2; ++mmm) { - for (int j = 0; j < 8; ++j) { - accum_data_v[j] = initial_accum_data; + // Initialize with bias. + if (channel_dimension_is_col) { + const float* bias_elem_ptr = + bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment; + for (int j = 0; j < 8; ++j) { + accum_data_v[j] = _mm512_set1_ps(bias_elem_ptr[j]); + } + } else { + const __m512 initial_accum_data = + _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment); + for (int j = 0; j < 8; ++j) { + accum_data_v[j] = initial_accum_data; + } } const float* lhs_ptr = lhs_col_ptr; |