10#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H
11#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H
14#include "./InternalHeaderCheck.h"
19enum { ShardByRow = 0, ShardByCol = 1 };
22template <
typename ResScalar,
typename LhsScalar,
typename RhsScalar,
typename StorageIndex,
23 int ShardingType = ShardByCol>
24class TensorContractionBlocking {
41#if !defined(EIGEN_HIPCC)
44 TensorContractionBlocking(StorageIndex k, StorageIndex m, StorageIndex n, StorageIndex num_threads = 1)
45 : kc_(k), mc_(m), nc_(n) {
46 if (ShardingType == ShardByCol) {
47 computeProductBlockingSizes<LhsScalar, RhsScalar, 1>(kc_, mc_, nc_, num_threads);
49 computeProductBlockingSizes<LhsScalar, RhsScalar, 1>(kc_, nc_, mc_, num_threads);
52 const int rhs_packet_size = internal::packet_traits<RhsScalar>::size;
53 kc_ = (rhs_packet_size <= 8 || kc_ <= rhs_packet_size) ? kc_ : (kc_ / rhs_packet_size) * rhs_packet_size;
56 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE StorageIndex kc()
const {
return kc_; }
57 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE StorageIndex mc()
const {
return mc_; }
58 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE StorageIndex nc()
const {
return nc_; }
Namespace containing all symbols from the Eigen library.