Loading...
Searching...
No Matches
TensorCustomOp.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
11#define EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
12
13namespace Eigen {
14
15namespace internal {
16template<typename CustomUnaryFunc, typename XprType>
17struct traits<TensorCustomUnaryOp<CustomUnaryFunc, XprType> >
18{
19 typedef typename XprType::Scalar Scalar;
20 typedef typename XprType::StorageKind StorageKind;
21 typedef typename XprType::Index Index;
22 typedef typename XprType::Nested Nested;
23 typedef typename remove_reference<Nested>::type _Nested;
24 static const int NumDimensions = traits<XprType>::NumDimensions;
25 static const int Layout = traits<XprType>::Layout;
26};
27
28template<typename CustomUnaryFunc, typename XprType>
29struct eval<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Eigen::Dense>
30{
31 typedef const TensorCustomUnaryOp<CustomUnaryFunc, XprType>& type;
32};
33
34template<typename CustomUnaryFunc, typename XprType>
35struct nested<TensorCustomUnaryOp<CustomUnaryFunc, XprType> >
36{
37 typedef TensorCustomUnaryOp<CustomUnaryFunc, XprType> type;
38};
39
40} // end namespace internal
41
47template<typename CustomUnaryFunc, typename XprType>
48class TensorCustomUnaryOp : public TensorBase<TensorCustomUnaryOp<CustomUnaryFunc, XprType>, ReadOnlyAccessors>
49{
50 public:
51 typedef typename internal::traits<TensorCustomUnaryOp>::Scalar Scalar;
52 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
53 typedef typename XprType::CoeffReturnType CoeffReturnType;
54 typedef typename internal::nested<TensorCustomUnaryOp>::type Nested;
55 typedef typename internal::traits<TensorCustomUnaryOp>::StorageKind StorageKind;
56 typedef typename internal::traits<TensorCustomUnaryOp>::Index Index;
57
58 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomUnaryOp(const XprType& expr, const CustomUnaryFunc& func)
59 : m_expr(expr), m_func(func) {}
60
61 EIGEN_DEVICE_FUNC
62 const CustomUnaryFunc& func() const { return m_func; }
63
64 EIGEN_DEVICE_FUNC
65 const typename internal::remove_all<typename XprType::Nested>::type&
66 expression() const { return m_expr; }
67
68 protected:
69 typename XprType::Nested m_expr;
70 const CustomUnaryFunc m_func;
71};
72
73
74// Eval as rvalue
75template<typename CustomUnaryFunc, typename XprType, typename Device>
76struct TensorEvaluator<const TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Device>
77{
79 typedef typename internal::traits<ArgType>::Index Index;
80 static const int NumDims = internal::traits<ArgType>::NumDimensions;
81 typedef DSizes<Index, NumDims> Dimensions;
82 typedef typename internal::remove_const<typename ArgType::Scalar>::type Scalar;
83 typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
84 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
85 static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
86
87 enum {
88 IsAligned = false,
89 PacketAccess = (internal::packet_traits<Scalar>::size > 1),
90 BlockAccess = false,
91 Layout = TensorEvaluator<XprType, Device>::Layout,
92 CoordAccess = false, // to be implemented
93 RawAccess = false
94 };
95
96 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const ArgType& op, const Device& device)
97 : m_op(op), m_device(device), m_result(NULL)
98 {
99 m_dimensions = op.func().dimensions(op.expression());
100 }
101
102 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
103
104 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) {
105 if (data) {
106 evalTo(data);
107 return false;
108 } else {
109 m_result = static_cast<CoeffReturnType*>(
110 m_device.allocate(dimensions().TotalSize() * sizeof(Scalar)));
111 evalTo(m_result);
112 return true;
113 }
114 }
115
116 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
117 if (m_result != NULL) {
118 m_device.deallocate(m_result);
119 m_result = NULL;
120 }
121 }
122
123 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
124 return m_result[index];
125 }
126
127 template<int LoadMode>
128 EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const {
129 return internal::ploadt<PacketReturnType, LoadMode>(m_result + index);
130 }
131
132 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
133 // TODO(rmlarsen): Extend CustomOp API to return its cost estimate.
134 return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize);
135 }
136
137 EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return m_result; }
138
139 protected:
140 EIGEN_DEVICE_FUNC void evalTo(Scalar* data) {
141 TensorMap<Tensor<CoeffReturnType, NumDims, Layout, Index> > result(
142 data, m_dimensions);
143 m_op.func().eval(m_op.expression(), result, m_device);
144 }
145
146 Dimensions m_dimensions;
147 const ArgType m_op;
148 const Device& m_device;
149 CoeffReturnType* m_result;
150};
151
152
153
161namespace internal {
162template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
163struct traits<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> >
164{
165 typedef typename internal::promote_storage_type<typename LhsXprType::Scalar,
166 typename RhsXprType::Scalar>::ret Scalar;
167 typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType,
168 typename RhsXprType::CoeffReturnType>::ret CoeffReturnType;
169 typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind,
170 typename traits<RhsXprType>::StorageKind>::ret StorageKind;
171 typedef typename promote_index_type<typename traits<LhsXprType>::Index,
172 typename traits<RhsXprType>::Index>::type Index;
173 typedef typename LhsXprType::Nested LhsNested;
174 typedef typename RhsXprType::Nested RhsNested;
175 typedef typename remove_reference<LhsNested>::type _LhsNested;
176 typedef typename remove_reference<RhsNested>::type _RhsNested;
177 static const int NumDimensions = traits<LhsXprType>::NumDimensions;
178 static const int Layout = traits<LhsXprType>::Layout;
179};
180
181template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
182struct eval<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Eigen::Dense>
183{
184 typedef const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>& type;
185};
186
187template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
188struct nested<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> >
189{
190 typedef TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType> type;
191};
192
193} // end namespace internal
194
195
196
197template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType>
198class TensorCustomBinaryOp : public TensorBase<TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, ReadOnlyAccessors>
199{
200 public:
201 typedef typename internal::traits<TensorCustomBinaryOp>::Scalar Scalar;
202 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
203 typedef typename internal::traits<TensorCustomBinaryOp>::CoeffReturnType CoeffReturnType;
204 typedef typename internal::nested<TensorCustomBinaryOp>::type Nested;
205 typedef typename internal::traits<TensorCustomBinaryOp>::StorageKind StorageKind;
206 typedef typename internal::traits<TensorCustomBinaryOp>::Index Index;
207
208 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCustomBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const CustomBinaryFunc& func)
209
210 : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_func(func) {}
211
212 EIGEN_DEVICE_FUNC
213 const CustomBinaryFunc& func() const { return m_func; }
214
215 EIGEN_DEVICE_FUNC
216 const typename internal::remove_all<typename LhsXprType::Nested>::type&
217 lhsExpression() const { return m_lhs_xpr; }
218
219 EIGEN_DEVICE_FUNC
220 const typename internal::remove_all<typename RhsXprType::Nested>::type&
221 rhsExpression() const { return m_rhs_xpr; }
222
223 protected:
224 typename LhsXprType::Nested m_lhs_xpr;
225 typename RhsXprType::Nested m_rhs_xpr;
226 const CustomBinaryFunc m_func;
227};
228
229
230// Eval as rvalue
231template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType, typename Device>
232struct TensorEvaluator<const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType, RhsXprType>, Device>
233{
235 typedef typename internal::traits<XprType>::Index Index;
236 static const int NumDims = internal::traits<XprType>::NumDimensions;
237 typedef DSizes<Index, NumDims> Dimensions;
238 typedef typename XprType::Scalar Scalar;
239 typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
240 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
241 static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
242
243 enum {
244 IsAligned = false,
245 PacketAccess = (internal::packet_traits<Scalar>::size > 1),
246 BlockAccess = false,
247 Layout = TensorEvaluator<LhsXprType, Device>::Layout,
248 CoordAccess = false, // to be implemented
249 RawAccess = false
250 };
251
252 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
253 : m_op(op), m_device(device), m_result(NULL)
254 {
255 m_dimensions = op.func().dimensions(op.lhsExpression(), op.rhsExpression());
256 }
257
258 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
259
260 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) {
261 if (data) {
262 evalTo(data);
263 return false;
264 } else {
265 m_result = static_cast<Scalar *>(m_device.allocate(dimensions().TotalSize() * sizeof(Scalar)));
266 evalTo(m_result);
267 return true;
268 }
269 }
270
271 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
272 if (m_result != NULL) {
273 m_device.deallocate(m_result);
274 m_result = NULL;
275 }
276 }
277
278 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
279 return m_result[index];
280 }
281
282 template<int LoadMode>
283 EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const {
284 return internal::ploadt<PacketReturnType, LoadMode>(m_result + index);
285 }
286
287 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
288 // TODO(rmlarsen): Extend CustomOp API to return its cost estimate.
289 return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize);
290 }
291
292 EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return m_result; }
293
294 protected:
295 EIGEN_DEVICE_FUNC void evalTo(Scalar* data) {
296 TensorMap<Tensor<Scalar, NumDims, Layout> > result(data, m_dimensions);
297 m_op.func().eval(m_op.lhsExpression(), m_op.rhsExpression(), result, m_device);
298 }
299
300 Dimensions m_dimensions;
301 const XprType m_op;
302 const Device& m_device;
303 CoeffReturnType* m_result;
304};
305
306
307} // end namespace Eigen
308
309#endif // EIGEN_CXX11_TENSOR_TENSOR_CUSTOM_OP_H
The tensor base class.
Definition TensorForwardDeclarations.h:29
Tensor custom class.
Definition TensorCustomOp.h:199
Tensor custom class.
Definition TensorCustomOp.h:49
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The tensor evaluator class.
Definition TensorEvaluator.h:27
const Device & device() const
required by sycl in order to construct sycl buffer from raw pointer
Definition TensorEvaluator.h:112