diff options
author | Benoit Jacob <benoitjacob@google.com> | 2020-04-22 05:51:26 +0300 |
---|---|---|
committer | Copybara-Service <copybara-worker@google.com> | 2020-04-22 05:51:49 +0300 |
commit | 145aecd896b44cfc455803aef4ed98745054567e (patch) | |
tree | a4d4e08dc63c6eca9e326985ebb91fd6b4468381 /ruy/matrix.h | |
parent | f3c69a73897c2d97851d9528b34ba8c9371da886 (diff) |
Rename:
internal_matrix.h -> mat.h
Introduce the Mat class, internal counterpart of Matrix.
Renaming: functions that erase/unerase scalar types are now named
EraseType/UneraseType.
Renaming:
PackedLayout -> PMatLayout
PackedMatrix -> PMat
DMatrix -> EMat
PMatrix -> PEMat
PiperOrigin-RevId: 307730877
Diffstat (limited to 'ruy/matrix.h')
-rw-r--r-- | ruy/matrix.h | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/ruy/matrix.h b/ruy/matrix.h index 16f330c..389f37f 100644 --- a/ruy/matrix.h +++ b/ruy/matrix.h @@ -223,6 +223,35 @@ template <Order tOrder, int tRows, int tCols> constexpr int FixedKernelLayout<tOrder, tRows, tCols>::kRows; #endif +// TODO(b/130417400) add a unit test +inline int Offset(const Layout& layout, int row, int col) { + // TODO(benoitjacob) - should check this but this make the _slow tests take + // 5x longer. Find a mitigation like in Eigen with an 'internal' variant + // bypassing the check? + // RUY_DCHECK_GE(row, 0); + // RUY_DCHECK_GE(col, 0); + // RUY_DCHECK_LT(row, layout.rows); + // RUY_DCHECK_LT(col, layout.cols); + int row_stride = layout.order == Order::kColMajor ? 1 : layout.stride; + int col_stride = layout.order == Order::kRowMajor ? 1 : layout.stride; + return row * row_stride + col * col_stride; +} + +template <typename Scalar> +const Scalar* ElementPtr(const Matrix<Scalar>& mat, int row, int col) { + return mat.data.get() + Offset(mat.layout, row, col); +} + +template <typename Scalar> +Scalar* ElementPtr(Matrix<Scalar>* mat, int row, int col) { + return mat->data.get() + Offset(mat->layout, row, col); +} + +template <typename Scalar> +Scalar Element(const Matrix<Scalar>& mat, int row, int col) { + return *ElementPtr(mat, row, col); +} + } // namespace ruy #endif // RUY_RUY_MATRIX_H_ |