diff options
Diffstat (limited to 'extern/Eigen3/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h')
-rw-r--r-- | extern/Eigen3/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h | 112 |
1 files changed, 88 insertions, 24 deletions
diff --git a/extern/Eigen3/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h b/extern/Eigen3/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h index 432d3a9dc84..5c37639091c 100644 --- a/extern/Eigen3/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h +++ b/extern/Eigen3/Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h @@ -12,6 +12,9 @@ namespace Eigen { +template<typename Scalar, typename Index, int StorageOrder, int UpLo, bool ConjLhs, bool ConjRhs> +struct selfadjoint_rank1_update; + namespace internal { /********************************************************************** @@ -39,7 +42,7 @@ struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder, { typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar; static EIGEN_STRONG_INLINE void run(Index size, Index depth,const LhsScalar* lhs, Index lhsStride, - const RhsScalar* rhs, Index rhsStride, ResScalar* res, Index resStride, ResScalar alpha) + const RhsScalar* rhs, Index rhsStride, ResScalar* res, Index resStride, const ResScalar& alpha) { general_matrix_matrix_triangular_product<Index, RhsScalar, RhsStorageOrder==RowMajor ? ColMajor : RowMajor, ConjugateRhs, @@ -55,7 +58,7 @@ struct general_matrix_matrix_triangular_product<Index,LhsScalar,LhsStorageOrder, { typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar; static EIGEN_STRONG_INLINE void run(Index size, Index depth,const LhsScalar* _lhs, Index lhsStride, - const RhsScalar* _rhs, Index rhsStride, ResScalar* res, Index resStride, ResScalar alpha) + const RhsScalar* _rhs, Index rhsStride, ResScalar* res, Index resStride, const ResScalar& alpha) { const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride); const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride); @@ -133,7 +136,7 @@ struct tribb_kernel enum { BlockSize = EIGEN_PLAIN_ENUM_MAX(mr,nr) }; - void operator()(ResScalar* res, Index resStride, const LhsScalar* blockA, const RhsScalar* blockB, Index size, Index depth, ResScalar alpha, RhsScalar* workspace) + void operator()(ResScalar* res, Index resStride, const LhsScalar* blockA, const RhsScalar* blockB, Index size, Index depth, const ResScalar& alpha, RhsScalar* workspace) { gebp_kernel<LhsScalar, RhsScalar, Index, mr, nr, ConjLhs, ConjRhs> gebp_kernel; Matrix<ResScalar,BlockSize,BlockSize,ColMajor> buffer; @@ -180,31 +183,92 @@ struct tribb_kernel // high level API +template<typename MatrixType, typename ProductType, int UpLo, bool IsOuterProduct> +struct general_product_to_triangular_selector; + + +template<typename MatrixType, typename ProductType, int UpLo> +struct general_product_to_triangular_selector<MatrixType,ProductType,UpLo,true> +{ + static void run(MatrixType& mat, const ProductType& prod, const typename MatrixType::Scalar& alpha) + { + typedef typename MatrixType::Scalar Scalar; + typedef typename MatrixType::Index Index; + + typedef typename internal::remove_all<typename ProductType::LhsNested>::type Lhs; + typedef internal::blas_traits<Lhs> LhsBlasTraits; + typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhs; + typedef typename internal::remove_all<ActualLhs>::type _ActualLhs; + typename internal::add_const_on_value_type<ActualLhs>::type actualLhs = LhsBlasTraits::extract(prod.lhs()); + + typedef typename internal::remove_all<typename ProductType::RhsNested>::type Rhs; + typedef internal::blas_traits<Rhs> RhsBlasTraits; + typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhs; + typedef typename internal::remove_all<ActualRhs>::type _ActualRhs; + typename internal::add_const_on_value_type<ActualRhs>::type actualRhs = RhsBlasTraits::extract(prod.rhs()); + + Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived()); + + enum { + StorageOrder = (internal::traits<MatrixType>::Flags&RowMajorBit) ? RowMajor : ColMajor, + UseLhsDirectly = _ActualLhs::InnerStrideAtCompileTime==1, + UseRhsDirectly = _ActualRhs::InnerStrideAtCompileTime==1 + }; + + internal::gemv_static_vector_if<Scalar,Lhs::SizeAtCompileTime,Lhs::MaxSizeAtCompileTime,!UseLhsDirectly> static_lhs; + ei_declare_aligned_stack_constructed_variable(Scalar, actualLhsPtr, actualLhs.size(), + (UseLhsDirectly ? const_cast<Scalar*>(actualLhs.data()) : static_lhs.data())); + if(!UseLhsDirectly) Map<typename _ActualLhs::PlainObject>(actualLhsPtr, actualLhs.size()) = actualLhs; + + internal::gemv_static_vector_if<Scalar,Rhs::SizeAtCompileTime,Rhs::MaxSizeAtCompileTime,!UseRhsDirectly> static_rhs; + ei_declare_aligned_stack_constructed_variable(Scalar, actualRhsPtr, actualRhs.size(), + (UseRhsDirectly ? const_cast<Scalar*>(actualRhs.data()) : static_rhs.data())); + if(!UseRhsDirectly) Map<typename _ActualRhs::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs; + + + selfadjoint_rank1_update<Scalar,Index,StorageOrder,UpLo, + LhsBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex, + RhsBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex> + ::run(actualLhs.size(), mat.data(), mat.outerStride(), actualLhsPtr, actualRhsPtr, actualAlpha); + } +}; + +template<typename MatrixType, typename ProductType, int UpLo> +struct general_product_to_triangular_selector<MatrixType,ProductType,UpLo,false> +{ + static void run(MatrixType& mat, const ProductType& prod, const typename MatrixType::Scalar& alpha) + { + typedef typename MatrixType::Index Index; + + typedef typename internal::remove_all<typename ProductType::LhsNested>::type Lhs; + typedef internal::blas_traits<Lhs> LhsBlasTraits; + typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhs; + typedef typename internal::remove_all<ActualLhs>::type _ActualLhs; + typename internal::add_const_on_value_type<ActualLhs>::type actualLhs = LhsBlasTraits::extract(prod.lhs()); + + typedef typename internal::remove_all<typename ProductType::RhsNested>::type Rhs; + typedef internal::blas_traits<Rhs> RhsBlasTraits; + typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhs; + typedef typename internal::remove_all<ActualRhs>::type _ActualRhs; + typename internal::add_const_on_value_type<ActualRhs>::type actualRhs = RhsBlasTraits::extract(prod.rhs()); + + typename ProductType::Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived()); + + internal::general_matrix_matrix_triangular_product<Index, + typename Lhs::Scalar, _ActualLhs::Flags&RowMajorBit ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate, + typename Rhs::Scalar, _ActualRhs::Flags&RowMajorBit ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate, + MatrixType::Flags&RowMajorBit ? RowMajor : ColMajor, UpLo> + ::run(mat.cols(), actualLhs.cols(), + &actualLhs.coeffRef(0,0), actualLhs.outerStride(), &actualRhs.coeffRef(0,0), actualRhs.outerStride(), + mat.data(), mat.outerStride(), actualAlpha); + } +}; + template<typename MatrixType, unsigned int UpLo> template<typename ProductDerived, typename _Lhs, typename _Rhs> TriangularView<MatrixType,UpLo>& TriangularView<MatrixType,UpLo>::assignProduct(const ProductBase<ProductDerived, _Lhs,_Rhs>& prod, const Scalar& alpha) { - typedef typename internal::remove_all<typename ProductDerived::LhsNested>::type Lhs; - typedef internal::blas_traits<Lhs> LhsBlasTraits; - typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhs; - typedef typename internal::remove_all<ActualLhs>::type _ActualLhs; - typename internal::add_const_on_value_type<ActualLhs>::type actualLhs = LhsBlasTraits::extract(prod.lhs()); - - typedef typename internal::remove_all<typename ProductDerived::RhsNested>::type Rhs; - typedef internal::blas_traits<Rhs> RhsBlasTraits; - typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhs; - typedef typename internal::remove_all<ActualRhs>::type _ActualRhs; - typename internal::add_const_on_value_type<ActualRhs>::type actualRhs = RhsBlasTraits::extract(prod.rhs()); - - typename ProductDerived::Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived()); - - internal::general_matrix_matrix_triangular_product<Index, - typename Lhs::Scalar, _ActualLhs::Flags&RowMajorBit ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate, - typename Rhs::Scalar, _ActualRhs::Flags&RowMajorBit ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate, - MatrixType::Flags&RowMajorBit ? RowMajor : ColMajor, UpLo> - ::run(m_matrix.cols(), actualLhs.cols(), - &actualLhs.coeffRef(0,0), actualLhs.outerStride(), &actualRhs.coeffRef(0,0), actualRhs.outerStride(), - const_cast<Scalar*>(m_matrix.data()), m_matrix.outerStride(), actualAlpha); + general_product_to_triangular_selector<MatrixType, ProductDerived, UpLo, (_Lhs::ColsAtCompileTime==1) || (_Rhs::RowsAtCompileTime==1)>::run(m_matrix.const_cast_derived(), prod.derived(), alpha); return *this; } |