Eigen-unsupported  5.0.1-dev+284dcc12
 
Loading...
Searching...
No Matches
TensorTrace.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2017 Gagan Goel <gagan.nith@gmail.com>
5// Copyright (C) 2017 Benoit Steiner <benoit.steiner.goog@gmail.com>
6//
7// This Source Code Form is subject to the terms of the Mozilla
8// Public License v. 2.0. If a copy of the MPL was not distributed
9// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10
11#ifndef EIGEN_CXX11_TENSOR_TENSOR_TRACE_H
12#define EIGEN_CXX11_TENSOR_TENSOR_TRACE_H
13
14// IWYU pragma: private
15#include "./InternalHeaderCheck.h"
16
17namespace Eigen {
18
19namespace internal {
20template <typename Dims, typename XprType>
21struct traits<TensorTraceOp<Dims, XprType> > : public traits<XprType> {
22 typedef typename XprType::Scalar Scalar;
23 typedef traits<XprType> XprTraits;
24 typedef typename XprTraits::StorageKind StorageKind;
25 typedef typename XprTraits::Index Index;
26 typedef typename XprType::Nested Nested;
27 typedef std::remove_reference_t<Nested> Nested_;
28 static constexpr int NumDimensions = XprTraits::NumDimensions - array_size<Dims>::value;
29 static constexpr int Layout = XprTraits::Layout;
30 enum {
31 // Trace is read-only.
32 Flags = traits<XprType>::Flags & ~LvalueBit
33 };
34};
35
36template <typename Dims, typename XprType>
37struct eval<TensorTraceOp<Dims, XprType>, Eigen::Dense> {
38 typedef const TensorTraceOp<Dims, XprType>& type;
39};
40
41template <typename Dims, typename XprType>
42struct nested<TensorTraceOp<Dims, XprType>, 1, typename eval<TensorTraceOp<Dims, XprType> >::type> {
43 typedef TensorTraceOp<Dims, XprType> type;
44};
45
46} // end namespace internal
47
53template <typename Dims, typename XprType>
54class TensorTraceOp : public TensorBase<TensorTraceOp<Dims, XprType> > {
55 public:
56 typedef typename Eigen::internal::traits<TensorTraceOp>::Scalar Scalar;
57 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
58 typedef typename XprType::CoeffReturnType CoeffReturnType;
59 typedef typename Eigen::internal::nested<TensorTraceOp>::type Nested;
60 typedef typename Eigen::internal::traits<TensorTraceOp>::StorageKind StorageKind;
61 typedef typename Eigen::internal::traits<TensorTraceOp>::Index Index;
62
63 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorTraceOp(const XprType& expr, const Dims& dims)
64 : m_xpr(expr), m_dims(dims) {}
65
66 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dims& dims() const { return m_dims; }
67
68 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const internal::remove_all_t<typename XprType::Nested>& expression() const {
69 return m_xpr;
70 }
71
72 protected:
73 typename XprType::Nested m_xpr;
74 const Dims m_dims;
75};
76
77// Eval as rvalue
78template <typename Dims, typename ArgType, typename Device>
79struct TensorEvaluator<const TensorTraceOp<Dims, ArgType>, Device> {
80 typedef TensorTraceOp<Dims, ArgType> XprType;
81 static constexpr int NumInputDims =
82 internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
83 static constexpr int NumReducedDims = internal::array_size<Dims>::value;
84 static constexpr int NumOutputDims = NumInputDims - NumReducedDims;
85 typedef typename XprType::Index Index;
86 typedef DSizes<Index, NumOutputDims> Dimensions;
87 typedef typename XprType::Scalar Scalar;
88 typedef typename XprType::CoeffReturnType CoeffReturnType;
89 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
90 static constexpr int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
91 typedef StorageMemory<CoeffReturnType, Device> Storage;
92 typedef typename Storage::Type EvaluatorPointerType;
93
94 static constexpr int Layout = TensorEvaluator<ArgType, Device>::Layout;
95 enum {
96 IsAligned = false,
97 PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
98 BlockAccess = false,
100 CoordAccess = false,
101 RawAccess = false
102 };
103
104 //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
105 typedef internal::TensorBlockNotImplemented TensorBlock;
106 //===--------------------------------------------------------------------===//
107
108 EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
109 : m_impl(op.expression(), device), m_traceDim(1), m_device(device) {
110 EIGEN_STATIC_ASSERT((NumOutputDims >= 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
111 EIGEN_STATIC_ASSERT((NumReducedDims >= 2) || ((NumReducedDims == 0) && (NumInputDims == 0)),
112 YOU_MADE_A_PROGRAMMING_MISTAKE);
113
114 for (int i = 0; i < NumInputDims; ++i) {
115 m_reduced[i] = false;
116 }
117
118 const Dims& op_dims = op.dims();
119 for (int i = 0; i < NumReducedDims; ++i) {
120 eigen_assert(op_dims[i] >= 0);
121 eigen_assert(op_dims[i] < NumInputDims);
122 m_reduced[op_dims[i]] = true;
123 }
124
125 // All the dimensions should be distinct to compute the trace
126 int num_distinct_reduce_dims = 0;
127 for (int i = 0; i < NumInputDims; ++i) {
128 if (m_reduced[i]) {
129 ++num_distinct_reduce_dims;
130 }
131 }
132
133 EIGEN_ONLY_USED_FOR_DEBUG(num_distinct_reduce_dims);
134 eigen_assert(num_distinct_reduce_dims == NumReducedDims);
135
136 // Compute the dimensions of the result.
137 const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();
138
139 int output_index = 0;
140 int reduced_index = 0;
141 for (int i = 0; i < NumInputDims; ++i) {
142 if (m_reduced[i]) {
143 m_reducedDims[reduced_index] = input_dims[i];
144 if (reduced_index > 0) {
145 // All the trace dimensions must have the same size
146 eigen_assert(m_reducedDims[0] == m_reducedDims[reduced_index]);
147 }
148 ++reduced_index;
149 } else {
150 m_dimensions[output_index] = input_dims[i];
151 ++output_index;
152 }
153 }
154
155 if (NumReducedDims != 0) {
156 m_traceDim = m_reducedDims[0];
157 }
158
159 // Compute the output strides
160 if (NumOutputDims > 0) {
161 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
162 m_outputStrides[0] = 1;
163 for (int i = 1; i < NumOutputDims; ++i) {
164 m_outputStrides[i] = m_outputStrides[i - 1] * m_dimensions[i - 1];
165 }
166 } else {
167 m_outputStrides.back() = 1;
168 for (int i = NumOutputDims - 2; i >= 0; --i) {
169 m_outputStrides[i] = m_outputStrides[i + 1] * m_dimensions[i + 1];
170 }
171 }
172 }
173
174 // Compute the input strides
175 if (NumInputDims > 0) {
176 array<Index, NumInputDims> input_strides;
177 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
178 input_strides[0] = 1;
179 for (int i = 1; i < NumInputDims; ++i) {
180 input_strides[i] = input_strides[i - 1] * input_dims[i - 1];
181 }
182 } else {
183 input_strides.back() = 1;
184 for (int i = NumInputDims - 2; i >= 0; --i) {
185 input_strides[i] = input_strides[i + 1] * input_dims[i + 1];
186 }
187 }
188
189 output_index = 0;
190 reduced_index = 0;
191 for (int i = 0; i < NumInputDims; ++i) {
192 if (m_reduced[i]) {
193 m_reducedStrides[reduced_index] = input_strides[i];
194 ++reduced_index;
195 } else {
196 m_preservedStrides[output_index] = input_strides[i];
197 ++output_index;
198 }
199 }
200 }
201 }
202
203 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
204
205 EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType /*data*/) {
206 m_impl.evalSubExprsIfNeeded(NULL);
207 return true;
208 }
209
210 EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return nullptr; }
211
212 EIGEN_STRONG_INLINE void cleanup() { m_impl.cleanup(); }
213
214 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
215 // Initialize the result
216 CoeffReturnType result = internal::cast<int, CoeffReturnType>(0);
217 Index index_stride = 0;
218 for (int i = 0; i < NumReducedDims; ++i) {
219 index_stride += m_reducedStrides[i];
220 }
221
222 // If trace is requested along all dimensions, starting index would be 0
223 Index cur_index = 0;
224 if (NumOutputDims != 0) cur_index = firstInput(index);
225 for (Index i = 0; i < m_traceDim; ++i) {
226 result += m_impl.coeff(cur_index);
227 cur_index += index_stride;
228 }
229
230 return result;
231 }
232
233 template <int LoadMode>
234 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const {
235 eigen_assert(index + PacketSize - 1 < dimensions().TotalSize());
236
237 EIGEN_ALIGN_MAX std::remove_const_t<CoeffReturnType> values[PacketSize];
238 for (int i = 0; i < PacketSize; ++i) {
239 values[i] = coeff(index + i);
240 }
241 PacketReturnType result = internal::ploadt<PacketReturnType, LoadMode>(values);
242 return result;
243 }
244
245 protected:
246 // Given the output index, finds the first index in the input tensor used to compute the trace
247 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index firstInput(Index index) const {
248 Index startInput = 0;
249 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
250 for (int i = NumOutputDims - 1; i > 0; --i) {
251 const Index idx = index / m_outputStrides[i];
252 startInput += idx * m_preservedStrides[i];
253 index -= idx * m_outputStrides[i];
254 }
255 startInput += index * m_preservedStrides[0];
256 } else {
257 for (int i = 0; i < NumOutputDims - 1; ++i) {
258 const Index idx = index / m_outputStrides[i];
259 startInput += idx * m_preservedStrides[i];
260 index -= idx * m_outputStrides[i];
261 }
262 startInput += index * m_preservedStrides[NumOutputDims - 1];
263 }
264 return startInput;
265 }
266
267 Dimensions m_dimensions;
268 TensorEvaluator<ArgType, Device> m_impl;
269 // Initialize the size of the trace dimension
270 Index m_traceDim;
271 const Device EIGEN_DEVICE_REF m_device;
272 array<bool, NumInputDims> m_reduced;
273 array<Index, NumReducedDims> m_reducedDims;
274 array<Index, NumOutputDims> m_outputStrides;
275 array<Index, NumReducedDims> m_reducedStrides;
276 array<Index, NumOutputDims> m_preservedStrides;
277};
278
279} // End namespace Eigen
280
281#endif // EIGEN_CXX11_TENSOR_TENSOR_TRACE_H
The tensor base class.
Definition TensorForwardDeclarations.h:68
Tensor Trace class.
Definition TensorTrace.h:54
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The tensor evaluator class.
Definition TensorEvaluator.h:30