Loading...
Searching...
No Matches
TensorEvaluator.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CXX11_TENSOR_TENSOR_EVALUATOR_H
11#define EIGEN_CXX11_TENSOR_TENSOR_EVALUATOR_H
12
13namespace Eigen {
14
15// Generic evaluator
26template <typename Derived, typename Device>
27struct TensorEvaluator {
28 typedef typename Derived::Index Index;
29 typedef typename Derived::Scalar Scalar;
30 typedef typename Derived::Scalar CoeffReturnType;
31 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
32 typedef typename Derived::Dimensions Dimensions;
33
34 // NumDimensions is -1 for variable dim tensors
35 static const int NumCoords = internal::traits<Derived>::NumDimensions > 0 ?
36 internal::traits<Derived>::NumDimensions : 0;
37
38 enum {
39 IsAligned = Derived::IsAligned,
40 PacketAccess = (internal::unpacket_traits<PacketReturnType>::size > 1),
41 Layout = Derived::Layout,
42 CoordAccess = NumCoords > 0,
43 RawAccess = true
44 };
45
46 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived& m, const Device& device)
47 : m_data(const_cast<typename internal::traits<Derived>::template MakePointer<Scalar>::Type>(m.data())), m_dims(m.dimensions()), m_device(device), m_impl(m)
48 { }
49
50 // Used for accessor extraction in SYCL Managed TensorMap:
51 const Derived& derived() const { return m_impl; }
52 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dims; }
53
54 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* dest) {
55 if (dest) {
56 m_device.memcpy((void*)dest, m_data, sizeof(Scalar) * m_dims.TotalSize());
57 return false;
58 }
59 return true;
60 }
61
62 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { }
63
64 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
65 eigen_assert(m_data);
66 return m_data[index];
67 }
68
69 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) {
70 eigen_assert(m_data);
71 return m_data[index];
72 }
73
74 template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
75 PacketReturnType packet(Index index) const
76 {
77 return internal::ploadt<PacketReturnType, LoadMode>(m_data + index);
78 }
79
80 template <int StoreMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
81 void writePacket(Index index, const PacketReturnType& x)
82 {
83 return internal::pstoret<Scalar, PacketReturnType, StoreMode>(m_data + index, x);
84 }
85
86 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(const array<DenseIndex, NumCoords>& coords) const {
87 eigen_assert(m_data);
88 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
89 return m_data[m_dims.IndexOfColMajor(coords)];
90 } else {
91 return m_data[m_dims.IndexOfRowMajor(coords)];
92 }
93 }
94
95 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(const array<DenseIndex, NumCoords>& coords) {
96 eigen_assert(m_data);
97 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
98 return m_data[m_dims.IndexOfColMajor(coords)];
99 } else {
100 return m_data[m_dims.IndexOfRowMajor(coords)];
101 }
102 }
103
104 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
105 return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized,
106 internal::unpacket_traits<PacketReturnType>::size);
107 }
108
109 EIGEN_DEVICE_FUNC typename internal::traits<Derived>::template MakePointer<Scalar>::Type data() const { return m_data; }
110
112 const Device& device() const{return m_device;}
113
114 protected:
115 typename internal::traits<Derived>::template MakePointer<Scalar>::Type m_data;
116 Dimensions m_dims;
117 const Device& m_device;
118 const Derived& m_impl;
119};
120
121namespace {
122template <typename T> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
123T loadConstant(const T* address) {
124 return *address;
125}
126// Use the texture cache on CUDA devices whenever possible
127#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
128template <> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
129float loadConstant(const float* address) {
130 return __ldg(address);
131}
132template <> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
133double loadConstant(const double* address) {
134 return __ldg(address);
135}
136template <> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
137Eigen::half loadConstant(const Eigen::half* address) {
138 return Eigen::half(half_impl::raw_uint16_to_half(__ldg(&address->x)));
139}
140#endif
141}
142
143
144// Default evaluator for rvalues
145template<typename Derived, typename Device>
146struct TensorEvaluator<const Derived, Device>
147{
148 typedef typename Derived::Index Index;
149 typedef typename Derived::Scalar Scalar;
150 typedef typename Derived::Scalar CoeffReturnType;
151 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
152 typedef typename Derived::Dimensions Dimensions;
153
154 // NumDimensions is -1 for variable dim tensors
155 static const int NumCoords = internal::traits<Derived>::NumDimensions > 0 ?
156 internal::traits<Derived>::NumDimensions : 0;
157
158 enum {
159 IsAligned = Derived::IsAligned,
160 PacketAccess = (internal::unpacket_traits<PacketReturnType>::size > 1),
161 Layout = Derived::Layout,
162 CoordAccess = NumCoords > 0,
163 RawAccess = true
164 };
165
166 // Used for accessor extraction in SYCL Managed TensorMap:
167 const Derived& derived() const { return m_impl; }
168
169 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived& m, const Device& device)
170 : m_data(m.data()), m_dims(m.dimensions()), m_device(device), m_impl(m)
171 { }
172
173 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dims; }
174
175 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) {
176 if (!NumTraits<typename internal::remove_const<Scalar>::type>::RequireInitialization && data) {
177 m_device.memcpy((void*)data, m_data, m_dims.TotalSize() * sizeof(Scalar));
178 return false;
179 }
180 return true;
181 }
182
183 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { }
184
185 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
186 eigen_assert(m_data);
187 return loadConstant(m_data+index);
188 }
189
190 template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
191 PacketReturnType packet(Index index) const
192 {
193 return internal::ploadt_ro<PacketReturnType, LoadMode>(m_data + index);
194 }
195
196 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(const array<DenseIndex, NumCoords>& coords) const {
197 eigen_assert(m_data);
198 const Index index = (static_cast<int>(Layout) == static_cast<int>(ColMajor)) ? m_dims.IndexOfColMajor(coords)
199 : m_dims.IndexOfRowMajor(coords);
200 return loadConstant(m_data+index);
201 }
202
203 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
204 return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized,
205 internal::unpacket_traits<PacketReturnType>::size);
206 }
207
208 EIGEN_DEVICE_FUNC typename internal::traits<Derived>::template MakePointer<const Scalar>::Type data() const { return m_data; }
209
211 const Device& device() const{return m_device;}
212
213 protected:
214 typename internal::traits<Derived>::template MakePointer<const Scalar>::Type m_data;
215 Dimensions m_dims;
216 const Device& m_device;
217 const Derived& m_impl;
218};
219
220
221
222
223// -------------------- CwiseNullaryOp --------------------
224
225template<typename NullaryOp, typename ArgType, typename Device>
226struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, ArgType>, Device>
227{
228 typedef TensorCwiseNullaryOp<NullaryOp, ArgType> XprType;
229
230 enum {
231 IsAligned = true,
232 PacketAccess = internal::functor_traits<NullaryOp>::PacketAccess,
233 Layout = TensorEvaluator<ArgType, Device>::Layout,
234 CoordAccess = false, // to be implemented
235 RawAccess = false
236 };
237
238 EIGEN_DEVICE_FUNC
239 TensorEvaluator(const XprType& op, const Device& device)
240 : m_functor(op.functor()), m_argImpl(op.nestedExpression(), device), m_wrapper()
241 { }
242
243 typedef typename XprType::Index Index;
244 typedef typename XprType::Scalar Scalar;
245 typedef typename internal::traits<XprType>::Scalar CoeffReturnType;
246 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
247 static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
248 typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions;
249
250 EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_argImpl.dimensions(); }
251
252 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType*) { return true; }
253 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { }
254
255 EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
256 {
257 return m_wrapper(m_functor, index);
258 }
259
260 template<int LoadMode>
261 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
262 {
263 return m_wrapper.template packetOp<PacketReturnType, Index>(m_functor, index);
264 }
265
266 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
267 costPerCoeff(bool vectorized) const {
268 return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized,
269 internal::unpacket_traits<PacketReturnType>::size);
270 }
271
272 EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return NULL; }
273
275 const TensorEvaluator<ArgType, Device>& impl() const { return m_argImpl; }
277 NullaryOp functor() const { return m_functor; }
278
279
280 private:
281 const NullaryOp m_functor;
282 TensorEvaluator<ArgType, Device> m_argImpl;
283 const internal::nullary_wrapper<CoeffReturnType,NullaryOp> m_wrapper;
284};
285
286
287
288// -------------------- CwiseUnaryOp --------------------
289
290template<typename UnaryOp, typename ArgType, typename Device>
291struct TensorEvaluator<const TensorCwiseUnaryOp<UnaryOp, ArgType>, Device>
292{
293 typedef TensorCwiseUnaryOp<UnaryOp, ArgType> XprType;
294
295 enum {
296 IsAligned = TensorEvaluator<ArgType, Device>::IsAligned,
297 PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess & internal::functor_traits<UnaryOp>::PacketAccess,
298 Layout = TensorEvaluator<ArgType, Device>::Layout,
299 CoordAccess = false, // to be implemented
300 RawAccess = false
301 };
302
303 EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
304 : m_functor(op.functor()),
305 m_argImpl(op.nestedExpression(), device)
306 { }
307
308 typedef typename XprType::Index Index;
309 typedef typename XprType::Scalar Scalar;
310 typedef typename internal::traits<XprType>::Scalar CoeffReturnType;
311 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
312 static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
313 typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions;
314
315 EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_argImpl.dimensions(); }
316
317 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar*) {
318 m_argImpl.evalSubExprsIfNeeded(NULL);
319 return true;
320 }
321 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
322 m_argImpl.cleanup();
323 }
324
325 EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
326 {
327 return m_functor(m_argImpl.coeff(index));
328 }
329
330 template<int LoadMode>
331 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
332 {
333 return m_functor.packetOp(m_argImpl.template packet<LoadMode>(index));
334 }
335
336 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
337 const double functor_cost = internal::functor_traits<UnaryOp>::Cost;
338 return m_argImpl.costPerCoeff(vectorized) +
339 TensorOpCost(0, 0, functor_cost, vectorized, PacketSize);
340 }
341
342 EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return NULL; }
343
345 const TensorEvaluator<ArgType, Device> & impl() const { return m_argImpl; }
347 UnaryOp functor() const { return m_functor; }
348
349
350 private:
351 const UnaryOp m_functor;
352 TensorEvaluator<ArgType, Device> m_argImpl;
353};
354
355
356// -------------------- CwiseBinaryOp --------------------
357
358template<typename BinaryOp, typename LeftArgType, typename RightArgType, typename Device>
359struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArgType>, Device>
360{
361 typedef TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArgType> XprType;
362
363 enum {
364 IsAligned = TensorEvaluator<LeftArgType, Device>::IsAligned & TensorEvaluator<RightArgType, Device>::IsAligned,
365 PacketAccess = TensorEvaluator<LeftArgType, Device>::PacketAccess & TensorEvaluator<RightArgType, Device>::PacketAccess &
366 internal::functor_traits<BinaryOp>::PacketAccess,
367 Layout = TensorEvaluator<LeftArgType, Device>::Layout,
368 CoordAccess = false, // to be implemented
369 RawAccess = false
370 };
371
372 EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
373 : m_functor(op.functor()),
374 m_leftImpl(op.lhsExpression(), device),
375 m_rightImpl(op.rhsExpression(), device)
376 {
377 EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) == static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout) || internal::traits<XprType>::NumDimensions <= 1), YOU_MADE_A_PROGRAMMING_MISTAKE);
378 eigen_assert(dimensions_match(m_leftImpl.dimensions(), m_rightImpl.dimensions()));
379 }
380
381 typedef typename XprType::Index Index;
382 typedef typename XprType::Scalar Scalar;
383 typedef typename internal::traits<XprType>::Scalar CoeffReturnType;
384 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
385 static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
386 typedef typename TensorEvaluator<LeftArgType, Device>::Dimensions Dimensions;
387
388 EIGEN_DEVICE_FUNC const Dimensions& dimensions() const
389 {
390 // TODO: use right impl instead if right impl dimensions are known at compile time.
391 return m_leftImpl.dimensions();
392 }
393
394 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType*) {
395 m_leftImpl.evalSubExprsIfNeeded(NULL);
396 m_rightImpl.evalSubExprsIfNeeded(NULL);
397 return true;
398 }
399 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
400 m_leftImpl.cleanup();
401 m_rightImpl.cleanup();
402 }
403
404 EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
405 {
406 return m_functor(m_leftImpl.coeff(index), m_rightImpl.coeff(index));
407 }
408 template<int LoadMode>
409 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
410 {
411 return m_functor.packetOp(m_leftImpl.template packet<LoadMode>(index), m_rightImpl.template packet<LoadMode>(index));
412 }
413
414 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
415 costPerCoeff(bool vectorized) const {
416 const double functor_cost = internal::functor_traits<BinaryOp>::Cost;
417 return m_leftImpl.costPerCoeff(vectorized) +
418 m_rightImpl.costPerCoeff(vectorized) +
419 TensorOpCost(0, 0, functor_cost, vectorized, PacketSize);
420 }
421
422 EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return NULL; }
424 const TensorEvaluator<LeftArgType, Device>& left_impl() const { return m_leftImpl; }
426 const TensorEvaluator<RightArgType, Device>& right_impl() const { return m_rightImpl; }
428 BinaryOp functor() const { return m_functor; }
429
430 private:
431 const BinaryOp m_functor;
432 TensorEvaluator<LeftArgType, Device> m_leftImpl;
433 TensorEvaluator<RightArgType, Device> m_rightImpl;
434};
435
436// -------------------- CwiseTernaryOp --------------------
437
438template<typename TernaryOp, typename Arg1Type, typename Arg2Type, typename Arg3Type, typename Device>
439struct TensorEvaluator<const TensorCwiseTernaryOp<TernaryOp, Arg1Type, Arg2Type, Arg3Type>, Device>
440{
441 typedef TensorCwiseTernaryOp<TernaryOp, Arg1Type, Arg2Type, Arg3Type> XprType;
442
443 enum {
444 IsAligned = TensorEvaluator<Arg1Type, Device>::IsAligned & TensorEvaluator<Arg2Type, Device>::IsAligned & TensorEvaluator<Arg3Type, Device>::IsAligned,
445 PacketAccess = TensorEvaluator<Arg1Type, Device>::PacketAccess & TensorEvaluator<Arg2Type, Device>::PacketAccess & TensorEvaluator<Arg3Type, Device>::PacketAccess &
446 internal::functor_traits<TernaryOp>::PacketAccess,
447 Layout = TensorEvaluator<Arg1Type, Device>::Layout,
448 CoordAccess = false, // to be implemented
449 RawAccess = false
450 };
451
452 EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
453 : m_functor(op.functor()),
454 m_arg1Impl(op.arg1Expression(), device),
455 m_arg2Impl(op.arg2Expression(), device),
456 m_arg3Impl(op.arg3Expression(), device)
457 {
458 EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<Arg1Type, Device>::Layout) == static_cast<int>(TensorEvaluator<Arg3Type, Device>::Layout) || internal::traits<XprType>::NumDimensions <= 1), YOU_MADE_A_PROGRAMMING_MISTAKE);
459
460 EIGEN_STATIC_ASSERT((internal::is_same<typename internal::traits<Arg1Type>::StorageKind,
461 typename internal::traits<Arg2Type>::StorageKind>::value),
462 STORAGE_KIND_MUST_MATCH)
463 EIGEN_STATIC_ASSERT((internal::is_same<typename internal::traits<Arg1Type>::StorageKind,
464 typename internal::traits<Arg3Type>::StorageKind>::value),
465 STORAGE_KIND_MUST_MATCH)
466 EIGEN_STATIC_ASSERT((internal::is_same<typename internal::traits<Arg1Type>::Index,
467 typename internal::traits<Arg2Type>::Index>::value),
468 STORAGE_INDEX_MUST_MATCH)
469 EIGEN_STATIC_ASSERT((internal::is_same<typename internal::traits<Arg1Type>::Index,
470 typename internal::traits<Arg3Type>::Index>::value),
471 STORAGE_INDEX_MUST_MATCH)
472
473 eigen_assert(dimensions_match(m_arg1Impl.dimensions(), m_arg2Impl.dimensions()) && dimensions_match(m_arg1Impl.dimensions(), m_arg3Impl.dimensions()));
474 }
475
476 typedef typename XprType::Index Index;
477 typedef typename XprType::Scalar Scalar;
478 typedef typename internal::traits<XprType>::Scalar CoeffReturnType;
479 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
480 static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
481 typedef typename TensorEvaluator<Arg1Type, Device>::Dimensions Dimensions;
482
483 EIGEN_DEVICE_FUNC const Dimensions& dimensions() const
484 {
485 // TODO: use arg2 or arg3 dimensions if they are known at compile time.
486 return m_arg1Impl.dimensions();
487 }
488
489 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType*) {
490 m_arg1Impl.evalSubExprsIfNeeded(NULL);
491 m_arg2Impl.evalSubExprsIfNeeded(NULL);
492 m_arg3Impl.evalSubExprsIfNeeded(NULL);
493 return true;
494 }
495 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
496 m_arg1Impl.cleanup();
497 m_arg2Impl.cleanup();
498 m_arg3Impl.cleanup();
499 }
500
501 EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
502 {
503 return m_functor(m_arg1Impl.coeff(index), m_arg2Impl.coeff(index), m_arg3Impl.coeff(index));
504 }
505 template<int LoadMode>
506 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
507 {
508 return m_functor.packetOp(m_arg1Impl.template packet<LoadMode>(index),
509 m_arg2Impl.template packet<LoadMode>(index),
510 m_arg3Impl.template packet<LoadMode>(index));
511 }
512
513 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
514 costPerCoeff(bool vectorized) const {
515 const double functor_cost = internal::functor_traits<TernaryOp>::Cost;
516 return m_arg1Impl.costPerCoeff(vectorized) +
517 m_arg2Impl.costPerCoeff(vectorized) +
518 m_arg3Impl.costPerCoeff(vectorized) +
519 TensorOpCost(0, 0, functor_cost, vectorized, PacketSize);
520 }
521
522 EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return NULL; }
523
525 const TensorEvaluator<Arg1Type, Device> & arg1Impl() const { return m_arg1Impl; }
527 const TensorEvaluator<Arg2Type, Device>& arg2Impl() const { return m_arg2Impl; }
529 const TensorEvaluator<Arg3Type, Device>& arg3Impl() const { return m_arg3Impl; }
530
531 private:
532 const TernaryOp m_functor;
533 TensorEvaluator<Arg1Type, Device> m_arg1Impl;
534 TensorEvaluator<Arg2Type, Device> m_arg2Impl;
535 TensorEvaluator<Arg3Type, Device> m_arg3Impl;
536};
537
538
539// -------------------- SelectOp --------------------
540
541template<typename IfArgType, typename ThenArgType, typename ElseArgType, typename Device>
542struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType>, Device>
543{
544 typedef TensorSelectOp<IfArgType, ThenArgType, ElseArgType> XprType;
545 typedef typename XprType::Scalar Scalar;
546
547 enum {
548 IsAligned = TensorEvaluator<ThenArgType, Device>::IsAligned & TensorEvaluator<ElseArgType, Device>::IsAligned,
549 PacketAccess = TensorEvaluator<ThenArgType, Device>::PacketAccess & TensorEvaluator<ElseArgType, Device>::PacketAccess &
550 internal::packet_traits<Scalar>::HasBlend,
551 Layout = TensorEvaluator<IfArgType, Device>::Layout,
552 CoordAccess = false, // to be implemented
553 RawAccess = false
554 };
555
556 EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
557 : m_condImpl(op.ifExpression(), device),
558 m_thenImpl(op.thenExpression(), device),
559 m_elseImpl(op.elseExpression(), device)
560 {
561 EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<IfArgType, Device>::Layout) == static_cast<int>(TensorEvaluator<ThenArgType, Device>::Layout)), YOU_MADE_A_PROGRAMMING_MISTAKE);
562 EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<IfArgType, Device>::Layout) == static_cast<int>(TensorEvaluator<ElseArgType, Device>::Layout)), YOU_MADE_A_PROGRAMMING_MISTAKE);
563 eigen_assert(dimensions_match(m_condImpl.dimensions(), m_thenImpl.dimensions()));
564 eigen_assert(dimensions_match(m_thenImpl.dimensions(), m_elseImpl.dimensions()));
565 }
566
567 typedef typename XprType::Index Index;
568 typedef typename internal::traits<XprType>::Scalar CoeffReturnType;
569 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
570 static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
571 typedef typename TensorEvaluator<IfArgType, Device>::Dimensions Dimensions;
572
573 EIGEN_DEVICE_FUNC const Dimensions& dimensions() const
574 {
575 // TODO: use then or else impl instead if they happen to be known at compile time.
576 return m_condImpl.dimensions();
577 }
578
579 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType*) {
580 m_condImpl.evalSubExprsIfNeeded(NULL);
581 m_thenImpl.evalSubExprsIfNeeded(NULL);
582 m_elseImpl.evalSubExprsIfNeeded(NULL);
583 return true;
584 }
585 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
586 m_condImpl.cleanup();
587 m_thenImpl.cleanup();
588 m_elseImpl.cleanup();
589 }
590
591 EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
592 {
593 return m_condImpl.coeff(index) ? m_thenImpl.coeff(index) : m_elseImpl.coeff(index);
594 }
595 template<int LoadMode>
596 EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const
597 {
598 internal::Selector<PacketSize> select;
599 for (Index i = 0; i < PacketSize; ++i) {
600 select.select[i] = m_condImpl.coeff(index+i);
601 }
602 return internal::pblend(select,
603 m_thenImpl.template packet<LoadMode>(index),
604 m_elseImpl.template packet<LoadMode>(index));
605 }
606
607 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
608 costPerCoeff(bool vectorized) const {
609 return m_condImpl.costPerCoeff(vectorized) +
610 m_thenImpl.costPerCoeff(vectorized)
611 .cwiseMax(m_elseImpl.costPerCoeff(vectorized));
612 }
613
614 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType* data() const { return NULL; }
616 const TensorEvaluator<IfArgType, Device> & cond_impl() const { return m_condImpl; }
618 const TensorEvaluator<ThenArgType, Device>& then_impl() const { return m_thenImpl; }
620 const TensorEvaluator<ElseArgType, Device>& else_impl() const { return m_elseImpl; }
621
622 private:
623 TensorEvaluator<IfArgType, Device> m_condImpl;
624 TensorEvaluator<ThenArgType, Device> m_thenImpl;
625 TensorEvaluator<ElseArgType, Device> m_elseImpl;
626};
627
628
629} // end namespace Eigen
630
631#endif // EIGEN_CXX11_TENSOR_TENSOR_EVALUATOR_H
Tensor binary expression.
Definition TensorExpr.h:189
Tensor nullary expression.
Definition TensorExpr.h:43
Tensor unary expression.
Definition TensorExpr.h:107
Namespace containing all symbols from the Eigen library.
const Device & device() const
required by sycl in order to construct sycl buffer from raw pointer
Definition TensorEvaluator.h:112