diff options
Diffstat (limited to 'extern/Eigen3/Eigen/src/SparseCore/SparseSparseProductWithPruning.h')
-rw-r--r-- | extern/Eigen3/Eigen/src/SparseCore/SparseSparseProductWithPruning.h | 149 |
1 files changed, 149 insertions, 0 deletions
diff --git a/extern/Eigen3/Eigen/src/SparseCore/SparseSparseProductWithPruning.h b/extern/Eigen3/Eigen/src/SparseCore/SparseSparseProductWithPruning.h new file mode 100644 index 00000000000..2438ac573d0 --- /dev/null +++ b/extern/Eigen3/Eigen/src/SparseCore/SparseSparseProductWithPruning.h @@ -0,0 +1,149 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2008-2011 Gael Guennebaud <gael.guennebaud@inria.fr> +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H +#define EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H + +namespace Eigen { + +namespace internal { + + +// perform a pseudo in-place sparse * sparse product assuming all matrices are col major +template<typename Lhs, typename Rhs, typename ResultType> +static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res, typename ResultType::RealScalar tolerance) +{ + // return sparse_sparse_product_with_pruning_impl2(lhs,rhs,res); + + typedef typename remove_all<Lhs>::type::Scalar Scalar; + typedef typename remove_all<Lhs>::type::Index Index; + + // make sure to call innerSize/outerSize since we fake the storage order. + Index rows = lhs.innerSize(); + Index cols = rhs.outerSize(); + //int size = lhs.outerSize(); + eigen_assert(lhs.outerSize() == rhs.innerSize()); + + // allocate a temporary buffer + AmbiVector<Scalar,Index> tempVector(rows); + + // estimate the number of non zero entries + // given a rhs column containing Y non zeros, we assume that the respective Y columns + // of the lhs differs in average of one non zeros, thus the number of non zeros for + // the product of a rhs column with the lhs is X+Y where X is the average number of non zero + // per column of the lhs. + // Therefore, we have nnz(lhs*rhs) = nnz(lhs) + nnz(rhs) + Index estimated_nnz_prod = lhs.nonZeros() + rhs.nonZeros(); + + // mimics a resizeByInnerOuter: + if(ResultType::IsRowMajor) + res.resize(cols, rows); + else + res.resize(rows, cols); + + res.reserve(estimated_nnz_prod); + double ratioColRes = double(estimated_nnz_prod)/double(lhs.rows()*rhs.cols()); + for (Index j=0; j<cols; ++j) + { + // FIXME: + //double ratioColRes = (double(rhs.innerVector(j).nonZeros()) + double(lhs.nonZeros())/double(lhs.cols()))/double(lhs.rows()); + // let's do a more accurate determination of the nnz ratio for the current column j of res + tempVector.init(ratioColRes); + tempVector.setZero(); + for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt) + { + // FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index()) + tempVector.restart(); + Scalar x = rhsIt.value(); + for (typename Lhs::InnerIterator lhsIt(lhs, rhsIt.index()); lhsIt; ++lhsIt) + { + tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x; + } + } + res.startVec(j); + for (typename AmbiVector<Scalar,Index>::Iterator it(tempVector,tolerance); it; ++it) + res.insertBackByOuterInner(j,it.index()) = it.value(); + } + res.finalize(); +} + +template<typename Lhs, typename Rhs, typename ResultType, + int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit, + int RhsStorageOrder = traits<Rhs>::Flags&RowMajorBit, + int ResStorageOrder = traits<ResultType>::Flags&RowMajorBit> +struct sparse_sparse_product_with_pruning_selector; + +template<typename Lhs, typename Rhs, typename ResultType> +struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor> +{ + typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar; + typedef typename ResultType::RealScalar RealScalar; + + static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, RealScalar tolerance) + { + typename remove_all<ResultType>::type _res(res.rows(), res.cols()); + internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,ResultType>(lhs, rhs, _res, tolerance); + res.swap(_res); + } +}; + +template<typename Lhs, typename Rhs, typename ResultType> +struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor> +{ + typedef typename ResultType::RealScalar RealScalar; + static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, RealScalar tolerance) + { + // we need a col-major matrix to hold the result + typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType; + SparseTemporaryType _res(res.rows(), res.cols()); + internal::sparse_sparse_product_with_pruning_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, _res, tolerance); + res = _res; + } +}; + +template<typename Lhs, typename Rhs, typename ResultType> +struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor> +{ + typedef typename ResultType::RealScalar RealScalar; + static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, RealScalar tolerance) + { + // let's transpose the product to get a column x column product + typename remove_all<ResultType>::type _res(res.rows(), res.cols()); + internal::sparse_sparse_product_with_pruning_impl<Rhs,Lhs,ResultType>(rhs, lhs, _res, tolerance); + res.swap(_res); + } +}; + +template<typename Lhs, typename Rhs, typename ResultType> +struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor> +{ + typedef typename ResultType::RealScalar RealScalar; + static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, RealScalar tolerance) + { + typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix; + ColMajorMatrix colLhs(lhs); + ColMajorMatrix colRhs(rhs); + internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrix,ColMajorMatrix,ResultType>(colLhs, colRhs, res, tolerance); + + // let's transpose the product to get a column x column product +// typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType; +// SparseTemporaryType _res(res.cols(), res.rows()); +// sparse_sparse_product_with_pruning_impl<Rhs,Lhs,SparseTemporaryType>(rhs, lhs, _res); +// res = _res.transpose(); + } +}; + +// NOTE the 2 others cases (col row *) must never occur since they are caught +// by ProductReturnType which transforms it to (col col *) by evaluating rhs. + +} // end namespace internal + +} // end namespace Eigen + +#endif // EIGEN_SPARSESPARSEPRODUCTWITHPRUNING_H |