diff options
author | Dayeong Lee <dayeongl@google.com> | 2021-11-01 11:00:23 +0300 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2021-11-01 11:00:48 +0300 |
commit | f805132145194c44839db198afa61e08d7852cc4 (patch) | |
tree | 931a194f6a732729a334e1252e185cae7b3c96f5 | |
parent | 02d2088d84791eab6821f1a56510a2ea72e2cd77 (diff) |
Ruy: Support 8x16 avx2_fma kernel
PiperOrigin-RevId: 406766575
-rw-r--r-- | ruy/kernel_arm32.cc | 6 | ||||
-rw-r--r-- | ruy/kernel_arm64.cc | 21 | ||||
-rw-r--r-- | ruy/kernel_avx.cc | 6 | ||||
-rw-r--r-- | ruy/kernel_avx2_fma.cc | 53 | ||||
-rw-r--r-- | ruy/kernel_avx512.cc | 6 | ||||
-rw-r--r-- | ruy/kernel_common.h | 11 | ||||
-rw-r--r-- | ruy/kernel_x86.h | 18 |
7 files changed, 84 insertions, 37 deletions
diff --git a/ruy/kernel_arm32.cc b/ruy/kernel_arm32.cc index 8782dce..c8e053d 100644 --- a/ruy/kernel_arm32.cc +++ b/ruy/kernel_arm32.cc @@ -630,7 +630,8 @@ void Kernel8bitNeon(const KernelParams8bit<4, 2>& params) { CheckOffsetsInKernelParams8bit(params); const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* rhs_col_ptr = + static_cast<const int8_t*>(params.rhs_base_ptr); const std::int8_t* lhs_ptr = lhs_col_ptr; const std::int8_t* rhs_ptr = rhs_col_ptr; @@ -1630,7 +1631,8 @@ void Kernel8bitNeon1Col(const KernelParams8bit<4, 2>& params) { CheckOffsetsInKernelParams8bit(params); const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* rhs_col_ptr = + static_cast<const int8_t*>(params.rhs_base_ptr); const std::int8_t* lhs_ptr = lhs_col_ptr; const std::int8_t* rhs_ptr = rhs_col_ptr; diff --git a/ruy/kernel_arm64.cc b/ruy/kernel_arm64.cc index 5424107..532138d 100644 --- a/ruy/kernel_arm64.cc +++ b/ruy/kernel_arm64.cc @@ -101,7 +101,8 @@ void Kernel8bitNeon(const KernelParams8bit<4, 4>& params) { CheckOffsetsInKernelParams8bit(params); const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* rhs_col_ptr = + static_cast<const int8_t*>(params.rhs_base_ptr); const std::int8_t* lhs_ptr = lhs_col_ptr; const std::int8_t* rhs_ptr = rhs_col_ptr; void* dst_col_ptr = params.dst_base_ptr; @@ -1160,7 +1161,8 @@ void Kernel8bitNeon1Col(const KernelParams8bit<4, 4>& params) { CheckOffsetsInKernelParams8bit(params); const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* rhs_col_ptr = + static_cast<const int8_t*>(params.rhs_base_ptr); const std::int8_t* lhs_ptr = lhs_col_ptr; const std::int8_t* rhs_ptr = rhs_col_ptr; void* dst_col_ptr = params.dst_base_ptr; @@ -1832,7 +1834,8 @@ void Kernel8bitNeonA55ish(const KernelParams8bit<4, 4>& params) { CheckOffsetsInKernelParams8bit(params); const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* rhs_col_ptr = + static_cast<const int8_t*>(params.rhs_base_ptr); const std::int8_t* lhs_ptr = lhs_col_ptr; const std::int8_t* rhs_ptr = rhs_col_ptr; void* dst_col_ptr = params.dst_base_ptr; @@ -2987,7 +2990,8 @@ void Kernel8bitNeonDotprod(const KernelParams8bit<8, 8>& params) { CheckOffsetsInKernelParams8bit(params); const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* rhs_col_ptr = + static_cast<const int8_t*>(params.rhs_base_ptr); const std::int8_t* lhs_ptr = lhs_col_ptr; const std::int8_t* rhs_ptr = rhs_col_ptr; void* dst_col_ptr = params.dst_base_ptr; @@ -4413,7 +4417,8 @@ void Kernel8bitNeonDotprodX1(const KernelParams8bit<8, 8>& params) { CheckOffsetsInKernelParams8bit(params); const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* rhs_col_ptr = + static_cast<const int8_t*>(params.rhs_base_ptr); const std::int8_t* lhs_ptr = lhs_col_ptr; const std::int8_t* rhs_ptr = rhs_col_ptr; void* dst_col_ptr = params.dst_base_ptr; @@ -5667,7 +5672,8 @@ void Kernel8bitNeonDotprod1Col(const KernelParams8bit<8, 8>& params) { CheckOffsetsInKernelParams8bit(params); const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* rhs_col_ptr = + static_cast<const int8_t*>(params.rhs_base_ptr); const std::int8_t* lhs_ptr = lhs_col_ptr; const std::int8_t* rhs_ptr = rhs_col_ptr; void* dst_col_ptr = params.dst_base_ptr; @@ -6362,7 +6368,8 @@ void Kernel8bitNeonDotprodA55ish(const KernelParams8bit<8, 8>& params) { CheckOffsetsInKernelParams8bit(params); const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* rhs_col_ptr = + static_cast<const int8_t*>(params.rhs_base_ptr); const std::int8_t* lhs_ptr = lhs_col_ptr; const std::int8_t* rhs_ptr = rhs_col_ptr; void* dst_col_ptr = params.dst_base_ptr; diff --git a/ruy/kernel_avx.cc b/ruy/kernel_avx.cc index 2405735..0f7e2e3 100644 --- a/ruy/kernel_avx.cc +++ b/ruy/kernel_avx.cc @@ -462,7 +462,8 @@ void Kernel8bitAvxImpl(const KernelParams8bit<8, 8>& params) { RUY_DCHECK(false); } - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* rhs_col_ptr = + static_cast<const int8_t*>(params.rhs_base_ptr); void* dst_col_ptr = params.dst_base_ptr; for (int col = params.start_col; col <= params.last_col; @@ -1184,7 +1185,8 @@ void Kernel8bitAvxSingleColImpl(const KernelParams8bit<8, 8>& params) { int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* rhs_col_ptr = + static_cast<const int8_t*>(params.rhs_base_ptr); void* dst_col_ptr = params.dst_base_ptr; const std::int32_t* bias_col_ptr = params.bias; if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { diff --git a/ruy/kernel_avx2_fma.cc b/ruy/kernel_avx2_fma.cc index eae333c..50488a4 100644 --- a/ruy/kernel_avx2_fma.cc +++ b/ruy/kernel_avx2_fma.cc @@ -121,7 +121,7 @@ void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) { RUY_DCHECK(false); } - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const void* rhs_col_ptr = params.rhs_base_ptr; void* dst_col_ptr = params.dst_base_ptr; for (int col = params.start_col; col <= params.last_col; @@ -251,7 +251,7 @@ void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) { } const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; + const void* rhs_ptr = rhs_col_ptr; for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { const __m256i lhs_data = _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr)); @@ -259,21 +259,29 @@ void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) { _mm256_load_si256(reinterpret_cast<const __m256i*>(rhs_ptr)); // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. - std::int32_t rhs_data[16]; - const __m128i rhs_data_bottom_lane = - _mm256_castsi256_si128(rhs_data_8bit); - const __m128i rhs_data_top_lane = - _mm256_extracti128_si256(rhs_data_8bit, 1); - const __m256i rhs_16_bit_dup_low = - _mm256_cvtepi8_epi16(rhs_data_bottom_lane); - const __m256i rhs_16_bit_dup_high = - _mm256_cvtepi8_epi16(rhs_data_top_lane); - // Now that we have cast the RHS data, we store it so that each value - // can be separately loaded in the accumulation loop. - _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data), - rhs_16_bit_dup_low); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data + 8), - rhs_16_bit_dup_high); + std::int32_t rhs_data_buf[16]; + const std::int32_t* rhs_data = + reinterpret_cast<const std::int32_t*>(rhs_ptr); + + if (params.rhs_scalar_size == 1) { + rhs_data = rhs_data_buf; + const __m128i rhs_data_bottom_lane = + _mm256_castsi256_si128(rhs_data_8bit); + const __m128i rhs_data_top_lane = + _mm256_extracti128_si256(rhs_data_8bit, 1); + const __m256i rhs_16_bit_dup_low = + _mm256_cvtepi8_epi16(rhs_data_bottom_lane); + const __m256i rhs_16_bit_dup_high = + _mm256_cvtepi8_epi16(rhs_data_top_lane); + // Now that we have cast the RHS data, we store it so that each value + // can be separately loaded in the accumulation loop. + _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data_buf), + rhs_16_bit_dup_low); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data_buf + 8), + rhs_16_bit_dup_high); + } else { + RUY_DCHECK(params.rhs_scalar_size == 2); + } const __m256i lhs_data_split = _mm256_shuffle_epi8(lhs_data, splitter_idx); @@ -339,7 +347,9 @@ void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) { process_column(tmp2, tmp3, accum_data_v7); lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; - rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; + rhs_ptr = static_cast<const void*>( + static_cast<const char*>(rhs_ptr) + + kAvx8bitBlockSize * kAvx8bitInnerSize * params.rhs_scalar_size); } if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) { @@ -717,7 +727,9 @@ void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) { dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) + kAvx8bitBlockSize * params.dst_stride); - rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; + rhs_col_ptr = + static_cast<const void*>(static_cast<const char*>(rhs_col_ptr) + + kAvx8bitBlockSize * params.rhs_stride); } // End col-block loop. } // NOLINT(readability/fn_size) @@ -743,7 +755,8 @@ void Kernel8bitAvx2SingleColImpl(const KernelParams8bit<8, 8>& params) { int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* rhs_col_ptr = + static_cast<const int8_t*>(params.rhs_base_ptr); void* dst_col_ptr = params.dst_base_ptr; const std::int32_t* bias_col_ptr = params.bias; if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { diff --git a/ruy/kernel_avx512.cc b/ruy/kernel_avx512.cc index 84b9380..25ce2c7 100644 --- a/ruy/kernel_avx512.cc +++ b/ruy/kernel_avx512.cc @@ -67,7 +67,8 @@ void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) { RUY_DCHECK(false); } - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* rhs_col_ptr = + static_cast<const int8_t*>(params.rhs_base_ptr); void* dst_col_ptr = params.dst_base_ptr; for (int col = params.start_col; col <= params.last_col; col += 16) { @@ -625,7 +626,8 @@ void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) { int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 16 : 0; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* rhs_col_ptr = + static_cast<const int8_t*>(params.rhs_base_ptr); void* dst_col_ptr = params.dst_base_ptr; const std::int32_t* bias_col_ptr = params.bias; if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { diff --git a/ruy/kernel_common.h b/ruy/kernel_common.h index cff243b..69e819b 100644 --- a/ruy/kernel_common.h +++ b/ruy/kernel_common.h @@ -101,7 +101,8 @@ struct KernelParams8bit { const std::int8_t* lhs_base_ptr; const std::int32_t* multiplier_fixedpoint; const std::int32_t* multiplier_exponent; - const std::int8_t* rhs_base_ptr; + // Make it void* to support 8bit(LHS)x16bit(RHS) case. + const void* rhs_base_ptr; void* dst_base_ptr; std::int32_t lhs_zero_point; std::int32_t rhs_zero_point; @@ -125,11 +126,12 @@ struct KernelParams8bit { std::uint8_t dst_tmp_buf[LhsCols * RhsCols * kMaxDstTypeSize]; std::int32_t multiplier_fixedpoint_buf[LhsCols]; std::int32_t multiplier_exponent_buf[LhsCols]; + std::size_t rhs_scalar_size; }; -template <typename DstScalar, int LhsCols, int RhsCols> +template <typename RhsScalar, typename DstScalar, int LhsCols, int RhsCols> void MakeKernelParams8bit(const PMat<std::int8_t>& lhs, - const PMat<std::int8_t>& rhs, + const PMat<RhsScalar>& rhs, const MulParams<std::int32_t, DstScalar>& mul_params, int start_row, int start_col, int end_row, int end_col, Mat<DstScalar>* dst, @@ -145,6 +147,7 @@ void MakeKernelParams8bit(const PMat<std::int8_t>& lhs, RUY_DCHECK_EQ(end_col % RhsCols, 0); params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride; + params->rhs_scalar_size = sizeof(RhsScalar); params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride; params->flags = 0; params->bias = params->zero_data; @@ -168,7 +171,7 @@ void MakeKernelParams8bit(const PMat<std::int8_t>& lhs, params->last_row = end_row - LhsCols; params->last_col = end_col - RhsCols; params->lhs_stride = lhs.layout.stride; - params->rhs_stride = rhs.layout.stride; + params->rhs_stride = params->rhs_scalar_size * rhs.layout.stride; params->dst_stride = sizeof(DstScalar) * dst->layout.stride; params->lhs_zero_point = lhs.zero_point; params->rhs_zero_point = rhs.zero_point; diff --git a/ruy/kernel_x86.h b/ruy/kernel_x86.h index d2045de..051c894 100644 --- a/ruy/kernel_x86.h +++ b/ruy/kernel_x86.h @@ -111,6 +111,24 @@ struct Kernel<Path::kAvx2Fma, std::int8_t, std::int8_t, std::int32_t, } }; +template <typename DstScalar> +struct Kernel<Path::kAvx2Fma, std::int8_t, std::int16_t, std::int32_t, + DstScalar> { + static constexpr Path kPath = Path::kAvx2Fma; + Tuning tuning = Tuning::kAuto; + using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>; + using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>; + explicit Kernel(Tuning tuning_) : tuning(tuning_) {} + void Run(const PMat<std::int8_t>& lhs, const PMat<std::int16_t>& rhs, + const MulParams<std::int32_t, DstScalar>& mul_params, int start_row, + int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const { + KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params; + MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row, + end_col, dst, ¶ms); + Kernel8bitAvx2(params); + } +}; + void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params); void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params); |