Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/google/ruy.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDayeong Lee <dayeongl@google.com>2021-11-01 11:53:50 +0300
committerCopybara-Service <copybara-worker@google.com>2021-11-02 16:12:31 +0300
commit2e9f250c79883dbd3759432a4b19d3b10f6576a7 (patch)
tree743c4fcb5e13a761cb9313cb04b61c8199a39235
parent409296d21c15800375d894f2e0d1cc4c88c14cd9 (diff)
Update TFLite kernel to use Ruy 16x8 Gemm instead of reference kernel.test_406772541
PiperOrigin-RevId: 406772541
-rw-r--r--ruy/kernel_avx2_fma.cc33
-rw-r--r--ruy/kernel_avx512.cc30
-rw-r--r--ruy/kernel_x86.h14
3 files changed, 53 insertions, 24 deletions
diff --git a/ruy/kernel_avx2_fma.cc b/ruy/kernel_avx2_fma.cc
index 50488a4..e725777 100644
--- a/ruy/kernel_avx2_fma.cc
+++ b/ruy/kernel_avx2_fma.cc
@@ -755,8 +755,7 @@ void Kernel8bitAvx2SingleColImpl(const KernelParams8bit<8, 8>& params) {
int bias_ptr_block_increment =
params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0;
- const std::int8_t* rhs_col_ptr =
- static_cast<const int8_t*>(params.rhs_base_ptr);
+ const void* rhs_col_ptr = params.rhs_base_ptr;
void* dst_col_ptr = params.dst_base_ptr;
const std::int32_t* bias_col_ptr = params.bias;
if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
@@ -820,20 +819,29 @@ void Kernel8bitAvx2SingleColImpl(const KernelParams8bit<8, 8>& params) {
}
const std::int8_t* lhs_ptr = lhs_col_ptr;
- const std::int8_t* rhs_ptr = rhs_col_ptr;
+ const void* rhs_ptr = rhs_col_ptr;
for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) {
const __m256i lhs_data =
_mm256_load_si256(reinterpret_cast<const __m256i*>(lhs_ptr));
- const __m128i rhs_data_8bit = intrin_utils::mm_loadu_si32<path>(rhs_ptr);
+ const std::int32_t* rhs_data =
+ reinterpret_cast<const std::int32_t*>(rhs_ptr);
// Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
// For simplicity we load 4x the data that we need and process twice the
// data that we need and store only the data we need.
- std::int32_t rhs_data[2];
- const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit);
- // Now that we have cast the RHS data, we store it so that each value
- // can be separately loaded in the accumulation loop.
- _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup);
+ std::int32_t rhs_data_buf[2];
+ if (params.rhs_scalar_size == 1) {
+ rhs_data = rhs_data_buf;
+ const __m128i rhs_data_8bit =
+ intrin_utils::mm_loadu_si32<path>(rhs_ptr);
+ const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit);
+ // Now that we have cast the RHS data, we store it so that each value
+ // can be separately loaded in the accumulation loop.
+ _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data_buf),
+ rhs_16_bit_dup);
+ } else {
+ RUY_DCHECK(params.rhs_scalar_size == 2);
+ }
// NOTE: There may be opportunities for permuting the data in the packing
// code instead of here.
@@ -864,7 +872,9 @@ void Kernel8bitAvx2SingleColImpl(const KernelParams8bit<8, 8>& params) {
_mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
- rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
+ rhs_ptr = static_cast<const void*>(static_cast<const char*>(rhs_ptr) +
+ kAvx8bitBlockSize * kAvx8bitInnerSize *
+ params.rhs_scalar_size);
}
if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
@@ -1002,7 +1012,8 @@ void Kernel8bitAvx2SingleColImpl(const KernelParams8bit<8, 8>& params) {
dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
kAvx8bitBlockSize * params.dst_stride);
- rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride;
+ rhs_col_ptr = static_cast<const void*>(static_cast<const char*>(rhs_col_ptr) +
+ kAvx8bitBlockSize * params.rhs_stride);
} // NOLINT(readability/fn_size)
void Kernel8bitAvx2SingleCol(const KernelParams8bit<8, 8>& params) {
diff --git a/ruy/kernel_avx512.cc b/ruy/kernel_avx512.cc
index 3d2219b..654ba27 100644
--- a/ruy/kernel_avx512.cc
+++ b/ruy/kernel_avx512.cc
@@ -634,8 +634,7 @@ void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) {
int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 16 : 0;
- const std::int8_t* rhs_col_ptr =
- static_cast<const int8_t*>(params.rhs_base_ptr);
+ const void* rhs_col_ptr = params.rhs_base_ptr;
void* dst_col_ptr = params.dst_base_ptr;
const std::int32_t* bias_col_ptr = params.bias;
if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
@@ -694,20 +693,28 @@ void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) {
}
const std::int8_t* lhs_ptr = lhs_col_ptr;
- const std::int8_t* rhs_ptr = rhs_col_ptr;
+ const void* rhs_ptr = rhs_col_ptr;
for (int d = 0; d < params.depth; d += 4) {
const __m512i lhs_data = _mm512_loadu_si512(lhs_ptr);
- const __m128i rhs_data_8bit =
- _mm_loadu_si128(reinterpret_cast<const __m128i*>(rhs_ptr));
+ const std::int32_t* rhs_data =
+ reinterpret_cast<const std::int32_t*>(rhs_ptr);
// Each "int32" is two 16-bit RHS values, sign extended from 8-bit.
// For simplicity we load 4x the data that we need and process twice the
// data that we need and store only the data we need.
- std::int32_t rhs_data[2];
- const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit);
- // Now that we have cast the RHS data, we store it so that each value
- // can be separately loaded in the accumulation loop.
- _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup);
+ std::int32_t rhs_data_buf[2];
+ if (params.rhs_scalar_size == 1) {
+ rhs_data = rhs_data_buf;
+ const __m128i rhs_data_8bit =
+ _mm_loadu_si128(reinterpret_cast<const __m128i*>(rhs_ptr));
+ const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit);
+ // Now that we have cast the RHS data, we store it so that each value
+ // can be separately loaded in the accumulation loop.
+ _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data_buf),
+ rhs_16_bit_dup);
+ } else {
+ RUY_DCHECK(params.rhs_scalar_size == 2);
+ }
// Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit.
const __m512i lhs_16_bit_low =
@@ -731,7 +738,8 @@ void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) {
accum_data_v0 = accum_v;
lhs_ptr += 16 * 4;
- rhs_ptr += 16 * 4;
+ rhs_ptr = static_cast<const void*>(static_cast<const char*>(rhs_ptr) +
+ 16 * 4 * params.rhs_scalar_size);
}
if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
diff --git a/ruy/kernel_x86.h b/ruy/kernel_x86.h
index 0a6cf90..51787b9 100644
--- a/ruy/kernel_x86.h
+++ b/ruy/kernel_x86.h
@@ -74,7 +74,12 @@ struct Kernel<Path::kAvx512, std::int8_t, std::int16_t, std::int32_t,
KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
end_col, dst, &params);
- Kernel8bitAvx512(params);
+ if (dst->layout.cols == 1 &&
+ mul_params.channel_dimension() == ChannelDimension::kRow) {
+ Kernel8bitAvx512SingleCol(params);
+ } else {
+ Kernel8bitAvx512(params);
+ }
}
};
@@ -143,7 +148,12 @@ struct Kernel<Path::kAvx2Fma, std::int8_t, std::int16_t, std::int32_t,
KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
end_col, dst, &params);
- Kernel8bitAvx2(params);
+ if (dst->layout.cols == 1 &&
+ mul_params.channel_dimension() == ChannelDimension::kRow) {
+ Kernel8bitAvx2SingleCol(params);
+ } else {
+ Kernel8bitAvx2(params);
+ }
}
};