diff options
Diffstat (limited to 'ruy/kernel_avx512.cc')
-rw-r--r-- | ruy/kernel_avx512.cc | 30 |
1 files changed, 19 insertions, 11 deletions
diff --git a/ruy/kernel_avx512.cc b/ruy/kernel_avx512.cc index 3d2219b..654ba27 100644 --- a/ruy/kernel_avx512.cc +++ b/ruy/kernel_avx512.cc @@ -634,8 +634,7 @@ void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) { int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 16 : 0; - const std::int8_t* rhs_col_ptr = - static_cast<const int8_t*>(params.rhs_base_ptr); + const void* rhs_col_ptr = params.rhs_base_ptr; void* dst_col_ptr = params.dst_base_ptr; const std::int32_t* bias_col_ptr = params.bias; if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { @@ -694,20 +693,28 @@ void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) { } const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; + const void* rhs_ptr = rhs_col_ptr; for (int d = 0; d < params.depth; d += 4) { const __m512i lhs_data = _mm512_loadu_si512(lhs_ptr); - const __m128i rhs_data_8bit = - _mm_loadu_si128(reinterpret_cast<const __m128i*>(rhs_ptr)); + const std::int32_t* rhs_data = + reinterpret_cast<const std::int32_t*>(rhs_ptr); // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. // For simplicity we load 4x the data that we need and process twice the // data that we need and store only the data we need. - std::int32_t rhs_data[2]; - const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit); - // Now that we have cast the RHS data, we store it so that each value - // can be separately loaded in the accumulation loop. - _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup); + std::int32_t rhs_data_buf[2]; + if (params.rhs_scalar_size == 1) { + rhs_data = rhs_data_buf; + const __m128i rhs_data_8bit = + _mm_loadu_si128(reinterpret_cast<const __m128i*>(rhs_ptr)); + const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit); + // Now that we have cast the RHS data, we store it so that each value + // can be separately loaded in the accumulation loop. + _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data_buf), + rhs_16_bit_dup); + } else { + RUY_DCHECK(params.rhs_scalar_size == 2); + } // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit. const __m512i lhs_16_bit_low = @@ -731,7 +738,8 @@ void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) { accum_data_v0 = accum_v; lhs_ptr += 16 * 4; - rhs_ptr += 16 * 4; + rhs_ptr = static_cast<const void*>(static_cast<const char*>(rhs_ptr) + + 16 * 4 * params.rhs_scalar_size); } if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) { |