Eigen  5.0.1-dev+60122df6
 
Loading...
Searching...
No Matches
TriangularMatrixMatrix_BLAS.h
1/*
2 Copyright (c) 2011, Intel Corporation. All rights reserved.
3
4 Redistribution and use in source and binary forms, with or without modification,
5 are permitted provided that the following conditions are met:
6
7 * Redistributions of source code must retain the above copyright notice, this
8 list of conditions and the following disclaimer.
9 * Redistributions in binary form must reproduce the above copyright notice,
10 this list of conditions and the following disclaimer in the documentation
11 and/or other materials provided with the distribution.
12 * Neither the name of Intel Corporation nor the names of its contributors may
13 be used to endorse or promote products derived from this software without
14 specific prior written permission.
15
16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17 ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18 WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
20 ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21 (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22 LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
23 ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
27 ********************************************************************************
28 * Content : Eigen bindings to BLAS F77
29 * Triangular matrix * matrix product functionality based on ?TRMM.
30 ********************************************************************************
31*/
32
33#ifndef EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
34#define EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
35
36// IWYU pragma: private
37#include "../InternalHeaderCheck.h"
38
39namespace Eigen {
40
41namespace internal {
42
43template <typename Scalar, typename Index, int Mode, bool LhsIsTriangular, int LhsStorageOrder, bool ConjugateLhs,
44 int RhsStorageOrder, bool ConjugateRhs, int ResStorageOrder>
45struct product_triangular_matrix_matrix_trmm
46 : product_triangular_matrix_matrix<Scalar, Index, Mode, LhsIsTriangular, LhsStorageOrder, ConjugateLhs,
47 RhsStorageOrder, ConjugateRhs, ResStorageOrder, 1, BuiltIn> {};
48
49// try to go to BLAS specialization
50#define EIGEN_BLAS_TRMM_SPECIALIZE(Scalar, LhsIsTriangular) \
51 template <typename Index, int Mode, int LhsStorageOrder, bool ConjugateLhs, int RhsStorageOrder, bool ConjugateRhs> \
52 struct product_triangular_matrix_matrix<Scalar, Index, Mode, LhsIsTriangular, LhsStorageOrder, ConjugateLhs, \
53 RhsStorageOrder, ConjugateRhs, ColMajor, 1, Specialized> { \
54 static inline void run(Index _rows, Index _cols, Index _depth, const Scalar* _lhs, Index lhsStride, \
55 const Scalar* _rhs, Index rhsStride, Scalar* res, Index resIncr, Index resStride, \
56 Scalar alpha, level3_blocking<Scalar, Scalar>& blocking) { \
57 EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \
58 eigen_assert(resIncr == 1); \
59 product_triangular_matrix_matrix_trmm<Scalar, Index, Mode, LhsIsTriangular, LhsStorageOrder, ConjugateLhs, \
60 RhsStorageOrder, ConjugateRhs, ColMajor>::run(_rows, _cols, _depth, _lhs, \
61 lhsStride, _rhs, rhsStride, \
62 res, resStride, alpha, \
63 blocking); \
64 } \
65 };
66
67EIGEN_BLAS_TRMM_SPECIALIZE(double, true)
68EIGEN_BLAS_TRMM_SPECIALIZE(double, false)
69EIGEN_BLAS_TRMM_SPECIALIZE(dcomplex, true)
70EIGEN_BLAS_TRMM_SPECIALIZE(dcomplex, false)
71EIGEN_BLAS_TRMM_SPECIALIZE(float, true)
72EIGEN_BLAS_TRMM_SPECIALIZE(float, false)
73EIGEN_BLAS_TRMM_SPECIALIZE(scomplex, true)
74EIGEN_BLAS_TRMM_SPECIALIZE(scomplex, false)
75
76// implements col-major += alpha * op(triangular) * op(general)
77#define EIGEN_BLAS_TRMM_L(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \
78 template <typename Index, int Mode, int LhsStorageOrder, bool ConjugateLhs, int RhsStorageOrder, bool ConjugateRhs> \
79 struct product_triangular_matrix_matrix_trmm<EIGTYPE, Index, Mode, true, LhsStorageOrder, ConjugateLhs, \
80 RhsStorageOrder, ConjugateRhs, ColMajor> { \
81 enum { \
82 IsLower = (Mode & Lower) == Lower, \
83 SetDiag = (Mode & (ZeroDiag | UnitDiag)) ? 0 : 1, \
84 IsUnitDiag = (Mode & UnitDiag) ? 1 : 0, \
85 IsZeroDiag = (Mode & ZeroDiag) ? 1 : 0, \
86 LowUp = IsLower ? Lower : Upper, \
87 conjA = ((LhsStorageOrder == ColMajor) && ConjugateLhs) ? 1 : 0 \
88 }; \
89 \
90 static void run(Index _rows, Index _cols, Index _depth, const EIGTYPE* _lhs, Index lhsStride, const EIGTYPE* _rhs, \
91 Index rhsStride, EIGTYPE* res, Index resStride, EIGTYPE alpha, \
92 level3_blocking<EIGTYPE, EIGTYPE>& blocking) { \
93 if (_rows == 0 || _cols == 0 || _depth == 0) return; \
94 Index diagSize = (std::min)(_rows, _depth); \
95 Index rows = IsLower ? _rows : diagSize; \
96 Index depth = IsLower ? diagSize : _depth; \
97 Index cols = _cols; \
98 \
99 typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
100 typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
101 \
102 /* Non-square case - doesn't fit to BLAS ?TRMM. Fall to default triangular product or call BLAS ?GEMM*/ \
103 if (rows != depth) { \
104 /* FIXME handle mkl_domain_get_max_threads */ \
105 /*int nthr = mkl_domain_get_max_threads(EIGEN_BLAS_DOMAIN_BLAS);*/ int nthr = 1; \
106 \
107 if (((nthr == 1) && (((std::max)(rows, depth) - diagSize) / (double)diagSize < 0.5))) { \
108 /* Most likely no benefit to call TRMM or GEMM from BLAS */ \
109 product_triangular_matrix_matrix<EIGTYPE, Index, Mode, true, LhsStorageOrder, ConjugateLhs, RhsStorageOrder, \
110 ConjugateRhs, ColMajor, 1, BuiltIn>::run(_rows, _cols, _depth, _lhs, \
111 lhsStride, _rhs, rhsStride, res, \
112 1, resStride, alpha, blocking); \
113 /*std::cout << "TRMM_L: A is not square! Go to Eigen TRMM implementation!\n";*/ \
114 } else { \
115 /* Make sense to call GEMM */ \
116 Map<const MatrixLhs, 0, OuterStride<> > lhsMap(_lhs, rows, depth, OuterStride<>(lhsStride)); \
117 MatrixLhs aa_tmp = lhsMap.template triangularView<Mode>(); \
118 BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
119 gemm_blocking_space<ColMajor, EIGTYPE, EIGTYPE, Dynamic, Dynamic, Dynamic> gemm_blocking(_rows, _cols, \
120 _depth, 1, true); \
121 general_matrix_matrix_product<Index, EIGTYPE, LhsStorageOrder, ConjugateLhs, EIGTYPE, RhsStorageOrder, \
122 ConjugateRhs, ColMajor, 1>::run(rows, cols, depth, aa_tmp.data(), aStride, \
123 _rhs, rhsStride, res, 1, resStride, alpha, \
124 gemm_blocking, 0); \
125 \
126 /*std::cout << "TRMM_L: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \
127 } \
128 return; \
129 } \
130 char side = 'L', transa, uplo, diag = 'N'; \
131 EIGTYPE* b; \
132 const EIGTYPE* a; \
133 BlasIndex m, n, lda, ldb; \
134 \
135 /* Set m, n */ \
136 m = convert_index<BlasIndex>(diagSize); \
137 n = convert_index<BlasIndex>(cols); \
138 \
139 /* Set trans */ \
140 transa = (LhsStorageOrder == RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \
141 \
142 /* Set b, ldb */ \
143 Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs, depth, cols, OuterStride<>(rhsStride)); \
144 MatrixX##EIGPREFIX b_tmp; \
145 \
146 if (ConjugateRhs) \
147 b_tmp = rhs.conjugate(); \
148 else \
149 b_tmp = rhs; \
150 b = b_tmp.data(); \
151 ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
152 \
153 /* Set uplo */ \
154 uplo = IsLower ? 'L' : 'U'; \
155 if (LhsStorageOrder == RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \
156 /* Set a, lda */ \
157 Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs, rows, depth, OuterStride<>(lhsStride)); \
158 MatrixLhs a_tmp; \
159 \
160 if ((conjA != 0) || (SetDiag == 0)) { \
161 if (conjA) \
162 a_tmp = lhs.conjugate(); \
163 else \
164 a_tmp = lhs; \
165 if (IsZeroDiag) \
166 a_tmp.diagonal().setZero(); \
167 else if (IsUnitDiag) \
168 a_tmp.diagonal().setOnes(); \
169 a = a_tmp.data(); \
170 lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
171 } else { \
172 a = _lhs; \
173 lda = convert_index<BlasIndex>(lhsStride); \
174 } \
175 /*std::cout << "TRMM_L: A is square! Go to BLAS TRMM implementation! \n";*/ \
176 /* call ?trmm*/ \
177 BLASFUNC(&side, &uplo, &transa, &diag, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, \
178 &lda, (BLASTYPE*)b, &ldb); \
179 \
180 /* Add op(a_triangular)*b into res*/ \
181 Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res, rows, cols, OuterStride<>(resStride)); \
182 res_tmp = res_tmp + b_tmp; \
183 } \
184 };
185
186#ifdef EIGEN_USE_MKL
187EIGEN_BLAS_TRMM_L(double, double, d, dtrmm)
188EIGEN_BLAS_TRMM_L(dcomplex, MKL_Complex16, cd, ztrmm)
189EIGEN_BLAS_TRMM_L(float, float, f, strmm)
190EIGEN_BLAS_TRMM_L(scomplex, MKL_Complex8, cf, ctrmm)
191#else
192EIGEN_BLAS_TRMM_L(double, double, d, dtrmm_)
193EIGEN_BLAS_TRMM_L(dcomplex, double, cd, ztrmm_)
194EIGEN_BLAS_TRMM_L(float, float, f, strmm_)
195EIGEN_BLAS_TRMM_L(scomplex, float, cf, ctrmm_)
196#endif
197
198// implements col-major += alpha * op(general) * op(triangular)
199#define EIGEN_BLAS_TRMM_R(EIGTYPE, BLASTYPE, EIGPREFIX, BLASFUNC) \
200 template <typename Index, int Mode, int LhsStorageOrder, bool ConjugateLhs, int RhsStorageOrder, bool ConjugateRhs> \
201 struct product_triangular_matrix_matrix_trmm<EIGTYPE, Index, Mode, false, LhsStorageOrder, ConjugateLhs, \
202 RhsStorageOrder, ConjugateRhs, ColMajor> { \
203 enum { \
204 IsLower = (Mode & Lower) == Lower, \
205 SetDiag = (Mode & (ZeroDiag | UnitDiag)) ? 0 : 1, \
206 IsUnitDiag = (Mode & UnitDiag) ? 1 : 0, \
207 IsZeroDiag = (Mode & ZeroDiag) ? 1 : 0, \
208 LowUp = IsLower ? Lower : Upper, \
209 conjA = ((RhsStorageOrder == ColMajor) && ConjugateRhs) ? 1 : 0 \
210 }; \
211 \
212 static void run(Index _rows, Index _cols, Index _depth, const EIGTYPE* _lhs, Index lhsStride, const EIGTYPE* _rhs, \
213 Index rhsStride, EIGTYPE* res, Index resStride, EIGTYPE alpha, \
214 level3_blocking<EIGTYPE, EIGTYPE>& blocking) { \
215 if (_rows == 0 || _cols == 0 || _depth == 0) return; \
216 Index diagSize = (std::min)(_cols, _depth); \
217 Index rows = _rows; \
218 Index depth = IsLower ? _depth : diagSize; \
219 Index cols = IsLower ? diagSize : _cols; \
220 \
221 typedef Matrix<EIGTYPE, Dynamic, Dynamic, LhsStorageOrder> MatrixLhs; \
222 typedef Matrix<EIGTYPE, Dynamic, Dynamic, RhsStorageOrder> MatrixRhs; \
223 \
224 /* Non-square case - doesn't fit to BLAS ?TRMM. Fall to default triangular product or call BLAS ?GEMM*/ \
225 if (cols != depth) { \
226 int nthr = 1 /*mkl_domain_get_max_threads(EIGEN_BLAS_DOMAIN_BLAS)*/; \
227 \
228 if ((nthr == 1) && (((std::max)(cols, depth) - diagSize) / (double)diagSize < 0.5)) { \
229 /* Most likely no benefit to call TRMM or GEMM from BLAS*/ \
230 product_triangular_matrix_matrix<EIGTYPE, Index, Mode, false, LhsStorageOrder, ConjugateLhs, \
231 RhsStorageOrder, ConjugateRhs, ColMajor, 1, BuiltIn>::run(_rows, _cols, \
232 _depth, _lhs, \
233 lhsStride, _rhs, \
234 rhsStride, res, \
235 1, resStride, \
236 alpha, blocking); \
237 /*std::cout << "TRMM_R: A is not square! Go to Eigen TRMM implementation!\n";*/ \
238 } else { \
239 /* Make sense to call GEMM */ \
240 Map<const MatrixRhs, 0, OuterStride<> > rhsMap(_rhs, depth, cols, OuterStride<>(rhsStride)); \
241 MatrixRhs aa_tmp = rhsMap.template triangularView<Mode>(); \
242 BlasIndex aStride = convert_index<BlasIndex>(aa_tmp.outerStride()); \
243 gemm_blocking_space<ColMajor, EIGTYPE, EIGTYPE, Dynamic, Dynamic, Dynamic> gemm_blocking(_rows, _cols, \
244 _depth, 1, true); \
245 general_matrix_matrix_product<Index, EIGTYPE, LhsStorageOrder, ConjugateLhs, EIGTYPE, RhsStorageOrder, \
246 ConjugateRhs, ColMajor, 1>::run(rows, cols, depth, _lhs, lhsStride, \
247 aa_tmp.data(), aStride, res, 1, resStride, \
248 alpha, gemm_blocking, 0); \
249 \
250 /*std::cout << "TRMM_R: A is not square! Go to BLAS GEMM implementation! " << nthr<<" \n";*/ \
251 } \
252 return; \
253 } \
254 char side = 'R', transa, uplo, diag = 'N'; \
255 EIGTYPE* b; \
256 const EIGTYPE* a; \
257 BlasIndex m, n, lda, ldb; \
258 \
259 /* Set m, n */ \
260 m = convert_index<BlasIndex>(rows); \
261 n = convert_index<BlasIndex>(diagSize); \
262 \
263 /* Set trans */ \
264 transa = (RhsStorageOrder == RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \
265 \
266 /* Set b, ldb */ \
267 Map<const MatrixLhs, 0, OuterStride<> > lhs(_lhs, rows, depth, OuterStride<>(lhsStride)); \
268 MatrixX##EIGPREFIX b_tmp; \
269 \
270 if (ConjugateLhs) \
271 b_tmp = lhs.conjugate(); \
272 else \
273 b_tmp = lhs; \
274 b = b_tmp.data(); \
275 ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
276 \
277 /* Set uplo */ \
278 uplo = IsLower ? 'L' : 'U'; \
279 if (RhsStorageOrder == RowMajor) uplo = (uplo == 'L') ? 'U' : 'L'; \
280 /* Set a, lda */ \
281 Map<const MatrixRhs, 0, OuterStride<> > rhs(_rhs, depth, cols, OuterStride<>(rhsStride)); \
282 MatrixRhs a_tmp; \
283 \
284 if ((conjA != 0) || (SetDiag == 0)) { \
285 if (conjA) \
286 a_tmp = rhs.conjugate(); \
287 else \
288 a_tmp = rhs; \
289 if (IsZeroDiag) \
290 a_tmp.diagonal().setZero(); \
291 else if (IsUnitDiag) \
292 a_tmp.diagonal().setOnes(); \
293 a = a_tmp.data(); \
294 lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
295 } else { \
296 a = _rhs; \
297 lda = convert_index<BlasIndex>(rhsStride); \
298 } \
299 /*std::cout << "TRMM_R: A is square! Go to BLAS TRMM implementation! \n";*/ \
300 /* call ?trmm*/ \
301 BLASFUNC(&side, &uplo, &transa, &diag, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, \
302 &lda, (BLASTYPE*)b, &ldb); \
303 \
304 /* Add op(a_triangular)*b into res*/ \
305 Map<MatrixX##EIGPREFIX, 0, OuterStride<> > res_tmp(res, rows, cols, OuterStride<>(resStride)); \
306 res_tmp = res_tmp + b_tmp; \
307 } \
308 };
309
310#ifdef EIGEN_USE_MKL
311EIGEN_BLAS_TRMM_R(double, double, d, dtrmm)
312EIGEN_BLAS_TRMM_R(dcomplex, MKL_Complex16, cd, ztrmm)
313EIGEN_BLAS_TRMM_R(float, float, f, strmm)
314EIGEN_BLAS_TRMM_R(scomplex, MKL_Complex8, cf, ctrmm)
315#else
316EIGEN_BLAS_TRMM_R(double, double, d, dtrmm_)
317EIGEN_BLAS_TRMM_R(dcomplex, double, cd, ztrmm_)
318EIGEN_BLAS_TRMM_R(float, float, f, strmm_)
319EIGEN_BLAS_TRMM_R(scomplex, float, cf, ctrmm_)
320#endif
321} // end namespace internal
322
323} // end namespace Eigen
324
325#endif // EIGEN_TRIANGULAR_MATRIX_MATRIX_BLAS_H
Namespace containing all symbols from the Eigen library.
Definition B01_Experimental.dox:1
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:82