11#ifndef EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H
12#define EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H
15#include "./InternalHeaderCheck.h"
20template <
typename XprType>
21struct traits<TensorIndexPairOp<XprType>> :
public traits<XprType> {
22 typedef traits<XprType> XprTraits;
23 typedef typename XprTraits::StorageKind StorageKind;
24 typedef typename XprTraits::Index
Index;
25 typedef Pair<Index, typename XprTraits::Scalar> Scalar;
26 typedef typename XprType::Nested Nested;
27 typedef std::remove_reference_t<Nested> Nested_;
28 static constexpr int NumDimensions = XprTraits::NumDimensions;
29 static constexpr int Layout = XprTraits::Layout;
32template <
typename XprType>
33struct eval<TensorIndexPairOp<XprType>, Eigen::Dense> {
34 typedef const TensorIndexPairOp<XprType> EIGEN_DEVICE_REF type;
37template <
typename XprType>
38struct nested<TensorIndexPairOp<XprType>, 1, typename eval<TensorIndexPairOp<XprType>>::type> {
39 typedef TensorIndexPairOp<XprType> type;
49template <
typename XprType>
50class TensorIndexPairOp :
public TensorBase<TensorIndexPairOp<XprType>, ReadOnlyAccessors> {
52 typedef typename Eigen::internal::traits<TensorIndexPairOp>::Scalar Scalar;
54 typedef typename Eigen::internal::nested<TensorIndexPairOp>::type Nested;
55 typedef typename Eigen::internal::traits<TensorIndexPairOp>::StorageKind StorageKind;
56 typedef typename Eigen::internal::traits<TensorIndexPairOp>::Index Index;
57 typedef Pair<Index, typename XprType::CoeffReturnType> CoeffReturnType;
59 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorIndexPairOp(
const XprType& expr) : m_xpr(expr) {}
61 EIGEN_DEVICE_FUNC
const internal::remove_all_t<typename XprType::Nested>& expression()
const {
return m_xpr; }
64 typename XprType::Nested m_xpr;
68template <
typename ArgType,
typename Device>
71 typedef typename XprType::Index
Index;
72 typedef typename XprType::Scalar
Scalar;
75 typedef typename TensorEvaluator<ArgType, Device>::Dimensions
Dimensions;
76 static constexpr int NumDims = internal::array_size<Dimensions>::value;
77 typedef StorageMemory<CoeffReturnType, Device> Storage;
78 typedef typename Storage::Type EvaluatorPointerType;
88 static constexpr int Layout = TensorEvaluator<ArgType, Device>::Layout;
91 typedef internal::TensorBlockNotImplemented TensorBlock;
94 EIGEN_STRONG_INLINE TensorEvaluator(
const XprType& op,
const Device& device) : m_impl(op.expression(), device) {}
96 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Dimensions& dimensions()
const {
return m_impl.dimensions(); }
98 EIGEN_STRONG_INLINE
bool evalSubExprsIfNeeded(EvaluatorPointerType ) {
99 m_impl.evalSubExprsIfNeeded(NULL);
102 EIGEN_STRONG_INLINE
void cleanup() { m_impl.cleanup(); }
104 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index)
const {
105 return CoeffReturnType(index, m_impl.coeff(index));
108 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(
bool vectorized)
const {
109 return m_impl.costPerCoeff(vectorized) + TensorOpCost(0, 0, 1);
112 EIGEN_DEVICE_FUNC EvaluatorPointerType data()
const {
return NULL; }
115 TensorEvaluator<ArgType, Device> m_impl;
126template <
typename ReduceOp,
typename Dims,
typename XprType>
127struct traits<TensorPairReducerOp<ReduceOp, Dims, XprType>> :
public traits<XprType> {
128 typedef traits<XprType> XprTraits;
129 typedef typename XprTraits::StorageKind StorageKind;
130 typedef typename XprTraits::Index Index;
131 typedef Index Scalar;
132 typedef typename XprType::Nested Nested;
133 typedef std::remove_reference_t<Nested> Nested_;
134 static constexpr int NumDimensions = XprTraits::NumDimensions - array_size<Dims>::value;
135 static constexpr int Layout = XprTraits::Layout;
138template <
typename ReduceOp,
typename Dims,
typename XprType>
139struct eval<TensorPairReducerOp<ReduceOp, Dims, XprType>,
Eigen::Dense> {
140 typedef const TensorPairReducerOp<ReduceOp, Dims, XprType> EIGEN_DEVICE_REF type;
143template <
typename ReduceOp,
typename Dims,
typename XprType>
144struct nested<TensorPairReducerOp<ReduceOp, Dims, XprType>, 1,
145 typename eval<TensorPairReducerOp<ReduceOp, Dims, XprType>>::type> {
146 typedef TensorPairReducerOp<ReduceOp, Dims, XprType> type;
151template <
typename ReduceOp,
typename Dims,
typename XprType>
152class TensorPairReducerOp :
public TensorBase<TensorPairReducerOp<ReduceOp, Dims, XprType>, ReadOnlyAccessors> {
154 typedef typename Eigen::internal::traits<TensorPairReducerOp>::Scalar Scalar;
155 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
156 typedef typename Eigen::internal::nested<TensorPairReducerOp>::type Nested;
157 typedef typename Eigen::internal::traits<TensorPairReducerOp>::StorageKind StorageKind;
158 typedef typename Eigen::internal::traits<TensorPairReducerOp>::Index Index;
159 typedef Index CoeffReturnType;
161 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorPairReducerOp(
const XprType& expr,
const ReduceOp& reduce_op,
162 const Index return_dim,
const Dims& reduce_dims)
163 : m_xpr(expr), m_reduce_op(reduce_op), m_return_dim(return_dim), m_reduce_dims(reduce_dims) {}
165 EIGEN_DEVICE_FUNC
const internal::remove_all_t<typename XprType::Nested>& expression()
const {
return m_xpr; }
167 EIGEN_DEVICE_FUNC
const ReduceOp& reduce_op()
const {
return m_reduce_op; }
169 EIGEN_DEVICE_FUNC
const Dims& reduce_dims()
const {
return m_reduce_dims; }
171 EIGEN_DEVICE_FUNC Index return_dim()
const {
return m_return_dim; }
174 typename XprType::Nested m_xpr;
175 const ReduceOp m_reduce_op;
176 const Index m_return_dim;
177 const Dims m_reduce_dims;
181template <
typename ReduceOp,
typename Dims,
typename ArgType,
typename Device>
182struct TensorEvaluator<const TensorPairReducerOp<ReduceOp, Dims, ArgType>, Device> {
183 typedef TensorPairReducerOp<ReduceOp, Dims, ArgType> XprType;
184 typedef typename XprType::Index Index;
185 typedef typename XprType::Scalar Scalar;
186 typedef typename XprType::CoeffReturnType CoeffReturnType;
187 typedef typename TensorIndexPairOp<ArgType>::CoeffReturnType PairType;
188 typedef typename TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexPairOp<ArgType>>,
189 Device>::Dimensions Dimensions;
190 typedef typename TensorEvaluator<const TensorIndexPairOp<ArgType>, Device>::Dimensions InputDimensions;
191 static constexpr int NumDims = internal::array_size<InputDimensions>::value;
192 typedef array<Index, NumDims> StrideDims;
193 typedef StorageMemory<CoeffReturnType, Device> Storage;
194 typedef typename Storage::Type EvaluatorPointerType;
195 typedef StorageMemory<PairType, Device> PairStorageMem;
199 PacketAccess =
false,
201 PreferBlockAccess = TensorEvaluator<ArgType, Device>::PreferBlockAccess,
205 static constexpr int Layout =
206 TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexPairOp<ArgType>>, Device>::Layout;
208 typedef internal::TensorBlockNotImplemented TensorBlock;
211 EIGEN_STRONG_INLINE TensorEvaluator(
const XprType& op,
const Device& device)
212 : m_orig_impl(op.expression(), device),
213 m_impl(op.expression().index_pairs().reduce(op.reduce_dims(), op.reduce_op()), device),
214 m_return_dim(op.return_dim()) {
215 gen_strides(m_orig_impl.dimensions(), m_strides);
216 if (Layout ==
static_cast<int>(
ColMajor)) {
217 const Index total_size = internal::array_prod(m_orig_impl.dimensions());
218 m_stride_mod = (m_return_dim < NumDims - 1) ? m_strides[m_return_dim + 1] : total_size;
220 const Index total_size = internal::array_prod(m_orig_impl.dimensions());
221 m_stride_mod = (m_return_dim > 0) ? m_strides[m_return_dim - 1] : total_size;
225 ((m_return_dim >= 0) && (m_return_dim < static_cast<Index>(m_strides.size()))) ? m_strides[m_return_dim] : 1;
228 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Dimensions& dimensions()
const {
return m_impl.dimensions(); }
230 EIGEN_STRONG_INLINE
bool evalSubExprsIfNeeded(EvaluatorPointerType ) {
231 m_impl.evalSubExprsIfNeeded(NULL);
234 EIGEN_STRONG_INLINE
void cleanup() { m_impl.cleanup(); }
236 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index)
const {
237 const PairType v = m_impl.coeff(index);
238 return (m_return_dim < 0) ? v.first : (v.first % m_stride_mod) / m_stride_div;
241 EIGEN_DEVICE_FUNC EvaluatorPointerType data()
const {
return NULL; }
243 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(
bool vectorized)
const {
244 const double compute_cost =
245 1.0 + (m_return_dim < 0 ? 0.0 : (TensorOpCost::ModCost<Index>() + TensorOpCost::DivCost<Index>()));
246 return m_orig_impl.costPerCoeff(vectorized) + m_impl.costPerCoeff(vectorized) + TensorOpCost(0, 0, compute_cost);
250 EIGEN_DEVICE_FUNC
void gen_strides(
const InputDimensions& dims, StrideDims& strides) {
251 if (m_return_dim < 0) {
254 eigen_assert(m_return_dim < NumDims &&
"Asking to convert index to a dimension outside of the rank");
258 if (Layout ==
static_cast<int>(
ColMajor)) {
260 for (
int i = 1; i < NumDims; ++i) {
261 strides[i] = strides[i - 1] * dims[i - 1];
264 strides[NumDims - 1] = 1;
265 for (
int i = NumDims - 2; i >= 0; --i) {
266 strides[i] = strides[i + 1] * dims[i + 1];
272 TensorEvaluator<const TensorIndexPairOp<ArgType>, Device> m_orig_impl;
273 TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexPairOp<ArgType>>, Device> m_impl;
274 const Index m_return_dim;
275 StrideDims m_strides;
The tensor base class.
Definition TensorForwardDeclarations.h:68
Tensor + Index Pair class.
Definition TensorArgMax.h:50
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The tensor evaluator class.
Definition TensorEvaluator.h:30