Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/google/ruy.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/ruy
diff options
context:
space:
mode:
authorBenoit Jacob <benoitjacob@google.com>2020-04-22 05:51:26 +0300
committerCopybara-Service <copybara-worker@google.com>2020-04-22 05:51:49 +0300
commit145aecd896b44cfc455803aef4ed98745054567e (patch)
treea4d4e08dc63c6eca9e326985ebb91fd6b4468381 /ruy
parentf3c69a73897c2d97851d9528b34ba8c9371da886 (diff)
Rename:
internal_matrix.h -> mat.h Introduce the Mat class, internal counterpart of Matrix. Renaming: functions that erase/unerase scalar types are now named EraseType/UneraseType. Renaming: PackedLayout -> PMatLayout PackedMatrix -> PMat DMatrix -> EMat PMatrix -> PEMat PiperOrigin-RevId: 307730877
Diffstat (limited to 'ruy')
-rw-r--r--ruy/BUILD18
-rw-r--r--ruy/dispatch.h44
-rw-r--r--ruy/kernel_arm.h32
-rw-r--r--ruy/kernel_common.h40
-rw-r--r--ruy/kernel_x86.h42
-rw-r--r--ruy/mat.h (renamed from ruy/internal_matrix.h)180
-rw-r--r--ruy/matrix.h29
-rw-r--r--ruy/pack_arm.h24
-rw-r--r--ruy/pack_common.h15
-rw-r--r--ruy/pack_x86.h38
-rw-r--r--ruy/prepack.h17
-rw-r--r--ruy/ruy.h12
-rw-r--r--ruy/ruy_advanced.h16
-rw-r--r--ruy/trmul.cc12
-rw-r--r--ruy/trmul_params.h14
15 files changed, 292 insertions, 241 deletions
diff --git a/ruy/BUILD b/ruy/BUILD
index 2e61c66..cb6ab4a 100644
--- a/ruy/BUILD
+++ b/ruy/BUILD
@@ -365,8 +365,8 @@ cc_test(
)
cc_library(
- name = "internal_matrix",
- hdrs = ["internal_matrix.h"],
+ name = "mat",
+ hdrs = ["mat.h"],
copts = ruy_copts_base(),
deps = [
":check_macros",
@@ -403,7 +403,7 @@ cc_library(
deps = [
":check_macros",
":common",
- ":internal_matrix",
+ ":mat",
":matrix",
":mul_params",
":opt_set",
@@ -428,7 +428,7 @@ cc_library(
deps = [
":check_macros",
":common",
- ":internal_matrix",
+ ":mat",
":matrix",
":opt_set",
":path",
@@ -699,13 +699,13 @@ cc_library(
deps = [
":check_macros",
":common",
- ":internal_matrix",
":kernel_arm", # fixdeps: keep
":kernel_avx2", # fixdeps: keep
":kernel_avx512", # fixdeps: keep
":kernel_avxvnni", # fixdeps: keep
":kernel_common",
":kernel_sse42", # fixdeps: keep
+ ":mat",
":matrix",
":mul_params",
":opt_set",
@@ -728,7 +728,7 @@ cc_library(
deps = [
":check_macros",
":common",
- ":internal_matrix",
+ ":mat",
":matrix",
":opt_set",
":pack_arm", # fixdeps: keep
@@ -835,7 +835,7 @@ cc_library(
hdrs = ["trmul_params.h"],
copts = ruy_copts_base(),
deps = [
- ":internal_matrix",
+ ":mat",
":side_pair",
":tune",
],
@@ -853,7 +853,7 @@ cc_library(
":common",
":context",
":context_internal",
- ":internal_matrix",
+ ":mat",
":matrix",
":mul_params",
":opt_set",
@@ -885,8 +885,8 @@ cc_library(
":common",
":context",
":context_internal",
- ":internal_matrix",
":kernel",
+ ":mat",
":matrix",
":mul_params",
":opt_set",
diff --git a/ruy/dispatch.h b/ruy/dispatch.h
index 4c9e099..69234ee 100644
--- a/ruy/dispatch.h
+++ b/ruy/dispatch.h
@@ -42,9 +42,9 @@ limitations under the License.
#include "ruy/common.h"
#include "ruy/context.h"
#include "ruy/context_internal.h"
-#include "ruy/internal_matrix.h"
#include "ruy/kernel.h"
#include "ruy/kernel_common.h"
+#include "ruy/mat.h"
#include "ruy/matrix.h"
#include "ruy/mul_params.h"
#include "ruy/opt_set.h"
@@ -63,8 +63,9 @@ namespace ruy {
// this function enforces that the matrix multiplication at hand falls into
// that special case.
template <typename MulParamsType>
-void EnforceLayoutSupport(const Layout& lhs_layout, const Layout& rhs_layout,
- const Layout& dst_layout) {
+void EnforceLayoutSupport(const MatLayout& lhs_layout,
+ const MatLayout& rhs_layout,
+ const MatLayout& dst_layout) {
if (MulParamsType::kLayoutSupport == LayoutSupport::kRCC) {
RUY_DCHECK(IsRowMajor(lhs_layout));
RUY_DCHECK(IsColMajor(rhs_layout));
@@ -133,9 +134,9 @@ inline bool IsColMajorTrMul(const TrMulParams& params) {
IsColMajor(params.dst.layout);
}
-inline void CreatePackedLayout(const InternalLayout& src, const Type& scalar,
+inline void CreatePackedLayout(const MatLayout& src, const Type& scalar,
const KernelLayout& kernel_layout,
- PackedLayout* packed) {
+ PMatLayout* packed) {
packed->order = Order::kColMajor;
packed->rows = round_up_pot(src.rows, kernel_layout.rows);
packed->cols = round_up_pot(src.cols, kernel_layout.cols);
@@ -161,8 +162,8 @@ void CreatePackedMatrix(Side side, const KernelLayout& kernel_layout,
typename std::conditional<std::is_floating_point<Scalar>::value, Scalar,
std::int32_t>::type;
- const DMatrix& src = params->src[side];
- PMatrix* packed = &params->packed[side];
+ const EMat& src = params->src[side];
+ PEMat* packed = &params->packed[side];
packed->data_type = Type::Create<PackedScalar>();
packed->sums_type = Type::Create<SumsType>();
CreatePackedLayout(src.layout, packed->data_type, kernel_layout,
@@ -323,14 +324,13 @@ void PopulateTrMulParamsAllCompiledPaths(Path the_path, TrMulParams* params) {
template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
typename DstScalar, typename MulParamsType>
-void CreateTrMulParams(const Matrix<LhsScalar>& lhs,
- const Matrix<RhsScalar>& rhs,
- const MulParamsType& mul_params, Matrix<DstScalar>* dst,
+void CreateTrMulParams(const Mat<LhsScalar>& lhs, const Mat<RhsScalar>& rhs,
+ const MulParamsType& mul_params, Mat<DstScalar>* dst,
Path the_path, TrMulParams* params) {
// Fill in the fields we already know.
- params->src[Side::kLhs] = ToDMatrix(lhs);
- params->src[Side::kRhs] = ToDMatrix(rhs);
- params->dst = ToDMatrix(*dst);
+ params->src[Side::kLhs] = EraseType(lhs);
+ params->src[Side::kRhs] = EraseType(rhs);
+ params->dst = EraseType(*dst);
params->mul_params = ToVoidPtr(&mul_params);
// Create inner loops and packed matrices based on the Path.
@@ -341,8 +341,8 @@ void CreateTrMulParams(const Matrix<LhsScalar>& lhs,
template <typename LhsScalar, typename RhsScalar, typename DstScalar,
typename MulParamsType>
-void ReferenceMul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
- const MulParamsType& mul_params, Matrix<DstScalar>* dst) {
+void ReferenceMul(const Mat<LhsScalar>& lhs, const Mat<RhsScalar>& rhs,
+ const MulParamsType& mul_params, Mat<DstScalar>* dst) {
profiler::ScopeLabel label("ReferenceMul");
for (int i = 0; i < lhs.layout.rows; i++) {
for (int j = 0; j < rhs.layout.cols; j++) {
@@ -371,8 +371,8 @@ template <bool ReferenceMulIsEnabled>
struct CompileTimeEnabledReferenceMul {
template <typename LhsScalar, typename RhsScalar, typename DstScalar,
typename MulParamsType>
- static void Run(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
- const MulParamsType& mul_params, Matrix<DstScalar>* dst) {
+ static void Run(const Mat<LhsScalar>& lhs, const Mat<RhsScalar>& rhs,
+ const MulParamsType& mul_params, Mat<DstScalar>* dst) {
ReferenceMul(lhs, rhs, mul_params, dst);
}
};
@@ -383,8 +383,8 @@ template <>
struct CompileTimeEnabledReferenceMul</*ReferenceMulIsEnabled=*/false> {
template <typename LhsScalar, typename RhsScalar, typename DstScalar,
typename MulParamsType>
- static void Run(const Matrix<LhsScalar>&, const Matrix<RhsScalar>&,
- const MulParamsType&, Matrix<DstScalar>*) {
+ static void Run(const Mat<LhsScalar>&, const Mat<RhsScalar>&,
+ const MulParamsType&, Mat<DstScalar>*) {
RUY_DCHECK(false);
}
};
@@ -431,9 +431,9 @@ inline void HandlePrepackedCaching(TrMulParams* params,
template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
typename DstScalar, typename MulParamsType>
-void DispatchMul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
+void DispatchMul(const Mat<LhsScalar>& lhs, const Mat<RhsScalar>& rhs,
const MulParamsType& mul_params, Context* context,
- Matrix<DstScalar>* dst) {
+ Mat<DstScalar>* dst) {
static_assert(CompiledPaths != Path::kNone, "Must compile at least one Path");
static_assert((CompiledPaths & ~kAllPaths) == Path::kNone,
"CompiledPaths must be a subset of ruy::kAllPaths");
@@ -475,7 +475,7 @@ void DispatchMul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
//
// This is Ruy's main code path.
constexpr Path TrMulCompiledPaths = CompiledPaths & ~Path::kReference;
- Matrix<LhsScalar> transposed_lhs(lhs);
+ Mat<LhsScalar> transposed_lhs(lhs);
Transpose(&transposed_lhs);
TrMulParams params;
CreateTrMulParams<TrMulCompiledPaths>(transposed_lhs, rhs, mul_params, dst,
diff --git a/ruy/kernel_arm.h b/ruy/kernel_arm.h
index 770aba6..7cf68c9 100644
--- a/ruy/kernel_arm.h
+++ b/ruy/kernel_arm.h
@@ -20,8 +20,8 @@ limitations under the License.
#include <cstdint>
#include "ruy/common.h"
-#include "ruy/internal_matrix.h"
#include "ruy/kernel_common.h"
+#include "ruy/mat.h"
#include "ruy/matrix.h"
#include "ruy/mul_params.h"
#include "ruy/opt_set.h"
@@ -56,11 +56,9 @@ struct Kernel<Path::kNeon, std::int8_t, std::int8_t, DstScalar,
using RhsLayout = FixedKernelLayout<Order::kColMajor, 16, 4>;
Tuning tuning = Tuning::kAuto;
explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
- void Run(const PackedMatrix<std::int8_t>& lhs,
- const PackedMatrix<std::int8_t>& rhs,
+ void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs,
const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
- int start_col, int end_row, int end_col,
- Matrix<DstScalar>* dst) const {
+ int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
end_col, dst, &params);
@@ -85,11 +83,9 @@ struct Kernel<Path::kNeon, std::int8_t, std::int8_t, DstScalar,
using RhsLayout = FixedKernelLayout<Order::kColMajor, 16, 2>;
Tuning tuning = Tuning::kAuto;
explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
- void Run(const PackedMatrix<std::int8_t>& lhs,
- const PackedMatrix<std::int8_t>& rhs,
+ void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs,
const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
- int start_col, int end_row, int end_col,
- Matrix<DstScalar>* dst) const {
+ int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
end_col, dst, &params);
@@ -110,11 +106,9 @@ struct Kernel<Path::kNeonDotprod, std::int8_t, std::int8_t, DstScalar,
using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
- void Run(const PackedMatrix<std::int8_t>& lhs,
- const PackedMatrix<std::int8_t>& rhs,
+ void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs,
const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
- int start_col, int end_row, int end_col,
- Matrix<DstScalar>* dst) const {
+ int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
end_col, dst, &params);
@@ -142,9 +136,9 @@ struct Kernel<Path::kNeon, float, float, float, MulParams<float, float>> {
using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
- void Run(const PackedMatrix<float>& lhs, const PackedMatrix<float>& rhs,
+ void Run(const PMat<float>& lhs, const PMat<float>& rhs,
const MulParams<float, float>& mul_params, int start_row,
- int start_col, int end_row, int end_col, Matrix<float>* dst) const {
+ int start_col, int end_row, int end_col, Mat<float>* dst) const {
KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
end_col, dst, &params);
@@ -165,9 +159,9 @@ struct Kernel<Path::kNeon, float, float, float, MulParams<float, float>> {
using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 4>;
explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
- void Run(const PackedMatrix<float>& lhs, const PackedMatrix<float>& rhs,
+ void Run(const PMat<float>& lhs, const PMat<float>& rhs,
const MulParams<float, float>& mul_params, int start_row,
- int start_col, int end_row, int end_col, Matrix<float>* dst) const {
+ int start_col, int end_row, int end_col, Mat<float>* dst) const {
KernelParamsFloat<8, 4> params;
MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
@@ -190,9 +184,9 @@ struct Kernel<Path::kNeonDotprod, float, float, float,
using Base =
Kernel<Path::kNeon, float, float, float, MulParams<float, float>>;
explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
- void Run(const PackedMatrix<float>& lhs, const PackedMatrix<float>& rhs,
+ void Run(const PMat<float>& lhs, const PMat<float>& rhs,
const MulParams<float, float>& mul_params, int start_row,
- int start_col, int end_row, int end_col, Matrix<float>* dst) const {
+ int start_col, int end_row, int end_col, Mat<float>* dst) const {
KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
end_col, dst, &params);
diff --git a/ruy/kernel_common.h b/ruy/kernel_common.h
index 5717a63..6b22e7d 100644
--- a/ruy/kernel_common.h
+++ b/ruy/kernel_common.h
@@ -22,7 +22,7 @@ limitations under the License.
#include "ruy/check_macros.h"
#include "ruy/common.h"
-#include "ruy/internal_matrix.h"
+#include "ruy/mat.h"
#include "ruy/matrix.h"
#include "ruy/mul_params.h"
#include "ruy/opt_set.h"
@@ -41,11 +41,10 @@ struct Kernel {};
template <Path ThePath, typename LhsScalar, typename RhsScalar,
typename DstScalar, typename MulParamsType>
-void RunKernelTyped(Tuning tuning, const PackedMatrix<LhsScalar>& lhs,
- const PackedMatrix<RhsScalar>& rhs,
- const MulParamsType& mul_params, int start_row,
- int start_col, int end_row, int end_col,
- Matrix<DstScalar>* dst) {
+void RunKernelTyped(Tuning tuning, const PMat<LhsScalar>& lhs,
+ const PMat<RhsScalar>& rhs, const MulParamsType& mul_params,
+ int start_row, int start_col, int end_row, int end_col,
+ Mat<DstScalar>* dst) {
using Kernel =
Kernel<ThePath, LhsScalar, RhsScalar, DstScalar, MulParamsType>;
Kernel kernel(tuning);
@@ -82,13 +81,13 @@ void RunKernelTyped(Tuning tuning, const PackedMatrix<LhsScalar>& lhs,
// Main entry point for kernels.
template <Path ThePath, typename LhsScalar, typename RhsScalar,
typename DstScalar, typename MulParamsType>
-void RunKernel(Tuning tuning, const SidePair<PMatrix>& src, void* mul_params,
+void RunKernel(Tuning tuning, const SidePair<PEMat>& src, void* mul_params,
const SidePair<int>& start, const SidePair<int>& end,
- DMatrix* dst) {
- Matrix<DstScalar> mdst = ToMatrix<DstScalar>(*dst);
+ EMat* dst) {
+ Mat<DstScalar> mdst = UneraseType<DstScalar>(*dst);
RunKernelTyped<ThePath, LhsScalar, RhsScalar, DstScalar, MulParamsType>(
- tuning, ToPackedMatrix<LhsScalar>(src[Side::kLhs]),
- ToPackedMatrix<RhsScalar>(src[Side::kRhs]),
+ tuning, UneraseType<LhsScalar>(src[Side::kLhs]),
+ UneraseType<RhsScalar>(src[Side::kRhs]),
*static_cast<const MulParamsType*>(mul_params), start[Side::kLhs],
start[Side::kRhs], end[Side::kLhs], end[Side::kRhs], &mdst);
}
@@ -179,10 +178,9 @@ struct Kernel<Path::kStandardCpp, LhsScalar, RhsScalar, DstScalar,
using LhsLayout = typename MulParamsType::StandardCppKernelLhsLayout;
using RhsLayout = typename MulParamsType::StandardCppKernelRhsLayout;
explicit Kernel(Tuning) {}
- void Run(const PackedMatrix<LhsScalar>& lhs,
- const PackedMatrix<RhsScalar>& rhs, const MulParamsType& mul_params,
- int start_row, int start_col, int end_row, int end_col,
- Matrix<DstScalar>* dst) const {
+ void Run(const PMat<LhsScalar>& lhs, const PMat<RhsScalar>& rhs,
+ const MulParamsType& mul_params, int start_row, int start_col,
+ int end_row, int end_col, Mat<DstScalar>* dst) const {
// See the comment in RunKernelTyped. end_row may be larger than
// dst->layout.rows. It's the responsibility of the kernel to avoid
// overrunning dst boundaries, which we do here by computing
@@ -331,11 +329,11 @@ struct KernelParams8bit {
};
template <typename DstScalar, int LhsCols, int RhsCols>
-void MakeKernelParams8bit(const PackedMatrix<std::int8_t>& lhs,
- const PackedMatrix<std::int8_t>& rhs,
+void MakeKernelParams8bit(const PMat<std::int8_t>& lhs,
+ const PMat<std::int8_t>& rhs,
const MulParams<std::int32_t, DstScalar>& mul_params,
int start_row, int start_col, int end_row,
- int end_col, Matrix<DstScalar>* dst,
+ int end_col, Mat<DstScalar>* dst,
KernelParams8bit<LhsCols, RhsCols>* params) {
using Params = KernelParams8bit<LhsCols, RhsCols>;
@@ -428,11 +426,11 @@ struct KernelParamsFloat {
};
template <int LhsCols, int RhsCols>
-inline void MakeKernelParamsFloat(const PackedMatrix<float>& lhs,
- const PackedMatrix<float>& rhs,
+inline void MakeKernelParamsFloat(const PMat<float>& lhs,
+ const PMat<float>& rhs,
const MulParams<float, float>& mul_params,
int start_row, int start_col, int end_row,
- int end_col, Matrix<float>* dst,
+ int end_col, Mat<float>* dst,
KernelParamsFloat<LhsCols, RhsCols>* params) {
const int depth = lhs.layout.rows;
RUY_DCHECK_EQ(start_row % LhsCols, 0);
diff --git a/ruy/kernel_x86.h b/ruy/kernel_x86.h
index b948ff8..45695d1 100644
--- a/ruy/kernel_x86.h
+++ b/ruy/kernel_x86.h
@@ -19,8 +19,8 @@ limitations under the License.
#include <cstdint>
#include "ruy/common.h"
-#include "ruy/internal_matrix.h"
#include "ruy/kernel_common.h"
+#include "ruy/mat.h"
#include "ruy/matrix.h"
#include "ruy/mul_params.h"
#include "ruy/opt_set.h"
@@ -44,11 +44,9 @@ struct Kernel<Path::kSse42, std::int8_t, std::int8_t, DstScalar,
using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
- void Run(const PackedMatrix<std::int8_t>& lhs,
- const PackedMatrix<std::int8_t>& rhs,
+ void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs,
const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
- int start_col, int end_row, int end_col,
- Matrix<DstScalar>* dst) const {
+ int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
end_col, dst, &params);
@@ -64,9 +62,9 @@ struct Kernel<Path::kSse42, float, float, float, MulParams<float, float>> {
using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
- void Run(const PackedMatrix<float>& lhs, const PackedMatrix<float>& rhs,
+ void Run(const PMat<float>& lhs, const PMat<float>& rhs,
const MulParams<float, float>& mul_params, int start_row,
- int start_col, int end_row, int end_col, Matrix<float>* dst) const {
+ int start_col, int end_row, int end_col, Mat<float>* dst) const {
KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
end_col, dst, &params);
@@ -84,11 +82,9 @@ struct Kernel<Path::kAvx512, std::int8_t, std::int8_t, DstScalar,
using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>;
using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>;
explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
- void Run(const PackedMatrix<std::int8_t>& lhs,
- const PackedMatrix<std::int8_t>& rhs,
+ void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs,
const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
- int start_col, int end_row, int end_col,
- Matrix<DstScalar>* dst) const {
+ int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
end_col, dst, &params);
@@ -109,9 +105,9 @@ struct Kernel<Path::kAvx512, float, float, float, MulParams<float, float>> {
using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 16>;
using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 16>;
explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
- void Run(const PackedMatrix<float>& lhs, const PackedMatrix<float>& rhs,
+ void Run(const PMat<float>& lhs, const PMat<float>& rhs,
const MulParams<float, float>& mul_params, int start_row,
- int start_col, int end_row, int end_col, Matrix<float>* dst) const {
+ int start_col, int end_row, int end_col, Mat<float>* dst) const {
KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
end_col, dst, &params);
@@ -133,11 +129,9 @@ struct Kernel<Path::kAvx2, std::int8_t, std::int8_t, DstScalar,
using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 8>;
explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
- void Run(const PackedMatrix<std::int8_t>& lhs,
- const PackedMatrix<std::int8_t>& rhs,
+ void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs,
const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
- int start_col, int end_row, int end_col,
- Matrix<DstScalar>* dst) const {
+ int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
end_col, dst, &params);
@@ -158,9 +152,9 @@ struct Kernel<Path::kAvx2, float, float, float, MulParams<float, float>> {
using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
- void Run(const PackedMatrix<float>& lhs, const PackedMatrix<float>& rhs,
+ void Run(const PMat<float>& lhs, const PMat<float>& rhs,
const MulParams<float, float>& mul_params, int start_row,
- int start_col, int end_row, int end_col, Matrix<float>* dst) const {
+ int start_col, int end_row, int end_col, Mat<float>* dst) const {
KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
end_col, dst, &params);
@@ -185,11 +179,9 @@ struct Kernel<Path::kAvxVnni, std::int8_t, std::int8_t, DstScalar,
using LhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>;
using RhsLayout = FixedKernelLayout<Order::kColMajor, 4, 16>;
explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
- void Run(const PackedMatrix<std::int8_t>& lhs,
- const PackedMatrix<std::int8_t>& rhs,
+ void Run(const PMat<std::int8_t>& lhs, const PMat<std::int8_t>& rhs,
const MulParams<std::int32_t, DstScalar>& mul_params, int start_row,
- int start_col, int end_row, int end_col,
- Matrix<DstScalar>* dst) const {
+ int start_col, int end_row, int end_col, Mat<DstScalar>* dst) const {
KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
end_col, dst, &params);
@@ -205,9 +197,9 @@ struct Kernel<Path::kAvxVnni, float, float, float, MulParams<float, float>> {
using LhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 16>;
using RhsLayout = FixedKernelLayout<Order::kRowMajor, 1, 16>;
explicit Kernel(Tuning tuning_) : tuning(tuning_) {}
- void Run(const PackedMatrix<float>& lhs, const PackedMatrix<float>& rhs,
+ void Run(const PMat<float>& lhs, const PMat<float>& rhs,
const MulParams<float, float>& mul_params, int start_row,
- int start_col, int end_row, int end_col, Matrix<float>* dst) const {
+ int start_col, int end_row, int end_col, Mat<float>* dst) const {
KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
end_col, dst, &params);
diff --git a/ruy/internal_matrix.h b/ruy/mat.h
index 834352d..fe4dfe8 100644
--- a/ruy/internal_matrix.h
+++ b/ruy/mat.h
@@ -14,32 +14,47 @@ limitations under the License.
==============================================================================*/
// Internal types and helpers for matrices.
+// "Mat" is the name we use to refer to our internal matrix classes; it can be
+// thought of as a shorter version of "InternalMatrix"`
//
-// Ruy has a couple slightly different notions of matrices, besides the
+// Ruy has four internal matrix classes, besides the
// Matrix<T> class that we expose to the user-facing API.
//
// TODO(silvasean): Put parts of this architecture description somewhere more
// prominent.
//
-// The 4 main matrix types are:
-// - Matrix<T>: This is a user-facing type on Ruy's external API boundary. It is
-// also used internally.
-// - DMatrix: This is a type-erased version of Matrix<T>. "D" = "dynamic".
-// - PMatrix: This represents a packed matrix, which requires tracking kernel
-// layout and row/column sums for quantization. It is type-erased.
-// - PackedMatrix<T>: This is a statically typed variant of PMatrix for
-// convenience inside typed routines.
+// The 4 internal matrix classes are named Mat, EMat, PMat, PEMat, where:
+// - "E" indicates a type-erased class, storing a void* pointer and a runtime
+// enum value to track the scalar type, as opposed to being templatized
+// on a Scalar type and storing a Scalar* pointer.
+// - "P" indicates a packed matrix class, the output of the packing code and
+// input of the kernel code. See comments in pack.h regarding packing.
+//
+// In other words:
+//
+// Plain matrices Packed matrices
+// +----------------------------------
+// Templated | Mat, Matrix PMat
+// Type-erased | EMat PEMat
//
// Note that Matrix<T> is *not* implemented in terms of the internal types. It
-// is an independent, simple, and user-facing type.
+// is an independent, simple, and user-facing type. Matrix<T> is functionally
+// equivalent to Mat, but we keep it separate to insulate internals from
+// interface and to be able to make different compromises in internals
+// vs interface: in internals we prefer Mat to be a C-style struct with
+// raw data member access and to be similar to the other PMat/EMat/PEMat
+// classes for consistency.
//
// The use of type-erasure might seem surprising for a library like Ruy with a
// heavily-templated entry point, but it is motivated by the desire for most of
// Ruy's "middle-end" to be non-templated. Ruy can be thought of as having 3
// main parts:
-// - "front-end" (dispatch.h) - this is the highly templated ruy::Mul entry
-// point, along with routines that select RunKernel and RunPack implementations
-// statically based on those template parameters.
+// - "entry-point" (ruy.h) - this is the highly templated ruy::Mul entry
+// point.
+// - "front-end" (dispatch.h) - the work to handle the entry-point call down
+// to the point where it can be handed off to the middle/back ends below. That
+// includes routines that select RunKernel and RunPack
+// implementations statically based on those template parameters.
// - "back-end" (kernel.h, pack.h)- this consists of the implementations of
// RunKernel and RunPack, often in assembly code, which are the building blocks
// that Ruy calls to perform matrix multiplication. These are templated so that
@@ -56,9 +71,10 @@ limitations under the License.
// and thus the static type information is still present.
//
// Each layer of Ruy uses matrix types:
-// - "front-end": Matrix<T>
-// - "middle-end": DMatrix, PMatrix
-// - "back-end": Matrix<T>, PackedMatrix<T>
+// - "entry-point": Matrix<T>
+// - "front-end": Mat
+// - "middle-end": EMat, PEMat
+// - "back-end": Mat, PMat
//
// The use of separate types for packed matrices is not essential, but makes it
// obvious at a glance whether a matrix is a packed matrix or not. We would
@@ -70,22 +86,14 @@ limitations under the License.
// definition for Matrix<T> and see a very simple definition with no internal
// details like sums and kernel block layout.
//
-// To present another structured view of our various matrix types, here's a
-// table:
-// Plain matrices Packed matrices
-// +----------------------------------
-// Templated | Matrix<T> PackedMatrix<T>
-// Type-erased | DMatrix PMatrix
-//
-//
// There is 1 additional matrix type not mentioned above, due to its low
// importance:
-// - PrepackedMatrix: This is a user-facing version of PMatrix. It has the bare
-// minimum of fields needed for representing the raw data and sums buffers of a
-// packed matrix for the "advanced" explicit pre-packing API. This type plays no
-// role in Ruy's internals and can generally by ignored. The only reason it
-// exists is so that PMatrix is not exposed to users -- we prefer to keep the
-// internal matrix types hidden, even from "advanced" users.
+// - PrepackedMatrix: This is a user-facing version of PEMat. It has
+// the bare minimum of fields needed for representing the raw data and sums
+// buffers of a packed matrix for the "advanced" explicit pre-packing API. This
+// type plays no role in Ruy's internals and can generally by ignored. The only
+// reason it exists is so that PEMat is not exposed to users -- we
+// prefer to keep the internal matrix types hidden, even from "advanced" users.
#ifndef RUY_RUY_INTERNAL_MATRIX_H_
#define RUY_RUY_INTERNAL_MATRIX_H_
@@ -102,8 +110,8 @@ limitations under the License.
namespace ruy {
-// Internal counterpart of Layout.
-struct InternalLayout {
+// Internal counterpart of Layout, used by Mat.
+struct MatLayout final {
std::int32_t rows = 0;
std::int32_t cols = 0;
// Stride is the offset between two adjacent matrix elements
@@ -112,17 +120,8 @@ struct InternalLayout {
Order order = Order::kColMajor;
};
-inline Layout ToLayout(const InternalLayout& src) {
- Layout ret;
- ret.set_rows(src.rows);
- ret.set_cols(src.cols);
- ret.set_stride(src.stride);
- ret.set_order(src.order);
- return ret;
-}
-
-inline InternalLayout ToInternalLayout(const Layout& src) {
- InternalLayout ret;
+inline MatLayout ToInternal(const Layout& src) {
+ MatLayout ret;
ret.rows = src.rows;
ret.cols = src.cols;
ret.stride = src.stride;
@@ -130,6 +129,35 @@ inline InternalLayout ToInternalLayout(const Layout& src) {
return ret;
}
+// Internal counterpart of Matrix
+template <typename Scalar>
+struct Mat final {
+ detail::ConstCheckingPtr<Scalar> data;
+ MatLayout layout;
+ Scalar zero_point = 0;
+ bool cacheable = false;
+};
+
+template <typename Scalar>
+inline Mat<Scalar> ToInternal(const Matrix<Scalar>& src) {
+ Mat<Scalar> ret;
+ ret.data = src.data;
+ ret.layout = ToInternal(src.layout);
+ ret.zero_point = src.zero_point;
+ ret.cacheable = src.cacheable;
+ return ret;
+}
+
+template <typename Scalar>
+inline Mat<Scalar> ToInternal(Matrix<Scalar>& src) {
+ Mat<Scalar> ret;
+ ret.data = src.data;
+ ret.layout = ToInternal(src.layout);
+ ret.zero_point = src.zero_point;
+ ret.cacheable = src.cacheable;
+ return ret;
+}
+
// KernelLayout describes small-scale block structure in a packed matrix layout.
// It's a runtime (as opposed to compile-time-constant) version of the
// FixedKernelLayout struct used to declare kernel layouts.
@@ -144,7 +172,7 @@ inline InternalLayout ToInternalLayout(const Layout& src) {
// Note that in the case of kcols=1, krows=1, this degenerates to
// `[cols, rows, 1, 1]` which is equivalent to having no small-scale block
// structure.
-struct KernelLayout {
+struct KernelLayout final {
Order order = Order::kColMajor;
std::uint8_t rows = 1;
std::uint8_t cols = 1;
@@ -154,9 +182,9 @@ struct KernelLayout {
// the input matrices. This block structure is necessary for the kernels to
// process data efficiently.
//
-// This struct is very similar to InternalLayout, but has the extra KernelLayout
+// This struct is very similar to MatLayout, but has the extra KernelLayout
// field.
-struct PackedLayout {
+struct PMatLayout final {
std::int32_t rows = 0;
std::int32_t cols = 0;
// Stride is the offset between two adjacent matrix elements
@@ -179,7 +207,7 @@ struct PackedLayout {
// this file, Ruy's "front-end", which is templated, instantiates all the
// necessary "back-end" routines with complete static knowledge of all the
// types.
-struct Type {
+struct Type final {
template <typename T>
static Type Create() {
Type ret;
@@ -202,26 +230,26 @@ struct Type {
};
// Type-erased matrix.
-struct DMatrix {
+struct EMat final {
Type data_type;
void* data = nullptr;
- InternalLayout layout;
+ MatLayout layout;
std::int32_t zero_point = 0;
};
// Type-erased packed matrix.
-struct PMatrix {
+struct PEMat final {
Type data_type;
void* data = nullptr;
Type sums_type;
void* sums = nullptr;
- PackedLayout layout;
+ PMatLayout layout;
std::int32_t zero_point = 0;
};
// Convenient typed helper for packed matrices.
template <typename Scalar>
-struct PackedMatrix {
+struct PMat final {
// The row/column sums needed for quantized matrix multiplication when
// the opposite operand of the multiplication uses a non-symmetric zero
// point.
@@ -238,36 +266,36 @@ struct PackedMatrix {
Scalar* data = nullptr;
SumsType* sums = nullptr;
- PackedLayout layout;
+ PMatLayout layout;
std::int32_t zero_point = 0;
};
template <typename T>
-DMatrix ToDMatrix(const Matrix<T>& matrix) {
- DMatrix ret;
+EMat EraseType(const Mat<T>& matrix) {
+ EMat ret;
ret.data_type = Type::Create<T>();
ret.data = ToVoidPtr(matrix.data.get());
- ret.layout = ToInternalLayout(matrix.layout);
+ ret.layout = matrix.layout;
ret.zero_point = matrix.zero_point;
return ret;
}
template <typename T>
-Matrix<T> ToMatrix(const DMatrix& dmatrix) {
+Mat<T> UneraseType(const EMat& dmatrix) {
dmatrix.data_type.AssertIs<T>();
- Matrix<T> ret;
+ Mat<T> ret;
ret.data = static_cast<T*>(dmatrix.data);
- ret.layout = ToLayout(dmatrix.layout);
+ ret.layout = dmatrix.layout;
ret.zero_point = dmatrix.zero_point;
return ret;
}
template <typename T>
-PackedMatrix<T> ToPackedMatrix(const PMatrix& pmatrix) {
- using SumsType = typename PackedMatrix<T>::SumsType;
+PMat<T> UneraseType(const PEMat& pmatrix) {
+ using SumsType = typename PMat<T>::SumsType;
pmatrix.data_type.AssertIs<T>();
pmatrix.sums_type.AssertIs<SumsType>();
- PackedMatrix<T> ret;
+ PMat<T> ret;
ret.data = static_cast<T*>(pmatrix.data);
ret.sums = static_cast<SumsType*>(pmatrix.sums);
ret.layout = pmatrix.layout;
@@ -275,7 +303,7 @@ PackedMatrix<T> ToPackedMatrix(const PMatrix& pmatrix) {
return ret;
}
-// Helpers for InternalLayout / PackedLayout.
+// Helpers for MatLayout / PMatLayout.
template <typename LayoutType>
inline bool IsUnstrided(const LayoutType& layout) {
@@ -304,7 +332,7 @@ int FlatSize(const LayoutType& layout) {
}
// TODO(b/130417400) add a unit test
-inline int Offset(const Layout& layout, int row, int col) {
+inline int Offset(const MatLayout& layout, int row, int col) {
// TODO(benoitjacob) - should check this but this make the _slow tests take
// 5x longer. Find a mitigation like in Eigen with an 'internal' variant
// bypassing the check?
@@ -318,7 +346,7 @@ inline int Offset(const Layout& layout, int row, int col) {
}
// TODO(b/130417400) add a unit test
-inline int Offset(const PackedLayout& layout, int row, int col) {
+inline int Offset(const PMatLayout& layout, int row, int col) {
RUY_DCHECK(is_pot(layout.kernel.rows));
RUY_DCHECK(is_pot(layout.kernel.cols));
int row_outer = row & ~(layout.kernel.rows - 1);
@@ -340,48 +368,48 @@ inline int Offset(const PackedLayout& layout, int row, int col) {
return offset_outer + offset_inner;
}
-// Helpers for Matrix<T>.
+// Helpers for Mat<T>.
template <typename Scalar>
-const Scalar* ElementPtr(const Matrix<Scalar>& mat, int row, int col) {
+const Scalar* ElementPtr(const Mat<Scalar>& mat, int row, int col) {
return mat.data.get() + Offset(mat.layout, row, col);
}
template <typename Scalar>
-Scalar* ElementPtr(Matrix<Scalar>* mat, int row, int col) {
+Scalar* ElementPtr(Mat<Scalar>* mat, int row, int col) {
return mat->data.get() + Offset(mat->layout, row, col);
}
template <typename Scalar>
-Scalar Element(const Matrix<Scalar>& mat, int row, int col) {
+Scalar Element(const Mat<Scalar>& mat, int row, int col) {
return *ElementPtr(mat, row, col);
}
-// Helpers for PackedMatrix<T>.
+// Helpers for PMat<T>.
// Duplicated from Matrix<T>, but the duplication seems acceptable.
template <typename Scalar>
-const Scalar* ElementPtr(const PackedMatrix<Scalar>& mat, int row, int col) {
+const Scalar* ElementPtr(const PMat<Scalar>& mat, int row, int col) {
return mat.data + Offset(mat.layout, row, col);
}
template <typename Scalar>
-Scalar* ElementPtr(PackedMatrix<Scalar>* mat, int row, int col) {
+Scalar* ElementPtr(PMat<Scalar>* mat, int row, int col) {
return mat->data + Offset(mat->layout, row, col);
}
template <typename Scalar>
-Scalar Element(const PackedMatrix<Scalar>& mat, int row, int col) {
+Scalar Element(const PMat<Scalar>& mat, int row, int col) {
return *ElementPtr(mat, row, col);
}
-// Helpers for PMatrix.
+// Helpers for PEMat.
-inline int DataSize(const PMatrix& packed) {
+inline int DataSize(const PEMat& packed) {
return FlatSize(packed.layout) * packed.data_type.size;
}
-inline int SumsSize(const PMatrix& packed) {
+inline int SumsSize(const PEMat& packed) {
// Packed matrices are only relevant for Ruy's TrMul implementations. For
// TrMul, the number of sums is always equal to the number of columns.
return packed.layout.cols * packed.sums_type.size;
diff --git a/ruy/matrix.h b/ruy/matrix.h
index 16f330c..389f37f 100644
--- a/ruy/matrix.h
+++ b/ruy/matrix.h
@@ -223,6 +223,35 @@ template <Order tOrder, int tRows, int tCols>
constexpr int FixedKernelLayout<tOrder, tRows, tCols>::kRows;
#endif
+// TODO(b/130417400) add a unit test
+inline int Offset(const Layout& layout, int row, int col) {
+ // TODO(benoitjacob) - should check this but this make the _slow tests take
+ // 5x longer. Find a mitigation like in Eigen with an 'internal' variant
+ // bypassing the check?
+ // RUY_DCHECK_GE(row, 0);
+ // RUY_DCHECK_GE(col, 0);
+ // RUY_DCHECK_LT(row, layout.rows);
+ // RUY_DCHECK_LT(col, layout.cols);
+ int row_stride = layout.order == Order::kColMajor ? 1 : layout.stride;
+ int col_stride = layout.order == Order::kRowMajor ? 1 : layout.stride;
+ return row * row_stride + col * col_stride;
+}
+
+template <typename Scalar>
+const Scalar* ElementPtr(const Matrix<Scalar>& mat, int row, int col) {
+ return mat.data.get() + Offset(mat.layout, row, col);
+}
+
+template <typename Scalar>
+Scalar* ElementPtr(Matrix<Scalar>* mat, int row, int col) {
+ return mat->data.get() + Offset(mat->layout, row, col);
+}
+
+template <typename Scalar>
+Scalar Element(const Matrix<Scalar>& mat, int row, int col) {
+ return *ElementPtr(mat, row, col);
+}
+
} // namespace ruy
#endif // RUY_RUY_MATRIX_H_
diff --git a/ruy/pack_arm.h b/ruy/pack_arm.h
index b08e8c9..1b52b6e 100644
--- a/ruy/pack_arm.h
+++ b/ruy/pack_arm.h
@@ -88,7 +88,7 @@ limitations under the License.
#include "ruy/check_macros.h"
#include "ruy/common.h"
-#include "ruy/internal_matrix.h"
+#include "ruy/mat.h"
#include "ruy/matrix.h"
#include "ruy/opt_set.h"
#include "ruy/pack_common.h"
@@ -142,8 +142,8 @@ struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kColMajor, 16, 4>, Scalar,
static constexpr int kInputXor =
std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
- static void Run(Tuning tuning, const Matrix<Scalar>& src_matrix,
- PackedMatrix<std::int8_t>* packed_matrix, int start_col,
+ static void Run(Tuning tuning, const Mat<Scalar>& src_matrix,
+ PMat<std::int8_t>* packed_matrix, int start_col,
int end_col) {
RUY_DCHECK(IsColMajor(src_matrix.layout));
RUY_DCHECK(IsColMajor(packed_matrix->layout));
@@ -225,8 +225,8 @@ struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kColMajor, 16, 2>, Scalar,
"");
static constexpr int kInputXor =
std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
- static void Run(Tuning, const Matrix<Scalar>& src_matrix,
- PackedMatrix<std::int8_t>* packed_matrix, int start_col,
+ static void Run(Tuning, const Mat<Scalar>& src_matrix,
+ PMat<std::int8_t>* packed_matrix, int start_col,
int end_col) {
RUY_DCHECK(IsColMajor(src_matrix.layout));
RUY_DCHECK(IsColMajor(packed_matrix->layout));
@@ -274,8 +274,8 @@ struct PackImpl<Path::kNeonDotprod, FixedKernelLayout<Order::kColMajor, 4, 8>,
static constexpr int kInputXor =
std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
- static void Run(Tuning tuning, const Matrix<Scalar>& src_matrix,
- PackedMatrix<std::int8_t>* packed_matrix, int start_col,
+ static void Run(Tuning tuning, const Mat<Scalar>& src_matrix,
+ PMat<std::int8_t>* packed_matrix, int start_col,
int end_col) {
RUY_DCHECK(IsColMajor(src_matrix.layout));
RUY_DCHECK(IsColMajor(packed_matrix->layout));
@@ -355,9 +355,8 @@ void PackFloatNeonOutOfOrder(const float* src_ptr0, const float* src_ptr1,
template <>
struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kRowMajor, 1, 8>, float,
float, float> {
- static void Run(Tuning tuning, const Matrix<float>& src_matrix,
- PackedMatrix<float>* packed_matrix, int start_col,
- int end_col) {
+ static void Run(Tuning tuning, const Mat<float>& src_matrix,
+ PMat<float>* packed_matrix, int start_col, int end_col) {
RUY_DCHECK(IsColMajor(src_matrix.layout));
RUY_DCHECK(IsColMajor(packed_matrix->layout));
RUY_DCHECK_EQ(start_col % 8, 0);
@@ -432,9 +431,8 @@ struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kRowMajor, 1, 8>, float,
template <>
struct PackImpl<Path::kNeon, FixedKernelLayout<Order::kRowMajor, 1, 4>, float,
float, float> {
- static void Run(Tuning, const Matrix<float>& src_matrix,
- PackedMatrix<float>* packed_matrix, int start_col,
- int end_col) {
+ static void Run(Tuning, const Mat<float>& src_matrix,
+ PMat<float>* packed_matrix, int start_col, int end_col) {
RUY_DCHECK(IsColMajor(src_matrix.layout));
RUY_DCHECK(IsColMajor(packed_matrix->layout));
RUY_DCHECK_EQ(start_col % 4, 0);
diff --git a/ruy/pack_common.h b/ruy/pack_common.h
index d5df567..8e62729 100644
--- a/ruy/pack_common.h
+++ b/ruy/pack_common.h
@@ -87,7 +87,7 @@ limitations under the License.
#include "ruy/check_macros.h"
#include "ruy/common.h"
-#include "ruy/internal_matrix.h"
+#include "ruy/mat.h"
#include "ruy/matrix.h"
#include "ruy/opt_set.h"
#include "ruy/path.h"
@@ -193,8 +193,8 @@ template <typename FixedKernelLayout, typename Scalar, typename PackedScalar,
typename SumsType>
struct PackImpl<Path::kStandardCpp, FixedKernelLayout, Scalar, PackedScalar,
SumsType> {
- static void Run(Tuning, const Matrix<Scalar>& src_matrix,
- PackedMatrix<PackedScalar>* packed_matrix, int start_col,
+ static void Run(Tuning, const Mat<Scalar>& src_matrix,
+ PMat<PackedScalar>* packed_matrix, int start_col,
int end_col) {
profiler::ScopeLabel label("Pack (generic)");
RUY_DCHECK_EQ((end_col - start_col) % FixedKernelLayout::kCols, 0);
@@ -231,12 +231,11 @@ RUY_INHERIT_PACK(Path::kAvx512, Path::kAvxVnni)
// Main entry point for packing.
template <Path ThePath, typename FixedKernelLayout, typename Scalar,
typename PackedScalar>
-void RunPack(Tuning tuning, const DMatrix& src_matrix, PMatrix* packed_matrix,
+void RunPack(Tuning tuning, const EMat& src_matrix, PEMat* packed_matrix,
int start_col, int end_col) {
- using SumsType = typename PackedMatrix<PackedScalar>::SumsType;
- Matrix<Scalar> src = ToMatrix<Scalar>(src_matrix);
- PackedMatrix<PackedScalar> packed =
- ToPackedMatrix<PackedScalar>(*packed_matrix);
+ using SumsType = typename PMat<PackedScalar>::SumsType;
+ Mat<Scalar> src = UneraseType<Scalar>(src_matrix);
+ PMat<PackedScalar> packed = UneraseType<PackedScalar>(*packed_matrix);
PackImpl<ThePath, FixedKernelLayout, Scalar, PackedScalar, SumsType>::Run(
tuning, src, &packed, start_col, end_col);
}
diff --git a/ruy/pack_x86.h b/ruy/pack_x86.h
index 4334b25..10d2148 100644
--- a/ruy/pack_x86.h
+++ b/ruy/pack_x86.h
@@ -89,7 +89,7 @@ limitations under the License.
#include "ruy/check_macros.h"
#include "ruy/common.h"
-#include "ruy/internal_matrix.h"
+#include "ruy/mat.h"
#include "ruy/matrix.h"
#include "ruy/opt_set.h"
#include "ruy/pack_common.h"
@@ -122,8 +122,8 @@ struct PackImpl<Path::kSse42, FixedKernelLayout<Order::kColMajor, 4, 8>, Scalar,
static constexpr std::int8_t kInputXor =
std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
- static void Run(Tuning, const Matrix<Scalar>& src_matrix,
- PackedMatrix<std::int8_t>* packed_matrix, int start_col,
+ static void Run(Tuning, const Mat<Scalar>& src_matrix,
+ PMat<std::int8_t>* packed_matrix, int start_col,
int end_col) {
profiler::ScopeLabel label("Pack (SSE 4.2 8-bit)");
@@ -165,9 +165,8 @@ template <>
struct PackImpl<Path::kSse42, FixedKernelLayout<Order::kRowMajor, 1, 8>, float,
float, float> {
using Layout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
- static void Run(Tuning, const Matrix<float>& src_matrix,
- PackedMatrix<float>* packed_matrix, int start_col,
- int end_col) {
+ static void Run(Tuning, const Mat<float>& src_matrix,
+ PMat<float>* packed_matrix, int start_col, int end_col) {
profiler::ScopeLabel label("Pack (SSE 4.2 float)");
RUY_DCHECK(IsColMajor(src_matrix.layout));
@@ -209,8 +208,8 @@ struct PackImpl<Path::kAvx2, FixedKernelLayout<Order::kColMajor, 4, 8>, Scalar,
static constexpr std::int8_t kInputXor =
std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
- static void Run(Tuning, const Matrix<Scalar>& src_matrix,
- PackedMatrix<std::int8_t>* packed_matrix, int start_col,
+ static void Run(Tuning, const Mat<Scalar>& src_matrix,
+ PMat<std::int8_t>* packed_matrix, int start_col,
int end_col) {
profiler::ScopeLabel label("Pack (AVX2 8-bit)");
@@ -248,9 +247,8 @@ template <>
struct PackImpl<Path::kAvx2, FixedKernelLayout<Order::kRowMajor, 1, 8>, float,
float, float> {
using Layout = FixedKernelLayout<Order::kRowMajor, 1, 8>;
- static void Run(Tuning, const Matrix<float>& src_matrix,
- PackedMatrix<float>* packed_matrix, int start_col,
- int end_col) {
+ static void Run(Tuning, const Mat<float>& src_matrix,
+ PMat<float>* packed_matrix, int start_col, int end_col) {
profiler::ScopeLabel label("Pack (AVX2 float)");
RUY_DCHECK(IsColMajor(src_matrix.layout));
@@ -294,8 +292,8 @@ struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kColMajor, 4, 16>,
static constexpr std::int8_t kInputXor =
std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
- static void Run(Tuning, const Matrix<Scalar>& src_matrix,
- PackedMatrix<std::int8_t>* packed_matrix, int start_col,
+ static void Run(Tuning, const Mat<Scalar>& src_matrix,
+ PMat<std::int8_t>* packed_matrix, int start_col,
int end_col) {
profiler::ScopeLabel label("Pack (AVX-512 8-bit)");
@@ -333,9 +331,8 @@ void PackFloatAvx512(const float* src_ptr, const float* zerobuf, int src_stride,
template <>
struct PackImpl<Path::kAvx512, FixedKernelLayout<Order::kRowMajor, 1, 16>,
float, float, float> {
- static void Run(Tuning, const Matrix<float>& src_matrix,
- PackedMatrix<float>* packed_matrix, int start_col,
- int end_col) {
+ static void Run(Tuning, const Mat<float>& src_matrix,
+ PMat<float>* packed_matrix, int start_col, int end_col) {
profiler::ScopeLabel label("Pack (AVX-512 float)");
using Layout = FixedKernelLayout<Order::kRowMajor, 1, 16>;
RUY_DCHECK(IsColMajor(src_matrix.layout));
@@ -383,8 +380,8 @@ struct PackImpl<Path::kAvxVnni, FixedKernelLayout<Order::kColMajor, 4, 16>,
static constexpr std::int8_t kInputXor =
std::is_same<Scalar, std::int8_t>::value ? 0 : 0x80;
- static void Run(Tuning, const Matrix<Scalar>& src_matrix,
- PackedMatrix<std::int8_t>* packed_matrix, int start_col,
+ static void Run(Tuning, const Mat<Scalar>& src_matrix,
+ PMat<std::int8_t>* packed_matrix, int start_col,
int end_col) {
profiler::ScopeLabel label("Pack (AVX-512 8-bit)");
@@ -427,9 +424,8 @@ void PackFloatAvxVnni(const float* src_ptr, const float* zerobuf,
template <>
struct PackImpl<Path::kAvxVnni, FixedKernelLayout<Order::kRowMajor, 1, 16>,
float, float, float> {
- static void Run(Tuning, const Matrix<float>& src_matrix,
- PackedMatrix<float>* packed_matrix, int start_col,
- int end_col) {
+ static void Run(Tuning, const Mat<float>& src_matrix,
+ PMat<float>* packed_matrix, int start_col, int end_col) {
profiler::ScopeLabel label("Pack (AVX-512 float)");
using Layout = FixedKernelLayout<Order::kRowMajor, 1, 16>;
diff --git a/ruy/prepack.h b/ruy/prepack.h
index c2d5975..847992b 100644
--- a/ruy/prepack.h
+++ b/ruy/prepack.h
@@ -25,7 +25,7 @@ limitations under the License.
#include "ruy/context.h"
#include "ruy/context_internal.h"
#include "ruy/dispatch.h"
-#include "ruy/internal_matrix.h"
+#include "ruy/mat.h"
#include "ruy/matrix.h"
#include "ruy/mul_params.h"
#include "ruy/path.h"
@@ -39,17 +39,16 @@ namespace ruy {
template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
typename DstScalar, typename MulParamsType>
-void PrePackForMulInternal(const Matrix<LhsScalar>& lhs,
- const Matrix<RhsScalar>& rhs,
+void PrePackForMulInternal(const Mat<LhsScalar>& lhs, const Mat<RhsScalar>& rhs,
const MulParamsType& mul_params, Context* context,
- Matrix<DstScalar>* dst,
+ Mat<DstScalar>* dst,
SidePair<PrepackedMatrix*> prepacked,
std::function<void*(int)> alloc_fn) {
profiler::ScopeLabel label("PrePackForMul");
Path the_path = ContextInternal::GetPathToTake<CompiledPaths>(context);
RUY_CHECK_NE(the_path, Path::kReference);
constexpr Path TrMulCompiledPaths = CompiledPaths & ~Path::kReference;
- Matrix<LhsScalar> transposed_lhs(lhs);
+ Mat<LhsScalar> transposed_lhs(lhs);
Transpose(&transposed_lhs);
TrMulParams params;
CreateTrMulParams<TrMulCompiledPaths>(transposed_lhs, rhs, mul_params, dst,
@@ -75,10 +74,10 @@ void PrePackForMulInternal(const Matrix<LhsScalar>& lhs,
template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
typename DstScalar, typename MulParamsType>
-void MulWithPrepackedInternal(const Matrix<LhsScalar>& lhs,
- const Matrix<RhsScalar>& rhs,
+void MulWithPrepackedInternal(const Mat<LhsScalar>& lhs,
+ const Mat<RhsScalar>& rhs,
const MulParamsType& mul_params, Context* context,
- Matrix<DstScalar>* dst,
+ Mat<DstScalar>* dst,
SidePair<PrepackedMatrix*> prepacked) {
profiler::ScopeLabel label("MulWithPrepacked");
@@ -89,7 +88,7 @@ void MulWithPrepackedInternal(const Matrix<LhsScalar>& lhs,
Path the_path = ContextInternal::GetPathToTake<CompiledPaths>(context);
RUY_CHECK_NE(the_path, Path::kReference);
constexpr Path TrMulCompiledPaths = CompiledPaths & ~Path::kReference;
- Matrix<LhsScalar> transposed_lhs(lhs);
+ Mat<LhsScalar> transposed_lhs(lhs);
Transpose(&transposed_lhs);
TrMulParams params;
CreateTrMulParams<TrMulCompiledPaths>(transposed_lhs, rhs, mul_params, dst,
diff --git a/ruy/ruy.h b/ruy/ruy.h
index 2260d71..eaebf9f 100644
--- a/ruy/ruy.h
+++ b/ruy/ruy.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "ruy/context.h"
#include "ruy/dispatch.h"
+#include "ruy/mat.h"
#include "ruy/matrix.h"
#include "ruy/mul_params.h"
#include "ruy/path.h"
@@ -74,8 +75,12 @@ template <typename LhsScalar, typename RhsScalar, typename DstScalar,
void Mul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
const MulParamsType& mul_params, Context* context,
Matrix<DstScalar>* dst) {
+ Mat<LhsScalar> internal_lhs = ToInternal(lhs);
+ Mat<RhsScalar> internal_rhs = ToInternal(rhs);
+ Mat<DstScalar> internal_dst = ToInternal(*dst);
DispatchMul<ruy::kDefaultPaths, LhsScalar, RhsScalar, DstScalar,
- MulParamsType>(lhs, rhs, mul_params, context, dst);
+ MulParamsType>(internal_lhs, internal_rhs, mul_params, context,
+ &internal_dst);
}
// Variant of ruy::Mul allowing to specify a custom OR-ed set of Path's to
@@ -85,8 +90,11 @@ template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
void Mul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
const MulParamsType& mul_params, Context* context,
Matrix<DstScalar>* dst) {
+ Mat<LhsScalar> internal_lhs = ToInternal(lhs);
+ Mat<RhsScalar> internal_rhs = ToInternal(rhs);
+ Mat<DstScalar> internal_dst = ToInternal(*dst);
DispatchMul<CompiledPaths, LhsScalar, RhsScalar, DstScalar, MulParamsType>(
- lhs, rhs, mul_params, context, dst);
+ internal_lhs, internal_rhs, mul_params, context, &internal_dst);
}
} // namespace ruy
diff --git a/ruy/ruy_advanced.h b/ruy/ruy_advanced.h
index f034a51..8f54533 100644
--- a/ruy/ruy_advanced.h
+++ b/ruy/ruy_advanced.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <functional>
#include "ruy/context.h"
+#include "ruy/mat.h"
#include "ruy/matrix.h"
#include "ruy/path.h"
#include "ruy/prepack.h"
@@ -49,9 +50,13 @@ void PrePackForMul(const Matrix<LhsScalar>& lhs, const Matrix<RhsScalar>& rhs,
Matrix<DstScalar>* dst, PrepackedMatrix* prepacked_lhs,
PrepackedMatrix* prepacked_rhs,
std::function<void*(int)> alloc_fn) {
+ Mat<LhsScalar> internal_lhs = ToInternal(lhs);
+ Mat<RhsScalar> internal_rhs = ToInternal(rhs);
+ Mat<DstScalar> internal_dst = ToInternal(*dst);
SidePair<PrepackedMatrix*> prepacked(prepacked_lhs, prepacked_rhs);
- PrePackForMulInternal<CompiledPaths>(lhs, rhs, mul_params, context, dst,
- prepacked, alloc_fn);
+ PrePackForMulInternal<CompiledPaths>(internal_lhs, internal_rhs, mul_params,
+ context, &internal_dst, prepacked,
+ alloc_fn);
}
template <Path CompiledPaths, typename LhsScalar, typename RhsScalar,
@@ -61,8 +66,13 @@ void MulWithPrepacked(const Matrix<LhsScalar>& lhs,
const MulParamsType& mul_params, Context* context,
Matrix<DstScalar>* dst, PrepackedMatrix* prepacked_lhs,
PrepackedMatrix* prepacked_rhs) {
+ Mat<LhsScalar> internal_lhs = ToInternal(lhs);
+ Mat<RhsScalar> internal_rhs = ToInternal(rhs);
+ Mat<DstScalar> internal_dst = ToInternal(*dst);
SidePair<PrepackedMatrix*> prepacked(prepacked_lhs, prepacked_rhs);
- MulWithPrepackedInternal<CompiledPaths>(lhs, rhs, mul_params, context, dst,
+
+ MulWithPrepackedInternal<CompiledPaths>(internal_lhs, internal_rhs,
+ mul_params, context, &internal_dst,
prepacked);
}
diff --git a/ruy/trmul.cc b/ruy/trmul.cc
index 5755e3b..1b26fc1 100644
--- a/ruy/trmul.cc
+++ b/ruy/trmul.cc
@@ -26,7 +26,7 @@ limitations under the License.
#include "ruy/check_macros.h"
#include "ruy/common.h"
#include "ruy/context_internal.h"
-#include "ruy/internal_matrix.h"
+#include "ruy/mat.h"
#include "ruy/matrix.h"
#include "ruy/mul_params.h"
#include "ruy/opt_set.h"
@@ -240,7 +240,7 @@ struct TrMulTask final : Task {
SidePair<bool*> local_packed;
};
-void AllocatePMatrix(Allocator* allocator, PMatrix* packed) {
+void AllocatePMatrix(Allocator* allocator, PEMat* packed) {
packed->data = allocator->AllocateBytes(DataSize(*packed));
packed->sums = allocator->AllocateBytes(SumsSize(*packed));
}
@@ -284,10 +284,10 @@ void TrMul(TrMulParams* params, Context* context) {
static_cast<int>(params->path), context->max_num_threads,
params->is_prepacked[Side::kLhs], params->is_prepacked[Side::kRhs]);
- PMatrix& packed_lhs = params->packed[Side::kLhs];
- PMatrix& packed_rhs = params->packed[Side::kRhs];
- DMatrix& lhs = params->src[Side::kLhs];
- DMatrix& rhs = params->src[Side::kRhs];
+ PEMat& packed_lhs = params->packed[Side::kLhs];
+ PEMat& packed_rhs = params->packed[Side::kRhs];
+ EMat& lhs = params->src[Side::kLhs];
+ EMat& rhs = params->src[Side::kRhs];
const int rows = lhs.layout.cols;
const int cols = rhs.layout.cols;
diff --git a/ruy/trmul_params.h b/ruy/trmul_params.h
index a7017f8..fecef99 100644
--- a/ruy/trmul_params.h
+++ b/ruy/trmul_params.h
@@ -16,16 +16,16 @@ limitations under the License.
#ifndef RUY_RUY_TRMUL_PARAMS_H_
#define RUY_RUY_TRMUL_PARAMS_H_
-#include "ruy/internal_matrix.h"
+#include "ruy/mat.h"
#include "ruy/side_pair.h"
#include "ruy/tune.h"
namespace ruy {
-using RunKernelFn = void(Tuning, const SidePair<PMatrix>&, void*,
- const SidePair<int>&, const SidePair<int>&, DMatrix*);
+using RunKernelFn = void(Tuning, const SidePair<PEMat>&, void*,
+ const SidePair<int>&, const SidePair<int>&, EMat*);
-using RunPackFn = void(Tuning, const DMatrix&, PMatrix*, int, int);
+using RunPackFn = void(Tuning, const EMat&, PEMat*, int, int);
// Type-erased data needed for implementing TrMul.
struct TrMulParams {
@@ -53,9 +53,9 @@ struct TrMulParams {
RunKernelFn* run_kernel = nullptr;
// Matrices and packed matrices.
- SidePair<DMatrix> src;
- DMatrix dst;
- SidePair<PMatrix> packed;
+ SidePair<EMat> src;
+ EMat dst;
+ SidePair<PEMat> packed;
SidePair<bool> is_prepacked;
// Type-erased MulParamsType.