Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/google/ruy.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDayeong Lee <dayeongl@google.com>2021-11-01 11:00:23 +0300
committerCopybara-Service <copybara-worker@google.com>2021-11-01 11:00:48 +0300
commitf805132145194c44839db198afa61e08d7852cc4 (patch)
tree931a194f6a732729a334e1252e185cae7b3c96f5
parent02d2088d84791eab6821f1a56510a2ea72e2cd77 (diff)
Ruy: Support 8x16 avx2_fma kernel
PiperOrigin-RevId: 406766575
-rw-r--r--ruy/kernel_arm32.cc6
-rw-r--r--ruy/kernel_arm64.cc21
-rw-r--r--ruy/kernel_avx.cc6
-rw-r--r--ruy/kernel_avx2_fma.cc53
-rw-r--r--ruy/kernel_avx512.cc6
-rw-r--r--ruy/kernel_common.h11
-rw-r--r--ruy/kernel_x86.h18
7 files changed, 84 insertions, 37 deletions
diff --git a/ruy/kernel_arm32.cc b/ruy/kernel_arm32.cc
index 8782dce..c8e053d 100644
--- a/ruy/kernel_arm32.cc
+++ b/ruy/kernel_arm32.cc
@@ -630,7 +630,8 @@ void Kernel8bitNeon(const KernelParams8bit<4, 2>& params) {
CheckOffsetsInKernelParams8bit(params);
const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
- const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ const std::int8_t* rhs_col_ptr =
+ static_cast<const int8_t*>(params.rhs_base_ptr);
const std::int8_t* lhs_ptr = lhs_col_ptr;
const std::int8_t* rhs_ptr = rhs_col_ptr;
@@ -1630,7 +1631,8 @@ void Kernel8bitNeon1Col(const KernelParams8bit<4, 2>& params) {
CheckOffsetsInKernelParams8bit(params);
const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
- const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ const std::int8_t* rhs_col_ptr =
+ static_cast<const int8_t*>(params.rhs_base_ptr);
const std::int8_t* lhs_ptr = lhs_col_ptr;
const std::int8_t* rhs_ptr = rhs_col_ptr;
diff --git a/ruy/kernel_arm64.cc b/ruy/kernel_arm64.cc
index 5424107..532138d 100644
--- a/ruy/kernel_arm64.cc
+++ b/ruy/kernel_arm64.cc
@@ -101,7 +101,8 @@ void Kernel8bitNeon(const KernelParams8bit<4, 4>& params) {
CheckOffsetsInKernelParams8bit(params);
const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
- const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ const std::int8_t* rhs_col_ptr =
+ static_cast<const int8_t*>(params.rhs_base_ptr);
const std::int8_t* lhs_ptr = lhs_col_ptr;
const std::int8_t* rhs_ptr = rhs_col_ptr;
void* dst_col_ptr = params.dst_base_ptr;
@@ -1160,7 +1161,8 @@ void Kernel8bitNeon1Col(const KernelParams8bit<4, 4>& params) {
CheckOffsetsInKernelParams8bit(params);
const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
- const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ const std::int8_t* rhs_col_ptr =
+ static_cast<const int8_t*>(params.rhs_base_ptr);
const std::int8_t* lhs_ptr = lhs_col_ptr;
const std::int8_t* rhs_ptr = rhs_col_ptr;
void* dst_col_ptr = params.dst_base_ptr;
@@ -1832,7 +1834,8 @@ void Kernel8bitNeonA55ish(const KernelParams8bit<4, 4>& params) {
CheckOffsetsInKernelParams8bit(params);
const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
- const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ const std::int8_t* rhs_col_ptr =
+ static_cast<const int8_t*>(params.rhs_base_ptr);
const std::int8_t* lhs_ptr = lhs_col_ptr;
const std::int8_t* rhs_ptr = rhs_col_ptr;
void* dst_col_ptr = params.dst_base_ptr;
@@ -2987,7 +2990,8 @@ void Kernel8bitNeonDotprod(const KernelParams8bit<8, 8>& params) {
CheckOffsetsInKernelParams8bit(params);
const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
- const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ const std::int8_t* rhs_col_ptr =
+ static_cast<const int8_t*>(params.rhs_base_ptr);
const std::int8_t* lhs_ptr = lhs_col_ptr;
const std::int8_t* rhs_ptr = rhs_col_ptr;
void* dst_col_ptr = params.dst_base_ptr;
@@ -4413,7 +4417,8 @@ void Kernel8bitNeonDotprodX1(const KernelParams8bit<8, 8>& params) {
CheckOffsetsInKernelParams8bit(params);
const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
- const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ const std::int8_t* rhs_col_ptr =
+ static_cast<const int8_t*>(params.rhs_base_ptr);
const std::int8_t* lhs_ptr = lhs_col_ptr;
const std::int8_t* rhs_ptr = rhs_col_ptr;
void* dst_col_ptr = params.dst_base_ptr;
@@ -5667,7 +5672,8 @@ void Kernel8bitNeonDotprod1Col(const KernelParams8bit<8, 8>& params) {
CheckOffsetsInKernelParams8bit(params);
const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
- const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ const std::int8_t* rhs_col_ptr =
+ static_cast<const int8_t*>(params.rhs_base_ptr);
const std::int8_t* lhs_ptr = lhs_col_ptr;
const std::int8_t* rhs_ptr = rhs_col_ptr;
void* dst_col_ptr = params.dst_base_ptr;
@@ -6362,7 +6368,8 @@ void Kernel8bitNeonDotprodA55ish(const KernelParams8bit<8, 8>& params) {
CheckOffsetsInKernelParams8bit(params);
const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
- const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ const std::int8_t* rhs_col_ptr =
+ static_cast<const int8_t*>(params.rhs_base_ptr);
const std::int8_t* lhs_ptr = lhs_col_ptr;
const std::int8_t* rhs_ptr = rhs_col_ptr;
void* dst_col_ptr = params.dst_base_ptr;
diff --git a/ruy/kernel_avx.cc b/ruy/kernel_avx.cc
index 2405735..0f7e2e3 100644
--- a/ruy/kernel_avx.cc
+++ b/ruy/kernel_avx.cc
@@ -462,7 +462,8 @@ void Kernel8bitAvxImpl(const KernelParams8bit<8, 8>& params) {
RUY_DCHECK(false);
}
- const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ const std::int8_t* rhs_col_ptr =
+ static_cast<const int8_t*>(params.rhs_base_ptr);
void* dst_col_ptr = params.dst_base_ptr;
for (int col = params.start_col; col <= params.last_col;
@@ -1184,7 +1185,8 @@ void Kernel8bitAvxSingleColImpl(const KernelParams8bit<8, 8>& params) {
int bias_ptr_block_increment =
params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0;
- const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ const std::int8_t* rhs_col_ptr =
+ static_cast<const int8_t*>(params.rhs_base_ptr);
void* dst_col_ptr = params.dst_base_ptr;
const std::int32_t* bias_col_ptr = params.bias;
if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
diff --git a/ruy/kernel_avx2_fma.cc b/ruy/kernel_avx2_fma.cc
index eae333c..50488a4 100644
--- a/ruy/kernel_avx2_fma.cc
+++ b/ruy/kernel_avx2_fma.cc
@@ -121,7 +121,7 @@ void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) {
RUY_DCHECK(false);
}
- const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ const void* rhs_col_ptr = params.rhs_base_ptr;
void* dst_col_ptr = params.dst_base_ptr;
for (int col = params.start_col; col <= params.last_col;
@@ -251,7 +251,7 @@ void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) {
}
const std::int8_t* lhs_ptr = lhs_col_ptr;
- const std::int8_t* rhs_ptr = rhs_col_ptr;
+ const void* rhs_ptr = rhs_col_ptr;
for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) {
const __m256i lhs_data =
_mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr));
@@ -259,21 +259,29 @@ void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) {
_mm256_load_si256(reinterpret_cast<const __m256i*>(rhs_ptr));
// Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
- std::int32_t rhs_data[16];
- const __m128i rhs_data_bottom_lane =
- _mm256_castsi256_si128(rhs_data_8bit);
- const __m128i rhs_data_top_lane =
- _mm256_extracti128_si256(rhs_data_8bit, 1);
- const __m256i rhs_16_bit_dup_low =
- _mm256_cvtepi8_epi16(rhs_data_bottom_lane);
- const __m256i rhs_16_bit_dup_high =
- _mm256_cvtepi8_epi16(rhs_data_top_lane);
- // Now that we have cast the RHS data, we store it so that each value
- // can be separately loaded in the accumulation loop.
- _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data),
- rhs_16_bit_dup_low);
- _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data + 8),
- rhs_16_bit_dup_high);
+ std::int32_t rhs_data_buf[16];
+ const std::int32_t* rhs_data =
+ reinterpret_cast<const std::int32_t*>(rhs_ptr);
+
+ if (params.rhs_scalar_size == 1) {
+ rhs_data = rhs_data_buf;
+ const __m128i rhs_data_bottom_lane =
+ _mm256_castsi256_si128(rhs_data_8bit);
+ const __m128i rhs_data_top_lane =
+ _mm256_extracti128_si256(rhs_data_8bit, 1);
+ const __m256i rhs_16_bit_dup_low =
+ _mm256_cvtepi8_epi16(rhs_data_bottom_lane);
+ const __m256i rhs_16_bit_dup_high =
+ _mm256_cvtepi8_epi16(rhs_data_top_lane);
+ // Now that we have cast the RHS data, we store it so that each value
+ // can be separately loaded in the accumulation loop.
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data_buf),
+ rhs_16_bit_dup_low);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data_buf + 8),
+ rhs_16_bit_dup_high);
+ } else {
+ RUY_DCHECK(params.rhs_scalar_size == 2);
+ }
const __m256i lhs_data_split =
_mm256_shuffle_epi8(lhs_data, splitter_idx);
@@ -339,7 +347,9 @@ void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) {
process_column(tmp2, tmp3, accum_data_v7);
lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
- rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
+ rhs_ptr = static_cast<const void*>(
+ static_cast<const char*>(rhs_ptr) +
+ kAvx8bitBlockSize * kAvx8bitInnerSize * params.rhs_scalar_size);
}
if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
@@ -717,7 +727,9 @@ void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) {
dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
kAvx8bitBlockSize * params.dst_stride);
- rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride;
+ rhs_col_ptr =
+ static_cast<const void*>(static_cast<const char*>(rhs_col_ptr) +
+ kAvx8bitBlockSize * params.rhs_stride);
} // End col-block loop.
} // NOLINT(readability/fn_size)
@@ -743,7 +755,8 @@ void Kernel8bitAvx2SingleColImpl(const KernelParams8bit<8, 8>& params) {
int bias_ptr_block_increment =
params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0;
- const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ const std::int8_t* rhs_col_ptr =
+ static_cast<const int8_t*>(params.rhs_base_ptr);
void* dst_col_ptr = params.dst_base_ptr;
const std::int32_t* bias_col_ptr = params.bias;
if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
diff --git a/ruy/kernel_avx512.cc b/ruy/kernel_avx512.cc
index 84b9380..25ce2c7 100644
--- a/ruy/kernel_avx512.cc
+++ b/ruy/kernel_avx512.cc
@@ -67,7 +67,8 @@ void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) {
RUY_DCHECK(false);
}
- const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ const std::int8_t* rhs_col_ptr =
+ static_cast<const int8_t*>(params.rhs_base_ptr);
void* dst_col_ptr = params.dst_base_ptr;
for (int col = params.start_col; col <= params.last_col; col += 16) {
@@ -625,7 +626,8 @@ void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) {
int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 16 : 0;
- const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ const std::int8_t* rhs_col_ptr =
+ static_cast<const int8_t*>(params.rhs_base_ptr);
void* dst_col_ptr = params.dst_base_ptr;
const std::int32_t* bias_col_ptr = params.bias;
if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
diff --git a/ruy/kernel_common.h b/ruy/kernel_common.h
index cff243b..69e819b 100644
--- a/ruy/kernel_common.h
+++ b/ruy/kernel_common.h
@@ -101,7 +101,8 @@ struct KernelParams8bit {
const std::int8_t* lhs_base_ptr;
const std::int32_t* multiplier_fixedpoint;
const std::int32_t* multiplier_exponent;
- const std::int8_t* rhs_base_ptr;
+ // Make it void* to support 8bit(LHS)x16bit(RHS) case.
+ const void* rhs_base_ptr;
void* dst_base_ptr;
std::int32_t lhs_zero_point;
std::int32_t rhs_zero_point;
@@ -125,11 +126,12 @@ struct KernelParams8bit {
std::uint8_t dst_tmp_buf[LhsCols * RhsCols * kMaxDstTypeSize];
std::int32_t multiplier_fixedpoint_buf[LhsCols];
std::int32_t multiplier_exponent_buf[LhsCols];
+ std::size_t rhs_scalar_size;
};
-template <typename DstScalar, int LhsCols, int RhsCols>
+template <typename RhsScalar, typename DstScalar, int LhsCols, int RhsCols>
void MakeKernelParams8bit(const PMat<std::int8_t>& lhs,
- const PMat<std::int8_t>& rhs,
+ const PMat<RhsScalar>& rhs,
const MulParams<std::int32_t, DstScalar>& mul_params,
int start_row, int start_col, int end_row,
int end_col, Mat<DstScalar>* dst,
@@ -145,6 +147,7 @@ void MakeKernelParams8bit(const PMat<std::int8_t>& lhs,
RUY_DCHECK_EQ(end_col % RhsCols, 0);
params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride;
+ params->rhs_scalar_size = sizeof(RhsScalar);
params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride;
params->flags = 0;
params->bias = params->zero_data;
@@ -168,7 +171,7 @@ void MakeKernelParams8bit(const PMat<std::int8_t>& lhs,
params->last_row = end_row - LhsCols;
params->last_col = end_col - RhsCols;
params->lhs_stride = lhs.layout.stride;
- params->rhs_stride = rhs.layout.stride;
+ params->rhs_stride = params->rhs_scalar_size * rhs.layout.stride;
params->dst_stride = sizeof(DstScalar) * dst->layout.stride;
params->lhs_zero_point = lhs.zero_point;
params->rhs_zero_point = rhs.zero_point;
diff --git a/ruy/kernel_x86.h b/ruy/kernel_x86.h
index d2045de..051c894 100644
--- a/ruy/kernel_x86.h
+++ b/ruy/kernel_x86.h
@@ -111,6 +111,24 @@ struct Kernel<Path::kAvx2Fma, std::int8_t, std::int8_t, std::int32_t,
}
};
+template <typename DstScalar>
+struct Kernel<Path::kAvx2Fma, std::int8_t, std::int16_t, std::int32_t,
+ DstScalar> {
+ static constexpr Path kPath = Path::kAvx2Fma;
+ Tuning tuning = Tuning::kAuto;
+ using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
+ using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
+ explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
+ void Run(const PMat<std::int8_t>& lhs, const PMat<std::int16_t>& rhs,
+ const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
+ int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
+ KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
+ MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
+ end_col, dst, &params);
+ Kernel8bitAvx2(params);
+ }
+};
+
void KernelFloatAvx2(const KernelParamsFloat<8, 8>& params);
void KernelFloatAvx2SingleCol(const KernelParamsFloat<8, 8>& params);