Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/google/ruy.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'ruy/kernel_avx512.cc')
-rw-r--r--ruy/kernel_avx512.cc30
1 files changed, 19 insertions, 11 deletions
diff --git a/ruy/kernel_avx512.cc b/ruy/kernel_avx512.cc
index 3d2219b..654ba27 100644
--- a/ruy/kernel_avx512.cc
+++ b/ruy/kernel_avx512.cc
@@ -634,8 +634,7 @@ void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) {
int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 16 : 0;
- const std::int8_t* rhs_col_ptr =
- static_cast<const int8_t*>(params.rhs_base_ptr);
+ const void* rhs_col_ptr = params.rhs_base_ptr;
void* dst_col_ptr = params.dst_base_ptr;
const std::int32_t* bias_col_ptr = params.bias;
if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
@@ -694,20 +693,28 @@ void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) {
}
const std::int8_t* lhs_ptr = lhs_col_ptr;
- const std::int8_t* rhs_ptr = rhs_col_ptr;
+ const void* rhs_ptr = rhs_col_ptr;
for (int d = 0; d < params.depth; d += 4) {
const __m512i lhs_data = _mm512_loadu_si512(lhs_ptr);
- const __m128i rhs_data_8bit =
- _mm_loadu_si128(reinterpret_cast<const __m128i*>(rhs_ptr));
+ const std::int32_t* rhs_data =
+ reinterpret_cast<const std::int32_t*>(rhs_ptr);
// Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
// For simplicity we load 4x the data that we need and process twice the
// data that we need and store only the data we need.
- std::int32_t rhs_data[2];
- const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit);
- // Now that we have cast the RHS data, we store it so that each value
- // can be separately loaded in the accumulation loop.
- _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup);
+ std::int32_t rhs_data_buf[2];
+ if (params.rhs_scalar_size == 1) {
+ rhs_data = rhs_data_buf;
+ const __m128i rhs_data_8bit =
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(rhs_ptr));
+ const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit);
+ // Now that we have cast the RHS data, we store it so that each value
+ // can be separately loaded in the accumulation loop.
+ _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data_buf),
+ rhs_16_bit_dup);
+ } else {
+ RUY_DCHECK(params.rhs_scalar_size == 2);
+ }
// Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit.
const __m512i lhs_16_bit_low =
@@ -731,7 +738,8 @@ void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) {
accum_data_v0 = accum_v;
lhs_ptr += 16 * 4;
- rhs_ptr += 16 * 4;
+ rhs_ptr = static_cast<const void*>(static_cast<const char*>(rhs_ptr) +
+ 16 * 4 * params.rhs_scalar_size);
}
if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {