Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/google/ruy.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBenoit Jacob <benoitjacob@google.com>2020-07-13 20:23:39 +0300
committerCopybara-Service <copybara-worker@google.com>2020-07-13 20:24:03 +0300
commit27d16d0b47ad31a81aa1d7b044a4a2162159d928 (patch)
tree00ac5155f6e132c1c0ae5d82757b1ccd6403c6d9 /ruy/kernel_avx512.cc
parent592d30cc49aa5ee3410677cf15e7d3a43b59b257 (diff)
Efficient support for any channel_dimension for float kernels on AVX-512.
PiperOrigin-RevId: 320981916
Diffstat (limited to 'ruy/kernel_avx512.cc')
-rw-r--r--ruy/kernel_avx512.cc158
1 files changed, 118 insertions, 40 deletions
diff --git a/ruy/kernel_avx512.cc b/ruy/kernel_avx512.cc
index a502ead..34e1038 100644
--- a/ruy/kernel_avx512.cc
+++ b/ruy/kernel_avx512.cc
@@ -1279,10 +1279,12 @@ void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) {
params.dst_base_ptr - params.start_col * dst_stride - params.start_row;
const float* adj_lhs_col_ptr =
params.lhs_base_ptr - params.start_row * lhs_stride;
- const float* bias_col_ptr = params.bias;
+ const float* bias_ptr = params.bias;
const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max);
const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min);
+ const bool channel_dimension_is_col =
+ params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL;
int col = params.start_col;
for (; col <= end_col - 16; col += 16) {
@@ -1293,23 +1295,45 @@ void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) {
for (; row <= end_row - 16; row += 16) {
const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
float* dst_ptr = dst_col_ptr + row;
- const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
-
- // Initialize with bias.
- const __m512 initial_accum_data = _mm512_loadu_ps(bias_ptr);
// Process block in two halves, split by columns.
{
constexpr int mmm = 0;
- __m512 accum_data_v0 = initial_accum_data;
- __m512 accum_data_v1 = initial_accum_data;
- __m512 accum_data_v2 = initial_accum_data;
- __m512 accum_data_v3 = initial_accum_data;
- __m512 accum_data_v4 = initial_accum_data;
- __m512 accum_data_v5 = initial_accum_data;
- __m512 accum_data_v6 = initial_accum_data;
- __m512 accum_data_v7 = initial_accum_data;
+ __m512 accum_data_v0;
+ __m512 accum_data_v1;
+ __m512 accum_data_v2;
+ __m512 accum_data_v3;
+ __m512 accum_data_v4;
+ __m512 accum_data_v5;
+ __m512 accum_data_v6;
+ __m512 accum_data_v7;
+
+ // Initialize with bias.
+ if (channel_dimension_is_col) {
+ const float* bias_elem_ptr =
+ bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment;
+ accum_data_v0 = _mm512_set1_ps(bias_elem_ptr[0]);
+ accum_data_v1 = _mm512_set1_ps(bias_elem_ptr[1]);
+ accum_data_v2 = _mm512_set1_ps(bias_elem_ptr[2]);
+ accum_data_v3 = _mm512_set1_ps(bias_elem_ptr[3]);
+ accum_data_v4 = _mm512_set1_ps(bias_elem_ptr[4]);
+ accum_data_v5 = _mm512_set1_ps(bias_elem_ptr[5]);
+ accum_data_v6 = _mm512_set1_ps(bias_elem_ptr[6]);
+ accum_data_v7 = _mm512_set1_ps(bias_elem_ptr[7]);
+ } else {
+ const __m512 initial_accum_data =
+ _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment);
+
+ accum_data_v0 = initial_accum_data;
+ accum_data_v1 = initial_accum_data;
+ accum_data_v2 = initial_accum_data;
+ accum_data_v3 = initial_accum_data;
+ accum_data_v4 = initial_accum_data;
+ accum_data_v5 = initial_accum_data;
+ accum_data_v6 = initial_accum_data;
+ accum_data_v7 = initial_accum_data;
+ }
const float* lhs_ptr = lhs_col_ptr;
const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
@@ -1411,15 +1435,40 @@ void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) {
{
constexpr int mmm = 1;
- __m512 accum_data_v0 = initial_accum_data;
- __m512 accum_data_v1 = initial_accum_data;
- __m512 accum_data_v2 = initial_accum_data;
- __m512 accum_data_v3 = initial_accum_data;
- __m512 accum_data_v4 = initial_accum_data;
- __m512 accum_data_v5 = initial_accum_data;
- __m512 accum_data_v6 = initial_accum_data;
- __m512 accum_data_v7 = initial_accum_data;
-
+ __m512 accum_data_v0;
+ __m512 accum_data_v1;
+ __m512 accum_data_v2;
+ __m512 accum_data_v3;
+ __m512 accum_data_v4;
+ __m512 accum_data_v5;
+ __m512 accum_data_v6;
+ __m512 accum_data_v7;
+
+ // Initialize with bias.
+ if (channel_dimension_is_col) {
+ const float* bias_elem_ptr =
+ bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment;
+ accum_data_v0 = _mm512_set1_ps(bias_elem_ptr[0]);
+ accum_data_v1 = _mm512_set1_ps(bias_elem_ptr[1]);
+ accum_data_v2 = _mm512_set1_ps(bias_elem_ptr[2]);
+ accum_data_v3 = _mm512_set1_ps(bias_elem_ptr[3]);
+ accum_data_v4 = _mm512_set1_ps(bias_elem_ptr[4]);
+ accum_data_v5 = _mm512_set1_ps(bias_elem_ptr[5]);
+ accum_data_v6 = _mm512_set1_ps(bias_elem_ptr[6]);
+ accum_data_v7 = _mm512_set1_ps(bias_elem_ptr[7]);
+ } else {
+ const __m512 initial_accum_data =
+ _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment);
+
+ accum_data_v0 = initial_accum_data;
+ accum_data_v1 = initial_accum_data;
+ accum_data_v2 = initial_accum_data;
+ accum_data_v3 = initial_accum_data;
+ accum_data_v4 = initial_accum_data;
+ accum_data_v5 = initial_accum_data;
+ accum_data_v6 = initial_accum_data;
+ accum_data_v7 = initial_accum_data;
+ }
const float* lhs_ptr = lhs_col_ptr;
const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
for (int d = 0; d < (params.depth - 1); ++d) {
@@ -1521,24 +1570,46 @@ void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) {
const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
float* dst_ptr = dst_col_ptr + row;
- const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
- // Initialize with bias.
const __mmask16 row_mask =
(static_cast<std::uint32_t>(1) << residual_rows) - 1;
- const __m512 initial_accum_data =
- _mm512_maskz_loadu_ps(row_mask, bias_ptr);
// Process block in two halves, split by columns.
for (int mmm = 0; mmm < 2; ++mmm) {
- __m512 accum_data_v0 = initial_accum_data;
- __m512 accum_data_v1 = initial_accum_data;
- __m512 accum_data_v2 = initial_accum_data;
- __m512 accum_data_v3 = initial_accum_data;
- __m512 accum_data_v4 = initial_accum_data;
- __m512 accum_data_v5 = initial_accum_data;
- __m512 accum_data_v6 = initial_accum_data;
- __m512 accum_data_v7 = initial_accum_data;
+ __m512 accum_data_v0;
+ __m512 accum_data_v1;
+ __m512 accum_data_v2;
+ __m512 accum_data_v3;
+ __m512 accum_data_v4;
+ __m512 accum_data_v5;
+ __m512 accum_data_v6;
+ __m512 accum_data_v7;
+
+ // Initialize with bias.
+ if (channel_dimension_is_col) {
+ const float* bias_elem_ptr =
+ bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment;
+ accum_data_v0 = _mm512_set1_ps(bias_elem_ptr[0]);
+ accum_data_v1 = _mm512_set1_ps(bias_elem_ptr[1]);
+ accum_data_v2 = _mm512_set1_ps(bias_elem_ptr[2]);
+ accum_data_v3 = _mm512_set1_ps(bias_elem_ptr[3]);
+ accum_data_v4 = _mm512_set1_ps(bias_elem_ptr[4]);
+ accum_data_v5 = _mm512_set1_ps(bias_elem_ptr[5]);
+ accum_data_v6 = _mm512_set1_ps(bias_elem_ptr[6]);
+ accum_data_v7 = _mm512_set1_ps(bias_elem_ptr[7]);
+ } else {
+ const __m512 initial_accum_data =
+ _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment);
+
+ accum_data_v0 = initial_accum_data;
+ accum_data_v1 = initial_accum_data;
+ accum_data_v2 = initial_accum_data;
+ accum_data_v3 = initial_accum_data;
+ accum_data_v4 = initial_accum_data;
+ accum_data_v5 = initial_accum_data;
+ accum_data_v6 = initial_accum_data;
+ accum_data_v7 = initial_accum_data;
+ }
const float* lhs_ptr = lhs_col_ptr;
const float* rhs_ptr = rhs_col_ptr + 8 * mmm;
@@ -1657,18 +1728,25 @@ void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) {
const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
float* dst_ptr = dst_col_ptr + row;
- const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
- // Initialize with bias.
const __mmask16 row_mask =
(static_cast<std::uint32_t>(1) << residual_rows) - 1;
- const __m512 initial_accum_data =
- _mm512_maskz_loadu_ps(row_mask, bias_ptr);
// Process block in two halves, split by columns.
for (int mmm = 0; mmm < 2; ++mmm) {
- for (int j = 0; j < 8; ++j) {
- accum_data_v[j] = initial_accum_data;
+ // Initialize with bias.
+ if (channel_dimension_is_col) {
+ const float* bias_elem_ptr =
+ bias_ptr + (col + 8 * mmm) * bias_ptr_block_increment;
+ for (int j = 0; j < 8; ++j) {
+ accum_data_v[j] = _mm512_set1_ps(bias_elem_ptr[j]);
+ }
+ } else {
+ const __m512 initial_accum_data =
+ _mm512_loadu_ps(bias_ptr + row * bias_ptr_block_increment);
+ for (int j = 0; j < 8; ++j) {
+ accum_data_v[j] = initial_accum_data;
+ }
}
const float* lhs_ptr = lhs_col_ptr;