diff options
Diffstat (limited to 'extern/Eigen3/Eigen/src/Core/ArrayWrapper.h')
-rw-r--r-- | extern/Eigen3/Eigen/src/Core/ArrayWrapper.h | 151 |
1 files changed, 48 insertions, 103 deletions
diff --git a/extern/Eigen3/Eigen/src/Core/ArrayWrapper.h b/extern/Eigen3/Eigen/src/Core/ArrayWrapper.h index b4641e2a01f..688aadd6260 100644 --- a/extern/Eigen3/Eigen/src/Core/ArrayWrapper.h +++ b/extern/Eigen3/Eigen/src/Core/ArrayWrapper.h @@ -32,7 +32,8 @@ struct traits<ArrayWrapper<ExpressionType> > // Let's remove NestByRefBit enum { Flags0 = traits<typename remove_all<typename ExpressionType::Nested>::type >::Flags, - Flags = Flags0 & ~NestByRefBit + LvalueBitFlag = is_lvalue<ExpressionType>::value ? LvalueBit : 0, + Flags = (Flags0 & ~(NestByRefBit | LvalueBit)) | LvalueBitFlag }; }; } @@ -44,6 +45,7 @@ class ArrayWrapper : public ArrayBase<ArrayWrapper<ExpressionType> > typedef ArrayBase<ArrayWrapper> Base; EIGEN_DENSE_PUBLIC_INTERFACE(ArrayWrapper) EIGEN_INHERIT_ASSIGNMENT_OPERATORS(ArrayWrapper) + typedef typename internal::remove_all<ExpressionType>::type NestedExpression; typedef typename internal::conditional< internal::is_lvalue<ExpressionType>::value, @@ -51,76 +53,45 @@ class ArrayWrapper : public ArrayBase<ArrayWrapper<ExpressionType> > const Scalar >::type ScalarWithConstIfNotLvalue; - typedef typename internal::nested<ExpressionType>::type NestedExpressionType; + typedef typename internal::ref_selector<ExpressionType>::non_const_type NestedExpressionType; - inline ArrayWrapper(ExpressionType& matrix) : m_expression(matrix) {} + using Base::coeffRef; + EIGEN_DEVICE_FUNC + explicit EIGEN_STRONG_INLINE ArrayWrapper(ExpressionType& matrix) : m_expression(matrix) {} + + EIGEN_DEVICE_FUNC inline Index rows() const { return m_expression.rows(); } + EIGEN_DEVICE_FUNC inline Index cols() const { return m_expression.cols(); } + EIGEN_DEVICE_FUNC inline Index outerStride() const { return m_expression.outerStride(); } + EIGEN_DEVICE_FUNC inline Index innerStride() const { return m_expression.innerStride(); } - inline ScalarWithConstIfNotLvalue* data() { return m_expression.const_cast_derived().data(); } + EIGEN_DEVICE_FUNC + inline ScalarWithConstIfNotLvalue* data() { return m_expression.data(); } + EIGEN_DEVICE_FUNC inline const Scalar* data() const { return m_expression.data(); } - inline CoeffReturnType coeff(Index rowId, Index colId) const - { - return m_expression.coeff(rowId, colId); - } - - inline Scalar& coeffRef(Index rowId, Index colId) - { - return m_expression.const_cast_derived().coeffRef(rowId, colId); - } - + EIGEN_DEVICE_FUNC inline const Scalar& coeffRef(Index rowId, Index colId) const { - return m_expression.const_cast_derived().coeffRef(rowId, colId); - } - - inline CoeffReturnType coeff(Index index) const - { - return m_expression.coeff(index); - } - - inline Scalar& coeffRef(Index index) - { - return m_expression.const_cast_derived().coeffRef(index); + return m_expression.coeffRef(rowId, colId); } + EIGEN_DEVICE_FUNC inline const Scalar& coeffRef(Index index) const { - return m_expression.const_cast_derived().coeffRef(index); - } - - template<int LoadMode> - inline const PacketScalar packet(Index rowId, Index colId) const - { - return m_expression.template packet<LoadMode>(rowId, colId); - } - - template<int LoadMode> - inline void writePacket(Index rowId, Index colId, const PacketScalar& val) - { - m_expression.const_cast_derived().template writePacket<LoadMode>(rowId, colId, val); - } - - template<int LoadMode> - inline const PacketScalar packet(Index index) const - { - return m_expression.template packet<LoadMode>(index); - } - - template<int LoadMode> - inline void writePacket(Index index, const PacketScalar& val) - { - m_expression.const_cast_derived().template writePacket<LoadMode>(index, val); + return m_expression.coeffRef(index); } template<typename Dest> + EIGEN_DEVICE_FUNC inline void evalTo(Dest& dst) const { dst = m_expression; } const typename internal::remove_all<NestedExpressionType>::type& + EIGEN_DEVICE_FUNC nestedExpression() const { return m_expression; @@ -128,10 +99,12 @@ class ArrayWrapper : public ArrayBase<ArrayWrapper<ExpressionType> > /** Forwards the resizing request to the nested expression * \sa DenseBase::resize(Index) */ - void resize(Index newSize) { m_expression.const_cast_derived().resize(newSize); } + EIGEN_DEVICE_FUNC + void resize(Index newSize) { m_expression.resize(newSize); } /** Forwards the resizing request to the nested expression * \sa DenseBase::resize(Index,Index)*/ - void resize(Index nbRows, Index nbCols) { m_expression.const_cast_derived().resize(nbRows,nbCols); } + EIGEN_DEVICE_FUNC + void resize(Index rows, Index cols) { m_expression.resize(rows,cols); } protected: NestedExpressionType m_expression; @@ -157,7 +130,8 @@ struct traits<MatrixWrapper<ExpressionType> > // Let's remove NestByRefBit enum { Flags0 = traits<typename remove_all<typename ExpressionType::Nested>::type >::Flags, - Flags = Flags0 & ~NestByRefBit + LvalueBitFlag = is_lvalue<ExpressionType>::value ? LvalueBit : 0, + Flags = (Flags0 & ~(NestByRefBit | LvalueBit)) | LvalueBitFlag }; }; } @@ -169,6 +143,7 @@ class MatrixWrapper : public MatrixBase<MatrixWrapper<ExpressionType> > typedef MatrixBase<MatrixWrapper<ExpressionType> > Base; EIGEN_DENSE_PUBLIC_INTERFACE(MatrixWrapper) EIGEN_INHERIT_ASSIGNMENT_OPERATORS(MatrixWrapper) + typedef typename internal::remove_all<ExpressionType>::type NestedExpression; typedef typename internal::conditional< internal::is_lvalue<ExpressionType>::value, @@ -176,72 +151,40 @@ class MatrixWrapper : public MatrixBase<MatrixWrapper<ExpressionType> > const Scalar >::type ScalarWithConstIfNotLvalue; - typedef typename internal::nested<ExpressionType>::type NestedExpressionType; + typedef typename internal::ref_selector<ExpressionType>::non_const_type NestedExpressionType; - inline MatrixWrapper(ExpressionType& a_matrix) : m_expression(a_matrix) {} + using Base::coeffRef; + EIGEN_DEVICE_FUNC + explicit inline MatrixWrapper(ExpressionType& matrix) : m_expression(matrix) {} + + EIGEN_DEVICE_FUNC inline Index rows() const { return m_expression.rows(); } + EIGEN_DEVICE_FUNC inline Index cols() const { return m_expression.cols(); } + EIGEN_DEVICE_FUNC inline Index outerStride() const { return m_expression.outerStride(); } + EIGEN_DEVICE_FUNC inline Index innerStride() const { return m_expression.innerStride(); } - inline ScalarWithConstIfNotLvalue* data() { return m_expression.const_cast_derived().data(); } + EIGEN_DEVICE_FUNC + inline ScalarWithConstIfNotLvalue* data() { return m_expression.data(); } + EIGEN_DEVICE_FUNC inline const Scalar* data() const { return m_expression.data(); } - inline CoeffReturnType coeff(Index rowId, Index colId) const - { - return m_expression.coeff(rowId, colId); - } - - inline Scalar& coeffRef(Index rowId, Index colId) - { - return m_expression.const_cast_derived().coeffRef(rowId, colId); - } - + EIGEN_DEVICE_FUNC inline const Scalar& coeffRef(Index rowId, Index colId) const { return m_expression.derived().coeffRef(rowId, colId); } - inline CoeffReturnType coeff(Index index) const - { - return m_expression.coeff(index); - } - - inline Scalar& coeffRef(Index index) - { - return m_expression.const_cast_derived().coeffRef(index); - } - + EIGEN_DEVICE_FUNC inline const Scalar& coeffRef(Index index) const { - return m_expression.const_cast_derived().coeffRef(index); - } - - template<int LoadMode> - inline const PacketScalar packet(Index rowId, Index colId) const - { - return m_expression.template packet<LoadMode>(rowId, colId); - } - - template<int LoadMode> - inline void writePacket(Index rowId, Index colId, const PacketScalar& val) - { - m_expression.const_cast_derived().template writePacket<LoadMode>(rowId, colId, val); - } - - template<int LoadMode> - inline const PacketScalar packet(Index index) const - { - return m_expression.template packet<LoadMode>(index); - } - - template<int LoadMode> - inline void writePacket(Index index, const PacketScalar& val) - { - m_expression.const_cast_derived().template writePacket<LoadMode>(index, val); + return m_expression.coeffRef(index); } + EIGEN_DEVICE_FUNC const typename internal::remove_all<NestedExpressionType>::type& nestedExpression() const { @@ -250,10 +193,12 @@ class MatrixWrapper : public MatrixBase<MatrixWrapper<ExpressionType> > /** Forwards the resizing request to the nested expression * \sa DenseBase::resize(Index) */ - void resize(Index newSize) { m_expression.const_cast_derived().resize(newSize); } + EIGEN_DEVICE_FUNC + void resize(Index newSize) { m_expression.resize(newSize); } /** Forwards the resizing request to the nested expression * \sa DenseBase::resize(Index,Index)*/ - void resize(Index nbRows, Index nbCols) { m_expression.const_cast_derived().resize(nbRows,nbCols); } + EIGEN_DEVICE_FUNC + void resize(Index rows, Index cols) { m_expression.resize(rows,cols); } protected: NestedExpressionType m_expression; |