Eigen-unsupported  3.4.1 (git rev 28ded8800c26864e537852658428ab44c8399e87)
 
Loading...
Searching...
No Matches
TensorArgMax.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2015 Eugene Brevdo <ebrevdo@gmail.com>
5// Benoit Steiner <benoit.steiner.goog@gmail.com>
6//
7// This Source Code Form is subject to the terms of the Mozilla
8// Public License v. 2.0. If a copy of the MPL was not distributed
9// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10
11#ifndef EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H
12#define EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H
13
14namespace Eigen {
15namespace internal {
16
17template<typename XprType>
18struct traits<TensorIndexTupleOp<XprType> > : public traits<XprType>
19{
20 typedef traits<XprType> XprTraits;
21 typedef typename XprTraits::StorageKind StorageKind;
22 typedef typename XprTraits::Index Index;
23 typedef Tuple<Index, typename XprTraits::Scalar> Scalar;
24 typedef typename XprType::Nested Nested;
25 typedef typename remove_reference<Nested>::type _Nested;
26 static const int NumDimensions = XprTraits::NumDimensions;
27 static const int Layout = XprTraits::Layout;
28};
29
30template<typename XprType>
31struct eval<TensorIndexTupleOp<XprType>, Eigen::Dense>
32{
33 typedef const TensorIndexTupleOp<XprType>EIGEN_DEVICE_REF type;
34};
35
36template<typename XprType>
37struct nested<TensorIndexTupleOp<XprType>, 1,
38 typename eval<TensorIndexTupleOp<XprType> >::type>
39{
40 typedef TensorIndexTupleOp<XprType> type;
41};
42
43} // end namespace internal
44
50template<typename XprType>
51class TensorIndexTupleOp : public TensorBase<TensorIndexTupleOp<XprType>, ReadOnlyAccessors>
52{
53 public:
54 typedef typename Eigen::internal::traits<TensorIndexTupleOp>::Scalar Scalar;
55 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
56 typedef typename Eigen::internal::nested<TensorIndexTupleOp>::type Nested;
57 typedef typename Eigen::internal::traits<TensorIndexTupleOp>::StorageKind StorageKind;
58 typedef typename Eigen::internal::traits<TensorIndexTupleOp>::Index Index;
59 typedef Tuple<Index, typename XprType::CoeffReturnType> CoeffReturnType;
60
61 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorIndexTupleOp(const XprType& expr)
62 : m_xpr(expr) {}
63
64 EIGEN_DEVICE_FUNC
65 const typename internal::remove_all<typename XprType::Nested>::type&
66 expression() const { return m_xpr; }
67
68 protected:
69 typename XprType::Nested m_xpr;
70};
71
72// Eval as rvalue
73template<typename ArgType, typename Device>
74struct TensorEvaluator<const TensorIndexTupleOp<ArgType>, Device>
75{
76 typedef TensorIndexTupleOp<ArgType> XprType;
77 typedef typename XprType::Index Index;
78 typedef typename XprType::Scalar Scalar;
79 typedef typename XprType::CoeffReturnType CoeffReturnType;
80
81 typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions;
82 static const int NumDims = internal::array_size<Dimensions>::value;
83 typedef StorageMemory<CoeffReturnType, Device> Storage;
84 typedef typename Storage::Type EvaluatorPointerType;
85
86 enum {
87 IsAligned = /*TensorEvaluator<ArgType, Device>::IsAligned*/ false,
88 PacketAccess = /*TensorEvaluator<ArgType, Device>::PacketAccess*/ false,
89 BlockAccess = false,
90 PreferBlockAccess = TensorEvaluator<ArgType, Device>::PreferBlockAccess,
91 Layout = TensorEvaluator<ArgType, Device>::Layout,
92 CoordAccess = false, // to be implemented
93 RawAccess = false
94 };
95
96 //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
97 typedef internal::TensorBlockNotImplemented TensorBlock;
98 //===--------------------------------------------------------------------===//
99
100 EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
101 : m_impl(op.expression(), device) { }
102
103 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const {
104 return m_impl.dimensions();
105 }
106
107 EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType /*data*/) {
108 m_impl.evalSubExprsIfNeeded(NULL);
109 return true;
110 }
111 EIGEN_STRONG_INLINE void cleanup() {
112 m_impl.cleanup();
113 }
114
115 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
116 {
117 return CoeffReturnType(index, m_impl.coeff(index));
118 }
119
120 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
121 costPerCoeff(bool vectorized) const {
122 return m_impl.costPerCoeff(vectorized) + TensorOpCost(0, 0, 1);
123 }
124
125 EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return NULL; }
126
127#ifdef EIGEN_USE_SYCL
128 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {
129 m_impl.bind(cgh);
130 }
131#endif
132
133 protected:
134 TensorEvaluator<ArgType, Device> m_impl;
135};
136
137namespace internal {
138
145template<typename ReduceOp, typename Dims, typename XprType>
146struct traits<TensorTupleReducerOp<ReduceOp, Dims, XprType> > : public traits<XprType>
147{
148 typedef traits<XprType> XprTraits;
149 typedef typename XprTraits::StorageKind StorageKind;
150 typedef typename XprTraits::Index Index;
151 typedef Index Scalar;
152 typedef typename XprType::Nested Nested;
153 typedef typename remove_reference<Nested>::type _Nested;
154 static const int NumDimensions = XprTraits::NumDimensions - array_size<Dims>::value;
155 static const int Layout = XprTraits::Layout;
156};
157
158template<typename ReduceOp, typename Dims, typename XprType>
159struct eval<TensorTupleReducerOp<ReduceOp, Dims, XprType>, Eigen::Dense>
160{
161 typedef const TensorTupleReducerOp<ReduceOp, Dims, XprType>EIGEN_DEVICE_REF type;
162};
163
164template<typename ReduceOp, typename Dims, typename XprType>
165struct nested<TensorTupleReducerOp<ReduceOp, Dims, XprType>, 1,
166 typename eval<TensorTupleReducerOp<ReduceOp, Dims, XprType> >::type>
167{
168 typedef TensorTupleReducerOp<ReduceOp, Dims, XprType> type;
169};
170
171} // end namespace internal
172
173template<typename ReduceOp, typename Dims, typename XprType>
174class TensorTupleReducerOp : public TensorBase<TensorTupleReducerOp<ReduceOp, Dims, XprType>, ReadOnlyAccessors>
175{
176 public:
177 typedef typename Eigen::internal::traits<TensorTupleReducerOp>::Scalar Scalar;
178 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
179 typedef typename Eigen::internal::nested<TensorTupleReducerOp>::type Nested;
180 typedef typename Eigen::internal::traits<TensorTupleReducerOp>::StorageKind StorageKind;
181 typedef typename Eigen::internal::traits<TensorTupleReducerOp>::Index Index;
182 typedef Index CoeffReturnType;
183
184 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorTupleReducerOp(const XprType& expr,
185 const ReduceOp& reduce_op,
186 const Index return_dim,
187 const Dims& reduce_dims)
188 : m_xpr(expr), m_reduce_op(reduce_op), m_return_dim(return_dim), m_reduce_dims(reduce_dims) {}
189
190 EIGEN_DEVICE_FUNC
191 const typename internal::remove_all<typename XprType::Nested>::type&
192 expression() const { return m_xpr; }
193
194 EIGEN_DEVICE_FUNC
195 const ReduceOp& reduce_op() const { return m_reduce_op; }
196
197 EIGEN_DEVICE_FUNC
198 const Dims& reduce_dims() const { return m_reduce_dims; }
199
200 EIGEN_DEVICE_FUNC
201 Index return_dim() const { return m_return_dim; }
202
203 protected:
204 typename XprType::Nested m_xpr;
205 const ReduceOp m_reduce_op;
206 const Index m_return_dim;
207 const Dims m_reduce_dims;
208};
209
210// Eval as rvalue
211template<typename ReduceOp, typename Dims, typename ArgType, typename Device>
212struct TensorEvaluator<const TensorTupleReducerOp<ReduceOp, Dims, ArgType>, Device>
213{
214 typedef TensorTupleReducerOp<ReduceOp, Dims, ArgType> XprType;
215 typedef typename XprType::Index Index;
216 typedef typename XprType::Scalar Scalar;
217 typedef typename XprType::CoeffReturnType CoeffReturnType;
218 typedef typename TensorIndexTupleOp<ArgType>::CoeffReturnType TupleType;
219 typedef typename TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device>::Dimensions Dimensions;
220 typedef typename TensorEvaluator<const TensorIndexTupleOp<ArgType> , Device>::Dimensions InputDimensions;
221 static const int NumDims = internal::array_size<InputDimensions>::value;
222 typedef array<Index, NumDims> StrideDims;
223 typedef StorageMemory<CoeffReturnType, Device> Storage;
224 typedef typename Storage::Type EvaluatorPointerType;
225 typedef StorageMemory<TupleType, Device> TupleStorageMem;
226
227 enum {
228 IsAligned = /*TensorEvaluator<ArgType, Device>::IsAligned*/ false,
229 PacketAccess = /*TensorEvaluator<ArgType, Device>::PacketAccess*/ false,
230 BlockAccess = false,
231 PreferBlockAccess = TensorEvaluator<ArgType, Device>::PreferBlockAccess,
232 Layout = TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device>::Layout,
233 CoordAccess = false, // to be implemented
234 RawAccess = false
235 };
236
237 //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
238 typedef internal::TensorBlockNotImplemented TensorBlock;
239 //===--------------------------------------------------------------------===//
240
241 EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
242 : m_orig_impl(op.expression(), device),
243 m_impl(op.expression().index_tuples().reduce(op.reduce_dims(), op.reduce_op()), device),
244 m_return_dim(op.return_dim())
245 {
246 gen_strides(m_orig_impl.dimensions(), m_strides);
247 if (Layout == static_cast<int>(ColMajor)) {
248 const Index total_size = internal::array_prod(m_orig_impl.dimensions());
249 m_stride_mod = (m_return_dim < NumDims - 1) ? m_strides[m_return_dim + 1] : total_size;
250 } else {
251 const Index total_size = internal::array_prod(m_orig_impl.dimensions());
252 m_stride_mod = (m_return_dim > 0) ? m_strides[m_return_dim - 1] : total_size;
253 }
254 // If m_return_dim is not a valid index, returns 1 or this can crash on Windows.
255 m_stride_div = ((m_return_dim >= 0) &&
256 (m_return_dim < static_cast<Index>(m_strides.size())))
257 ? m_strides[m_return_dim] : 1;
258 }
259
260 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const {
261 return m_impl.dimensions();
262 }
263
264 EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType /*data*/) {
265 m_impl.evalSubExprsIfNeeded(NULL);
266 return true;
267 }
268 EIGEN_STRONG_INLINE void cleanup() {
269 m_impl.cleanup();
270 }
271
272 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
273 const TupleType v = m_impl.coeff(index);
274 return (m_return_dim < 0) ? v.first : (v.first % m_stride_mod) / m_stride_div;
275 }
276
277 EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return NULL; }
278#ifdef EIGEN_USE_SYCL
279 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {
280 m_impl.bind(cgh);
281 m_orig_impl.bind(cgh);
282 }
283#endif
284
285 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
286 costPerCoeff(bool vectorized) const {
287 const double compute_cost = 1.0 +
288 (m_return_dim < 0 ? 0.0 : (TensorOpCost::ModCost<Index>() + TensorOpCost::DivCost<Index>()));
289 return m_orig_impl.costPerCoeff(vectorized) +
290 m_impl.costPerCoeff(vectorized) + TensorOpCost(0, 0, compute_cost);
291 }
292
293 private:
294 EIGEN_DEVICE_FUNC void gen_strides(const InputDimensions& dims, StrideDims& strides) {
295 if (m_return_dim < 0) {
296 return; // Won't be using the strides.
297 }
298 eigen_assert(m_return_dim < NumDims &&
299 "Asking to convert index to a dimension outside of the rank");
300
301 // Calculate m_stride_div and m_stride_mod, which are used to
302 // calculate the value of an index w.r.t. the m_return_dim.
303 if (Layout == static_cast<int>(ColMajor)) {
304 strides[0] = 1;
305 for (int i = 1; i < NumDims; ++i) {
306 strides[i] = strides[i-1] * dims[i-1];
307 }
308 } else {
309 strides[NumDims-1] = 1;
310 for (int i = NumDims - 2; i >= 0; --i) {
311 strides[i] = strides[i+1] * dims[i+1];
312 }
313 }
314 }
315
316 protected:
317 TensorEvaluator<const TensorIndexTupleOp<ArgType>, Device> m_orig_impl;
318 TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device> m_impl;
319 const Index m_return_dim;
320 StrideDims m_strides;
321 Index m_stride_mod;
322 Index m_stride_div;
323};
324
325} // end namespace Eigen
326
327#endif // EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H
The tensor base class.
Definition TensorForwardDeclarations.h:56
Tensor + Index Pair class.
Definition TensorArgMax.h:52
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The tensor evaluator class.
Definition TensorEvaluator.h:27