Eigen  5.0.1-dev+60122df6
 
Loading...
Searching...
No Matches
TriangularMatrixVector.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_TRIANGULARMATRIXVECTOR_H
11#define EIGEN_TRIANGULARMATRIXVECTOR_H
12
13// IWYU pragma: private
14#include "../InternalHeaderCheck.h"
15
16namespace Eigen {
17
18namespace internal {
19
20template <typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,
21 int StorageOrder, int Version = Specialized>
22struct triangular_matrix_vector_product;
23
24template <typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version>
25struct triangular_matrix_vector_product<Index, Mode, LhsScalar, ConjLhs, RhsScalar, ConjRhs, ColMajor, Version> {
26 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
27 static constexpr bool IsLower = ((Mode & Lower) == Lower);
28 static constexpr bool HasUnitDiag = (Mode & UnitDiag) == UnitDiag;
29 static constexpr bool HasZeroDiag = (Mode & ZeroDiag) == ZeroDiag;
30 static EIGEN_DONT_INLINE void run(Index _rows, Index _cols, const LhsScalar* lhs_, Index lhsStride,
31 const RhsScalar* rhs_, Index rhsIncr, ResScalar* res_, Index resIncr,
32 const RhsScalar& alpha);
33};
34
35template <typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version>
36EIGEN_DONT_INLINE void triangular_matrix_vector_product<Index, Mode, LhsScalar, ConjLhs, RhsScalar, ConjRhs, ColMajor,
37 Version>::run(Index _rows, Index _cols, const LhsScalar* lhs_,
38 Index lhsStride, const RhsScalar* rhs_,
39 Index rhsIncr, ResScalar* res_, Index resIncr,
40 const RhsScalar& alpha) {
41 static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
42 Index size = (std::min)(_rows, _cols);
43 Index rows = IsLower ? _rows : (std::min)(_rows, _cols);
44 Index cols = IsLower ? (std::min)(_rows, _cols) : _cols;
45
46 typedef Map<const Matrix<LhsScalar, Dynamic, Dynamic, ColMajor>, 0, OuterStride<> > LhsMap;
47 const LhsMap lhs(lhs_, rows, cols, OuterStride<>(lhsStride));
48 typename conj_expr_if<ConjLhs, LhsMap>::type cjLhs(lhs);
49
50 typedef Map<const Matrix<RhsScalar, Dynamic, 1>, 0, InnerStride<> > RhsMap;
51 const RhsMap rhs(rhs_, cols, InnerStride<>(rhsIncr));
52 typename conj_expr_if<ConjRhs, RhsMap>::type cjRhs(rhs);
53
54 typedef Map<Matrix<ResScalar, Dynamic, 1> > ResMap;
55 ResMap res(res_, rows);
56
57 typedef const_blas_data_mapper<LhsScalar, Index, ColMajor> LhsMapper;
58 typedef const_blas_data_mapper<RhsScalar, Index, RowMajor> RhsMapper;
59
60 for (Index pi = 0; pi < size; pi += PanelWidth) {
61 Index actualPanelWidth = (std::min)(PanelWidth, size - pi);
62 for (Index k = 0; k < actualPanelWidth; ++k) {
63 Index i = pi + k;
64 Index s = IsLower ? ((HasUnitDiag || HasZeroDiag) ? i + 1 : i) : pi;
65 Index r = IsLower ? actualPanelWidth - k : k + 1;
66 if ((!(HasUnitDiag || HasZeroDiag)) || (--r) > 0)
67 res.segment(s, r) += (alpha * cjRhs.coeff(i)) * cjLhs.col(i).segment(s, r);
68 if (HasUnitDiag) res.coeffRef(i) += alpha * cjRhs.coeff(i);
69 }
70 Index r = IsLower ? rows - pi - actualPanelWidth : pi;
71 if (r > 0) {
72 Index s = IsLower ? pi + actualPanelWidth : 0;
73 general_matrix_vector_product<Index, LhsScalar, LhsMapper, ColMajor, ConjLhs, RhsScalar, RhsMapper, ConjRhs,
74 BuiltIn>::run(r, actualPanelWidth, LhsMapper(&lhs.coeffRef(s, pi), lhsStride),
75 RhsMapper(&rhs.coeffRef(pi), rhsIncr), &res.coeffRef(s), resIncr,
76 alpha);
77 }
78 }
79 if ((!IsLower) && cols > size) {
80 general_matrix_vector_product<Index, LhsScalar, LhsMapper, ColMajor, ConjLhs, RhsScalar, RhsMapper, ConjRhs>::run(
81 rows, cols - size, LhsMapper(&lhs.coeffRef(0, size), lhsStride), RhsMapper(&rhs.coeffRef(size), rhsIncr), res_,
82 resIncr, alpha);
83 }
84}
85
86template <typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version>
87struct triangular_matrix_vector_product<Index, Mode, LhsScalar, ConjLhs, RhsScalar, ConjRhs, RowMajor, Version> {
88 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
89 static constexpr bool IsLower = ((Mode & Lower) == Lower);
90 static constexpr bool HasUnitDiag = (Mode & UnitDiag) == UnitDiag;
91 static constexpr bool HasZeroDiag = (Mode & ZeroDiag) == ZeroDiag;
92 static EIGEN_DONT_INLINE void run(Index _rows, Index _cols, const LhsScalar* lhs_, Index lhsStride,
93 const RhsScalar* rhs_, Index rhsIncr, ResScalar* res_, Index resIncr,
94 const ResScalar& alpha);
95};
96
97template <typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int Version>
98EIGEN_DONT_INLINE void triangular_matrix_vector_product<Index, Mode, LhsScalar, ConjLhs, RhsScalar, ConjRhs, RowMajor,
99 Version>::run(Index _rows, Index _cols, const LhsScalar* lhs_,
100 Index lhsStride, const RhsScalar* rhs_,
101 Index rhsIncr, ResScalar* res_, Index resIncr,
102 const ResScalar& alpha) {
103 static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
104 Index diagSize = (std::min)(_rows, _cols);
105 Index rows = IsLower ? _rows : diagSize;
106 Index cols = IsLower ? diagSize : _cols;
107
108 typedef Map<const Matrix<LhsScalar, Dynamic, Dynamic, RowMajor>, 0, OuterStride<> > LhsMap;
109 const LhsMap lhs(lhs_, rows, cols, OuterStride<>(lhsStride));
110 typename conj_expr_if<ConjLhs, LhsMap>::type cjLhs(lhs);
111
112 typedef Map<const Matrix<RhsScalar, Dynamic, 1> > RhsMap;
113 const RhsMap rhs(rhs_, cols);
114 typename conj_expr_if<ConjRhs, RhsMap>::type cjRhs(rhs);
115
116 typedef Map<Matrix<ResScalar, Dynamic, 1>, 0, InnerStride<> > ResMap;
117 ResMap res(res_, rows, InnerStride<>(resIncr));
118
119 typedef const_blas_data_mapper<LhsScalar, Index, RowMajor> LhsMapper;
120 typedef const_blas_data_mapper<RhsScalar, Index, RowMajor> RhsMapper;
121
122 for (Index pi = 0; pi < diagSize; pi += PanelWidth) {
123 Index actualPanelWidth = (std::min)(PanelWidth, diagSize - pi);
124 for (Index k = 0; k < actualPanelWidth; ++k) {
125 Index i = pi + k;
126 Index s = IsLower ? pi : ((HasUnitDiag || HasZeroDiag) ? i + 1 : i);
127 Index r = IsLower ? k + 1 : actualPanelWidth - k;
128 if ((!(HasUnitDiag || HasZeroDiag)) || (--r) > 0)
129 res.coeffRef(i) += alpha * (cjLhs.row(i).segment(s, r).cwiseProduct(cjRhs.segment(s, r).transpose())).sum();
130 if (HasUnitDiag) res.coeffRef(i) += alpha * cjRhs.coeff(i);
131 }
132 Index r = IsLower ? pi : cols - pi - actualPanelWidth;
133 if (r > 0) {
134 Index s = IsLower ? 0 : pi + actualPanelWidth;
135 general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjLhs, RhsScalar, RhsMapper, ConjRhs,
136 BuiltIn>::run(actualPanelWidth, r, LhsMapper(&lhs.coeffRef(pi, s), lhsStride),
137 RhsMapper(&rhs.coeffRef(s), rhsIncr), &res.coeffRef(pi), resIncr,
138 alpha);
139 }
140 }
141 if (IsLower && rows > diagSize) {
142 general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjLhs, RhsScalar, RhsMapper, ConjRhs>::run(
143 rows - diagSize, cols, LhsMapper(&lhs.coeffRef(diagSize, 0), lhsStride), RhsMapper(&rhs.coeffRef(0), rhsIncr),
144 &res.coeffRef(diagSize), resIncr, alpha);
145 }
146}
147
148/***************************************************************************
149 * Wrapper to product_triangular_vector
150 ***************************************************************************/
151
152template <int Mode, int StorageOrder>
153struct trmv_selector;
154
155} // end namespace internal
156
157namespace internal {
158
159template <int Mode, typename Lhs, typename Rhs>
160struct triangular_product_impl<Mode, true, Lhs, false, Rhs, true> {
161 template <typename Dest>
162 static void run(Dest& dst, const Lhs& lhs, const Rhs& rhs, const typename Dest::Scalar& alpha) {
163 eigen_assert(dst.rows() == lhs.rows() && dst.cols() == rhs.cols());
164
165 internal::trmv_selector<Mode, (int(internal::traits<Lhs>::Flags) & RowMajorBit) ? RowMajor : ColMajor>::run(
166 lhs, rhs, dst, alpha);
167 }
168};
169
170template <int Mode, typename Lhs, typename Rhs>
171struct triangular_product_impl<Mode, false, Lhs, true, Rhs, false> {
172 template <typename Dest>
173 static void run(Dest& dst, const Lhs& lhs, const Rhs& rhs, const typename Dest::Scalar& alpha) {
174 eigen_assert(dst.rows() == lhs.rows() && dst.cols() == rhs.cols());
175
176 Transpose<Dest> dstT(dst);
177 internal::trmv_selector<(Mode & (UnitDiag | ZeroDiag)) | ((Mode & Lower) ? Upper : Lower),
178 (int(internal::traits<Rhs>::Flags) & RowMajorBit) ? ColMajor
179 : RowMajor>::run(rhs.transpose(),
180 lhs.transpose(), dstT,
181 alpha);
182 }
183};
184
185} // end namespace internal
186
187namespace internal {
188
189// TODO: find a way to factorize this piece of code with gemv_selector since the logic is exactly the same.
190
191template <int Mode>
192struct trmv_selector<Mode, ColMajor> {
193 template <typename Lhs, typename Rhs, typename Dest>
194 static void run(const Lhs& lhs, const Rhs& rhs, Dest& dest, const typename Dest::Scalar& alpha) {
195 typedef typename Lhs::Scalar LhsScalar;
196 typedef typename Rhs::Scalar RhsScalar;
197 typedef typename Dest::Scalar ResScalar;
198
199 typedef internal::blas_traits<Lhs> LhsBlasTraits;
200 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
201 typedef internal::blas_traits<Rhs> RhsBlasTraits;
202 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
203 constexpr int Alignment = (std::min)(int(AlignedMax), int(internal::packet_traits<ResScalar>::size));
204
205 typedef Map<Matrix<ResScalar, Dynamic, 1>, Alignment> MappedDest;
206
207 add_const_on_value_type_t<ActualLhsType> actualLhs = LhsBlasTraits::extract(lhs);
208 add_const_on_value_type_t<ActualRhsType> actualRhs = RhsBlasTraits::extract(rhs);
209
210 LhsScalar lhs_alpha = LhsBlasTraits::extractScalarFactor(lhs);
211 RhsScalar rhs_alpha = RhsBlasTraits::extractScalarFactor(rhs);
212 ResScalar actualAlpha = alpha * lhs_alpha * rhs_alpha;
213
214 // FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
215 // on, the other hand it is good for the cache to pack the vector anyways...
216 constexpr bool EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime == 1;
217 constexpr bool ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex);
218 constexpr bool MightCannotUseDest = (Dest::InnerStrideAtCompileTime != 1) || ComplexByReal;
219
220 gemv_static_vector_if<ResScalar, Dest::SizeAtCompileTime, Dest::MaxSizeAtCompileTime, MightCannotUseDest>
221 static_dest;
222
223 bool alphaIsCompatible = (!ComplexByReal) || numext::is_exactly_zero(numext::imag(actualAlpha));
224 bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
225
226 RhsScalar compatibleAlpha = get_factor<ResScalar, RhsScalar>::run(actualAlpha);
227
228 ei_declare_aligned_stack_constructed_variable(ResScalar, actualDestPtr, dest.size(),
229 evalToDest ? dest.data() : static_dest.data());
230
231 if (!evalToDest) {
232#ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
233 constexpr int Size = Dest::SizeAtCompileTime;
234 Index size = dest.size();
235 EIGEN_DENSE_STORAGE_CTOR_PLUGIN
236#endif
237 if (!alphaIsCompatible) {
238 MappedDest(actualDestPtr, dest.size()).setZero();
239 compatibleAlpha = RhsScalar(1);
240 } else
241 MappedDest(actualDestPtr, dest.size()) = dest;
242 }
243
244 internal::triangular_matrix_vector_product<Index, Mode, LhsScalar, LhsBlasTraits::NeedToConjugate, RhsScalar,
245 RhsBlasTraits::NeedToConjugate, ColMajor>::run(actualLhs.rows(),
246 actualLhs.cols(),
247 actualLhs.data(),
248 actualLhs.outerStride(),
249 actualRhs.data(),
250 actualRhs.innerStride(),
251 actualDestPtr, 1,
252 compatibleAlpha);
253
254 if (!evalToDest) {
255 if (!alphaIsCompatible)
256 dest += actualAlpha * MappedDest(actualDestPtr, dest.size());
257 else
258 dest = MappedDest(actualDestPtr, dest.size());
259 }
260
261 if (((Mode & UnitDiag) == UnitDiag) && !numext::is_exactly_one(lhs_alpha)) {
262 Index diagSize = (std::min)(lhs.rows(), lhs.cols());
263 dest.head(diagSize) -= (lhs_alpha - LhsScalar(1)) * rhs.head(diagSize);
264 }
265 }
266};
267
268template <int Mode>
269struct trmv_selector<Mode, RowMajor> {
270 template <typename Lhs, typename Rhs, typename Dest>
271 static void run(const Lhs& lhs, const Rhs& rhs, Dest& dest, const typename Dest::Scalar& alpha) {
272 typedef typename Lhs::Scalar LhsScalar;
273 typedef typename Rhs::Scalar RhsScalar;
274 typedef typename Dest::Scalar ResScalar;
275
276 typedef internal::blas_traits<Lhs> LhsBlasTraits;
277 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
278 typedef internal::blas_traits<Rhs> RhsBlasTraits;
279 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
280 typedef internal::remove_all_t<ActualRhsType> ActualRhsTypeCleaned;
281
282 std::add_const_t<ActualLhsType> actualLhs = LhsBlasTraits::extract(lhs);
283 std::add_const_t<ActualRhsType> actualRhs = RhsBlasTraits::extract(rhs);
284
285 LhsScalar lhs_alpha = LhsBlasTraits::extractScalarFactor(lhs);
286 RhsScalar rhs_alpha = RhsBlasTraits::extractScalarFactor(rhs);
287 ResScalar actualAlpha = alpha * lhs_alpha * rhs_alpha;
288
289 constexpr bool DirectlyUseRhs = ActualRhsTypeCleaned::InnerStrideAtCompileTime == 1;
290
291 const RhsScalar* actualRhsPtr = actualRhs.data();
292
293 // Potentially create a temporary buffer to copy RHS to contiguous memory.
294 gemv_static_vector_if<RhsScalar, ActualRhsTypeCleaned::SizeAtCompileTime,
295 ActualRhsTypeCleaned::MaxSizeAtCompileTime, !DirectlyUseRhs>
296 static_rhs; // Fixed-sized array.
297 RhsScalar* buffer = nullptr;
298 if (!DirectlyUseRhs) {
299 // Maybe used fixed-sized buffer, otherwise allocate.
300 if (static_rhs.data() != nullptr) {
301 buffer = static_rhs.data();
302 } else {
303 // Allocate either with alloca or malloc.
304 Eigen::internal::check_size_for_overflow<RhsScalar>(actualRhs.size());
305#ifdef EIGEN_ALLOCA
306 buffer = static_cast<RhsScalar*>((sizeof(RhsScalar) * actualRhs.size() <= EIGEN_STACK_ALLOCATION_LIMIT)
307 ? EIGEN_ALIGNED_ALLOCA(sizeof(RhsScalar) * actualRhs.size())
308 : Eigen::internal::aligned_malloc(sizeof(RhsScalar) * actualRhs.size()));
309#else
310 buffer = static_cast<RhsScalar*>(Eigen::internal::aligned_malloc(sizeof(RhsScalar) * actualRhs.size()));
311#endif
312 }
313#ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
314 constexpr int Size = ActualRhsTypeCleaned::SizeAtCompileTime;
315 Index size = actualRhs.size();
316 EIGEN_DENSE_STORAGE_CTOR_PLUGIN
317#endif
318 Map<typename ActualRhsTypeCleaned::PlainObject, Eigen::AlignedMax>(buffer, actualRhs.size()) = actualRhs;
319 actualRhsPtr = buffer;
320 }
321 // Deallocate only if malloced.
322 Eigen::internal::aligned_stack_memory_handler<RhsScalar> buffer_stack_memory_destructor(
323 buffer, actualRhs.size(),
324 !DirectlyUseRhs && static_rhs.data() == nullptr && actualRhs.size() > EIGEN_STACK_ALLOCATION_LIMIT);
325
326 internal::triangular_matrix_vector_product<Index, Mode, LhsScalar, LhsBlasTraits::NeedToConjugate, RhsScalar,
327 RhsBlasTraits::NeedToConjugate, RowMajor>::run(actualLhs.rows(),
328 actualLhs.cols(),
329 actualLhs.data(),
330 actualLhs.outerStride(),
331 actualRhsPtr, 1,
332 dest.data(),
333 dest.innerStride(),
334 actualAlpha);
335
336 if (((Mode & UnitDiag) == UnitDiag) && !numext::is_exactly_one(lhs_alpha)) {
337 Index diagSize = (std::min)(lhs.rows(), lhs.cols());
338 dest.head(diagSize) -= (lhs_alpha - LhsScalar(1)) * rhs.head(diagSize);
339 }
340 }
341};
342
343} // end namespace internal
344
345} // end namespace Eigen
346
347#endif // EIGEN_TRIANGULARMATRIXVECTOR_H
@ UnitDiag
Definition Constants.h:215
@ ZeroDiag
Definition Constants.h:217
@ Lower
Definition Constants.h:211
@ Upper
Definition Constants.h:213
@ ColMajor
Definition Constants.h:318
@ RowMajor
Definition Constants.h:320
const unsigned int RowMajorBit
Definition Constants.h:70
Namespace containing all symbols from the Eigen library.
Definition B01_Experimental.dox:1
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:82