From 2c5f035bbcbe6aee09eb0b4bd29aa0d52b36e061 Mon Sep 17 00:00:00 2001 From: Dayeong Lee Date: Wed, 3 Nov 2021 23:08:22 -0700 Subject: Ruy: Support 8x16 avx512/avx2_fma kernel for single_column. PiperOrigin-RevId: 407507985 --- ruy/kernel_avx2_fma.cc | 33 ++++++++++++++++++++++----------- ruy/kernel_avx512.cc | 30 +++++++++++++++++++----------- ruy/kernel_x86.h | 14 ++++++++++++-- 3 files changed, 53 insertions(+), 24 deletions(-) diff --git a/ruy/kernel_avx2_fma.cc b/ruy/kernel_avx2_fma.cc index 50488a4..e725777 100644 --- a/ruy/kernel_avx2_fma.cc +++ b/ruy/kernel_avx2_fma.cc @@ -755,8 +755,7 @@ void Kernel8bitAvx2SingleColImpl(const KernelParams8bit<8, 8>& params) { int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0; - const std::int8_t* rhs_col_ptr = - static_cast(params.rhs_base_ptr); + const void* rhs_col_ptr = params.rhs_base_ptr; void* dst_col_ptr = params.dst_base_ptr; const std::int32_t* bias_col_ptr = params.bias; if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { @@ -820,20 +819,29 @@ void Kernel8bitAvx2SingleColImpl(const KernelParams8bit<8, 8>& params) { } const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; + const void* rhs_ptr = rhs_col_ptr; for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { const __m256i lhs_data = _mm256_load_si256(reinterpret_cast(lhs_ptr)); - const __m128i rhs_data_8bit = intrin_utils::mm_loadu_si32(rhs_ptr); + const std::int32_t* rhs_data = + reinterpret_cast(rhs_ptr); // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. // For simplicity we load 4x the data that we need and process twice the // data that we need and store only the data we need. - std::int32_t rhs_data[2]; - const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit); - // Now that we have cast the RHS data, we store it so that each value - // can be separately loaded in the accumulation loop. - _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup); + std::int32_t rhs_data_buf[2]; + if (params.rhs_scalar_size == 1) { + rhs_data = rhs_data_buf; + const __m128i rhs_data_8bit = + intrin_utils::mm_loadu_si32(rhs_ptr); + const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit); + // Now that we have cast the RHS data, we store it so that each value + // can be separately loaded in the accumulation loop. + _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data_buf), + rhs_16_bit_dup); + } else { + RUY_DCHECK(params.rhs_scalar_size == 2); + } // NOTE: There may be opportunities for permuting the data in the packing // code instead of here. @@ -864,7 +872,9 @@ void Kernel8bitAvx2SingleColImpl(const KernelParams8bit<8, 8>& params) { _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; - rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; + rhs_ptr = static_cast(static_cast(rhs_ptr) + + kAvx8bitBlockSize * kAvx8bitInnerSize * + params.rhs_scalar_size); } if (params.dst_type_id != DstTypeId::kValue) { @@ -1002,7 +1012,8 @@ void Kernel8bitAvx2SingleColImpl(const KernelParams8bit<8, 8>& params) { dst_col_ptr = static_cast(static_cast(dst_col_ptr) + kAvx8bitBlockSize * params.dst_stride); - rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; + rhs_col_ptr = static_cast(static_cast(rhs_col_ptr) + + kAvx8bitBlockSize * params.rhs_stride); } // NOLINT(readability/fn_size) void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) { diff --git a/ruy/kernel_avx512.cc b/ruy/kernel_avx512.cc index 3d2219b..654ba27 100644 --- a/ruy/kernel_avx512.cc +++ b/ruy/kernel_avx512.cc @@ -634,8 +634,7 @@ void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) { int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 16 : 0; - const std::int8_t* rhs_col_ptr = - static_cast(params.rhs_base_ptr); + const void* rhs_col_ptr = params.rhs_base_ptr; void* dst_col_ptr = params.dst_base_ptr; const std::int32_t* bias_col_ptr = params.bias; if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { @@ -694,20 +693,28 @@ void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) { } const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; + const void* rhs_ptr = rhs_col_ptr; for (int d = 0; d < params.depth; d += 4) { const __m512i lhs_data = _mm512_loadu_si512(lhs_ptr); - const __m128i rhs_data_8bit = - _mm_loadu_si128(reinterpret_cast(rhs_ptr)); + const std::int32_t* rhs_data = + reinterpret_cast(rhs_ptr); // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. // For simplicity we load 4x the data that we need and process twice the // data that we need and store only the data we need. - std::int32_t rhs_data[2]; - const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit); - // Now that we have cast the RHS data, we store it so that each value - // can be separately loaded in the accumulation loop. - _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup); + std::int32_t rhs_data_buf[2]; + if (params.rhs_scalar_size == 1) { + rhs_data = rhs_data_buf; + const __m128i rhs_data_8bit = + _mm_loadu_si128(reinterpret_cast(rhs_ptr)); + const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit); + // Now that we have cast the RHS data, we store it so that each value + // can be separately loaded in the accumulation loop. + _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data_buf), + rhs_16_bit_dup); + } else { + RUY_DCHECK(params.rhs_scalar_size == 2); + } // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit. const __m512i lhs_16_bit_low = @@ -731,7 +738,8 @@ void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) { accum_data_v0 = accum_v; lhs_ptr += 16 * 4; - rhs_ptr += 16 * 4; + rhs_ptr = static_cast(static_cast(rhs_ptr) + + 16 * 4 * params.rhs_scalar_size); } if (params.dst_type_id != DstTypeId::kValue) { diff --git a/ruy/kernel_x86.h b/ruy/kernel_x86.h index 0a6cf90..51787b9 100644 --- a/ruy/kernel_x86.h +++ b/ruy/kernel_x86.h @@ -74,7 +74,12 @@ struct Kernel params; MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row, end_col, dst, ¶ms); - Kernel8bitAvx512(params); + if (dst->layout.cols == 1 && + mul_params.channel_dimension() == ChannelDimension::kRow) { + Kernel8bitAvx512SingleCol(params); + } else { + Kernel8bitAvx512(params); + } } }; @@ -143,7 +148,12 @@ struct Kernel params; MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row, end_col, dst, ¶ms); - Kernel8bitAvx2(params); + if (dst->layout.cols == 1 && + mul_params.channel_dimension() == ChannelDimension::kRow) { + Kernel8bitAvx2SingleCol(params); + } else { + Kernel8bitAvx2(params); + } } }; -- cgit v1.2.3