10#ifndef EIGEN_CXX11_TENSOR_TENSOR_EVALUATOR_H
11#define EIGEN_CXX11_TENSOR_TENSOR_EVALUATOR_H
14#include "./InternalHeaderCheck.h"
29template <
typename Derived,
typename Device>
30struct TensorEvaluator {
31 typedef typename Derived::Index Index;
32 typedef typename Derived::Scalar Scalar;
33 typedef typename Derived::Scalar CoeffReturnType;
34 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
35 typedef typename Derived::Dimensions Dimensions;
36 typedef Derived XprType;
37 static constexpr int PacketSize = PacketType<CoeffReturnType, Device>::size;
38 typedef typename internal::traits<Derived>::template MakePointer<Scalar>::Type TensorPointerType;
39 typedef StorageMemory<Scalar, Device> Storage;
40 typedef typename Storage::Type EvaluatorPointerType;
43 static constexpr int NumCoords =
44 internal::traits<Derived>::NumDimensions > 0 ? internal::traits<Derived>::NumDimensions : 0;
45 static constexpr int Layout = Derived::Layout;
48 IsAligned = Derived::IsAligned,
49 PacketAccess = (PacketType<CoeffReturnType, Device>::size > 1),
50 BlockAccess = internal::is_arithmetic<std::remove_const_t<Scalar>>::value,
51 PreferBlockAccess =
false,
52 CoordAccess = NumCoords > 0,
56 typedef std::remove_const_t<Scalar> ScalarNoConst;
59 typedef internal::TensorBlockDescriptor<NumCoords, Index> TensorBlockDesc;
60 typedef internal::TensorBlockScratchAllocator<Device> TensorBlockScratch;
62 typedef typename internal::TensorMaterializedBlock<ScalarNoConst, NumCoords, Layout, Index> TensorBlock;
65 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(
const Derived& m,
const Device& device)
66 : m_data(device.get((
const_cast<TensorPointerType
>(m.data())))), m_dims(m.dimensions()), m_device(device) {}
68 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Dimensions& dimensions()
const {
return m_dims; }
70 EIGEN_STRONG_INLINE
bool evalSubExprsIfNeeded(EvaluatorPointerType dest) {
71 if (!
NumTraits<std::remove_const_t<Scalar>>::RequireInitialization && dest) {
72 m_device.memcpy((
void*)(m_device.get(dest)), m_device.get(m_data), m_dims.TotalSize() *
sizeof(Scalar));
78#ifdef EIGEN_USE_THREADS
79 template <
typename EvalSubExprsCallback>
80 EIGEN_STRONG_INLINE
void evalSubExprsIfNeededAsync(EvaluatorPointerType dest, EvalSubExprsCallback done) {
82 done(evalSubExprsIfNeeded(dest));
86 EIGEN_STRONG_INLINE
void cleanup() {}
88 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index)
const {
89 eigen_assert(m_data != NULL);
93 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType& coeffRef(Index index)
const {
94 eigen_assert(m_data != NULL);
98 template <
int LoadMode>
99 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index)
const {
100 return internal::ploadt<PacketReturnType, LoadMode>(m_data + index);
108 template <
typename PacketReturnTypeT>
109 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
110 std::enable_if_t<internal::unpacket_traits<PacketReturnTypeT>::masked_load_available, PacketReturnTypeT>
111 partialPacket(Index index,
typename internal::unpacket_traits<PacketReturnTypeT>::mask_t umask)
const {
112 return internal::ploadu<PacketReturnTypeT>(m_data + index, umask);
115 template <
int StoreMode>
116 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void writePacket(Index index,
const PacketReturnType& x)
const {
117 return internal::pstoret<Scalar, PacketReturnType, StoreMode>(m_data + index, x);
120 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(
const array<DenseIndex, NumCoords>& coords)
const {
121 eigen_assert(m_data != NULL);
122 if (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) {
123 return m_data[m_dims.IndexOfColMajor(coords)];
125 return m_data[m_dims.IndexOfRowMajor(coords)];
129 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType& coeffRef(
const array<DenseIndex, NumCoords>& coords)
const {
130 eigen_assert(m_data != NULL);
131 if (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) {
132 return m_data[m_dims.IndexOfColMajor(coords)];
134 return m_data[m_dims.IndexOfRowMajor(coords)];
138 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(
bool vectorized)
const {
139 return TensorOpCost(
sizeof(CoeffReturnType), 0, 0, vectorized, PacketType<CoeffReturnType, Device>::size);
142 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE internal::TensorBlockResourceRequirements getResourceRequirements()
const {
143 return internal::TensorBlockResourceRequirements::any();
146 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlock block(TensorBlockDesc& desc, TensorBlockScratch& scratch,
147 bool =
false)
const {
148 eigen_assert(m_data != NULL);
149 return TensorBlock::materialize(m_data, m_dims, desc, scratch);
152 template <
typename TensorBlock>
153 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void writeBlock(
const TensorBlockDesc& desc,
const TensorBlock& block) {
154 eigen_assert(m_data != NULL);
156 typedef typename TensorBlock::XprType TensorBlockExpr;
157 typedef internal::TensorBlockAssignment<Scalar, NumCoords, TensorBlockExpr, Index> TensorBlockAssign;
159 TensorBlockAssign::Run(
160 TensorBlockAssign::target(desc.dimensions(), internal::strides<Layout>(m_dims), m_data, desc.offset()),
164 EIGEN_DEVICE_FUNC EvaluatorPointerType data()
const {
return m_data; }
167 EvaluatorPointerType m_data;
169 const Device EIGEN_DEVICE_REF m_device;
174EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T loadConstant(
const T* address) {
178#if defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 350
180EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
float loadConstant(
const float* address) {
181 return __ldg(address);
184EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
double loadConstant(
const double* address) {
185 return __ldg(address);
188EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Eigen::half loadConstant(
const Eigen::half* address) {
189 return Eigen::half(half_impl::raw_uint16_to_half(__ldg(&address->x)));
196template <
typename Derived,
typename Device>
198 typedef typename Derived::Index Index;
199 typedef typename Derived::Scalar Scalar;
200 typedef typename Derived::Scalar CoeffReturnType;
201 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
202 typedef typename Derived::Dimensions Dimensions;
203 typedef const Derived XprType;
204 typedef typename internal::traits<Derived>::template MakePointer<const Scalar>::Type TensorPointerType;
205 typedef StorageMemory<const Scalar, Device> Storage;
206 typedef typename Storage::Type EvaluatorPointerType;
208 typedef std::remove_const_t<Scalar> ScalarNoConst;
211 static constexpr int NumCoords =
212 internal::traits<Derived>::NumDimensions > 0 ? internal::traits<Derived>::NumDimensions : 0;
213 static constexpr int PacketSize = PacketType<CoeffReturnType, Device>::size;
214 static constexpr int Layout = Derived::Layout;
217 IsAligned = Derived::IsAligned,
218 PacketAccess = (PacketType<CoeffReturnType, Device>::size > 1),
219 BlockAccess = internal::is_arithmetic<ScalarNoConst>::value,
220 PreferBlockAccess =
false,
221 CoordAccess = NumCoords > 0,
226 typedef internal::TensorBlockDescriptor<NumCoords, Index> TensorBlockDesc;
227 typedef internal::TensorBlockScratchAllocator<Device> TensorBlockScratch;
229 typedef typename internal::TensorMaterializedBlock<ScalarNoConst, NumCoords, Layout, Index> TensorBlock;
232 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC TensorEvaluator(
const Derived& m,
const Device& device)
233 : m_data(device.get(m.data())), m_dims(m.dimensions()), m_device(device) {}
235 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Dimensions& dimensions()
const {
return m_dims; }
237 EIGEN_STRONG_INLINE
bool evalSubExprsIfNeeded(EvaluatorPointerType data) {
238 if (!NumTraits<std::remove_const_t<Scalar>>::RequireInitialization && data) {
239 m_device.memcpy((
void*)(m_device.get(data)), m_device.get(m_data), m_dims.TotalSize() *
sizeof(Scalar));
245#ifdef EIGEN_USE_THREADS
246 template <
typename EvalSubExprsCallback>
247 EIGEN_STRONG_INLINE
void evalSubExprsIfNeededAsync(EvaluatorPointerType dest, EvalSubExprsCallback done) {
249 done(evalSubExprsIfNeeded(dest));
253 EIGEN_STRONG_INLINE
void cleanup() {}
255 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index)
const {
256 eigen_assert(m_data != NULL);
257 return internal::loadConstant(m_data + index);
260 template <
int LoadMode>
261 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index)
const {
262 return internal::ploadt_ro<PacketReturnType, LoadMode>(m_data + index);
270 template <
typename PacketReturnTypeT>
271 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
272 std::enable_if_t<internal::unpacket_traits<PacketReturnTypeT>::masked_load_available, PacketReturnTypeT>
273 partialPacket(Index index,
typename internal::unpacket_traits<PacketReturnTypeT>::mask_t umask)
const {
274 return internal::ploadu<PacketReturnTypeT>(m_data + index, umask);
277 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(
const array<DenseIndex, NumCoords>& coords)
const {
278 eigen_assert(m_data != NULL);
279 const Index index = (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) ? m_dims.IndexOfColMajor(coords)
280 : m_dims.IndexOfRowMajor(coords);
281 return internal::loadConstant(m_data + index);
284 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(
bool vectorized)
const {
285 return TensorOpCost(
sizeof(CoeffReturnType), 0, 0, vectorized, PacketType<CoeffReturnType, Device>::size);
288 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE internal::TensorBlockResourceRequirements getResourceRequirements()
const {
289 return internal::TensorBlockResourceRequirements::any();
292 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlock block(TensorBlockDesc& desc, TensorBlockScratch& scratch,
293 bool =
false)
const {
294 eigen_assert(m_data != NULL);
295 return TensorBlock::materialize(m_data, m_dims, desc, scratch);
298 EIGEN_DEVICE_FUNC EvaluatorPointerType data()
const {
return m_data; }
301 EvaluatorPointerType m_data;
303 const Device EIGEN_DEVICE_REF m_device;
308template <
typename NullaryOp,
typename ArgType,
typename Device>
310 typedef TensorCwiseNullaryOp<NullaryOp, ArgType> XprType;
312 EIGEN_DEVICE_FUNC TensorEvaluator(
const XprType& op,
const Device& device)
313 : m_functor(op.functor()), m_argImpl(op.nestedExpression(), device), m_wrapper() {}
315 typedef typename XprType::Index Index;
316 typedef typename XprType::Scalar Scalar;
317 typedef typename internal::traits<XprType>::Scalar CoeffReturnType;
318 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
319 static constexpr int PacketSize = PacketType<CoeffReturnType, Device>::size;
320 typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions;
321 typedef StorageMemory<CoeffReturnType, Device> Storage;
322 typedef typename Storage::Type EvaluatorPointerType;
324 static constexpr int Layout = TensorEvaluator<ArgType, Device>::Layout;
327 PacketAccess = internal::functor_traits<NullaryOp>::PacketAccess
329 && (PacketType<CoeffReturnType, Device>::size > 1)
333 PreferBlockAccess =
false,
339 typedef internal::TensorBlockNotImplemented TensorBlock;
342 EIGEN_DEVICE_FUNC
const Dimensions& dimensions()
const {
return m_argImpl.dimensions(); }
344 EIGEN_STRONG_INLINE
bool evalSubExprsIfNeeded(EvaluatorPointerType) {
return true; }
346#ifdef EIGEN_USE_THREADS
347 template <
typename EvalSubExprsCallback>
348 EIGEN_STRONG_INLINE
void evalSubExprsIfNeededAsync(EvaluatorPointerType, EvalSubExprsCallback done) {
353 EIGEN_STRONG_INLINE
void cleanup() {}
355 EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index)
const {
return m_wrapper(m_functor, index); }
357 template <
int LoadMode>
358 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index)
const {
359 return m_wrapper.template packetOp<PacketReturnType, Index>(m_functor, index);
362 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(
bool vectorized)
const {
363 return TensorOpCost(
sizeof(CoeffReturnType), 0, 0, vectorized, PacketType<CoeffReturnType, Device>::size);
366 EIGEN_DEVICE_FUNC EvaluatorPointerType data()
const {
return NULL; }
369 const NullaryOp m_functor;
370 TensorEvaluator<ArgType, Device> m_argImpl;
371 const internal::nullary_wrapper<CoeffReturnType, NullaryOp> m_wrapper;
376template <
typename UnaryOp,
typename ArgType,
typename Device>
378 typedef TensorCwiseUnaryOp<UnaryOp, ArgType> XprType;
380 static constexpr int Layout = TensorEvaluator<ArgType, Device>::Layout;
382 IsAligned = TensorEvaluator<ArgType, Device>::IsAligned,
384 int(TensorEvaluator<ArgType, Device>::PacketAccess) & int(internal::functor_traits<UnaryOp>::PacketAccess),
385 BlockAccess = TensorEvaluator<ArgType, Device>::BlockAccess,
386 PreferBlockAccess = TensorEvaluator<ArgType, Device>::PreferBlockAccess,
391 EIGEN_DEVICE_FUNC TensorEvaluator(
const XprType& op,
const Device& device)
392 : m_device(device), m_functor(op.functor()), m_argImpl(op.nestedExpression(), device) {}
394 typedef typename XprType::Index Index;
395 typedef typename XprType::Scalar Scalar;
396 typedef std::remove_const_t<Scalar> ScalarNoConst;
397 typedef typename internal::traits<XprType>::Scalar CoeffReturnType;
398 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
399 static constexpr int PacketSize = PacketType<CoeffReturnType, Device>::size;
400 typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions;
401 typedef StorageMemory<CoeffReturnType, Device> Storage;
402 typedef typename Storage::Type EvaluatorPointerType;
403 static constexpr int NumDims = internal::array_size<Dimensions>::value;
406 typedef internal::TensorBlockDescriptor<NumDims, Index> TensorBlockDesc;
407 typedef internal::TensorBlockScratchAllocator<Device> TensorBlockScratch;
409 typedef typename TensorEvaluator<const ArgType, Device>::TensorBlock ArgTensorBlock;
411 typedef internal::TensorCwiseUnaryBlock<UnaryOp, ArgTensorBlock> TensorBlock;
414 EIGEN_DEVICE_FUNC
const Dimensions& dimensions()
const {
return m_argImpl.dimensions(); }
416 EIGEN_STRONG_INLINE
bool evalSubExprsIfNeeded(EvaluatorPointerType) {
417 m_argImpl.evalSubExprsIfNeeded(NULL);
421#ifdef EIGEN_USE_THREADS
422 template <
typename EvalSubExprsCallback>
423 EIGEN_STRONG_INLINE
void evalSubExprsIfNeededAsync(EvaluatorPointerType, EvalSubExprsCallback done) {
424 m_argImpl.evalSubExprsIfNeededAsync(
nullptr, [done](
bool) { done(
true); });
428 EIGEN_STRONG_INLINE
void cleanup() { m_argImpl.cleanup(); }
430 EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index)
const {
return m_functor(m_argImpl.coeff(index)); }
432 template <
int LoadMode>
433 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index)
const {
434 return m_functor.packetOp(m_argImpl.template packet<LoadMode>(index));
437 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(
bool vectorized)
const {
438 const double functor_cost = internal::functor_traits<UnaryOp>::Cost;
439 return m_argImpl.costPerCoeff(vectorized) + TensorOpCost(0, 0, functor_cost, vectorized, PacketSize);
442 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE internal::TensorBlockResourceRequirements getResourceRequirements()
const {
443 static const double functor_cost = internal::functor_traits<UnaryOp>::Cost;
444 return m_argImpl.getResourceRequirements().addCostPerCoeff({0, 0, functor_cost / PacketSize});
447 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlock block(TensorBlockDesc& desc, TensorBlockScratch& scratch,
448 bool =
false)
const {
449 return TensorBlock(m_argImpl.block(desc, scratch), m_functor);
452 EIGEN_DEVICE_FUNC EvaluatorPointerType data()
const {
return NULL; }
455 const Device EIGEN_DEVICE_REF m_device;
456 const UnaryOp m_functor;
457 TensorEvaluator<ArgType, Device> m_argImpl;
462template <
typename BinaryOp,
typename LeftArgType,
typename RightArgType,
typename Device>
464 typedef TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArgType> XprType;
466 static constexpr int Layout = TensorEvaluator<LeftArgType, Device>::Layout;
469 int(TensorEvaluator<LeftArgType, Device>::IsAligned) & int(TensorEvaluator<RightArgType, Device>::IsAligned),
470 PacketAccess = int(TensorEvaluator<LeftArgType, Device>::PacketAccess) &
471 int(TensorEvaluator<RightArgType, Device>::PacketAccess) &
472 int(internal::functor_traits<BinaryOp>::PacketAccess),
473 BlockAccess = int(TensorEvaluator<LeftArgType, Device>::BlockAccess) &
474 int(TensorEvaluator<RightArgType, Device>::BlockAccess),
475 PreferBlockAccess = int(TensorEvaluator<LeftArgType, Device>::PreferBlockAccess) |
476 int(TensorEvaluator<RightArgType, Device>::PreferBlockAccess),
481 EIGEN_DEVICE_FUNC TensorEvaluator(
const XprType& op,
const Device& device)
483 m_functor(op.functor()),
484 m_leftImpl(op.lhsExpression(), device),
485 m_rightImpl(op.rhsExpression(), device) {
486 EIGEN_STATIC_ASSERT((
static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) ==
487 static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout) ||
488 internal::traits<XprType>::NumDimensions <= 1),
489 YOU_MADE_A_PROGRAMMING_MISTAKE);
490 eigen_assert(dimensions_match(m_leftImpl.dimensions(), m_rightImpl.dimensions()));
493 typedef typename XprType::Index Index;
494 typedef typename XprType::Scalar Scalar;
495 typedef typename internal::traits<XprType>::Scalar CoeffReturnType;
496 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
497 static constexpr int PacketSize = PacketType<CoeffReturnType, Device>::size;
498 typedef typename TensorEvaluator<LeftArgType, Device>::Dimensions Dimensions;
499 typedef StorageMemory<CoeffReturnType, Device> Storage;
500 typedef typename Storage::Type EvaluatorPointerType;
502 static constexpr int NumDims = internal::array_size<typename TensorEvaluator<LeftArgType, Device>::Dimensions>::value;
505 typedef internal::TensorBlockDescriptor<NumDims, Index> TensorBlockDesc;
506 typedef internal::TensorBlockScratchAllocator<Device> TensorBlockScratch;
508 typedef typename TensorEvaluator<const LeftArgType, Device>::TensorBlock LeftTensorBlock;
509 typedef typename TensorEvaluator<const RightArgType, Device>::TensorBlock RightTensorBlock;
511 typedef internal::TensorCwiseBinaryBlock<BinaryOp, LeftTensorBlock, RightTensorBlock> TensorBlock;
514 EIGEN_DEVICE_FUNC
const Dimensions& dimensions()
const {
516 return m_leftImpl.dimensions();
519 EIGEN_STRONG_INLINE
bool evalSubExprsIfNeeded(EvaluatorPointerType) {
520 m_leftImpl.evalSubExprsIfNeeded(NULL);
521 m_rightImpl.evalSubExprsIfNeeded(NULL);
525#ifdef EIGEN_USE_THREADS
526 template <
typename EvalSubExprsCallback>
527 EIGEN_STRONG_INLINE
void evalSubExprsIfNeededAsync(EvaluatorPointerType, EvalSubExprsCallback done) {
529 m_leftImpl.evalSubExprsIfNeededAsync(
530 nullptr, [
this, done](
bool) { m_rightImpl.evalSubExprsIfNeededAsync(
nullptr, [done](
bool) { done(
true); }); });
534 EIGEN_STRONG_INLINE
void cleanup() {
535 m_leftImpl.cleanup();
536 m_rightImpl.cleanup();
539 EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index)
const {
540 return m_functor(m_leftImpl.coeff(index), m_rightImpl.coeff(index));
542 template <
int LoadMode>
543 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index)
const {
544 return m_functor.packetOp(m_leftImpl.template packet<LoadMode>(index),
545 m_rightImpl.template packet<LoadMode>(index));
548 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(
bool vectorized)
const {
549 const double functor_cost = internal::functor_traits<BinaryOp>::Cost;
550 return m_leftImpl.costPerCoeff(vectorized) + m_rightImpl.costPerCoeff(vectorized) +
551 TensorOpCost(0, 0, functor_cost, vectorized, PacketSize);
554 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE internal::TensorBlockResourceRequirements getResourceRequirements()
const {
555 static const double functor_cost = internal::functor_traits<BinaryOp>::Cost;
556 return internal::TensorBlockResourceRequirements::merge(m_leftImpl.getResourceRequirements(),
557 m_rightImpl.getResourceRequirements())
558 .addCostPerCoeff({0, 0, functor_cost / PacketSize});
561 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlock block(TensorBlockDesc& desc, TensorBlockScratch& scratch,
562 bool =
false)
const {
563 desc.DropDestinationBuffer();
564 return TensorBlock(m_leftImpl.block(desc, scratch), m_rightImpl.block(desc, scratch), m_functor);
567 EIGEN_DEVICE_FUNC EvaluatorPointerType data()
const {
return NULL; }
570 const Device EIGEN_DEVICE_REF m_device;
571 const BinaryOp m_functor;
572 TensorEvaluator<LeftArgType, Device> m_leftImpl;
573 TensorEvaluator<RightArgType, Device> m_rightImpl;
578template <
typename TernaryOp,
typename Arg1Type,
typename Arg2Type,
typename Arg3Type,
typename Device>
579struct TensorEvaluator<const TensorCwiseTernaryOp<TernaryOp, Arg1Type, Arg2Type, Arg3Type>, Device> {
580 typedef TensorCwiseTernaryOp<TernaryOp, Arg1Type, Arg2Type, Arg3Type> XprType;
582 static constexpr int Layout = TensorEvaluator<Arg1Type, Device>::Layout;
584 IsAligned = TensorEvaluator<Arg1Type, Device>::IsAligned & TensorEvaluator<Arg2Type, Device>::IsAligned &
585 TensorEvaluator<Arg3Type, Device>::IsAligned,
586 PacketAccess = TensorEvaluator<Arg1Type, Device>::PacketAccess && TensorEvaluator<Arg2Type, Device>::PacketAccess &&
587 TensorEvaluator<Arg3Type, Device>::PacketAccess && internal::functor_traits<TernaryOp>::PacketAccess,
589 PreferBlockAccess = TensorEvaluator<Arg1Type, Device>::PreferBlockAccess ||
590 TensorEvaluator<Arg2Type, Device>::PreferBlockAccess ||
591 TensorEvaluator<Arg3Type, Device>::PreferBlockAccess,
596 EIGEN_DEVICE_FUNC TensorEvaluator(
const XprType& op,
const Device& device)
597 : m_functor(op.functor()),
598 m_arg1Impl(op.arg1Expression(), device),
599 m_arg2Impl(op.arg2Expression(), device),
600 m_arg3Impl(op.arg3Expression(), device) {
601 EIGEN_STATIC_ASSERT((
static_cast<int>(TensorEvaluator<Arg1Type, Device>::Layout) ==
602 static_cast<int>(TensorEvaluator<Arg3Type, Device>::Layout) ||
603 internal::traits<XprType>::NumDimensions <= 1),
604 YOU_MADE_A_PROGRAMMING_MISTAKE);
606 EIGEN_STATIC_ASSERT((internal::is_same<
typename internal::traits<Arg1Type>::StorageKind,
607 typename internal::traits<Arg2Type>::StorageKind>::value),
608 STORAGE_KIND_MUST_MATCH)
609 EIGEN_STATIC_ASSERT((internal::is_same<
typename internal::traits<Arg1Type>::StorageKind,
610 typename internal::traits<Arg3Type>::StorageKind>::value),
611 STORAGE_KIND_MUST_MATCH)
612 EIGEN_STATIC_ASSERT((internal::is_same<
typename internal::traits<Arg1Type>::Index,
613 typename internal::traits<Arg2Type>::Index>::value),
614 STORAGE_INDEX_MUST_MATCH)
615 EIGEN_STATIC_ASSERT((internal::is_same<
typename internal::traits<Arg1Type>::Index,
616 typename internal::traits<Arg3Type>::Index>::value),
617 STORAGE_INDEX_MUST_MATCH)
619 eigen_assert(dimensions_match(m_arg1Impl.dimensions(), m_arg2Impl.dimensions()) &&
620 dimensions_match(m_arg1Impl.dimensions(), m_arg3Impl.dimensions()));
623 typedef typename XprType::Index Index;
624 typedef typename XprType::Scalar Scalar;
625 typedef typename internal::traits<XprType>::Scalar CoeffReturnType;
626 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
627 static constexpr int PacketSize = PacketType<CoeffReturnType, Device>::size;
628 typedef typename TensorEvaluator<Arg1Type, Device>::Dimensions Dimensions;
629 typedef StorageMemory<CoeffReturnType, Device> Storage;
630 typedef typename Storage::Type EvaluatorPointerType;
633 typedef internal::TensorBlockNotImplemented TensorBlock;
636 EIGEN_DEVICE_FUNC
const Dimensions& dimensions()
const {
638 return m_arg1Impl.dimensions();
641 EIGEN_STRONG_INLINE
bool evalSubExprsIfNeeded(EvaluatorPointerType) {
642 m_arg1Impl.evalSubExprsIfNeeded(NULL);
643 m_arg2Impl.evalSubExprsIfNeeded(NULL);
644 m_arg3Impl.evalSubExprsIfNeeded(NULL);
647 EIGEN_STRONG_INLINE
void cleanup() {
648 m_arg1Impl.cleanup();
649 m_arg2Impl.cleanup();
650 m_arg3Impl.cleanup();
653 EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index)
const {
654 return m_functor(m_arg1Impl.coeff(index), m_arg2Impl.coeff(index), m_arg3Impl.coeff(index));
656 template <
int LoadMode>
657 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index)
const {
658 return m_functor.packetOp(m_arg1Impl.template packet<LoadMode>(index), m_arg2Impl.template packet<LoadMode>(index),
659 m_arg3Impl.template packet<LoadMode>(index));
662 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(
bool vectorized)
const {
663 const double functor_cost = internal::functor_traits<TernaryOp>::Cost;
664 return m_arg1Impl.costPerCoeff(vectorized) + m_arg2Impl.costPerCoeff(vectorized) +
665 m_arg3Impl.costPerCoeff(vectorized) + TensorOpCost(0, 0, functor_cost, vectorized, PacketSize);
668 EIGEN_DEVICE_FUNC EvaluatorPointerType data()
const {
return NULL; }
671 const TernaryOp m_functor;
672 TensorEvaluator<Arg1Type, Device> m_arg1Impl;
673 TensorEvaluator<Arg2Type, Device> m_arg2Impl;
674 TensorEvaluator<Arg3Type, Device> m_arg3Impl;
679template <
typename IfArgType,
typename ThenArgType,
typename ElseArgType,
typename Device>
680struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType>, Device> {
681 typedef TensorSelectOp<IfArgType, ThenArgType, ElseArgType> XprType;
682 typedef typename XprType::Scalar Scalar;
684 using TernarySelectOp = internal::scalar_boolean_select_op<typename internal::traits<ThenArgType>::Scalar,
685 typename internal::traits<ElseArgType>::Scalar,
686 typename internal::traits<IfArgType>::Scalar>;
687 static constexpr bool TernaryPacketAccess =
688 TensorEvaluator<ThenArgType, Device>::PacketAccess && TensorEvaluator<ElseArgType, Device>::PacketAccess &&
689 TensorEvaluator<IfArgType, Device>::PacketAccess && internal::functor_traits<TernarySelectOp>::PacketAccess;
691 static constexpr int Layout = TensorEvaluator<IfArgType, Device>::Layout;
693 IsAligned = TensorEvaluator<ThenArgType, Device>::IsAligned & TensorEvaluator<ElseArgType, Device>::IsAligned,
694 PacketAccess = (TensorEvaluator<ThenArgType, Device>::PacketAccess &&
695 TensorEvaluator<ElseArgType, Device>::PacketAccess && PacketType<Scalar, Device>::HasBlend) ||
697 BlockAccess = TensorEvaluator<IfArgType, Device>::BlockAccess &&
698 TensorEvaluator<ThenArgType, Device>::BlockAccess &&
699 TensorEvaluator<ElseArgType, Device>::BlockAccess,
700 PreferBlockAccess = TensorEvaluator<IfArgType, Device>::PreferBlockAccess ||
701 TensorEvaluator<ThenArgType, Device>::PreferBlockAccess ||
702 TensorEvaluator<ElseArgType, Device>::PreferBlockAccess,
707 EIGEN_DEVICE_FUNC TensorEvaluator(
const XprType& op,
const Device& device)
708 : m_condImpl(op.ifExpression(), device),
709 m_thenImpl(op.thenExpression(), device),
710 m_elseImpl(op.elseExpression(), device) {
711 EIGEN_STATIC_ASSERT((
static_cast<int>(TensorEvaluator<IfArgType, Device>::Layout) ==
712 static_cast<int>(TensorEvaluator<ThenArgType, Device>::Layout)),
713 YOU_MADE_A_PROGRAMMING_MISTAKE);
714 EIGEN_STATIC_ASSERT((
static_cast<int>(TensorEvaluator<IfArgType, Device>::Layout) ==
715 static_cast<int>(TensorEvaluator<ElseArgType, Device>::Layout)),
716 YOU_MADE_A_PROGRAMMING_MISTAKE);
717 eigen_assert(dimensions_match(m_condImpl.dimensions(), m_thenImpl.dimensions()));
718 eigen_assert(dimensions_match(m_thenImpl.dimensions(), m_elseImpl.dimensions()));
721 typedef typename XprType::Index Index;
722 typedef typename internal::traits<XprType>::Scalar CoeffReturnType;
723 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
724 static constexpr int PacketSize = PacketType<CoeffReturnType, Device>::size;
725 typedef typename TensorEvaluator<IfArgType, Device>::Dimensions Dimensions;
726 typedef StorageMemory<CoeffReturnType, Device> Storage;
727 typedef typename Storage::Type EvaluatorPointerType;
729 static constexpr int NumDims = internal::array_size<Dimensions>::value;
732 typedef internal::TensorBlockDescriptor<NumDims, Index> TensorBlockDesc;
733 typedef internal::TensorBlockScratchAllocator<Device> TensorBlockScratch;
735 typedef typename TensorEvaluator<const IfArgType, Device>::TensorBlock IfArgTensorBlock;
736 typedef typename TensorEvaluator<const ThenArgType, Device>::TensorBlock ThenArgTensorBlock;
737 typedef typename TensorEvaluator<const ElseArgType, Device>::TensorBlock ElseArgTensorBlock;
739 struct TensorSelectOpBlockFactory {
740 template <
typename IfArgXprType,
typename ThenArgXprType,
typename ElseArgXprType>
742 typedef TensorSelectOp<const IfArgXprType, const ThenArgXprType, const ElseArgXprType> type;
745 template <
typename IfArgXprType,
typename ThenArgXprType,
typename ElseArgXprType>
746 typename XprType<IfArgXprType, ThenArgXprType, ElseArgXprType>::type expr(
const IfArgXprType& if_expr,
747 const ThenArgXprType& then_expr,
748 const ElseArgXprType& else_expr)
const {
749 return typename XprType<IfArgXprType, ThenArgXprType, ElseArgXprType>::type(if_expr, then_expr, else_expr);
753 typedef internal::TensorTernaryExprBlock<TensorSelectOpBlockFactory, IfArgTensorBlock, ThenArgTensorBlock,
758 EIGEN_DEVICE_FUNC
const Dimensions& dimensions()
const {
760 return m_condImpl.dimensions();
763 EIGEN_STRONG_INLINE
bool evalSubExprsIfNeeded(EvaluatorPointerType) {
764 m_condImpl.evalSubExprsIfNeeded(NULL);
765 m_thenImpl.evalSubExprsIfNeeded(NULL);
766 m_elseImpl.evalSubExprsIfNeeded(NULL);
770#ifdef EIGEN_USE_THREADS
771 template <
typename EvalSubExprsCallback>
772 EIGEN_STRONG_INLINE
void evalSubExprsIfNeededAsync(EvaluatorPointerType, EvalSubExprsCallback done) {
773 m_condImpl.evalSubExprsIfNeeded(
nullptr, [
this, done](
bool) {
774 m_thenImpl.evalSubExprsIfNeeded(
775 nullptr, [
this, done](
bool) { m_elseImpl.evalSubExprsIfNeeded(
nullptr, [done](
bool) { done(
true); }); });
780 EIGEN_STRONG_INLINE
void cleanup() {
781 m_condImpl.cleanup();
782 m_thenImpl.cleanup();
783 m_elseImpl.cleanup();
786 EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index)
const {
787 return m_condImpl.coeff(index) ? m_thenImpl.coeff(index) : m_elseImpl.coeff(index);
790 template <
int LoadMode,
bool UseTernary = TernaryPacketAccess, std::enable_if_t<!UseTernary,
bool> = true>
791 EIGEN_DEVICE_FUNC PacketReturnType packet(Index index)
const {
792 internal::Selector<PacketSize> select;
794 for (Index i = 0; i < PacketSize; ++i) {
795 select.select[i] = m_condImpl.coeff(index + i);
797 return internal::pblend(select, m_thenImpl.template packet<LoadMode>(index),
798 m_elseImpl.template packet<LoadMode>(index));
801 template <
int LoadMode,
bool UseTernary = TernaryPacketAccess, std::enable_if_t<UseTernary,
bool> = true>
802 EIGEN_DEVICE_FUNC PacketReturnType packet(Index index)
const {
803 return TernarySelectOp().template packetOp<PacketReturnType>(m_thenImpl.template packet<LoadMode>(index),
804 m_elseImpl.template packet<LoadMode>(index),
805 m_condImpl.template packet<LoadMode>(index));
808 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(
bool vectorized)
const {
809 return m_condImpl.costPerCoeff(vectorized) +
810 m_thenImpl.costPerCoeff(vectorized).cwiseMax(m_elseImpl.costPerCoeff(vectorized));
813 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE internal::TensorBlockResourceRequirements getResourceRequirements()
const {
814 auto then_req = m_thenImpl.getResourceRequirements();
815 auto else_req = m_elseImpl.getResourceRequirements();
817 auto merged_req = internal::TensorBlockResourceRequirements::merge(then_req, else_req);
818 merged_req.cost_per_coeff = then_req.cost_per_coeff.cwiseMax(else_req.cost_per_coeff);
820 return internal::TensorBlockResourceRequirements::merge(m_condImpl.getResourceRequirements(), merged_req);
823 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlock block(TensorBlockDesc& desc, TensorBlockScratch& scratch,
824 bool =
false)
const {
827 desc.DropDestinationBuffer();
829 return TensorBlock(m_condImpl.block(desc, scratch), m_thenImpl.block(desc, scratch),
830 m_elseImpl.block(desc, scratch), TensorSelectOpBlockFactory());
833 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EvaluatorPointerType data()
const {
return NULL; }
837 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void bind(cl::sycl::handler& cgh)
const {
838 m_condImpl.bind(cgh);
839 m_thenImpl.bind(cgh);
840 m_elseImpl.bind(cgh);
844 TensorEvaluator<IfArgType, Device> m_condImpl;
845 TensorEvaluator<ThenArgType, Device> m_thenImpl;
846 TensorEvaluator<ElseArgType, Device> m_elseImpl;
851#if defined(EIGEN_USE_SYCL) && defined(SYCL_COMPILER_IS_DPCPP)
852template <
typename Derived,
typename Device>
853struct cl::sycl::is_device_copyable<
854 Eigen::TensorEvaluator<Derived, Device>,
855 std::enable_if_t<!std::is_trivially_copyable<Eigen::TensorEvaluator<Derived, Device>>::value>> : std::true_type {};
Tensor binary expression.
Definition TensorExpr.h:171
Tensor nullary expression.
Definition TensorExpr.h:42
Tensor unary expression.
Definition TensorExpr.h:98
Namespace containing all symbols from the Eigen library.
The tensor evaluator class.
Definition TensorEvaluator.h:30