Loading...
Searching...
No Matches
TensorLayoutSwap.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CXX11_TENSOR_TENSOR_LAYOUT_SWAP_H
11#define EIGEN_CXX11_TENSOR_TENSOR_LAYOUT_SWAP_H
12
13namespace Eigen {
14
15namespace internal {
16template<typename XprType>
17struct traits<TensorLayoutSwapOp<XprType> > : public traits<XprType>
18{
19 typedef typename XprType::Scalar Scalar;
20 typedef traits<XprType> XprTraits;
21 typedef typename XprTraits::StorageKind StorageKind;
22 typedef typename XprTraits::Index Index;
23 typedef typename XprType::Nested Nested;
24 typedef typename remove_reference<Nested>::type _Nested;
25 static const int NumDimensions = traits<XprType>::NumDimensions;
26 static const int Layout = (traits<XprType>::Layout == ColMajor) ? RowMajor : ColMajor;
27};
28
29template<typename XprType>
30struct eval<TensorLayoutSwapOp<XprType>, Eigen::Dense>
31{
32 typedef const TensorLayoutSwapOp<XprType>& type;
33};
34
35template<typename XprType>
36struct nested<TensorLayoutSwapOp<XprType>, 1, typename eval<TensorLayoutSwapOp<XprType> >::type>
37{
38 typedef TensorLayoutSwapOp<XprType> type;
39};
40
41} // end namespace internal
42
65template <typename XprType>
66class TensorLayoutSwapOp : public TensorBase<TensorLayoutSwapOp<XprType>, WriteAccessors> {
67 public:
68 typedef typename Eigen::internal::traits<TensorLayoutSwapOp>::Scalar Scalar;
69 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
70 typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
71 typedef typename Eigen::internal::nested<TensorLayoutSwapOp>::type Nested;
72 typedef typename Eigen::internal::traits<TensorLayoutSwapOp>::StorageKind StorageKind;
73 typedef typename Eigen::internal::traits<TensorLayoutSwapOp>::Index Index;
74
75 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorLayoutSwapOp(const XprType& expr)
76 : m_xpr(expr) {}
77
78 EIGEN_DEVICE_FUNC
79 const typename internal::remove_all<typename XprType::Nested>::type&
80 expression() const { return m_xpr; }
81
82 EIGEN_DEVICE_FUNC
83 EIGEN_STRONG_INLINE TensorLayoutSwapOp& operator = (const TensorLayoutSwapOp& other)
84 {
85 typedef TensorAssignOp<TensorLayoutSwapOp, const TensorLayoutSwapOp> Assign;
86 Assign assign(*this, other);
87 internal::TensorExecutor<const Assign, DefaultDevice>::run(assign, DefaultDevice());
88 return *this;
89 }
90
91 template<typename OtherDerived>
92 EIGEN_DEVICE_FUNC
93 EIGEN_STRONG_INLINE TensorLayoutSwapOp& operator = (const OtherDerived& other)
94 {
95 typedef TensorAssignOp<TensorLayoutSwapOp, const OtherDerived> Assign;
96 Assign assign(*this, other);
97 internal::TensorExecutor<const Assign, DefaultDevice>::run(assign, DefaultDevice());
98 return *this;
99 }
100
101 protected:
102 typename XprType::Nested m_xpr;
103};
104
105
106// Eval as rvalue
107template<typename ArgType, typename Device>
108struct TensorEvaluator<const TensorLayoutSwapOp<ArgType>, Device>
109{
110 typedef TensorLayoutSwapOp<ArgType> XprType;
111 typedef typename XprType::Index Index;
112 static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
113 typedef DSizes<Index, NumDims> Dimensions;
114
115 enum {
116 IsAligned = TensorEvaluator<ArgType, Device>::IsAligned,
117 PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
118 Layout = (static_cast<int>(TensorEvaluator<ArgType, Device>::Layout) == static_cast<int>(ColMajor)) ? RowMajor : ColMajor,
119 CoordAccess = false, // to be implemented
120 RawAccess = TensorEvaluator<ArgType, Device>::RawAccess
121 };
122
123 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
124 : m_impl(op.expression(), device)
125 {
126 for(int i = 0; i < NumDims; ++i) {
127 m_dimensions[i] = m_impl.dimensions()[NumDims-1-i];
128 }
129 }
130
131 typedef typename XprType::Scalar Scalar;
132 typedef typename XprType::CoeffReturnType CoeffReturnType;
133 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
134
135 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
136
137 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) {
138 return m_impl.evalSubExprsIfNeeded(data);
139 }
140 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
141 m_impl.cleanup();
142 }
143
144 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
145 {
146 return m_impl.coeff(index);
147 }
148
149 template<int LoadMode>
150 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
151 {
152 return m_impl.template packet<LoadMode>(index);
153 }
154
155 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
156 return m_impl.costPerCoeff(vectorized);
157 }
158
159 EIGEN_DEVICE_FUNC Scalar* data() const { return m_impl.data(); }
160
161 const TensorEvaluator<ArgType, Device>& impl() const { return m_impl; }
162
163 protected:
164 TensorEvaluator<ArgType, Device> m_impl;
165 Dimensions m_dimensions;
166};
167
168
169// Eval as lvalue
170template<typename ArgType, typename Device>
171 struct TensorEvaluator<TensorLayoutSwapOp<ArgType>, Device>
172 : public TensorEvaluator<const TensorLayoutSwapOp<ArgType>, Device>
173{
174 typedef TensorEvaluator<const TensorLayoutSwapOp<ArgType>, Device> Base;
175 typedef TensorLayoutSwapOp<ArgType> XprType;
176
177 enum {
178 IsAligned = TensorEvaluator<ArgType, Device>::IsAligned,
179 PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
180 Layout = (static_cast<int>(TensorEvaluator<ArgType, Device>::Layout) == static_cast<int>(ColMajor)) ? RowMajor : ColMajor,
181 CoordAccess = false // to be implemented
182 };
183
184 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
185 : Base(op, device)
186 { }
187
188 typedef typename XprType::Index Index;
189 typedef typename XprType::Scalar Scalar;
190 typedef typename XprType::CoeffReturnType CoeffReturnType;
191 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
192
193 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType& coeffRef(Index index)
194 {
195 return this->m_impl.coeffRef(index);
196 }
197 template <int StoreMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
198 void writePacket(Index index, const PacketReturnType& x)
199 {
200 this->m_impl.template writePacket<StoreMode>(index, x);
201 }
202};
203
204} // end namespace Eigen
205
206#endif // EIGEN_CXX11_TENSOR_TENSOR_LAYOUT_SWAP_H
The tensor base class.
Definition TensorForwardDeclarations.h:29
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The tensor evaluator class.
Definition TensorEvaluator.h:27
const Device & device() const
required by sycl in order to construct sycl buffer from raw pointer
Definition TensorEvaluator.h:112