Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/google/ruy.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'ruy/kernel_sse42.cc')
-rw-r--r--ruy/kernel_sse42.cc428
1 files changed, 428 insertions, 0 deletions
diff --git a/ruy/kernel_sse42.cc b/ruy/kernel_sse42.cc
new file mode 100644
index 0000000..747ca1c
--- /dev/null
+++ b/ruy/kernel_sse42.cc
@@ -0,0 +1,428 @@
+/* Copyright 2019 Google LLC. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <algorithm>
+#include <cstdint>
+
+#include "ruy/check_macros.h"
+#include "ruy/kernel.h"
+#include "ruy/opt_set.h"
+#include "ruy/platform.h"
+#include "ruy/profiler/instrumentation.h"
+
+#if RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+#include <immintrin.h> // IWYU pragma: keep
+#endif
+
+namespace ruy {
+
+#if !(RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM))
+
+void Kernel8bitSse42(const KernelParams8bit<8, 8>& params) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+void KernelFloatSse42(const KernelParamsFloat<8, 8>& params) {
+ // CPU-ID-based checks should disable the path that would reach this point.
+ RUY_DCHECK(false);
+}
+
+#else // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+static constexpr int kAvxFloatBlockSize = 8;
+static constexpr int kAvx8bitBlockSize = 8;
+static constexpr int kAvx8bitInnerSize = 4;
+
+// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder.
+// Optimization is not finished. In particular the dimensions of the kernel
+// blocks can be changed as desired.
+//
+// When removing this comment, update profiling label below.
+void Kernel8bitSse42(const KernelParams8bit<8, 8>& params) {
+ profiler::ScopeLabel label("Kernel kSse42 8-bit (UNFINISHED)");
+ std::int32_t accum_data[kAvx8bitBlockSize][kAvx8bitBlockSize];
+
+ int bias_ptr_block_increment =
+ params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvx8bitBlockSize : 0;
+
+ const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
+ void* dst_col_ptr = params.dst_base_ptr;
+ const std::int32_t* bias_col_ptr = params.bias;
+ if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
+ bias_col_ptr += params.start_row;
+ }
+
+ for (int col = params.start_col; col <= params.last_col;
+ col += kAvx8bitBlockSize) {
+ const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
+ void* dst_ptr = dst_col_ptr;
+ const std::int32_t* bias_ptr = bias_col_ptr;
+
+ for (int row = params.start_row; row <= params.last_row;
+ row += kAvx8bitBlockSize) {
+ const int residual_rows =
+ std::min(params.dst_rows - row, kAvx8bitBlockSize);
+ const int residual_cols =
+ std::min(params.dst_cols - col, kAvx8bitBlockSize);
+
+ // Initialize with bias.
+ std::int32_t initial_accum_data[kAvx8bitBlockSize];
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ initial_accum_data[i] = 0;
+ }
+ for (int i = 0; i < residual_rows; ++i) {
+ initial_accum_data[i] = bias_ptr[i];
+ }
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ accum_data[j][i] = initial_accum_data[i];
+ }
+ }
+ bias_ptr += bias_ptr_block_increment;
+
+ std::int8_t lhs_data[kAvx8bitBlockSize][kAvx8bitInnerSize];
+ std::int8_t rhs_data[kAvx8bitBlockSize][kAvx8bitInnerSize];
+ const std::int8_t* lhs_ptr = lhs_col_ptr;
+ const std::int8_t* rhs_ptr = rhs_col_ptr;
+ for (int d = 0; d < params.depth; d += kAvx8bitInnerSize) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ for (int x = 0; x < kAvx8bitInnerSize; ++x) {
+ lhs_data[i][x] = lhs_ptr[i * kAvx8bitInnerSize + x];
+ rhs_data[i][x] = rhs_ptr[i * kAvx8bitInnerSize + x];
+ }
+ }
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ for (int x = 0; x < kAvx8bitInnerSize; ++x) {
+ accum_data[j][i] += lhs_data[i][x] * rhs_data[j][x];
+ }
+ }
+ }
+ lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
+ rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
+ }
+
+ if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && params.rhs_zero_point) {
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ accum_data[j][i] -=
+ params.rhs_zero_point * params.lhs_sums[row + i];
+ }
+ }
+ }
+ if ((params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && params.lhs_zero_point) {
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ accum_data[j][i] -=
+ params.lhs_zero_point * params.rhs_sums[col + j];
+ }
+ }
+ }
+ if (params.lhs_zero_point && params.rhs_zero_point) {
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ accum_data[j][i] += params.prod_zp_depth;
+ }
+ }
+ }
+
+ if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
+ std::int32_t m_vector[kAvx8bitBlockSize];
+ std::int32_t e_vector[kAvx8bitBlockSize];
+ // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
+ if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) {
+ int i = 0;
+ for (; i < residual_rows; ++i) {
+ m_vector[i] = params.multiplier_fixedpoint[row + i];
+ e_vector[i] = params.multiplier_exponent[row + i];
+ }
+ for (; i < kAvx8bitBlockSize; ++i) {
+ m_vector[i] = m_vector[0];
+ e_vector[i] = e_vector[0];
+ }
+ } else {
+ // These arrays have size LhsCols, and are pre-filled.
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ m_vector[i] = params.multiplier_fixedpoint[i];
+ e_vector[i] = params.multiplier_exponent[i];
+ }
+ }
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ accum_data[j][i] = MultiplyByQuantizedMultiplier(
+ accum_data[j][i], m_vector[i], e_vector[i]);
+ }
+ }
+
+ if (params.dst_zero_point) {
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ accum_data[j][i] += params.dst_zero_point;
+ }
+ }
+ }
+
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ accum_data[j][i] =
+ std::min<std::int32_t>(accum_data[j][i], params.clamp_max);
+ accum_data[j][i] =
+ std::max<std::int32_t>(accum_data[j][i], params.clamp_min);
+ }
+ }
+ }
+
+ const bool store_full_block = (residual_rows == kAvx8bitBlockSize) &&
+ (residual_cols == kAvx8bitBlockSize);
+
+ if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) {
+ std::int8_t* tmp_ptr =
+ store_full_block
+ ? static_cast<std::int8_t*>(dst_ptr)
+ : const_cast<std::int8_t*>(
+ reinterpret_cast<const std::int8_t*>(params.dst_tmp_buf));
+ const int block_col_offset =
+ store_full_block ? params.dst_stride / sizeof(std::int8_t)
+ : kAvx8bitBlockSize;
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ tmp_ptr[i] = accum_data[j][i];
+ }
+ tmp_ptr += block_col_offset;
+ }
+
+ if (!store_full_block) {
+ const std::int8_t* block_ptr =
+ reinterpret_cast<const std::int8_t*>(params.dst_tmp_buf);
+ for (int j = 0; j < residual_cols; ++j) {
+ for (int i = 0; i < residual_rows; ++i) {
+ static_cast<std::int8_t*>(
+ dst_ptr)[j * params.dst_stride / sizeof(std::int8_t) + i] =
+ block_ptr[i];
+ }
+ block_ptr += kAvx8bitBlockSize;
+ }
+ }
+ dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) {
+ std::uint8_t* tmp_ptr = store_full_block
+ ? static_cast<std::uint8_t*>(dst_ptr)
+ : const_cast<std::uint8_t*>(
+ reinterpret_cast<const std::uint8_t*>(
+ params.dst_tmp_buf));
+ const int block_col_offset =
+ store_full_block ? params.dst_stride : kAvx8bitBlockSize;
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ tmp_ptr[i] = accum_data[j][i];
+ }
+ tmp_ptr += block_col_offset;
+ }
+
+ if (!store_full_block) {
+ const std::uint8_t* block_ptr =
+ reinterpret_cast<const std::uint8_t*>(params.dst_tmp_buf);
+ for (int j = 0; j < residual_cols; ++j) {
+ for (int i = 0; i < residual_rows; ++i) {
+ static_cast<std::uint8_t*>(
+ dst_ptr)[j * params.dst_stride / sizeof(std::uint8_t) + i] =
+ block_ptr[i];
+ }
+ block_ptr += kAvx8bitBlockSize;
+ }
+ }
+ dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) {
+ if (store_full_block) {
+ std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr);
+ const int block_col_offset = params.dst_stride / sizeof(std::int16_t);
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ tmp_ptr[i] = accum_data[j][i];
+ }
+ tmp_ptr += block_col_offset;
+ }
+ } else {
+ std::int16_t* tmp_ptr = const_cast<std::int16_t*>(
+ reinterpret_cast<const std::int16_t*>(params.dst_tmp_buf));
+ const int block_col_offset = kAvx8bitBlockSize;
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ tmp_ptr[i] = accum_data[j][i];
+ }
+ tmp_ptr += block_col_offset;
+ }
+ const std::int16_t* block_ptr =
+ reinterpret_cast<const std::int16_t*>(params.dst_tmp_buf);
+ std::int16_t* dst_block_ptr = static_cast<std::int16_t*>(dst_ptr);
+ for (int j = 0; j < residual_cols; ++j) {
+ for (int i = 0; i < residual_rows; ++i) {
+ dst_block_ptr[i] = block_ptr[i];
+ }
+ dst_block_ptr += params.dst_stride / sizeof(std::int16_t);
+ block_ptr += kAvx8bitBlockSize;
+ }
+ }
+ dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) {
+ if (store_full_block) {
+ std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr);
+ const int block_col_offset = params.dst_stride / sizeof(std::int32_t);
+ for (int j = 0; j < kAvx8bitBlockSize; ++j) {
+ for (int i = 0; i < kAvx8bitBlockSize; ++i) {
+ tmp_ptr[i] = accum_data[j][i];
+ }
+ tmp_ptr += block_col_offset;
+ }
+ } else {
+ std::int32_t* dst_block_ptr = static_cast<std::int32_t*>(dst_ptr);
+ for (int j = 0; j < residual_cols; ++j) {
+ for (int i = 0; i < residual_rows; ++i) {
+ dst_block_ptr[i] = accum_data[j][i];
+ }
+ dst_block_ptr += params.dst_stride / sizeof(std::int32_t);
+ }
+ }
+ dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) +
+ kAvx8bitBlockSize);
+ } else {
+ RUY_DCHECK(false);
+ }
+
+ lhs_col_ptr += kAvx8bitBlockSize * params.lhs_stride;
+ } // End row-block loop.
+
+ dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) +
+ kAvx8bitBlockSize * params.dst_stride);
+ rhs_col_ptr += kAvx8bitBlockSize * params.rhs_stride;
+ } // End col-block loop.
+} // NOLINT(readability/fn_size)
+
+// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder.
+// Optimization is not finished. In particular the dimensions of the kernel
+// blocks can be changed as desired.
+//
+// When removing this comment, update profiling label below.
+void KernelFloatSse42(const KernelParamsFloat<8, 8>& params) {
+ profiler::ScopeLabel label("Kernel kSse42 float (UNFINISHED)");
+
+ float lhs_data[kAvxFloatBlockSize];
+ float rhs_data[kAvxFloatBlockSize];
+ float accum_data[kAvxFloatBlockSize][kAvxFloatBlockSize];
+ int bias_ptr_block_increment =
+ params.flags & RUY_ASM_FLAG_HAS_BIAS ? kAvxFloatBlockSize : 0;
+
+ const float* rhs_col_ptr = params.rhs_base_ptr;
+ float* dst_col_ptr = params.dst_base_ptr;
+ const float* bias_col_ptr = params.bias;
+ if (params.flags & RUY_ASM_FLAG_HAS_BIAS) {
+ bias_col_ptr += params.start_row;
+ }
+
+ for (int col = params.start_col; col <= params.last_col;
+ col += kAvxFloatBlockSize) {
+ const float* lhs_col_ptr = params.lhs_base_ptr;
+ float* dst_ptr = dst_col_ptr;
+ const float* bias_ptr = bias_col_ptr;
+
+ for (int row = params.start_row; row <= params.last_row;
+ row += kAvxFloatBlockSize) {
+ const int residual_rows =
+ std::min(params.dst_rows - row, kAvxFloatBlockSize);
+ const int residual_cols =
+ std::min(params.dst_cols - col, kAvxFloatBlockSize);
+
+ // Initialize with bias.
+ float initial_accum_data[kAvxFloatBlockSize];
+ for (int i = 0; i < kAvxFloatBlockSize; ++i) {
+ initial_accum_data[i] = 0.0f;
+ }
+ for (int i = 0; i < residual_rows; ++i) {
+ initial_accum_data[i] = bias_ptr[i];
+ }
+ for (int j = 0; j < kAvxFloatBlockSize; ++j) {
+ for (int i = 0; i < kAvxFloatBlockSize; ++i) {
+ accum_data[j][i] = initial_accum_data[i];
+ }
+ }
+ bias_ptr += bias_ptr_block_increment;
+
+ const float* lhs_ptr = lhs_col_ptr;
+ const float* rhs_ptr = rhs_col_ptr;
+ for (int d = 0; d < params.depth; ++d) {
+ for (int i = 0; i < kAvxFloatBlockSize; ++i) {
+ lhs_data[i] = lhs_ptr[i];
+ rhs_data[i] = rhs_ptr[i];
+ }
+ for (int j = 0; j < kAvxFloatBlockSize; ++j) {
+ for (int i = 0; i < kAvxFloatBlockSize; ++i) {
+ accum_data[j][i] += lhs_data[i] * rhs_data[j];
+ }
+ }
+ lhs_ptr += kAvxFloatBlockSize;
+ rhs_ptr += kAvxFloatBlockSize;
+ }
+
+ for (int j = 0; j < kAvxFloatBlockSize; ++j) {
+ for (int i = 0; i < kAvxFloatBlockSize; ++i) {
+ accum_data[j][i] =
+ std::min<float>(accum_data[j][i], params.clamp_max);
+ accum_data[j][i] =
+ std::max<float>(accum_data[j][i], params.clamp_min);
+ }
+ }
+
+ const bool store_full_block = (residual_rows == kAvxFloatBlockSize) &&
+ (residual_cols == kAvxFloatBlockSize);
+
+ {
+ float* block_ptr =
+ store_full_block ? dst_ptr : const_cast<float*>(params.dst_tmp_buf);
+ const int block_col_offset = store_full_block
+ ? params.dst_stride / sizeof(float)
+ : kAvxFloatBlockSize;
+ for (int j = 0; j < kAvxFloatBlockSize; ++j) {
+ for (int i = 0; i < kAvxFloatBlockSize; ++i) {
+ block_ptr[i] = accum_data[j][i];
+ }
+ block_ptr += block_col_offset;
+ }
+ }
+ if (!store_full_block) {
+ const float* block_ptr = params.dst_tmp_buf;
+ for (int j = 0; j < residual_cols; ++j) {
+ for (int i = 0; i < residual_rows; ++i) {
+ dst_ptr[j * params.dst_stride / sizeof(float) + i] = block_ptr[i];
+ }
+ block_ptr += kAvxFloatBlockSize;
+ }
+ }
+
+ lhs_col_ptr += kAvxFloatBlockSize * params.lhs_stride / sizeof(float);
+ dst_ptr += kAvxFloatBlockSize;
+ } // End row-block loop.
+
+ dst_col_ptr += kAvxFloatBlockSize * params.dst_stride / sizeof(float);
+ rhs_col_ptr += kAvxFloatBlockSize * params.rhs_stride / sizeof(float);
+ } // End col-block loop.
+}
+
+#endif // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM)
+
+} // namespace ruy