Eigen-unsupported  5.0.1-dev+284dcc12
 
Loading...
Searching...
No Matches
TensorConcatenation.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONCATENATION_H
11#define EIGEN_CXX11_TENSOR_TENSOR_CONCATENATION_H
12
13// IWYU pragma: private
14#include "./InternalHeaderCheck.h"
15
16namespace Eigen {
17
18namespace internal {
19template <typename Axis, typename LhsXprType, typename RhsXprType>
20struct traits<TensorConcatenationOp<Axis, LhsXprType, RhsXprType> > {
21 // Type promotion to handle the case where the types of the lhs and the rhs are different.
22 typedef typename promote_storage_type<typename LhsXprType::Scalar, typename RhsXprType::Scalar>::ret Scalar;
23 typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind,
24 typename traits<RhsXprType>::StorageKind>::ret StorageKind;
25 typedef
26 typename promote_index_type<typename traits<LhsXprType>::Index, typename traits<RhsXprType>::Index>::type Index;
27 typedef typename LhsXprType::Nested LhsNested;
28 typedef typename RhsXprType::Nested RhsNested;
29 typedef std::remove_reference_t<LhsNested> LhsNested_;
30 typedef std::remove_reference_t<RhsNested> RhsNested_;
31 static constexpr int NumDimensions = traits<LhsXprType>::NumDimensions;
32 static constexpr int Layout = traits<LhsXprType>::Layout;
33 enum { Flags = 0 };
34 typedef std::conditional_t<Pointer_type_promotion<typename LhsXprType::Scalar, Scalar>::val,
35 typename traits<LhsXprType>::PointerType, typename traits<RhsXprType>::PointerType>
36 PointerType;
37};
38
39template <typename Axis, typename LhsXprType, typename RhsXprType>
40struct eval<TensorConcatenationOp<Axis, LhsXprType, RhsXprType>, Eigen::Dense> {
41 typedef const TensorConcatenationOp<Axis, LhsXprType, RhsXprType>& type;
42};
43
44template <typename Axis, typename LhsXprType, typename RhsXprType>
45struct nested<TensorConcatenationOp<Axis, LhsXprType, RhsXprType>, 1,
46 typename eval<TensorConcatenationOp<Axis, LhsXprType, RhsXprType> >::type> {
47 typedef TensorConcatenationOp<Axis, LhsXprType, RhsXprType> type;
48};
49
50} // end namespace internal
51
57template <typename Axis, typename LhsXprType, typename RhsXprType>
58class TensorConcatenationOp : public TensorBase<TensorConcatenationOp<Axis, LhsXprType, RhsXprType>, WriteAccessors> {
59 public:
61 typedef typename internal::traits<TensorConcatenationOp>::Scalar Scalar;
62 typedef typename internal::traits<TensorConcatenationOp>::StorageKind StorageKind;
63 typedef typename internal::traits<TensorConcatenationOp>::Index Index;
64 typedef typename internal::nested<TensorConcatenationOp>::type Nested;
65 typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType,
66 typename RhsXprType::CoeffReturnType>::ret CoeffReturnType;
67 typedef typename NumTraits<Scalar>::Real RealScalar;
68
69 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorConcatenationOp(const LhsXprType& lhs, const RhsXprType& rhs, Axis axis)
70 : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_axis(axis) {}
71
72 EIGEN_DEVICE_FUNC const internal::remove_all_t<typename LhsXprType::Nested>& lhsExpression() const {
73 return m_lhs_xpr;
74 }
75
76 EIGEN_DEVICE_FUNC const internal::remove_all_t<typename RhsXprType::Nested>& rhsExpression() const {
77 return m_rhs_xpr;
78 }
79
80 EIGEN_DEVICE_FUNC const Axis& axis() const { return m_axis; }
81
82 EIGEN_TENSOR_INHERIT_ASSIGNMENT_OPERATORS(TensorConcatenationOp)
83 protected:
84 typename LhsXprType::Nested m_lhs_xpr;
85 typename RhsXprType::Nested m_rhs_xpr;
86 const Axis m_axis;
87};
88
89// Eval as rvalue
90template <typename Axis, typename LeftArgType, typename RightArgType, typename Device>
91struct TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> {
93 typedef typename XprType::Index Index;
94 static constexpr int NumDims = internal::array_size<typename TensorEvaluator<LeftArgType, Device>::Dimensions>::value;
95 static constexpr int RightNumDims =
96 internal::array_size<typename TensorEvaluator<RightArgType, Device>::Dimensions>::value;
97 typedef DSizes<Index, NumDims> Dimensions;
98 typedef typename XprType::Scalar Scalar;
99 typedef typename XprType::CoeffReturnType CoeffReturnType;
100 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
101 typedef StorageMemory<CoeffReturnType, Device> Storage;
102 typedef typename Storage::Type EvaluatorPointerType;
103 static constexpr int Layout = TensorEvaluator<LeftArgType, Device>::Layout;
104 enum {
105 IsAligned = false,
106 PacketAccess =
107 TensorEvaluator<LeftArgType, Device>::PacketAccess && TensorEvaluator<RightArgType, Device>::PacketAccess,
108 BlockAccess = false,
111 RawAccess = false
112 };
113
114 //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
115 typedef internal::TensorBlockNotImplemented TensorBlock;
116 //===--------------------------------------------------------------------===//
117
118 EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
119 : m_leftImpl(op.lhsExpression(), device), m_rightImpl(op.rhsExpression(), device), m_axis(op.axis()) {
120 EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) ==
121 static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout) ||
122 NumDims == 1),
123 YOU_MADE_A_PROGRAMMING_MISTAKE);
124 EIGEN_STATIC_ASSERT((NumDims == RightNumDims), YOU_MADE_A_PROGRAMMING_MISTAKE);
125 EIGEN_STATIC_ASSERT((NumDims > 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
126
127 eigen_assert(0 <= m_axis && m_axis < NumDims);
128 const Dimensions& lhs_dims = m_leftImpl.dimensions();
129 const Dimensions& rhs_dims = m_rightImpl.dimensions();
130 {
131 int i = 0;
132 for (; i < m_axis; ++i) {
133 eigen_assert(lhs_dims[i] > 0);
134 eigen_assert(lhs_dims[i] == rhs_dims[i]);
135 m_dimensions[i] = lhs_dims[i];
136 }
137 eigen_assert(lhs_dims[i] > 0); // Now i == m_axis.
138 eigen_assert(rhs_dims[i] > 0);
139 m_dimensions[i] = lhs_dims[i] + rhs_dims[i];
140 for (++i; i < NumDims; ++i) {
141 eigen_assert(lhs_dims[i] > 0);
142 eigen_assert(lhs_dims[i] == rhs_dims[i]);
143 m_dimensions[i] = lhs_dims[i];
144 }
145 }
146
147 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
148 m_leftStrides[0] = 1;
149 m_rightStrides[0] = 1;
150 m_outputStrides[0] = 1;
151
152 for (int j = 1; j < NumDims; ++j) {
153 m_leftStrides[j] = m_leftStrides[j - 1] * lhs_dims[j - 1];
154 m_rightStrides[j] = m_rightStrides[j - 1] * rhs_dims[j - 1];
155 m_outputStrides[j] = m_outputStrides[j - 1] * m_dimensions[j - 1];
156 }
157 } else {
158 m_leftStrides[NumDims - 1] = 1;
159 m_rightStrides[NumDims - 1] = 1;
160 m_outputStrides[NumDims - 1] = 1;
161
162 for (int j = NumDims - 2; j >= 0; --j) {
163 m_leftStrides[j] = m_leftStrides[j + 1] * lhs_dims[j + 1];
164 m_rightStrides[j] = m_rightStrides[j + 1] * rhs_dims[j + 1];
165 m_outputStrides[j] = m_outputStrides[j + 1] * m_dimensions[j + 1];
166 }
167 }
168 }
169
170 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
171
172 // TODO(phli): Add short-circuit memcpy evaluation if underlying data are linear?
173 EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType) {
174 m_leftImpl.evalSubExprsIfNeeded(NULL);
175 m_rightImpl.evalSubExprsIfNeeded(NULL);
176 return true;
177 }
178
179 EIGEN_STRONG_INLINE void cleanup() {
180 m_leftImpl.cleanup();
181 m_rightImpl.cleanup();
182 }
183
184 // TODO(phli): attempt to speed this up. The integer divisions and modulo are slow.
185 // See CL/76180724 comments for more ideas.
186 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
187 // Collect dimension-wise indices (subs).
188 array<Index, NumDims> subs;
189 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
190 for (int i = NumDims - 1; i > 0; --i) {
191 subs[i] = index / m_outputStrides[i];
192 index -= subs[i] * m_outputStrides[i];
193 }
194 subs[0] = index;
195 } else {
196 for (int i = 0; i < NumDims - 1; ++i) {
197 subs[i] = index / m_outputStrides[i];
198 index -= subs[i] * m_outputStrides[i];
199 }
200 subs[NumDims - 1] = index;
201 }
202
203 const Dimensions& left_dims = m_leftImpl.dimensions();
204 if (subs[m_axis] < left_dims[m_axis]) {
205 Index left_index;
206 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
207 left_index = subs[0];
208 EIGEN_UNROLL_LOOP
209 for (int i = 1; i < NumDims; ++i) {
210 left_index += (subs[i] % left_dims[i]) * m_leftStrides[i];
211 }
212 } else {
213 left_index = subs[NumDims - 1];
214 EIGEN_UNROLL_LOOP
215 for (int i = NumDims - 2; i >= 0; --i) {
216 left_index += (subs[i] % left_dims[i]) * m_leftStrides[i];
217 }
218 }
219 return m_leftImpl.coeff(left_index);
220 } else {
221 subs[m_axis] -= left_dims[m_axis];
222 const Dimensions& right_dims = m_rightImpl.dimensions();
223 Index right_index;
224 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
225 right_index = subs[0];
226 EIGEN_UNROLL_LOOP
227 for (int i = 1; i < NumDims; ++i) {
228 right_index += (subs[i] % right_dims[i]) * m_rightStrides[i];
229 }
230 } else {
231 right_index = subs[NumDims - 1];
232 EIGEN_UNROLL_LOOP
233 for (int i = NumDims - 2; i >= 0; --i) {
234 right_index += (subs[i] % right_dims[i]) * m_rightStrides[i];
235 }
236 }
237 return m_rightImpl.coeff(right_index);
238 }
239 }
240
241 // TODO(phli): Add a real vectorization.
242 template <int LoadMode>
243 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const {
244 const int packetSize = PacketType<CoeffReturnType, Device>::size;
245 EIGEN_STATIC_ASSERT((packetSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
246 eigen_assert(index + packetSize - 1 < dimensions().TotalSize());
247
248 EIGEN_ALIGN_MAX CoeffReturnType values[packetSize];
249 EIGEN_UNROLL_LOOP
250 for (int i = 0; i < packetSize; ++i) {
251 values[i] = coeff(index + i);
252 }
253 PacketReturnType rslt = internal::pload<PacketReturnType>(values);
254 return rslt;
255 }
256
257 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
258 const double compute_cost = NumDims * (2 * TensorOpCost::AddCost<Index>() + 2 * TensorOpCost::MulCost<Index>() +
259 TensorOpCost::DivCost<Index>() + TensorOpCost::ModCost<Index>());
260 const double lhs_size = m_leftImpl.dimensions().TotalSize();
261 const double rhs_size = m_rightImpl.dimensions().TotalSize();
262 return (lhs_size / (lhs_size + rhs_size)) * m_leftImpl.costPerCoeff(vectorized) +
263 (rhs_size / (lhs_size + rhs_size)) * m_rightImpl.costPerCoeff(vectorized) + TensorOpCost(0, 0, compute_cost);
264 }
265
266 EIGEN_DEVICE_FUNC EvaluatorPointerType data() const { return NULL; }
267
268 protected:
269 Dimensions m_dimensions;
270 array<Index, NumDims> m_outputStrides;
271 array<Index, NumDims> m_leftStrides;
272 array<Index, NumDims> m_rightStrides;
273 TensorEvaluator<LeftArgType, Device> m_leftImpl;
274 TensorEvaluator<RightArgType, Device> m_rightImpl;
275 const Axis m_axis;
276};
277
278// Eval as lvalue
279template <typename Axis, typename LeftArgType, typename RightArgType, typename Device>
280struct TensorEvaluator<TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device>
281 : public TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> {
282 typedef TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> Base;
283 typedef TensorConcatenationOp<Axis, LeftArgType, RightArgType> XprType;
284 typedef typename Base::Dimensions Dimensions;
285 static constexpr int Layout = TensorEvaluator<LeftArgType, Device>::Layout;
286 enum {
287 IsAligned = false,
288 PacketAccess =
289 TensorEvaluator<LeftArgType, Device>::PacketAccess && TensorEvaluator<RightArgType, Device>::PacketAccess,
290 BlockAccess = false,
291 PreferBlockAccess = TensorEvaluator<LeftArgType, Device>::PreferBlockAccess ||
292 TensorEvaluator<RightArgType, Device>::PreferBlockAccess,
293 RawAccess = false
294 };
295
296 //===- Tensor block evaluation strategy (see TensorBlock.h) -------------===//
297 typedef internal::TensorBlockNotImplemented TensorBlock;
298 //===--------------------------------------------------------------------===//
299
300 EIGEN_STRONG_INLINE TensorEvaluator(XprType& op, const Device& device) : Base(op, device) {
301 EIGEN_STATIC_ASSERT((static_cast<int>(Layout) == static_cast<int>(ColMajor)), YOU_MADE_A_PROGRAMMING_MISTAKE);
302 }
303
304 typedef typename XprType::Index Index;
305 typedef typename XprType::Scalar Scalar;
306 typedef typename XprType::CoeffReturnType CoeffReturnType;
307 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
308
309 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType& coeffRef(Index index) const {
310 // Collect dimension-wise indices (subs).
311 array<Index, Base::NumDims> subs;
312 for (int i = Base::NumDims - 1; i > 0; --i) {
313 subs[i] = index / this->m_outputStrides[i];
314 index -= subs[i] * this->m_outputStrides[i];
315 }
316 subs[0] = index;
317
318 const Dimensions& left_dims = this->m_leftImpl.dimensions();
319 if (subs[this->m_axis] < left_dims[this->m_axis]) {
320 Index left_index = subs[0];
321 for (int i = 1; i < Base::NumDims; ++i) {
322 left_index += (subs[i] % left_dims[i]) * this->m_leftStrides[i];
323 }
324 return this->m_leftImpl.coeffRef(left_index);
325 } else {
326 subs[this->m_axis] -= left_dims[this->m_axis];
327 const Dimensions& right_dims = this->m_rightImpl.dimensions();
328 Index right_index = subs[0];
329 for (int i = 1; i < Base::NumDims; ++i) {
330 right_index += (subs[i] % right_dims[i]) * this->m_rightStrides[i];
331 }
332 return this->m_rightImpl.coeffRef(right_index);
333 }
334 }
335
336 template <int StoreMode>
337 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void writePacket(Index index, const PacketReturnType& x) const {
338 const int packetSize = PacketType<CoeffReturnType, Device>::size;
339 EIGEN_STATIC_ASSERT((packetSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
340 eigen_assert(index + packetSize - 1 < this->dimensions().TotalSize());
341
342 EIGEN_ALIGN_MAX CoeffReturnType values[packetSize];
343 internal::pstore<CoeffReturnType, PacketReturnType>(values, x);
344 for (int i = 0; i < packetSize; ++i) {
345 coeffRef(index + i) = values[i];
346 }
347 }
348};
349
350} // end namespace Eigen
351
352#endif // EIGEN_CXX11_TENSOR_TENSOR_CONCATENATION_H
The tensor base class.
Definition TensorForwardDeclarations.h:68
Tensor concatenation class.
Definition TensorConcatenation.h:58
WriteAccessors
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The tensor evaluator class.
Definition TensorEvaluator.h:30