Eigen  5.0.1-dev+60122df6
 
Loading...
Searching...
No Matches
SelfadjointProduct.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_SELFADJOINT_PRODUCT_H
11#define EIGEN_SELFADJOINT_PRODUCT_H
12
13/**********************************************************************
14 * This file implements a self adjoint product: C += A A^T updating only
15 * half of the selfadjoint matrix C.
16 * It corresponds to the level 3 SYRK and level 2 SYR Blas routines.
17 **********************************************************************/
18
19// IWYU pragma: private
20#include "../InternalHeaderCheck.h"
21
22namespace Eigen {
23
24template <typename Scalar, typename Index, int UpLo, bool ConjLhs, bool ConjRhs>
25struct selfadjoint_rank1_update<Scalar, Index, ColMajor, UpLo, ConjLhs, ConjRhs> {
26 static void run(Index size, Scalar* mat, Index stride, const Scalar* vecX, const Scalar* vecY, const Scalar& alpha) {
27 internal::conj_if<ConjRhs> cj;
28 typedef Map<const Matrix<Scalar, Dynamic, 1> > OtherMap;
29 typedef std::conditional_t<ConjLhs, typename OtherMap::ConjugateReturnType, const OtherMap&> ConjLhsType;
30 for (Index i = 0; i < size; ++i) {
31 Map<Matrix<Scalar, Dynamic, 1> >(mat + stride * i + (UpLo == Lower ? i : 0),
32 (UpLo == Lower ? size - i : (i + 1))) +=
33 (alpha * cj(vecY[i])) *
34 ConjLhsType(OtherMap(vecX + (UpLo == Lower ? i : 0), UpLo == Lower ? size - i : (i + 1)));
35 }
36 }
37};
38
39template <typename Scalar, typename Index, int UpLo, bool ConjLhs, bool ConjRhs>
40struct selfadjoint_rank1_update<Scalar, Index, RowMajor, UpLo, ConjLhs, ConjRhs> {
41 static void run(Index size, Scalar* mat, Index stride, const Scalar* vecX, const Scalar* vecY, const Scalar& alpha) {
42 selfadjoint_rank1_update<Scalar, Index, ColMajor, UpLo == Lower ? Upper : Lower, ConjRhs, ConjLhs>::run(
43 size, mat, stride, vecY, vecX, alpha);
44 }
45};
46
47template <typename MatrixType, typename OtherType, int UpLo, bool OtherIsVector = OtherType::IsVectorAtCompileTime>
48struct selfadjoint_product_selector;
49
50template <typename MatrixType, typename OtherType, int UpLo>
51struct selfadjoint_product_selector<MatrixType, OtherType, UpLo, true> {
52 static void run(MatrixType& mat, const OtherType& other, const typename MatrixType::Scalar& alpha) {
53 typedef typename MatrixType::Scalar Scalar;
54 typedef internal::blas_traits<OtherType> OtherBlasTraits;
55 typedef typename OtherBlasTraits::DirectLinearAccessType ActualOtherType;
56 typedef internal::remove_all_t<ActualOtherType> ActualOtherType_;
57 internal::add_const_on_value_type_t<ActualOtherType> actualOther = OtherBlasTraits::extract(other.derived());
58
59 Scalar actualAlpha = alpha * OtherBlasTraits::extractScalarFactor(other.derived());
60
61 enum {
62 StorageOrder = (internal::traits<MatrixType>::Flags & RowMajorBit) ? RowMajor : ColMajor,
63 UseOtherDirectly = ActualOtherType_::InnerStrideAtCompileTime == 1
64 };
65 internal::gemv_static_vector_if<Scalar, OtherType::SizeAtCompileTime, OtherType::MaxSizeAtCompileTime,
66 !UseOtherDirectly>
67 static_other;
68
69 ei_declare_aligned_stack_constructed_variable(
70 Scalar, actualOtherPtr, other.size(),
71 (UseOtherDirectly ? const_cast<Scalar*>(actualOther.data()) : static_other.data()));
72
73 if (!UseOtherDirectly)
74 Map<typename ActualOtherType_::PlainObject>(actualOtherPtr, actualOther.size()) = actualOther;
75
76 selfadjoint_rank1_update<
77 Scalar, Index, StorageOrder, UpLo, OtherBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex,
78 (!OtherBlasTraits::NeedToConjugate) && NumTraits<Scalar>::IsComplex>::run(other.size(), mat.data(),
79 mat.outerStride(), actualOtherPtr,
80 actualOtherPtr, actualAlpha);
81 }
82};
83
84template <typename MatrixType, typename OtherType, int UpLo>
85struct selfadjoint_product_selector<MatrixType, OtherType, UpLo, false> {
86 static void run(MatrixType& mat, const OtherType& other, const typename MatrixType::Scalar& alpha) {
87 typedef typename MatrixType::Scalar Scalar;
88 typedef internal::blas_traits<OtherType> OtherBlasTraits;
89 typedef typename OtherBlasTraits::DirectLinearAccessType ActualOtherType;
90 typedef internal::remove_all_t<ActualOtherType> ActualOtherType_;
91 internal::add_const_on_value_type_t<ActualOtherType> actualOther = OtherBlasTraits::extract(other.derived());
92
93 Scalar actualAlpha = alpha * OtherBlasTraits::extractScalarFactor(other.derived());
94
95 enum {
96 IsRowMajor = (internal::traits<MatrixType>::Flags & RowMajorBit) ? 1 : 0,
97 OtherIsRowMajor = ActualOtherType_::Flags & RowMajorBit ? 1 : 0
98 };
99
100 Index size = mat.cols();
101 Index depth = actualOther.cols();
102
103 typedef internal::gemm_blocking_space<IsRowMajor ? RowMajor : ColMajor, Scalar, Scalar,
104 MatrixType::MaxColsAtCompileTime, MatrixType::MaxColsAtCompileTime,
105 ActualOtherType_::MaxColsAtCompileTime>
106 BlockingType;
107
108 BlockingType blocking(size, size, depth, 1, false);
109
110 internal::general_matrix_matrix_triangular_product<
111 Index, Scalar, OtherIsRowMajor ? RowMajor : ColMajor,
112 OtherBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex, Scalar, OtherIsRowMajor ? ColMajor : RowMajor,
113 (!OtherBlasTraits::NeedToConjugate) && NumTraits<Scalar>::IsComplex, IsRowMajor ? RowMajor : ColMajor,
114 MatrixType::InnerStrideAtCompileTime, UpLo>::run(size, depth, actualOther.data(), actualOther.outerStride(),
115 actualOther.data(), actualOther.outerStride(), mat.data(),
116 mat.innerStride(), mat.outerStride(), actualAlpha, blocking);
117 }
118};
119
120// high level API
121
122template <typename MatrixType, unsigned int UpLo>
123template <typename DerivedU>
124EIGEN_DEVICE_FUNC SelfAdjointView<MatrixType, UpLo>& SelfAdjointView<MatrixType, UpLo>::rankUpdate(
125 const MatrixBase<DerivedU>& u, const Scalar& alpha) {
126 selfadjoint_product_selector<MatrixType, DerivedU, UpLo>::run(_expression().const_cast_derived(), u.derived(), alpha);
127
128 return *this;
129}
130
131} // end namespace Eigen
132
133#endif // EIGEN_SELFADJOINT_PRODUCT_H
@ Lower
Definition Constants.h:211
@ ColMajor
Definition Constants.h:318
@ RowMajor
Definition Constants.h:320
const unsigned int RowMajorBit
Definition Constants.h:70
Namespace containing all symbols from the Eigen library.
Definition B01_Experimental.dox:1
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:82