10#ifndef EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_H
11#define EIGEN_GENERAL_MATRIX_MATRIX_TRIANGULAR_H
14#include "../InternalHeaderCheck.h"
18template <
typename Scalar,
typename Index,
int StorageOrder,
int UpLo,
bool ConjLhs,
bool ConjRhs>
19struct selfadjoint_rank1_update;
31template <
typename LhsScalar,
typename RhsScalar,
typename Index,
int mr,
int nr,
bool ConjLhs,
bool ConjRhs,
32 int ResInnerStride,
int UpLo>
36template <
typename Index,
typename LhsScalar,
int LhsStorageOrder,
bool ConjugateLhs,
typename RhsScalar,
37 int RhsStorageOrder,
bool ConjugateRhs,
int ResStorageOrder,
int ResInnerStride,
int UpLo,
38 int Version = Specialized>
39struct general_matrix_matrix_triangular_product;
42template <
typename Index,
typename LhsScalar,
int LhsStorageOrder,
bool ConjugateLhs,
typename RhsScalar,
43 int RhsStorageOrder,
bool ConjugateRhs,
int ResInnerStride,
int UpLo,
int Version>
44struct general_matrix_matrix_triangular_product<
Index, LhsScalar, LhsStorageOrder, ConjugateLhs, RhsScalar,
45 RhsStorageOrder, ConjugateRhs,
RowMajor, ResInnerStride, UpLo,
47 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
48 static EIGEN_STRONG_INLINE
void run(
Index size,
Index depth,
const LhsScalar* lhs,
Index lhsStride,
49 const RhsScalar* rhs,
Index rhsStride, ResScalar* res,
Index resIncr,
50 Index resStride,
const ResScalar& alpha,
51 level3_blocking<RhsScalar, LhsScalar>& blocking) {
54 ConjugateLhs,
ColMajor, ResInnerStride,
56 lhsStride, res, resIncr, resStride,
61template <
typename Index,
typename LhsScalar,
int LhsStorageOrder,
bool ConjugateLhs,
typename RhsScalar,
62 int RhsStorageOrder,
bool ConjugateRhs,
int ResInnerStride,
int UpLo,
int Version>
63struct general_matrix_matrix_triangular_product<
Index, LhsScalar, LhsStorageOrder, ConjugateLhs, RhsScalar,
64 RhsStorageOrder, ConjugateRhs,
ColMajor, ResInnerStride, UpLo,
66 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
67 static EIGEN_STRONG_INLINE
void run(
Index size,
Index depth,
const LhsScalar* lhs_,
Index lhsStride,
68 const RhsScalar* rhs_,
Index rhsStride, ResScalar* res_,
Index resIncr,
69 Index resStride,
const ResScalar& alpha,
70 level3_blocking<LhsScalar, RhsScalar>& blocking) {
75 typedef gebp_traits<LhsScalar, RhsScalar> Traits;
77 typedef const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> LhsMapper;
78 typedef const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> RhsMapper;
79 typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
80 LhsMapper lhs(lhs_, lhsStride);
81 RhsMapper rhs(rhs_, rhsStride);
82 ResMapper res(res_, resStride, resIncr);
84 Index kc = blocking.kc();
86 Index mc = (std::min)(size, (std::max)(
static_cast<decltype(blocking.mc())
>(Traits::nr), blocking.mc()));
89 if (mc > Traits::nr) {
90 using UnsignedIndex =
typename make_unsigned<Index>::type;
91 mc = (UnsignedIndex(mc) / Traits::nr) * Traits::nr;
94 std::size_t sizeA = kc * mc;
95 std::size_t sizeB = kc * size;
97 ei_declare_aligned_stack_constructed_variable(LhsScalar, blockA, sizeA, blocking.blockA());
98 ei_declare_aligned_stack_constructed_variable(RhsScalar, blockB, sizeB, blocking.blockB());
100 gemm_pack_lhs<LhsScalar,
Index, LhsMapper, Traits::mr, Traits::LhsProgress,
typename Traits::LhsPacket4Packing,
103 gemm_pack_rhs<RhsScalar, Index, RhsMapper, Traits::nr, RhsStorageOrder> pack_rhs;
104 gebp_kernel<LhsScalar, RhsScalar, Index, ResMapper, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp;
105 tribb_kernel<LhsScalar, RhsScalar, Index, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs, ResInnerStride, UpLo>
108 for (
Index k2 = 0; k2 < depth; k2 += kc) {
109 const Index actual_kc = (std::min)(k2 + kc, depth) - k2;
112 pack_rhs(blockB, rhs.getSubMapper(k2, 0), actual_kc, size);
114 for (
Index i2 = 0; i2 < size; i2 += mc) {
115 const Index actual_mc = (std::min)(i2 + mc, size) - i2;
117 pack_lhs(blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc);
124 gebp(res.getSubMapper(i2, 0), blockA, blockB, actual_mc, actual_kc, (std::min)(size, i2), alpha, -1, -1, 0,
127 sybb(res_ + resStride * i2 + resIncr * i2, resIncr, resStride, blockA, blockB + actual_kc * i2, actual_mc,
131 Index j2 = i2 + actual_mc;
132 gebp(res.getSubMapper(i2, j2), blockA, blockB + actual_kc * j2, actual_mc, actual_kc,
133 (std::max)(
Index(0), size - j2), alpha, -1, -1, 0, 0);
149template <
typename LhsScalar,
typename RhsScalar,
typename Index,
int mr,
int nr,
bool ConjLhs,
bool ConjRhs,
150 int ResInnerStride,
int UpLo>
152 typedef gebp_traits<LhsScalar, RhsScalar, ConjLhs, ConjRhs> Traits;
153 typedef typename Traits::ResScalar ResScalar;
155 enum { BlockSize = meta_least_common_multiple<plain_enum_max(mr, nr), plain_enum_min(mr, nr)>::ret };
156 void operator()(ResScalar* res_,
Index resIncr,
Index resStride,
const LhsScalar* blockA,
const RhsScalar* blockB,
157 Index size,
Index depth,
const ResScalar& alpha) {
158 typedef blas_data_mapper<ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
159 typedef blas_data_mapper<ResScalar, Index, ColMajor, Unaligned> BufferMapper;
160 ResMapper res(res_, resStride, resIncr);
161 gebp_kernel<LhsScalar, RhsScalar, Index, ResMapper, mr, nr, ConjLhs, ConjRhs> gebp_kernel1;
162 gebp_kernel<LhsScalar, RhsScalar, Index, BufferMapper, mr, nr, ConjLhs, ConjRhs> gebp_kernel2;
164 Matrix<ResScalar, BlockSize, BlockSize, ColMajor> buffer;
168 for (
Index j = 0; j < size; j += BlockSize) {
169 Index actualBlockSize = std::min<Index>(BlockSize, size - j);
170 const RhsScalar* actual_b = blockB + j * depth;
173 gebp_kernel1(res.getSubMapper(0, j), blockA, actual_b, j, depth, actualBlockSize, alpha, -1, -1, 0, 0);
180 gebp_kernel2(BufferMapper(buffer.data(), BlockSize), blockA + depth * i, actual_b, actualBlockSize, depth,
181 actualBlockSize, alpha, -1, -1, 0, 0);
184 for (
Index j1 = 0; j1 < actualBlockSize; ++j1) {
185 typename ResMapper::LinearMapper r = res.getLinearMapper(i, j + j1);
186 for (
Index i1 = UpLo ==
Lower ? j1 : 0; UpLo ==
Lower ? i1 < actualBlockSize : i1 <= j1; ++i1)
187 r(i1) += buffer(i1, j1);
192 Index i = j + actualBlockSize;
193 gebp_kernel1(res.getSubMapper(i, j), blockA + depth * i, actual_b, size - i, depth, actualBlockSize, alpha, -1,
204template <
typename MatrixType,
typename ProductType,
int UpLo,
bool IsOuterProduct>
205struct general_product_to_triangular_selector;
207template <
typename MatrixType,
typename ProductType,
int UpLo>
208struct general_product_to_triangular_selector<
MatrixType, ProductType, UpLo, true> {
209 static void run(MatrixType& mat,
const ProductType& prod,
const typename MatrixType::Scalar& alpha,
bool beta) {
210 typedef typename MatrixType::Scalar Scalar;
212 typedef internal::remove_all_t<typename ProductType::LhsNested> Lhs;
213 typedef internal::blas_traits<Lhs> LhsBlasTraits;
214 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhs;
215 typedef internal::remove_all_t<ActualLhs> ActualLhs_;
216 internal::add_const_on_value_type_t<ActualLhs> actualLhs = LhsBlasTraits::extract(prod.lhs());
218 typedef internal::remove_all_t<typename ProductType::RhsNested> Rhs;
219 typedef internal::blas_traits<Rhs> RhsBlasTraits;
220 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhs;
221 typedef internal::remove_all_t<ActualRhs> ActualRhs_;
222 internal::add_const_on_value_type_t<ActualRhs> actualRhs = RhsBlasTraits::extract(prod.rhs());
224 Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) *
225 RhsBlasTraits::extractScalarFactor(prod.rhs().derived());
227 if (!beta) mat.template triangularView<UpLo>().setZero();
231 UseLhsDirectly = ActualLhs_::InnerStrideAtCompileTime == 1,
232 UseRhsDirectly = ActualRhs_::InnerStrideAtCompileTime == 1
235 internal::gemv_static_vector_if<Scalar, Lhs::SizeAtCompileTime, Lhs::MaxSizeAtCompileTime, !UseLhsDirectly>
237 ei_declare_aligned_stack_constructed_variable(
238 Scalar, actualLhsPtr, actualLhs.size(),
239 (UseLhsDirectly ?
const_cast<Scalar*
>(actualLhs.data()) : static_lhs.data()));
240 if (!UseLhsDirectly) Map<typename ActualLhs_::PlainObject>(actualLhsPtr, actualLhs.size()) = actualLhs;
242 internal::gemv_static_vector_if<Scalar, Rhs::SizeAtCompileTime, Rhs::MaxSizeAtCompileTime, !UseRhsDirectly>
244 ei_declare_aligned_stack_constructed_variable(
245 Scalar, actualRhsPtr, actualRhs.size(),
246 (UseRhsDirectly ?
const_cast<Scalar*
>(actualRhs.data()) : static_rhs.data()));
247 if (!UseRhsDirectly) Map<typename ActualRhs_::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
249 selfadjoint_rank1_update<
250 Scalar,
Index, StorageOrder, UpLo, LhsBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex,
251 RhsBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex>::run(actualLhs.size(), mat.data(),
252 mat.outerStride(), actualLhsPtr,
253 actualRhsPtr, actualAlpha);
257template <
typename MatrixType,
typename ProductType,
int UpLo>
258struct general_product_to_triangular_selector<MatrixType, ProductType, UpLo, false> {
259 static void run(MatrixType& mat,
const ProductType& prod,
const typename MatrixType::Scalar& alpha,
bool beta) {
260 typedef internal::remove_all_t<typename ProductType::LhsNested> Lhs;
261 typedef internal::blas_traits<Lhs> LhsBlasTraits;
262 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhs;
263 typedef internal::remove_all_t<ActualLhs> ActualLhs_;
264 internal::add_const_on_value_type_t<ActualLhs> actualLhs = LhsBlasTraits::extract(prod.lhs());
266 typedef internal::remove_all_t<typename ProductType::RhsNested> Rhs;
267 typedef internal::blas_traits<Rhs> RhsBlasTraits;
268 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhs;
269 typedef internal::remove_all_t<ActualRhs> ActualRhs_;
270 internal::add_const_on_value_type_t<ActualRhs> actualRhs = RhsBlasTraits::extract(prod.rhs());
272 typename ProductType::Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) *
273 RhsBlasTraits::extractScalarFactor(prod.rhs().derived());
275 if (!beta) mat.template triangularView<UpLo>().setZero();
278 IsRowMajor = (internal::traits<MatrixType>::Flags &
RowMajorBit) ? 1 : 0,
279 LhsIsRowMajor = ActualLhs_::Flags &
RowMajorBit ? 1 : 0,
280 RhsIsRowMajor = ActualRhs_::Flags &
RowMajorBit ? 1 : 0,
284 Index size = mat.cols();
285 if (SkipDiag) size--;
286 Index depth = actualLhs.cols();
288 typedef internal::gemm_blocking_space<IsRowMajor ?
RowMajor :
ColMajor,
typename Lhs::Scalar,
typename Rhs::Scalar,
289 MatrixType::MaxColsAtCompileTime, MatrixType::MaxColsAtCompileTime,
290 ActualRhs_::MaxColsAtCompileTime>
293 BlockingType blocking(size, size, depth, 1,
false);
295 internal::general_matrix_matrix_triangular_product<
297 typename Rhs::Scalar, RhsIsRowMajor ?
RowMajor :
ColMajor, RhsBlasTraits::NeedToConjugate,
299 UpLo&(
Lower |
Upper)>::run(size, depth, &actualLhs.coeffRef(SkipDiag && (UpLo & Lower) == Lower ? 1 : 0, 0),
300 actualLhs.outerStride(),
301 &actualRhs.coeffRef(0, SkipDiag && (UpLo & Upper) == Upper ? 1 : 0),
302 actualRhs.outerStride(),
304 (SkipDiag ? (
bool(IsRowMajor) != ((UpLo & Lower) == Lower) ? mat.innerStride()
307 mat.innerStride(), mat.outerStride(), actualAlpha, blocking);
311template <
typename MatrixType_,
unsigned int Mode_>
312template <
typename ProductType>
313EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
typename TriangularViewImpl<MatrixType_, Mode_, Dense>::TriangularViewType&
314TriangularViewImpl<MatrixType_, Mode_, Dense>::_assignProduct(
315 const ProductType& prod,
const typename TriangularViewImpl<MatrixType_, Mode_, Dense>::Scalar& alpha,
bool beta) {
316 EIGEN_STATIC_ASSERT((Mode_ & UnitDiag) == 0, WRITING_TO_TRIANGULAR_PART_WITH_UNIT_DIAGONAL_IS_NOT_SUPPORTED);
317 eigen_assert(derived().nestedExpression().rows() == prod.rows() && derived().cols() == prod.cols());
319 general_product_to_triangular_selector<MatrixType_, ProductType, Mode_,
320 internal::traits<ProductType>::InnerSize == 1>::run(derived()
322 .const_cast_derived(),
@ UnitDiag
Definition Constants.h:215
@ ZeroDiag
Definition Constants.h:217
@ Lower
Definition Constants.h:211
@ Upper
Definition Constants.h:213
@ ColMajor
Definition Constants.h:318
@ RowMajor
Definition Constants.h:320
const unsigned int RowMajorBit
Definition Constants.h:70
Namespace containing all symbols from the Eigen library.
Definition B01_Experimental.dox:1
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:82