Eigen-unsupported  3.4.1 (git rev 28ded8800c26864e537852658428ab44c8399e87)
 
Loading...
Searching...
No Matches
TensorTrace.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2017 Gagan Goel <gagan.nith@gmail.com>
5// Copyright (C) 2017 Benoit Steiner <benoit.steiner.goog@gmail.com>
6//
7// This Source Code Form is subject to the terms of the Mozilla
8// Public License v. 2.0. If a copy of the MPL was not distributed
9// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10
11#ifndef EIGEN_CXX11_TENSOR_TENSOR_TRACE_H
12#define EIGEN_CXX11_TENSOR_TENSOR_TRACE_H
13
14namespace Eigen {
15
16namespace internal {
17template<typename Dims, typename XprType>
18struct traits<TensorTraceOp<Dims, XprType> > : public traits<XprType>
19{
20 typedef typename XprType::Scalar Scalar;
21 typedef traits<XprType> XprTraits;
22 typedef typename XprTraits::StorageKind StorageKind;
23 typedef typename XprTraits::Index Index;
24 typedef typename XprType::Nested Nested;
25 typedef typename remove_reference<Nested>::type _Nested;
26 static const int NumDimensions = XprTraits::NumDimensions - array_size<Dims>::value;
27 static const int Layout = XprTraits::Layout;
28};
29
30template<typename Dims, typename XprType>
31struct eval<TensorTraceOp<Dims, XprType>, Eigen::Dense>
32{
33 typedef const TensorTraceOp<Dims, XprType>& type;
34};
35
36template<typename Dims, typename XprType>
37struct nested<TensorTraceOp<Dims, XprType>, 1, typename eval<TensorTraceOp<Dims, XprType> >::type>
38{
39 typedef TensorTraceOp<Dims, XprType> type;
40};
41
42} // end namespace internal
43
49template <typename Dims, typename XprType>
50class TensorTraceOp : public TensorBase<TensorTraceOp<Dims, XprType> > {
51 public:
52 typedef typename Eigen::internal::traits<TensorTraceOp>::Scalar Scalar;
53 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
54 typedef typename XprType::CoeffReturnType CoeffReturnType;
55 typedef typename Eigen::internal::nested<TensorTraceOp>::type Nested;
56 typedef typename Eigen::internal::traits<TensorTraceOp>::StorageKind StorageKind;
57 typedef typename Eigen::internal::traits<TensorTraceOp>::Index Index;
58
59 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorTraceOp(const XprType& expr, const Dims& dims)
60 : m_xpr(expr), m_dims(dims) {
61 }
62
63 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
64 const Dims& dims() const { return m_dims; }
65
66 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
67 const typename internal::remove_all<typename XprType::Nested>::type& expression() const { return m_xpr; }
68
69 protected:
70 typename XprType::Nested m_xpr;
71 const Dims m_dims;
72};
73
74
75// Eval as rvalue
76template<typename Dims, typename ArgType, typename Device>
77struct TensorEvaluator<const TensorTraceOp<Dims, ArgType>, Device>
78{
79 typedef TensorTraceOp<Dims, ArgType> XprType;
80 static const int NumInputDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
81 static const int NumReducedDims = internal::array_size<Dims>::value;
82 static const int NumOutputDims = NumInputDims - NumReducedDims;
83 typedef typename XprType::Index Index;
84 typedef DSizes<Index, NumOutputDims> Dimensions;
85 typedef typename XprType::Scalar Scalar;
86 typedef typename XprType::CoeffReturnType CoeffReturnType;
87 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
88 static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
89 typedef StorageMemory<CoeffReturnType, Device> Storage;
90 typedef typename Storage::Type EvaluatorPointerType;
91
92 enum {
93 IsAligned = false,
94 PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
95 BlockAccess = false,
96 PreferBlockAccess = TensorEvaluator<ArgType, Device>::PreferBlockAccess,
97 Layout = TensorEvaluator<ArgType, Device>::Layout,
98 CoordAccess = false,
99 RawAccess = false
100 };
101
102 //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
103 typedef internal::TensorBlockNotImplemented TensorBlock;
104 //===--------------------------------------------------------------------===//
105
106 EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
107 : m_impl(op.expression(), device), m_traceDim(1), m_device(device)
108 {
109
110 EIGEN_STATIC_ASSERT((NumOutputDims >= 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
111 EIGEN_STATIC_ASSERT((NumReducedDims >= 2) || ((NumReducedDims == 0) && (NumInputDims == 0)), YOU_MADE_A_PROGRAMMING_MISTAKE);
112
113 for (int i = 0; i < NumInputDims; ++i) {
114 m_reduced[i] = false;
115 }
116
117 const Dims& op_dims = op.dims();
118 for (int i = 0; i < NumReducedDims; ++i) {
119 eigen_assert(op_dims[i] >= 0);
120 eigen_assert(op_dims[i] < NumInputDims);
121 m_reduced[op_dims[i]] = true;
122 }
123
124 // All the dimensions should be distinct to compute the trace
125 int num_distinct_reduce_dims = 0;
126 for (int i = 0; i < NumInputDims; ++i) {
127 if (m_reduced[i]) {
128 ++num_distinct_reduce_dims;
129 }
130 }
131
132 EIGEN_ONLY_USED_FOR_DEBUG(num_distinct_reduce_dims);
133 eigen_assert(num_distinct_reduce_dims == NumReducedDims);
134
135 // Compute the dimensions of the result.
136 const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();
137
138 int output_index = 0;
139 int reduced_index = 0;
140 for (int i = 0; i < NumInputDims; ++i) {
141 if (m_reduced[i]) {
142 m_reducedDims[reduced_index] = input_dims[i];
143 if (reduced_index > 0) {
144 // All the trace dimensions must have the same size
145 eigen_assert(m_reducedDims[0] == m_reducedDims[reduced_index]);
146 }
147 ++reduced_index;
148 }
149 else {
150 m_dimensions[output_index] = input_dims[i];
151 ++output_index;
152 }
153 }
154
155 if (NumReducedDims != 0) {
156 m_traceDim = m_reducedDims[0];
157 }
158
159 // Compute the output strides
160 if (NumOutputDims > 0) {
161 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
162 m_outputStrides[0] = 1;
163 for (int i = 1; i < NumOutputDims; ++i) {
164 m_outputStrides[i] = m_outputStrides[i - 1] * m_dimensions[i - 1];
165 }
166 }
167 else {
168 m_outputStrides.back() = 1;
169 for (int i = NumOutputDims - 2; i >= 0; --i) {
170 m_outputStrides[i] = m_outputStrides[i + 1] * m_dimensions[i + 1];
171 }
172 }
173 }
174
175 // Compute the input strides
176 if (NumInputDims > 0) {
177 array<Index, NumInputDims> input_strides;
178 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
179 input_strides[0] = 1;
180 for (int i = 1; i < NumInputDims; ++i) {
181 input_strides[i] = input_strides[i - 1] * input_dims[i - 1];
182 }
183 }
184 else {
185 input_strides.back() = 1;
186 for (int i = NumInputDims - 2; i >= 0; --i) {
187 input_strides[i] = input_strides[i + 1] * input_dims[i + 1];
188 }
189 }
190
191 output_index = 0;
192 reduced_index = 0;
193 for (int i = 0; i < NumInputDims; ++i) {
194 if(m_reduced[i]) {
195 m_reducedStrides[reduced_index] = input_strides[i];
196 ++reduced_index;
197 }
198 else {
199 m_preservedStrides[output_index] = input_strides[i];
200 ++output_index;
201 }
202 }
203 }
204 }
205
206 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const {
207 return m_dimensions;
208 }
209
210 EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType /*data*/) {
211 m_impl.evalSubExprsIfNeeded(NULL);
212 return true;
213 }
214
215 EIGEN_STRONG_INLINE void cleanup() {
216 m_impl.cleanup();
217 }
218
219 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
220 {
221 // Initialize the result
222 CoeffReturnType result = internal::cast<int, CoeffReturnType>(0);
223 Index index_stride = 0;
224 for (int i = 0; i < NumReducedDims; ++i) {
225 index_stride += m_reducedStrides[i];
226 }
227
228 // If trace is requested along all dimensions, starting index would be 0
229 Index cur_index = 0;
230 if (NumOutputDims != 0)
231 cur_index = firstInput(index);
232 for (Index i = 0; i < m_traceDim; ++i) {
233 result += m_impl.coeff(cur_index);
234 cur_index += index_stride;
235 }
236
237 return result;
238 }
239
240 template<int LoadMode>
241 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const {
242
243 EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE);
244 eigen_assert(index + PacketSize - 1 < dimensions().TotalSize());
245
246 EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
247 for (int i = 0; i < PacketSize; ++i) {
248 values[i] = coeff(index + i);
249 }
250 PacketReturnType result = internal::ploadt<PacketReturnType, LoadMode>(values);
251 return result;
252 }
253
254#ifdef EIGEN_USE_SYCL
255 // binding placeholder accessors to a command group handler for SYCL
256 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {
257 m_impl.bind(cgh);
258 }
259#endif
260
261 protected:
262 // Given the output index, finds the first index in the input tensor used to compute the trace
263 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index firstInput(Index index) const {
264 Index startInput = 0;
265 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
266 for (int i = NumOutputDims - 1; i > 0; --i) {
267 const Index idx = index / m_outputStrides[i];
268 startInput += idx * m_preservedStrides[i];
269 index -= idx * m_outputStrides[i];
270 }
271 startInput += index * m_preservedStrides[0];
272 }
273 else {
274 for (int i = 0; i < NumOutputDims - 1; ++i) {
275 const Index idx = index / m_outputStrides[i];
276 startInput += idx * m_preservedStrides[i];
277 index -= idx * m_outputStrides[i];
278 }
279 startInput += index * m_preservedStrides[NumOutputDims - 1];
280 }
281 return startInput;
282 }
283
284 Dimensions m_dimensions;
285 TensorEvaluator<ArgType, Device> m_impl;
286 // Initialize the size of the trace dimension
287 Index m_traceDim;
288 const Device EIGEN_DEVICE_REF m_device;
289 array<bool, NumInputDims> m_reduced;
290 array<Index, NumReducedDims> m_reducedDims;
291 array<Index, NumOutputDims> m_outputStrides;
292 array<Index, NumReducedDims> m_reducedStrides;
293 array<Index, NumOutputDims> m_preservedStrides;
294};
295
296
297} // End namespace Eigen
298
299#endif // EIGEN_CXX11_TENSOR_TENSOR_TRACE_H
The tensor base class.
Definition TensorForwardDeclarations.h:56
Tensor Trace class.
Definition TensorTrace.h:50
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The tensor evaluator class.
Definition TensorEvaluator.h:27