Eigen  5.0.1-dev+60122df6
 
Loading...
Searching...
No Matches
TriangularSolver.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2008 Gael Guennebaud <gael.guennebaud@inria.fr>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_SPARSETRIANGULARSOLVER_H
11#define EIGEN_SPARSETRIANGULARSOLVER_H
12
13// IWYU pragma: private
14#include "./InternalHeaderCheck.h"
15
16namespace Eigen {
17
18namespace internal {
19
20template <typename Lhs, typename Rhs, int Mode,
21 int UpLo = (Mode & Lower) ? Lower
22 : (Mode & Upper) ? Upper
23 : -1,
24 int StorageOrder = int(traits<Lhs>::Flags) & RowMajorBit>
25struct sparse_solve_triangular_selector;
26
27// forward substitution, row-major
28template <typename Lhs, typename Rhs, int Mode>
29struct sparse_solve_triangular_selector<Lhs, Rhs, Mode, Lower, RowMajor> {
30 typedef typename Rhs::Scalar Scalar;
31 typedef evaluator<Lhs> LhsEval;
32 typedef typename evaluator<Lhs>::InnerIterator LhsIterator;
33 static void run(const Lhs& lhs, Rhs& other) {
34 LhsEval lhsEval(lhs);
35 for (Index col = 0; col < other.cols(); ++col) {
36 for (Index i = 0; i < lhs.rows(); ++i) {
37 Scalar tmp = other.coeff(i, col);
38 Scalar lastVal(0);
39 Index lastIndex = 0;
40 for (LhsIterator it(lhsEval, i); it; ++it) {
41 lastVal = it.value();
42 lastIndex = it.index();
43 if (lastIndex == i) break;
44 tmp = numext::madd<Scalar>(-lastVal, other.coeff(lastIndex, col), tmp);
45 }
46 if (Mode & UnitDiag)
47 other.coeffRef(i, col) = tmp;
48 else {
49 eigen_assert(lastIndex == i);
50 other.coeffRef(i, col) = tmp / lastVal;
51 }
52 }
53 }
54 }
55};
56
57// backward substitution, row-major
58template <typename Lhs, typename Rhs, int Mode>
59struct sparse_solve_triangular_selector<Lhs, Rhs, Mode, Upper, RowMajor> {
60 typedef typename Rhs::Scalar Scalar;
61 typedef evaluator<Lhs> LhsEval;
62 typedef typename evaluator<Lhs>::InnerIterator LhsIterator;
63 static void run(const Lhs& lhs, Rhs& other) {
64 LhsEval lhsEval(lhs);
65 for (Index col = 0; col < other.cols(); ++col) {
66 for (Index i = lhs.rows() - 1; i >= 0; --i) {
67 Scalar tmp = other.coeff(i, col);
68 Scalar l_ii(0);
69 LhsIterator it(lhsEval, i);
70 while (it && it.index() < i) ++it;
71 if (!(Mode & UnitDiag)) {
72 eigen_assert(it && it.index() == i);
73 l_ii = it.value();
74 ++it;
75 } else if (it && it.index() == i)
76 ++it;
77 for (; it; ++it) {
78 tmp = numext::madd<Scalar>(-it.value(), other.coeff(it.index(), col), tmp);
79 }
80
81 if (Mode & UnitDiag)
82 other.coeffRef(i, col) = tmp;
83 else
84 other.coeffRef(i, col) = tmp / l_ii;
85 }
86 }
87 }
88};
89
90// forward substitution, col-major
91template <typename Lhs, typename Rhs, int Mode>
92struct sparse_solve_triangular_selector<Lhs, Rhs, Mode, Lower, ColMajor> {
93 typedef typename Rhs::Scalar Scalar;
94 typedef evaluator<Lhs> LhsEval;
95 typedef typename evaluator<Lhs>::InnerIterator LhsIterator;
96 static void run(const Lhs& lhs, Rhs& other) {
97 LhsEval lhsEval(lhs);
98 for (Index col = 0; col < other.cols(); ++col) {
99 for (Index i = 0; i < lhs.cols(); ++i) {
100 Scalar& tmp = other.coeffRef(i, col);
101 if (!numext::is_exactly_zero(tmp)) // optimization when other is actually sparse
102 {
103 LhsIterator it(lhsEval, i);
104 while (it && it.index() < i) ++it;
105 if (!(Mode & UnitDiag)) {
106 eigen_assert(it && it.index() == i);
107 tmp /= it.value();
108 }
109 if (it && it.index() == i) ++it;
110 for (; it; ++it) {
111 other.coeffRef(it.index(), col) = numext::madd<Scalar>(-tmp, it.value(), other.coeffRef(it.index(), col));
112 }
113 }
114 }
115 }
116 }
117};
118
119// backward substitution, col-major
120template <typename Lhs, typename Rhs, int Mode>
121struct sparse_solve_triangular_selector<Lhs, Rhs, Mode, Upper, ColMajor> {
122 typedef typename Rhs::Scalar Scalar;
123 typedef evaluator<Lhs> LhsEval;
124 typedef typename evaluator<Lhs>::InnerIterator LhsIterator;
125 static void run(const Lhs& lhs, Rhs& other) {
126 LhsEval lhsEval(lhs);
127 for (Index col = 0; col < other.cols(); ++col) {
128 for (Index i = lhs.cols() - 1; i >= 0; --i) {
129 Scalar& tmp = other.coeffRef(i, col);
130 if (!numext::is_exactly_zero(tmp)) // optimization when other is actually sparse
131 {
132 if (!(Mode & UnitDiag)) {
133 // TODO replace this by a binary search. make sure the binary search is safe for partially sorted elements
134 LhsIterator it(lhsEval, i);
135 while (it && it.index() != i) ++it;
136 eigen_assert(it && it.index() == i);
137 other.coeffRef(i, col) /= it.value();
138 }
139 LhsIterator it(lhsEval, i);
140 for (; it && it.index() < i; ++it) {
141 other.coeffRef(it.index(), col) = numext::madd<Scalar>(-tmp, it.value(), other.coeffRef(it.index(), col));
142 }
143 }
144 }
145 }
146 }
147};
148
149} // end namespace internal
150
151#ifndef EIGEN_PARSED_BY_DOXYGEN
152
153template <typename ExpressionType, unsigned int Mode>
154template <typename OtherDerived>
155void TriangularViewImpl<ExpressionType, Mode, Sparse>::solveInPlace(MatrixBase<OtherDerived>& other) const {
156 eigen_assert(derived().cols() == derived().rows() && derived().cols() == other.rows());
157 eigen_assert((!(Mode & ZeroDiag)) && bool(Mode & (Upper | Lower)));
158
159 enum { copy = internal::traits<OtherDerived>::Flags & RowMajorBit };
160
161 typedef std::conditional_t<copy, typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&>
162 OtherCopy;
163 OtherCopy otherCopy(other.derived());
164
165 internal::sparse_solve_triangular_selector<ExpressionType, std::remove_reference_t<OtherCopy>, Mode>::run(
166 derived().nestedExpression(), otherCopy);
167
168 if (copy) other = otherCopy;
169}
170#endif
171
172// pure sparse path
173
174namespace internal {
175
176template <typename Lhs, typename Rhs, int Mode,
177 int UpLo = (Mode & Lower) ? Lower
178 : (Mode & Upper) ? Upper
179 : -1,
180 int StorageOrder = int(Lhs::Flags) & (RowMajorBit)>
181struct sparse_solve_triangular_sparse_selector;
182
183// forward substitution, col-major
184template <typename Lhs, typename Rhs, int Mode, int UpLo>
185struct sparse_solve_triangular_sparse_selector<Lhs, Rhs, Mode, UpLo, ColMajor> {
186 typedef typename Rhs::Scalar Scalar;
187 typedef typename promote_index_type<typename traits<Lhs>::StorageIndex, typename traits<Rhs>::StorageIndex>::type
188 StorageIndex;
189 static void run(const Lhs& lhs, Rhs& other) {
190 const bool IsLower = (UpLo == Lower);
191 AmbiVector<Scalar, StorageIndex> tempVector(other.rows() * 2);
192 tempVector.setBounds(0, other.rows());
193
194 Rhs res(other.rows(), other.cols());
195 res.reserve(other.nonZeros());
196
197 for (Index col = 0; col < other.cols(); ++col) {
198 // FIXME estimate number of non zeros
199 tempVector.init(.99 /*float(other.col(col).nonZeros())/float(other.rows())*/);
200 tempVector.setZero();
201 tempVector.restart();
202 for (typename Rhs::InnerIterator rhsIt(other, col); rhsIt; ++rhsIt) {
203 tempVector.coeffRef(rhsIt.index()) = rhsIt.value();
204 }
205
206 for (Index i = IsLower ? 0 : lhs.cols() - 1; IsLower ? i < lhs.cols() : i >= 0; i += IsLower ? 1 : -1) {
207 tempVector.restart();
208 Scalar& ci = tempVector.coeffRef(i);
209 if (!numext::is_exactly_zero(ci)) {
210 // find
211 typename Lhs::InnerIterator it(lhs, i);
212 if (!(Mode & UnitDiag)) {
213 if (IsLower) {
214 eigen_assert(it.index() == i);
215 ci /= it.value();
216 } else
217 ci /= lhs.coeff(i, i);
218 }
219 tempVector.restart();
220 if (IsLower) {
221 if (it.index() == i) ++it;
222 for (; it; ++it) {
223 tempVector.coeffRef(it.index()) = numext::madd<Scalar>(-ci, it.value(), tempVector.coeffRef(it.index()));
224 }
225 } else {
226 for (; it && it.index() < i; ++it) {
227 tempVector.coeffRef(it.index()) = numext::madd<Scalar>(-ci, it.value(), tempVector.coeffRef(it.index()));
228 }
229 }
230 }
231 }
232
233 // Index count = 0;
234 // FIXME compute a reference value to filter zeros
235 for (typename AmbiVector<Scalar, StorageIndex>::Iterator it(tempVector /*,1e-12*/); it; ++it) {
236 // ++ count;
237 // std::cerr << "fill " << it.index() << ", " << col << "\n";
238 // std::cout << it.value() << " ";
239 // FIXME use insertBack
240 res.insert(it.index(), col) = it.value();
241 }
242 // std::cout << "tempVector.nonZeros() == " << int(count) << " / " << (other.rows()) << "\n";
243 }
244 res.finalize();
245 other = res.markAsRValue();
246 }
247};
248
249} // end namespace internal
250
251#ifndef EIGEN_PARSED_BY_DOXYGEN
252template <typename ExpressionType, unsigned int Mode>
253template <typename OtherDerived>
254void TriangularViewImpl<ExpressionType, Mode, Sparse>::solveInPlace(SparseMatrixBase<OtherDerived>& other) const {
255 eigen_assert(derived().cols() == derived().rows() && derived().cols() == other.rows());
256 eigen_assert((!(Mode & ZeroDiag)) && bool(Mode & (Upper | Lower)));
257
258 // enum { copy = internal::traits<OtherDerived>::Flags & RowMajorBit };
259
260 // typedef std::conditional_t<copy,
261 // typename internal::plain_matrix_type_column_major<OtherDerived>::type, OtherDerived&> OtherCopy;
262 // OtherCopy otherCopy(other.derived());
263
264 internal::sparse_solve_triangular_sparse_selector<ExpressionType, OtherDerived, Mode>::run(
265 derived().nestedExpression(), other.derived());
266
267 // if (copy)
268 // other = otherCopy;
269}
270#endif
271
272} // end namespace Eigen
273
274#endif // EIGEN_SPARSETRIANGULARSOLVER_H
Base class for all dense matrices, vectors, and expressions.
Definition MatrixBase.h:52
Base class of any sparse matrices or sparse expressions.
Definition SparseMatrixBase.h:30
Definition AmbiVector.h:251
@ UnitDiag
Definition Constants.h:215
@ ZeroDiag
Definition Constants.h:217
@ Lower
Definition Constants.h:211
@ Upper
Definition Constants.h:213
const unsigned int RowMajorBit
Definition Constants.h:70
Namespace containing all symbols from the Eigen library.
Definition B01_Experimental.dox:1
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:82