10#ifndef EIGEN_GENERAL_MATRIX_MATRIX_H
11#define EIGEN_GENERAL_MATRIX_MATRIX_H
14#include "../InternalHeaderCheck.h"
20template <
typename LhsScalar_,
typename RhsScalar_>
24template <
typename Index,
typename LhsScalar,
int LhsStorageOrder,
bool ConjugateLhs,
typename RhsScalar,
25 int RhsStorageOrder,
bool ConjugateRhs,
int ResInnerStride>
26struct general_matrix_matrix_product<
Index, LhsScalar, LhsStorageOrder, ConjugateLhs, RhsScalar, RhsStorageOrder,
27 ConjugateRhs,
RowMajor, ResInnerStride> {
28 typedef gebp_traits<RhsScalar, LhsScalar> Traits;
30 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
31 static EIGEN_STRONG_INLINE
void run(
Index rows,
Index cols,
Index depth,
const LhsScalar* lhs,
Index lhsStride,
32 const RhsScalar* rhs,
Index rhsStride, ResScalar* res,
Index resIncr,
33 Index resStride, ResScalar alpha, level3_blocking<RhsScalar, LhsScalar>& blocking,
34 GemmParallelInfo<Index>* info = 0) {
38 ResInnerStride>::run(cols, rows, depth, rhs, rhsStride, lhs, lhsStride, res, resIncr,
39 resStride, alpha, blocking, info);
45template <
typename Index,
typename LhsScalar,
int LhsStorageOrder,
bool ConjugateLhs,
typename RhsScalar,
46 int RhsStorageOrder,
bool ConjugateRhs,
int ResInnerStride>
47struct general_matrix_matrix_product<
Index, LhsScalar, LhsStorageOrder, ConjugateLhs, RhsScalar, RhsStorageOrder,
48 ConjugateRhs,
ColMajor, ResInnerStride> {
49 typedef gebp_traits<LhsScalar, RhsScalar> Traits;
51 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
52 static void run(
Index rows,
Index cols,
Index depth,
const LhsScalar* lhs_,
Index lhsStride,
const RhsScalar* rhs_,
53 Index rhsStride, ResScalar* res_,
Index resIncr,
Index resStride, ResScalar alpha,
54 level3_blocking<LhsScalar, RhsScalar>& blocking, GemmParallelInfo<Index>* info = 0) {
55 typedef const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> LhsMapper;
56 typedef const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> RhsMapper;
57 typedef blas_data_mapper<typename Traits::ResScalar, Index, ColMajor, Unaligned, ResInnerStride> ResMapper;
58 LhsMapper lhs(lhs_, lhsStride);
59 RhsMapper rhs(rhs_, rhsStride);
60 ResMapper res(res_, resStride, resIncr);
62 Index kc = blocking.kc();
63 Index mc = (std::min)(rows, blocking.mc());
64 Index nc = (std::min)(cols, blocking.nc());
66 gemm_pack_lhs<LhsScalar,
Index, LhsMapper, Traits::mr, Traits::LhsProgress,
typename Traits::LhsPacket4Packing,
69 gemm_pack_rhs<RhsScalar, Index, RhsMapper, Traits::nr, RhsStorageOrder> pack_rhs;
70 gebp_kernel<LhsScalar, RhsScalar, Index, ResMapper, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp;
72#if !defined(EIGEN_USE_BLAS) && (defined(EIGEN_HAS_OPENMP) || defined(EIGEN_GEMM_THREADPOOL))
75 int tid = info->logical_thread_id;
76 int threads = info->num_threads;
78 LhsScalar* blockA = blocking.blockA();
79 eigen_internal_assert(blockA != 0);
81 std::size_t sizeB = kc * nc;
82 ei_declare_aligned_stack_constructed_variable(RhsScalar, blockB, sizeB, 0);
85 for (
Index k = 0; k < depth; k += kc) {
86 const Index actual_kc = (std::min)(k + kc, depth) - k;
90 pack_rhs(blockB, rhs.getSubMapper(k, 0), actual_kc, nc);
99 while (info->task_info[tid].users != 0) {
100 std::this_thread::yield();
102 info->task_info[tid].users = threads;
104 pack_lhs(blockA + info->task_info[tid].lhs_start * actual_kc,
105 lhs.getSubMapper(info->task_info[tid].lhs_start, k), actual_kc, info->task_info[tid].lhs_length);
108 info->task_info[tid].sync = k;
111 for (
int shift = 0; shift < threads; ++shift) {
112 int i = (tid + shift) % threads;
118 while (info->task_info[i].sync != k) {
119 std::this_thread::yield();
123 gebp(res.getSubMapper(info->task_info[i].lhs_start, 0), blockA + info->task_info[i].lhs_start * actual_kc,
124 blockB, info->task_info[i].lhs_length, actual_kc, nc, alpha);
128 for (
Index j = nc; j < cols; j += nc) {
129 const Index actual_nc = (std::min)(j + nc, cols) - j;
132 pack_rhs(blockB, rhs.getSubMapper(k, j), actual_kc, actual_nc);
135 gebp(res.getSubMapper(0, j), blockA, blockB, rows, actual_kc, actual_nc, alpha);
140 for (
Index i = 0; i < threads; ++i) info->task_info[i].users -= 1;
145 EIGEN_UNUSED_VARIABLE(info);
148 std::size_t sizeA = kc * mc;
149 std::size_t sizeB = kc * nc;
151 ei_declare_aligned_stack_constructed_variable(LhsScalar, blockA, sizeA, blocking.blockA());
152 ei_declare_aligned_stack_constructed_variable(RhsScalar, blockB, sizeB, blocking.blockB());
154 const bool pack_rhs_once = mc != rows && kc == depth && nc == cols;
157 for (
Index i2 = 0; i2 < rows; i2 += mc) {
158 const Index actual_mc = (std::min)(i2 + mc, rows) - i2;
160 for (
Index k2 = 0; k2 < depth; k2 += kc) {
161 const Index actual_kc = (std::min)(k2 + kc, depth) - k2;
167 pack_lhs(blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc);
170 for (
Index j2 = 0; j2 < cols; j2 += nc) {
171 const Index actual_nc = (std::min)(j2 + nc, cols) - j2;
176 if ((!pack_rhs_once) || i2 == 0) pack_rhs(blockB, rhs.getSubMapper(k2, j2), actual_kc, actual_nc);
179 gebp(res.getSubMapper(i2, j2), blockA, blockB, actual_mc, actual_kc, actual_nc, alpha);
192template <
typename Scalar,
typename Index,
typename Gemm,
typename Lhs,
typename Rhs,
typename Dest,
193 typename BlockingType>
195 gemm_functor(
const Lhs& lhs,
const Rhs& rhs, Dest& dest,
const Scalar& actualAlpha, BlockingType& blocking)
196 : m_lhs(lhs), m_rhs(rhs), m_dest(dest), m_actualAlpha(actualAlpha), m_blocking(blocking) {}
198 void initParallelSession(
Index num_threads)
const {
199 m_blocking.initParallel(m_lhs.rows(), m_rhs.cols(), m_lhs.cols(), num_threads);
200 m_blocking.allocateA();
203 void operator()(
Index row,
Index rows,
Index col = 0,
Index cols = -1, GemmParallelInfo<Index>* info = 0)
const {
204 if (cols == -1) cols = m_rhs.cols();
206 Gemm::run(rows, cols, m_lhs.cols(), &m_lhs.coeffRef(row, 0), m_lhs.outerStride(), &m_rhs.coeffRef(0, col),
207 m_rhs.outerStride(), (Scalar*)&(m_dest.coeffRef(row, col)), m_dest.innerStride(), m_dest.outerStride(),
208 m_actualAlpha, m_blocking, info);
211 typedef typename Gemm::Traits Traits;
217 Scalar m_actualAlpha;
218 BlockingType& m_blocking;
221template <
int StorageOrder,
typename LhsScalar,
typename RhsScalar,
int MaxRows,
int MaxCols,
int MaxDepth,
222 int KcFactor = 1,
bool FiniteAtCompileTime = MaxRows !=
Dynamic && MaxCols !=
Dynamic && MaxDepth !=
Dynamic>
223class gemm_blocking_space;
225template <
typename LhsScalar_,
typename RhsScalar_>
226class level3_blocking {
227 typedef LhsScalar_ LhsScalar;
228 typedef RhsScalar_ RhsScalar;
239 level3_blocking() : m_blockA(0), m_blockB(0), m_mc(0), m_nc(0), m_kc(0) {}
241 inline Index mc()
const {
return m_mc; }
242 inline Index nc()
const {
return m_nc; }
243 inline Index kc()
const {
return m_kc; }
245 inline LhsScalar* blockA() {
return m_blockA; }
246 inline RhsScalar* blockB() {
return m_blockB; }
249template <
int StorageOrder,
typename LhsScalar_,
typename RhsScalar_,
int MaxRows,
int MaxCols,
int MaxDepth,
251class gemm_blocking_space<StorageOrder, LhsScalar_, RhsScalar_, MaxRows, MaxCols, MaxDepth, KcFactor,
253 :
public level3_blocking<std::conditional_t<StorageOrder == RowMajor, RhsScalar_, LhsScalar_>,
254 std::conditional_t<StorageOrder == RowMajor, LhsScalar_, RhsScalar_>> {
256 Transpose = StorageOrder ==
RowMajor,
257 ActualRows = Transpose ? MaxCols : MaxRows,
258 ActualCols = Transpose ? MaxRows : MaxCols
260 typedef std::conditional_t<Transpose, RhsScalar_, LhsScalar_> LhsScalar;
261 typedef std::conditional_t<Transpose, LhsScalar_, RhsScalar_> RhsScalar;
262 enum { SizeA = ActualRows * MaxDepth, SizeB = ActualCols * MaxDepth };
264#if EIGEN_MAX_STATIC_ALIGN_BYTES >= EIGEN_DEFAULT_ALIGN_BYTES
265 EIGEN_ALIGN_MAX LhsScalar m_staticA[SizeA];
266 EIGEN_ALIGN_MAX RhsScalar m_staticB[SizeB];
268 EIGEN_ALIGN_MAX
char m_staticA[SizeA *
sizeof(LhsScalar) + EIGEN_DEFAULT_ALIGN_BYTES - 1];
269 EIGEN_ALIGN_MAX
char m_staticB[SizeB *
sizeof(RhsScalar) + EIGEN_DEFAULT_ALIGN_BYTES - 1];
275 this->m_mc = ActualRows;
276 this->m_nc = ActualCols;
277 this->m_kc = MaxDepth;
278#if EIGEN_MAX_STATIC_ALIGN_BYTES >= EIGEN_DEFAULT_ALIGN_BYTES
279 this->m_blockA = m_staticA;
280 this->m_blockB = m_staticB;
282 this->m_blockA =
reinterpret_cast<LhsScalar*
>((std::uintptr_t(m_staticA) + (EIGEN_DEFAULT_ALIGN_BYTES - 1)) &
283 ~std::size_t(EIGEN_DEFAULT_ALIGN_BYTES - 1));
284 this->m_blockB =
reinterpret_cast<RhsScalar*
>((std::uintptr_t(m_staticB) + (EIGEN_DEFAULT_ALIGN_BYTES - 1)) &
285 ~std::size_t(EIGEN_DEFAULT_ALIGN_BYTES - 1));
291 inline void allocateA() {}
292 inline void allocateB() {}
293 inline void allocateAll() {}
296template <
int StorageOrder,
typename LhsScalar_,
typename RhsScalar_,
int MaxRows,
int MaxCols,
int MaxDepth,
298class gemm_blocking_space<StorageOrder, LhsScalar_, RhsScalar_, MaxRows, MaxCols, MaxDepth, KcFactor, false>
299 :
public level3_blocking<std::conditional_t<StorageOrder == RowMajor, RhsScalar_, LhsScalar_>,
300 std::conditional_t<StorageOrder == RowMajor, LhsScalar_, RhsScalar_>> {
301 enum { Transpose = StorageOrder ==
RowMajor };
302 typedef std::conditional_t<Transpose, RhsScalar_, LhsScalar_> LhsScalar;
303 typedef std::conditional_t<Transpose, LhsScalar_, RhsScalar_> RhsScalar;
310 this->m_mc = Transpose ? cols : rows;
311 this->m_nc = Transpose ? rows : cols;
315 computeProductBlockingSizes<LhsScalar, RhsScalar, KcFactor>(this->m_kc, this->m_mc, this->m_nc, num_threads);
318 Index n = this->m_nc;
319 computeProductBlockingSizes<LhsScalar, RhsScalar, KcFactor>(this->m_kc, this->m_mc, n, num_threads);
322 m_sizeA = this->m_mc * this->m_kc;
323 m_sizeB = this->m_kc * this->m_nc;
327 this->m_mc = Transpose ? cols : rows;
328 this->m_nc = Transpose ? rows : cols;
331 eigen_internal_assert(this->m_blockA == 0 && this->m_blockB == 0);
332 Index m = this->m_mc;
333 computeProductBlockingSizes<LhsScalar, RhsScalar, KcFactor>(this->m_kc, m, this->m_nc, num_threads);
334 m_sizeA = this->m_mc * this->m_kc;
335 m_sizeB = this->m_kc * this->m_nc;
339 if (this->m_blockA == 0) this->m_blockA = aligned_new<LhsScalar>(m_sizeA);
343 if (this->m_blockB == 0) this->m_blockB = aligned_new<RhsScalar>(m_sizeB);
351 ~gemm_blocking_space() {
352 aligned_delete(this->m_blockA, m_sizeA);
353 aligned_delete(this->m_blockB, m_sizeB);
361template <
typename Lhs,
typename Rhs>
362struct generic_product_impl<Lhs, Rhs, DenseShape, DenseShape, GemmProduct>
363 : generic_product_impl_base<Lhs, Rhs, generic_product_impl<Lhs, Rhs, DenseShape, DenseShape, GemmProduct>> {
364 typedef typename Product<Lhs, Rhs>::Scalar Scalar;
365 typedef typename Lhs::Scalar LhsScalar;
366 typedef typename Rhs::Scalar RhsScalar;
368 typedef internal::blas_traits<Lhs> LhsBlasTraits;
369 typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
370 typedef internal::remove_all_t<ActualLhsType> ActualLhsTypeCleaned;
372 typedef internal::blas_traits<Rhs> RhsBlasTraits;
373 typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
374 typedef internal::remove_all_t<ActualRhsType> ActualRhsTypeCleaned;
376 enum { MaxDepthAtCompileTime = min_size_prefer_fixed(Lhs::MaxColsAtCompileTime, Rhs::MaxRowsAtCompileTime) };
378 typedef generic_product_impl<Lhs, Rhs, DenseShape, DenseShape, CoeffBasedProductMode> lazyproduct;
380 template <
typename Dst>
381 static void evalTo(Dst& dst,
const Lhs& lhs,
const Rhs& rhs) {
388 if ((rhs.rows() + dst.rows() + dst.cols()) < EIGEN_GEMM_TO_COEFFBASED_THRESHOLD && rhs.rows() > 0)
389 lazyproduct::eval_dynamic(dst, lhs, rhs, internal::assign_op<typename Dst::Scalar, Scalar>());
392 scaleAndAddTo(dst, lhs, rhs, Scalar(1));
396 template <
typename Dst>
397 static void addTo(Dst& dst,
const Lhs& lhs,
const Rhs& rhs) {
398 if ((rhs.rows() + dst.rows() + dst.cols()) < EIGEN_GEMM_TO_COEFFBASED_THRESHOLD && rhs.rows() > 0)
399 lazyproduct::eval_dynamic(dst, lhs, rhs, internal::add_assign_op<typename Dst::Scalar, Scalar>());
401 scaleAndAddTo(dst, lhs, rhs, Scalar(1));
404 template <
typename Dst>
405 static void subTo(Dst& dst,
const Lhs& lhs,
const Rhs& rhs) {
406 if ((rhs.rows() + dst.rows() + dst.cols()) < EIGEN_GEMM_TO_COEFFBASED_THRESHOLD && rhs.rows() > 0)
407 lazyproduct::eval_dynamic(dst, lhs, rhs, internal::sub_assign_op<typename Dst::Scalar, Scalar>());
409 scaleAndAddTo(dst, lhs, rhs, Scalar(-1));
412 template <
typename Dest>
413 static void scaleAndAddTo(Dest& dst,
const Lhs& a_lhs,
const Rhs& a_rhs,
const Scalar& alpha) {
414 eigen_assert(dst.rows() == a_lhs.rows() && dst.cols() == a_rhs.cols());
415 if (a_lhs.cols() == 0 || a_lhs.rows() == 0 || a_rhs.cols() == 0)
return;
417 if (dst.cols() == 1) {
419 typename Dest::ColXpr dst_vec(dst.col(0));
420 return internal::generic_product_impl<Lhs,
typename Rhs::ConstColXpr, DenseShape, DenseShape,
421 GemvProduct>::scaleAndAddTo(dst_vec, a_lhs, a_rhs.col(0), alpha);
422 }
else if (dst.rows() == 1) {
424 typename Dest::RowXpr dst_vec(dst.row(0));
425 return internal::generic_product_impl<
typename Lhs::ConstRowXpr, Rhs, DenseShape, DenseShape,
426 GemvProduct>::scaleAndAddTo(dst_vec, a_lhs.row(0), a_rhs, alpha);
429 add_const_on_value_type_t<ActualLhsType> lhs = LhsBlasTraits::extract(a_lhs);
430 add_const_on_value_type_t<ActualRhsType> rhs = RhsBlasTraits::extract(a_rhs);
432 Scalar actualAlpha = combine_scalar_factors(alpha, a_lhs, a_rhs);
435 Dest::MaxRowsAtCompileTime, Dest::MaxColsAtCompileTime, MaxDepthAtCompileTime>
438 typedef internal::gemm_functor<
440 internal::general_matrix_matrix_product<
442 bool(LhsBlasTraits::NeedToConjugate), RhsScalar,
445 ActualLhsTypeCleaned, ActualRhsTypeCleaned, Dest, BlockingType>
448 BlockingType blocking(dst.rows(), dst.cols(), lhs.cols(), 1,
true);
449 internal::parallelize_gemm<(Dest::MaxRowsAtCompileTime > 32 || Dest::MaxRowsAtCompileTime ==
Dynamic)>(
450 GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), a_lhs.rows(), a_rhs.cols(), a_lhs.cols(),
@ ColMajor
Definition Constants.h:318
@ RowMajor
Definition Constants.h:320
const unsigned int RowMajorBit
Definition Constants.h:70
Namespace containing all symbols from the Eigen library.
Definition B01_Experimental.dox:1
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:82
const int Dynamic
Definition Constants.h:25