10#ifndef EIGEN_CXX11_TENSOR_TENSOR_REF_H
11#define EIGEN_CXX11_TENSOR_TENSOR_REF_H
14#include "./InternalHeaderCheck.h"
20template <
typename Dimensions,
typename Scalar>
21class TensorLazyBaseEvaluator {
23 TensorLazyBaseEvaluator() : m_refcount(0) {}
24 virtual ~TensorLazyBaseEvaluator() {}
26 EIGEN_DEVICE_FUNC
virtual const Dimensions& dimensions()
const = 0;
27 EIGEN_DEVICE_FUNC
virtual const Scalar* data()
const = 0;
29 EIGEN_DEVICE_FUNC
virtual const Scalar coeff(DenseIndex index)
const = 0;
30 EIGEN_DEVICE_FUNC
virtual Scalar& coeffRef(DenseIndex index) = 0;
32 void incrRefCount() { ++m_refcount; }
33 void decrRefCount() { --m_refcount; }
34 int refCount()
const {
return m_refcount; }
38 TensorLazyBaseEvaluator(
const TensorLazyBaseEvaluator& other);
39 TensorLazyBaseEvaluator& operator=(
const TensorLazyBaseEvaluator& other);
44template <
typename Dimensions,
typename Expr,
typename Device>
45class TensorLazyEvaluatorReadOnly
46 :
public TensorLazyBaseEvaluator<Dimensions, typename TensorEvaluator<Expr, Device>::Scalar> {
49 typedef typename TensorEvaluator<Expr, Device>::Scalar Scalar;
50 typedef StorageMemory<Scalar, Device> Storage;
51 typedef typename Storage::Type EvaluatorPointerType;
52 typedef TensorEvaluator<Expr, Device> EvalType;
54 TensorLazyEvaluatorReadOnly(
const Expr& expr,
const Device& device) : m_impl(expr, device), m_dummy(Scalar(0)) {
56 internal::array_size<Dimensions>::value == internal::array_size<typename EvalType::Dimensions>::value,
57 "Dimension sizes must match.");
58 const auto& other_dims = m_impl.dimensions();
59 for (std::size_t i = 0; i < m_dims.size(); ++i) {
60 m_dims[i] = other_dims[i];
62 m_impl.evalSubExprsIfNeeded(NULL);
64 virtual ~TensorLazyEvaluatorReadOnly() { m_impl.cleanup(); }
66 EIGEN_DEVICE_FUNC
virtual const Dimensions& dimensions()
const {
return m_dims; }
67 EIGEN_DEVICE_FUNC
virtual const Scalar* data()
const {
return m_impl.data(); }
69 EIGEN_DEVICE_FUNC
virtual const Scalar coeff(DenseIndex index)
const {
return m_impl.coeff(index); }
70 EIGEN_DEVICE_FUNC
virtual Scalar& coeffRef(DenseIndex ) {
71 eigen_assert(
false &&
"can't reference the coefficient of a rvalue");
76 TensorEvaluator<Expr, Device> m_impl;
81template <
typename Dimensions,
typename Expr,
typename Device>
82class TensorLazyEvaluatorWritable :
public TensorLazyEvaluatorReadOnly<Dimensions, Expr, Device> {
84 typedef TensorLazyEvaluatorReadOnly<Dimensions, Expr, Device> Base;
85 typedef typename Base::Scalar Scalar;
86 typedef StorageMemory<Scalar, Device> Storage;
87 typedef typename Storage::Type EvaluatorPointerType;
89 TensorLazyEvaluatorWritable(
const Expr& expr,
const Device& device) : Base(expr, device) {}
90 virtual ~TensorLazyEvaluatorWritable() {}
92 EIGEN_DEVICE_FUNC
virtual Scalar& coeffRef(DenseIndex index) {
return this->m_impl.coeffRef(index); }
95template <
typename Dimensions,
typename Expr,
typename Device,
bool IsWritable>
96class TensorLazyEvaluator :
public std::conditional_t<IsWritable, TensorLazyEvaluatorWritable<Dimensions, Expr, Device>,
97 TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device>> {
99 typedef std::conditional_t<IsWritable, TensorLazyEvaluatorWritable<Dimensions, Expr, Device>,
100 TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device>>
102 typedef typename Base::Scalar Scalar;
104 TensorLazyEvaluator(
const Expr& expr,
const Device& device) : Base(expr, device) {}
105 virtual ~TensorLazyEvaluator() {}
108template <
typename Derived>
109class TensorRefBase :
public TensorBase<Derived> {
111 typedef typename traits<Derived>::PlainObjectType PlainObjectType;
112 typedef typename PlainObjectType::Base Base;
113 typedef typename Eigen::internal::nested<Derived>::type Nested;
114 typedef typename traits<PlainObjectType>::StorageKind StorageKind;
115 typedef typename traits<PlainObjectType>::Index Index;
116 typedef typename traits<PlainObjectType>::Scalar Scalar;
117 typedef typename NumTraits<Scalar>::Real RealScalar;
118 typedef typename Base::CoeffReturnType CoeffReturnType;
119 typedef Scalar* PointerType;
120 typedef PointerType PointerArgType;
122 static constexpr Index NumIndices = PlainObjectType::NumIndices;
123 typedef typename PlainObjectType::Dimensions Dimensions;
125 static constexpr int Layout = PlainObjectType::Layout;
128 PacketAccess =
false,
130 PreferBlockAccess =
false,
136 typedef TensorBlockNotImplemented TensorBlock;
139 EIGEN_STRONG_INLINE TensorRefBase() : m_evaluator(NULL) {}
141 TensorRefBase(
const TensorRefBase& other) : TensorBase<Derived>(other), m_evaluator(other.m_evaluator) {
142 eigen_assert(m_evaluator->refCount() > 0);
143 m_evaluator->incrRefCount();
146 TensorRefBase& operator=(
const TensorRefBase& other) {
147 if (
this != &other) {
149 m_evaluator = other.m_evaluator;
150 eigen_assert(m_evaluator->refCount() > 0);
151 m_evaluator->incrRefCount();
156 template <
typename Expression,
157 typename EnableIf = std::enable_if_t<!std::is_same<std::decay_t<Expression>, Derived>::value>>
158 EIGEN_STRONG_INLINE TensorRefBase(
const Expression& expr)
159 : m_evaluator(new TensorLazyEvaluator<Dimensions, Expression, DefaultDevice,
160 !std::is_const<PlainObjectType>::value &&
161 bool(is_lvalue<Expression>::value)>(expr, DefaultDevice())) {
162 m_evaluator->incrRefCount();
165 template <
typename Expression,
166 typename EnableIf = std::enable_if_t<!std::is_same<std::decay_t<Expression>, Derived>::value>>
167 EIGEN_STRONG_INLINE TensorRefBase& operator=(
const Expression& expr) {
169 m_evaluator =
new TensorLazyEvaluator < Dimensions, Expression, DefaultDevice,
170 !std::is_const<PlainObjectType>::value&& bool(is_lvalue<Expression>::value) >
171 (expr, DefaultDevice());
172 m_evaluator->incrRefCount();
176 ~TensorRefBase() { unrefEvaluator(); }
178 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index rank()
const {
return m_evaluator->dimensions().size(); }
179 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index dimension(Index n)
const {
return m_evaluator->dimensions()[n]; }
180 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Dimensions& dimensions()
const {
return m_evaluator->dimensions(); }
181 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index size()
const {
return m_evaluator->dimensions().TotalSize(); }
182 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Scalar* data()
const {
return m_evaluator->data(); }
184 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Scalar operator()(Index index)
const {
return m_evaluator->coeff(index); }
186 template <
typename... IndexTypes>
187 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Scalar operator()(Index firstIndex, IndexTypes... otherIndices)
const {
188 const std::size_t num_indices = (
sizeof...(otherIndices) + 1);
189 const array<Index, num_indices> indices{{firstIndex, otherIndices...}};
190 return coeff(indices);
193 template <std::
size_t NumIndices>
194 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Scalar coeff(
const array<Index, NumIndices>& indices)
const {
195 const Dimensions& dims = this->dimensions();
197 if (PlainObjectType::Options &
RowMajor) {
199 for (
size_t i = 1; i < NumIndices; ++i) {
200 index = index * dims[i] + indices[i];
203 index += indices[NumIndices - 1];
204 for (
int i = NumIndices - 2; i >= 0; --i) {
205 index = index * dims[i] + indices[i];
208 return m_evaluator->coeff(index);
211 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Scalar coeff(Index index)
const {
return m_evaluator->coeff(index); }
213 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) {
return m_evaluator->coeffRef(index); }
216 TensorLazyBaseEvaluator<Dimensions, Scalar>* evaluator() {
return m_evaluator; }
219 EIGEN_STRONG_INLINE
void unrefEvaluator() {
221 m_evaluator->decrRefCount();
222 if (m_evaluator->refCount() == 0) {
228 TensorLazyBaseEvaluator<Dimensions, Scalar>* m_evaluator;
240template <
typename PlainObjectType>
241class TensorRef :
public internal::TensorRefBase<TensorRef<PlainObjectType>> {
242 typedef internal::TensorRefBase<TensorRef<PlainObjectType>> Base;
245 using Scalar =
typename Base::Scalar;
246 using Dimensions =
typename Base::Dimensions;
248 EIGEN_STRONG_INLINE TensorRef() : Base() {}
250 EIGEN_STRONG_INLINE TensorRef(
const TensorRef& other) : Base(other) {}
252 template <
typename Expression>
253 EIGEN_STRONG_INLINE TensorRef(
const Expression& expr) : Base(expr) {
254 EIGEN_STATIC_ASSERT(internal::is_lvalue<Expression>::value,
255 "Expression must be mutable to create a mutable TensorRef<Expression>. Did you mean "
256 "TensorRef<const Expression>?)");
259 TensorRef& operator=(
const TensorRef& other) {
return Base::operator=(other).derived(); }
261 template <
typename Expression>
262 EIGEN_STRONG_INLINE TensorRef& operator=(
const Expression& expr) {
263 EIGEN_STATIC_ASSERT(internal::is_lvalue<Expression>::value,
264 "Expression must be mutable to create a mutable TensorRef<Expression>. Did you mean "
265 "TensorRef<const Expression>?)");
266 return Base::operator=(expr).derived();
269 template <
typename... IndexTypes>
270 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index firstIndex, IndexTypes... otherIndices) {
271 const std::size_t num_indices = (
sizeof...(otherIndices) + 1);
272 const array<Index, num_indices> indices{{firstIndex, otherIndices...}};
273 return coeffRef(indices);
276 template <std::
size_t NumIndices>
277 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(
const array<Index, NumIndices>& indices) {
278 const Dimensions& dims = this->dimensions();
280 if (PlainObjectType::Options &
RowMajor) {
282 for (
size_t i = 1; i < NumIndices; ++i) {
283 index = index * dims[i] + indices[i];
286 index += indices[NumIndices - 1];
287 for (
int i = NumIndices - 2; i >= 0; --i) {
288 index = index * dims[i] + indices[i];
291 return Base::evaluator()->coeffRef(index);
294 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) {
return Base::evaluator()->coeffRef(index); }
304template <
typename PlainObjectType>
305class TensorRef<const PlainObjectType> :
public internal::TensorRefBase<TensorRef<const PlainObjectType>> {
306 typedef internal::TensorRefBase<TensorRef<const PlainObjectType>> Base;
309 EIGEN_STRONG_INLINE TensorRef() : Base() {}
311 EIGEN_STRONG_INLINE TensorRef(
const TensorRef& other) : Base(other) {}
313 template <
typename Expression>
314 EIGEN_STRONG_INLINE TensorRef(
const Expression& expr) : Base(expr) {}
316 TensorRef& operator=(
const TensorRef& other) {
return Base::operator=(other).derived(); }
318 template <
typename Expression>
319 EIGEN_STRONG_INLINE TensorRef& operator=(
const Expression& expr) {
320 return Base::operator=(expr).derived();
325template <
typename Derived,
typename Device>
327 typedef typename Derived::Index
Index;
328 typedef typename Derived::Scalar
Scalar;
330 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
331 typedef typename Derived::Dimensions
Dimensions;
332 typedef StorageMemory<CoeffReturnType, Device> Storage;
333 typedef typename Storage::Type EvaluatorPointerType;
335 static constexpr int Layout = TensorRef<Derived>::Layout;
338 PacketAccess =
false,
340 PreferBlockAccess =
false,
346 typedef internal::TensorBlockNotImplemented TensorBlock;
349 EIGEN_STRONG_INLINE TensorEvaluator(
const TensorRef<Derived>& m,
const Device&) : m_ref(m) {}
351 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Dimensions& dimensions()
const {
return m_ref.dimensions(); }
353 EIGEN_STRONG_INLINE
bool evalSubExprsIfNeeded(EvaluatorPointerType) {
return true; }
355 EIGEN_STRONG_INLINE
void cleanup() {}
357 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index)
const {
return m_ref.coeff(index); }
359 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) {
return m_ref.coeffRef(index); }
361 EIGEN_DEVICE_FUNC
const Scalar* data()
const {
return m_ref.data(); }
364 TensorRef<Derived> m_ref;
368template <
typename Derived,
typename Device>
370 typedef typename Derived::Index Index;
371 typedef typename Derived::Scalar Scalar;
372 typedef typename Derived::Scalar CoeffReturnType;
373 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
374 typedef typename Derived::Dimensions Dimensions;
376 typedef TensorEvaluator<const TensorRef<Derived>, Device> Base;
378 enum { IsAligned =
false, PacketAccess =
false, BlockAccess =
false, PreferBlockAccess =
false, RawAccess =
false };
381 typedef internal::TensorBlockNotImplemented TensorBlock;
384 EIGEN_STRONG_INLINE TensorEvaluator(TensorRef<Derived>& m,
const Device& d) : Base(m, d) {}
386 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) {
return this->m_ref.coeffRef(index); }
A reference to a tensor expression The expression will be evaluated lazily (as much as possible).
Definition TensorRef.h:241
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The tensor evaluator class.
Definition TensorEvaluator.h:30