10#ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_H
11#define EIGEN_TRIANGULAR_MATRIX_MATRIX_H
14#include "../InternalHeaderCheck.h"
47template <
typename Scalar,
typename Index,
int Mode,
bool LhsIsTriangular,
int LhsStorageOrder,
bool ConjugateLhs,
48 int RhsStorageOrder,
bool ConjugateRhs,
int ResStorageOrder,
int ResInnerStride,
int Version = Specialized>
49struct product_triangular_matrix_matrix;
51template <
typename Scalar,
typename Index,
int Mode,
bool LhsIsTriangular,
int LhsStorageOrder,
bool ConjugateLhs,
52 int RhsStorageOrder,
bool ConjugateRhs,
int ResInnerStride,
int Version>
53struct product_triangular_matrix_matrix<Scalar,
Index, Mode, LhsIsTriangular, LhsStorageOrder, ConjugateLhs,
54 RhsStorageOrder, ConjugateRhs,
RowMajor, ResInnerStride, Version> {
55 static EIGEN_STRONG_INLINE
void run(
Index rows,
Index cols,
Index depth,
const Scalar* lhs,
Index lhsStride,
56 const Scalar* rhs,
Index rhsStride, Scalar* res,
Index resIncr,
Index resStride,
57 const Scalar& alpha, level3_blocking<Scalar, Scalar>& blocking) {
61 ColMajor, ResInnerStride>::run(cols, rows, depth, rhs, rhsStride, lhs, lhsStride,
62 res, resIncr, resStride, alpha, blocking);
67template <
typename Scalar,
typename Index,
int Mode,
int LhsStorageOrder,
bool ConjugateLhs,
int RhsStorageOrder,
68 bool ConjugateRhs,
int ResInnerStride,
int Version>
69struct product_triangular_matrix_matrix<Scalar,
Index, Mode, true, LhsStorageOrder, ConjugateLhs, RhsStorageOrder,
70 ConjugateRhs,
ColMajor, ResInnerStride, Version> {
71 typedef gebp_traits<Scalar, Scalar> Traits;
73 SmallPanelWidth = 2 * plain_enum_max(Traits::mr, Traits::nr),
78 static EIGEN_DONT_INLINE
void run(
Index _rows,
Index _cols,
Index _depth,
const Scalar* lhs_,
Index lhsStride,
79 const Scalar* rhs_,
Index rhsStride, Scalar* res,
Index resIncr,
Index resStride,
80 const Scalar& alpha, level3_blocking<Scalar, Scalar>& blocking);
83template <
typename Scalar,
typename Index,
int Mode,
int LhsStorageOrder,
bool ConjugateLhs,
int RhsStorageOrder,
84 bool ConjugateRhs,
int ResInnerStride,
int Version>
85EIGEN_DONT_INLINE
void product_triangular_matrix_matrix<
86 Scalar,
Index, Mode,
true, LhsStorageOrder, ConjugateLhs, RhsStorageOrder, ConjugateRhs,
ColMajor, ResInnerStride,
87 Version>::run(
Index _rows,
Index _cols,
Index _depth,
const Scalar* lhs_,
Index lhsStride,
const Scalar* rhs_,
88 Index rhsStride, Scalar* res_,
Index resIncr,
Index resStride,
const Scalar& alpha,
89 level3_blocking<Scalar, Scalar>& blocking) {
91 Index diagSize = (std::min)(_rows, _depth);
92 Index rows = IsLower ? _rows : diagSize;
93 Index depth = IsLower ? diagSize : _depth;
96 typedef const_blas_data_mapper<Scalar, Index, LhsStorageOrder> LhsMapper;
97 typedef const_blas_data_mapper<Scalar, Index, RhsStorageOrder> RhsMapper;
98 typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
99 LhsMapper lhs(lhs_, lhsStride);
100 RhsMapper rhs(rhs_, rhsStride);
101 ResMapper res(res_, resStride, resIncr);
103 Index kc = blocking.kc();
104 Index mc = (std::min)(rows, blocking.mc());
108 Index panelWidth = (std::min)(
Index(SmallPanelWidth), (std::min)(kc, mc));
110 std::size_t sizeA = kc * mc;
111 std::size_t sizeB = kc * cols;
113 ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA());
114 ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB());
116 Matrix<Scalar, SmallPanelWidth, SmallPanelWidth, LhsStorageOrder> triangularBuffer;
117 triangularBuffer.setZero();
119 triangularBuffer.diagonal().setZero();
121 triangularBuffer.diagonal().setOnes();
123 gebp_kernel<Scalar, Scalar, Index, ResMapper, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp_kernel;
124 gemm_pack_lhs<Scalar,
Index, LhsMapper, Traits::mr, Traits::LhsProgress,
typename Traits::LhsPacket4Packing,
127 gemm_pack_rhs<Scalar, Index, RhsMapper, Traits::nr, RhsStorageOrder> pack_rhs;
129 for (
Index k2 = IsLower ? depth : 0; IsLower ? k2 > 0 : k2 < depth; IsLower ? k2 -= kc : k2 += kc) {
130 Index actual_kc = (std::min)(IsLower ? k2 : depth - k2, kc);
131 Index actual_k2 = IsLower ? k2 - actual_kc : k2;
134 if ((!IsLower) && (k2 < rows) && (k2 + actual_kc > rows)) {
135 actual_kc = rows - k2;
136 k2 = k2 + actual_kc - kc;
139 pack_rhs(blockB, rhs.getSubMapper(actual_k2, 0), actual_kc, cols);
147 if (IsLower || actual_k2 < rows) {
149 for (
Index k1 = 0; k1 < actual_kc; k1 += panelWidth) {
150 Index actualPanelWidth = std::min<Index>(actual_kc - k1, panelWidth);
151 Index lengthTarget = IsLower ? actual_kc - k1 - actualPanelWidth : k1;
152 Index startBlock = actual_k2 + k1;
153 Index blockBOffset = k1;
158 for (
Index k = 0; k < actualPanelWidth; ++k) {
159 if (SetDiag) triangularBuffer.coeffRef(k, k) = lhs(startBlock + k, startBlock + k);
160 for (
Index i = IsLower ? k + 1 : 0; IsLower ? i < actualPanelWidth : i < k; ++i)
161 triangularBuffer.coeffRef(i, k) = lhs(startBlock + i, startBlock + k);
163 pack_lhs(blockA, LhsMapper(triangularBuffer.data(), triangularBuffer.outerStride()), actualPanelWidth,
166 gebp_kernel(res.getSubMapper(startBlock, 0), blockA, blockB, actualPanelWidth, actualPanelWidth, cols, alpha,
167 actualPanelWidth, actual_kc, 0, blockBOffset);
170 if (lengthTarget > 0) {
171 Index startTarget = IsLower ? actual_k2 + k1 + actualPanelWidth : actual_k2;
173 pack_lhs(blockA, lhs.getSubMapper(startTarget, startBlock), actualPanelWidth, lengthTarget);
175 gebp_kernel(res.getSubMapper(startTarget, 0), blockA, blockB, lengthTarget, actualPanelWidth, cols, alpha,
176 actualPanelWidth, actual_kc, 0, blockBOffset);
182 Index start = IsLower ? k2 : 0;
183 Index end = IsLower ? rows : (std::min)(actual_k2, rows);
184 for (
Index i2 = start; i2 < end; i2 += mc) {
185 const Index actual_mc = (std::min)(i2 + mc, end) - i2;
186 gemm_pack_lhs<Scalar,
Index, LhsMapper, Traits::mr, Traits::LhsProgress,
typename Traits::LhsPacket4Packing,
187 LhsStorageOrder,
false>()(blockA, lhs.getSubMapper(i2, actual_k2), actual_kc, actual_mc);
189 gebp_kernel(res.getSubMapper(i2, 0), blockA, blockB, actual_mc, actual_kc, cols, alpha, -1, -1, 0, 0);
196template <
typename Scalar,
typename Index,
int Mode,
int LhsStorageOrder,
bool ConjugateLhs,
int RhsStorageOrder,
197 bool ConjugateRhs,
int ResInnerStride,
int Version>
198struct product_triangular_matrix_matrix<Scalar,
Index, Mode, false, LhsStorageOrder, ConjugateLhs, RhsStorageOrder,
199 ConjugateRhs,
ColMajor, ResInnerStride, Version> {
200 typedef gebp_traits<Scalar, Scalar> Traits;
202 SmallPanelWidth = plain_enum_max(Traits::mr, Traits::nr),
207 static EIGEN_DONT_INLINE
void run(
Index _rows,
Index _cols,
Index _depth,
const Scalar* lhs_,
Index lhsStride,
208 const Scalar* rhs_,
Index rhsStride, Scalar* res,
Index resIncr,
Index resStride,
209 const Scalar& alpha, level3_blocking<Scalar, Scalar>& blocking);
212template <
typename Scalar,
typename Index,
int Mode,
int LhsStorageOrder,
bool ConjugateLhs,
int RhsStorageOrder,
213 bool ConjugateRhs,
int ResInnerStride,
int Version>
214EIGEN_DONT_INLINE
void product_triangular_matrix_matrix<
215 Scalar,
Index, Mode,
false, LhsStorageOrder, ConjugateLhs, RhsStorageOrder, ConjugateRhs,
ColMajor, ResInnerStride,
216 Version>::run(
Index _rows,
Index _cols,
Index _depth,
const Scalar* lhs_,
Index lhsStride,
const Scalar* rhs_,
217 Index rhsStride, Scalar* res_,
Index resIncr,
Index resStride,
const Scalar& alpha,
218 level3_blocking<Scalar, Scalar>& blocking) {
219 const Index PacketBytes = packet_traits<Scalar>::size *
sizeof(Scalar);
221 Index diagSize = (std::min)(_cols, _depth);
223 Index depth = IsLower ? _depth : diagSize;
224 Index cols = IsLower ? diagSize : _cols;
226 typedef const_blas_data_mapper<Scalar, Index, LhsStorageOrder> LhsMapper;
227 typedef const_blas_data_mapper<Scalar, Index, RhsStorageOrder> RhsMapper;
228 typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
229 LhsMapper lhs(lhs_, lhsStride);
230 RhsMapper rhs(rhs_, rhsStride);
231 ResMapper res(res_, resStride, resIncr);
233 Index kc = blocking.kc();
234 Index mc = (std::min)(rows, blocking.mc());
236 std::size_t sizeA = kc * mc;
237 std::size_t sizeB = kc * cols + EIGEN_MAX_ALIGN_BYTES /
sizeof(Scalar);
239 ei_declare_aligned_stack_constructed_variable(Scalar, blockA, sizeA, blocking.blockA());
240 ei_declare_aligned_stack_constructed_variable(Scalar, blockB, sizeB, blocking.blockB());
242 Matrix<Scalar, SmallPanelWidth, SmallPanelWidth, RhsStorageOrder> triangularBuffer;
243 triangularBuffer.setZero();
245 triangularBuffer.diagonal().setZero();
247 triangularBuffer.diagonal().setOnes();
249 gebp_kernel<Scalar, Scalar, Index, ResMapper, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp_kernel;
250 gemm_pack_lhs<Scalar,
Index, LhsMapper, Traits::mr, Traits::LhsProgress,
typename Traits::LhsPacket4Packing,
253 gemm_pack_rhs<Scalar, Index, RhsMapper, Traits::nr, RhsStorageOrder> pack_rhs;
254 gemm_pack_rhs<Scalar, Index, RhsMapper, Traits::nr, RhsStorageOrder, false, true> pack_rhs_panel;
256 for (
Index k2 = IsLower ? 0 : depth; IsLower ? k2 < depth : k2 > 0; IsLower ? k2 += kc : k2 -= kc) {
257 Index actual_kc = (std::min)(IsLower ? depth - k2 : k2, kc);
258 Index actual_k2 = IsLower ? k2 : k2 - actual_kc;
261 if (IsLower && (k2 < cols) && (actual_k2 + actual_kc > cols)) {
262 actual_kc = cols - k2;
263 k2 = actual_k2 + actual_kc - kc;
267 Index rs = IsLower ? (std::min)(cols, actual_k2) : cols - k2;
269 Index ts = (IsLower && actual_k2 >= cols) ? 0 : actual_kc;
271 Scalar* geb = blockB + ts * ts;
272 geb = geb + internal::first_aligned<PacketBytes>(geb, PacketBytes /
sizeof(Scalar));
274 pack_rhs(geb, rhs.getSubMapper(actual_k2, IsLower ? 0 : k2), actual_kc, rs);
278 for (
Index j2 = 0; j2 < actual_kc; j2 += SmallPanelWidth) {
279 Index actualPanelWidth = std::min<Index>(actual_kc - j2, SmallPanelWidth);
280 Index actual_j2 = actual_k2 + j2;
281 Index panelOffset = IsLower ? j2 + actualPanelWidth : 0;
282 Index panelLength = IsLower ? actual_kc - j2 - actualPanelWidth : j2;
284 pack_rhs_panel(blockB + j2 * actual_kc, rhs.getSubMapper(actual_k2 + panelOffset, actual_j2), panelLength,
285 actualPanelWidth, actual_kc, panelOffset);
288 for (
Index j = 0; j < actualPanelWidth; ++j) {
289 if (SetDiag) triangularBuffer.coeffRef(j, j) = rhs(actual_j2 + j, actual_j2 + j);
290 for (
Index k = IsLower ? j + 1 : 0; IsLower ? k < actualPanelWidth : k < j; ++k)
291 triangularBuffer.coeffRef(k, j) = rhs(actual_j2 + k, actual_j2 + j);
294 pack_rhs_panel(blockB + j2 * actual_kc, RhsMapper(triangularBuffer.data(), triangularBuffer.outerStride()),
295 actualPanelWidth, actualPanelWidth, actual_kc, j2);
299 for (
Index i2 = 0; i2 < rows; i2 += mc) {
300 const Index actual_mc = (std::min)(mc, rows - i2);
301 pack_lhs(blockA, lhs.getSubMapper(i2, actual_k2), actual_kc, actual_mc);
305 for (
Index j2 = 0; j2 < actual_kc; j2 += SmallPanelWidth) {
306 Index actualPanelWidth = std::min<Index>(actual_kc - j2, SmallPanelWidth);
307 Index panelLength = IsLower ? actual_kc - j2 : j2 + actualPanelWidth;
308 Index blockOffset = IsLower ? j2 : 0;
310 gebp_kernel(res.getSubMapper(i2, actual_k2 + j2), blockA, blockB + j2 * actual_kc, actual_mc, panelLength,
311 actualPanelWidth, alpha, actual_kc, actual_kc,
312 blockOffset, blockOffset);
315 gebp_kernel(res.getSubMapper(i2, IsLower ? 0 : k2), blockA, geb, actual_mc, actual_kc, rs, alpha, -1, -1, 0, 0);
327template <
int Mode,
bool LhsIsTriangular,
typename Lhs,
typename Rhs>
328struct triangular_product_impl<Mode, LhsIsTriangular, Lhs, false, Rhs, false> {
329 template <
typename Dest>
330 static void run(Dest& dst,
const Lhs& a_lhs,
const Rhs& a_rhs,
const typename Dest::Scalar& alpha) {
331 typedef typename Lhs::Scalar LhsScalar;
332 typedef typename Rhs::Scalar RhsScalar;
333 typedef typename Dest::Scalar Scalar;
335 typedef internal::blas_traits<Lhs> LhsBlasTraits;
336 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
337 typedef internal::remove_all_t<ActualLhsType> ActualLhsTypeCleaned;
338 typedef internal::blas_traits<Rhs> RhsBlasTraits;
339 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
340 typedef internal::remove_all_t<ActualRhsType> ActualRhsTypeCleaned;
342 internal::add_const_on_value_type_t<ActualLhsType> lhs = LhsBlasTraits::extract(a_lhs);
343 internal::add_const_on_value_type_t<ActualRhsType> rhs = RhsBlasTraits::extract(a_rhs);
347 if (lhs.size() == 0 || rhs.size() == 0) {
351 LhsScalar lhs_alpha = LhsBlasTraits::extractScalarFactor(a_lhs);
352 RhsScalar rhs_alpha = RhsBlasTraits::extractScalarFactor(a_rhs);
353 Scalar actualAlpha = alpha * lhs_alpha * rhs_alpha;
356 Lhs::MaxRowsAtCompileTime, Rhs::MaxColsAtCompileTime,
357 Lhs::MaxColsAtCompileTime, 4>
361 Index stripedRows = ((!LhsIsTriangular) || (IsLower)) ? lhs.rows() : (std::min)(lhs.rows(), lhs.cols());
362 Index stripedCols = ((LhsIsTriangular) || (!IsLower)) ? rhs.cols() : (std::min)(rhs.cols(), rhs.rows());
363 Index stripedDepth = LhsIsTriangular ? ((!IsLower) ? lhs.cols() : (std::min)(lhs.cols(), lhs.rows()))
364 : ((IsLower) ? rhs.rows() : (std::min)(rhs.rows(), rhs.cols()));
366 BlockingType blocking(stripedRows, stripedCols, stripedDepth, 1,
false);
368 internal::product_triangular_matrix_matrix<
369 Scalar,
Index, Mode, LhsIsTriangular,
371 LhsBlasTraits::NeedToConjugate,
374 Dest::InnerStrideAtCompileTime>::run(stripedRows, stripedCols, stripedDepth,
375 &lhs.coeffRef(0, 0), lhs.outerStride(),
376 &rhs.coeffRef(0, 0), rhs.outerStride(),
377 &dst.coeffRef(0, 0), dst.innerStride(), dst.outerStride(),
378 actualAlpha, blocking);
382 if (LhsIsTriangular && !numext::is_exactly_one(lhs_alpha)) {
383 Index diagSize = (std::min)(lhs.rows(), lhs.cols());
384 dst.topRows(diagSize) -= ((lhs_alpha - LhsScalar(1)) * a_rhs).topRows(diagSize);
385 }
else if ((!LhsIsTriangular) && !numext::is_exactly_one(rhs_alpha)) {
386 Index diagSize = (std::min)(rhs.rows(), rhs.cols());
387 dst.leftCols(diagSize) -= (rhs_alpha - RhsScalar(1)) * a_lhs.leftCols(diagSize);
@ UnitDiag
Definition Constants.h:215
@ ZeroDiag
Definition Constants.h:217
@ Lower
Definition Constants.h:211
@ Upper
Definition Constants.h:213
@ ColMajor
Definition Constants.h:318
@ RowMajor
Definition Constants.h:320
const unsigned int RowMajorBit
Definition Constants.h:70
Namespace containing all symbols from the Eigen library.
Definition B01_Experimental.dox:1
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:82