Eigen  5.0.1-dev+60122df6
 
Loading...
Searching...
No Matches
GeneralMatrixMatrixTriangular.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2009-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_H
11#define EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_H
12
13// IWYU pragma: private
14#include "../InternalHeaderCheck.h"
15
16namespace Eigen {
17
18template <typename Scalar, typename Index, int StorageOrder, int UpLo, bool ConjLhs, bool ConjRhs>
19struct selfadjoint_rank1_update;
20
21namespace internal {
22
23/**********************************************************************
24 * This file implements a general A * B product while
25 * evaluating only one triangular part of the product.
26 * This is a more general version of self adjoint product (C += A A^T)
27 * as the level 3 SYRK Blas routine.
28 **********************************************************************/
29
30// forward declarations (defined at the end of this file)
31template <typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjLhs, bool ConjRhs,
32 int ResInnerStride, int UpLo>
33struct tribb_kernel;
34
35/* Optimized matrix-matrix product evaluating only one triangular half */
36template <typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, typename RhsScalar,
37 int RhsStorageOrder, bool ConjugateRhs, int ResStorageOrder, int ResInnerStride, int UpLo,
38 int Version = Specialized>
39struct general_matrix_matrix_triangular_product;
40
41// as usual if the result is row major => we transpose the product
42template <typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, typename RhsScalar,
43 int RhsStorageOrder, bool ConjugateRhs, int ResInnerStride, int UpLo, int Version>
44struct general_matrix_matrix_triangular_product<Index, LhsScalar, LhsStorageOrder, ConjugateLhs, RhsScalar,
45 RhsStorageOrder, ConjugateRhs, RowMajor, ResInnerStride, UpLo,
46 Version> {
47 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
48 static EIGEN_STRONG_INLINE void run(Index size, Index depth, const LhsScalar* lhs, Index lhsStride,
49 const RhsScalar* rhs, Index rhsStride, ResScalar* res, Index resIncr,
50 Index resStride, const ResScalar& alpha,
51 level3_blocking<RhsScalar, LhsScalar>& blocking) {
52 general_matrix_matrix_triangular_product<Index, RhsScalar, RhsStorageOrder == RowMajor ? ColMajor : RowMajor,
53 ConjugateRhs, LhsScalar, LhsStorageOrder == RowMajor ? ColMajor : RowMajor,
54 ConjugateLhs, ColMajor, ResInnerStride,
55 UpLo == Lower ? Upper : Lower>::run(size, depth, rhs, rhsStride, lhs,
56 lhsStride, res, resIncr, resStride,
57 alpha, blocking);
58 }
59};
60
61template <typename Index, typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs, typename RhsScalar,
62 int RhsStorageOrder, bool ConjugateRhs, int ResInnerStride, int UpLo, int Version>
63struct general_matrix_matrix_triangular_product<Index, LhsScalar, LhsStorageOrder, ConjugateLhs, RhsScalar,
64 RhsStorageOrder, ConjugateRhs, ColMajor, ResInnerStride, UpLo,
65 Version> {
66 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
67 static EIGEN_STRONG_INLINE void run(Index size, Index depth, const LhsScalar* lhs_, Index lhsStride,
68 const RhsScalar* rhs_, Index rhsStride, ResScalar* res_, Index resIncr,
69 Index resStride, const ResScalar& alpha,
70 level3_blocking<LhsScalar, RhsScalar>& blocking) {
71 if (size == 0) {
72 return;
73 }
74
75 typedef gebp_traits<LhsScalar, RhsScalar> Traits;
76
77 typedef const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> LhsMapper;
78 typedef const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> RhsMapper;
79 typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
80 LhsMapper lhs(lhs_, lhsStride);
81 RhsMapper rhs(rhs_, rhsStride);
82 ResMapper res(res_, resStride, resIncr);
83
84 Index kc = blocking.kc();
85 // Ensure that mc >= nr and <= size
86 Index mc = (std::min)(size, (std::max)(static_cast<decltype(blocking.mc())>(Traits::nr), blocking.mc()));
87
88 // !!! mc must be a multiple of nr
89 if (mc > Traits::nr) {
90 using UnsignedIndex = typename make_unsigned<Index>::type;
91 mc = (UnsignedIndex(mc) / Traits::nr) * Traits::nr;
92 }
93
94 std::size_t sizeA = kc * mc;
95 std::size_t sizeB = kc * size;
96
97 ei_declare_aligned_stack_constructed_variable(LhsScalar, blockA, sizeA, blocking.blockA());
98 ei_declare_aligned_stack_constructed_variable(RhsScalar, blockB, sizeB, blocking.blockB());
99
100 gemm_pack_lhs<LhsScalar, Index, LhsMapper, Traits::mr, Traits::LhsProgress, typename Traits::LhsPacket4Packing,
101 LhsStorageOrder>
102 pack_lhs;
103 gemm_pack_rhs<RhsScalar, Index, RhsMapper, Traits::nr, RhsStorageOrder> pack_rhs;
104 gebp_kernel<LhsScalar, RhsScalar, Index, ResMapper, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp;
105 tribb_kernel<LhsScalar, RhsScalar, Index, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs, ResInnerStride, UpLo>
106 sybb;
107
108 for (Index k2 = 0; k2 < depth; k2 += kc) {
109 const Index actual_kc = (std::min)(k2 + kc, depth) - k2;
110
111 // note that the actual rhs is the transpose/adjoint of mat
112 pack_rhs(blockB, rhs.getSubMapper(k2, 0), actual_kc, size);
113
114 for (Index i2 = 0; i2 < size; i2 += mc) {
115 const Index actual_mc = (std::min)(i2 + mc, size) - i2;
116
117 pack_lhs(blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc);
118
119 // the selected actual_mc * size panel of res is split into three different part:
120 // 1 - before the diagonal => processed with gebp or skipped
121 // 2 - the actual_mc x actual_mc symmetric block => processed with a special kernel
122 // 3 - after the diagonal => processed with gebp or skipped
123 if (UpLo == Lower)
124 gebp(res.getSubMapper(i2, 0), blockA, blockB, actual_mc, actual_kc, (std::min)(size, i2), alpha, -1, -1, 0,
125 0);
126
127 sybb(res_ + resStride * i2 + resIncr * i2, resIncr, resStride, blockA, blockB + actual_kc * i2, actual_mc,
128 actual_kc, alpha);
129
130 if (UpLo == Upper) {
131 Index j2 = i2 + actual_mc;
132 gebp(res.getSubMapper(i2, j2), blockA, blockB + actual_kc * j2, actual_mc, actual_kc,
133 (std::max)(Index(0), size - j2), alpha, -1, -1, 0, 0);
134 }
135 }
136 }
137 }
138};
139
140// Optimized packed Block * packed Block product kernel evaluating only one given triangular part
141// This kernel is built on top of the gebp kernel:
142// - the current destination block is processed per panel of actual_mc x BlockSize
143// where BlockSize is set to the minimal value allowing gebp to be as fast as possible
144// - then, as usual, each panel is split into three parts along the diagonal,
145// the sub blocks above and below the diagonal are processed as usual,
146// while the triangular block overlapping the diagonal is evaluated into a
147// small temporary buffer which is then accumulated into the result using a
148// triangular traversal.
149template <typename LhsScalar, typename RhsScalar, typename Index, int mr, int nr, bool ConjLhs, bool ConjRhs,
150 int ResInnerStride, int UpLo>
151struct tribb_kernel {
152 typedef gebp_traits<LhsScalar, RhsScalar, ConjLhs, ConjRhs> Traits;
153 typedef typename Traits::ResScalar ResScalar;
154
155 enum { BlockSize = meta_least_common_multiple<plain_enum_max(mr, nr), plain_enum_min(mr, nr)>::ret };
156 void operator()(ResScalar* res_, Index resIncr, Index resStride, const LhsScalar* blockA, const RhsScalar* blockB,
157 Index size, Index depth, const ResScalar& alpha) {
158 typedef blas_data_mapper<ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
159 typedef blas_data_mapper<ResScalar, Index, ColMajor, Unaligned> BufferMapper;
160 ResMapper res(res_, resStride, resIncr);
161 gebp_kernel<LhsScalar, RhsScalar, Index, ResMapper, mr, nr, ConjLhs, ConjRhs> gebp_kernel1;
162 gebp_kernel<LhsScalar, RhsScalar, Index, BufferMapper, mr, nr, ConjLhs, ConjRhs> gebp_kernel2;
163
164 Matrix<ResScalar, BlockSize, BlockSize, ColMajor> buffer;
165
166 // let's process the block per panel of actual_mc x BlockSize,
167 // again, each is split into three parts, etc.
168 for (Index j = 0; j < size; j += BlockSize) {
169 Index actualBlockSize = std::min<Index>(BlockSize, size - j);
170 const RhsScalar* actual_b = blockB + j * depth;
171
172 if (UpLo == Upper)
173 gebp_kernel1(res.getSubMapper(0, j), blockA, actual_b, j, depth, actualBlockSize, alpha, -1, -1, 0, 0);
174
175 // selfadjoint micro block
176 {
177 Index i = j;
178 buffer.setZero();
179 // 1 - apply the kernel on the temporary buffer
180 gebp_kernel2(BufferMapper(buffer.data(), BlockSize), blockA + depth * i, actual_b, actualBlockSize, depth,
181 actualBlockSize, alpha, -1, -1, 0, 0);
182
183 // 2 - triangular accumulation
184 for (Index j1 = 0; j1 < actualBlockSize; ++j1) {
185 typename ResMapper::LinearMapper r = res.getLinearMapper(i, j + j1);
186 for (Index i1 = UpLo == Lower ? j1 : 0; UpLo == Lower ? i1 < actualBlockSize : i1 <= j1; ++i1)
187 r(i1) += buffer(i1, j1);
188 }
189 }
190
191 if (UpLo == Lower) {
192 Index i = j + actualBlockSize;
193 gebp_kernel1(res.getSubMapper(i, j), blockA + depth * i, actual_b, size - i, depth, actualBlockSize, alpha, -1,
194 -1, 0, 0);
195 }
196 }
197 }
198};
199
200} // end namespace internal
201
202// high level API
203
204template <typename MatrixType, typename ProductType, int UpLo, bool IsOuterProduct>
205struct general_product_to_triangular_selector;
206
207template <typename MatrixType, typename ProductType, int UpLo>
208struct general_product_to_triangular_selector<MatrixType, ProductType, UpLo, true> {
209 static void run(MatrixType& mat, const ProductType& prod, const typename MatrixType::Scalar& alpha, bool beta) {
210 typedef typename MatrixType::Scalar Scalar;
211
212 typedef internal::remove_all_t<typename ProductType::LhsNested> Lhs;
213 typedef internal::blas_traits<Lhs> LhsBlasTraits;
214 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhs;
215 typedef internal::remove_all_t<ActualLhs> ActualLhs_;
216 internal::add_const_on_value_type_t<ActualLhs> actualLhs = LhsBlasTraits::extract(prod.lhs());
217
218 typedef internal::remove_all_t<typename ProductType::RhsNested> Rhs;
219 typedef internal::blas_traits<Rhs> RhsBlasTraits;
220 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhs;
221 typedef internal::remove_all_t<ActualRhs> ActualRhs_;
222 internal::add_const_on_value_type_t<ActualRhs> actualRhs = RhsBlasTraits::extract(prod.rhs());
223
224 Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) *
225 RhsBlasTraits::extractScalarFactor(prod.rhs().derived());
226
227 if (!beta) mat.template triangularView<UpLo>().setZero();
228
229 enum {
230 StorageOrder = (internal::traits<MatrixType>::Flags & RowMajorBit) ? RowMajor : ColMajor,
231 UseLhsDirectly = ActualLhs_::InnerStrideAtCompileTime == 1,
232 UseRhsDirectly = ActualRhs_::InnerStrideAtCompileTime == 1
233 };
234
235 internal::gemv_static_vector_if<Scalar, Lhs::SizeAtCompileTime, Lhs::MaxSizeAtCompileTime, !UseLhsDirectly>
236 static_lhs;
237 ei_declare_aligned_stack_constructed_variable(
238 Scalar, actualLhsPtr, actualLhs.size(),
239 (UseLhsDirectly ? const_cast<Scalar*>(actualLhs.data()) : static_lhs.data()));
240 if (!UseLhsDirectly) Map<typename ActualLhs_::PlainObject>(actualLhsPtr, actualLhs.size()) = actualLhs;
241
242 internal::gemv_static_vector_if<Scalar, Rhs::SizeAtCompileTime, Rhs::MaxSizeAtCompileTime, !UseRhsDirectly>
243 static_rhs;
244 ei_declare_aligned_stack_constructed_variable(
245 Scalar, actualRhsPtr, actualRhs.size(),
246 (UseRhsDirectly ? const_cast<Scalar*>(actualRhs.data()) : static_rhs.data()));
247 if (!UseRhsDirectly) Map<typename ActualRhs_::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
248
249 selfadjoint_rank1_update<
250 Scalar, Index, StorageOrder, UpLo, LhsBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex,
251 RhsBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex>::run(actualLhs.size(), mat.data(),
252 mat.outerStride(), actualLhsPtr,
253 actualRhsPtr, actualAlpha);
254 }
255};
256
257template <typename MatrixType, typename ProductType, int UpLo>
258struct general_product_to_triangular_selector<MatrixType, ProductType, UpLo, false> {
259 static void run(MatrixType& mat, const ProductType& prod, const typename MatrixType::Scalar& alpha, bool beta) {
260 typedef internal::remove_all_t<typename ProductType::LhsNested> Lhs;
261 typedef internal::blas_traits<Lhs> LhsBlasTraits;
262 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhs;
263 typedef internal::remove_all_t<ActualLhs> ActualLhs_;
264 internal::add_const_on_value_type_t<ActualLhs> actualLhs = LhsBlasTraits::extract(prod.lhs());
265
266 typedef internal::remove_all_t<typename ProductType::RhsNested> Rhs;
267 typedef internal::blas_traits<Rhs> RhsBlasTraits;
268 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhs;
269 typedef internal::remove_all_t<ActualRhs> ActualRhs_;
270 internal::add_const_on_value_type_t<ActualRhs> actualRhs = RhsBlasTraits::extract(prod.rhs());
271
272 typename ProductType::Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) *
273 RhsBlasTraits::extractScalarFactor(prod.rhs().derived());
274
275 if (!beta) mat.template triangularView<UpLo>().setZero();
276
277 enum {
278 IsRowMajor = (internal::traits<MatrixType>::Flags & RowMajorBit) ? 1 : 0,
279 LhsIsRowMajor = ActualLhs_::Flags & RowMajorBit ? 1 : 0,
280 RhsIsRowMajor = ActualRhs_::Flags & RowMajorBit ? 1 : 0,
281 SkipDiag = (UpLo & (UnitDiag | ZeroDiag)) != 0
282 };
283
284 Index size = mat.cols();
285 if (SkipDiag) size--;
286 Index depth = actualLhs.cols();
287
288 typedef internal::gemm_blocking_space<IsRowMajor ? RowMajor : ColMajor, typename Lhs::Scalar, typename Rhs::Scalar,
289 MatrixType::MaxColsAtCompileTime, MatrixType::MaxColsAtCompileTime,
290 ActualRhs_::MaxColsAtCompileTime>
291 BlockingType;
292
293 BlockingType blocking(size, size, depth, 1, false);
294
295 internal::general_matrix_matrix_triangular_product<
296 Index, typename Lhs::Scalar, LhsIsRowMajor ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate,
297 typename Rhs::Scalar, RhsIsRowMajor ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate,
298 IsRowMajor ? RowMajor : ColMajor, MatrixType::InnerStrideAtCompileTime,
299 UpLo&(Lower | Upper)>::run(size, depth, &actualLhs.coeffRef(SkipDiag && (UpLo & Lower) == Lower ? 1 : 0, 0),
300 actualLhs.outerStride(),
301 &actualRhs.coeffRef(0, SkipDiag && (UpLo & Upper) == Upper ? 1 : 0),
302 actualRhs.outerStride(),
303 mat.data() +
304 (SkipDiag ? (bool(IsRowMajor) != ((UpLo & Lower) == Lower) ? mat.innerStride()
305 : mat.outerStride())
306 : 0),
307 mat.innerStride(), mat.outerStride(), actualAlpha, blocking);
308 }
309};
310
311template <typename MatrixType_, unsigned int Mode_>
312template <typename ProductType>
313EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename TriangularViewImpl<MatrixType_, Mode_, Dense>::TriangularViewType&
314TriangularViewImpl<MatrixType_, Mode_, Dense>::_assignProduct(
315 const ProductType& prod, const typename TriangularViewImpl<MatrixType_, Mode_, Dense>::Scalar& alpha, bool beta) {
316 EIGEN_STATIC_ASSERT((Mode_ & UnitDiag) == 0, WRITING_TO_TRIANGULAR_PART_WITH_UNIT_DIAGONAL_IS_NOT_SUPPORTED);
317 eigen_assert(derived().nestedExpression().rows() == prod.rows() && derived().cols() == prod.cols());
318
319 general_product_to_triangular_selector<MatrixType_, ProductType, Mode_,
320 internal::traits<ProductType>::InnerSize == 1>::run(derived()
321 .nestedExpression()
322 .const_cast_derived(),
323 prod, alpha, beta);
324
325 return derived();
326}
327
328} // end namespace Eigen
329
330#endif // EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_H
@ UnitDiag
Definition Constants.h:215
@ ZeroDiag
Definition Constants.h:217
@ Lower
Definition Constants.h:211
@ Upper
Definition Constants.h:213
@ ColMajor
Definition Constants.h:318
@ RowMajor
Definition Constants.h:320
const unsigned int RowMajorBit
Definition Constants.h:70
Namespace containing all symbols from the Eigen library.
Definition B01_Experimental.dox:1
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:82