Eigen  5.0.1-dev+60122df6
 
Loading...
Searching...
No Matches
SelfadjointMatrixVector.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2008-2009 Gael Guennebaud <gael.guennebaud@inria.fr>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_SELFADJOINT_MATRIX_VECTOR_H
11#define EIGEN_SELFADJOINT_MATRIX_VECTOR_H
12
13// IWYU pragma: private
14#include "../InternalHeaderCheck.h"
15
16namespace Eigen {
17
18namespace internal {
19
20/* Optimized selfadjoint matrix * vector product:
21 * This algorithm processes 2 columns at once that allows to both reduce
22 * the number of load/stores of the result by a factor 2 and to reduce
23 * the instruction dependency.
24 */
25
26template <typename Scalar, typename Index, int StorageOrder, int UpLo, bool ConjugateLhs, bool ConjugateRhs,
27 int Version = Specialized>
28struct selfadjoint_matrix_vector_product;
29
30template <typename Scalar, typename Index, int StorageOrder, int UpLo, bool ConjugateLhs, bool ConjugateRhs,
31 int Version>
32struct selfadjoint_matrix_vector_product
33
34{
35 static EIGEN_DONT_INLINE EIGEN_DEVICE_FUNC void run(Index size, const Scalar* lhs, Index lhsStride, const Scalar* rhs,
36 Scalar* res, Scalar alpha);
37};
38
39template <typename Scalar, typename Index, int StorageOrder, int UpLo, bool ConjugateLhs, bool ConjugateRhs,
40 int Version>
41EIGEN_DONT_INLINE EIGEN_DEVICE_FUNC void
42selfadjoint_matrix_vector_product<Scalar, Index, StorageOrder, UpLo, ConjugateLhs, ConjugateRhs, Version>::run(
43 Index size, const Scalar* lhs, Index lhsStride, const Scalar* rhs, Scalar* res, Scalar alpha) {
44 typedef typename packet_traits<Scalar>::type Packet;
45 typedef typename NumTraits<Scalar>::Real RealScalar;
46 const Index PacketSize = sizeof(Packet) / sizeof(Scalar);
47
48 enum {
49 IsRowMajor = StorageOrder == RowMajor ? 1 : 0,
50 IsLower = UpLo == Lower ? 1 : 0,
51 FirstTriangular = IsRowMajor == IsLower
52 };
53
54 conj_helper<Scalar, Scalar, NumTraits<Scalar>::IsComplex && logical_xor(ConjugateLhs, IsRowMajor), ConjugateRhs> cj0;
55 conj_helper<Scalar, Scalar, NumTraits<Scalar>::IsComplex && logical_xor(ConjugateLhs, !IsRowMajor), ConjugateRhs> cj1;
56 conj_helper<RealScalar, Scalar, false, ConjugateRhs> cjd;
57
58 conj_helper<Packet, Packet, NumTraits<Scalar>::IsComplex && logical_xor(ConjugateLhs, IsRowMajor), ConjugateRhs> pcj0;
59 conj_helper<Packet, Packet, NumTraits<Scalar>::IsComplex && logical_xor(ConjugateLhs, !IsRowMajor), ConjugateRhs>
60 pcj1;
61
62 Scalar cjAlpha = ConjugateRhs ? numext::conj(alpha) : alpha;
63
64 Index bound = numext::maxi(Index(0), size - 8) & 0xfffffffe;
65 if (FirstTriangular) bound = size - bound;
66
67 for (Index j = FirstTriangular ? bound : 0; j < (FirstTriangular ? size : bound); j += 2) {
68 const Scalar* EIGEN_RESTRICT A0 = lhs + j * lhsStride;
69 const Scalar* EIGEN_RESTRICT A1 = lhs + (j + 1) * lhsStride;
70
71 Scalar t0 = cjAlpha * rhs[j];
72 Packet ptmp0 = pset1<Packet>(t0);
73 Scalar t1 = cjAlpha * rhs[j + 1];
74 Packet ptmp1 = pset1<Packet>(t1);
75
76 Scalar t2(0);
77 Packet ptmp2 = pset1<Packet>(t2);
78 Scalar t3(0);
79 Packet ptmp3 = pset1<Packet>(t3);
80
81 Index starti = FirstTriangular ? 0 : j + 2;
82 Index endi = FirstTriangular ? j : size;
83 Index alignedStart = (starti) + internal::first_default_aligned(&res[starti], endi - starti);
84 Index alignedEnd = alignedStart + ((endi - alignedStart) / (PacketSize)) * (PacketSize);
85
86 res[j] += cjd.pmul(numext::real(A0[j]), t0);
87 res[j + 1] += cjd.pmul(numext::real(A1[j + 1]), t1);
88 if (FirstTriangular) {
89 res[j] += cj0.pmul(A1[j], t1);
90 t3 += cj1.pmul(A1[j], rhs[j]);
91 } else {
92 res[j + 1] += cj0.pmul(A0[j + 1], t0);
93 t2 += cj1.pmul(A0[j + 1], rhs[j + 1]);
94 }
95
96 for (Index i = starti; i < alignedStart; ++i) {
97 res[i] += cj0.pmul(A0[i], t0) + cj0.pmul(A1[i], t1);
98 t2 += cj1.pmul(A0[i], rhs[i]);
99 t3 += cj1.pmul(A1[i], rhs[i]);
100 }
101 // Yes this an optimization for gcc 4.3 and 4.4 (=> huge speed up)
102 // gcc 4.2 does this optimization automatically.
103 const Scalar* EIGEN_RESTRICT a0It = A0 + alignedStart;
104 const Scalar* EIGEN_RESTRICT a1It = A1 + alignedStart;
105 const Scalar* EIGEN_RESTRICT rhsIt = rhs + alignedStart;
106 Scalar* EIGEN_RESTRICT resIt = res + alignedStart;
107 for (Index i = alignedStart; i < alignedEnd; i += PacketSize) {
108 Packet A0i = ploadu<Packet>(a0It);
109 a0It += PacketSize;
110 Packet A1i = ploadu<Packet>(a1It);
111 a1It += PacketSize;
112 Packet Bi = ploadu<Packet>(rhsIt);
113 rhsIt += PacketSize; // FIXME should be aligned in most cases
114 Packet Xi = pload<Packet>(resIt);
115
116 Xi = pcj0.pmadd(A0i, ptmp0, pcj0.pmadd(A1i, ptmp1, Xi));
117 ptmp2 = pcj1.pmadd(A0i, Bi, ptmp2);
118 ptmp3 = pcj1.pmadd(A1i, Bi, ptmp3);
119 pstore(resIt, Xi);
120 resIt += PacketSize;
121 }
122 for (Index i = alignedEnd; i < endi; i++) {
123 res[i] += cj0.pmul(A0[i], t0) + cj0.pmul(A1[i], t1);
124 t2 += cj1.pmul(A0[i], rhs[i]);
125 t3 += cj1.pmul(A1[i], rhs[i]);
126 }
127
128 res[j] += alpha * (t2 + predux(ptmp2));
129 res[j + 1] += alpha * (t3 + predux(ptmp3));
130 }
131 for (Index j = FirstTriangular ? 0 : bound; j < (FirstTriangular ? bound : size); j++) {
132 const Scalar* EIGEN_RESTRICT A0 = lhs + j * lhsStride;
133
134 Scalar t1 = cjAlpha * rhs[j];
135 Scalar t2(0);
136 res[j] += cjd.pmul(numext::real(A0[j]), t1);
137 for (Index i = FirstTriangular ? 0 : j + 1; i < (FirstTriangular ? j : size); i++) {
138 res[i] += cj0.pmul(A0[i], t1);
139 t2 += cj1.pmul(A0[i], rhs[i]);
140 }
141 res[j] += alpha * t2;
142 }
143}
144
145} // end namespace internal
146
147/***************************************************************************
148 * Wrapper to product_selfadjoint_vector
149 ***************************************************************************/
150
151namespace internal {
152
153template <typename Lhs, int LhsMode, typename Rhs>
154struct selfadjoint_product_impl<Lhs, LhsMode, false, Rhs, 0, true> {
155 typedef typename Product<Lhs, Rhs>::Scalar Scalar;
156
157 typedef internal::blas_traits<Lhs> LhsBlasTraits;
158 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
159 typedef internal::remove_all_t<ActualLhsType> ActualLhsTypeCleaned;
160
161 typedef internal::blas_traits<Rhs> RhsBlasTraits;
162 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
163 typedef internal::remove_all_t<ActualRhsType> ActualRhsTypeCleaned;
164
165 enum { LhsUpLo = LhsMode & (Upper | Lower) };
166
167 // Verify that the Rhs is a vector in the correct orientation.
168 // Otherwise, we break the assumption that we are multiplying
169 // MxN * Nx1.
170 static_assert(Rhs::ColsAtCompileTime == 1, "The RHS must be a column vector.");
171
172 template <typename Dest>
173 static EIGEN_DEVICE_FUNC void run(Dest& dest, const Lhs& a_lhs, const Rhs& a_rhs, const Scalar& alpha) {
174 typedef typename Dest::Scalar ResScalar;
175 typedef typename Rhs::Scalar RhsScalar;
176 typedef Map<Matrix<ResScalar, Dynamic, 1>, plain_enum_min(AlignedMax, internal::packet_traits<ResScalar>::size)>
177 MappedDest;
178
179 eigen_assert(dest.rows() == a_lhs.rows() && dest.cols() == a_rhs.cols());
180
181 add_const_on_value_type_t<ActualLhsType> lhs = LhsBlasTraits::extract(a_lhs);
182 add_const_on_value_type_t<ActualRhsType> rhs = RhsBlasTraits::extract(a_rhs);
183
184 Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs) * RhsBlasTraits::extractScalarFactor(a_rhs);
185
186 enum {
187 EvalToDest = (Dest::InnerStrideAtCompileTime == 1),
188 UseRhs = (ActualRhsTypeCleaned::InnerStrideAtCompileTime == 1)
189 };
190
191 internal::gemv_static_vector_if<ResScalar, Dest::SizeAtCompileTime, Dest::MaxSizeAtCompileTime, !EvalToDest>
192 static_dest;
193 internal::gemv_static_vector_if<RhsScalar, ActualRhsTypeCleaned::SizeAtCompileTime,
194 ActualRhsTypeCleaned::MaxSizeAtCompileTime, !UseRhs>
195 static_rhs;
196
197 ei_declare_aligned_stack_constructed_variable(ResScalar, actualDestPtr, dest.size(),
198 EvalToDest ? dest.data() : static_dest.data());
199
200 ei_declare_aligned_stack_constructed_variable(RhsScalar, actualRhsPtr, rhs.size(),
201 UseRhs ? const_cast<RhsScalar*>(rhs.data()) : static_rhs.data());
202
203 if (!EvalToDest) {
204#ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
205 constexpr int Size = Dest::SizeAtCompileTime;
206 Index size = dest.size();
207 EIGEN_DENSE_STORAGE_CTOR_PLUGIN
208#endif
209 MappedDest(actualDestPtr, dest.size()) = dest;
210 }
211
212 if (!UseRhs) {
213#ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
214 constexpr int Size = ActualRhsTypeCleaned::SizeAtCompileTime;
215 Index size = rhs.size();
216 EIGEN_DENSE_STORAGE_CTOR_PLUGIN
217#endif
218 Map<typename ActualRhsTypeCleaned::PlainObject>(actualRhsPtr, rhs.size()) = rhs;
219 }
220
221 internal::selfadjoint_matrix_vector_product<
222 Scalar, Index, (internal::traits<ActualLhsTypeCleaned>::Flags & RowMajorBit) ? RowMajor : ColMajor,
223 int(LhsUpLo), bool(LhsBlasTraits::NeedToConjugate),
224 bool(RhsBlasTraits::NeedToConjugate)>::run(lhs.rows(), // size
225 &lhs.coeffRef(0, 0), lhs.outerStride(), // lhs info
226 actualRhsPtr, // rhs info
227 actualDestPtr, // result info
228 actualAlpha // scale factor
229 );
230
231 if (!EvalToDest) dest = MappedDest(actualDestPtr, dest.size());
232 }
233};
234
235template <typename Lhs, typename Rhs, int RhsMode>
236struct selfadjoint_product_impl<Lhs, 0, true, Rhs, RhsMode, false> {
237 typedef typename Product<Lhs, Rhs>::Scalar Scalar;
238 enum { RhsUpLo = RhsMode & (Upper | Lower) };
239
240 template <typename Dest>
241 static void run(Dest& dest, const Lhs& a_lhs, const Rhs& a_rhs, const Scalar& alpha) {
242 // let's simply transpose the product
243 Transpose<Dest> destT(dest);
244 selfadjoint_product_impl<Transpose<const Rhs>, int(RhsUpLo) == Upper ? Lower : Upper, false, Transpose<const Lhs>,
245 0, true>::run(destT, a_rhs.transpose(), a_lhs.transpose(), alpha);
246 }
247};
248
249} // end namespace internal
250
251} // end namespace Eigen
252
253#endif // EIGEN_SELFADJOINT_MATRIX_VECTOR_H
@ Lower
Definition Constants.h:211
@ Upper
Definition Constants.h:213
@ ColMajor
Definition Constants.h:318
@ RowMajor
Definition Constants.h:320
const unsigned int RowMajorBit
Definition Constants.h:70
Namespace containing all symbols from the Eigen library.
Definition B01_Experimental.dox:1
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:82