10#ifndef EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H
11#define EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H
16template<
typename Broadcast,
typename XprType>
17struct traits<TensorBroadcastingOp<Broadcast, XprType> > :
public traits<XprType>
19 typedef typename XprType::Scalar Scalar;
20 typedef traits<XprType> XprTraits;
21 typedef typename XprTraits::StorageKind StorageKind;
22 typedef typename XprTraits::Index
Index;
23 typedef typename XprType::Nested Nested;
24 typedef typename remove_reference<Nested>::type _Nested;
25 static const int NumDimensions = XprTraits::NumDimensions;
26 static const int Layout = XprTraits::Layout;
29template<
typename Broadcast,
typename XprType>
30struct eval<TensorBroadcastingOp<Broadcast, XprType>, Eigen::Dense>
32 typedef const TensorBroadcastingOp<Broadcast, XprType>& type;
35template<
typename Broadcast,
typename XprType>
36struct nested<TensorBroadcastingOp<Broadcast, XprType>, 1, typename eval<TensorBroadcastingOp<Broadcast, XprType> >::type>
38 typedef TensorBroadcastingOp<Broadcast, XprType> type;
41template <
typename Dims>
42struct is_input_scalar {
43 static const bool value =
false;
46struct is_input_scalar<Sizes<> > {
47 static const bool value =
true;
49#ifndef EIGEN_EMULATE_CXX11_META_H
50template <
typename std::size_t... Indices>
51struct is_input_scalar<Sizes<Indices...> > {
52 static const bool value = (Sizes<Indices...>::total_size == 1);
61template <
typename Broadcast,
typename XprType>
62class TensorBroadcastingOp :
public TensorBase<TensorBroadcastingOp<Broadcast, XprType>, ReadOnlyAccessors> {
64 typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Scalar Scalar;
66 typedef typename XprType::CoeffReturnType CoeffReturnType;
67 typedef typename Eigen::internal::nested<TensorBroadcastingOp>::type Nested;
68 typedef typename Eigen::internal::traits<TensorBroadcastingOp>::StorageKind StorageKind;
69 typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Index Index;
71 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBroadcastingOp(
const XprType& expr,
const Broadcast& broadcast)
72 : m_xpr(expr), m_broadcast(broadcast) {}
75 const Broadcast& broadcast()
const {
return m_broadcast; }
78 const typename internal::remove_all<typename XprType::Nested>::type&
79 expression()
const {
return m_xpr; }
82 typename XprType::Nested m_xpr;
83 const Broadcast m_broadcast;
88template<
typename Broadcast,
typename ArgType,
typename Device>
92 typedef typename XprType::Index
Index;
93 static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
95 typedef typename XprType::Scalar
Scalar;
96 typedef typename TensorEvaluator<ArgType, Device>::Dimensions InputDimensions;
97 typedef typename XprType::CoeffReturnType CoeffReturnType;
98 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
99 static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
103 PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
104 Layout = TensorEvaluator<ArgType, Device>::Layout,
108 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(
const XprType& op,
const Device&
device)
109 : m_broadcast(op.broadcast()),m_impl(op.expression(),
device)
114 EIGEN_STATIC_ASSERT((NumDims > 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
115 const InputDimensions& input_dims = m_impl.dimensions();
116 const Broadcast& broadcast = op.broadcast();
117 for (
int i = 0; i < NumDims; ++i) {
118 eigen_assert(input_dims[i] > 0);
119 m_dimensions[i] = input_dims[i] * broadcast[i];
122 if (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) {
123 m_inputStrides[0] = 1;
124 m_outputStrides[0] = 1;
125 for (
int i = 1; i < NumDims; ++i) {
126 m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1];
127 m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1];
130 m_inputStrides[NumDims-1] = 1;
131 m_outputStrides[NumDims-1] = 1;
132 for (
int i = NumDims-2; i >= 0; --i) {
133 m_inputStrides[i] = m_inputStrides[i+1] * input_dims[i+1];
134 m_outputStrides[i] = m_outputStrides[i+1] * m_dimensions[i+1];
139 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Dimensions& dimensions()
const {
return m_dimensions; }
141 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
bool evalSubExprsIfNeeded(Scalar* ) {
142 m_impl.evalSubExprsIfNeeded(NULL);
146 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void cleanup() {
150 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffReturnType coeff(Index index)
const
152 if (internal::is_input_scalar<
typename internal::remove_all<InputDimensions>::type>::value) {
153 return m_impl.coeff(0);
156 if (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) {
157 return coeffColMajor(index);
159 return coeffRowMajor(index);
164 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeffColMajor(Index index)
const
166 Index inputIndex = 0;
167 for (
int i = NumDims - 1; i > 0; --i) {
168 const Index idx = index / m_outputStrides[i];
169 if (internal::index_statically_eq<Broadcast>(i, 1)) {
170 eigen_assert(idx < m_impl.dimensions()[i]);
171 inputIndex += idx * m_inputStrides[i];
173 if (internal::index_statically_eq<InputDimensions>(i, 1)) {
174 eigen_assert(idx % m_impl.dimensions()[i] == 0);
176 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
179 index -= idx * m_outputStrides[i];
181 if (internal::index_statically_eq<Broadcast>(0, 1)) {
182 eigen_assert(index < m_impl.dimensions()[0]);
185 if (internal::index_statically_eq<InputDimensions>(0, 1)) {
186 eigen_assert(index % m_impl.dimensions()[0] == 0);
188 inputIndex += (index % m_impl.dimensions()[0]);
191 return m_impl.coeff(inputIndex);
194 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeffRowMajor(Index index)
const
196 Index inputIndex = 0;
197 for (
int i = 0; i < NumDims - 1; ++i) {
198 const Index idx = index / m_outputStrides[i];
199 if (internal::index_statically_eq<Broadcast>(i, 1)) {
200 eigen_assert(idx < m_impl.dimensions()[i]);
201 inputIndex += idx * m_inputStrides[i];
203 if (internal::index_statically_eq<InputDimensions>(i, 1)) {
204 eigen_assert(idx % m_impl.dimensions()[i] == 0);
206 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
209 index -= idx * m_outputStrides[i];
211 if (internal::index_statically_eq<Broadcast>(NumDims-1, 1)) {
212 eigen_assert(index < m_impl.dimensions()[NumDims-1]);
215 if (internal::index_statically_eq<InputDimensions>(NumDims-1, 1)) {
216 eigen_assert(index % m_impl.dimensions()[NumDims-1] == 0);
218 inputIndex += (index % m_impl.dimensions()[NumDims-1]);
221 return m_impl.coeff(inputIndex);
224 template<
int LoadMode>
225 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketReturnType packet(Index index)
const
227 if (internal::is_input_scalar<
typename internal::remove_all<InputDimensions>::type>::value) {
228 return internal::pset1<PacketReturnType>(m_impl.coeff(0));
231 if (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) {
232 return packetColMajor<LoadMode>(index);
234 return packetRowMajor<LoadMode>(index);
240 template<
int LoadMode>
241 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetColMajor(Index index)
const
243 EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
244 eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
246 const Index originalIndex = index;
248 Index inputIndex = 0;
249 for (
int i = NumDims - 1; i > 0; --i) {
250 const Index idx = index / m_outputStrides[i];
251 if (internal::index_statically_eq<Broadcast>(i, 1)) {
252 eigen_assert(idx < m_impl.dimensions()[i]);
253 inputIndex += idx * m_inputStrides[i];
255 if (internal::index_statically_eq<InputDimensions>(i, 1)) {
256 eigen_assert(idx % m_impl.dimensions()[i] == 0);
258 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
261 index -= idx * m_outputStrides[i];
264 if (internal::index_statically_eq<Broadcast>(0, 1)) {
265 eigen_assert(index < m_impl.dimensions()[0]);
266 innermostLoc = index;
268 if (internal::index_statically_eq<InputDimensions>(0, 1)) {
269 eigen_assert(index % m_impl.dimensions()[0] == 0);
272 innermostLoc = index % m_impl.dimensions()[0];
275 inputIndex += innermostLoc;
279 if (innermostLoc + PacketSize <= m_impl.dimensions()[0]) {
280 return m_impl.template packet<Unaligned>(inputIndex);
282 EIGEN_ALIGN_MAX
typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
283 values[0] = m_impl.coeff(inputIndex);
284 for (
int i = 1; i < PacketSize; ++i) {
285 values[i] = coeffColMajor(originalIndex+i);
287 PacketReturnType rslt = internal::pload<PacketReturnType>(values);
292 template<
int LoadMode>
293 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetRowMajor(Index index)
const
295 EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
296 eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
298 const Index originalIndex = index;
300 Index inputIndex = 0;
301 for (
int i = 0; i < NumDims - 1; ++i) {
302 const Index idx = index / m_outputStrides[i];
303 if (internal::index_statically_eq<Broadcast>(i, 1)) {
304 eigen_assert(idx < m_impl.dimensions()[i]);
305 inputIndex += idx * m_inputStrides[i];
307 if (internal::index_statically_eq<InputDimensions>(i, 1)) {
308 eigen_assert(idx % m_impl.dimensions()[i] == 0);
310 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
313 index -= idx * m_outputStrides[i];
316 if (internal::index_statically_eq<Broadcast>(NumDims-1, 1)) {
317 eigen_assert(index < m_impl.dimensions()[NumDims-1]);
318 innermostLoc = index;
320 if (internal::index_statically_eq<InputDimensions>(NumDims-1, 1)) {
321 eigen_assert(index % m_impl.dimensions()[NumDims-1] == 0);
324 innermostLoc = index % m_impl.dimensions()[NumDims-1];
327 inputIndex += innermostLoc;
331 if (innermostLoc + PacketSize <= m_impl.dimensions()[NumDims-1]) {
332 return m_impl.template packet<Unaligned>(inputIndex);
334 EIGEN_ALIGN_MAX
typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
335 values[0] = m_impl.coeff(inputIndex);
336 for (
int i = 1; i < PacketSize; ++i) {
337 values[i] = coeffRowMajor(originalIndex+i);
339 PacketReturnType rslt = internal::pload<PacketReturnType>(values);
344 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
345 costPerCoeff(
bool vectorized)
const {
346 double compute_cost = TensorOpCost::AddCost<Index>();
348 for (
int i = NumDims - 1; i > 0; --i) {
349 compute_cost += TensorOpCost::DivCost<Index>();
350 if (internal::index_statically_eq<Broadcast>(i, 1)) {
352 TensorOpCost::MulCost<Index>() + TensorOpCost::AddCost<Index>();
354 if (!internal::index_statically_eq<InputDimensions>(i, 1)) {
355 compute_cost += TensorOpCost::MulCost<Index>() +
356 TensorOpCost::ModCost<Index>() +
357 TensorOpCost::AddCost<Index>();
361 TensorOpCost::MulCost<Index>() + TensorOpCost::AddCost<Index>();
364 return m_impl.costPerCoeff(vectorized) +
365 TensorOpCost(0, 0, compute_cost, vectorized, PacketSize);
368 EIGEN_DEVICE_FUNC Scalar* data()
const {
return NULL; }
370 const TensorEvaluator<ArgType, Device>& impl()
const {
return m_impl; }
372 Broadcast functor()
const {
return m_broadcast; }
375 const Broadcast m_broadcast;
376 Dimensions m_dimensions;
377 array<Index, NumDims> m_outputStrides;
378 array<Index, NumDims> m_inputStrides;
379 TensorEvaluator<ArgType, Device> m_impl;
The tensor base class.
Definition TensorForwardDeclarations.h:29
Definition TensorBroadcasting.h:62
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The tensor evaluator class.
Definition TensorEvaluator.h:27
const Device & device() const
required by sycl in order to construct sycl buffer from raw pointer
Definition TensorEvaluator.h:112