10#ifndef EIGEN_SELFADJOINT_MATRIX_MATRIX_H
11#define EIGEN_SELFADJOINT_MATRIX_MATRIX_H
14#include "../InternalHeaderCheck.h"
21template <
typename Scalar,
typename Index,
int Pack1,
int Pack2_dummy,
int StorageOrder>
23 template <
int BlockRows>
24 inline void pack(Scalar* blockA,
const const_blas_data_mapper<Scalar, Index, StorageOrder>& lhs,
Index cols,
Index i,
27 for (
Index k = 0; k < i; k++)
28 for (
Index w = 0; w < BlockRows; w++) blockA[count++] = lhs(i + w, k);
31 for (
Index k = i; k < i + BlockRows; k++) {
32 for (
Index w = 0; w < h; w++) blockA[count++] = numext::conj(lhs(k, i + w));
34 blockA[count++] = numext::real(lhs(k, k));
36 for (
Index w = h + 1; w < BlockRows; w++) blockA[count++] = lhs(i + w, k);
40 for (
Index k = i + BlockRows; k < cols; k++)
41 for (
Index w = 0; w < BlockRows; w++) blockA[count++] = numext::conj(lhs(k, i + w));
43 void operator()(Scalar* blockA,
const Scalar* lhs_,
Index lhsStride,
Index cols,
Index rows) {
44 typedef typename unpacket_traits<typename packet_traits<Scalar>::type>::half HalfPacket;
45 typedef typename unpacket_traits<typename unpacket_traits<typename packet_traits<Scalar>::type>::half>::half
48 PacketSize = packet_traits<Scalar>::size,
49 HalfPacketSize = unpacket_traits<HalfPacket>::size,
50 QuarterPacketSize = unpacket_traits<QuarterPacket>::size,
51 HasHalf = (int)HalfPacketSize < (
int)PacketSize,
52 HasQuarter = (int)QuarterPacketSize < (
int)HalfPacketSize
55 const_blas_data_mapper<Scalar, Index, StorageOrder> lhs(lhs_, lhsStride);
59 const Index peeled_mc3 = Pack1 >= 3 * PacketSize ? (rows / (3 * PacketSize)) * (3 * PacketSize) : 0;
60 const Index peeled_mc2 =
61 Pack1 >= 2 * PacketSize ? peeled_mc3 + ((rows - peeled_mc3) / (2 * PacketSize)) * (2 * PacketSize) : 0;
62 const Index peeled_mc1 =
63 Pack1 >= 1 * PacketSize ? peeled_mc2 + ((rows - peeled_mc2) / (1 * PacketSize)) * (1 * PacketSize) : 0;
64 const Index peeled_mc_half =
65 Pack1 >= HalfPacketSize ? peeled_mc1 + ((rows - peeled_mc1) / (HalfPacketSize)) * (HalfPacketSize) : 0;
66 const Index peeled_mc_quarter =
67 Pack1 >= QuarterPacketSize
68 ? peeled_mc_half + ((rows - peeled_mc_half) / (QuarterPacketSize)) * (QuarterPacketSize)
71 if (Pack1 >= 3 * PacketSize)
72 for (
Index i = 0; i < peeled_mc3; i += 3 * PacketSize) pack<3 * PacketSize>(blockA, lhs, cols, i, count);
74 if (Pack1 >= 2 * PacketSize)
75 for (
Index i = peeled_mc3; i < peeled_mc2; i += 2 * PacketSize) pack<2 * PacketSize>(blockA, lhs, cols, i, count);
77 if (Pack1 >= 1 * PacketSize)
78 for (
Index i = peeled_mc2; i < peeled_mc1; i += 1 * PacketSize) pack<1 * PacketSize>(blockA, lhs, cols, i, count);
80 if (HasHalf && Pack1 >= HalfPacketSize)
81 for (
Index i = peeled_mc1; i < peeled_mc_half; i += HalfPacketSize)
82 pack<HalfPacketSize>(blockA, lhs, cols, i, count);
84 if (HasQuarter && Pack1 >= QuarterPacketSize)
85 for (
Index i = peeled_mc_half; i < peeled_mc_quarter; i += QuarterPacketSize)
86 pack<QuarterPacketSize>(blockA, lhs, cols, i, count);
89 for (
Index i = peeled_mc_quarter; i < rows; i++) {
90 for (
Index k = 0; k < i; k++) blockA[count++] = lhs(i, k);
92 blockA[count++] = numext::real(lhs(i, i));
94 for (
Index k = i + 1; k < cols; k++) blockA[count++] = numext::conj(lhs(k, i));
99template <
typename Scalar,
typename Index,
int nr,
int StorageOrder>
100struct symm_pack_rhs {
101 enum { PacketSize = packet_traits<Scalar>::size };
102 void operator()(Scalar* blockB,
const Scalar* rhs_,
Index rhsStride,
Index rows,
Index cols,
Index k2) {
103 Index end_k = k2 + rows;
105 const_blas_data_mapper<Scalar, Index, StorageOrder> rhs(rhs_, rhsStride);
106 Index packet_cols8 = nr >= 8 ? (cols / 8) * 8 : 0;
107 Index packet_cols4 = nr >= 4 ? (cols / 4) * 4 : 0;
110 for (
Index j2 = 0; j2 < k2; j2 += nr) {
111 for (
Index k = k2; k < end_k; k++) {
112 blockB[count + 0] = rhs(k, j2 + 0);
113 blockB[count + 1] = rhs(k, j2 + 1);
115 blockB[count + 2] = rhs(k, j2 + 2);
116 blockB[count + 3] = rhs(k, j2 + 3);
119 blockB[count + 4] = rhs(k, j2 + 4);
120 blockB[count + 5] = rhs(k, j2 + 5);
121 blockB[count + 6] = rhs(k, j2 + 6);
122 blockB[count + 7] = rhs(k, j2 + 7);
129 Index end8 = nr >= 8 ? (std::min)(k2 + rows, packet_cols8) : k2;
131 for (
Index j2 = k2; j2 < end8; j2 += 8) {
134 for (
Index k = k2; k < j2; k++) {
135 blockB[count + 0] = numext::conj(rhs(j2 + 0, k));
136 blockB[count + 1] = numext::conj(rhs(j2 + 1, k));
137 blockB[count + 2] = numext::conj(rhs(j2 + 2, k));
138 blockB[count + 3] = numext::conj(rhs(j2 + 3, k));
139 blockB[count + 4] = numext::conj(rhs(j2 + 4, k));
140 blockB[count + 5] = numext::conj(rhs(j2 + 5, k));
141 blockB[count + 6] = numext::conj(rhs(j2 + 6, k));
142 blockB[count + 7] = numext::conj(rhs(j2 + 7, k));
147 for (
Index k = j2; k < j2 + 8; k++) {
149 for (
Index w = 0; w < h; ++w) blockB[count + w] = rhs(k, j2 + w);
151 blockB[count + h] = numext::real(rhs(k, k));
154 for (
Index w = h + 1; w < 8; ++w) blockB[count + w] = numext::conj(rhs(j2 + w, k));
159 for (
Index k = j2 + 8; k < end_k; k++) {
160 blockB[count + 0] = rhs(k, j2 + 0);
161 blockB[count + 1] = rhs(k, j2 + 1);
162 blockB[count + 2] = rhs(k, j2 + 2);
163 blockB[count + 3] = rhs(k, j2 + 3);
164 blockB[count + 4] = rhs(k, j2 + 4);
165 blockB[count + 5] = rhs(k, j2 + 5);
166 blockB[count + 6] = rhs(k, j2 + 6);
167 blockB[count + 7] = rhs(k, j2 + 7);
173 for (
Index j2 = end8; j2 < (std::min)(k2 + rows, packet_cols4); j2 += 4) {
176 for (
Index k = k2; k < j2; k++) {
177 blockB[count + 0] = numext::conj(rhs(j2 + 0, k));
178 blockB[count + 1] = numext::conj(rhs(j2 + 1, k));
179 blockB[count + 2] = numext::conj(rhs(j2 + 2, k));
180 blockB[count + 3] = numext::conj(rhs(j2 + 3, k));
185 for (
Index k = j2; k < j2 + 4; k++) {
187 for (
Index w = 0; w < h; ++w) blockB[count + w] = rhs(k, j2 + w);
189 blockB[count + h] = numext::real(rhs(k, k));
192 for (
Index w = h + 1; w < 4; ++w) blockB[count + w] = numext::conj(rhs(j2 + w, k));
197 for (
Index k = j2 + 4; k < end_k; k++) {
198 blockB[count + 0] = rhs(k, j2 + 0);
199 blockB[count + 1] = rhs(k, j2 + 1);
200 blockB[count + 2] = rhs(k, j2 + 2);
201 blockB[count + 3] = rhs(k, j2 + 3);
209 for (
Index j2 = k2 + rows; j2 < packet_cols8; j2 += 8) {
210 for (
Index k = k2; k < end_k; k++) {
211 blockB[count + 0] = numext::conj(rhs(j2 + 0, k));
212 blockB[count + 1] = numext::conj(rhs(j2 + 1, k));
213 blockB[count + 2] = numext::conj(rhs(j2 + 2, k));
214 blockB[count + 3] = numext::conj(rhs(j2 + 3, k));
215 blockB[count + 4] = numext::conj(rhs(j2 + 4, k));
216 blockB[count + 5] = numext::conj(rhs(j2 + 5, k));
217 blockB[count + 6] = numext::conj(rhs(j2 + 6, k));
218 blockB[count + 7] = numext::conj(rhs(j2 + 7, k));
224 for (
Index j2 = (std::max)(packet_cols8, k2 + rows); j2 < packet_cols4; j2 += 4) {
225 for (
Index k = k2; k < end_k; k++) {
226 blockB[count + 0] = numext::conj(rhs(j2 + 0, k));
227 blockB[count + 1] = numext::conj(rhs(j2 + 1, k));
228 blockB[count + 2] = numext::conj(rhs(j2 + 2, k));
229 blockB[count + 3] = numext::conj(rhs(j2 + 3, k));
236 for (
Index j2 = packet_cols4; j2 < cols; ++j2) {
238 Index half = (std::min)(end_k, j2);
239 for (
Index k = k2; k < half; k++) {
240 blockB[count] = numext::conj(rhs(j2, k));
244 if (half == j2 && half < k2 + rows) {
245 blockB[count] = numext::real(rhs(j2, j2));
251 for (
Index k = half + 1; k < k2 + rows; k++) {
252 blockB[count] = rhs(k, j2);
262template <
typename Scalar,
typename Index,
int LhsStorageOrder,
bool LhsSelfAdjoint,
bool ConjugateLhs,
263 int RhsStorageOrder,
bool RhsSelfAdjoint,
bool ConjugateRhs,
int ResStorageOrder,
int ResInnerStride>
264struct product_selfadjoint_matrix;
266template <
typename Scalar,
typename Index,
int LhsStorageOrder,
bool LhsSelfAdjoint,
bool ConjugateLhs,
267 int RhsStorageOrder,
bool RhsSelfAdjoint,
bool ConjugateRhs,
int ResInnerStride>
268struct product_selfadjoint_matrix<Scalar,
Index, LhsStorageOrder, LhsSelfAdjoint, ConjugateLhs, RhsStorageOrder,
269 RhsSelfAdjoint, ConjugateRhs,
RowMajor, ResInnerStride> {
270 static EIGEN_STRONG_INLINE
void run(
Index rows,
Index cols,
const Scalar* lhs,
Index lhsStride,
const Scalar* rhs,
271 Index rhsStride, Scalar* res,
Index resIncr,
Index resStride,
const Scalar& alpha,
272 level3_blocking<Scalar, Scalar>& blocking) {
273 product_selfadjoint_matrix<
275 NumTraits<Scalar>::IsComplex && logical_xor(RhsSelfAdjoint, ConjugateRhs),
277 NumTraits<Scalar>::IsComplex && logical_xor(LhsSelfAdjoint, ConjugateLhs),
ColMajor,
278 ResInnerStride>::run(cols, rows, rhs, rhsStride, lhs, lhsStride, res, resIncr, resStride, alpha, blocking);
282template <
typename Scalar,
typename Index,
int LhsStorageOrder,
bool ConjugateLhs,
int RhsStorageOrder,
283 bool ConjugateRhs,
int ResInnerStride>
284struct product_selfadjoint_matrix<Scalar,
Index, LhsStorageOrder, true, ConjugateLhs, RhsStorageOrder, false,
285 ConjugateRhs,
ColMajor, ResInnerStride> {
286 static EIGEN_DONT_INLINE
void run(
Index rows,
Index cols,
const Scalar* lhs_,
Index lhsStride,
const Scalar* rhs_,
287 Index rhsStride, Scalar* res,
Index resIncr,
Index resStride,
const Scalar& alpha,
288 level3_blocking<Scalar, Scalar>& blocking);
291template <
typename Scalar,
typename Index,
int LhsStorageOrder,
bool ConjugateLhs,
int RhsStorageOrder,
292 bool ConjugateRhs,
int ResInnerStride>
293EIGEN_DONT_INLINE
void
294product_selfadjoint_matrix<Scalar,
Index, LhsStorageOrder,
true, ConjugateLhs, RhsStorageOrder,
false, ConjugateRhs,
296 const Scalar* rhs_,
Index rhsStride, Scalar* res_,
297 Index resIncr,
Index resStride,
const Scalar& alpha,
298 level3_blocking<Scalar, Scalar>& blocking) {
301 typedef gebp_traits<Scalar, Scalar> Traits;
303 typedef const_blas_data_mapper<Scalar, Index, LhsStorageOrder> LhsMapper;
305 typedef const_blas_data_mapper<Scalar, Index, RhsStorageOrder> RhsMapper;
306 typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
307 LhsMapper lhs(lhs_, lhsStride);
308 LhsTransposeMapper lhs_transpose(lhs_, lhsStride);
309 RhsMapper rhs(rhs_, rhsStride);
310 ResMapper res(res_, resStride, resIncr);
312 Index kc = blocking.kc();
313 Index mc = (std::min)(rows, blocking.mc());
315 kc = (std::min)(kc, mc);
316 std::size_t sizeA = kc * mc;
317 std::size_t sizeB = kc * cols;
318 ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA());
319 ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB());
321 gebp_kernel<Scalar, Scalar, Index, ResMapper, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp_kernel;
322 symm_pack_lhs<Scalar, Index, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs;
323 gemm_pack_rhs<Scalar, Index, RhsMapper, Traits::nr, RhsStorageOrder> pack_rhs;
324 gemm_pack_lhs<Scalar,
Index, LhsTransposeMapper, Traits::mr, Traits::LhsProgress,
typename Traits::LhsPacket4Packing,
328 for (
Index k2 = 0; k2 < size; k2 += kc) {
329 const Index actual_kc = (std::min)(k2 + kc, size) - k2;
334 pack_rhs(blockB, rhs.getSubMapper(k2, 0), actual_kc, cols);
340 for (
Index i2 = 0; i2 < k2; i2 += mc) {
341 const Index actual_mc = (std::min)(i2 + mc, k2) - i2;
343 pack_lhs_transposed(blockA, lhs_transpose.getSubMapper(i2, k2), actual_kc, actual_mc);
345 gebp_kernel(res.getSubMapper(i2, 0), blockA, blockB, actual_mc, actual_kc, cols, alpha);
349 const Index actual_mc = (std::min)(k2 + kc, size) - k2;
351 pack_lhs(blockA, &lhs(k2, k2), lhsStride, actual_kc, actual_mc);
353 gebp_kernel(res.getSubMapper(k2, 0), blockA, blockB, actual_mc, actual_kc, cols, alpha);
356 for (
Index i2 = k2 + kc; i2 < size; i2 += mc) {
357 const Index actual_mc = (std::min)(i2 + mc, size) - i2;
358 gemm_pack_lhs<Scalar,
Index, LhsMapper, Traits::mr, Traits::LhsProgress,
typename Traits::LhsPacket4Packing,
359 LhsStorageOrder,
false>()(blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc);
361 gebp_kernel(res.getSubMapper(i2, 0), blockA, blockB, actual_mc, actual_kc, cols, alpha);
367template <
typename Scalar,
typename Index,
int LhsStorageOrder,
bool ConjugateLhs,
int RhsStorageOrder,
368 bool ConjugateRhs,
int ResInnerStride>
369struct product_selfadjoint_matrix<Scalar,
Index, LhsStorageOrder, false, ConjugateLhs, RhsStorageOrder, true,
370 ConjugateRhs,
ColMajor, ResInnerStride> {
371 static EIGEN_DONT_INLINE
void run(
Index rows,
Index cols,
const Scalar* lhs_,
Index lhsStride,
const Scalar* rhs_,
372 Index rhsStride, Scalar* res,
Index resIncr,
Index resStride,
const Scalar& alpha,
373 level3_blocking<Scalar, Scalar>& blocking);
376template <
typename Scalar,
typename Index,
int LhsStorageOrder,
bool ConjugateLhs,
int RhsStorageOrder,
377 bool ConjugateRhs,
int ResInnerStride>
378EIGEN_DONT_INLINE
void
379product_selfadjoint_matrix<Scalar,
Index, LhsStorageOrder,
false, ConjugateLhs, RhsStorageOrder,
true, ConjugateRhs,
381 const Scalar* rhs_,
Index rhsStride, Scalar* res_,
382 Index resIncr,
Index resStride,
const Scalar& alpha,
383 level3_blocking<Scalar, Scalar>& blocking) {
386 typedef gebp_traits<Scalar, Scalar> Traits;
388 typedef const_blas_data_mapper<Scalar, Index, LhsStorageOrder> LhsMapper;
389 typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
390 LhsMapper lhs(lhs_, lhsStride);
391 ResMapper res(res_, resStride, resIncr);
393 Index kc = blocking.kc();
394 Index mc = (std::min)(rows, blocking.mc());
395 std::size_t sizeA = kc * mc;
396 std::size_t sizeB = kc * cols;
397 ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA());
398 ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB());
400 gebp_kernel<Scalar, Scalar, Index, ResMapper, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp_kernel;
401 gemm_pack_lhs<Scalar,
Index, LhsMapper, Traits::mr, Traits::LhsProgress,
typename Traits::LhsPacket4Packing,
404 symm_pack_rhs<Scalar, Index, Traits::nr, RhsStorageOrder> pack_rhs;
406 for (
Index k2 = 0; k2 < size; k2 += kc) {
407 const Index actual_kc = (std::min)(k2 + kc, size) - k2;
409 pack_rhs(blockB, rhs_, rhsStride, actual_kc, cols, k2);
412 for (
Index i2 = 0; i2 < rows; i2 += mc) {
413 const Index actual_mc = (std::min)(i2 + mc, rows) - i2;
414 pack_lhs(blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc);
416 gebp_kernel(res.getSubMapper(i2, 0), blockA, blockB, actual_mc, actual_kc, cols, alpha);
429template <
typename Lhs,
int LhsMode,
typename Rhs,
int RhsMode>
430struct selfadjoint_product_impl<Lhs, LhsMode, false, Rhs, RhsMode, false> {
431 typedef typename Product<Lhs, Rhs>::Scalar Scalar;
433 typedef internal::blas_traits<Lhs> LhsBlasTraits;
434 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
435 typedef internal::blas_traits<Rhs> RhsBlasTraits;
436 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
445 template <
typename Dest>
446 static void run(Dest& dst,
const Lhs& a_lhs,
const Rhs& a_rhs,
const Scalar& alpha) {
447 eigen_assert(dst.rows() == a_lhs.rows() && dst.cols() == a_rhs.cols());
449 add_const_on_value_type_t<ActualLhsType> lhs = LhsBlasTraits::extract(a_lhs);
450 add_const_on_value_type_t<ActualRhsType> rhs = RhsBlasTraits::extract(a_rhs);
452 Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs) * RhsBlasTraits::extractScalarFactor(a_rhs);
455 Lhs::MaxRowsAtCompileTime, Rhs::MaxColsAtCompileTime,
456 Lhs::MaxColsAtCompileTime, 1>
459 BlockingType blocking(lhs.rows(), rhs.cols(), lhs.cols(), 1,
false);
461 internal::product_selfadjoint_matrix<
465 NumTraits<Scalar>::IsComplex && internal::logical_xor(LhsIsUpper,
bool(LhsBlasTraits::NeedToConjugate)),
468 NumTraits<Scalar>::IsComplex && internal::logical_xor(RhsIsUpper,
bool(RhsBlasTraits::NeedToConjugate)),
470 Dest::InnerStrideAtCompileTime>::run(lhs.rows(), rhs.cols(),
471 &lhs.coeffRef(0, 0), lhs.outerStride(),
472 &rhs.coeffRef(0, 0), rhs.outerStride(),
473 &dst.coeffRef(0, 0), dst.innerStride(), dst.outerStride(),
474 actualAlpha, blocking
@ SelfAdjoint
Definition Constants.h:227
@ Lower
Definition Constants.h:211
@ Upper
Definition Constants.h:213
@ ColMajor
Definition Constants.h:318
@ RowMajor
Definition Constants.h:320
const unsigned int RowMajorBit
Definition Constants.h:70
Namespace containing all symbols from the Eigen library.
Definition B01_Experimental.dox:1
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:82