diff options
Diffstat (limited to 'ruy/kernel_avx2_fma.cc')
-rw-r--r-- | ruy/kernel_avx2_fma.cc | 53 |
1 files changed, 33 insertions, 20 deletions
diff --git a/ruy/kernel_avx2_fma.cc b/ruy/kernel_avx2_fma.cc index eae333c..50488a4 100644 --- a/ruy/kernel_avx2_fma.cc +++ b/ruy/kernel_avx2_fma.cc @@ -121,7 +121,7 @@ void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) { RUY_DCHECK(false); } - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const void* rhs_col_ptr = params.rhs_base_ptr; void* dst_col_ptr = params.dst_base_ptr; for (int col = params.start_col; col <= params.last_col; @@ -251,7 +251,7 @@ void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) { } const std::int8_t* lhs_ptr = lhs_col_ptr; - const std::int8_t* rhs_ptr = rhs_col_ptr; + const void* rhs_ptr = rhs_col_ptr; for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) { const __m256i lhs_data = _mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr)); @@ -259,21 +259,29 @@ void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) { _mm256_load_si256(reinterpret_cast<const __m256i*>(rhs_ptr)); // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. - std::int32_t rhs_data[16]; - const __m128i rhs_data_bottom_lane = - _mm256_castsi256_si128(rhs_data_8bit); - const __m128i rhs_data_top_lane = - _mm256_extracti128_si256(rhs_data_8bit, 1); - const __m256i rhs_16_bit_dup_low = - _mm256_cvtepi8_epi16(rhs_data_bottom_lane); - const __m256i rhs_16_bit_dup_high = - _mm256_cvtepi8_epi16(rhs_data_top_lane); - // Now that we have cast the RHS data, we store it so that each value - // can be separately loaded in the accumulation loop. - _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data), - rhs_16_bit_dup_low); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data + 8), - rhs_16_bit_dup_high); + std::int32_t rhs_data_buf[16]; + const std::int32_t* rhs_data = + reinterpret_cast<const std::int32_t*>(rhs_ptr); + + if (params.rhs_scalar_size == 1) { + rhs_data = rhs_data_buf; + const __m128i rhs_data_bottom_lane = + _mm256_castsi256_si128(rhs_data_8bit); + const __m128i rhs_data_top_lane = + _mm256_extracti128_si256(rhs_data_8bit, 1); + const __m256i rhs_16_bit_dup_low = + _mm256_cvtepi8_epi16(rhs_data_bottom_lane); + const __m256i rhs_16_bit_dup_high = + _mm256_cvtepi8_epi16(rhs_data_top_lane); + // Now that we have cast the RHS data, we store it so that each value + // can be separately loaded in the accumulation loop. + _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data_buf), + rhs_16_bit_dup_low); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(rhs_data_buf + 8), + rhs_16_bit_dup_high); + } else { + RUY_DCHECK(params.rhs_scalar_size == 2); + } const __m256i lhs_data_split = _mm256_shuffle_epi8(lhs_data, splitter_idx); @@ -339,7 +347,9 @@ void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) { process_column(tmp2, tmp3, accum_data_v7); lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; - rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize; + rhs_ptr = static_cast<const void*>( + static_cast<const char*>(rhs_ptr) + + kAvx8bitBlockSize * kAvx8bitInnerSize * params.rhs_scalar_size); } if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) { @@ -717,7 +727,9 @@ void Kernel8bitAvx2Impl(const KernelParams8bit<8, 8>& params) { dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) + kAvx8bitBlockSize * params.dst_stride); - rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride; + rhs_col_ptr = + static_cast<const void*>(static_cast<const char*>(rhs_col_ptr) + + kAvx8bitBlockSize * params.rhs_stride); } // End col-block loop. } // NOLINT(readability/fn_size) @@ -743,7 +755,8 @@ void Kernel8bitAvx2SingleColImpl(const KernelParams8bit<8, 8>& params) { int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0; - const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; + const std::int8_t* rhs_col_ptr = + static_cast<const int8_t*>(params.rhs_base_ptr); void* dst_col_ptr = params.dst_base_ptr; const std::int32_t* bias_col_ptr = params.bias; if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { |