33#ifndef EIGEN_GENERAL_MATRIX_MATRIX_BLAS_H
34#define EIGEN_GENERAL_MATRIX_MATRIX_BLAS_H
37#include "../InternalHeaderCheck.h"
52#define GEMM_SPECIALIZATION(EIGTYPE, EIGPREFIX, BLASTYPE, BLASFUNC) \
53 template <typename Index, int LhsStorageOrder, bool ConjugateLhs, int RhsStorageOrder, bool ConjugateRhs> \
54 struct general_matrix_matrix_product<Index, EIGTYPE, LhsStorageOrder, ConjugateLhs, EIGTYPE, RhsStorageOrder, \
55 ConjugateRhs, ColMajor, 1> { \
56 typedef gebp_traits<EIGTYPE, EIGTYPE> Traits; \
58 static void run(Index rows, Index cols, Index depth, const EIGTYPE* lhs_, Index lhsStride, const EIGTYPE* rhs_, \
59 Index rhsStride, EIGTYPE* res, Index resIncr, Index resStride, EIGTYPE alpha, \
60 level3_blocking<EIGTYPE, EIGTYPE>& , GemmParallelInfo<Index>* ) { \
62 if (rows == 0 || cols == 0 || depth == 0) return; \
63 EIGEN_ONLY_USED_FOR_DEBUG(resIncr); \
64 eigen_assert(resIncr == 1); \
65 char transa, transb; \
66 BlasIndex m, n, k, lda, ldb, ldc; \
67 const EIGTYPE *a, *b; \
69 MatrixX##EIGPREFIX a_tmp, b_tmp; \
72 transa = (LhsStorageOrder == RowMajor) ? ((ConjugateLhs) ? 'C' : 'T') : 'N'; \
73 transb = (RhsStorageOrder == RowMajor) ? ((ConjugateRhs) ? 'C' : 'T') : 'N'; \
76 m = convert_index<BlasIndex>(rows); \
77 n = convert_index<BlasIndex>(cols); \
78 k = convert_index<BlasIndex>(depth); \
81 lda = convert_index<BlasIndex>(lhsStride); \
82 ldb = convert_index<BlasIndex>(rhsStride); \
83 ldc = convert_index<BlasIndex>(resStride); \
86 if ((LhsStorageOrder == ColMajor) && (ConjugateLhs)) { \
87 Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > lhs(lhs_, m, k, OuterStride<>(lhsStride)); \
88 a_tmp = lhs.conjugate(); \
90 lda = convert_index<BlasIndex>(a_tmp.outerStride()); \
94 if ((RhsStorageOrder == ColMajor) && (ConjugateRhs)) { \
95 Map<const MatrixX##EIGPREFIX, 0, OuterStride<> > rhs(rhs_, k, n, OuterStride<>(rhsStride)); \
96 b_tmp = rhs.conjugate(); \
98 ldb = convert_index<BlasIndex>(b_tmp.outerStride()); \
102 BLASFUNC(&transa, &transb, &m, &n, &k, (const BLASTYPE*)&numext::real_ref(alpha), (const BLASTYPE*)a, &lda, \
103 (const BLASTYPE*)b, &ldb, (const BLASTYPE*)&numext::real_ref(beta), (BLASTYPE*)res, &ldc); \
108GEMM_SPECIALIZATION(
double, d,
double, dgemm)
109GEMM_SPECIALIZATION(
float, f,
float, sgemm)
110GEMM_SPECIALIZATION(dcomplex, cd, MKL_Complex16, zgemm)
111GEMM_SPECIALIZATION(scomplex, cf, MKL_Complex8, cgemm)
113GEMM_SPECIALIZATION(
double, d,
double, dgemm_)
114GEMM_SPECIALIZATION(
float, f,
float, sgemm_)
115GEMM_SPECIALIZATION(dcomplex, cd,
double, zgemm_)
116GEMM_SPECIALIZATION(scomplex, cf,
float, cgemm_)
121#if EIGEN_USE_OPENBLAS_BFLOAT16
125void sbgemm_(
const char* trans_a,
const char* trans_b,
const int* M,
const int* N,
const int* K,
const float* alpha,
126 const Eigen::bfloat16* A,
const int* lda,
const Eigen::bfloat16* B,
const int* ldb,
const float* beta,
127 float* C,
const int* ldc);
130template <
typename Index,
int LhsStorageOrder,
bool ConjugateLhs,
int RhsStorageOrder,
bool ConjugateRhs>
131struct general_matrix_matrix_product<
Index, Eigen::bfloat16, LhsStorageOrder, ConjugateLhs, Eigen::bfloat16,
132 RhsStorageOrder, ConjugateRhs,
ColMajor, 1> {
133 typedef gebp_traits<Eigen::bfloat16, Eigen::bfloat16> Traits;
135 static void run(
Index rows,
Index cols,
Index depth,
const Eigen::bfloat16* lhs_,
Index lhsStride,
136 const Eigen::bfloat16* rhs_,
Index rhsStride, Eigen::bfloat16* res,
Index resIncr,
Index resStride,
137 Eigen::bfloat16 alpha, level3_blocking<Eigen::bfloat16, Eigen::bfloat16>& ,
138 GemmParallelInfo<Index>* ) {
140 if (rows == 0 || cols == 0 || depth == 0)
return;
141 EIGEN_ONLY_USED_FOR_DEBUG(resIncr);
142 eigen_assert(resIncr == 1);
144 BlasIndex m, n, k, lda, ldb, ldc;
145 const Eigen::bfloat16 *a, *b;
147 float falpha =
static_cast<float>(alpha);
148 float fbeta = float(1.0);
150 using MatrixXbf16 = Matrix<Eigen::bfloat16, Dynamic, Dynamic>;
151 MatrixXbf16 a_tmp, b_tmp;
155 transa = (LhsStorageOrder ==
RowMajor) ? ((ConjugateLhs) ?
'C' :
'T') :
'N';
156 transb = (RhsStorageOrder ==
RowMajor) ? ((ConjugateRhs) ?
'C' :
'T') :
'N';
159 m = convert_index<BlasIndex>(rows);
160 n = convert_index<BlasIndex>(cols);
161 k = convert_index<BlasIndex>(depth);
164 lda = convert_index<BlasIndex>(lhsStride);
165 ldb = convert_index<BlasIndex>(rhsStride);
166 ldc = convert_index<BlasIndex>(m);
169 if ((LhsStorageOrder ==
ColMajor) && (ConjugateLhs)) {
170 Map<const MatrixXbf16, 0, OuterStride<> > lhs(lhs_, m, k, OuterStride<>(lhsStride));
171 a_tmp = lhs.conjugate();
173 lda = convert_index<BlasIndex>(a_tmp.outerStride());
178 if ((RhsStorageOrder ==
ColMajor) && (ConjugateRhs)) {
179 Map<const MatrixXbf16, 0, OuterStride<> > rhs(rhs_, k, n, OuterStride<>(rhsStride));
180 b_tmp = rhs.conjugate();
182 ldb = convert_index<BlasIndex>(b_tmp.outerStride());
190 sbgemm_(&transa, &transb, &m, &n, &k, (
const float*)&numext::real_ref(falpha), a, &lda, b, &ldb,
191 (
const float*)&numext::real_ref(fbeta), r_tmp.data(), &ldc);
194 Map<MatrixXbf16, 0, OuterStride<> > result(res, m, n, OuterStride<>(resStride));
195 result = r_tmp.cast<Eigen::bfloat16>();
@ ColMajor
Definition Constants.h:318
@ RowMajor
Definition Constants.h:320
Matrix< float, Dynamic, Dynamic > MatrixXf
Dynamic×Dynamic matrix of type float.
Definition Matrix.h:478
Namespace containing all symbols from the Eigen library.
Definition B01_Experimental.dox:1
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:82