Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/google/ruy.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDayeong Lee <dayeongl@google.com>2021-11-02 09:27:15 +0300
committerCopybara-Service <copybara-worker@google.com>2021-11-02 09:27:41 +0300
commit409296d21c15800375d894f2e0d1cc4c88c14cd9 (patch)
tree4d947f25b75a53c38b31574705b3677d529e4561
parentf805132145194c44839db198afa61e08d7852cc4 (diff)
Ruy: Support 8x16 avx512 kernel
PiperOrigin-RevId: 407005437
-rw-r--r--ruy/kernel_avx512.cc48
-rw-r--r--ruy/kernel_x86.h18
2 files changed, 46 insertions, 20 deletions
diff --git a/ruy/kernel_avx512.cc b/ruy/kernel_avx512.cc
index 25ce2c7..3d2219b 100644
--- a/ruy/kernel_avx512.cc
+++ b/ruy/kernel_avx512.cc
@@ -67,8 +67,7 @@ void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) {
RUY_DCHECK(false);
}
- const std::int8_t* rhs_col_ptr =
- static_cast<const int8_t*>(params.rhs_base_ptr);
+ const void* rhs_col_ptr = params.rhs_base_ptr;
void* dst_col_ptr = params.dst_base_ptr;
for (int col = params.start_col; col <= params.last_col; col += 16) {
@@ -248,27 +247,34 @@ void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) {
}
const std::int8_t* lhs_ptr = lhs_col_ptr;
- const std::int8_t* rhs_ptr = rhs_col_ptr;
+ const void* rhs_ptr = rhs_col_ptr;
for (int d = 0; d < params.depth; d += 4) {
const __m512i lhs_data = _mm512_loadu_si512(lhs_ptr);
__m512i rhs_data_8bit = _mm512_loadu_si512(rhs_ptr);
// Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
- std::int32_t rhs_data[32];
- const __m256i rhs_data_bottom_lane =
- _mm512_castsi512_si256(rhs_data_8bit);
- const __m256i rhs_data_top_lane =
- _mm512_extracti32x8_epi32(rhs_data_8bit, 1);
- const __m512i rhs_16_bit_dup_low =
- _mm512_cvtepi8_epi16(rhs_data_bottom_lane);
- const __m512i rhs_16_bit_dup_high =
- _mm512_cvtepi8_epi16(rhs_data_top_lane);
- // Now that we have cast the RHS data, we store it so that each value
- // can be separately loaded in the accumulation loop.
- _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data),
- rhs_16_bit_dup_low);
- _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data + 16),
- rhs_16_bit_dup_high);
+ std::int32_t rhs_data_buf[32];
+ const std::int32_t* rhs_data =
+ reinterpret_cast<const std::int32_t*>(rhs_ptr);
+ if (params.rhs_scalar_size == 1) {
+ rhs_data = rhs_data_buf;
+ const __m256i rhs_data_bottom_lane =
+ _mm512_castsi512_si256(rhs_data_8bit);
+ const __m256i rhs_data_top_lane =
+ _mm512_extracti32x8_epi32(rhs_data_8bit, 1);
+ const __m512i rhs_16_bit_dup_low =
+ _mm512_cvtepi8_epi16(rhs_data_bottom_lane);
+ const __m512i rhs_16_bit_dup_high =
+ _mm512_cvtepi8_epi16(rhs_data_top_lane);
+ // Now that we have cast the RHS data, we store it so that each value
+ // can be separately loaded in the accumulation loop.
+ _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data_buf),
+ rhs_16_bit_dup_low);
+ _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data_buf + 16),
+ rhs_16_bit_dup_high);
+ } else {
+ RUY_DCHECK(params.rhs_scalar_size == 2);
+ }
// Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit.
const __m512i lhs_16_bit_low =
@@ -306,7 +312,8 @@ void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) {
process_column(15, accum_data_vf);
lhs_ptr += 16 * 4;
- rhs_ptr += 16 * 4;
+ rhs_ptr = static_cast<const void*>(static_cast<const char*>(rhs_ptr) +
+ 16 * 4 * params.rhs_scalar_size);
}
if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
@@ -613,7 +620,8 @@ void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) {
dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
16 * params.dst_stride);
- rhs_col_ptr += 16 * params.rhs_stride;
+ rhs_col_ptr = static_cast<const void*>(
+ static_cast<const char*>(rhs_col_ptr) + 16 * params.rhs_stride);
} // End col-block loop.
} // NOLINT(readability/fn_size)
diff --git a/ruy/kernel_x86.h b/ruy/kernel_x86.h
index 051c894..0a6cf90 100644
--- a/ruy/kernel_x86.h
+++ b/ruy/kernel_x86.h
@@ -60,6 +60,24 @@ struct Kernel<Path::kAvx512, std::int8_t, std::int8_t, std::int32_t, DstScalar>
}
};
+template <typename DstScalar>
+struct Kernel<Path::kAvx512, std::int8_t, std::int16_t, std::int32_t,
+ DstScalar> {
+ static constexpr Path kPath = Path::kAvx512;
+ Tuning tuning = Tuning::kAuto;
+ using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>;
+ using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>;
+ explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+ void Run(const PMat<std::int8_t>& lhs, const PMat<std::int16_t>& rhs,
+ const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
+ int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
+ KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
+ MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
+ end_col, dst, &params);
+ Kernel8bitAvx512(params);
+ }
+};
+
void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params);
void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& param);