Eigen-unsupported  5.0.1-dev+284dcc12
 
Loading...
Searching...
No Matches
TensorArgMax.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2015 Eugene Brevdo <ebrevdo@gmail.com>
5// Benoit Steiner <benoit.steiner.goog@gmail.com>
6//
7// This Source Code Form is subject to the terms of the Mozilla
8// Public License v. 2.0. If a copy of the MPL was not distributed
9// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10
11#ifndef EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H
12#define EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H
13
14// IWYU pragma: private
15#include "./InternalHeaderCheck.h"
16
17namespace Eigen {
18namespace internal {
19
20template <typename XprType>
21struct traits<TensorIndexPairOp<XprType>> : public traits<XprType> {
22 typedef traits<XprType> XprTraits;
23 typedef typename XprTraits::StorageKind StorageKind;
24 typedef typename XprTraits::Index Index;
25 typedef Pair<Index, typename XprTraits::Scalar> Scalar;
26 typedef typename XprType::Nested Nested;
27 typedef std::remove_reference_t<Nested> Nested_;
28 static constexpr int NumDimensions = XprTraits::NumDimensions;
29 static constexpr int Layout = XprTraits::Layout;
30};
31
32template <typename XprType>
33struct eval<TensorIndexPairOp<XprType>, Eigen::Dense> {
34 typedef const TensorIndexPairOp<XprType> EIGEN_DEVICE_REF type;
35};
36
37template <typename XprType>
38struct nested<TensorIndexPairOp<XprType>, 1, typename eval<TensorIndexPairOp<XprType>>::type> {
39 typedef TensorIndexPairOp<XprType> type;
40};
41
42} // end namespace internal
43
49template <typename XprType>
50class TensorIndexPairOp : public TensorBase<TensorIndexPairOp<XprType>, ReadOnlyAccessors> {
51 public:
52 typedef typename Eigen::internal::traits<TensorIndexPairOp>::Scalar Scalar;
53 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
54 typedef typename Eigen::internal::nested<TensorIndexPairOp>::type Nested;
55 typedef typename Eigen::internal::traits<TensorIndexPairOp>::StorageKind StorageKind;
56 typedef typename Eigen::internal::traits<TensorIndexPairOp>::Index Index;
57 typedef Pair<Index, typename XprType::CoeffReturnType> CoeffReturnType;
58
59 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorIndexPairOp(const XprType& expr) : m_xpr(expr) {}
60
61 EIGEN_DEVICE_FUNC const internal::remove_all_t<typename XprType::Nested>& expression() const { return m_xpr; }
62
63 protected:
64 typename XprType::Nested m_xpr;
65};
66
67// Eval as rvalue
68template <typename ArgType, typename Device>
69struct TensorEvaluator<const TensorIndexPairOp<ArgType>, Device> {
70 typedef TensorIndexPairOp<ArgType> XprType;
71 typedef typename XprType::Index Index;
72 typedef typename XprType::Scalar Scalar;
73 typedef typename XprType::CoeffReturnType CoeffReturnType;
74
75 typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions;
76 static constexpr int NumDims = internal::array_size<Dimensions>::value;
77 typedef StorageMemory<CoeffReturnType, Device> Storage;
78 typedef typename Storage::Type EvaluatorPointerType;
79
80 enum {
81 IsAligned = /*TensorEvaluator<ArgType, Device>::IsAligned*/ false,
82 PacketAccess = /*TensorEvaluator<ArgType, Device>::PacketAccess*/ false,
83 BlockAccess = false,
85 CoordAccess = false, // to be implemented
86 RawAccess = false
87 };
88 static constexpr int Layout = TensorEvaluator<ArgType, Device>::Layout;
89
90 //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
91 typedef internal::TensorBlockNotImplemented TensorBlock;
92 //===--------------------------------------------------------------------===//
93
94 EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) : m_impl(op.expression(), device) {}
95
96 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_impl.dimensions(); }
97
98 EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType /*data*/) {
99 m_impl.evalSubExprsIfNeeded(NULL);
100 return true;
101 }
102 EIGEN_STRONG_INLINE void cleanup() { m_impl.cleanup(); }
103
104 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
105 return CoeffReturnType(index, m_impl.coeff(index));
106 }
107
108 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
109 return m_impl.costPerCoeff(vectorized) + TensorOpCost(0, 0, 1);
110 }
111
112 EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return NULL; }
113
114 protected:
115 TensorEvaluator<ArgType, Device> m_impl;
116};
117
118namespace internal {
119
126template <typename ReduceOp, typename Dims, typename XprType>
127struct traits<TensorPairReducerOp<ReduceOp, Dims, XprType>> : public traits<XprType> {
128 typedef traits<XprType> XprTraits;
129 typedef typename XprTraits::StorageKind StorageKind;
130 typedef typename XprTraits::Index Index;
131 typedef Index Scalar;
132 typedef typename XprType::Nested Nested;
133 typedef std::remove_reference_t<Nested> Nested_;
134 static constexpr int NumDimensions = XprTraits::NumDimensions - array_size<Dims>::value;
135 static constexpr int Layout = XprTraits::Layout;
136};
137
138template <typename ReduceOp, typename Dims, typename XprType>
139struct eval<TensorPairReducerOp<ReduceOp, Dims, XprType>, Eigen::Dense> {
140 typedef const TensorPairReducerOp<ReduceOp, Dims, XprType> EIGEN_DEVICE_REF type;
141};
142
143template <typename ReduceOp, typename Dims, typename XprType>
144struct nested<TensorPairReducerOp<ReduceOp, Dims, XprType>, 1,
145 typename eval<TensorPairReducerOp<ReduceOp, Dims, XprType>>::type> {
146 typedef TensorPairReducerOp<ReduceOp, Dims, XprType> type;
147};
148
149} // end namespace internal
150
151template <typename ReduceOp, typename Dims, typename XprType>
152class TensorPairReducerOp : public TensorBase<TensorPairReducerOp<ReduceOp, Dims, XprType>, ReadOnlyAccessors> {
153 public:
154 typedef typename Eigen::internal::traits<TensorPairReducerOp>::Scalar Scalar;
155 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
156 typedef typename Eigen::internal::nested<TensorPairReducerOp>::type Nested;
157 typedef typename Eigen::internal::traits<TensorPairReducerOp>::StorageKind StorageKind;
158 typedef typename Eigen::internal::traits<TensorPairReducerOp>::Index Index;
159 typedef Index CoeffReturnType;
160
161 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorPairReducerOp(const XprType& expr, const ReduceOp& reduce_op,
162 const Index return_dim, const Dims& reduce_dims)
163 : m_xpr(expr), m_reduce_op(reduce_op), m_return_dim(return_dim), m_reduce_dims(reduce_dims) {}
164
165 EIGEN_DEVICE_FUNC const internal::remove_all_t<typename XprType::Nested>& expression() const { return m_xpr; }
166
167 EIGEN_DEVICE_FUNC const ReduceOp& reduce_op() const { return m_reduce_op; }
168
169 EIGEN_DEVICE_FUNC const Dims& reduce_dims() const { return m_reduce_dims; }
170
171 EIGEN_DEVICE_FUNC Index return_dim() const { return m_return_dim; }
172
173 protected:
174 typename XprType::Nested m_xpr;
175 const ReduceOp m_reduce_op;
176 const Index m_return_dim;
177 const Dims m_reduce_dims;
178};
179
180// Eval as rvalue
181template <typename ReduceOp, typename Dims, typename ArgType, typename Device>
182struct TensorEvaluator<const TensorPairReducerOp<ReduceOp, Dims, ArgType>, Device> {
183 typedef TensorPairReducerOp<ReduceOp, Dims, ArgType> XprType;
184 typedef typename XprType::Index Index;
185 typedef typename XprType::Scalar Scalar;
186 typedef typename XprType::CoeffReturnType CoeffReturnType;
187 typedef typename TensorIndexPairOp<ArgType>::CoeffReturnType PairType;
188 typedef typename TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexPairOp<ArgType>>,
189 Device>::Dimensions Dimensions;
190 typedef typename TensorEvaluator<const TensorIndexPairOp<ArgType>, Device>::Dimensions InputDimensions;
191 static constexpr int NumDims = internal::array_size<InputDimensions>::value;
192 typedef array<Index, NumDims> StrideDims;
193 typedef StorageMemory<CoeffReturnType, Device> Storage;
194 typedef typename Storage::Type EvaluatorPointerType;
195 typedef StorageMemory<PairType, Device> PairStorageMem;
196
197 enum {
198 IsAligned = /*TensorEvaluator<ArgType, Device>::IsAligned*/ false,
199 PacketAccess = /*TensorEvaluator<ArgType, Device>::PacketAccess*/ false,
200 BlockAccess = false,
201 PreferBlockAccess = TensorEvaluator<ArgType, Device>::PreferBlockAccess,
202 CoordAccess = false, // to be implemented
203 RawAccess = false
204 };
205 static constexpr int Layout =
206 TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexPairOp<ArgType>>, Device>::Layout;
207 //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
208 typedef internal::TensorBlockNotImplemented TensorBlock;
209 //===--------------------------------------------------------------------===//
210
211 EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
212 : m_orig_impl(op.expression(), device),
213 m_impl(op.expression().index_pairs().reduce(op.reduce_dims(), op.reduce_op()), device),
214 m_return_dim(op.return_dim()) {
215 gen_strides(m_orig_impl.dimensions(), m_strides);
216 if (Layout == static_cast<int>(ColMajor)) {
217 const Index total_size = internal::array_prod(m_orig_impl.dimensions());
218 m_stride_mod = (m_return_dim < NumDims - 1) ? m_strides[m_return_dim + 1] : total_size;
219 } else {
220 const Index total_size = internal::array_prod(m_orig_impl.dimensions());
221 m_stride_mod = (m_return_dim > 0) ? m_strides[m_return_dim - 1] : total_size;
222 }
223 // If m_return_dim is not a valid index, returns 1 or this can crash on Windows.
224 m_stride_div =
225 ((m_return_dim >= 0) && (m_return_dim < static_cast<Index>(m_strides.size()))) ? m_strides[m_return_dim] : 1;
226 }
227
228 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_impl.dimensions(); }
229
230 EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType /*data*/) {
231 m_impl.evalSubExprsIfNeeded(NULL);
232 return true;
233 }
234 EIGEN_STRONG_INLINE void cleanup() { m_impl.cleanup(); }
235
236 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
237 const PairType v = m_impl.coeff(index);
238 return (m_return_dim < 0) ? v.first : (v.first % m_stride_mod) / m_stride_div;
239 }
240
241 EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return NULL; }
242
243 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
244 const double compute_cost =
245 1.0 + (m_return_dim < 0 ? 0.0 : (TensorOpCost::ModCost<Index>() + TensorOpCost::DivCost<Index>()));
246 return m_orig_impl.costPerCoeff(vectorized) + m_impl.costPerCoeff(vectorized) + TensorOpCost(0, 0, compute_cost);
247 }
248
249 private:
250 EIGEN_DEVICE_FUNC void gen_strides(const InputDimensions& dims, StrideDims& strides) {
251 if (m_return_dim < 0) {
252 return; // Won't be using the strides.
253 }
254 eigen_assert(m_return_dim < NumDims && "Asking to convert index to a dimension outside of the rank");
255
256 // Calculate m_stride_div and m_stride_mod, which are used to
257 // calculate the value of an index w.r.t. the m_return_dim.
258 if (Layout == static_cast<int>(ColMajor)) {
259 strides[0] = 1;
260 for (int i = 1; i < NumDims; ++i) {
261 strides[i] = strides[i - 1] * dims[i - 1];
262 }
263 } else {
264 strides[NumDims - 1] = 1;
265 for (int i = NumDims - 2; i >= 0; --i) {
266 strides[i] = strides[i + 1] * dims[i + 1];
267 }
268 }
269 }
270
271 protected:
272 TensorEvaluator<const TensorIndexPairOp<ArgType>, Device> m_orig_impl;
273 TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexPairOp<ArgType>>, Device> m_impl;
274 const Index m_return_dim;
275 StrideDims m_strides;
276 Index m_stride_mod;
277 Index m_stride_div;
278};
279
280} // end namespace Eigen
281
282#endif // EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H
The tensor base class.
Definition TensorForwardDeclarations.h:68
Tensor + Index Pair class.
Definition TensorArgMax.h:50
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The tensor evaluator class.
Definition TensorEvaluator.h:30