Eigen-unsupported  5.0.1-dev+284dcc12
 
Loading...
Searching...
No Matches
TensorContractionBlocking.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H
11#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H
12
13// IWYU pragma: private
14#include "./InternalHeaderCheck.h"
15
16namespace Eigen {
17namespace internal {
18
19enum { ShardByRow = 0, ShardByCol = 1 };
20
21// Default Blocking Strategy
22template <typename ResScalar, typename LhsScalar, typename RhsScalar, typename StorageIndex,
23 int ShardingType = ShardByCol>
24class TensorContractionBlocking {
25 public:
26 /*
27 adding EIGEN_DEVICE_FUNC unconditionally to 'TensorContractionBlocking' constructor in `TensorContractionBlocking.h`
28 requires adding EIGEN_DEVICE_FUNC to `computeProductBlockingSizes` in `GeneralBlockPanelKernel.h`
29 which in turn, requires adding EIGEN_DEVICE_FUNC to `evaluateProductBlockingSizesHeuristic` in
30 `GeneralBlockPanelKernel.h` which in turn, requires adding EIGEN_DEVICE_FUNC to `manage_caching_sizes` in
31 `GeneralBlockPanelKernel.h` (else HIPCC will error out)
32
33 However adding EIGEN_DEVICE_FUNC to `manage_caching_sizes` in `GeneralBlockPanelKernel.h`
34 results in NVCC erroring out with the following error
35
36 ../Eigen/src/Core/products/GeneralBlockPanelKernel.h(57): error #2901:
37 dynamic initialization is not supported for function-scope static variables within a __device__/__global__
38 function
39 */
40
41#if !defined(EIGEN_HIPCC)
42 EIGEN_DEVICE_FUNC
43#endif
44 TensorContractionBlocking(StorageIndex k, StorageIndex m, StorageIndex n, StorageIndex num_threads = 1)
45 : kc_(k), mc_(m), nc_(n) {
46 if (ShardingType == ShardByCol) {
47 computeProductBlockingSizes<LhsScalar, RhsScalar, 1>(kc_, mc_, nc_, num_threads);
48 } else {
49 computeProductBlockingSizes<LhsScalar, RhsScalar, 1>(kc_, nc_, mc_, num_threads);
50 }
51
52 const int rhs_packet_size = internal::packet_traits<RhsScalar>::size;
53 kc_ = (rhs_packet_size <= 8 || kc_ <= rhs_packet_size) ? kc_ : (kc_ / rhs_packet_size) * rhs_packet_size;
54 }
55
56 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE StorageIndex kc() const { return kc_; }
57 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE StorageIndex mc() const { return mc_; }
58 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE StorageIndex nc() const { return nc_; }
59
60 private:
61 StorageIndex kc_;
62 StorageIndex mc_;
63 StorageIndex nc_;
64};
65
66} // end namespace internal
67} // end namespace Eigen
68
69#endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H
Namespace containing all symbols from the Eigen library.