diff options
Diffstat (limited to 'ruy/prepack.h')
-rw-r--r-- | ruy/prepack.h | 108 |
1 files changed, 108 insertions, 0 deletions
diff --git a/ruy/prepack.h b/ruy/prepack.h new file mode 100644 index 0000000..4bfc9ed --- /dev/null +++ b/ruy/prepack.h @@ -0,0 +1,108 @@ +/* Copyright 2019 Google LLC. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Implementation of low-level pre-packing API. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACK_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACK_H_ + +#include <cstddef> +#include <functional> + +#include "ruy/check_macros.h" +#include "ruy/context.h" +#include "ruy/dispatch.h" +#include "ruy/internal_matrix.h" +#include "ruy/matrix.h" +#include "ruy/path.h" +#include "ruy/profiler/instrumentation.h" +#include "ruy/side_pair.h" +#include "ruy/spec.h" +#include "ruy/trmul.h" +#include "ruy/trmul_params.h" +#include "ruy/tune.h" + +namespace ruy { + +template <Path CompiledPaths, typename LhsScalar, typename RhsScalar, + typename DstScalar, typename Spec> +void PrePackForMulInternal(const Matrix<LhsScalar>& lhs, + const Matrix<RhsScalar>& rhs, const Spec& spec, + Context* context, Matrix<DstScalar>* dst, + SidePair<PrepackedMatrix*> prepacked, + std::function<void*(std::size_t)> alloc_fn) { + profiler::ScopeLabel label("PrePackForMul"); + Path the_path = context->GetPathToTake<CompiledPaths>(); + RUY_CHECK_NE(the_path, Path::kReference); + constexpr Path TrMulCompiledPaths = CompiledPaths & ~Path::kReference; + Matrix<LhsScalar> transposed_lhs(lhs); + Transpose(&transposed_lhs); + TrMulParams params; + CreateTrMulParams<TrMulCompiledPaths>(transposed_lhs, rhs, spec, context, dst, + the_path, ¶ms); + + const SidePair<int> origin{0, 0}; + const SidePair<int> rounded_dims{params.packed[Side::kLhs].layout.cols, + params.packed[Side::kRhs].layout.cols}; + + Tuning tuning = context->GetMainThreadTuning(); + for (Side side : {Side::kLhs, Side::kRhs}) { + if (prepacked[side]) { + prepacked[side]->data_size = DataSize(params.packed[side]); + prepacked[side]->sums_size = SumsSize(params.packed[side]); + prepacked[side]->data = alloc_fn(prepacked[side]->data_size); + prepacked[side]->sums = alloc_fn(prepacked[side]->sums_size); + params.packed[side].data = prepacked[side]->data; + params.packed[side].sums = prepacked[side]->sums; + params.RunPack(side, tuning, origin[side], rounded_dims[side]); + } + } +} + +template <Path CompiledPaths, typename LhsScalar, typename RhsScalar, + typename DstScalar, typename Spec> +void MulWithPrepackedInternal(const Matrix<LhsScalar>& lhs, + const Matrix<RhsScalar>& rhs, const Spec& spec, + Context* context, Matrix<DstScalar>* dst, + SidePair<PrepackedMatrix*> prepacked) { + profiler::ScopeLabel label("MulWithPrepacked"); + + EnforceLayoutSupport<Spec>(lhs.layout, rhs.layout, dst->layout); + EnforceZeroPointSupport<Spec>(lhs.zero_point, rhs.zero_point, + dst->zero_point); + + Path the_path = context->GetPathToTake<CompiledPaths>(); + RUY_CHECK_NE(the_path, Path::kReference); + constexpr Path TrMulCompiledPaths = CompiledPaths & ~Path::kReference; + Matrix<LhsScalar> transposed_lhs(lhs); + Transpose(&transposed_lhs); + TrMulParams params; + CreateTrMulParams<TrMulCompiledPaths>(transposed_lhs, rhs, spec, context, dst, + the_path, ¶ms); + + for (Side side : {Side::kLhs, Side::kRhs}) { + if (prepacked[side]) { + params.packed[side].data = prepacked[side]->data; + params.packed[side].sums = prepacked[side]->sums; + params.is_prepacked[side] = true; + } + } + + TrMul(¶ms, context); +} + +} // namespace ruy + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_RUY_RUY_PREPACK_H_ |