diff options
author | Dayeong Lee <dayeongl@google.com> | 2021-11-04 09:08:22 +0300 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2021-11-04 09:08:48 +0300 |
commit | 2c5f035bbcbe6aee09eb0b4bd29aa0d52b36e061 (patch) | |
tree | 743c4fcb5e13a761cb9313cb04b61c8199a39235 | |
parent | 409296d21c15800375d894f2e0d1cc4c88c14cd9 (diff) |
Ruy: Support 8x16 avx512/avx2_fma kernel for single_column.
PiperOrigin-RevId: 407507985
-rw-r--r-- | ruy/kernel_avx2_fma.cc | 33 | ||||
-rw-r--r-- | ruy/kernel_avx512.cc | 30 | ||||
-rw-r--r-- | ruy/kernel_x86.h | 14 |
3 files changed, 53 insertions, 24 deletions
diff --git a/ruy/kernel_avx2_fma.cc b/ruy/kernel_avx2_fma.cc index 50488a4..e725777 100644 --- a/ruy/kernel_avx2_fma.cc +++ b/ruy/kernel_avx2_fma.cc @@ -755,8 +755,7 @@ void Kernel8bitAvx2SingleColImpl(const KernelParams8bit<8, 8>& params) { int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0; - const std::int8_t* rhs_col_ptr = - static_cast<const int8_t*>(params.rhs_base_ptr); + const void* rhs_col_ptr = params.rhs_base_ptr; void* dst_col_ptr = params.dst_base_ptr; const std::int32_t* bias_col_ptr = params.bias; if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { @@ -820,20 +819,29 @@ void Kernel8bitAvx2SingleColImpl(const KernelParams8bit<8, 8>& params) { } const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; + const void* rhs_ptr = rhs_col_ptr; for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { const __m256i lhs_data = _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr)); - const __m128i rhs_data_8bit = intrin_utils::mm_loadu_si32<path>(rhs_ptr); + const std::int32_t* rhs_data = + reinterpret_cast<const std::int32_t*>(rhs_ptr); // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. // For simplicity we load 4x the data that we need and process twice the // data that we need and store only the data we need. - std::int32_t rhs_data[2]; - const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit); - // Now that we have cast the RHS data, we store it so that each value - // can be separately loaded in the accumulation loop. - _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup); + std::int32_t rhs_data_buf[2]; + if (params.rhs_scalar_size == 1) { + rhs_data = rhs_data_buf; + const __m128i rhs_data_8bit = + intrin_utils::mm_loadu_si32<path>(rhs_ptr); + const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit); + // Now that we have cast the RHS data, we store it so that each value + // can be separately loaded in the accumulation loop. + _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data_buf), + rhs_16_bit_dup); + } else { + RUY_DCHECK(params.rhs_scalar_size == 2); + } // NOTE: There may be opportunities for permuting the data in the packing // code instead of here. @@ -864,7 +872,9 @@ void Kernel8bitAvx2SingleColImpl(const KernelParams8bit<8, 8>& params) { _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; - rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; + rhs_ptr = static_cast<const void*>(static_cast<const char*>(rhs_ptr) + + kAvx8bitBlockSize * kAvx8bitInnerSize * + params.rhs_scalar_size); } if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) { @@ -1002,7 +1012,8 @@ void Kernel8bitAvx2SingleColImpl(const KernelParams8bit<8, 8>& params) { dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) + kAvx8bitBlockSize * params.dst_stride); - rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; + rhs_col_ptr = static_cast<const void*>(static_cast<const char*>(rhs_col_ptr) + + kAvx8bitBlockSize * params.rhs_stride); } // NOLINT(readability/fn_size) void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) { diff --git a/ruy/kernel_avx512.cc b/ruy/kernel_avx512.cc index 3d2219b..654ba27 100644 --- a/ruy/kernel_avx512.cc +++ b/ruy/kernel_avx512.cc @@ -634,8 +634,7 @@ void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) { int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 16 : 0; - const std::int8_t* rhs_col_ptr = - static_cast<const int8_t*>(params.rhs_base_ptr); + const void* rhs_col_ptr = params.rhs_base_ptr; void* dst_col_ptr = params.dst_base_ptr; const std::int32_t* bias_col_ptr = params.bias; if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { @@ -694,20 +693,28 @@ void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) { } const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; + const void* rhs_ptr = rhs_col_ptr; for (int d = 0; d < params.depth; d += 4) { const __m512i lhs_data = _mm512_loadu_si512(lhs_ptr); - const __m128i rhs_data_8bit = - _mm_loadu_si128(reinterpret_cast<const __m128i*>(rhs_ptr)); + const std::int32_t* rhs_data = + reinterpret_cast<const std::int32_t*>(rhs_ptr); // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. // For simplicity we load 4x the data that we need and process twice the // data that we need and store only the data we need. - std::int32_t rhs_data[2]; - const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit); - // Now that we have cast the RHS data, we store it so that each value - // can be separately loaded in the accumulation loop. - _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup); + std::int32_t rhs_data_buf[2]; + if (params.rhs_scalar_size == 1) { + rhs_data = rhs_data_buf; + const __m128i rhs_data_8bit = + _mm_loadu_si128(reinterpret_cast<const __m128i*>(rhs_ptr)); + const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit); + // Now that we have cast the RHS data, we store it so that each value + // can be separately loaded in the accumulation loop. + _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data_buf), + rhs_16_bit_dup); + } else { + RUY_DCHECK(params.rhs_scalar_size == 2); + } // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit. const __m512i lhs_16_bit_low = @@ -731,7 +738,8 @@ void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) { accum_data_v0 = accum_v; lhs_ptr += 16 * 4; - rhs_ptr += 16 * 4; + rhs_ptr = static_cast<const void*>(static_cast<const char*>(rhs_ptr) + + 16 * 4 * params.rhs_scalar_size); } if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) { diff --git a/ruy/kernel_x86.h b/ruy/kernel_x86.h index 0a6cf90..51787b9 100644 --- a/ruy/kernel_x86.h +++ b/ruy/kernel_x86.h @@ -74,7 +74,12 @@ struct Kernel<Path::kAvx512, std::int8_t, std::int16_t, std::int32_t, KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params; MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row, end_col, dst, ¶ms); - Kernel8bitAvx512(params); + if (dst->layout.cols == 1 && + mul_params.channel_dimension() == ChannelDimension::kRow) { + Kernel8bitAvx512SingleCol(params); + } else { + Kernel8bitAvx512(params); + } } }; @@ -143,7 +148,12 @@ struct Kernel<Path::kAvx2Fma, std::int8_t, std::int16_t, std::int32_t, KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params; MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row, end_col, dst, ¶ms); - Kernel8bitAvx2(params); + if (dst->layout.cols == 1 && + mul_params.channel_dimension() == ChannelDimension::kRow) { + Kernel8bitAvx2SingleCol(params); + } else { + Kernel8bitAvx2(params); + } } }; |