Eigen  5.0.1-dev+60122df6
 
Loading...
Searching...
No Matches
TriangularMatrixVector_BLAS.h
1/*
2 Copyright (c) 2011, Intel Corporation. All rights reserved.
3
4 Redistribution and use in source and binary forms, with or without modification,
5 are permitted provided that the following conditions are met:
6
7 * Redistributions of source code must retain the above copyright notice, this
8 list of conditions and the following disclaimer.
9 * Redistributions in binary form must reproduce the above copyright notice,
10 this list of conditions and the following disclaimer in the documentation
11 and/or other materials provided with the distribution.
12 * Neither the name of Intel Corporation nor the names of its contributors may
13 be used to endorse or promote products derived from this software without
14 specific prior written permission.
15
16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17 ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18 WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
20 ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21 (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22 LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
23 ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
27 ********************************************************************************
28 * Content : Eigen bindings to BLAS F77
29 * Triangular matrix-vector product functionality based on ?TRMV.
30 ********************************************************************************
31*/
32
33#ifndef EIGEN_TRIANGULAR_MATRIX_VECTOR_BLAS_H
34#define EIGEN_TRIANGULAR_MATRIX_VECTOR_BLAS_H
35
36// IWYU pragma: private
37#include "../InternalHeaderCheck.h"
38
39namespace Eigen {
40
41namespace internal {
42
43/**********************************************************************
44 * This file implements triangular matrix-vector multiplication using BLAS
45 **********************************************************************/
46
47// trmv/hemv specialization
48
49template <typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs,
50 int StorageOrder>
51struct triangular_matrix_vector_product_trmv
52 : triangular_matrix_vector_product<Index, Mode, LhsScalar, ConjLhs, RhsScalar, ConjRhs, StorageOrder, BuiltIn> {};
53
54#define EIGEN_BLAS_TRMV_SPECIALIZE(Scalar) \
55 template <typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
56 struct triangular_matrix_vector_product<Index, Mode, Scalar, ConjLhs, Scalar, ConjRhs, ColMajor, Specialized> { \
57 static void run(Index rows_, Index cols_, const Scalar* lhs_, Index lhsStride, const Scalar* rhs_, Index rhsIncr, \
58 Scalar* res_, Index resIncr, Scalar alpha) { \
59 triangular_matrix_vector_product_trmv<Index, Mode, Scalar, ConjLhs, Scalar, ConjRhs, ColMajor>::run( \
60 rows_, cols_, lhs_, lhsStride, rhs_, rhsIncr, res_, resIncr, alpha); \
61 } \
62 }; \
63 template <typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
64 struct triangular_matrix_vector_product<Index, Mode, Scalar, ConjLhs, Scalar, ConjRhs, RowMajor, Specialized> { \
65 static void run(Index rows_, Index cols_, const Scalar* lhs_, Index lhsStride, const Scalar* rhs_, Index rhsIncr, \
66 Scalar* res_, Index resIncr, Scalar alpha) { \
67 triangular_matrix_vector_product_trmv<Index, Mode, Scalar, ConjLhs, Scalar, ConjRhs, RowMajor>::run( \
68 rows_, cols_, lhs_, lhsStride, rhs_, rhsIncr, res_, resIncr, alpha); \
69 } \
70 };
71
72EIGEN_BLAS_TRMV_SPECIALIZE(double)
73EIGEN_BLAS_TRMV_SPECIALIZE(float)
74EIGEN_BLAS_TRMV_SPECIALIZE(dcomplex)
75EIGEN_BLAS_TRMV_SPECIALIZE(scomplex)
76
77// implements col-major: res += alpha * op(triangular) * vector
78#define EIGEN_BLAS_TRMV_CM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX, BLASPOSTFIX) \
79 template <typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
80 struct triangular_matrix_vector_product_trmv<Index, Mode, EIGTYPE, ConjLhs, EIGTYPE, ConjRhs, ColMajor> { \
81 enum { \
82 IsLower = (Mode & Lower) == Lower, \
83 SetDiag = (Mode & (ZeroDiag | UnitDiag)) ? 0 : 1, \
84 IsUnitDiag = (Mode & UnitDiag) ? 1 : 0, \
85 IsZeroDiag = (Mode & ZeroDiag) ? 1 : 0, \
86 LowUp = IsLower ? Lower : Upper \
87 }; \
88 static void run(Index rows_, Index cols_, const EIGTYPE* lhs_, Index lhsStride, const EIGTYPE* rhs_, \
89 Index rhsIncr, EIGTYPE* res_, Index resIncr, EIGTYPE alpha) { \
90 if (rows_ == 0 || cols_ == 0) return; \
91 if (ConjLhs || IsZeroDiag) { \
92 triangular_matrix_vector_product<Index, Mode, EIGTYPE, ConjLhs, EIGTYPE, ConjRhs, ColMajor, BuiltIn>::run( \
93 rows_, cols_, lhs_, lhsStride, rhs_, rhsIncr, res_, resIncr, alpha); \
94 return; \
95 } \
96 Index size = (std::min)(rows_, cols_); \
97 Index rows = IsLower ? rows_ : size; \
98 Index cols = IsLower ? size : cols_; \
99 \
100 typedef VectorX##EIGPREFIX VectorRhs; \
101 EIGTYPE *x, *y; \
102 \
103 /* Set x*/ \
104 Map<const VectorRhs, 0, InnerStride<> > rhs(rhs_, cols, InnerStride<>(rhsIncr)); \
105 VectorRhs x_tmp; \
106 if (ConjRhs) \
107 x_tmp = rhs.conjugate(); \
108 else \
109 x_tmp = rhs; \
110 x = x_tmp.data(); \
111 \
112 /* Square part handling */ \
113 \
114 char trans, uplo, diag; \
115 BlasIndex m, n, lda, incx, incy; \
116 EIGTYPE const* a; \
117 EIGTYPE beta(1); \
118 \
119 /* Set m, n */ \
120 n = convert_index<BlasIndex>(size); \
121 lda = convert_index<BlasIndex>(lhsStride); \
122 incx = 1; \
123 incy = convert_index<BlasIndex>(resIncr); \
124 \
125 /* Set uplo, trans and diag*/ \
126 trans = 'N'; \
127 uplo = IsLower ? 'L' : 'U'; \
128 diag = IsUnitDiag ? 'U' : 'N'; \
129 \
130 /* call ?TRMV*/ \
131 BLASPREFIX##trmv##BLASPOSTFIX(&uplo, &trans, &diag, &n, (const BLASTYPE*)lhs_, &lda, (BLASTYPE*)x, &incx); \
132 \
133 /* Add op(a_tr)rhs into res*/ \
134 BLASPREFIX##axpy##BLASPOSTFIX(&n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)x, &incx, \
135 (BLASTYPE*)res_, &incy); \
136 /* Non-square case - doesn't fit to BLAS ?TRMV. Fall to default triangular product*/ \
137 if (size < (std::max)(rows, cols)) { \
138 if (ConjRhs) \
139 x_tmp = rhs.conjugate(); \
140 else \
141 x_tmp = rhs; \
142 x = x_tmp.data(); \
143 if (size < rows) { \
144 y = res_ + size * resIncr; \
145 a = lhs_ + size; \
146 m = convert_index<BlasIndex>(rows - size); \
147 n = convert_index<BlasIndex>(size); \
148 } else { \
149 x += size; \
150 y = res_; \
151 a = lhs_ + size * lda; \
152 m = convert_index<BlasIndex>(size); \
153 n = convert_index<BlasIndex>(cols - size); \
154 } \
155 BLASPREFIX##gemv##BLASPOSTFIX(&trans, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, \
156 &lda, (const BLASTYPE*)x, &incx, (const BLASTYPE*)&numext::real_ref(beta), \
157 (BLASTYPE*)y, &incy); \
158 } \
159 } \
160 };
161
162#ifdef EIGEN_USE_MKL
163EIGEN_BLAS_TRMV_CM(double, double, d, d, )
164EIGEN_BLAS_TRMV_CM(dcomplex, MKL_Complex16, cd, z, )
165EIGEN_BLAS_TRMV_CM(float, float, f, s, )
166EIGEN_BLAS_TRMV_CM(scomplex, MKL_Complex8, cf, c, )
167#else
168EIGEN_BLAS_TRMV_CM(double, double, d, d, _)
169EIGEN_BLAS_TRMV_CM(dcomplex, double, cd, z, _)
170EIGEN_BLAS_TRMV_CM(float, float, f, s, _)
171EIGEN_BLAS_TRMV_CM(scomplex, float, cf, c, _)
172#endif
173
174// implements row-major: res += alpha * op(triangular) * vector
175#define EIGEN_BLAS_TRMV_RM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX, BLASPOSTFIX) \
176 template <typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
177 struct triangular_matrix_vector_product_trmv<Index, Mode, EIGTYPE, ConjLhs, EIGTYPE, ConjRhs, RowMajor> { \
178 enum { \
179 IsLower = (Mode & Lower) == Lower, \
180 SetDiag = (Mode & (ZeroDiag | UnitDiag)) ? 0 : 1, \
181 IsUnitDiag = (Mode & UnitDiag) ? 1 : 0, \
182 IsZeroDiag = (Mode & ZeroDiag) ? 1 : 0, \
183 LowUp = IsLower ? Lower : Upper \
184 }; \
185 static void run(Index rows_, Index cols_, const EIGTYPE* lhs_, Index lhsStride, const EIGTYPE* rhs_, \
186 Index rhsIncr, EIGTYPE* res_, Index resIncr, EIGTYPE alpha) { \
187 if (rows_ == 0 || cols_ == 0) return; \
188 if (IsZeroDiag) { \
189 triangular_matrix_vector_product<Index, Mode, EIGTYPE, ConjLhs, EIGTYPE, ConjRhs, RowMajor, BuiltIn>::run( \
190 rows_, cols_, lhs_, lhsStride, rhs_, rhsIncr, res_, resIncr, alpha); \
191 return; \
192 } \
193 Index size = (std::min)(rows_, cols_); \
194 Index rows = IsLower ? rows_ : size; \
195 Index cols = IsLower ? size : cols_; \
196 \
197 typedef VectorX##EIGPREFIX VectorRhs; \
198 EIGTYPE *x, *y; \
199 \
200 /* Set x*/ \
201 Map<const VectorRhs, 0, InnerStride<> > rhs(rhs_, cols, InnerStride<>(rhsIncr)); \
202 VectorRhs x_tmp; \
203 if (ConjRhs) \
204 x_tmp = rhs.conjugate(); \
205 else \
206 x_tmp = rhs; \
207 x = x_tmp.data(); \
208 \
209 /* Square part handling */ \
210 \
211 char trans, uplo, diag; \
212 BlasIndex m, n, lda, incx, incy; \
213 EIGTYPE const* a; \
214 EIGTYPE beta(1); \
215 \
216 /* Set m, n */ \
217 n = convert_index<BlasIndex>(size); \
218 lda = convert_index<BlasIndex>(lhsStride); \
219 incx = 1; \
220 incy = convert_index<BlasIndex>(resIncr); \
221 \
222 /* Set uplo, trans and diag*/ \
223 trans = ConjLhs ? 'C' : 'T'; \
224 uplo = IsLower ? 'U' : 'L'; \
225 diag = IsUnitDiag ? 'U' : 'N'; \
226 \
227 /* call ?TRMV*/ \
228 BLASPREFIX##trmv##BLASPOSTFIX(&uplo, &trans, &diag, &n, (const BLASTYPE*)lhs_, &lda, (BLASTYPE*)x, &incx); \
229 \
230 /* Add op(a_tr)rhs into res*/ \
231 BLASPREFIX##axpy##BLASPOSTFIX(&n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)x, &incx, \
232 (BLASTYPE*)res_, &incy); \
233 /* Non-square case - doesn't fit to BLAS ?TRMV. Fall to default triangular product*/ \
234 if (size < (std::max)(rows, cols)) { \
235 if (ConjRhs) \
236 x_tmp = rhs.conjugate(); \
237 else \
238 x_tmp = rhs; \
239 x = x_tmp.data(); \
240 if (size < rows) { \
241 y = res_ + size * resIncr; \
242 a = lhs_ + size * lda; \
243 m = convert_index<BlasIndex>(rows - size); \
244 n = convert_index<BlasIndex>(size); \
245 } else { \
246 x += size; \
247 y = res_; \
248 a = lhs_ + size; \
249 m = convert_index<BlasIndex>(size); \
250 n = convert_index<BlasIndex>(cols - size); \
251 } \
252 BLASPREFIX##gemv##BLASPOSTFIX(&trans, &n, &m, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, \
253 &lda, (const BLASTYPE*)x, &incx, (const BLASTYPE*)&numext::real_ref(beta), \
254 (BLASTYPE*)y, &incy); \
255 } \
256 } \
257 };
258
259#ifdef EIGEN_USE_MKL
260EIGEN_BLAS_TRMV_RM(double, double, d, d, )
261EIGEN_BLAS_TRMV_RM(dcomplex, MKL_Complex16, cd, z, )
262EIGEN_BLAS_TRMV_RM(float, float, f, s, )
263EIGEN_BLAS_TRMV_RM(scomplex, MKL_Complex8, cf, c, )
264#else
265EIGEN_BLAS_TRMV_RM(double, double, d, d, _)
266EIGEN_BLAS_TRMV_RM(dcomplex, double, cd, z, _)
267EIGEN_BLAS_TRMV_RM(float, float, f, s, _)
268EIGEN_BLAS_TRMV_RM(scomplex, float, cf, c, _)
269#endif
270
271} // namespace internal
272
273} // end namespace Eigen
274
275#endif // EIGEN_TRIANGULAR_MATRIX_VECTOR_BLAS_H
Namespace containing all symbols from the Eigen library.
Definition B01_Experimental.dox:1
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:82