Loading...
Searching...
No Matches
TensorExpr.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
11#define EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
12
13namespace Eigen {
14
15namespace internal {
16template<typename NullaryOp, typename XprType>
17struct traits<TensorCwiseNullaryOp<NullaryOp, XprType> >
18 : traits<XprType>
19{
20 typedef traits<XprType> XprTraits;
21 typedef typename XprType::Scalar Scalar;
22 typedef typename XprType::Nested XprTypeNested;
23 typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
24 static const int NumDimensions = XprTraits::NumDimensions;
25 static const int Layout = XprTraits::Layout;
26
27 enum {
28 Flags = 0
29 };
30};
31
32} // end namespace internal
33
42template <typename NullaryOp, typename XprType>
43class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, XprType>, ReadOnlyAccessors> {
44 public:
45 typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Scalar Scalar;
46 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
47 typedef typename XprType::CoeffReturnType CoeffReturnType;
48 typedef TensorCwiseNullaryOp<NullaryOp, XprType> Nested;
49 typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::StorageKind StorageKind;
50 typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Index Index;
51
52 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseNullaryOp(const XprType& xpr, const NullaryOp& func = NullaryOp())
53 : m_xpr(xpr), m_functor(func) {}
54
55 EIGEN_DEVICE_FUNC
56 const typename internal::remove_all<typename XprType::Nested>::type&
57 nestedExpression() const { return m_xpr; }
58
59 EIGEN_DEVICE_FUNC
60 const NullaryOp& functor() const { return m_functor; }
61
62 protected:
63 typename XprType::Nested m_xpr;
64 const NullaryOp m_functor;
65};
66
67
68
69namespace internal {
70template<typename UnaryOp, typename XprType>
71struct traits<TensorCwiseUnaryOp<UnaryOp, XprType> >
72 : traits<XprType>
73{
74 // TODO(phli): Add InputScalar, InputPacket. Check references to
75 // current Scalar/Packet to see if the intent is Input or Output.
76 typedef typename result_of<UnaryOp(typename XprType::Scalar)>::type Scalar;
77 typedef traits<XprType> XprTraits;
78 typedef typename XprType::Nested XprTypeNested;
79 typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
80 static const int NumDimensions = XprTraits::NumDimensions;
81 static const int Layout = XprTraits::Layout;
82};
83
84template<typename UnaryOp, typename XprType>
85struct eval<TensorCwiseUnaryOp<UnaryOp, XprType>, Eigen::Dense>
86{
87 typedef const TensorCwiseUnaryOp<UnaryOp, XprType>& type;
88};
89
90template<typename UnaryOp, typename XprType>
91struct nested<TensorCwiseUnaryOp<UnaryOp, XprType>, 1, typename eval<TensorCwiseUnaryOp<UnaryOp, XprType> >::type>
92{
93 typedef TensorCwiseUnaryOp<UnaryOp, XprType> type;
94};
95
96} // end namespace internal
97
106template <typename UnaryOp, typename XprType>
107class TensorCwiseUnaryOp : public TensorBase<TensorCwiseUnaryOp<UnaryOp, XprType>, ReadOnlyAccessors> {
108 public:
109 // TODO(phli): Add InputScalar, InputPacket. Check references to
110 // current Scalar/Packet to see if the intent is Input or Output.
111 typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Scalar Scalar;
112 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
113 typedef Scalar CoeffReturnType;
114 typedef typename Eigen::internal::nested<TensorCwiseUnaryOp>::type Nested;
115 typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::StorageKind StorageKind;
116 typedef typename Eigen::internal::traits<TensorCwiseUnaryOp>::Index Index;
117
118 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseUnaryOp(const XprType& xpr, const UnaryOp& func = UnaryOp())
119 : m_xpr(xpr), m_functor(func) {}
120
121 EIGEN_DEVICE_FUNC
122 const UnaryOp& functor() const { return m_functor; }
123
125 EIGEN_DEVICE_FUNC
126 const typename internal::remove_all<typename XprType::Nested>::type&
127 nestedExpression() const { return m_xpr; }
128
129 protected:
130 typename XprType::Nested m_xpr;
131 const UnaryOp m_functor;
132};
133
134
135namespace internal {
136template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
137struct traits<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> >
138{
139 // Type promotion to handle the case where the types of the lhs and the rhs
140 // are different.
141 // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket. Check references to
142 // current Scalar/Packet to see if the intent is Inputs or Output.
143 typedef typename result_of<
144 BinaryOp(typename LhsXprType::Scalar,
145 typename RhsXprType::Scalar)>::type Scalar;
146 typedef traits<LhsXprType> XprTraits;
147 typedef typename promote_storage_type<
148 typename traits<LhsXprType>::StorageKind,
149 typename traits<RhsXprType>::StorageKind>::ret StorageKind;
150 typedef typename promote_index_type<
151 typename traits<LhsXprType>::Index,
152 typename traits<RhsXprType>::Index>::type Index;
153 typedef typename LhsXprType::Nested LhsNested;
154 typedef typename RhsXprType::Nested RhsNested;
155 typedef typename remove_reference<LhsNested>::type _LhsNested;
156 typedef typename remove_reference<RhsNested>::type _RhsNested;
157 static const int NumDimensions = XprTraits::NumDimensions;
158 static const int Layout = XprTraits::Layout;
159
160 enum {
161 Flags = 0
162 };
163};
164
165template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
166struct eval<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, Eigen::Dense>
167{
168 typedef const TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>& type;
169};
170
171template<typename BinaryOp, typename LhsXprType, typename RhsXprType>
172struct nested<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, 1, typename eval<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> >::type>
173{
174 typedef TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType> type;
175};
176
177} // end namespace internal
178
187template <typename BinaryOp, typename LhsXprType, typename RhsXprType>
188class TensorCwiseBinaryOp
189 : public TensorBase<TensorCwiseBinaryOp<BinaryOp, LhsXprType, RhsXprType>, ReadOnlyAccessors> {
190 public:
191 // TODO(phli): Add Lhs/RhsScalar, Lhs/RhsPacket. Check references to
192 // current Scalar/Packet to see if the intent is Inputs or Output.
193 typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Scalar Scalar;
194 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
195 typedef Scalar CoeffReturnType;
196 typedef typename Eigen::internal::nested<TensorCwiseBinaryOp>::type Nested;
197 typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::StorageKind StorageKind;
198 typedef typename Eigen::internal::traits<TensorCwiseBinaryOp>::Index Index;
199
200 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseBinaryOp(const LhsXprType& lhs, const RhsXprType& rhs, const BinaryOp& func = BinaryOp())
201 : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_functor(func) {}
202
203 EIGEN_DEVICE_FUNC
204 const BinaryOp& functor() const { return m_functor; }
205
207 EIGEN_DEVICE_FUNC
208 const typename internal::remove_all<typename LhsXprType::Nested>::type&
209 lhsExpression() const { return m_lhs_xpr; }
210
211 EIGEN_DEVICE_FUNC
212 const typename internal::remove_all<typename RhsXprType::Nested>::type&
213 rhsExpression() const { return m_rhs_xpr; }
214
215 protected:
216 typename LhsXprType::Nested m_lhs_xpr;
217 typename RhsXprType::Nested m_rhs_xpr;
218 const BinaryOp m_functor;
219};
220
221
222namespace internal {
223template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
224struct traits<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >
225{
226 // Type promotion to handle the case where the types of the args are different.
227 typedef typename result_of<
228 TernaryOp(typename Arg1XprType::Scalar,
229 typename Arg2XprType::Scalar,
230 typename Arg3XprType::Scalar)>::type Scalar;
231 typedef traits<Arg1XprType> XprTraits;
232 typedef typename traits<Arg1XprType>::StorageKind StorageKind;
233 typedef typename traits<Arg1XprType>::Index Index;
234 typedef typename Arg1XprType::Nested Arg1Nested;
235 typedef typename Arg2XprType::Nested Arg2Nested;
236 typedef typename Arg3XprType::Nested Arg3Nested;
237 typedef typename remove_reference<Arg1Nested>::type _Arg1Nested;
238 typedef typename remove_reference<Arg2Nested>::type _Arg2Nested;
239 typedef typename remove_reference<Arg3Nested>::type _Arg3Nested;
240 static const int NumDimensions = XprTraits::NumDimensions;
241 static const int Layout = XprTraits::Layout;
242
243 enum {
244 Flags = 0
245 };
246};
247
248template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
249struct eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, Eigen::Dense>
250{
251 typedef const TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>& type;
252};
253
254template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
255struct nested<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, 1, typename eval<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> >::type>
256{
257 typedef TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType> type;
258};
259
260} // end namespace internal
261
262
263
264template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType>
265class TensorCwiseTernaryOp : public TensorBase<TensorCwiseTernaryOp<TernaryOp, Arg1XprType, Arg2XprType, Arg3XprType>, ReadOnlyAccessors>
266{
267 public:
268 typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Scalar Scalar;
269 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
270 typedef Scalar CoeffReturnType;
271 typedef typename Eigen::internal::nested<TensorCwiseTernaryOp>::type Nested;
272 typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::StorageKind StorageKind;
273 typedef typename Eigen::internal::traits<TensorCwiseTernaryOp>::Index Index;
274
275 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseTernaryOp(const Arg1XprType& arg1, const Arg2XprType& arg2, const Arg3XprType& arg3, const TernaryOp& func = TernaryOp())
276 : m_arg1_xpr(arg1), m_arg2_xpr(arg2), m_arg3_xpr(arg3), m_functor(func) {}
277
278 EIGEN_DEVICE_FUNC
279 const TernaryOp& functor() const { return m_functor; }
280
282 EIGEN_DEVICE_FUNC
283 const typename internal::remove_all<typename Arg1XprType::Nested>::type&
284 arg1Expression() const { return m_arg1_xpr; }
285
286 EIGEN_DEVICE_FUNC
287 const typename internal::remove_all<typename Arg2XprType::Nested>::type&
288 arg2Expression() const { return m_arg2_xpr; }
289
290 EIGEN_DEVICE_FUNC
291 const typename internal::remove_all<typename Arg3XprType::Nested>::type&
292 arg3Expression() const { return m_arg3_xpr; }
293
294 protected:
295 typename Arg1XprType::Nested m_arg1_xpr;
296 typename Arg2XprType::Nested m_arg2_xpr;
297 typename Arg3XprType::Nested m_arg3_xpr;
298 const TernaryOp m_functor;
299};
300
301
302namespace internal {
303template<typename IfXprType, typename ThenXprType, typename ElseXprType>
304struct traits<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >
305 : traits<ThenXprType>
306{
307 typedef typename traits<ThenXprType>::Scalar Scalar;
308 typedef traits<ThenXprType> XprTraits;
309 typedef typename promote_storage_type<typename traits<ThenXprType>::StorageKind,
310 typename traits<ElseXprType>::StorageKind>::ret StorageKind;
311 typedef typename promote_index_type<typename traits<ElseXprType>::Index,
312 typename traits<ThenXprType>::Index>::type Index;
313 typedef typename IfXprType::Nested IfNested;
314 typedef typename ThenXprType::Nested ThenNested;
315 typedef typename ElseXprType::Nested ElseNested;
316 static const int NumDimensions = XprTraits::NumDimensions;
317 static const int Layout = XprTraits::Layout;
318};
319
320template<typename IfXprType, typename ThenXprType, typename ElseXprType>
321struct eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, Eigen::Dense>
322{
323 typedef const TensorSelectOp<IfXprType, ThenXprType, ElseXprType>& type;
324};
325
326template<typename IfXprType, typename ThenXprType, typename ElseXprType>
327struct nested<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, 1, typename eval<TensorSelectOp<IfXprType, ThenXprType, ElseXprType> >::type>
328{
329 typedef TensorSelectOp<IfXprType, ThenXprType, ElseXprType> type;
330};
331
332} // end namespace internal
333
334
335template<typename IfXprType, typename ThenXprType, typename ElseXprType>
336class TensorSelectOp : public TensorBase<TensorSelectOp<IfXprType, ThenXprType, ElseXprType>, ReadOnlyAccessors>
337{
338 public:
339 typedef typename Eigen::internal::traits<TensorSelectOp>::Scalar Scalar;
340 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
341 typedef typename internal::promote_storage_type<typename ThenXprType::CoeffReturnType,
342 typename ElseXprType::CoeffReturnType>::ret CoeffReturnType;
343 typedef typename Eigen::internal::nested<TensorSelectOp>::type Nested;
344 typedef typename Eigen::internal::traits<TensorSelectOp>::StorageKind StorageKind;
345 typedef typename Eigen::internal::traits<TensorSelectOp>::Index Index;
346
347 EIGEN_DEVICE_FUNC
348 TensorSelectOp(const IfXprType& a_condition,
349 const ThenXprType& a_then,
350 const ElseXprType& a_else)
351 : m_condition(a_condition), m_then(a_then), m_else(a_else)
352 { }
353
354 EIGEN_DEVICE_FUNC
355 const IfXprType& ifExpression() const { return m_condition; }
356
357 EIGEN_DEVICE_FUNC
358 const ThenXprType& thenExpression() const { return m_then; }
359
360 EIGEN_DEVICE_FUNC
361 const ElseXprType& elseExpression() const { return m_else; }
362
363 protected:
364 typename IfXprType::Nested m_condition;
365 typename ThenXprType::Nested m_then;
366 typename ElseXprType::Nested m_else;
367};
368
369
370} // end namespace Eigen
371
372#endif // EIGEN_CXX11_TENSOR_TENSOR_EXPR_H
The tensor base class.
Definition TensorForwardDeclarations.h:29
const internal::remove_all< typenameLhsXprType::Nested >::type & lhsExpression() const
Definition TensorExpr.h:209
Tensor unary expression.
Definition TensorExpr.h:107
const internal::remove_all< typenameXprType::Nested >::type & nestedExpression() const
Definition TensorExpr.h:127
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index