10#ifndef EIGEN_SELFADJOINT_MATRIX_MATRIX_H
11#define EIGEN_SELFADJOINT_MATRIX_MATRIX_H
18template<
typename Scalar,
typename Index,
int Pack1,
int Pack2_dummy,
int StorageOrder>
21 template<
int BlockRows>
inline
22 void pack(Scalar* blockA,
const const_blas_data_mapper<Scalar,Index,StorageOrder>& lhs,
Index cols,
Index i,
Index& count)
25 for(
Index k=0; k<i; k++)
26 for(
Index w=0; w<BlockRows; w++)
27 blockA[count++] = lhs(i+w,k);
30 for(
Index k=i; k<i+BlockRows; k++)
32 for(
Index w=0; w<h; w++)
33 blockA[count++] = numext::conj(lhs(k, i+w));
35 blockA[count++] = numext::real(lhs(k,k));
37 for(
Index w=h+1; w<BlockRows; w++)
38 blockA[count++] = lhs(i+w, k);
42 for(
Index k=i+BlockRows; k<cols; k++)
43 for(
Index w=0; w<BlockRows; w++)
44 blockA[count++] = numext::conj(lhs(k, i+w));
46 void operator()(Scalar* blockA,
const Scalar* lhs_,
Index lhsStride,
Index cols,
Index rows)
48 typedef typename unpacket_traits<typename packet_traits<Scalar>::type>::half HalfPacket;
49 typedef typename unpacket_traits<typename unpacket_traits<typename packet_traits<Scalar>::type>::half>::half QuarterPacket;
50 enum { PacketSize = packet_traits<Scalar>::size,
51 HalfPacketSize = unpacket_traits<HalfPacket>::size,
52 QuarterPacketSize = unpacket_traits<QuarterPacket>::size,
53 HasHalf = (int)HalfPacketSize < (
int)PacketSize,
54 HasQuarter = (int)QuarterPacketSize < (
int)HalfPacketSize};
56 const_blas_data_mapper<Scalar,Index,StorageOrder> lhs(lhs_,lhsStride);
60 const Index peeled_mc3 = Pack1>=3*PacketSize ? (rows/(3*PacketSize))*(3*PacketSize) : 0;
61 const Index peeled_mc2 = Pack1>=2*PacketSize ? peeled_mc3+((rows-peeled_mc3)/(2*PacketSize))*(2*PacketSize) : 0;
62 const Index peeled_mc1 = Pack1>=1*PacketSize ? peeled_mc2+((rows-peeled_mc2)/(1*PacketSize))*(1*PacketSize) : 0;
63 const Index peeled_mc_half = Pack1>=HalfPacketSize ? peeled_mc1+((rows-peeled_mc1)/(HalfPacketSize))*(HalfPacketSize) : 0;
64 const Index peeled_mc_quarter = Pack1>=QuarterPacketSize ? peeled_mc_half+((rows-peeled_mc_half)/(QuarterPacketSize))*(QuarterPacketSize) : 0;
66 if(Pack1>=3*PacketSize)
67 for(
Index i=0; i<peeled_mc3; i+=3*PacketSize)
68 pack<3*PacketSize>(blockA, lhs, cols, i, count);
70 if(Pack1>=2*PacketSize)
71 for(
Index i=peeled_mc3; i<peeled_mc2; i+=2*PacketSize)
72 pack<2*PacketSize>(blockA, lhs, cols, i, count);
74 if(Pack1>=1*PacketSize)
75 for(
Index i=peeled_mc2; i<peeled_mc1; i+=1*PacketSize)
76 pack<1*PacketSize>(blockA, lhs, cols, i, count);
78 if(HasHalf && Pack1>=HalfPacketSize)
79 for(
Index i=peeled_mc1; i<peeled_mc_half; i+=HalfPacketSize)
80 pack<HalfPacketSize>(blockA, lhs, cols, i, count);
82 if(HasQuarter && Pack1>=QuarterPacketSize)
83 for(
Index i=peeled_mc_half; i<peeled_mc_quarter; i+=QuarterPacketSize)
84 pack<QuarterPacketSize>(blockA, lhs, cols, i, count);
87 for(
Index i=peeled_mc_quarter; i<rows; i++)
89 for(
Index k=0; k<i; k++)
90 blockA[count++] = lhs(i, k);
92 blockA[count++] = numext::real(lhs(i, i));
94 for(
Index k=i+1; k<cols; k++)
95 blockA[count++] = numext::conj(lhs(k, i));
100template<
typename Scalar,
typename Index,
int nr,
int StorageOrder>
103 enum { PacketSize = packet_traits<Scalar>::size };
104 void operator()(Scalar* blockB,
const Scalar* rhs_,
Index rhsStride,
Index rows,
Index cols,
Index k2)
106 Index end_k = k2 + rows;
108 const_blas_data_mapper<Scalar,Index,StorageOrder> rhs(rhs_,rhsStride);
109 Index packet_cols8 = nr>=8 ? (cols/8) * 8 : 0;
110 Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0;
113 for(
Index j2=0; j2<k2; j2+=nr)
115 for(
Index k=k2; k<end_k; k++)
117 blockB[count+0] = rhs(k,j2+0);
118 blockB[count+1] = rhs(k,j2+1);
121 blockB[count+2] = rhs(k,j2+2);
122 blockB[count+3] = rhs(k,j2+3);
126 blockB[count+4] = rhs(k,j2+4);
127 blockB[count+5] = rhs(k,j2+5);
128 blockB[count+6] = rhs(k,j2+6);
129 blockB[count+7] = rhs(k,j2+7);
136 Index end8 = nr>=8 ? (std::min)(k2+rows,packet_cols8) : k2;
139 for(
Index j2=k2; j2<end8; j2+=8)
143 for(
Index k=k2; k<j2; k++)
145 blockB[count+0] = numext::conj(rhs(j2+0,k));
146 blockB[count+1] = numext::conj(rhs(j2+1,k));
147 blockB[count+2] = numext::conj(rhs(j2+2,k));
148 blockB[count+3] = numext::conj(rhs(j2+3,k));
149 blockB[count+4] = numext::conj(rhs(j2+4,k));
150 blockB[count+5] = numext::conj(rhs(j2+5,k));
151 blockB[count+6] = numext::conj(rhs(j2+6,k));
152 blockB[count+7] = numext::conj(rhs(j2+7,k));
157 for(
Index k=j2; k<j2+8; k++)
160 for (
Index w=0 ; w<h; ++w)
161 blockB[count+w] = rhs(k,j2+w);
163 blockB[count+h] = numext::real(rhs(k,k));
166 for (
Index w=h+1 ; w<8; ++w)
167 blockB[count+w] = numext::conj(rhs(j2+w,k));
172 for(
Index k=j2+8; k<end_k; k++)
174 blockB[count+0] = rhs(k,j2+0);
175 blockB[count+1] = rhs(k,j2+1);
176 blockB[count+2] = rhs(k,j2+2);
177 blockB[count+3] = rhs(k,j2+3);
178 blockB[count+4] = rhs(k,j2+4);
179 blockB[count+5] = rhs(k,j2+5);
180 blockB[count+6] = rhs(k,j2+6);
181 blockB[count+7] = rhs(k,j2+7);
188 for(
Index j2=end8; j2<(std::min)(k2+rows,packet_cols4); j2+=4)
192 for(
Index k=k2; k<j2; k++)
194 blockB[count+0] = numext::conj(rhs(j2+0,k));
195 blockB[count+1] = numext::conj(rhs(j2+1,k));
196 blockB[count+2] = numext::conj(rhs(j2+2,k));
197 blockB[count+3] = numext::conj(rhs(j2+3,k));
202 for(
Index k=j2; k<j2+4; k++)
205 for (
Index w=0 ; w<h; ++w)
206 blockB[count+w] = rhs(k,j2+w);
208 blockB[count+h] = numext::real(rhs(k,k));
211 for (
Index w=h+1 ; w<4; ++w)
212 blockB[count+w] = numext::conj(rhs(j2+w,k));
217 for(
Index k=j2+4; k<end_k; k++)
219 blockB[count+0] = rhs(k,j2+0);
220 blockB[count+1] = rhs(k,j2+1);
221 blockB[count+2] = rhs(k,j2+2);
222 blockB[count+3] = rhs(k,j2+3);
231 for(
Index j2=k2+rows; j2<packet_cols8; j2+=8)
233 for(
Index k=k2; k<end_k; k++)
235 blockB[count+0] = numext::conj(rhs(j2+0,k));
236 blockB[count+1] = numext::conj(rhs(j2+1,k));
237 blockB[count+2] = numext::conj(rhs(j2+2,k));
238 blockB[count+3] = numext::conj(rhs(j2+3,k));
239 blockB[count+4] = numext::conj(rhs(j2+4,k));
240 blockB[count+5] = numext::conj(rhs(j2+5,k));
241 blockB[count+6] = numext::conj(rhs(j2+6,k));
242 blockB[count+7] = numext::conj(rhs(j2+7,k));
249 for(
Index j2=(std::max)(packet_cols8,k2+rows); j2<packet_cols4; j2+=4)
251 for(
Index k=k2; k<end_k; k++)
253 blockB[count+0] = numext::conj(rhs(j2+0,k));
254 blockB[count+1] = numext::conj(rhs(j2+1,k));
255 blockB[count+2] = numext::conj(rhs(j2+2,k));
256 blockB[count+3] = numext::conj(rhs(j2+3,k));
263 for(
Index j2=packet_cols4; j2<cols; ++j2)
266 Index half = (std::min)(end_k,j2);
267 for(
Index k=k2; k<half; k++)
269 blockB[count] = numext::conj(rhs(j2,k));
273 if(half==j2 && half<k2+rows)
275 blockB[count] = numext::real(rhs(j2,j2));
282 for(
Index k=half+1; k<k2+rows; k++)
284 blockB[count] = rhs(k,j2);
294template <
typename Scalar,
typename Index,
295 int LhsStorageOrder,
bool LhsSelfAdjoint,
bool ConjugateLhs,
296 int RhsStorageOrder,
bool RhsSelfAdjoint,
bool ConjugateRhs,
297 int ResStorageOrder,
int ResInnerStride>
298struct product_selfadjoint_matrix;
300template <
typename Scalar,
typename Index,
301 int LhsStorageOrder,
bool LhsSelfAdjoint,
bool ConjugateLhs,
302 int RhsStorageOrder,
bool RhsSelfAdjoint,
bool ConjugateRhs,
304struct product_selfadjoint_matrix<Scalar,
Index,LhsStorageOrder,LhsSelfAdjoint,ConjugateLhs, RhsStorageOrder,RhsSelfAdjoint,ConjugateRhs,
RowMajor,ResInnerStride>
307 static EIGEN_STRONG_INLINE
void run(
309 const Scalar* lhs,
Index lhsStride,
310 const Scalar* rhs,
Index rhsStride,
312 const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking)
314 product_selfadjoint_matrix<Scalar,
Index,
316 RhsSelfAdjoint, NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(RhsSelfAdjoint,ConjugateRhs),
318 LhsSelfAdjoint, NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(LhsSelfAdjoint,ConjugateLhs),
320 ::run(cols, rows, rhs, rhsStride, lhs, lhsStride, res, resIncr, resStride, alpha, blocking);
324template <
typename Scalar,
typename Index,
325 int LhsStorageOrder,
bool ConjugateLhs,
326 int RhsStorageOrder,
bool ConjugateRhs,
328struct product_selfadjoint_matrix<Scalar,
Index,LhsStorageOrder,true,ConjugateLhs, RhsStorageOrder,false,ConjugateRhs,
ColMajor,ResInnerStride>
331 static EIGEN_DONT_INLINE
void run(
333 const Scalar* lhs_,
Index lhsStride,
334 const Scalar* rhs_,
Index rhsStride,
336 const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking);
339template <
typename Scalar,
typename Index,
340 int LhsStorageOrder,
bool ConjugateLhs,
341 int RhsStorageOrder,
bool ConjugateRhs,
343EIGEN_DONT_INLINE
void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,true,ConjugateLhs, RhsStorageOrder,false,ConjugateRhs,ColMajor,ResInnerStride>::run(
345 const Scalar* lhs_,
Index lhsStride,
346 const Scalar* rhs_,
Index rhsStride,
348 const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking)
352 typedef gebp_traits<Scalar,Scalar> Traits;
354 typedef const_blas_data_mapper<Scalar, Index, LhsStorageOrder> LhsMapper;
356 typedef const_blas_data_mapper<Scalar, Index, RhsStorageOrder> RhsMapper;
357 typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
358 LhsMapper lhs(lhs_,lhsStride);
359 LhsTransposeMapper lhs_transpose(lhs_,lhsStride);
360 RhsMapper rhs(rhs_,rhsStride);
361 ResMapper res(res_, resStride, resIncr);
363 Index kc = blocking.kc();
364 Index mc = (std::min)(rows,blocking.mc());
366 kc = (std::min)(kc,mc);
367 std::size_t sizeA = kc*mc;
368 std::size_t sizeB = kc*cols;
369 ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA());
370 ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB());
372 gebp_kernel<Scalar, Scalar, Index, ResMapper, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp_kernel;
373 symm_pack_lhs<Scalar, Index, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs;
374 gemm_pack_rhs<Scalar, Index, RhsMapper, Traits::nr,RhsStorageOrder> pack_rhs;
375 gemm_pack_lhs<Scalar, Index, LhsTransposeMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, LhsStorageOrder==RowMajor?ColMajor:RowMajor, true> pack_lhs_transposed;
377 for(
Index k2=0; k2<size; k2+=kc)
379 const Index actual_kc = (std::min)(k2+kc,size)-k2;
384 pack_rhs(blockB, rhs.getSubMapper(k2,0), actual_kc, cols);
390 for(
Index i2=0; i2<k2; i2+=mc)
392 const Index actual_mc = (std::min)(i2+mc,k2)-i2;
394 pack_lhs_transposed(blockA, lhs_transpose.getSubMapper(i2, k2), actual_kc, actual_mc);
396 gebp_kernel(res.getSubMapper(i2, 0), blockA, blockB, actual_mc, actual_kc, cols, alpha);
400 const Index actual_mc = (std::min)(k2+kc,size)-k2;
402 pack_lhs(blockA, &lhs(k2,k2), lhsStride, actual_kc, actual_mc);
404 gebp_kernel(res.getSubMapper(k2, 0), blockA, blockB, actual_mc, actual_kc, cols, alpha);
407 for(
Index i2=k2+kc; i2<size; i2+=mc)
409 const Index actual_mc = (std::min)(i2+mc,size)-i2;
410 gemm_pack_lhs<Scalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, LhsStorageOrder,false>()
411 (blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc);
413 gebp_kernel(res.getSubMapper(i2, 0), blockA, blockB, actual_mc, actual_kc, cols, alpha);
419template <
typename Scalar,
typename Index,
420 int LhsStorageOrder,
bool ConjugateLhs,
421 int RhsStorageOrder,
bool ConjugateRhs,
423struct product_selfadjoint_matrix<Scalar,
Index,LhsStorageOrder,false,ConjugateLhs, RhsStorageOrder,true,ConjugateRhs,
ColMajor,ResInnerStride>
426 static EIGEN_DONT_INLINE
void run(
428 const Scalar* lhs_,
Index lhsStride,
429 const Scalar* rhs_,
Index rhsStride,
431 const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking);
434template <
typename Scalar,
typename Index,
435 int LhsStorageOrder,
bool ConjugateLhs,
436 int RhsStorageOrder,
bool ConjugateRhs,
438EIGEN_DONT_INLINE
void product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,false,ConjugateLhs, RhsStorageOrder,true,ConjugateRhs,ColMajor,ResInnerStride>::run(
440 const Scalar* lhs_,
Index lhsStride,
441 const Scalar* rhs_,
Index rhsStride,
443 const Scalar& alpha, level3_blocking<Scalar,Scalar>& blocking)
447 typedef gebp_traits<Scalar,Scalar> Traits;
449 typedef const_blas_data_mapper<Scalar, Index, LhsStorageOrder> LhsMapper;
450 typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
451 LhsMapper lhs(lhs_,lhsStride);
452 ResMapper res(res_,resStride, resIncr);
454 Index kc = blocking.kc();
455 Index mc = (std::min)(rows,blocking.mc());
456 std::size_t sizeA = kc*mc;
457 std::size_t sizeB = kc*cols;
458 ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA());
459 ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB());
461 gebp_kernel<Scalar, Scalar, Index, ResMapper, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp_kernel;
462 gemm_pack_lhs<Scalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing, LhsStorageOrder> pack_lhs;
463 symm_pack_rhs<Scalar, Index, Traits::nr,RhsStorageOrder> pack_rhs;
465 for(
Index k2=0; k2<size; k2+=kc)
467 const Index actual_kc = (std::min)(k2+kc,size)-k2;
469 pack_rhs(blockB, rhs_, rhsStride, actual_kc, cols, k2);
472 for(
Index i2=0; i2<rows; i2+=mc)
474 const Index actual_mc = (std::min)(i2+mc,rows)-i2;
475 pack_lhs(blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc);
477 gebp_kernel(res.getSubMapper(i2, 0), blockA, blockB, actual_mc, actual_kc, cols, alpha);
490template<
typename Lhs,
int LhsMode,
typename Rhs,
int RhsMode>
491struct selfadjoint_product_impl<Lhs,LhsMode,false,Rhs,RhsMode,false>
493 typedef typename Product<Lhs,Rhs>::Scalar Scalar;
495 typedef internal::blas_traits<Lhs> LhsBlasTraits;
496 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
497 typedef internal::blas_traits<Rhs> RhsBlasTraits;
498 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
507 template<
typename Dest>
508 static void run(Dest &dst,
const Lhs &a_lhs,
const Rhs &a_rhs,
const Scalar& alpha)
510 eigen_assert(dst.rows()==a_lhs.rows() && dst.cols()==a_rhs.cols());
512 typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(a_lhs);
513 typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(a_rhs);
515 Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs)
516 * RhsBlasTraits::extractScalarFactor(a_rhs);
519 Lhs::MaxRowsAtCompileTime, Rhs::MaxColsAtCompileTime, Lhs::MaxColsAtCompileTime,1> BlockingType;
521 BlockingType blocking(lhs.rows(), rhs.cols(), lhs.cols(), 1,
false);
523 internal::product_selfadjoint_matrix<Scalar,
Index,
525 NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(LhsIsUpper,
bool(LhsBlasTraits::NeedToConjugate)),
527 NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(RhsIsUpper,
bool(RhsBlasTraits::NeedToConjugate)),
529 Dest::InnerStrideAtCompileTime>
531 lhs.rows(), rhs.cols(),
532 &lhs.coeffRef(0,0), lhs.outerStride(),
533 &rhs.coeffRef(0,0), rhs.outerStride(),
534 &dst.coeffRef(0,0), dst.innerStride(), dst.outerStride(),
535 actualAlpha, blocking
@ SelfAdjoint
Definition Constants.h:225
@ Lower
Definition Constants.h:209
@ Upper
Definition Constants.h:211
@ ColMajor
Definition Constants.h:319
@ RowMajor
Definition Constants.h:321
const unsigned int RowMajorBit
Definition Constants.h:66
Namespace containing all symbols from the Eigen library.
Definition B01_Experimental.dox:1
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:74