SparseDiagonalProduct.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_SPARSE_DIAGONAL_PRODUCT_H
11#define EIGEN_SPARSE_DIAGONAL_PRODUCT_H
12
13namespace Eigen {
14
15// The product of a diagonal matrix with a sparse matrix can be easily
16// implemented using expression template.
17// We have two consider very different cases:
18// 1 - diag * row-major sparse
19// => each inner vector <=> scalar * sparse vector product
20// => so we can reuse CwiseUnaryOp::InnerIterator
21// 2 - diag * col-major sparse
22// => each inner vector <=> densevector * sparse vector cwise product
23// => again, we can reuse specialization of CwiseBinaryOp::InnerIterator
24// for that particular case
25// The two other cases are symmetric.
26
27namespace internal {
28
29template<typename Lhs, typename Rhs>
30struct traits<SparseDiagonalProduct<Lhs, Rhs> >
31{
32 typedef typename remove_all<Lhs>::type _Lhs;
33 typedef typename remove_all<Rhs>::type _Rhs;
34 typedef typename _Lhs::Scalar Scalar;
35 typedef typename promote_index_type<typename traits<Lhs>::Index,
36 typename traits<Rhs>::Index>::type Index;
37 typedef Sparse StorageKind;
38 typedef MatrixXpr XprKind;
39 enum {
40 RowsAtCompileTime = _Lhs::RowsAtCompileTime,
41 ColsAtCompileTime = _Rhs::ColsAtCompileTime,
42
43 MaxRowsAtCompileTime = _Lhs::MaxRowsAtCompileTime,
44 MaxColsAtCompileTime = _Rhs::MaxColsAtCompileTime,
45
46 SparseFlags = is_diagonal<_Lhs>::ret ? int(_Rhs::Flags) : int(_Lhs::Flags),
47 Flags = (SparseFlags&RowMajorBit),
48 CoeffReadCost = Dynamic
49 };
50};
51
52enum {SDP_IsDiagonal, SDP_IsSparseRowMajor, SDP_IsSparseColMajor};
53template<typename Lhs, typename Rhs, typename SparseDiagonalProductType, int RhsMode, int LhsMode>
54class sparse_diagonal_product_inner_iterator_selector;
55
56} // end namespace internal
57
58template<typename Lhs, typename Rhs>
59class SparseDiagonalProduct
60 : public SparseMatrixBase<SparseDiagonalProduct<Lhs,Rhs> >,
61 internal::no_assignment_operator
62{
63 typedef typename Lhs::Nested LhsNested;
64 typedef typename Rhs::Nested RhsNested;
65
66 typedef typename internal::remove_all<LhsNested>::type _LhsNested;
67 typedef typename internal::remove_all<RhsNested>::type _RhsNested;
68
69 enum {
70 LhsMode = internal::is_diagonal<_LhsNested>::ret ? internal::SDP_IsDiagonal
71 : (_LhsNested::Flags&RowMajorBit) ? internal::SDP_IsSparseRowMajor : internal::SDP_IsSparseColMajor,
72 RhsMode = internal::is_diagonal<_RhsNested>::ret ? internal::SDP_IsDiagonal
73 : (_RhsNested::Flags&RowMajorBit) ? internal::SDP_IsSparseRowMajor : internal::SDP_IsSparseColMajor
74 };
75
76 public:
77
78 EIGEN_SPARSE_PUBLIC_INTERFACE(SparseDiagonalProduct)
79
80 typedef internal::sparse_diagonal_product_inner_iterator_selector
81 <_LhsNested,_RhsNested,SparseDiagonalProduct,LhsMode,RhsMode> InnerIterator;
82
83 EIGEN_STRONG_INLINE SparseDiagonalProduct(const Lhs& lhs, const Rhs& rhs)
84 : m_lhs(lhs), m_rhs(rhs)
85 {
86 eigen_assert(lhs.cols() == rhs.rows() && "invalid sparse matrix * diagonal matrix product");
87 }
88
89 EIGEN_STRONG_INLINE Index rows() const { return m_lhs.rows(); }
90 EIGEN_STRONG_INLINE Index cols() const { return m_rhs.cols(); }
91
92 EIGEN_STRONG_INLINE const _LhsNested& lhs() const { return m_lhs; }
93 EIGEN_STRONG_INLINE const _RhsNested& rhs() const { return m_rhs; }
94
95 protected:
96 LhsNested m_lhs;
97 RhsNested m_rhs;
98};
99
100namespace internal {
101
102template<typename Lhs, typename Rhs, typename SparseDiagonalProductType>
103class sparse_diagonal_product_inner_iterator_selector
104<Lhs,Rhs,SparseDiagonalProductType,SDP_IsDiagonal,SDP_IsSparseRowMajor>
105 : public CwiseUnaryOp<scalar_multiple_op<typename Lhs::Scalar>,const Rhs>::InnerIterator
106{
107 typedef typename CwiseUnaryOp<scalar_multiple_op<typename Lhs::Scalar>,const Rhs>::InnerIterator Base;
108 typedef typename Lhs::Index Index;
109 public:
110 inline sparse_diagonal_product_inner_iterator_selector(
111 const SparseDiagonalProductType& expr, Index outer)
112 : Base(expr.rhs()*(expr.lhs().diagonal().coeff(outer)), outer)
113 {}
114};
115
116template<typename Lhs, typename Rhs, typename SparseDiagonalProductType>
117class sparse_diagonal_product_inner_iterator_selector
118<Lhs,Rhs,SparseDiagonalProductType,SDP_IsDiagonal,SDP_IsSparseColMajor>
119 : public CwiseBinaryOp<
120 scalar_product_op<typename Lhs::Scalar>,
121 SparseInnerVectorSet<Rhs,1>,
122 typename Lhs::DiagonalVectorType>::InnerIterator
123{
124 typedef typename CwiseBinaryOp<
125 scalar_product_op<typename Lhs::Scalar>,
126 SparseInnerVectorSet<Rhs,1>,
127 typename Lhs::DiagonalVectorType>::InnerIterator Base;
128 typedef typename Lhs::Index Index;
129 Index m_outer;
130 public:
131 inline sparse_diagonal_product_inner_iterator_selector(
132 const SparseDiagonalProductType& expr, Index outer)
133 : Base(expr.rhs().innerVector(outer) .cwiseProduct(expr.lhs().diagonal()), 0), m_outer(outer)
134 {}
135
136 inline Index outer() const { return m_outer; }
137 inline Index col() const { return m_outer; }
138};
139
140template<typename Lhs, typename Rhs, typename SparseDiagonalProductType>
141class sparse_diagonal_product_inner_iterator_selector
142<Lhs,Rhs,SparseDiagonalProductType,SDP_IsSparseColMajor,SDP_IsDiagonal>
143 : public CwiseUnaryOp<scalar_multiple_op<typename Rhs::Scalar>,const Lhs>::InnerIterator
144{
145 typedef typename CwiseUnaryOp<scalar_multiple_op<typename Rhs::Scalar>,const Lhs>::InnerIterator Base;
146 typedef typename Lhs::Index Index;
147 public:
148 inline sparse_diagonal_product_inner_iterator_selector(
149 const SparseDiagonalProductType& expr, Index outer)
150 : Base(expr.lhs()*expr.rhs().diagonal().coeff(outer), outer)
151 {}
152};
153
154template<typename Lhs, typename Rhs, typename SparseDiagonalProductType>
155class sparse_diagonal_product_inner_iterator_selector
156<Lhs,Rhs,SparseDiagonalProductType,SDP_IsSparseRowMajor,SDP_IsDiagonal>
157 : public CwiseBinaryOp<
158 scalar_product_op<typename Rhs::Scalar>,
159 SparseInnerVectorSet<Lhs,1>,
160 Transpose<const typename Rhs::DiagonalVectorType> >::InnerIterator
161{
162 typedef typename CwiseBinaryOp<
163 scalar_product_op<typename Rhs::Scalar>,
164 SparseInnerVectorSet<Lhs,1>,
165 Transpose<const typename Rhs::DiagonalVectorType> >::InnerIterator Base;
166 typedef typename Lhs::Index Index;
167 Index m_outer;
168 public:
169 inline sparse_diagonal_product_inner_iterator_selector(
170 const SparseDiagonalProductType& expr, Index outer)
171 : Base(expr.lhs().innerVector(outer) .cwiseProduct(expr.rhs().diagonal().transpose()), 0), m_outer(outer)
172 {}
173
174 inline Index outer() const { return m_outer; }
175 inline Index row() const { return m_outer; }
176};
177
178} // end namespace internal
179
180// SparseMatrixBase functions
181
182template<typename Derived>
183template<typename OtherDerived>
184const SparseDiagonalProduct<Derived,OtherDerived>
185SparseMatrixBase<Derived>::operator*(const DiagonalBase<OtherDerived> &other) const
186{
187 return SparseDiagonalProduct<Derived,OtherDerived>(this->derived(), other.derived());
188}
189
190} // end namespace Eigen
191
192#endif // EIGEN_SPARSE_DIAGONAL_PRODUCT_H
const unsigned int RowMajorBit
Definition Constants.h:48
Definition LDLT.h:18