Eigen-unsupported  3.4.1 (git rev 28ded8800c26864e537852658428ab44c8399e87)
 
Loading...
Searching...
No Matches
TensorCustomOp.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
11#define EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
12
13namespace Eigen {
14
15namespace internal {
16template<typename CustomUnaryFunc, typename XprType>
17struct traits<TensorCustomUnaryOp<CustomUnaryFunc, XprType> >
18{
19 typedef typename XprType::Scalar Scalar;
20 typedef typename XprType::StorageKind StorageKind;
21 typedef typename XprType::Index Index;
22 typedef typename XprType::Nested Nested;
23 typedef typename remove_reference<Nested>::type _Nested;
24 static const int NumDimensions = traits<XprType>::NumDimensions;
25 static const int Layout = traits<XprType>::Layout;
26 typedef typename traits<XprType>::PointerType PointerType;
27};
28
29template<typename CustomUnaryFunc, typename XprType>
30struct eval<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Eigen::Dense>
31{
32 typedef const TensorCustomUnaryOp<CustomUnaryFunc, XprType>EIGEN_DEVICE_REF type;
33};
34
35template<typename CustomUnaryFunc, typename XprType>
36struct nested<TensorCustomUnaryOp<CustomUnaryFunc, XprType> >
37{
38 typedef TensorCustomUnaryOp<CustomUnaryFunc, XprType> type;
39};
40
41} // end namespace internal
42
48template <typename CustomUnaryFunc, typename XprType>
49class TensorCustomUnaryOp : public TensorBase<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, ReadOnlyAccessors> {
50 public:
51 typedef typename internal::traits<TensorCustomUnaryOp>::Scalar Scalar;
52 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
53 typedef typename XprType::CoeffReturnType CoeffReturnType;
54 typedef typename internal::nested<TensorCustomUnaryOp>::type Nested;
55 typedef typename internal::traits<TensorCustomUnaryOp>::StorageKind StorageKind;
56 typedef typename internal::traits<TensorCustomUnaryOp>::Index Index;
57
58 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomUnaryOp(const XprType& expr, const CustomUnaryFunc& func)
59 : m_expr(expr), m_func(func) {}
60
61 EIGEN_DEVICE_FUNC
62 const CustomUnaryFunc& func() const { return m_func; }
63
64 EIGEN_DEVICE_FUNC
65 const typename internal::remove_all<typename XprType::Nested>::type&
66 expression() const { return m_expr; }
67
68 protected:
69 typename XprType::Nested m_expr;
70 const CustomUnaryFunc m_func;
71};
72
73
74// Eval as rvalue
75template<typename CustomUnaryFunc, typename XprType, typename Device>
76struct TensorEvaluator<const TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Device>
77{
79 typedef typename internal::traits<ArgType>::Index Index;
80 static const int NumDims = internal::traits<ArgType>::NumDimensions;
81 typedef DSizes<Index, NumDims> Dimensions;
82 typedef typename internal::remove_const<typename ArgType::Scalar>::type Scalar;
83 typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
84 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
85 static const int PacketSize = PacketType<CoeffReturnType, Device>::size;
86 typedef typename Eigen::internal::traits<XprType>::PointerType TensorPointerType;
87 typedef StorageMemory<CoeffReturnType, Device> Storage;
88 typedef typename Storage::Type EvaluatorPointerType;
89
90 enum {
91 IsAligned = false,
92 PacketAccess = (PacketType<CoeffReturnType, Device>::size > 1),
93 BlockAccess = false,
94 PreferBlockAccess = false,
95 Layout = TensorEvaluator<XprType, Device>::Layout,
96 CoordAccess = false, // to be implemented
97 RawAccess = false
98 };
99
100 //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
101 typedef internal::TensorBlockNotImplemented TensorBlock;
102 //===--------------------------------------------------------------------===//
103
104 EIGEN_STRONG_INLINE TensorEvaluator(const ArgType& op, const Device& device)
105 : m_op(op), m_device(device), m_result(NULL)
106 {
107 m_dimensions = op.func().dimensions(op.expression());
108 }
109
110 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
111
112 EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType data) {
113 if (data) {
114 evalTo(data);
115 return false;
116 } else {
117 m_result = static_cast<EvaluatorPointerType>(m_device.get( (CoeffReturnType*)
118 m_device.allocate_temp(dimensions().TotalSize() * sizeof(Scalar))));
119 evalTo(m_result);
120 return true;
121 }
122 }
123
124 EIGEN_STRONG_INLINE void cleanup() {
125 if (m_result) {
126 m_device.deallocate_temp(m_result);
127 m_result = NULL;
128 }
129 }
130
131 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
132 return m_result[index];
133 }
134
135 template<int LoadMode>
136 EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const {
137 return internal::ploadt<PacketReturnType, LoadMode>(m_result + index);
138 }
139
140 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
141 // TODO(rmlarsen): Extend CustomOp API to return its cost estimate.
142 return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize);
143 }
144
145 EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return m_result; }
146
147#ifdef EIGEN_USE_SYCL
148 // binding placeholder accessors to a command group handler for SYCL
149 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {
150 m_result.bind(cgh);
151 }
152#endif
153
154 protected:
155 void evalTo(EvaluatorPointerType data) {
156 TensorMap<Tensor<CoeffReturnType, NumDims, Layout, Index> > result(m_device.get(data), m_dimensions);
157 m_op.func().eval(m_op.expression(), result, m_device);
158 }
159
160 Dimensions m_dimensions;
161 const ArgType m_op;
162 const Device EIGEN_DEVICE_REF m_device;
163 EvaluatorPointerType m_result;
164};
165
166
167
175namespace internal {
176template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
177struct traits<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> >
178{
179 typedef typename internal::promote_storage_type<typename LhsXprType::Scalar,
180 typename RhsXprType::Scalar>::ret Scalar;
181 typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType,
182 typename RhsXprType::CoeffReturnType>::ret CoeffReturnType;
183 typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind,
184 typename traits<RhsXprType>::StorageKind>::ret StorageKind;
185 typedef typename promote_index_type<typename traits<LhsXprType>::Index,
186 typename traits<RhsXprType>::Index>::type Index;
187 typedef typename LhsXprType::Nested LhsNested;
188 typedef typename RhsXprType::Nested RhsNested;
189 typedef typename remove_reference<LhsNested>::type _LhsNested;
190 typedef typename remove_reference<RhsNested>::type _RhsNested;
191 static const int NumDimensions = traits<LhsXprType>::NumDimensions;
192 static const int Layout = traits<LhsXprType>::Layout;
193 typedef typename conditional<Pointer_type_promotion<typename LhsXprType::Scalar, Scalar>::val,
194 typename traits<LhsXprType>::PointerType, typename traits<RhsXprType>::PointerType>::type PointerType;
195};
196
197template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
198struct eval<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Eigen::Dense>
199{
200 typedef const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>& type;
201};
202
203template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
204struct nested<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> >
205{
206 typedef TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> type;
207};
208
209} // end namespace internal
210
211
212
213template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
214class TensorCustomBinaryOp : public TensorBase<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, ReadOnlyAccessors>
215{
216 public:
217 typedef typename internal::traits<TensorCustomBinaryOp>::Scalar Scalar;
218 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
219 typedef typename internal::traits<TensorCustomBinaryOp>::CoeffReturnType CoeffReturnType;
220 typedef typename internal::nested<TensorCustomBinaryOp>::type Nested;
221 typedef typename internal::traits<TensorCustomBinaryOp>::StorageKind StorageKind;
222 typedef typename internal::traits<TensorCustomBinaryOp>::Index Index;
223
224 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const CustomBinaryFunc& func)
225
226 : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_func(func) {}
227
228 EIGEN_DEVICE_FUNC
229 const CustomBinaryFunc& func() const { return m_func; }
230
231 EIGEN_DEVICE_FUNC
232 const typename internal::remove_all<typename LhsXprType::Nested>::type&
233 lhsExpression() const { return m_lhs_xpr; }
234
235 EIGEN_DEVICE_FUNC
236 const typename internal::remove_all<typename RhsXprType::Nested>::type&
237 rhsExpression() const { return m_rhs_xpr; }
238
239 protected:
240 typename LhsXprType::Nested m_lhs_xpr;
241 typename RhsXprType::Nested m_rhs_xpr;
242 const CustomBinaryFunc m_func;
243};
244
245
246// Eval as rvalue
247template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType, typename Device>
248struct TensorEvaluator<const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Device>
249{
251 typedef typename internal::traits<XprType>::Index Index;
252 static const int NumDims = internal::traits<XprType>::NumDimensions;
253 typedef DSizes<Index, NumDims> Dimensions;
254 typedef typename XprType::Scalar Scalar;
255 typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
256 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
257 static const int PacketSize = PacketType<CoeffReturnType, Device>::size;
258
259 typedef typename Eigen::internal::traits<XprType>::PointerType TensorPointerType;
260 typedef StorageMemory<CoeffReturnType, Device> Storage;
261 typedef typename Storage::Type EvaluatorPointerType;
262
263 enum {
264 IsAligned = false,
265 PacketAccess = (PacketType<CoeffReturnType, Device>::size > 1),
266 BlockAccess = false,
267 PreferBlockAccess = false,
268 Layout = TensorEvaluator<LhsXprType, Device>::Layout,
269 CoordAccess = false, // to be implemented
270 RawAccess = false
271 };
272
273 //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
274 typedef internal::TensorBlockNotImplemented TensorBlock;
275 //===--------------------------------------------------------------------===//
276
277 EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
278 : m_op(op), m_device(device), m_result(NULL)
279 {
280 m_dimensions = op.func().dimensions(op.lhsExpression(), op.rhsExpression());
281 }
282
283 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
284
285 EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType data) {
286 if (data) {
287 evalTo(data);
288 return false;
289 } else {
290 m_result = static_cast<EvaluatorPointerType>(m_device.get( (CoeffReturnType*)
291 m_device.allocate_temp(dimensions().TotalSize() * sizeof(CoeffReturnType))));
292 evalTo(m_result);
293 return true;
294 }
295 }
296
297 EIGEN_STRONG_INLINE void cleanup() {
298 if (m_result != NULL) {
299 m_device.deallocate_temp(m_result);
300 m_result = NULL;
301 }
302 }
303
304 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
305 return m_result[index];
306 }
307
308 template<int LoadMode>
309 EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const {
310 return internal::ploadt<PacketReturnType, LoadMode>(m_result + index);
311 }
312
313 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
314 // TODO(rmlarsen): Extend CustomOp API to return its cost estimate.
315 return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize);
316 }
317
318 EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return m_result; }
319
320#ifdef EIGEN_USE_SYCL
321 // binding placeholder accessors to a command group handler for SYCL
322 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(cl::sycl::handler &cgh) const {
323 m_result.bind(cgh);
324 }
325#endif
326
327 protected:
328 void evalTo(EvaluatorPointerType data) {
329 TensorMap<Tensor<CoeffReturnType, NumDims, Layout> > result(m_device.get(data), m_dimensions);
330 m_op.func().eval(m_op.lhsExpression(), m_op.rhsExpression(), result, m_device);
331 }
332
333 Dimensions m_dimensions;
334 const XprType m_op;
335 const Device EIGEN_DEVICE_REF m_device;
336 EvaluatorPointerType m_result;
337};
338
339
340} // end namespace Eigen
341
342#endif // EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
The tensor base class.
Definition TensorForwardDeclarations.h:56
Tensor custom class.
Definition TensorCustomOp.h:215
Tensor custom class.
Definition TensorCustomOp.h:49
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The tensor evaluator class.
Definition TensorEvaluator.h:27