Eigen-unsupported  5.0.1-dev+284dcc12
 
Loading...
Searching...
No Matches
TensorRef.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CXX11_TENSOR_TENSOR_REF_H
11#define EIGEN_CXX11_TENSOR_TENSOR_REF_H
12
13// IWYU pragma: private
14#include "./InternalHeaderCheck.h"
15
16namespace Eigen {
17
18namespace internal {
19
20template <typename Dimensions, typename Scalar>
21class TensorLazyBaseEvaluator {
22 public:
23 TensorLazyBaseEvaluator() : m_refcount(0) {}
24 virtual ~TensorLazyBaseEvaluator() {}
25
26 EIGEN_DEVICE_FUNC virtual const Dimensions& dimensions() const = 0;
27 EIGEN_DEVICE_FUNC virtual const Scalar* data() const = 0;
28
29 EIGEN_DEVICE_FUNC virtual const Scalar coeff(DenseIndex index) const = 0;
30 EIGEN_DEVICE_FUNC virtual Scalar& coeffRef(DenseIndex index) = 0;
31
32 void incrRefCount() { ++m_refcount; }
33 void decrRefCount() { --m_refcount; }
34 int refCount() const { return m_refcount; }
35
36 private:
37 // No copy, no assignment;
38 TensorLazyBaseEvaluator(const TensorLazyBaseEvaluator& other);
39 TensorLazyBaseEvaluator& operator=(const TensorLazyBaseEvaluator& other);
40
41 int m_refcount;
42};
43
44template <typename Dimensions, typename Expr, typename Device>
45class TensorLazyEvaluatorReadOnly
46 : public TensorLazyBaseEvaluator<Dimensions, typename TensorEvaluator<Expr, Device>::Scalar> {
47 public:
48 // typedef typename TensorEvaluator<Expr, Device>::Dimensions Dimensions;
49 typedef typename TensorEvaluator<Expr, Device>::Scalar Scalar;
50 typedef StorageMemory<Scalar, Device> Storage;
51 typedef typename Storage::Type EvaluatorPointerType;
52 typedef TensorEvaluator<Expr, Device> EvalType;
53
54 TensorLazyEvaluatorReadOnly(const Expr& expr, const Device& device) : m_impl(expr, device), m_dummy(Scalar(0)) {
55 EIGEN_STATIC_ASSERT(
56 internal::array_size<Dimensions>::value == internal::array_size<typename EvalType::Dimensions>::value,
57 "Dimension sizes must match.");
58 const auto& other_dims = m_impl.dimensions();
59 for (std::size_t i = 0; i < m_dims.size(); ++i) {
60 m_dims[i] = other_dims[i];
61 }
62 m_impl.evalSubExprsIfNeeded(NULL);
63 }
64 virtual ~TensorLazyEvaluatorReadOnly() { m_impl.cleanup(); }
65
66 EIGEN_DEVICE_FUNC virtual const Dimensions& dimensions() const { return m_dims; }
67 EIGEN_DEVICE_FUNC virtual const Scalar* data() const { return m_impl.data(); }
68
69 EIGEN_DEVICE_FUNC virtual const Scalar coeff(DenseIndex index) const { return m_impl.coeff(index); }
70 EIGEN_DEVICE_FUNC virtual Scalar& coeffRef(DenseIndex /*index*/) {
71 eigen_assert(false && "can't reference the coefficient of a rvalue");
72 return m_dummy;
73 };
74
75 protected:
76 TensorEvaluator<Expr, Device> m_impl;
77 Dimensions m_dims;
78 Scalar m_dummy;
79};
80
81template <typename Dimensions, typename Expr, typename Device>
82class TensorLazyEvaluatorWritable : public TensorLazyEvaluatorReadOnly<Dimensions, Expr, Device> {
83 public:
84 typedef TensorLazyEvaluatorReadOnly<Dimensions, Expr, Device> Base;
85 typedef typename Base::Scalar Scalar;
86 typedef StorageMemory<Scalar, Device> Storage;
87 typedef typename Storage::Type EvaluatorPointerType;
88
89 TensorLazyEvaluatorWritable(const Expr& expr, const Device& device) : Base(expr, device) {}
90 virtual ~TensorLazyEvaluatorWritable() {}
91
92 EIGEN_DEVICE_FUNC virtual Scalar& coeffRef(DenseIndex index) { return this->m_impl.coeffRef(index); }
93};
94
95template <typename Dimensions, typename Expr, typename Device, bool IsWritable>
96class TensorLazyEvaluator : public std::conditional_t<IsWritable, TensorLazyEvaluatorWritable<Dimensions, Expr, Device>,
97 TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device>> {
98 public:
99 typedef std::conditional_t<IsWritable, TensorLazyEvaluatorWritable<Dimensions, Expr, Device>,
100 TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device>>
101 Base;
102 typedef typename Base::Scalar Scalar;
103
104 TensorLazyEvaluator(const Expr& expr, const Device& device) : Base(expr, device) {}
105 virtual ~TensorLazyEvaluator() {}
106};
107
108template <typename Derived>
109class TensorRefBase : public TensorBase<Derived> {
110 public:
111 typedef typename traits<Derived>::PlainObjectType PlainObjectType;
112 typedef typename PlainObjectType::Base Base;
113 typedef typename Eigen::internal::nested<Derived>::type Nested;
114 typedef typename traits<PlainObjectType>::StorageKind StorageKind;
115 typedef typename traits<PlainObjectType>::Index Index;
116 typedef typename traits<PlainObjectType>::Scalar Scalar;
117 typedef typename NumTraits<Scalar>::Real RealScalar;
118 typedef typename Base::CoeffReturnType CoeffReturnType;
119 typedef Scalar* PointerType;
120 typedef PointerType PointerArgType;
121
122 static constexpr Index NumIndices = PlainObjectType::NumIndices;
123 typedef typename PlainObjectType::Dimensions Dimensions;
124
125 static constexpr int Layout = PlainObjectType::Layout;
126 enum {
127 IsAligned = false,
128 PacketAccess = false,
129 BlockAccess = false,
130 PreferBlockAccess = false,
131 CoordAccess = false, // to be implemented
132 RawAccess = false
133 };
134
135 //===- Tensor block evaluation strategy (see TensorBlock.h) -----------===//
136 typedef TensorBlockNotImplemented TensorBlock;
137 //===------------------------------------------------------------------===//
138
139 EIGEN_STRONG_INLINE TensorRefBase() : m_evaluator(NULL) {}
140
141 TensorRefBase(const TensorRefBase& other) : TensorBase<Derived>(other), m_evaluator(other.m_evaluator) {
142 eigen_assert(m_evaluator->refCount() > 0);
143 m_evaluator->incrRefCount();
144 }
145
146 TensorRefBase& operator=(const TensorRefBase& other) {
147 if (this != &other) {
148 unrefEvaluator();
149 m_evaluator = other.m_evaluator;
150 eigen_assert(m_evaluator->refCount() > 0);
151 m_evaluator->incrRefCount();
152 }
153 return *this;
154 }
155
156 template <typename Expression,
157 typename EnableIf = std::enable_if_t<!std::is_same<std::decay_t<Expression>, Derived>::value>>
158 EIGEN_STRONG_INLINE TensorRefBase(const Expression& expr)
159 : m_evaluator(new TensorLazyEvaluator<Dimensions, Expression, DefaultDevice,
160 /*IsWritable=*/!std::is_const<PlainObjectType>::value &&
161 bool(is_lvalue<Expression>::value)>(expr, DefaultDevice())) {
162 m_evaluator->incrRefCount();
163 }
164
165 template <typename Expression,
166 typename EnableIf = std::enable_if_t<!std::is_same<std::decay_t<Expression>, Derived>::value>>
167 EIGEN_STRONG_INLINE TensorRefBase& operator=(const Expression& expr) {
168 unrefEvaluator();
169 m_evaluator = new TensorLazyEvaluator < Dimensions, Expression, DefaultDevice,
170 /*IsWritable=*/!std::is_const<PlainObjectType>::value&& bool(is_lvalue<Expression>::value) >
171 (expr, DefaultDevice());
172 m_evaluator->incrRefCount();
173 return *this;
174 }
175
176 ~TensorRefBase() { unrefEvaluator(); }
177
178 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index rank() const { return m_evaluator->dimensions().size(); }
179 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_evaluator->dimensions()[n]; }
180 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_evaluator->dimensions(); }
181 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index size() const { return m_evaluator->dimensions().TotalSize(); }
182 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar* data() const { return m_evaluator->data(); }
183
184 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(Index index) const { return m_evaluator->coeff(index); }
185
186 template <typename... IndexTypes>
187 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar operator()(Index firstIndex, IndexTypes... otherIndices) const {
188 const std::size_t num_indices = (sizeof...(otherIndices) + 1);
189 const array<Index, num_indices> indices{{firstIndex, otherIndices...}};
190 return coeff(indices);
191 }
192
193 template <std::size_t NumIndices>
194 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(const array<Index, NumIndices>& indices) const {
195 const Dimensions& dims = this->dimensions();
196 Index index = 0;
197 if (PlainObjectType::Options & RowMajor) {
198 index += indices[0];
199 for (size_t i = 1; i < NumIndices; ++i) {
200 index = index * dims[i] + indices[i];
201 }
202 } else {
203 index += indices[NumIndices - 1];
204 for (int i = NumIndices - 2; i >= 0; --i) {
205 index = index * dims[i] + indices[i];
206 }
207 }
208 return m_evaluator->coeff(index);
209 }
210
211 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index index) const { return m_evaluator->coeff(index); }
212
213 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { return m_evaluator->coeffRef(index); }
214
215 protected:
216 TensorLazyBaseEvaluator<Dimensions, Scalar>* evaluator() { return m_evaluator; }
217
218 private:
219 EIGEN_STRONG_INLINE void unrefEvaluator() {
220 if (m_evaluator) {
221 m_evaluator->decrRefCount();
222 if (m_evaluator->refCount() == 0) {
223 delete m_evaluator;
224 }
225 }
226 }
227
228 TensorLazyBaseEvaluator<Dimensions, Scalar>* m_evaluator;
229};
230
231} // namespace internal
232
240template <typename PlainObjectType>
241class TensorRef : public internal::TensorRefBase<TensorRef<PlainObjectType>> {
242 typedef internal::TensorRefBase<TensorRef<PlainObjectType>> Base;
243
244 public:
245 using Scalar = typename Base::Scalar;
246 using Dimensions = typename Base::Dimensions;
247
248 EIGEN_STRONG_INLINE TensorRef() : Base() {}
249
250 EIGEN_STRONG_INLINE TensorRef(const TensorRef& other) : Base(other) {}
251
252 template <typename Expression>
253 EIGEN_STRONG_INLINE TensorRef(const Expression& expr) : Base(expr) {
254 EIGEN_STATIC_ASSERT(internal::is_lvalue<Expression>::value,
255 "Expression must be mutable to create a mutable TensorRef<Expression>. Did you mean "
256 "TensorRef<const Expression>?)");
257 }
258
259 TensorRef& operator=(const TensorRef& other) { return Base::operator=(other).derived(); }
260
261 template <typename Expression>
262 EIGEN_STRONG_INLINE TensorRef& operator=(const Expression& expr) {
263 EIGEN_STATIC_ASSERT(internal::is_lvalue<Expression>::value,
264 "Expression must be mutable to create a mutable TensorRef<Expression>. Did you mean "
265 "TensorRef<const Expression>?)");
266 return Base::operator=(expr).derived();
267 }
268
269 template <typename... IndexTypes>
270 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index firstIndex, IndexTypes... otherIndices) {
271 const std::size_t num_indices = (sizeof...(otherIndices) + 1);
272 const array<Index, num_indices> indices{{firstIndex, otherIndices...}};
273 return coeffRef(indices);
274 }
275
276 template <std::size_t NumIndices>
277 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(const array<Index, NumIndices>& indices) {
278 const Dimensions& dims = this->dimensions();
279 Index index = 0;
280 if (PlainObjectType::Options & RowMajor) {
281 index += indices[0];
282 for (size_t i = 1; i < NumIndices; ++i) {
283 index = index * dims[i] + indices[i];
284 }
285 } else {
286 index += indices[NumIndices - 1];
287 for (int i = NumIndices - 2; i >= 0; --i) {
288 index = index * dims[i] + indices[i];
289 }
290 }
291 return Base::evaluator()->coeffRef(index);
292 }
293
294 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { return Base::evaluator()->coeffRef(index); }
295};
296
304template <typename PlainObjectType>
305class TensorRef<const PlainObjectType> : public internal::TensorRefBase<TensorRef<const PlainObjectType>> {
306 typedef internal::TensorRefBase<TensorRef<const PlainObjectType>> Base;
307
308 public:
309 EIGEN_STRONG_INLINE TensorRef() : Base() {}
310
311 EIGEN_STRONG_INLINE TensorRef(const TensorRef& other) : Base(other) {}
312
313 template <typename Expression>
314 EIGEN_STRONG_INLINE TensorRef(const Expression& expr) : Base(expr) {}
315
316 TensorRef& operator=(const TensorRef& other) { return Base::operator=(other).derived(); }
317
318 template <typename Expression>
319 EIGEN_STRONG_INLINE TensorRef& operator=(const Expression& expr) {
320 return Base::operator=(expr).derived();
321 }
322};
323
324// evaluator for rvalues
325template <typename Derived, typename Device>
326struct TensorEvaluator<const TensorRef<Derived>, Device> {
327 typedef typename Derived::Index Index;
328 typedef typename Derived::Scalar Scalar;
329 typedef typename Derived::Scalar CoeffReturnType;
330 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
331 typedef typename Derived::Dimensions Dimensions;
332 typedef StorageMemory<CoeffReturnType, Device> Storage;
333 typedef typename Storage::Type EvaluatorPointerType;
334
335 static constexpr int Layout = TensorRef<Derived>::Layout;
336 enum {
337 IsAligned = false,
338 PacketAccess = false,
339 BlockAccess = false,
340 PreferBlockAccess = false,
341 CoordAccess = false, // to be implemented
342 RawAccess = false
343 };
344
345 //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
346 typedef internal::TensorBlockNotImplemented TensorBlock;
347 //===--------------------------------------------------------------------===//
348
349 EIGEN_STRONG_INLINE TensorEvaluator(const TensorRef<Derived>& m, const Device&) : m_ref(m) {}
350
351 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_ref.dimensions(); }
352
353 EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType) { return true; }
354
355 EIGEN_STRONG_INLINE void cleanup() {}
356
357 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { return m_ref.coeff(index); }
358
359 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { return m_ref.coeffRef(index); }
360
361 EIGEN_DEVICE_FUNC const Scalar* data() const { return m_ref.data(); }
362
363 protected:
364 TensorRef<Derived> m_ref;
365};
366
367// evaluator for lvalues
368template <typename Derived, typename Device>
369struct TensorEvaluator<TensorRef<Derived>, Device> : public TensorEvaluator<const TensorRef<Derived>, Device> {
370 typedef typename Derived::Index Index;
371 typedef typename Derived::Scalar Scalar;
372 typedef typename Derived::Scalar CoeffReturnType;
373 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
374 typedef typename Derived::Dimensions Dimensions;
375
376 typedef TensorEvaluator<const TensorRef<Derived>, Device> Base;
377
378 enum { IsAligned = false, PacketAccess = false, BlockAccess = false, PreferBlockAccess = false, RawAccess = false };
379
380 //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
381 typedef internal::TensorBlockNotImplemented TensorBlock;
382 //===--------------------------------------------------------------------===//
383
384 EIGEN_STRONG_INLINE TensorEvaluator(TensorRef<Derived>& m, const Device& d) : Base(m, d) {}
385
386 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { return this->m_ref.coeffRef(index); }
387};
388
389} // end namespace Eigen
390
391#endif // EIGEN_CXX11_TENSOR_TENSOR_REF_H
A reference to a tensor expression The expression will be evaluated lazily (as much as possible).
Definition TensorRef.h:241
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The tensor evaluator class.
Definition TensorEvaluator.h:30