diff options
Diffstat (limited to 'ruy/kernel_common.h')
-rw-r--r-- | ruy/kernel_common.h | 481 |
1 files changed, 481 insertions, 0 deletions
diff --git a/ruy/kernel_common.h b/ruy/kernel_common.h new file mode 100644 index 0000000..0cd123f --- /dev/null +++ b/ruy/kernel_common.h @@ -0,0 +1,481 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_COMMON_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_COMMON_H_ + +#include <algorithm> +#include <cstdint> +#include <type_traits> + +#include "ruy/check_macros.h" +#include "ruy/common.h" +#include "ruy/internal_matrix.h" +#include "ruy/matrix.h" +#include "ruy/opt_set.h" +#include "ruy/path.h" +#include "ruy/platform.h" +#include "ruy/profiler/instrumentation.h" +#include "ruy/side_pair.h" +#include "ruy/size_util.h" +#include "ruy/spec.h" +#include "ruy/tune.h" + +namespace ruy { + +template <Path ThePath, typename LhsScalar, typename RhsScalar, + typename DstScalar, typename Spec> +struct Kernel {}; + +template <Path ThePath, typename LhsScalar, typename RhsScalar, + typename DstScalar, typename Spec> +void RunKernelTyped(Tuning tuning, const PackedMatrix<LhsScalar>& lhs, + const PackedMatrix<RhsScalar>& rhs, const Spec& spec, + int start_row, int start_col, int end_row, int end_col, + Matrix<DstScalar>* dst) { + using Kernel = Kernel<ThePath, LhsScalar, RhsScalar, DstScalar, Spec>; + Kernel kernel(tuning); +#if !defined(NDEBUG) || !RUY_OPT_ENABLED(RUY_OPT_FAT_KERNEL) + using LhsLayout = typename Kernel::LhsLayout; + using RhsLayout = typename Kernel::RhsLayout; +#endif + // end_row and end_col may be larger than dst dimensions. + // that is because kernels write directly to the destination matrix, whose + // dimensions may not be a multiple of the kernel dimensions, and we try to + // keep this annoyance localized as an implementation detail in kernels, + // by allowing to pass rounded-up values down as far as possible. + // These assertions encode the contract. + RUY_DCHECK_LE(0, start_row); + RUY_DCHECK_LE(start_row, end_row); + RUY_DCHECK_LT(end_row, dst->layout.rows + LhsLayout::kCols); + RUY_DCHECK_EQ((end_row - start_row) % LhsLayout::kCols, 0); + RUY_DCHECK_LE(0, start_col); + RUY_DCHECK_LE(start_col, end_col); + RUY_DCHECK_LT(end_col, dst->layout.cols + RhsLayout::kCols); + RUY_DCHECK_EQ((end_col - start_col) % RhsLayout::kCols, 0); +#if RUY_OPT_ENABLED(RUY_OPT_FAT_KERNEL) + kernel.Run(lhs, rhs, spec, start_row, start_col, end_row, end_col, dst); +#else + for (int col = start_col; col < end_col; col += RhsLayout::kCols) { + int block_end_col = std::min(col + RhsLayout::kCols, end_col); + for (int row = start_row; row < end_row; row += LhsLayout::kCols) { + int block_end_row = std::min(row + LhsLayout::kCols, end_row); + kernel.Run(lhs, rhs, spec, row, col, block_end_row, block_end_col, dst); + } + } +#endif +} + +// Main entry point for kernels. +template <Path ThePath, typename LhsScalar, typename RhsScalar, + typename DstScalar, typename Spec> +void RunKernel(Tuning tuning, const SidePair<PMatrix>& src, void* spec, + const SidePair<int>& start, const SidePair<int>& end, + DMatrix* dst) { + Matrix<DstScalar> mdst = ToMatrix<DstScalar>(*dst); + RunKernelTyped<ThePath, LhsScalar, RhsScalar, DstScalar, Spec>( + tuning, ToPackedMatrix<LhsScalar>(src[Side::kLhs]), + ToPackedMatrix<RhsScalar>(src[Side::kRhs]), + *static_cast<const Spec*>(spec), start[Side::kLhs], start[Side::kRhs], + end[Side::kLhs], end[Side::kRhs], &mdst); +} + +// Copied from gemmlowp/fixedpoint. +inline std::int32_t SaturatingRoundingDoublingHighMul(std::int32_t a, + std::int32_t b) { + bool overflow = a == b && a == std::numeric_limits<std::int32_t>::min(); + std::int64_t a_64(a); + std::int64_t b_64(b); + std::int64_t ab_64 = a_64 * b_64; + std::int32_t nudge = ab_64 >= 0 ? (1 << 30) : (1 - (1 << 30)); + std::int32_t ab_x2_high32 = + static_cast<std::int32_t>((ab_64 + nudge) / (1ll << 31)); + return overflow ? std::numeric_limits<std::int32_t>::max() : ab_x2_high32; +} + +inline std::int32_t RoundingDivideByPOT(std::int32_t numerator, int exponent) { + std::int32_t sign = numerator >= 0 ? 1 : -1; + std::int32_t abs_numerator = std::abs(numerator); + std::int32_t mask = (1LL << exponent) - 1; + std::int32_t remainder = abs_numerator & mask; + std::int32_t threshold = mask >> 1; + std::int32_t abs_result = + (abs_numerator >> exponent) + (remainder > threshold ? 1 : 0); + return sign * abs_result; +} + +// Copied from TF Lite code. +inline std::int32_t MultiplyByQuantizedMultiplier( + std::int32_t x, std::int32_t quantized_multiplier, int shift) { + int left_shift = shift > 0 ? shift : 0; + int right_shift = shift > 0 ? 0 : -shift; + return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul( + x * (1 << left_shift), quantized_multiplier), + right_shift); +} + +// Helper to apply a fixed-point multiplier. Only 'applicable' if AccumScalar +// is int32 (i.e. in all cases except floating-point) and if the destination is +// not int32 (i.e. unless the user wants to get raw accumulators). +template <typename Spec, + bool IsApplicable = + std::is_same<typename Spec::AccumScalar, std::int32_t>::value && + !std::is_same<typename Spec::DstScalar, std::int32_t>::value> +struct ApplyMultiplierImpl {}; + +// Specialization in non-applicable case: do nothing, just check that values +// are default. +template <typename Spec> +struct ApplyMultiplierImpl<Spec, false> { + using AccumScalar = typename Spec::AccumScalar; + using DstScalar = typename Spec::DstScalar; + static void Run(const Spec& spec, int row, AccumScalar* accum) { + RUY_DCHECK_EQ(spec.multiplier_fixedpoint, 0); + RUY_DCHECK_EQ(spec.multiplier_exponent, 0); + } +}; + +template <typename Spec> +struct ApplyMultiplierImpl<Spec, true> { + using AccumScalar = typename Spec::AccumScalar; + using DstScalar = typename Spec::DstScalar; + static void Run(const Spec& spec, int row, AccumScalar* accum) { + AccumScalar m = spec.multiplier_fixedpoint_perchannel + ? spec.multiplier_fixedpoint_perchannel[row] + : spec.multiplier_fixedpoint; + int e = spec.multiplier_exponent_perchannel + ? spec.multiplier_exponent_perchannel[row] + : spec.multiplier_exponent; + *accum = MultiplyByQuantizedMultiplier(*accum, m, e); + } +}; + +template <typename Spec> +void ApplyMultiplier(const Spec& spec, int row, + typename Spec::AccumScalar* accum) { + ApplyMultiplierImpl<Spec>::Run(spec, row, accum); +} + +template <typename LhsScalar, typename RhsScalar, typename DstScalar, + typename Spec> +struct Kernel<Path::kStandardCpp, LhsScalar, RhsScalar, DstScalar, Spec> { + using AccumScalar = typename Spec::AccumScalar; + using LhsLayout = typename Spec::StandardCppKernelLhsLayout; + using RhsLayout = typename Spec::StandardCppKernelRhsLayout; + explicit Kernel(Tuning) {} + void Run(const PackedMatrix<LhsScalar>& lhs, + const PackedMatrix<RhsScalar>& rhs, const Spec& spec, int start_row, + int start_col, int end_row, int end_col, + Matrix<DstScalar>* dst) const { + // See the comment in RunKernelTyped. end_row may be larger than + // dst->layout.rows. It's the responsibility of the kernel to avoid + // overrunning dst boundaries, which we do here by computing + // clamped_end_row. + int clamped_end_row = std::min(end_row, dst->layout.rows); + int clamped_end_col = std::min(end_col, dst->layout.cols); + RUY_DCHECK_LE(0, start_row); + RUY_DCHECK_LE(start_row, clamped_end_row); + RUY_DCHECK_LE(clamped_end_row, dst->layout.rows); + RUY_DCHECK_LE(clamped_end_row, end_row); + RUY_DCHECK_LE(end_row - clamped_end_row, LhsLayout::kCols); + RUY_DCHECK_LE(0, start_col); + RUY_DCHECK_LE(start_col, clamped_end_col); + RUY_DCHECK_LE(clamped_end_col, dst->layout.cols); + RUY_DCHECK_LE(clamped_end_col, end_col); + RUY_DCHECK_LE(end_col - clamped_end_col, RhsLayout::kCols); + profiler::ScopeLabel label("Kernel (Standard Cpp)"); + const int depth = lhs.layout.rows; + for (int i = start_row; i < clamped_end_row; i++) { + for (int j = start_col; j < clamped_end_col; j++) { + using AccumScalar = typename Spec::AccumScalar; + AccumScalar accum = 0; + for (int k = 0; k < depth; k++) { + AccumScalar lhs_val = Element(lhs, k, i); + AccumScalar rhs_val = Element(rhs, k, j); + accum += lhs_val * rhs_val; + } + if (spec.bias) { + accum += spec.bias[i]; + } + if (lhs.zero_point) { + accum -= lhs.zero_point * rhs.sums[j]; + } + if (rhs.zero_point) { + accum -= rhs.zero_point * lhs.sums[i]; + } + if (lhs.zero_point && rhs.zero_point) { + accum += lhs.zero_point * rhs.zero_point * depth; + } + ApplyMultiplier(spec, i, &accum); + accum += dst->zero_point; + accum = std::min<AccumScalar>(accum, spec.clamp_max); + accum = std::max<AccumScalar>(accum, spec.clamp_min); + *ElementPtr(dst, i, j) = static_cast<DstScalar>(accum); + } + } + } +}; + +#define RUY_INHERIT_KERNEL(PARENT, CHILD) \ + template <typename LhsScalar, typename RhsScalar, typename DstScalar, \ + typename Spec> \ + struct Kernel<CHILD, LhsScalar, RhsScalar, DstScalar, Spec> \ + : Kernel<PARENT, LhsScalar, RhsScalar, DstScalar, Spec> { \ + explicit Kernel(Tuning tuning) \ + : Kernel<PARENT, LhsScalar, RhsScalar, DstScalar, Spec>(tuning) {} \ + }; + +#if RUY_PLATFORM(NEON) +RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kNeon) +RUY_INHERIT_KERNEL(Path::kNeon, Path::kNeonDotprod) +#elif RUY_PLATFORM(X86) +RUY_INHERIT_KERNEL(Path::kStandardCpp, Path::kSse42) +RUY_INHERIT_KERNEL(Path::kSse42, Path::kAvx2) +RUY_INHERIT_KERNEL(Path::kAvx2, Path::kAvx512) +RUY_INHERIT_KERNEL(Path::kAvx512, Path::kAvxVnni) +#endif + +// KernelParams are shared across 32-bit and 64-bit NEON code, and x86 code. +// +// In other cases, we still define (empty) versions, so that dummy kernels +// can use the classes in function signatures. +#if ((RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) && \ + RUY_OPT_ENABLED(RUY_OPT_ASM)) || \ + RUY_PLATFORM(X86) + +#define RUY_ASM_FLAG_HAS_BIAS 0x1 +#define RUY_ASM_FLAG_HAS_LHS_SUMS 0x2 +#define RUY_ASM_FLAG_HAS_RHS_SUMS 0x4 +#define RUY_ASM_FLAG_HAS_PERCHANNEL 0x8 +#define RUY_ASM_FLAG_NEEDS_LEFT_SHIFT 0x10 + +#define RUY_ASM_TYPE_ID_UINT8 1 +#define RUY_ASM_TYPE_ID_INT8 2 +#define RUY_ASM_TYPE_ID_INT16 3 +#define RUY_ASM_TYPE_ID_INT32 4 + +template <typename DstScalar> +struct DstTypeId {}; + +template <> +struct DstTypeId<std::uint8_t> { + static constexpr int kValue = RUY_ASM_TYPE_ID_UINT8; +}; + +template <> +struct DstTypeId<std::int8_t> { + static constexpr int kValue = RUY_ASM_TYPE_ID_INT8; +}; + +template <> +struct DstTypeId<std::int16_t> { + static constexpr int kValue = RUY_ASM_TYPE_ID_INT16; +}; + +template <> +struct DstTypeId<std::int32_t> { + static constexpr int kValue = RUY_ASM_TYPE_ID_INT32; +}; + +template <int LhsCols, int RhsCols> +struct KernelParams8bit { + static constexpr int kMaxDstTypeSize = 4; + + const std::int32_t* bias; + const std::int32_t* lhs_sums; + const std::int32_t* rhs_sums; + const std::int8_t* lhs_base_ptr; + const std::int32_t* multiplier_fixedpoint; + const std::int32_t* multiplier_exponent; + const std::int8_t* rhs_base_ptr; + void* dst_base_ptr; + std::int32_t lhs_zero_point; + std::int32_t rhs_zero_point; + std::int32_t dst_zero_point; + std::int32_t prod_zp_depth; + std::int32_t start_row; + std::int32_t start_col; + std::int32_t last_row; + std::int32_t last_col; + std::int32_t dst_rows; + std::int32_t dst_cols; + std::int32_t lhs_stride; + std::int32_t rhs_stride; + std::int32_t dst_stride; + std::int32_t depth; + std::int32_t clamp_min; + std::int32_t clamp_max; + std::uint8_t flags; + std::uint8_t dst_type_id; + const std::int32_t zero_data[LhsCols] = {0}; + std::uint8_t dst_tmp_buf[LhsCols * RhsCols * kMaxDstTypeSize]; + std::int32_t multiplier_fixedpoint_buf[LhsCols]; + std::int32_t multiplier_exponent_buf[LhsCols]; +}; + +template <typename DstScalar, int LhsCols, int RhsCols> +void MakeKernelParams8bit(const PackedMatrix<std::int8_t>& lhs, + const PackedMatrix<std::int8_t>& rhs, + const BasicSpec<std::int32_t, DstScalar>& spec, + int start_row, int start_col, int end_row, + int end_col, Matrix<DstScalar>* dst, + KernelParams8bit<LhsCols, RhsCols>* params) { + using Params = KernelParams8bit<LhsCols, RhsCols>; + + static_assert(sizeof(DstScalar) <= Params::kMaxDstTypeSize, ""); + + const int depth = lhs.layout.rows; + RUY_DCHECK_EQ(start_row % LhsCols, 0); + RUY_DCHECK_EQ(start_col % RhsCols, 0); + RUY_DCHECK_EQ(end_row % LhsCols, 0); + RUY_DCHECK_EQ(end_col % RhsCols, 0); + + params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride; + params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride; + params->flags = 0; + params->bias = params->zero_data; + if (spec.bias) { + params->bias = spec.bias; + params->flags |= RUY_ASM_FLAG_HAS_BIAS; + } + if (lhs.sums) { + params->lhs_sums = lhs.sums; + params->flags |= RUY_ASM_FLAG_HAS_LHS_SUMS; + } + if (rhs.sums) { + params->rhs_sums = rhs.sums; + params->flags |= RUY_ASM_FLAG_HAS_RHS_SUMS; + } + params->start_row = start_row; + params->start_col = start_col; + params->last_row = end_row - LhsCols; + params->last_col = end_col - RhsCols; + params->lhs_stride = lhs.layout.stride; + params->rhs_stride = rhs.layout.stride; + params->dst_stride = sizeof(DstScalar) * dst->layout.stride; + params->lhs_zero_point = lhs.zero_point; + params->rhs_zero_point = rhs.zero_point; + params->dst_zero_point = dst->zero_point; + params->depth = depth; + params->prod_zp_depth = lhs.zero_point * rhs.zero_point * depth; + if (spec.multiplier_fixedpoint_perchannel) { + params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT; + params->flags |= RUY_ASM_FLAG_HAS_PERCHANNEL; + params->multiplier_fixedpoint = spec.multiplier_fixedpoint_perchannel; + params->multiplier_exponent = spec.multiplier_exponent_perchannel; + } else { + if (spec.multiplier_exponent > 0) { + params->flags |= RUY_ASM_FLAG_NEEDS_LEFT_SHIFT; + } + params->multiplier_fixedpoint = params->multiplier_fixedpoint_buf; + params->multiplier_exponent = params->multiplier_exponent_buf; + for (int i = 0; i < LhsCols; i++) { + params->multiplier_fixedpoint_buf[i] = spec.multiplier_fixedpoint; + params->multiplier_exponent_buf[i] = spec.multiplier_exponent; + } + } + params->clamp_min = spec.clamp_min; + params->clamp_max = spec.clamp_max; + params->dst_rows = dst->layout.rows; + params->dst_cols = dst->layout.cols; + + RUY_DCHECK_LT(params->last_row, params->dst_rows); + RUY_DCHECK_LT(params->last_col, params->dst_cols); + + params->dst_type_id = DstTypeId<DstScalar>::kValue; + params->dst_base_ptr = + dst->data.get() + start_col * dst->layout.stride + start_row; +} + +template <int LhsCols, int RhsCols> +struct KernelParamsFloat { + const float* lhs_base_ptr; + const float* rhs_base_ptr; + float* dst_base_ptr; + const float* bias; + std::int32_t start_row; + std::int32_t start_col; + std::int32_t last_row; + std::int32_t last_col; + std::int32_t dst_rows; + std::int32_t dst_cols; + std::int32_t lhs_stride; + std::int32_t rhs_stride; + std::int32_t dst_stride; + std::int32_t depth; + float clamp_min; + float clamp_max; + std::uint8_t flags; + const float zero_data[LhsCols] = {0}; + float dst_tmp_buf[LhsCols * RhsCols]; +}; + +template <int LhsCols, int RhsCols> +inline void MakeKernelParamsFloat(const PackedMatrix<float>& lhs, + const PackedMatrix<float>& rhs, + const BasicSpec<float, float>& spec, + int start_row, int start_col, int end_row, + int end_col, Matrix<float>* dst, + KernelParamsFloat<LhsCols, RhsCols>* params) { + const int depth = lhs.layout.rows; + RUY_DCHECK_EQ(start_row % LhsCols, 0); + RUY_DCHECK_EQ(start_col % RhsCols, 0); + RUY_DCHECK_EQ(end_row % LhsCols, 0); + RUY_DCHECK_EQ(end_col % RhsCols, 0); + + params->lhs_base_ptr = lhs.data + start_row * lhs.layout.stride; + params->rhs_base_ptr = rhs.data + start_col * rhs.layout.stride; + params->dst_base_ptr = + dst->data.get() + start_col * dst->layout.stride + start_row; + + std::uint8_t flags = 0; + params->bias = params->zero_data; + if (spec.bias) { + params->bias = spec.bias; + flags |= RUY_ASM_FLAG_HAS_BIAS; + } + params->flags = flags; + params->start_row = start_row; + params->start_col = start_col; + params->last_row = end_row - LhsCols; + params->last_col = end_col - RhsCols; + params->lhs_stride = sizeof(float) * lhs.layout.stride; + params->rhs_stride = sizeof(float) * rhs.layout.stride; + params->dst_stride = sizeof(float) * dst->layout.stride; + params->depth = depth; + params->clamp_min = spec.clamp_min; + params->clamp_max = spec.clamp_max; + params->dst_rows = dst->layout.rows; + params->dst_cols = dst->layout.cols; + + RUY_DCHECK_LT(params->last_row, params->dst_rows); + RUY_DCHECK_LT(params->last_col, params->dst_cols); +} + +#else // ((RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) && + // RUY_OPT_ENABLED(RUY_OPT_ASM)) || RUY_PLATFORM(X86) + +template <int LhsCols, int RhsCols> +struct KernelParams8bit {}; + +template <int LhsCols, int RhsCols> +struct KernelParamsFloat {}; + +#endif // ((RUY_PLATFORM(NEON_64) || RUY_PLATFORM(NEON_32)) && + // RUY_OPT_ENABLED(RUY_OPT_ASM)) || RUY_PLATFORM(X86) + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_KERNEL_COMMON_H_ |