Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/google/ruy.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'ruy/kernel_avx2_fma.cc')
-rw-r--r--ruy/kernel_avx2_fma.cc53
1 files changed, 33 insertions, 20 deletions
diff --git a/ruy/kernel_avx2_fma.cc b/ruy/kernel_avx2_fma.cc
index eae333c..50488a4 100644
--- a/ruy/kernel_avx2_fma.cc
+++ b/ruy/kernel_avx2_fma.cc
@@ -121,7 +121,7 @@ void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) {
RUY_DCHECK(false);
}
- const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ const void* rhs_col_ptr = params.rhs_base_ptr;
void* dst_col_ptr = params.dst_base_ptr;
for (int col = params.start_col; col <= params.last_col;
@@ -251,7 +251,7 @@ void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) {
}
const std::int8_t* lhs_ptr = lhs_col_ptr;
- const std::int8_t* rhs_ptr = rhs_col_ptr;
+ const void* rhs_ptr = rhs_col_ptr;
for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) {
const __m256i lhs_data =
_mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr));
@@ -259,21 +259,29 @@ void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) {
_mm256_load_si256(reinterpret_cast<const __m256i*>(rhs_ptr));
// Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
- std::int32_t rhs_data[16];
- const __m128i rhs_data_bottom_lane =
- _mm256_castsi256_si128(rhs_data_8bit);
- const __m128i rhs_data_top_lane =
- _mm256_extracti128_si256(rhs_data_8bit, 1);
- const __m256i rhs_16_bit_dup_low =
- _mm256_cvtepi8_epi16(rhs_data_bottom_lane);
- const __m256i rhs_16_bit_dup_high =
- _mm256_cvtepi8_epi16(rhs_data_top_lane);
- // Now that we have cast the RHS data, we store it so that each value
- // can be separately loaded in the accumulation loop.
- _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data),
- rhs_16_bit_dup_low);
- _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data + 8),
- rhs_16_bit_dup_high);
+ std::int32_t rhs_data_buf[16];
+ const std::int32_t* rhs_data =
+ reinterpret_cast<const std::int32_t*>(rhs_ptr);
+
+ if (params.rhs_scalar_size == 1) {
+ rhs_data = rhs_data_buf;
+ const __m128i rhs_data_bottom_lane =
+ _mm256_castsi256_si128(rhs_data_8bit);
+ const __m128i rhs_data_top_lane =
+ _mm256_extracti128_si256(rhs_data_8bit, 1);
+ const __m256i rhs_16_bit_dup_low =
+ _mm256_cvtepi8_epi16(rhs_data_bottom_lane);
+ const __m256i rhs_16_bit_dup_high =
+ _mm256_cvtepi8_epi16(rhs_data_top_lane);
+ // Now that we have cast the RHS data, we store it so that each value
+ // can be separately loaded in the accumulation loop.
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data_buf),
+ rhs_16_bit_dup_low);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data_buf + 8),
+ rhs_16_bit_dup_high);
+ } else {
+ RUY_DCHECK(params.rhs_scalar_size == 2);
+ }
const __m256i lhs_data_split =
_mm256_shuffle_epi8(lhs_data, splitter_idx);
@@ -339,7 +347,9 @@ void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) {
process_column(tmp2, tmp3, accum_data_v7);
lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
- rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
+ rhs_ptr = static_cast<const void*>(
+ static_cast<const char*>(rhs_ptr) +
+ kAvx8bitBlockSize * kAvx8bitInnerSize * params.rhs_scalar_size);
}
if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
@@ -717,7 +727,9 @@ void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) {
dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
kAvx8bitBlockSize * params.dst_stride);
- rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride;
+ rhs_col_ptr =
+ static_cast<const void*>(static_cast<const char*>(rhs_col_ptr) +
+ kAvx8bitBlockSize * params.rhs_stride);
} // End col-block loop.
} // NOLINT(readability/fn_size)
@@ -743,7 +755,8 @@ void Kernel8bitAvx2SingleColImpl(const KernelParams8bit<8, 8>& params) {
int bias_ptr_block_increment =
params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0;
- const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ const std::int8_t* rhs_col_ptr =
+ static_cast<const int8_t*>(params.rhs_base_ptr);
void* dst_col_ptr = params.dst_base_ptr;
const std::int32_t* bias_col_ptr = params.bias;
if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {