Eigen  5.0.1-dev+284dcc12
 
Loading...
Searching...
No Matches
SelfadjointMatrixMatrix.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_SELFADJOINT_MATRIX_MATRIX_H
11#define EIGEN_SELFADJOINT_MATRIX_MATRIX_H
12
13// IWYU pragma: private
14#include "../InternalHeaderCheck.h"
15
16namespace Eigen {
17
18namespace internal {
19
20// pack a selfadjoint block diagonal for use with the gebp_kernel
21template <typename Scalar, typename Index, int Pack1, int Pack2_dummy, int StorageOrder>
22struct symm_pack_lhs {
23 template <int BlockRows>
24 inline void pack(Scalar* blockA, const const_blas_data_mapper<Scalar, Index, StorageOrder>& lhs, Index cols, Index i,
25 Index& count) {
26 // normal copy
27 for (Index k = 0; k < i; k++)
28 for (Index w = 0; w < BlockRows; w++) blockA[count++] = lhs(i + w, k); // normal
29 // symmetric copy
30 Index h = 0;
31 for (Index k = i; k < i + BlockRows; k++) {
32 for (Index w = 0; w < h; w++) blockA[count++] = numext::conj(lhs(k, i + w)); // transposed
33
34 blockA[count++] = numext::real(lhs(k, k)); // real (diagonal)
35
36 for (Index w = h + 1; w < BlockRows; w++) blockA[count++] = lhs(i + w, k); // normal
37 ++h;
38 }
39 // transposed copy
40 for (Index k = i + BlockRows; k < cols; k++)
41 for (Index w = 0; w < BlockRows; w++) blockA[count++] = numext::conj(lhs(k, i + w)); // transposed
42 }
43 void operator()(Scalar* blockA, const Scalar* lhs_, Index lhsStride, Index cols, Index rows) {
44 typedef typename unpacket_traits<typename packet_traits<Scalar>::type>::half HalfPacket;
45 typedef typename unpacket_traits<typename unpacket_traits<typename packet_traits<Scalar>::type>::half>::half
46 QuarterPacket;
47 enum {
48 PacketSize = packet_traits<Scalar>::size,
49 HalfPacketSize = unpacket_traits<HalfPacket>::size,
50 QuarterPacketSize = unpacket_traits<QuarterPacket>::size,
51 HasHalf = (int)HalfPacketSize < (int)PacketSize,
52 HasQuarter = (int)QuarterPacketSize < (int)HalfPacketSize
53 };
54
55 const_blas_data_mapper<Scalar, Index, StorageOrder> lhs(lhs_, lhsStride);
56 Index count = 0;
57 // Index peeled_mc3 = (rows/Pack1)*Pack1;
58
59 const Index peeled_mc3 = Pack1 >= 3 * PacketSize ? (rows / (3 * PacketSize)) * (3 * PacketSize) : 0;
60 const Index peeled_mc2 =
61 Pack1 >= 2 * PacketSize ? peeled_mc3 + ((rows - peeled_mc3) / (2 * PacketSize)) * (2 * PacketSize) : 0;
62 const Index peeled_mc1 =
63 Pack1 >= 1 * PacketSize ? peeled_mc2 + ((rows - peeled_mc2) / (1 * PacketSize)) * (1 * PacketSize) : 0;
64 const Index peeled_mc_half =
65 Pack1 >= HalfPacketSize ? peeled_mc1 + ((rows - peeled_mc1) / (HalfPacketSize)) * (HalfPacketSize) : 0;
66 const Index peeled_mc_quarter =
67 Pack1 >= QuarterPacketSize
68 ? peeled_mc_half + ((rows - peeled_mc_half) / (QuarterPacketSize)) * (QuarterPacketSize)
69 : 0;
70
71 if (Pack1 >= 3 * PacketSize)
72 for (Index i = 0; i < peeled_mc3; i += 3 * PacketSize) pack<3 * PacketSize>(blockA, lhs, cols, i, count);
73
74 if (Pack1 >= 2 * PacketSize)
75 for (Index i = peeled_mc3; i < peeled_mc2; i += 2 * PacketSize) pack<2 * PacketSize>(blockA, lhs, cols, i, count);
76
77 if (Pack1 >= 1 * PacketSize)
78 for (Index i = peeled_mc2; i < peeled_mc1; i += 1 * PacketSize) pack<1 * PacketSize>(blockA, lhs, cols, i, count);
79
80 if (HasHalf && Pack1 >= HalfPacketSize)
81 for (Index i = peeled_mc1; i < peeled_mc_half; i += HalfPacketSize)
82 pack<HalfPacketSize>(blockA, lhs, cols, i, count);
83
84 if (HasQuarter && Pack1 >= QuarterPacketSize)
85 for (Index i = peeled_mc_half; i < peeled_mc_quarter; i += QuarterPacketSize)
86 pack<QuarterPacketSize>(blockA, lhs, cols, i, count);
87
88 // do the same with mr==1
89 for (Index i = peeled_mc_quarter; i < rows; i++) {
90 for (Index k = 0; k < i; k++) blockA[count++] = lhs(i, k); // normal
91
92 blockA[count++] = numext::real(lhs(i, i)); // real (diagonal)
93
94 for (Index k = i + 1; k < cols; k++) blockA[count++] = numext::conj(lhs(k, i)); // transposed
95 }
96 }
97};
98
99template <typename Scalar, typename Index, int nr, int StorageOrder>
100struct symm_pack_rhs {
101 enum { PacketSize = packet_traits<Scalar>::size };
102 void operator()(Scalar* blockB, const Scalar* rhs_, Index rhsStride, Index rows, Index cols, Index k2) {
103 Index end_k = k2 + rows;
104 Index count = 0;
105 const_blas_data_mapper<Scalar, Index, StorageOrder> rhs(rhs_, rhsStride);
106 Index packet_cols8 = nr >= 8 ? (cols / 8) * 8 : 0;
107 Index packet_cols4 = nr >= 4 ? (cols / 4) * 4 : 0;
108
109 // first part: normal case
110 for (Index j2 = 0; j2 < k2; j2 += nr) {
111 for (Index k = k2; k < end_k; k++) {
112 blockB[count + 0] = rhs(k, j2 + 0);
113 blockB[count + 1] = rhs(k, j2 + 1);
114 if (nr >= 4) {
115 blockB[count + 2] = rhs(k, j2 + 2);
116 blockB[count + 3] = rhs(k, j2 + 3);
117 }
118 if (nr >= 8) {
119 blockB[count + 4] = rhs(k, j2 + 4);
120 blockB[count + 5] = rhs(k, j2 + 5);
121 blockB[count + 6] = rhs(k, j2 + 6);
122 blockB[count + 7] = rhs(k, j2 + 7);
123 }
124 count += nr;
125 }
126 }
127
128 // second part: diagonal block
129 Index end8 = nr >= 8 ? (std::min)(k2 + rows, packet_cols8) : k2;
130 if (nr >= 8) {
131 for (Index j2 = k2; j2 < end8; j2 += 8) {
132 // again we can split vertically in three different parts (transpose, symmetric, normal)
133 // transpose
134 for (Index k = k2; k < j2; k++) {
135 blockB[count + 0] = numext::conj(rhs(j2 + 0, k));
136 blockB[count + 1] = numext::conj(rhs(j2 + 1, k));
137 blockB[count + 2] = numext::conj(rhs(j2 + 2, k));
138 blockB[count + 3] = numext::conj(rhs(j2 + 3, k));
139 blockB[count + 4] = numext::conj(rhs(j2 + 4, k));
140 blockB[count + 5] = numext::conj(rhs(j2 + 5, k));
141 blockB[count + 6] = numext::conj(rhs(j2 + 6, k));
142 blockB[count + 7] = numext::conj(rhs(j2 + 7, k));
143 count += 8;
144 }
145 // symmetric
146 Index h = 0;
147 for (Index k = j2; k < j2 + 8; k++) {
148 // normal
149 for (Index w = 0; w < h; ++w) blockB[count + w] = rhs(k, j2 + w);
150
151 blockB[count + h] = numext::real(rhs(k, k));
152
153 // transpose
154 for (Index w = h + 1; w < 8; ++w) blockB[count + w] = numext::conj(rhs(j2 + w, k));
155 count += 8;
156 ++h;
157 }
158 // normal
159 for (Index k = j2 + 8; k < end_k; k++) {
160 blockB[count + 0] = rhs(k, j2 + 0);
161 blockB[count + 1] = rhs(k, j2 + 1);
162 blockB[count + 2] = rhs(k, j2 + 2);
163 blockB[count + 3] = rhs(k, j2 + 3);
164 blockB[count + 4] = rhs(k, j2 + 4);
165 blockB[count + 5] = rhs(k, j2 + 5);
166 blockB[count + 6] = rhs(k, j2 + 6);
167 blockB[count + 7] = rhs(k, j2 + 7);
168 count += 8;
169 }
170 }
171 }
172 if (nr >= 4) {
173 for (Index j2 = end8; j2 < (std::min)(k2 + rows, packet_cols4); j2 += 4) {
174 // again we can split vertically in three different parts (transpose, symmetric, normal)
175 // transpose
176 for (Index k = k2; k < j2; k++) {
177 blockB[count + 0] = numext::conj(rhs(j2 + 0, k));
178 blockB[count + 1] = numext::conj(rhs(j2 + 1, k));
179 blockB[count + 2] = numext::conj(rhs(j2 + 2, k));
180 blockB[count + 3] = numext::conj(rhs(j2 + 3, k));
181 count += 4;
182 }
183 // symmetric
184 Index h = 0;
185 for (Index k = j2; k < j2 + 4; k++) {
186 // normal
187 for (Index w = 0; w < h; ++w) blockB[count + w] = rhs(k, j2 + w);
188
189 blockB[count + h] = numext::real(rhs(k, k));
190
191 // transpose
192 for (Index w = h + 1; w < 4; ++w) blockB[count + w] = numext::conj(rhs(j2 + w, k));
193 count += 4;
194 ++h;
195 }
196 // normal
197 for (Index k = j2 + 4; k < end_k; k++) {
198 blockB[count + 0] = rhs(k, j2 + 0);
199 blockB[count + 1] = rhs(k, j2 + 1);
200 blockB[count + 2] = rhs(k, j2 + 2);
201 blockB[count + 3] = rhs(k, j2 + 3);
202 count += 4;
203 }
204 }
205 }
206
207 // third part: transposed
208 if (nr >= 8) {
209 for (Index j2 = k2 + rows; j2 < packet_cols8; j2 += 8) {
210 for (Index k = k2; k < end_k; k++) {
211 blockB[count + 0] = numext::conj(rhs(j2 + 0, k));
212 blockB[count + 1] = numext::conj(rhs(j2 + 1, k));
213 blockB[count + 2] = numext::conj(rhs(j2 + 2, k));
214 blockB[count + 3] = numext::conj(rhs(j2 + 3, k));
215 blockB[count + 4] = numext::conj(rhs(j2 + 4, k));
216 blockB[count + 5] = numext::conj(rhs(j2 + 5, k));
217 blockB[count + 6] = numext::conj(rhs(j2 + 6, k));
218 blockB[count + 7] = numext::conj(rhs(j2 + 7, k));
219 count += 8;
220 }
221 }
222 }
223 if (nr >= 4) {
224 for (Index j2 = (std::max)(packet_cols8, k2 + rows); j2 < packet_cols4; j2 += 4) {
225 for (Index k = k2; k < end_k; k++) {
226 blockB[count + 0] = numext::conj(rhs(j2 + 0, k));
227 blockB[count + 1] = numext::conj(rhs(j2 + 1, k));
228 blockB[count + 2] = numext::conj(rhs(j2 + 2, k));
229 blockB[count + 3] = numext::conj(rhs(j2 + 3, k));
230 count += 4;
231 }
232 }
233 }
234
235 // copy the remaining columns one at a time (=> the same with nr==1)
236 for (Index j2 = packet_cols4; j2 < cols; ++j2) {
237 // transpose
238 Index half = (std::min)(end_k, j2);
239 for (Index k = k2; k < half; k++) {
240 blockB[count] = numext::conj(rhs(j2, k));
241 count += 1;
242 }
243
244 if (half == j2 && half < k2 + rows) {
245 blockB[count] = numext::real(rhs(j2, j2));
246 count += 1;
247 } else
248 half--;
249
250 // normal
251 for (Index k = half + 1; k < k2 + rows; k++) {
252 blockB[count] = rhs(k, j2);
253 count += 1;
254 }
255 }
256 }
257};
258
259/* Optimized selfadjoint matrix * matrix (_SYMM) product built on top of
260 * the general matrix matrix product.
261 */
262template <typename Scalar, typename Index, int LhsStorageOrder, bool LhsSelfAdjoint, bool ConjugateLhs,
263 int RhsStorageOrder, bool RhsSelfAdjoint, bool ConjugateRhs, int ResStorageOrder, int ResInnerStride>
264struct product_selfadjoint_matrix;
265
266template <typename Scalar, typename Index, int LhsStorageOrder, bool LhsSelfAdjoint, bool ConjugateLhs,
267 int RhsStorageOrder, bool RhsSelfAdjoint, bool ConjugateRhs, int ResInnerStride>
268struct product_selfadjoint_matrix<Scalar, Index, LhsStorageOrder, LhsSelfAdjoint, ConjugateLhs, RhsStorageOrder,
269 RhsSelfAdjoint, ConjugateRhs, RowMajor, ResInnerStride> {
270 static EIGEN_STRONG_INLINE void run(Index rows, Index cols, const Scalar* lhs, Index lhsStride, const Scalar* rhs,
271 Index rhsStride, Scalar* res, Index resIncr, Index resStride, const Scalar& alpha,
272 level3_blocking<Scalar, Scalar>& blocking) {
273 product_selfadjoint_matrix<
274 Scalar, Index, logical_xor(RhsSelfAdjoint, RhsStorageOrder == RowMajor) ? ColMajor : RowMajor, RhsSelfAdjoint,
275 NumTraits<Scalar>::IsComplex && logical_xor(RhsSelfAdjoint, ConjugateRhs),
276 logical_xor(LhsSelfAdjoint, LhsStorageOrder == RowMajor) ? ColMajor : RowMajor, LhsSelfAdjoint,
277 NumTraits<Scalar>::IsComplex && logical_xor(LhsSelfAdjoint, ConjugateLhs), ColMajor,
278 ResInnerStride>::run(cols, rows, rhs, rhsStride, lhs, lhsStride, res, resIncr, resStride, alpha, blocking);
279 }
280};
281
282template <typename Scalar, typename Index, int LhsStorageOrder, bool ConjugateLhs, int RhsStorageOrder,
283 bool ConjugateRhs, int ResInnerStride>
284struct product_selfadjoint_matrix<Scalar, Index, LhsStorageOrder, true, ConjugateLhs, RhsStorageOrder, false,
285 ConjugateRhs, ColMajor, ResInnerStride> {
286 static EIGEN_DONT_INLINE void run(Index rows, Index cols, const Scalar* lhs_, Index lhsStride, const Scalar* rhs_,
287 Index rhsStride, Scalar* res, Index resIncr, Index resStride, const Scalar& alpha,
288 level3_blocking<Scalar, Scalar>& blocking);
289};
290
291template <typename Scalar, typename Index, int LhsStorageOrder, bool ConjugateLhs, int RhsStorageOrder,
292 bool ConjugateRhs, int ResInnerStride>
293EIGEN_DONT_INLINE void
294product_selfadjoint_matrix<Scalar, Index, LhsStorageOrder, true, ConjugateLhs, RhsStorageOrder, false, ConjugateRhs,
295 ColMajor, ResInnerStride>::run(Index rows, Index cols, const Scalar* lhs_, Index lhsStride,
296 const Scalar* rhs_, Index rhsStride, Scalar* res_,
297 Index resIncr, Index resStride, const Scalar& alpha,
298 level3_blocking<Scalar, Scalar>& blocking) {
299 Index size = rows;
300
301 typedef gebp_traits<Scalar, Scalar> Traits;
302
303 typedef const_blas_data_mapper<Scalar, Index, LhsStorageOrder> LhsMapper;
304 typedef const_blas_data_mapper<Scalar, Index, (LhsStorageOrder == RowMajor) ? ColMajor : RowMajor> LhsTransposeMapper;
305 typedef const_blas_data_mapper<Scalar, Index, RhsStorageOrder> RhsMapper;
306 typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
307 LhsMapper lhs(lhs_, lhsStride);
308 LhsTransposeMapper lhs_transpose(lhs_, lhsStride);
309 RhsMapper rhs(rhs_, rhsStride);
310 ResMapper res(res_, resStride, resIncr);
311
312 Index kc = blocking.kc(); // cache block size along the K direction
313 Index mc = (std::min)(rows, blocking.mc()); // cache block size along the M direction
314 // kc must be smaller than mc
315 kc = (std::min)(kc, mc);
316 std::size_t sizeA = kc * mc;
317 std::size_t sizeB = kc * cols;
318 ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA());
319 ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB());
320
321 gebp_kernel<Scalar, Scalar, Index, ResMapper, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp_kernel;
322 symm_pack_lhs<Scalar, Index, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs;
323 gemm_pack_rhs<Scalar, Index, RhsMapper, Traits::nr, RhsStorageOrder> pack_rhs;
324 gemm_pack_lhs<Scalar, Index, LhsTransposeMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing,
325 LhsStorageOrder == RowMajor ? ColMajor : RowMajor, true>
326 pack_lhs_transposed;
327
328 for (Index k2 = 0; k2 < size; k2 += kc) {
329 const Index actual_kc = (std::min)(k2 + kc, size) - k2;
330
331 // we have selected one row panel of rhs and one column panel of lhs
332 // pack rhs's panel into a sequential chunk of memory
333 // and expand each coeff to a constant packet for further reuse
334 pack_rhs(blockB, rhs.getSubMapper(k2, 0), actual_kc, cols);
335
336 // the select lhs's panel has to be split in three different parts:
337 // 1 - the transposed panel above the diagonal block => transposed packed copy
338 // 2 - the diagonal block => special packed copy
339 // 3 - the panel below the diagonal block => generic packed copy
340 for (Index i2 = 0; i2 < k2; i2 += mc) {
341 const Index actual_mc = (std::min)(i2 + mc, k2) - i2;
342 // transposed packed copy
343 pack_lhs_transposed(blockA, lhs_transpose.getSubMapper(i2, k2), actual_kc, actual_mc);
344
345 gebp_kernel(res.getSubMapper(i2, 0), blockA, blockB, actual_mc, actual_kc, cols, alpha);
346 }
347 // the block diagonal
348 {
349 const Index actual_mc = (std::min)(k2 + kc, size) - k2;
350 // symmetric packed copy
351 pack_lhs(blockA, &lhs(k2, k2), lhsStride, actual_kc, actual_mc);
352
353 gebp_kernel(res.getSubMapper(k2, 0), blockA, blockB, actual_mc, actual_kc, cols, alpha);
354 }
355
356 for (Index i2 = k2 + kc; i2 < size; i2 += mc) {
357 const Index actual_mc = (std::min)(i2 + mc, size) - i2;
358 gemm_pack_lhs<Scalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing,
359 LhsStorageOrder, false>()(blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc);
360
361 gebp_kernel(res.getSubMapper(i2, 0), blockA, blockB, actual_mc, actual_kc, cols, alpha);
362 }
363 }
364}
365
366// matrix * selfadjoint product
367template <typename Scalar, typename Index, int LhsStorageOrder, bool ConjugateLhs, int RhsStorageOrder,
368 bool ConjugateRhs, int ResInnerStride>
369struct product_selfadjoint_matrix<Scalar, Index, LhsStorageOrder, false, ConjugateLhs, RhsStorageOrder, true,
370 ConjugateRhs, ColMajor, ResInnerStride> {
371 static EIGEN_DONT_INLINE void run(Index rows, Index cols, const Scalar* lhs_, Index lhsStride, const Scalar* rhs_,
372 Index rhsStride, Scalar* res, Index resIncr, Index resStride, const Scalar& alpha,
373 level3_blocking<Scalar, Scalar>& blocking);
374};
375
376template <typename Scalar, typename Index, int LhsStorageOrder, bool ConjugateLhs, int RhsStorageOrder,
377 bool ConjugateRhs, int ResInnerStride>
378EIGEN_DONT_INLINE void
379product_selfadjoint_matrix<Scalar, Index, LhsStorageOrder, false, ConjugateLhs, RhsStorageOrder, true, ConjugateRhs,
380 ColMajor, ResInnerStride>::run(Index rows, Index cols, const Scalar* lhs_, Index lhsStride,
381 const Scalar* rhs_, Index rhsStride, Scalar* res_,
382 Index resIncr, Index resStride, const Scalar& alpha,
383 level3_blocking<Scalar, Scalar>& blocking) {
384 Index size = cols;
385
386 typedef gebp_traits<Scalar, Scalar> Traits;
387
388 typedef const_blas_data_mapper<Scalar, Index, LhsStorageOrder> LhsMapper;
389 typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
390 LhsMapper lhs(lhs_, lhsStride);
391 ResMapper res(res_, resStride, resIncr);
392
393 Index kc = blocking.kc(); // cache block size along the K direction
394 Index mc = (std::min)(rows, blocking.mc()); // cache block size along the M direction
395 std::size_t sizeA = kc * mc;
396 std::size_t sizeB = kc * cols;
397 ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA());
398 ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB());
399
400 gebp_kernel<Scalar, Scalar, Index, ResMapper, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp_kernel;
401 gemm_pack_lhs<Scalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing,
402 LhsStorageOrder>
403 pack_lhs;
404 symm_pack_rhs<Scalar, Index, Traits::nr, RhsStorageOrder> pack_rhs;
405
406 for (Index k2 = 0; k2 < size; k2 += kc) {
407 const Index actual_kc = (std::min)(k2 + kc, size) - k2;
408
409 pack_rhs(blockB, rhs_, rhsStride, actual_kc, cols, k2);
410
411 // => GEPP
412 for (Index i2 = 0; i2 < rows; i2 += mc) {
413 const Index actual_mc = (std::min)(i2 + mc, rows) - i2;
414 pack_lhs(blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc);
415
416 gebp_kernel(res.getSubMapper(i2, 0), blockA, blockB, actual_mc, actual_kc, cols, alpha);
417 }
418 }
419}
420
421} // end namespace internal
422
423/***************************************************************************
424 * Wrapper to product_selfadjoint_matrix
425 ***************************************************************************/
426
427namespace internal {
428
429template <typename Lhs, int LhsMode, typename Rhs, int RhsMode>
430struct selfadjoint_product_impl<Lhs, LhsMode, false, Rhs, RhsMode, false> {
431 typedef typename Product<Lhs, Rhs>::Scalar Scalar;
432
433 typedef internal::blas_traits<Lhs> LhsBlasTraits;
434 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
435 typedef internal::blas_traits<Rhs> RhsBlasTraits;
436 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
437
438 enum {
439 LhsIsUpper = (LhsMode & (Upper | Lower)) == Upper,
440 LhsIsSelfAdjoint = (LhsMode & SelfAdjoint) == SelfAdjoint,
441 RhsIsUpper = (RhsMode & (Upper | Lower)) == Upper,
442 RhsIsSelfAdjoint = (RhsMode & SelfAdjoint) == SelfAdjoint
443 };
444
445 template <typename Dest>
446 static void run(Dest& dst, const Lhs& a_lhs, const Rhs& a_rhs, const Scalar& alpha) {
447 eigen_assert(dst.rows() == a_lhs.rows() && dst.cols() == a_rhs.cols());
448
449 add_const_on_value_type_t<ActualLhsType> lhs = LhsBlasTraits::extract(a_lhs);
450 add_const_on_value_type_t<ActualRhsType> rhs = RhsBlasTraits::extract(a_rhs);
451
452 Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs) * RhsBlasTraits::extractScalarFactor(a_rhs);
453
454 typedef internal::gemm_blocking_space<(Dest::Flags & RowMajorBit) ? RowMajor : ColMajor, Scalar, Scalar,
455 Lhs::MaxRowsAtCompileTime, Rhs::MaxColsAtCompileTime,
456 Lhs::MaxColsAtCompileTime, 1>
457 BlockingType;
458
459 BlockingType blocking(lhs.rows(), rhs.cols(), lhs.cols(), 1, false);
460
461 internal::product_selfadjoint_matrix<
462 Scalar, Index,
463 internal::logical_xor(LhsIsUpper, internal::traits<Lhs>::Flags & RowMajorBit) ? RowMajor : ColMajor,
464 LhsIsSelfAdjoint,
465 NumTraits<Scalar>::IsComplex && internal::logical_xor(LhsIsUpper, bool(LhsBlasTraits::NeedToConjugate)),
466 internal::logical_xor(RhsIsUpper, internal::traits<Rhs>::Flags & RowMajorBit) ? RowMajor : ColMajor,
467 RhsIsSelfAdjoint,
468 NumTraits<Scalar>::IsComplex && internal::logical_xor(RhsIsUpper, bool(RhsBlasTraits::NeedToConjugate)),
469 internal::traits<Dest>::Flags & RowMajorBit ? RowMajor : ColMajor,
470 Dest::InnerStrideAtCompileTime>::run(lhs.rows(), rhs.cols(), // sizes
471 &lhs.coeffRef(0, 0), lhs.outerStride(), // lhs info
472 &rhs.coeffRef(0, 0), rhs.outerStride(), // rhs info
473 &dst.coeffRef(0, 0), dst.innerStride(), dst.outerStride(), // result info
474 actualAlpha, blocking // alpha
475 );
476 }
477};
478
479} // end namespace internal
480
481} // end namespace Eigen
482
483#endif // EIGEN_SELFADJOINT_MATRIX_MATRIX_H
@ SelfAdjoint
Definition Constants.h:227
@ Lower
Definition Constants.h:211
@ Upper
Definition Constants.h:213
@ ColMajor
Definition Constants.h:318
@ RowMajor
Definition Constants.h:320
const unsigned int RowMajorBit
Definition Constants.h:70
Namespace containing all symbols from the Eigen library.
Definition B01_Experimental.dox:1
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:82