diff options
author | Dayeong Lee <dayeongl@google.com> | 2021-11-02 09:27:15 +0300 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2021-11-02 09:27:41 +0300 |
commit | 409296d21c15800375d894f2e0d1cc4c88c14cd9 (patch) | |
tree | 4d947f25b75a53c38b31574705b3677d529e4561 | |
parent | f805132145194c44839db198afa61e08d7852cc4 (diff) |
Ruy: Support 8x16 avx512 kernel
PiperOrigin-RevId: 407005437
-rw-r--r-- | ruy/kernel_avx512.cc | 48 | ||||
-rw-r--r-- | ruy/kernel_x86.h | 18 |
2 files changed, 46 insertions, 20 deletions
diff --git a/ruy/kernel_avx512.cc b/ruy/kernel_avx512.cc index 25ce2c7..3d2219b 100644 --- a/ruy/kernel_avx512.cc +++ b/ruy/kernel_avx512.cc @@ -67,8 +67,7 @@ void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) { RUY_DCHECK(false); } - const std::int8_t* rhs_col_ptr = - static_cast<const int8_t*>(params.rhs_base_ptr); + const void* rhs_col_ptr = params.rhs_base_ptr; void* dst_col_ptr = params.dst_base_ptr; for (int col = params.start_col; col <= params.last_col; col += 16) { @@ -248,27 +247,34 @@ void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) { } const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; + const void* rhs_ptr = rhs_col_ptr; for (int d = 0; d < params.depth; d += 4) { const __m512i lhs_data = _mm512_loadu_si512(lhs_ptr); __m512i rhs_data_8bit = _mm512_loadu_si512(rhs_ptr); // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. - std::int32_t rhs_data[32]; - const __m256i rhs_data_bottom_lane = - _mm512_castsi512_si256(rhs_data_8bit); - const __m256i rhs_data_top_lane = - _mm512_extracti32x8_epi32(rhs_data_8bit, 1); - const __m512i rhs_16_bit_dup_low = - _mm512_cvtepi8_epi16(rhs_data_bottom_lane); - const __m512i rhs_16_bit_dup_high = - _mm512_cvtepi8_epi16(rhs_data_top_lane); - // Now that we have cast the RHS data, we store it so that each value - // can be separately loaded in the accumulation loop. - _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data), - rhs_16_bit_dup_low); - _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data + 16), - rhs_16_bit_dup_high); + std::int32_t rhs_data_buf[32]; + const std::int32_t* rhs_data = + reinterpret_cast<const std::int32_t*>(rhs_ptr); + if (params.rhs_scalar_size == 1) { + rhs_data = rhs_data_buf; + const __m256i rhs_data_bottom_lane = + _mm512_castsi512_si256(rhs_data_8bit); + const __m256i rhs_data_top_lane = + _mm512_extracti32x8_epi32(rhs_data_8bit, 1); + const __m512i rhs_16_bit_dup_low = + _mm512_cvtepi8_epi16(rhs_data_bottom_lane); + const __m512i rhs_16_bit_dup_high = + _mm512_cvtepi8_epi16(rhs_data_top_lane); + // Now that we have cast the RHS data, we store it so that each value + // can be separately loaded in the accumulation loop. + _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data_buf), + rhs_16_bit_dup_low); + _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data_buf + 16), + rhs_16_bit_dup_high); + } else { + RUY_DCHECK(params.rhs_scalar_size == 2); + } // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit. const __m512i lhs_16_bit_low = @@ -306,7 +312,8 @@ void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) { process_column(15, accum_data_vf); lhs_ptr += 16 * 4; - rhs_ptr += 16 * 4; + rhs_ptr = static_cast<const void*>(static_cast<const char*>(rhs_ptr) + + 16 * 4 * params.rhs_scalar_size); } if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) { @@ -613,7 +620,8 @@ void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) { dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) + 16 * params.dst_stride); - rhs_col_ptr += 16 * params.rhs_stride; + rhs_col_ptr = static_cast<const void*>( + static_cast<const char*>(rhs_col_ptr) + 16 * params.rhs_stride); } // End col-block loop. } // NOLINT(readability/fn_size) diff --git a/ruy/kernel_x86.h b/ruy/kernel_x86.h index 051c894..0a6cf90 100644 --- a/ruy/kernel_x86.h +++ b/ruy/kernel_x86.h @@ -60,6 +60,24 @@ struct Kernel<Path::kAvx512, std::int8_t, std::int8_t, std::int32_t, DstScalar> } }; +template <typename DstScalar> +struct Kernel<Path::kAvx512, std::int8_t, std::int16_t, std::int32_t, + DstScalar> { + static constexpr Path kPath = Path::kAvx512; + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>; + using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PMat<std::int8_t>& lhs, const PMat<std::int16_t>& rhs, + const MulParams<std::int32_t, DstScalar>& mul_params, int start_row, + int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const { + KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params; + MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row, + end_col, dst, ¶ms); + Kernel8bitAvx512(params); + } +}; + void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params); void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& param); |