33#ifndef EIGEN_TRIANGULAR_MATRIX_VECTOR_BLAS_H
34#define EIGEN_TRIANGULAR_MATRIX_VECTOR_BLAS_H
37#include "../InternalHeaderCheck.h"
49template <
typename Index,
int Mode,
typename LhsScalar,
bool ConjLhs,
typename RhsScalar,
bool ConjRhs,
51struct triangular_matrix_vector_product_trmv
52 : triangular_matrix_vector_product<Index, Mode, LhsScalar, ConjLhs, RhsScalar, ConjRhs, StorageOrder, BuiltIn> {};
54#define EIGEN_BLAS_TRMV_SPECIALIZE(Scalar) \
55 template <typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
56 struct triangular_matrix_vector_product<Index, Mode, Scalar, ConjLhs, Scalar, ConjRhs, ColMajor, Specialized> { \
57 static void run(Index rows_, Index cols_, const Scalar* lhs_, Index lhsStride, const Scalar* rhs_, Index rhsIncr, \
58 Scalar* res_, Index resIncr, Scalar alpha) { \
59 triangular_matrix_vector_product_trmv<Index, Mode, Scalar, ConjLhs, Scalar, ConjRhs, ColMajor>::run( \
60 rows_, cols_, lhs_, lhsStride, rhs_, rhsIncr, res_, resIncr, alpha); \
63 template <typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
64 struct triangular_matrix_vector_product<Index, Mode, Scalar, ConjLhs, Scalar, ConjRhs, RowMajor, Specialized> { \
65 static void run(Index rows_, Index cols_, const Scalar* lhs_, Index lhsStride, const Scalar* rhs_, Index rhsIncr, \
66 Scalar* res_, Index resIncr, Scalar alpha) { \
67 triangular_matrix_vector_product_trmv<Index, Mode, Scalar, ConjLhs, Scalar, ConjRhs, RowMajor>::run( \
68 rows_, cols_, lhs_, lhsStride, rhs_, rhsIncr, res_, resIncr, alpha); \
72EIGEN_BLAS_TRMV_SPECIALIZE(
double)
73EIGEN_BLAS_TRMV_SPECIALIZE(
float)
74EIGEN_BLAS_TRMV_SPECIALIZE(dcomplex)
75EIGEN_BLAS_TRMV_SPECIALIZE(scomplex)
78#define EIGEN_BLAS_TRMV_CM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX, BLASPOSTFIX) \
79 template <typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
80 struct triangular_matrix_vector_product_trmv<Index, Mode, EIGTYPE, ConjLhs, EIGTYPE, ConjRhs, ColMajor> { \
82 IsLower = (Mode & Lower) == Lower, \
83 SetDiag = (Mode & (ZeroDiag | UnitDiag)) ? 0 : 1, \
84 IsUnitDiag = (Mode & UnitDiag) ? 1 : 0, \
85 IsZeroDiag = (Mode & ZeroDiag) ? 1 : 0, \
86 LowUp = IsLower ? Lower : Upper \
88 static void run(Index rows_, Index cols_, const EIGTYPE* lhs_, Index lhsStride, const EIGTYPE* rhs_, \
89 Index rhsIncr, EIGTYPE* res_, Index resIncr, EIGTYPE alpha) { \
90 if (rows_ == 0 || cols_ == 0) return; \
91 if (ConjLhs || IsZeroDiag) { \
92 triangular_matrix_vector_product<Index, Mode, EIGTYPE, ConjLhs, EIGTYPE, ConjRhs, ColMajor, BuiltIn>::run( \
93 rows_, cols_, lhs_, lhsStride, rhs_, rhsIncr, res_, resIncr, alpha); \
96 Index size = (std::min)(rows_, cols_); \
97 Index rows = IsLower ? rows_ : size; \
98 Index cols = IsLower ? size : cols_; \
100 typedef VectorX##EIGPREFIX VectorRhs; \
104 Map<const VectorRhs, 0, InnerStride<> > rhs(rhs_, cols, InnerStride<>(rhsIncr)); \
107 x_tmp = rhs.conjugate(); \
114 char trans, uplo, diag; \
115 BlasIndex m, n, lda, incx, incy; \
120 n = convert_index<BlasIndex>(size); \
121 lda = convert_index<BlasIndex>(lhsStride); \
123 incy = convert_index<BlasIndex>(resIncr); \
127 uplo = IsLower ? 'L' : 'U'; \
128 diag = IsUnitDiag ? 'U' : 'N'; \
131 BLASPREFIX##trmv##BLASPOSTFIX(&uplo, &trans, &diag, &n, (const BLASTYPE*)lhs_, &lda, (BLASTYPE*)x, &incx); \
134 BLASPREFIX##axpy##BLASPOSTFIX(&n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)x, &incx, \
135 (BLASTYPE*)res_, &incy); \
137 if (size < (std::max)(rows, cols)) { \
139 x_tmp = rhs.conjugate(); \
144 y = res_ + size * resIncr; \
146 m = convert_index<BlasIndex>(rows - size); \
147 n = convert_index<BlasIndex>(size); \
151 a = lhs_ + size * lda; \
152 m = convert_index<BlasIndex>(size); \
153 n = convert_index<BlasIndex>(cols - size); \
155 BLASPREFIX##gemv##BLASPOSTFIX(&trans, &m, &n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, \
156 &lda, (const BLASTYPE*)x, &incx, (const BLASTYPE*)&numext::real_ref(beta), \
157 (BLASTYPE*)y, &incy); \
163EIGEN_BLAS_TRMV_CM(
double,
double, d, d, )
164EIGEN_BLAS_TRMV_CM(dcomplex, MKL_Complex16, cd, z, )
165EIGEN_BLAS_TRMV_CM(
float,
float, f, s, )
166EIGEN_BLAS_TRMV_CM(scomplex, MKL_Complex8, cf, c, )
168EIGEN_BLAS_TRMV_CM(
double,
double, d, d, _)
169EIGEN_BLAS_TRMV_CM(dcomplex,
double, cd, z, _)
170EIGEN_BLAS_TRMV_CM(
float,
float, f, s, _)
171EIGEN_BLAS_TRMV_CM(scomplex,
float, cf, c, _)
175#define EIGEN_BLAS_TRMV_RM(EIGTYPE, BLASTYPE, EIGPREFIX, BLASPREFIX, BLASPOSTFIX) \
176 template <typename Index, int Mode, bool ConjLhs, bool ConjRhs> \
177 struct triangular_matrix_vector_product_trmv<Index, Mode, EIGTYPE, ConjLhs, EIGTYPE, ConjRhs, RowMajor> { \
179 IsLower = (Mode & Lower) == Lower, \
180 SetDiag = (Mode & (ZeroDiag | UnitDiag)) ? 0 : 1, \
181 IsUnitDiag = (Mode & UnitDiag) ? 1 : 0, \
182 IsZeroDiag = (Mode & ZeroDiag) ? 1 : 0, \
183 LowUp = IsLower ? Lower : Upper \
185 static void run(Index rows_, Index cols_, const EIGTYPE* lhs_, Index lhsStride, const EIGTYPE* rhs_, \
186 Index rhsIncr, EIGTYPE* res_, Index resIncr, EIGTYPE alpha) { \
187 if (rows_ == 0 || cols_ == 0) return; \
189 triangular_matrix_vector_product<Index, Mode, EIGTYPE, ConjLhs, EIGTYPE, ConjRhs, RowMajor, BuiltIn>::run( \
190 rows_, cols_, lhs_, lhsStride, rhs_, rhsIncr, res_, resIncr, alpha); \
193 Index size = (std::min)(rows_, cols_); \
194 Index rows = IsLower ? rows_ : size; \
195 Index cols = IsLower ? size : cols_; \
197 typedef VectorX##EIGPREFIX VectorRhs; \
201 Map<const VectorRhs, 0, InnerStride<> > rhs(rhs_, cols, InnerStride<>(rhsIncr)); \
204 x_tmp = rhs.conjugate(); \
211 char trans, uplo, diag; \
212 BlasIndex m, n, lda, incx, incy; \
217 n = convert_index<BlasIndex>(size); \
218 lda = convert_index<BlasIndex>(lhsStride); \
220 incy = convert_index<BlasIndex>(resIncr); \
223 trans = ConjLhs ? 'C' : 'T'; \
224 uplo = IsLower ? 'U' : 'L'; \
225 diag = IsUnitDiag ? 'U' : 'N'; \
228 BLASPREFIX##trmv##BLASPOSTFIX(&uplo, &trans, &diag, &n, (const BLASTYPE*)lhs_, &lda, (BLASTYPE*)x, &incx); \
231 BLASPREFIX##axpy##BLASPOSTFIX(&n, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)x, &incx, \
232 (BLASTYPE*)res_, &incy); \
234 if (size < (std::max)(rows, cols)) { \
236 x_tmp = rhs.conjugate(); \
241 y = res_ + size * resIncr; \
242 a = lhs_ + size * lda; \
243 m = convert_index<BlasIndex>(rows - size); \
244 n = convert_index<BlasIndex>(size); \
249 m = convert_index<BlasIndex>(size); \
250 n = convert_index<BlasIndex>(cols - size); \
252 BLASPREFIX##gemv##BLASPOSTFIX(&trans, &n, &m, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, \
253 &lda, (const BLASTYPE*)x, &incx, (const BLASTYPE*)&numext::real_ref(beta), \
254 (BLASTYPE*)y, &incy); \
260EIGEN_BLAS_TRMV_RM(
double,
double, d, d, )
261EIGEN_BLAS_TRMV_RM(dcomplex, MKL_Complex16, cd, z, )
262EIGEN_BLAS_TRMV_RM(
float,
float, f, s, )
263EIGEN_BLAS_TRMV_RM(scomplex, MKL_Complex8, cf, c, )
265EIGEN_BLAS_TRMV_RM(
double,
double, d, d, _)
266EIGEN_BLAS_TRMV_RM(dcomplex,
double, cd, z, _)
267EIGEN_BLAS_TRMV_RM(
float,
float, f, s, _)
268EIGEN_BLAS_TRMV_RM(scomplex,
float, cf, c, _)
Namespace containing all symbols from the Eigen library.
Definition B01_Experimental.dox:1
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:82