10#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONCATENATION_H
11#define EIGEN_CXX11_TENSOR_TENSOR_CONCATENATION_H
14#include "./InternalHeaderCheck.h"
19template <
typename Axis,
typename LhsXprType,
typename RhsXprType>
20struct traits<TensorConcatenationOp<Axis, LhsXprType, RhsXprType> > {
22 typedef typename promote_storage_type<typename LhsXprType::Scalar, typename RhsXprType::Scalar>::ret Scalar;
23 typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind,
24 typename traits<RhsXprType>::StorageKind>::ret StorageKind;
26 typename promote_index_type<typename traits<LhsXprType>::Index,
typename traits<RhsXprType>::Index>::type
Index;
27 typedef typename LhsXprType::Nested LhsNested;
28 typedef typename RhsXprType::Nested RhsNested;
29 typedef std::remove_reference_t<LhsNested> LhsNested_;
30 typedef std::remove_reference_t<RhsNested> RhsNested_;
31 static constexpr int NumDimensions = traits<LhsXprType>::NumDimensions;
32 static constexpr int Layout = traits<LhsXprType>::Layout;
34 typedef std::conditional_t<Pointer_type_promotion<typename LhsXprType::Scalar, Scalar>::val,
35 typename traits<LhsXprType>::PointerType,
typename traits<RhsXprType>::PointerType>
39template <
typename Axis,
typename LhsXprType,
typename RhsXprType>
40struct eval<TensorConcatenationOp<Axis, LhsXprType, RhsXprType>, Eigen::Dense> {
41 typedef const TensorConcatenationOp<Axis, LhsXprType, RhsXprType>& type;
44template <
typename Axis,
typename LhsXprType,
typename RhsXprType>
45struct nested<TensorConcatenationOp<Axis, LhsXprType, RhsXprType>, 1,
46 typename eval<TensorConcatenationOp<Axis, LhsXprType, RhsXprType> >::type> {
47 typedef TensorConcatenationOp<Axis, LhsXprType, RhsXprType> type;
57template <
typename Axis,
typename LhsXprType,
typename RhsXprType>
58class TensorConcatenationOp :
public TensorBase<TensorConcatenationOp<Axis, LhsXprType, RhsXprType>, WriteAccessors> {
61 typedef typename internal::traits<TensorConcatenationOp>::Scalar Scalar;
62 typedef typename internal::traits<TensorConcatenationOp>::StorageKind StorageKind;
63 typedef typename internal::traits<TensorConcatenationOp>::Index Index;
64 typedef typename internal::nested<TensorConcatenationOp>::type Nested;
65 typedef typename internal::promote_storage_type<
typename LhsXprType::CoeffReturnType,
66 typename RhsXprType::CoeffReturnType>::ret CoeffReturnType;
69 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorConcatenationOp(
const LhsXprType& lhs,
const RhsXprType& rhs, Axis axis)
70 : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_axis(axis) {}
72 EIGEN_DEVICE_FUNC
const internal::remove_all_t<typename LhsXprType::Nested>& lhsExpression()
const {
76 EIGEN_DEVICE_FUNC
const internal::remove_all_t<typename RhsXprType::Nested>& rhsExpression()
const {
80 EIGEN_DEVICE_FUNC
const Axis& axis()
const {
return m_axis; }
82 EIGEN_TENSOR_INHERIT_ASSIGNMENT_OPERATORS(TensorConcatenationOp)
84 typename LhsXprType::Nested m_lhs_xpr;
85 typename RhsXprType::Nested m_rhs_xpr;
90template <
typename Axis,
typename LeftArgType,
typename RightArgType,
typename Device>
93 typedef typename XprType::Index
Index;
94 static constexpr int NumDims = internal::array_size<typename TensorEvaluator<LeftArgType, Device>::Dimensions>::value;
95 static constexpr int RightNumDims =
96 internal::array_size<typename TensorEvaluator<RightArgType, Device>::Dimensions>::value;
98 typedef typename XprType::Scalar
Scalar;
100 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
101 typedef StorageMemory<CoeffReturnType, Device> Storage;
102 typedef typename Storage::Type EvaluatorPointerType;
103 static constexpr int Layout = TensorEvaluator<LeftArgType, Device>::Layout;
107 TensorEvaluator<LeftArgType, Device>::PacketAccess && TensorEvaluator<RightArgType, Device>::PacketAccess,
115 typedef internal::TensorBlockNotImplemented TensorBlock;
118 EIGEN_STRONG_INLINE TensorEvaluator(
const XprType& op,
const Device& device)
119 : m_leftImpl(op.lhsExpression(), device), m_rightImpl(op.rhsExpression(), device), m_axis(op.axis()) {
120 EIGEN_STATIC_ASSERT((
static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) ==
121 static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout) ||
123 YOU_MADE_A_PROGRAMMING_MISTAKE);
124 EIGEN_STATIC_ASSERT((NumDims == RightNumDims), YOU_MADE_A_PROGRAMMING_MISTAKE);
125 EIGEN_STATIC_ASSERT((NumDims > 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
127 eigen_assert(0 <= m_axis && m_axis < NumDims);
128 const Dimensions& lhs_dims = m_leftImpl.dimensions();
129 const Dimensions& rhs_dims = m_rightImpl.dimensions();
132 for (; i < m_axis; ++i) {
133 eigen_assert(lhs_dims[i] > 0);
134 eigen_assert(lhs_dims[i] == rhs_dims[i]);
135 m_dimensions[i] = lhs_dims[i];
137 eigen_assert(lhs_dims[i] > 0);
138 eigen_assert(rhs_dims[i] > 0);
139 m_dimensions[i] = lhs_dims[i] + rhs_dims[i];
140 for (++i; i < NumDims; ++i) {
141 eigen_assert(lhs_dims[i] > 0);
142 eigen_assert(lhs_dims[i] == rhs_dims[i]);
143 m_dimensions[i] = lhs_dims[i];
147 if (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) {
148 m_leftStrides[0] = 1;
149 m_rightStrides[0] = 1;
150 m_outputStrides[0] = 1;
152 for (
int j = 1; j < NumDims; ++j) {
153 m_leftStrides[j] = m_leftStrides[j - 1] * lhs_dims[j - 1];
154 m_rightStrides[j] = m_rightStrides[j - 1] * rhs_dims[j - 1];
155 m_outputStrides[j] = m_outputStrides[j - 1] * m_dimensions[j - 1];
158 m_leftStrides[NumDims - 1] = 1;
159 m_rightStrides[NumDims - 1] = 1;
160 m_outputStrides[NumDims - 1] = 1;
162 for (
int j = NumDims - 2; j >= 0; --j) {
163 m_leftStrides[j] = m_leftStrides[j + 1] * lhs_dims[j + 1];
164 m_rightStrides[j] = m_rightStrides[j + 1] * rhs_dims[j + 1];
165 m_outputStrides[j] = m_outputStrides[j + 1] * m_dimensions[j + 1];
170 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Dimensions& dimensions()
const {
return m_dimensions; }
173 EIGEN_STRONG_INLINE
bool evalSubExprsIfNeeded(EvaluatorPointerType) {
174 m_leftImpl.evalSubExprsIfNeeded(NULL);
175 m_rightImpl.evalSubExprsIfNeeded(NULL);
179 EIGEN_STRONG_INLINE
void cleanup() {
180 m_leftImpl.cleanup();
181 m_rightImpl.cleanup();
186 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index)
const {
188 array<Index, NumDims> subs;
189 if (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) {
190 for (
int i = NumDims - 1; i > 0; --i) {
191 subs[i] = index / m_outputStrides[i];
192 index -= subs[i] * m_outputStrides[i];
196 for (
int i = 0; i < NumDims - 1; ++i) {
197 subs[i] = index / m_outputStrides[i];
198 index -= subs[i] * m_outputStrides[i];
200 subs[NumDims - 1] = index;
203 const Dimensions& left_dims = m_leftImpl.dimensions();
204 if (subs[m_axis] < left_dims[m_axis]) {
206 if (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) {
207 left_index = subs[0];
209 for (
int i = 1; i < NumDims; ++i) {
210 left_index += (subs[i] % left_dims[i]) * m_leftStrides[i];
213 left_index = subs[NumDims - 1];
215 for (
int i = NumDims - 2; i >= 0; --i) {
216 left_index += (subs[i] % left_dims[i]) * m_leftStrides[i];
219 return m_leftImpl.coeff(left_index);
221 subs[m_axis] -= left_dims[m_axis];
222 const Dimensions& right_dims = m_rightImpl.dimensions();
224 if (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) {
225 right_index = subs[0];
227 for (
int i = 1; i < NumDims; ++i) {
228 right_index += (subs[i] % right_dims[i]) * m_rightStrides[i];
231 right_index = subs[NumDims - 1];
233 for (
int i = NumDims - 2; i >= 0; --i) {
234 right_index += (subs[i] % right_dims[i]) * m_rightStrides[i];
237 return m_rightImpl.coeff(right_index);
242 template <
int LoadMode>
243 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index)
const {
244 const int packetSize = PacketType<CoeffReturnType, Device>::size;
245 EIGEN_STATIC_ASSERT((packetSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
246 eigen_assert(index + packetSize - 1 < dimensions().TotalSize());
248 EIGEN_ALIGN_MAX CoeffReturnType values[packetSize];
250 for (
int i = 0; i < packetSize; ++i) {
251 values[i] = coeff(index + i);
253 PacketReturnType rslt = internal::pload<PacketReturnType>(values);
257 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(
bool vectorized)
const {
258 const double compute_cost = NumDims * (2 * TensorOpCost::AddCost<Index>() + 2 * TensorOpCost::MulCost<Index>() +
259 TensorOpCost::DivCost<Index>() + TensorOpCost::ModCost<Index>());
260 const double lhs_size = m_leftImpl.dimensions().TotalSize();
261 const double rhs_size = m_rightImpl.dimensions().TotalSize();
262 return (lhs_size / (lhs_size + rhs_size)) * m_leftImpl.costPerCoeff(vectorized) +
263 (rhs_size / (lhs_size + rhs_size)) * m_rightImpl.costPerCoeff(vectorized) + TensorOpCost(0, 0, compute_cost);
266 EIGEN_DEVICE_FUNC EvaluatorPointerType data()
const {
return NULL; }
269 Dimensions m_dimensions;
270 array<Index, NumDims> m_outputStrides;
271 array<Index, NumDims> m_leftStrides;
272 array<Index, NumDims> m_rightStrides;
273 TensorEvaluator<LeftArgType, Device> m_leftImpl;
274 TensorEvaluator<RightArgType, Device> m_rightImpl;
279template <
typename Axis,
typename LeftArgType,
typename RightArgType,
typename Device>
281 :
public TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> {
282 typedef TensorEvaluator<const TensorConcatenationOp<Axis, LeftArgType, RightArgType>, Device> Base;
283 typedef TensorConcatenationOp<Axis, LeftArgType, RightArgType> XprType;
284 typedef typename Base::Dimensions Dimensions;
285 static constexpr int Layout = TensorEvaluator<LeftArgType, Device>::Layout;
289 TensorEvaluator<LeftArgType, Device>::PacketAccess && TensorEvaluator<RightArgType, Device>::PacketAccess,
291 PreferBlockAccess = TensorEvaluator<LeftArgType, Device>::PreferBlockAccess ||
292 TensorEvaluator<RightArgType, Device>::PreferBlockAccess,
297 typedef internal::TensorBlockNotImplemented TensorBlock;
300 EIGEN_STRONG_INLINE TensorEvaluator(XprType& op,
const Device& device) : Base(op, device) {
301 EIGEN_STATIC_ASSERT((
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)), YOU_MADE_A_PROGRAMMING_MISTAKE);
304 typedef typename XprType::Index Index;
305 typedef typename XprType::Scalar Scalar;
306 typedef typename XprType::CoeffReturnType CoeffReturnType;
307 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
309 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType& coeffRef(Index index)
const {
311 array<Index, Base::NumDims> subs;
312 for (
int i = Base::NumDims - 1; i > 0; --i) {
313 subs[i] = index / this->m_outputStrides[i];
314 index -= subs[i] * this->m_outputStrides[i];
318 const Dimensions& left_dims = this->m_leftImpl.dimensions();
319 if (subs[this->m_axis] < left_dims[this->m_axis]) {
320 Index left_index = subs[0];
321 for (
int i = 1; i < Base::NumDims; ++i) {
322 left_index += (subs[i] % left_dims[i]) * this->m_leftStrides[i];
324 return this->m_leftImpl.coeffRef(left_index);
326 subs[this->m_axis] -= left_dims[this->m_axis];
327 const Dimensions& right_dims = this->m_rightImpl.dimensions();
328 Index right_index = subs[0];
329 for (
int i = 1; i < Base::NumDims; ++i) {
330 right_index += (subs[i] % right_dims[i]) * this->m_rightStrides[i];
332 return this->m_rightImpl.coeffRef(right_index);
336 template <
int StoreMode>
337 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void writePacket(Index index,
const PacketReturnType& x)
const {
338 const int packetSize = PacketType<CoeffReturnType, Device>::size;
339 EIGEN_STATIC_ASSERT((packetSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
340 eigen_assert(index + packetSize - 1 < this->dimensions().TotalSize());
342 EIGEN_ALIGN_MAX CoeffReturnType values[packetSize];
343 internal::pstore<CoeffReturnType, PacketReturnType>(values, x);
344 for (
int i = 0; i < packetSize; ++i) {
345 coeffRef(index + i) = values[i];
The tensor base class.
Definition TensorForwardDeclarations.h:68
Tensor concatenation class.
Definition TensorConcatenation.h:58
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The tensor evaluator class.
Definition TensorEvaluator.h:30