Eigen-unsupported  5.0.1-dev+284dcc12
 
Loading...
Searching...
No Matches
TensorCustomOp.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
11#define EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
12
13// IWYU pragma: private
14#include "./InternalHeaderCheck.h"
15
16namespace Eigen {
17
18namespace internal {
19template <typename CustomUnaryFunc, typename XprType>
20struct traits<TensorCustomUnaryOp<CustomUnaryFunc, XprType> > {
21 typedef typename XprType::Scalar Scalar;
22 typedef typename XprType::StorageKind StorageKind;
23 typedef typename XprType::Index Index;
24 typedef typename XprType::Nested Nested;
25 typedef std::remove_reference_t<Nested> Nested_;
26 static constexpr int NumDimensions = traits<XprType>::NumDimensions;
27 static constexpr int Layout = traits<XprType>::Layout;
28 typedef typename traits<XprType>::PointerType PointerType;
29};
30
31template <typename CustomUnaryFunc, typename XprType>
32struct eval<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Eigen::Dense> {
33 typedef const TensorCustomUnaryOp<CustomUnaryFunc, XprType> EIGEN_DEVICE_REF type;
34};
35
36template <typename CustomUnaryFunc, typename XprType>
37struct nested<TensorCustomUnaryOp<CustomUnaryFunc, XprType> > {
38 typedef TensorCustomUnaryOp<CustomUnaryFunc, XprType> type;
39};
40
41} // end namespace internal
42
48template <typename CustomUnaryFunc, typename XprType>
49class TensorCustomUnaryOp : public TensorBase<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, ReadOnlyAccessors> {
50 public:
51 typedef typename internal::traits<TensorCustomUnaryOp>::Scalar Scalar;
52 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
53 typedef typename XprType::CoeffReturnType CoeffReturnType;
54 typedef typename internal::nested<TensorCustomUnaryOp>::type Nested;
55 typedef typename internal::traits<TensorCustomUnaryOp>::StorageKind StorageKind;
56 typedef typename internal::traits<TensorCustomUnaryOp>::Index Index;
57
58 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomUnaryOp(const XprType& expr, const CustomUnaryFunc& func)
59 : m_expr(expr), m_func(func) {}
60
61 EIGEN_DEVICE_FUNC const CustomUnaryFunc& func() const { return m_func; }
62
63 EIGEN_DEVICE_FUNC const internal::remove_all_t<typename XprType::Nested>& expression() const { return m_expr; }
64
65 protected:
66 typename XprType::Nested m_expr;
67 const CustomUnaryFunc m_func;
68};
69
70// Eval as rvalue
71template <typename CustomUnaryFunc, typename XprType, typename Device>
72struct TensorEvaluator<const TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Device> {
74 typedef typename internal::traits<ArgType>::Index Index;
75 static constexpr int NumDims = internal::traits<ArgType>::NumDimensions;
76 typedef DSizes<Index, NumDims> Dimensions;
77 typedef std::remove_const_t<typename ArgType::Scalar> Scalar;
78 typedef std::remove_const_t<typename XprType::CoeffReturnType> CoeffReturnType;
79 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
80 static constexpr int PacketSize = PacketType<CoeffReturnType, Device>::size;
81 typedef typename Eigen::internal::traits<XprType>::PointerType TensorPointerType;
82 typedef StorageMemory<CoeffReturnType, Device> Storage;
83 typedef typename Storage::Type EvaluatorPointerType;
84
85 static constexpr int Layout = TensorEvaluator<XprType, Device>::Layout;
86 enum {
87 IsAligned = false,
88 PacketAccess = (PacketType<CoeffReturnType, Device>::size > 1),
89 BlockAccess = false,
90 PreferBlockAccess = false,
91 CoordAccess = false, // to be implemented
92 RawAccess = false
93 };
94
95 //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
96 typedef internal::TensorBlockNotImplemented TensorBlock;
97 //===--------------------------------------------------------------------===//
98
99 EIGEN_STRONG_INLINE TensorEvaluator(const ArgType& op, const Device& device)
100 : m_op(op), m_device(device), m_result(NULL) {
101 m_dimensions = op.func().dimensions(op.expression());
102 }
103
104 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
105
106 EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType data) {
107 if (data) {
108 evalTo(data);
109 return false;
110 } else {
111 m_result = static_cast<EvaluatorPointerType>(
112 m_device.get((CoeffReturnType*)m_device.allocate_temp(dimensions().TotalSize() * sizeof(Scalar))));
113 evalTo(m_result);
114 return true;
115 }
116 }
117
118 EIGEN_STRONG_INLINE void cleanup() {
119 if (m_result) {
120 m_device.deallocate_temp(m_result);
121 m_result = NULL;
122 }
123 }
124
125 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { return m_result[index]; }
126
127 template <int LoadMode>
128 EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const {
129 return internal::ploadt<PacketReturnType, LoadMode>(m_result + index);
130 }
131
132 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
133 // TODO(rmlarsen): Extend CustomOp API to return its cost estimate.
134 return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize);
135 }
136
137 EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return m_result; }
138
139 protected:
140 void evalTo(EvaluatorPointerType data) {
141 TensorMap<Tensor<CoeffReturnType, NumDims, Layout, Index> > result(m_device.get(data), m_dimensions);
142 m_op.func().eval(m_op.expression(), result, m_device);
143 }
144
145 Dimensions m_dimensions;
146 const ArgType m_op;
147 const Device EIGEN_DEVICE_REF m_device;
148 EvaluatorPointerType m_result;
149};
150
158namespace internal {
159template <typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
160struct traits<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> > {
161 typedef typename internal::promote_storage_type<typename LhsXprType::Scalar, typename RhsXprType::Scalar>::ret Scalar;
162 typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType,
163 typename RhsXprType::CoeffReturnType>::ret CoeffReturnType;
164 typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind,
165 typename traits<RhsXprType>::StorageKind>::ret StorageKind;
166 typedef
167 typename promote_index_type<typename traits<LhsXprType>::Index, typename traits<RhsXprType>::Index>::type Index;
168 typedef typename LhsXprType::Nested LhsNested;
169 typedef typename RhsXprType::Nested RhsNested;
170 typedef std::remove_reference_t<LhsNested> LhsNested_;
171 typedef std::remove_reference_t<RhsNested> RhsNested_;
172 static constexpr int NumDimensions = traits<LhsXprType>::NumDimensions;
173 static constexpr int Layout = traits<LhsXprType>::Layout;
174 typedef std::conditional_t<Pointer_type_promotion<typename LhsXprType::Scalar, Scalar>::val,
175 typename traits<LhsXprType>::PointerType, typename traits<RhsXprType>::PointerType>
176 PointerType;
177};
178
179template <typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
180struct eval<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Eigen::Dense> {
181 typedef const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>& type;
182};
183
184template <typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
185struct nested<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> > {
186 typedef TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> type;
187};
188
189} // end namespace internal
190
191template <typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
192class TensorCustomBinaryOp
193 : public TensorBase<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, ReadOnlyAccessors> {
194 public:
195 typedef typename internal::traits<TensorCustomBinaryOp>::Scalar Scalar;
196 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
197 typedef typename internal::traits<TensorCustomBinaryOp>::CoeffReturnType CoeffReturnType;
198 typedef typename internal::nested<TensorCustomBinaryOp>::type Nested;
199 typedef typename internal::traits<TensorCustomBinaryOp>::StorageKind StorageKind;
200 typedef typename internal::traits<TensorCustomBinaryOp>::Index Index;
201
202 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs,
203 const CustomBinaryFunc& func)
204
205 : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_func(func) {}
206
207 EIGEN_DEVICE_FUNC const CustomBinaryFunc& func() const { return m_func; }
208
209 EIGEN_DEVICE_FUNC const internal::remove_all_t<typename LhsXprType::Nested>& lhsExpression() const {
210 return m_lhs_xpr;
211 }
212
213 EIGEN_DEVICE_FUNC const internal::remove_all_t<typename RhsXprType::Nested>& rhsExpression() const {
214 return m_rhs_xpr;
215 }
216
217 protected:
218 typename LhsXprType::Nested m_lhs_xpr;
219 typename RhsXprType::Nested m_rhs_xpr;
220 const CustomBinaryFunc m_func;
221};
222
223// Eval as rvalue
224template <typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType, typename Device>
225struct TensorEvaluator<const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Device> {
227 typedef typename internal::traits<XprType>::Index Index;
228 static constexpr int NumDims = internal::traits<XprType>::NumDimensions;
229 typedef DSizes<Index, NumDims> Dimensions;
230 typedef typename XprType::Scalar Scalar;
231 typedef std::remove_const_t<typename XprType::CoeffReturnType> CoeffReturnType;
232 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
233 static constexpr int PacketSize = PacketType<CoeffReturnType, Device>::size;
234
235 typedef typename Eigen::internal::traits<XprType>::PointerType TensorPointerType;
236 typedef StorageMemory<CoeffReturnType, Device> Storage;
237 typedef typename Storage::Type EvaluatorPointerType;
238
239 static constexpr int Layout = TensorEvaluator<LhsXprType, Device>::Layout;
240 enum {
241 IsAligned = false,
242 PacketAccess = (PacketType<CoeffReturnType, Device>::size > 1),
243 BlockAccess = false,
244 PreferBlockAccess = false,
245 CoordAccess = false, // to be implemented
246 RawAccess = false
247 };
248
249 //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
250 typedef internal::TensorBlockNotImplemented TensorBlock;
251 //===--------------------------------------------------------------------===//
252
253 EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
254 : m_op(op), m_device(device), m_result(NULL) {
255 m_dimensions = op.func().dimensions(op.lhsExpression(), op.rhsExpression());
256 }
257
258 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
259
260 EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType data) {
261 if (data) {
262 evalTo(data);
263 return false;
264 } else {
265 m_result = static_cast<EvaluatorPointerType>(
266 m_device.get((CoeffReturnType*)m_device.allocate_temp(dimensions().TotalSize() * sizeof(CoeffReturnType))));
267 evalTo(m_result);
268 return true;
269 }
270 }
271
272 EIGEN_STRONG_INLINE void cleanup() {
273 if (m_result != NULL) {
274 m_device.deallocate_temp(m_result);
275 m_result = NULL;
276 }
277 }
278
279 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { return m_result[index]; }
280
281 template <int LoadMode>
282 EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const {
283 return internal::ploadt<PacketReturnType, LoadMode>(m_result + index);
284 }
285
286 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
287 // TODO(rmlarsen): Extend CustomOp API to return its cost estimate.
288 return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize);
289 }
290
291 EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return m_result; }
292
293 protected:
294 void evalTo(EvaluatorPointerType data) {
295 TensorMap<Tensor<CoeffReturnType, NumDims, Layout> > result(m_device.get(data), m_dimensions);
296 m_op.func().eval(m_op.lhsExpression(), m_op.rhsExpression(), result, m_device);
297 }
298
299 Dimensions m_dimensions;
300 const XprType m_op;
301 const Device EIGEN_DEVICE_REF m_device;
302 EvaluatorPointerType m_result;
303};
304
305} // end namespace Eigen
306
307#endif // EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
The tensor base class.
Definition TensorForwardDeclarations.h:68
Tensor custom class.
Definition TensorCustomOp.h:193
Tensor custom class.
Definition TensorCustomOp.h:49
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The tensor evaluator class.
Definition TensorEvaluator.h:30