Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/google/ruy.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBenoit Jacob <benoitjacob@google.com>2020-04-22 05:51:26 +0300
committerCopybara-Service <copybara-worker@google.com>2020-04-22 05:51:49 +0300
commit145aecd896b44cfc455803aef4ed98745054567e (patch)
treea4d4e08dc63c6eca9e326985ebb91fd6b4468381 /ruy/matrix.h
parentf3c69a73897c2d97851d9528b34ba8c9371da886 (diff)
Rename:
internal_matrix.h -> mat.h Introduce the Mat class, internal counterpart of Matrix. Renaming: functions that erase/unerase scalar types are now named EraseType/UneraseType. Renaming: PackedLayout -> PMatLayout PackedMatrix -> PMat DMatrix -> EMat PMatrix -> PEMat PiperOrigin-RevId: 307730877
Diffstat (limited to 'ruy/matrix.h')
-rw-r--r--ruy/matrix.h29
1 files changed, 29 insertions, 0 deletions
diff --git a/ruy/matrix.h b/ruy/matrix.h
index 16f330c..389f37f 100644
--- a/ruy/matrix.h
+++ b/ruy/matrix.h
@@ -223,6 +223,35 @@ template <Order tOrder, int tRows, int tCols>
constexpr int FixedKernelLayout<tOrder, tRows, tCols>::kRows;
#endif
+// TODO(b/130417400) add a unit test
+inline int Offset(const Layout& layout, int row, int col) {
+ // TODO(benoitjacob) - should check this but this make the _slow tests take
+ // 5x longer. Find a mitigation like in Eigen with an 'internal' variant
+ // bypassing the check?
+ // RUY_DCHECK_GE(row, 0);
+ // RUY_DCHECK_GE(col, 0);
+ // RUY_DCHECK_LT(row, layout.rows);
+ // RUY_DCHECK_LT(col, layout.cols);
+ int row_stride = layout.order == Order::kColMajor ? 1 : layout.stride;
+ int col_stride = layout.order == Order::kRowMajor ? 1 : layout.stride;
+ return row * row_stride + col * col_stride;
+}
+
+template <typename Scalar>
+const Scalar* ElementPtr(const Matrix<Scalar>& mat, int row, int col) {
+ return mat.data.get() + Offset(mat.layout, row, col);
+}
+
+template <typename Scalar>
+Scalar* ElementPtr(Matrix<Scalar>* mat, int row, int col) {
+ return mat->data.get() + Offset(mat->layout, row, col);
+}
+
+template <typename Scalar>
+Scalar Element(const Matrix<Scalar>& mat, int row, int col) {
+ return *ElementPtr(mat, row, col);
+}
+
} // namespace ruy
#endif // RUY_RUY_MATRIX_H_