Loading...
Searching...
No Matches
TensorBroadcasting.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H
11#define EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H
12
13namespace Eigen {
14
15namespace internal {
16template<typename Broadcast, typename XprType>
17struct traits<TensorBroadcastingOp<Broadcast, XprType> > : public traits<XprType>
18{
19 typedef typename XprType::Scalar Scalar;
20 typedef traits<XprType> XprTraits;
21 typedef typename XprTraits::StorageKind StorageKind;
22 typedef typename XprTraits::Index Index;
23 typedef typename XprType::Nested Nested;
24 typedef typename remove_reference<Nested>::type _Nested;
25 static const int NumDimensions = XprTraits::NumDimensions;
26 static const int Layout = XprTraits::Layout;
27};
28
29template<typename Broadcast, typename XprType>
30struct eval<TensorBroadcastingOp<Broadcast, XprType>, Eigen::Dense>
31{
32 typedef const TensorBroadcastingOp<Broadcast, XprType>& type;
33};
34
35template<typename Broadcast, typename XprType>
36struct nested<TensorBroadcastingOp<Broadcast, XprType>, 1, typename eval<TensorBroadcastingOp<Broadcast, XprType> >::type>
37{
38 typedef TensorBroadcastingOp<Broadcast, XprType> type;
39};
40
41template <typename Dims>
42struct is_input_scalar {
43 static const bool value = false;
44};
45template <>
46struct is_input_scalar<Sizes<> > {
47 static const bool value = true;
48};
49#ifndef EIGEN_EMULATE_CXX11_META_H
50template <typename std::size_t... Indices>
51struct is_input_scalar<Sizes<Indices...> > {
52 static const bool value = (Sizes<Indices...>::total_size == 1);
53};
54#endif
55
56} // end namespace internal
57
61template <typename Broadcast, typename XprType>
62class TensorBroadcastingOp : public TensorBase<TensorBroadcastingOp<Broadcast, XprType>, ReadOnlyAccessors> {
63 public:
64 typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Scalar Scalar;
65 typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
66 typedef typename XprType::CoeffReturnType CoeffReturnType;
67 typedef typename Eigen::internal::nested<TensorBroadcastingOp>::type Nested;
68 typedef typename Eigen::internal::traits<TensorBroadcastingOp>::StorageKind StorageKind;
69 typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Index Index;
70
71 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBroadcastingOp(const XprType& expr, const Broadcast& broadcast)
72 : m_xpr(expr), m_broadcast(broadcast) {}
73
74 EIGEN_DEVICE_FUNC
75 const Broadcast& broadcast() const { return m_broadcast; }
76
77 EIGEN_DEVICE_FUNC
78 const typename internal::remove_all<typename XprType::Nested>::type&
79 expression() const { return m_xpr; }
80
81 protected:
82 typename XprType::Nested m_xpr;
83 const Broadcast m_broadcast;
84};
85
86
87// Eval as rvalue
88template<typename Broadcast, typename ArgType, typename Device>
89struct TensorEvaluator<const TensorBroadcastingOp<Broadcast, ArgType>, Device>
90{
92 typedef typename XprType::Index Index;
93 static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
94 typedef DSizes<Index, NumDims> Dimensions;
95 typedef typename XprType::Scalar Scalar;
96 typedef typename TensorEvaluator<ArgType, Device>::Dimensions InputDimensions;
97 typedef typename XprType::CoeffReturnType CoeffReturnType;
98 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
99 static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
100
101 enum {
102 IsAligned = true,
103 PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
104 Layout = TensorEvaluator<ArgType, Device>::Layout,
105 RawAccess = false
106 };
107
108 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
109 : m_broadcast(op.broadcast()),m_impl(op.expression(), device)
110 {
111 // The broadcasting op doesn't change the rank of the tensor. One can't broadcast a scalar
112 // and store the result in a scalar. Instead one should reshape the scalar into a a N-D
113 // tensor with N >= 1 of 1 element first and then broadcast.
114 EIGEN_STATIC_ASSERT((NumDims > 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
115 const InputDimensions& input_dims = m_impl.dimensions();
116 const Broadcast& broadcast = op.broadcast();
117 for (int i = 0; i < NumDims; ++i) {
118 eigen_assert(input_dims[i] > 0);
119 m_dimensions[i] = input_dims[i] * broadcast[i];
120 }
121
122 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
123 m_inputStrides[0] = 1;
124 m_outputStrides[0] = 1;
125 for (int i = 1; i < NumDims; ++i) {
126 m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1];
127 m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1];
128 }
129 } else {
130 m_inputStrides[NumDims-1] = 1;
131 m_outputStrides[NumDims-1] = 1;
132 for (int i = NumDims-2; i >= 0; --i) {
133 m_inputStrides[i] = m_inputStrides[i+1] * input_dims[i+1];
134 m_outputStrides[i] = m_outputStrides[i+1] * m_dimensions[i+1];
135 }
136 }
137 }
138
139 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
140
141 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) {
142 m_impl.evalSubExprsIfNeeded(NULL);
143 return true;
144 }
145
146 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
147 m_impl.cleanup();
148 }
149
150 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffReturnType coeff(Index index) const
151 {
152 if (internal::is_input_scalar<typename internal::remove_all<InputDimensions>::type>::value) {
153 return m_impl.coeff(0);
154 }
155
156 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
157 return coeffColMajor(index);
158 } else {
159 return coeffRowMajor(index);
160 }
161 }
162
163 // TODO: attempt to speed this up. The integer divisions and modulo are slow
164 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeffColMajor(Index index) const
165 {
166 Index inputIndex = 0;
167 for (int i = NumDims - 1; i > 0; --i) {
168 const Index idx = index / m_outputStrides[i];
169 if (internal::index_statically_eq<Broadcast>(i, 1)) {
170 eigen_assert(idx < m_impl.dimensions()[i]);
171 inputIndex += idx * m_inputStrides[i];
172 } else {
173 if (internal::index_statically_eq<InputDimensions>(i, 1)) {
174 eigen_assert(idx % m_impl.dimensions()[i] == 0);
175 } else {
176 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
177 }
178 }
179 index -= idx * m_outputStrides[i];
180 }
181 if (internal::index_statically_eq<Broadcast>(0, 1)) {
182 eigen_assert(index < m_impl.dimensions()[0]);
183 inputIndex += index;
184 } else {
185 if (internal::index_statically_eq<InputDimensions>(0, 1)) {
186 eigen_assert(index % m_impl.dimensions()[0] == 0);
187 } else {
188 inputIndex += (index % m_impl.dimensions()[0]);
189 }
190 }
191 return m_impl.coeff(inputIndex);
192 }
193
194 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeffRowMajor(Index index) const
195 {
196 Index inputIndex = 0;
197 for (int i = 0; i < NumDims - 1; ++i) {
198 const Index idx = index / m_outputStrides[i];
199 if (internal::index_statically_eq<Broadcast>(i, 1)) {
200 eigen_assert(idx < m_impl.dimensions()[i]);
201 inputIndex += idx * m_inputStrides[i];
202 } else {
203 if (internal::index_statically_eq<InputDimensions>(i, 1)) {
204 eigen_assert(idx % m_impl.dimensions()[i] == 0);
205 } else {
206 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
207 }
208 }
209 index -= idx * m_outputStrides[i];
210 }
211 if (internal::index_statically_eq<Broadcast>(NumDims-1, 1)) {
212 eigen_assert(index < m_impl.dimensions()[NumDims-1]);
213 inputIndex += index;
214 } else {
215 if (internal::index_statically_eq<InputDimensions>(NumDims-1, 1)) {
216 eigen_assert(index % m_impl.dimensions()[NumDims-1] == 0);
217 } else {
218 inputIndex += (index % m_impl.dimensions()[NumDims-1]);
219 }
220 }
221 return m_impl.coeff(inputIndex);
222 }
223
224 template<int LoadMode>
225 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketReturnType packet(Index index) const
226 {
227 if (internal::is_input_scalar<typename internal::remove_all<InputDimensions>::type>::value) {
228 return internal::pset1<PacketReturnType>(m_impl.coeff(0));
229 }
230
231 if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
232 return packetColMajor<LoadMode>(index);
233 } else {
234 return packetRowMajor<LoadMode>(index);
235 }
236 }
237
238 // Ignore the LoadMode and always use unaligned loads since we can't guarantee
239 // the alignment at compile time.
240 template<int LoadMode>
241 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetColMajor(Index index) const
242 {
243 EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
244 eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
245
246 const Index originalIndex = index;
247
248 Index inputIndex = 0;
249 for (int i = NumDims - 1; i > 0; --i) {
250 const Index idx = index / m_outputStrides[i];
251 if (internal::index_statically_eq<Broadcast>(i, 1)) {
252 eigen_assert(idx < m_impl.dimensions()[i]);
253 inputIndex += idx * m_inputStrides[i];
254 } else {
255 if (internal::index_statically_eq<InputDimensions>(i, 1)) {
256 eigen_assert(idx % m_impl.dimensions()[i] == 0);
257 } else {
258 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
259 }
260 }
261 index -= idx * m_outputStrides[i];
262 }
263 Index innermostLoc;
264 if (internal::index_statically_eq<Broadcast>(0, 1)) {
265 eigen_assert(index < m_impl.dimensions()[0]);
266 innermostLoc = index;
267 } else {
268 if (internal::index_statically_eq<InputDimensions>(0, 1)) {
269 eigen_assert(index % m_impl.dimensions()[0] == 0);
270 innermostLoc = 0;
271 } else {
272 innermostLoc = index % m_impl.dimensions()[0];
273 }
274 }
275 inputIndex += innermostLoc;
276
277 // Todo: this could be extended to the second dimension if we're not
278 // broadcasting alongside the first dimension, and so on.
279 if (innermostLoc + PacketSize <= m_impl.dimensions()[0]) {
280 return m_impl.template packet<Unaligned>(inputIndex);
281 } else {
282 EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
283 values[0] = m_impl.coeff(inputIndex);
284 for (int i = 1; i < PacketSize; ++i) {
285 values[i] = coeffColMajor(originalIndex+i);
286 }
287 PacketReturnType rslt = internal::pload<PacketReturnType>(values);
288 return rslt;
289 }
290 }
291
292 template<int LoadMode>
293 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetRowMajor(Index index) const
294 {
295 EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
296 eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
297
298 const Index originalIndex = index;
299
300 Index inputIndex = 0;
301 for (int i = 0; i < NumDims - 1; ++i) {
302 const Index idx = index / m_outputStrides[i];
303 if (internal::index_statically_eq<Broadcast>(i, 1)) {
304 eigen_assert(idx < m_impl.dimensions()[i]);
305 inputIndex += idx * m_inputStrides[i];
306 } else {
307 if (internal::index_statically_eq<InputDimensions>(i, 1)) {
308 eigen_assert(idx % m_impl.dimensions()[i] == 0);
309 } else {
310 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
311 }
312 }
313 index -= idx * m_outputStrides[i];
314 }
315 Index innermostLoc;
316 if (internal::index_statically_eq<Broadcast>(NumDims-1, 1)) {
317 eigen_assert(index < m_impl.dimensions()[NumDims-1]);
318 innermostLoc = index;
319 } else {
320 if (internal::index_statically_eq<InputDimensions>(NumDims-1, 1)) {
321 eigen_assert(index % m_impl.dimensions()[NumDims-1] == 0);
322 innermostLoc = 0;
323 } else {
324 innermostLoc = index % m_impl.dimensions()[NumDims-1];
325 }
326 }
327 inputIndex += innermostLoc;
328
329 // Todo: this could be extended to the second dimension if we're not
330 // broadcasting alongside the first dimension, and so on.
331 if (innermostLoc + PacketSize <= m_impl.dimensions()[NumDims-1]) {
332 return m_impl.template packet<Unaligned>(inputIndex);
333 } else {
334 EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
335 values[0] = m_impl.coeff(inputIndex);
336 for (int i = 1; i < PacketSize; ++i) {
337 values[i] = coeffRowMajor(originalIndex+i);
338 }
339 PacketReturnType rslt = internal::pload<PacketReturnType>(values);
340 return rslt;
341 }
342 }
343
344 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
345 costPerCoeff(bool vectorized) const {
346 double compute_cost = TensorOpCost::AddCost<Index>();
347 if (NumDims > 0) {
348 for (int i = NumDims - 1; i > 0; --i) {
349 compute_cost += TensorOpCost::DivCost<Index>();
350 if (internal::index_statically_eq<Broadcast>(i, 1)) {
351 compute_cost +=
352 TensorOpCost::MulCost<Index>() + TensorOpCost::AddCost<Index>();
353 } else {
354 if (!internal::index_statically_eq<InputDimensions>(i, 1)) {
355 compute_cost += TensorOpCost::MulCost<Index>() +
356 TensorOpCost::ModCost<Index>() +
357 TensorOpCost::AddCost<Index>();
358 }
359 }
360 compute_cost +=
361 TensorOpCost::MulCost<Index>() + TensorOpCost::AddCost<Index>();
362 }
363 }
364 return m_impl.costPerCoeff(vectorized) +
365 TensorOpCost(0, 0, compute_cost, vectorized, PacketSize);
366 }
367
368 EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; }
369
370 const TensorEvaluator<ArgType, Device>& impl() const { return m_impl; }
371
372 Broadcast functor() const { return m_broadcast; }
373
374 protected:
375 const Broadcast m_broadcast;
376 Dimensions m_dimensions;
377 array<Index, NumDims> m_outputStrides;
378 array<Index, NumDims> m_inputStrides;
379 TensorEvaluator<ArgType, Device> m_impl;
380};
381
382
383} // end namespace Eigen
384
385#endif // EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H
The tensor base class.
Definition TensorForwardDeclarations.h:29
Definition TensorBroadcasting.h:62
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The tensor evaluator class.
Definition TensorEvaluator.h:27
const Device & device() const
required by sycl in order to construct sycl buffer from raw pointer
Definition TensorEvaluator.h:112