10#ifndef EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H
11#define EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H
16template<
typename Broadcast,
typename XprType>
17struct traits<TensorBroadcastingOp<Broadcast, XprType> > :
public traits<XprType>
19 typedef typename XprType::Scalar Scalar;
20 typedef traits<XprType> XprTraits;
21 typedef typename XprTraits::StorageKind StorageKind;
22 typedef typename XprTraits::Index
Index;
23 typedef typename XprType::Nested Nested;
24 typedef typename remove_reference<Nested>::type _Nested;
25 static const int NumDimensions = XprTraits::NumDimensions;
26 static const int Layout = XprTraits::Layout;
27 typedef typename XprTraits::PointerType PointerType;
30template<
typename Broadcast,
typename XprType>
31struct eval<TensorBroadcastingOp<Broadcast, XprType>, Eigen::Dense>
33 typedef const TensorBroadcastingOp<Broadcast, XprType> EIGEN_DEVICE_REF type;
36template<
typename Broadcast,
typename XprType>
37struct nested<TensorBroadcastingOp<Broadcast, XprType>, 1, typename eval<TensorBroadcastingOp<Broadcast, XprType> >::type>
39 typedef TensorBroadcastingOp<Broadcast, XprType> type;
42template <
typename Dims>
43struct is_input_scalar {
44 static const bool value =
false;
47struct is_input_scalar<Sizes<> > {
48 static const bool value =
true;
50#ifndef EIGEN_EMULATE_CXX11_META_H
51template <
typename std::ptrdiff_t... Indices>
52struct is_input_scalar<Sizes<Indices...> > {
53 static const bool value = (Sizes<Indices...>::total_size == 1);
62template <
typename Broadcast,
typename XprType>
63class TensorBroadcastingOp :
public TensorBase<TensorBroadcastingOp<Broadcast, XprType>, ReadOnlyAccessors> {
65 typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Scalar Scalar;
67 typedef typename XprType::CoeffReturnType CoeffReturnType;
68 typedef typename Eigen::internal::nested<TensorBroadcastingOp>::type Nested;
69 typedef typename Eigen::internal::traits<TensorBroadcastingOp>::StorageKind StorageKind;
70 typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Index Index;
72 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBroadcastingOp(
const XprType& expr,
const Broadcast& broadcast)
73 : m_xpr(expr), m_broadcast(broadcast) {}
76 const Broadcast& broadcast()
const {
return m_broadcast; }
79 const typename internal::remove_all<typename XprType::Nested>::type&
80 expression()
const {
return m_xpr; }
83 typename XprType::Nested m_xpr;
84 const Broadcast m_broadcast;
89template<
typename Broadcast,
typename ArgType,
typename Device>
93 typedef typename XprType::Index
Index;
94 static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
96 typedef typename XprType::Scalar
Scalar;
97 typedef typename TensorEvaluator<ArgType, Device>::Dimensions InputDimensions;
99 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
100 static const int PacketSize = PacketType<CoeffReturnType, Device>::size;
102 bool isCopy, nByOne, oneByN;
104 typedef StorageMemory<CoeffReturnType, Device> Storage;
105 typedef typename Storage::Type EvaluatorPointerType;
108 IsAligned = TensorEvaluator<ArgType, Device>::IsAligned,
109 PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
110 BlockAccess = TensorEvaluator<ArgType, Device>::BlockAccess,
111 PreferBlockAccess =
true,
112 Layout = TensorEvaluator<ArgType, Device>::Layout,
116 typedef typename internal::remove_const<Scalar>::type ScalarNoConst;
120 typedef DSizes<Index, 2 * NumDims> BroadcastDimensions;
123 typedef internal::TensorBlockDescriptor<NumDims, Index> TensorBlockDesc;
124 typedef internal::TensorBlockScratchAllocator<Device> TensorBlockScratch;
126 typedef typename TensorEvaluator<const ArgType, Device>::TensorBlock
129 typedef typename internal::TensorMaterializedBlock<ScalarNoConst, NumDims,
134 EIGEN_STRONG_INLINE TensorEvaluator(
const XprType& op,
const Device& device)
135 : isCopy(false), nByOne(false), oneByN(false),
136 m_device(device), m_broadcast(op.broadcast()), m_impl(op.expression(), device)
142 EIGEN_STATIC_ASSERT((NumDims > 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
143 const InputDimensions& input_dims = m_impl.dimensions();
145 for (
int i = 0; i < NumDims; ++i) {
146 eigen_assert(input_dims[i] > 0);
147 m_dimensions[i] = input_dims[i] * m_broadcast[i];
148 if (m_broadcast[i] != 1) {
153 if (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) {
154 m_inputStrides[0] = 1;
155 m_outputStrides[0] = 1;
156 for (
int i = 1; i < NumDims; ++i) {
157 m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1];
158 m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1];
161 m_inputStrides[NumDims-1] = 1;
162 m_outputStrides[NumDims-1] = 1;
163 for (
int i = NumDims-2; i >= 0; --i) {
164 m_inputStrides[i] = m_inputStrides[i+1] * input_dims[i+1];
165 m_outputStrides[i] = m_outputStrides[i+1] * m_dimensions[i+1];
169 if (input_dims[0] == 1) {
171 for (
int i = 1; i < NumDims; ++i) {
172 if (m_broadcast[i] != 1) {
177 }
else if (input_dims[NumDims-1] == 1) {
179 for (
int i = 0; i < NumDims-1; ++i) {
180 if (m_broadcast[i] != 1) {
189 if (!oneByN && !nByOne) {
190 if (input_dims[0] == 1 && input_dims[NumDims-1] == 1 && NumDims > 2) {
193 for (
int i = 1; i < NumDims-1; ++i) {
194 if (m_broadcast[i] != 1) {
204 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Dimensions& dimensions()
const {
return m_dimensions; }
206 EIGEN_STRONG_INLINE
bool evalSubExprsIfNeeded(EvaluatorPointerType) {
207 m_impl.evalSubExprsIfNeeded(NULL);
211#ifdef EIGEN_USE_THREADS
212 template <
typename EvalSubExprsCallback>
213 EIGEN_STRONG_INLINE
void evalSubExprsIfNeededAsync(
214 EvaluatorPointerType, EvalSubExprsCallback done) {
215 m_impl.evalSubExprsIfNeededAsync(
nullptr, [done](
bool) { done(
true); });
219 EIGEN_STRONG_INLINE
void cleanup() {
223 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffReturnType coeff(Index index)
const
225 if (internal::is_input_scalar<
typename internal::remove_all<InputDimensions>::type>::value) {
226 return m_impl.coeff(0);
229 if (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) {
231 return m_impl.coeff(index);
233 return coeffColMajor(index);
237 return m_impl.coeff(index);
239 return coeffRowMajor(index);
245 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index indexColMajor(Index index)
const {
246 Index inputIndex = 0;
248 for (
int i = NumDims - 1; i > 0; --i) {
249 const Index idx = index / m_outputStrides[i];
250 if (internal::index_statically_eq<Broadcast>(i, 1)) {
251 eigen_assert(idx < m_impl.dimensions()[i]);
252 inputIndex += idx * m_inputStrides[i];
254 if (internal::index_statically_eq<InputDimensions>(i, 1)) {
255 eigen_assert(idx % m_impl.dimensions()[i] == 0);
257 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
260 index -= idx * m_outputStrides[i];
262 if (internal::index_statically_eq<Broadcast>(0, 1)) {
263 eigen_assert(index < m_impl.dimensions()[0]);
266 if (internal::index_statically_eq<InputDimensions>(0, 1)) {
267 eigen_assert(index % m_impl.dimensions()[0] == 0);
269 inputIndex += (index % m_impl.dimensions()[0]);
275 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeffColMajor(Index index)
const
277 return m_impl.coeff(indexColMajor(index));
280 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index indexRowMajor(Index index)
const {
281 Index inputIndex = 0;
283 for (
int i = 0; i < NumDims - 1; ++i) {
284 const Index idx = index / m_outputStrides[i];
285 if (internal::index_statically_eq<Broadcast>(i, 1)) {
286 eigen_assert(idx < m_impl.dimensions()[i]);
287 inputIndex += idx * m_inputStrides[i];
289 if (internal::index_statically_eq<InputDimensions>(i, 1)) {
290 eigen_assert(idx % m_impl.dimensions()[i] == 0);
292 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
295 index -= idx * m_outputStrides[i];
297 if (internal::index_statically_eq<Broadcast>(NumDims - 1, 1)) {
298 eigen_assert(index < m_impl.dimensions()[NumDims - 1]);
301 if (internal::index_statically_eq<InputDimensions>(NumDims - 1, 1)) {
302 eigen_assert(index % m_impl.dimensions()[NumDims - 1] == 0);
304 inputIndex += (index % m_impl.dimensions()[NumDims - 1]);
310 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeffRowMajor(Index index)
const
312 return m_impl.coeff(indexRowMajor(index));
315 template<
int LoadMode>
316 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketReturnType packet(Index index)
const
318 if (internal::is_input_scalar<
typename internal::remove_all<InputDimensions>::type>::value) {
319 return internal::pset1<PacketReturnType>(m_impl.coeff(0));
322 if (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) {
324 #ifdef EIGEN_GPU_COMPILE_PHASE
327 return m_impl.template packet<Unaligned>(index);
329 return m_impl.template packet<LoadMode>(index);
331 }
else if (oneByN && !nByOne) {
332 return packetNByOne<LoadMode>(index);
333 }
else if (!oneByN && nByOne) {
334 return packetOneByN<LoadMode>(index);
335 }
else if (oneByN && nByOne) {
336 return packetOneByNByOne<LoadMode>(index);
338 return packetColMajor<LoadMode>(index);
342 #ifdef EIGEN_GPU_COMPILE_PHASE
344 return m_impl.template packet<Unaligned>(index);
346 return m_impl.template packet<LoadMode>(index);
348 }
else if (oneByN && !nByOne) {
349 return packetOneByN<LoadMode>(index);
350 }
else if (!oneByN && nByOne) {
351 return packetNByOne<LoadMode>(index);
352 }
else if (oneByN && nByOne) {
353 return packetOneByNByOne<LoadMode>(index);
355 return packetRowMajor<LoadMode>(index);
360 template<
int LoadMode>
361 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetOneByNByOne
364 EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
365 eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
367 EIGEN_ALIGN_MAX
typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
368 Index startDim, endDim;
369 Index inputIndex, outputOffset, batchedIndex;
371 if (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) {
372 startDim = NumDims - 1;
376 endDim = NumDims - 2;
379 batchedIndex = index % m_outputStrides[startDim];
380 inputIndex = batchedIndex / m_outputStrides[endDim];
381 outputOffset = batchedIndex % m_outputStrides[endDim];
383 if (outputOffset + PacketSize <= m_outputStrides[endDim]) {
384 values[0] = m_impl.coeff(inputIndex);
385 return internal::pload1<PacketReturnType>(values);
388 for (
int i = 0, cur = 0; i < PacketSize; ++i, ++cur) {
389 if (outputOffset + cur < m_outputStrides[endDim]) {
390 values[i] = m_impl.coeff(inputIndex);
393 inputIndex = (inputIndex == m_inputStrides[startDim] ? 0 : inputIndex);
394 values[i] = m_impl.coeff(inputIndex);
399 return internal::pload<PacketReturnType>(values);
403 template<
int LoadMode>
404 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetOneByN(Index index)
const
410 EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
411 eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
414 const Index M = (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) ?
415 m_inputStrides[NumDims - 1] : m_inputStrides[0];
416 Index inputIndex = index % M;
417 if (inputIndex + PacketSize <= M) {
418 return m_impl.template packet<Unaligned>(inputIndex);
420 EIGEN_ALIGN_MAX
typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
422 for (
int i = 0; i < PacketSize; ++i) {
423 if (inputIndex > M - 1) {
426 values[i] = m_impl.coeff(inputIndex++);
428 return internal::pload<PacketReturnType>(values);
432 template<
int LoadMode>
433 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetNByOne(Index index)
const
439 EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
440 eigen_assert(index + PacketSize-1 < dimensions().TotalSize());
442 const Index M = (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) ?
443 m_broadcast[0] : m_broadcast[NumDims - 1];
445 Index inputIndex = index / M;
446 Index outputOffset = index % M;
447 if (outputOffset + PacketSize <= M) {
448 return internal::pset1<PacketReturnType>(m_impl.coeff(inputIndex));
450 EIGEN_ALIGN_MAX
typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
452 for (
int i = 0; i < PacketSize; ++i) {
453 if (outputOffset < M) {
454 values[i] = m_impl.coeff(inputIndex);
457 values[i] = m_impl.coeff(++inputIndex);
461 return internal::pload<PacketReturnType>(values);
467 template<
int LoadMode>
468 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetColMajor(Index index)
const
470 EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
471 eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
473 const Index originalIndex = index;
475 Index inputIndex = 0;
477 for (
int i = NumDims - 1; i > 0; --i) {
478 const Index idx = index / m_outputStrides[i];
479 if (internal::index_statically_eq<Broadcast>(i, 1)) {
480 eigen_assert(idx < m_impl.dimensions()[i]);
481 inputIndex += idx * m_inputStrides[i];
483 if (internal::index_statically_eq<InputDimensions>(i, 1)) {
484 eigen_assert(idx % m_impl.dimensions()[i] == 0);
486 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
489 index -= idx * m_outputStrides[i];
492 if (internal::index_statically_eq<Broadcast>(0, 1)) {
493 eigen_assert(index < m_impl.dimensions()[0]);
494 innermostLoc = index;
496 if (internal::index_statically_eq<InputDimensions>(0, 1)) {
497 eigen_assert(index % m_impl.dimensions()[0] == 0);
500 innermostLoc = index % m_impl.dimensions()[0];
503 inputIndex += innermostLoc;
507 if (innermostLoc + PacketSize <= m_impl.dimensions()[0]) {
508 return m_impl.template packet<Unaligned>(inputIndex);
510 EIGEN_ALIGN_MAX
typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
511 values[0] = m_impl.coeff(inputIndex);
513 for (
int i = 1; i < PacketSize; ++i) {
514 if (innermostLoc + i < m_impl.dimensions()[0]) {
515 values[i] = m_impl.coeff(inputIndex+i);
517 values[i] = coeffColMajor(originalIndex+i);
520 PacketReturnType rslt = internal::pload<PacketReturnType>(values);
525 template<
int LoadMode>
526 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetRowMajor(Index index)
const
528 EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
529 eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
531 const Index originalIndex = index;
533 Index inputIndex = 0;
535 for (
int i = 0; i < NumDims - 1; ++i) {
536 const Index idx = index / m_outputStrides[i];
537 if (internal::index_statically_eq<Broadcast>(i, 1)) {
538 eigen_assert(idx < m_impl.dimensions()[i]);
539 inputIndex += idx * m_inputStrides[i];
541 if (internal::index_statically_eq<InputDimensions>(i, 1)) {
542 eigen_assert(idx % m_impl.dimensions()[i] == 0);
544 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
547 index -= idx * m_outputStrides[i];
550 if (internal::index_statically_eq<Broadcast>(NumDims-1, 1)) {
551 eigen_assert(index < m_impl.dimensions()[NumDims-1]);
552 innermostLoc = index;
554 if (internal::index_statically_eq<InputDimensions>(NumDims-1, 1)) {
555 eigen_assert(index % m_impl.dimensions()[NumDims-1] == 0);
558 innermostLoc = index % m_impl.dimensions()[NumDims-1];
561 inputIndex += innermostLoc;
565 if (innermostLoc + PacketSize <= m_impl.dimensions()[NumDims-1]) {
566 return m_impl.template packet<Unaligned>(inputIndex);
568 EIGEN_ALIGN_MAX
typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
569 values[0] = m_impl.coeff(inputIndex);
571 for (
int i = 1; i < PacketSize; ++i) {
572 if (innermostLoc + i < m_impl.dimensions()[NumDims-1]) {
573 values[i] = m_impl.coeff(inputIndex+i);
575 values[i] = coeffRowMajor(originalIndex+i);
578 PacketReturnType rslt = internal::pload<PacketReturnType>(values);
583 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
584 costPerCoeff(
bool vectorized)
const {
585 double compute_cost = TensorOpCost::AddCost<Index>();
586 if (!isCopy && NumDims > 0) {
588 for (
int i = NumDims - 1; i > 0; --i) {
589 compute_cost += TensorOpCost::DivCost<Index>();
590 if (internal::index_statically_eq<Broadcast>(i, 1)) {
592 TensorOpCost::MulCost<Index>() + TensorOpCost::AddCost<Index>();
594 if (!internal::index_statically_eq<InputDimensions>(i, 1)) {
595 compute_cost += TensorOpCost::MulCost<Index>() +
596 TensorOpCost::ModCost<Index>() +
597 TensorOpCost::AddCost<Index>();
601 TensorOpCost::MulCost<Index>() + TensorOpCost::AddCost<Index>();
604 return m_impl.costPerCoeff(vectorized) +
605 TensorOpCost(0, 0, compute_cost, vectorized, PacketSize);
608 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
609 internal::TensorBlockResourceRequirements getResourceRequirements()
const {
612 const size_t target_size = m_device.firstLevelCacheSize();
613 return internal::TensorBlockResourceRequirements::merge(
614 m_impl.getResourceRequirements(),
615 internal::TensorBlockResourceRequirements::skewed<Scalar>(target_size));
618 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlock
619 block(TensorBlockDesc& desc, TensorBlockScratch& scratch,
620 bool =
false)
const {
621 BlockBroadcastingParams params = blockBroadcastingParams(desc);
623 if (params.inner_dim_size == 0 || params.bcast_dim_size == 0) {
628 const typename TensorBlock::Storage block_storage =
629 TensorBlock::prepareStorage(desc, scratch);
630 ScalarNoConst* materialized_output = block_storage.data();
633 size_t materialized_input_size = 0;
634 ScalarNoConst* materialized_input = NULL;
639 array<BlockBroadcastingIteratorState, NumDims> it;
642 for (
int i = params.inner_dim_count + 1; i < NumDims; ++i) {
643 const Index dim = IsColMajor ? i : NumDims - 1 - i;
644 it[idx].size = params.output_dims[dim];
646 it[idx].output_stride = m_outputStrides[dim];
647 it[idx].output_span = it[idx].output_stride * (it[idx].size - 1);
652 Index output_offset = 0;
656 const Index output_size = NumDims == 0 ? 1 : params.output_dims.TotalSize();
658 for (Index num_output_coeffs = 0; num_output_coeffs < output_size;) {
659 ScalarNoConst* bcast_output = materialized_output + num_output_coeffs;
660 Index bcast_offset = desc.offset() + output_offset;
663 num_output_coeffs += BroadcastBlockAlongBcastDim(
664 params, bcast_offset, scratch, bcast_output, &materialized_input,
665 &materialized_input_size);
668 for (
int j = 0; j < idx; ++j) {
669 if (++it[j].count < it[j].size) {
670 output_offset += it[j].output_stride;
674 output_offset -= it[j].output_span;
678 return block_storage.AsTensorMaterializedBlock();
681 EIGEN_DEVICE_FUNC EvaluatorPointerType data()
const {
return NULL; }
683 const TensorEvaluator<ArgType, Device>& impl()
const {
return m_impl; }
685 Broadcast functor()
const {
return m_broadcast; }
688 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void bind(
689 cl::sycl::handler& cgh)
const {
694 static const bool IsColMajor =
695 static_cast<int>(Layout) ==
static_cast<int>(
ColMajor);
714 struct BlockBroadcastingParams {
715 Dimensions input_dims;
716 Dimensions output_dims;
717 Dimensions output_strides;
721 Index bcast_dim_size;
722 Index inner_dim_size;
726 Dimensions input_block_sizes;
727 Dimensions input_block_strides;
730 BroadcastDimensions bcast_block_sizes;
731 BroadcastDimensions bcast_block_strides;
732 BroadcastDimensions bcast_input_strides;
735 struct BlockBroadcastingIteratorState {
742 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlockBroadcastingParams
743 blockBroadcastingParams(TensorBlockDesc& desc)
const {
744 BlockBroadcastingParams params;
746 params.input_dims = Dimensions(m_impl.dimensions());
749 params.output_dims = desc.dimensions();
750 params.output_strides = internal::strides<Layout>(params.output_dims);
754 params.bcast_dim = 0;
755 params.bcast_dim_size = 1;
756 params.inner_dim_size = 1;
760 params.inner_dim_count = 0;
762 for (
int i = 0; i < NumDims; ++i) {
763 const int dim = IsColMajor ? i : NumDims - i - 1;
765 if (params.output_dims[dim] == m_dimensions[dim]) {
766 params.inner_dim_size *= params.output_dims[dim];
767 ++params.inner_dim_count;
772 eigen_assert(params.output_dims[dim] < m_dimensions[dim]);
773 params.bcast_dim = dim;
774 params.bcast_dim_size = params.output_dims[dim];
779 for (
int i = 0; i < params.inner_dim_count; ++i) {
780 const int dim = IsColMajor ? i : NumDims - i - 1;
781 params.input_block_sizes[dim] = params.input_dims[dim];
783 for (
int i = params.inner_dim_count; i < NumDims; ++i) {
784 const int dim = IsColMajor ? i : NumDims - i - 1;
785 params.input_block_sizes[dim] = 1;
787 params.input_block_strides =
788 internal::strides<Layout>(params.input_block_sizes);
808 for (
int i = 0; i < params.inner_dim_count; ++i) {
809 const int dim = IsColMajor ? i : NumDims - i - 1;
811 const int copy_dim = IsColMajor ? 2 * i : 2 * NumDims - 2 * i - 1;
812 const int broadcast_dim = IsColMajor ? copy_dim + 1 : copy_dim - 1;
814 params.bcast_block_sizes[copy_dim] = params.input_dims[dim];
815 params.bcast_block_sizes[broadcast_dim] = m_broadcast[dim];
816 params.bcast_block_strides[copy_dim] = params.output_strides[dim];
817 params.bcast_block_strides[broadcast_dim] =
818 params.output_strides[dim] * params.input_dims[dim];
819 params.bcast_input_strides[copy_dim] = params.input_block_strides[dim];
820 params.bcast_input_strides[broadcast_dim] = 0;
823 for (
int i = 2 * params.inner_dim_count; i < 2 * NumDims; ++i) {
824 const int dim = IsColMajor ? i : 2 * NumDims - i - 1;
825 params.bcast_block_sizes[dim] = 1;
826 params.bcast_block_strides[dim] = 0;
827 params.bcast_input_strides[dim] = 0;
833 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlock emptyBlock()
const {
834 DSizes<Index, NumDims> dimensions;
835 for (
int i = 0; i < NumDims; ++i) dimensions[i] = 0;
836 return TensorBlock(internal::TensorBlockKind::kView, NULL, dimensions);
839 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index BroadcastBlockAlongBcastDim(
840 BlockBroadcastingParams params, Index bcast_offset,
841 TensorBlockScratch& scratch, ScalarNoConst* materialized_output,
842 ScalarNoConst** materialized_input,
843 size_t* materialized_input_size)
const {
844 if (params.bcast_dim_size == 1) {
846 return BroadcastBlock(
847 params.input_block_sizes, params.input_block_strides,
848 params.bcast_block_sizes, params.bcast_block_strides,
849 params.bcast_input_strides, bcast_offset, 0, scratch,
850 materialized_output, materialized_input, materialized_input_size);
852 }
else if (params.input_dims[params.bcast_dim] == 1) {
854 const int broadcast_bcast_dim =
855 IsColMajor ? 2 * params.inner_dim_count + 1
856 : 2 * NumDims - 2 * params.inner_dim_count - 2;
858 params.bcast_block_sizes[broadcast_bcast_dim] = params.bcast_dim_size;
859 params.bcast_input_strides[broadcast_bcast_dim] = 0;
860 params.bcast_block_strides[broadcast_bcast_dim] =
861 params.output_strides[params.bcast_dim];
863 return BroadcastBlock(
864 params.input_block_sizes, params.input_block_strides,
865 params.bcast_block_sizes, params.bcast_block_strides,
866 params.bcast_input_strides, bcast_offset, 0, scratch,
867 materialized_output, materialized_input, materialized_input_size);
872 Index num_output_coeffs = 0;
894 const Index bcast_dim_left_index =
895 bcast_offset / m_outputStrides[params.bcast_dim];
898 const Index input_bcast_dim_size = params.input_dims[params.bcast_dim];
902 const Index first_multiple =
903 divup<Index>(bcast_dim_left_index, input_bcast_dim_size) *
904 input_bcast_dim_size;
906 if (first_multiple <= bcast_dim_left_index + params.bcast_dim_size) {
908 const Index last_multiple =
909 (bcast_dim_left_index + params.bcast_dim_size) /
910 input_bcast_dim_size * input_bcast_dim_size;
911 const int copy_bcast_dim =
912 IsColMajor ? 2 * params.inner_dim_count
913 : 2 * NumDims - 2 * params.inner_dim_count - 1;
914 const int broadcast_bcast_dim =
915 IsColMajor ? 2 * params.inner_dim_count + 1
916 : 2 * NumDims - 2 * params.inner_dim_count - 2;
918 if (first_multiple > bcast_dim_left_index) {
919 const Index head_size = first_multiple - bcast_dim_left_index;
920 params.input_block_sizes[params.bcast_dim] = head_size;
921 params.bcast_block_sizes[copy_bcast_dim] = head_size;
922 params.bcast_input_strides[copy_bcast_dim] =
923 params.input_block_strides[params.bcast_dim];
924 params.bcast_block_strides[copy_bcast_dim] =
925 params.output_strides[params.bcast_dim];
926 params.bcast_block_sizes[broadcast_bcast_dim] = 1;
927 params.bcast_input_strides[broadcast_bcast_dim] = 0;
928 params.bcast_block_strides[broadcast_bcast_dim] =
929 params.output_strides[params.bcast_dim] *
930 params.input_dims[params.bcast_dim];
932 num_output_coeffs += BroadcastBlock(
933 params.input_block_sizes, params.input_block_strides,
934 params.bcast_block_sizes, params.bcast_block_strides,
935 params.bcast_input_strides, bcast_offset, 0, scratch,
936 materialized_output, materialized_input, materialized_input_size);
938 if (first_multiple < last_multiple) {
939 params.input_block_sizes[params.bcast_dim] = input_bcast_dim_size;
940 params.bcast_block_sizes[copy_bcast_dim] = input_bcast_dim_size;
941 params.bcast_input_strides[copy_bcast_dim] =
942 params.input_block_strides[params.bcast_dim];
943 params.bcast_block_strides[copy_bcast_dim] =
944 params.output_strides[params.bcast_dim];
945 params.bcast_block_sizes[broadcast_bcast_dim] =
946 (last_multiple - first_multiple) / input_bcast_dim_size;
947 params.bcast_input_strides[broadcast_bcast_dim] = 0;
948 params.bcast_block_strides[broadcast_bcast_dim] =
949 params.output_strides[params.bcast_dim] *
950 params.input_dims[params.bcast_dim];
951 const Index offset = (first_multiple - bcast_dim_left_index) *
952 m_outputStrides[params.bcast_dim];
954 num_output_coeffs += BroadcastBlock(
955 params.input_block_sizes, params.input_block_strides,
956 params.bcast_block_sizes, params.bcast_block_strides,
957 params.bcast_input_strides, bcast_offset, offset, scratch,
958 materialized_output, materialized_input, materialized_input_size);
960 if (last_multiple < bcast_dim_left_index + params.bcast_dim_size) {
961 const Index tail_size =
962 bcast_dim_left_index + params.bcast_dim_size - last_multiple;
963 params.input_block_sizes[params.bcast_dim] = tail_size;
964 params.bcast_block_sizes[copy_bcast_dim] = tail_size;
965 params.bcast_input_strides[copy_bcast_dim] =
966 params.input_block_strides[params.bcast_dim];
967 params.bcast_block_strides[copy_bcast_dim] =
968 params.output_strides[params.bcast_dim];
969 params.bcast_block_sizes[broadcast_bcast_dim] = 1;
970 params.bcast_input_strides[broadcast_bcast_dim] = 0;
971 params.bcast_block_strides[broadcast_bcast_dim] =
972 params.output_strides[params.bcast_dim] *
973 params.input_dims[params.bcast_dim];
974 const Index offset = (last_multiple - bcast_dim_left_index) *
975 m_outputStrides[params.bcast_dim];
977 num_output_coeffs += BroadcastBlock(
978 params.input_block_sizes, params.input_block_strides,
979 params.bcast_block_sizes, params.bcast_block_strides,
980 params.bcast_input_strides, bcast_offset, offset, scratch,
981 materialized_output, materialized_input, materialized_input_size);
985 const int copy_bcast_dim =
986 IsColMajor ? 2 * params.inner_dim_count
987 : 2 * NumDims - 2 * params.inner_dim_count - 1;
988 params.input_block_sizes[params.bcast_dim] = params.bcast_dim_size;
989 params.bcast_block_sizes[copy_bcast_dim] = params.bcast_dim_size;
990 params.bcast_input_strides[copy_bcast_dim] =
991 params.input_block_strides[params.bcast_dim];
992 params.bcast_block_strides[copy_bcast_dim] =
993 params.output_strides[params.bcast_dim];
995 num_output_coeffs += BroadcastBlock(
996 params.input_block_sizes, params.input_block_strides,
997 params.bcast_block_sizes, params.bcast_block_strides,
998 params.bcast_input_strides, bcast_offset, 0, scratch,
999 materialized_output, materialized_input, materialized_input_size);
1002 return num_output_coeffs;
1006 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index BroadcastBlock(
1007 const Dimensions& input_block_sizes,
1008 const Dimensions& input_block_strides,
1009 const BroadcastDimensions& bcast_block_sizes,
1010 const BroadcastDimensions& bcast_block_strides,
1011 const BroadcastDimensions& bcast_input_strides, Index bcast_offset,
1012 Index offset, TensorBlockScratch& scratch,
1013 ScalarNoConst* materialized_output, ScalarNoConst** materialized_input,
1014 size_t* materialized_input_size)
const {
1017 const Index input_offset = bcast_offset + offset;
1018 TensorBlockDesc input_desc(
1019 IsColMajor ? indexColMajor(input_offset) : indexRowMajor(input_offset),
1022 ArgTensorBlock input_block = m_impl.block(input_desc, scratch);
1027 const ScalarNoConst* input_buffer = NULL;
1029 if (input_block.data() != NULL) {
1031 input_buffer = input_block.data();
1038 const size_t input_total_size = input_block_sizes.TotalSize();
1039 if (*materialized_input == NULL ||
1040 *materialized_input_size < input_total_size) {
1041 *materialized_input_size = input_total_size;
1042 void* mem = scratch.allocate(*materialized_input_size *
sizeof(Scalar));
1043 *materialized_input =
static_cast<ScalarNoConst*
>(mem);
1046 typedef internal::TensorBlockAssignment<
1047 ScalarNoConst, NumDims,
typename ArgTensorBlock::XprType, Index>
1048 TensorBlockAssignment;
1050 TensorBlockAssignment::Run(
1051 TensorBlockAssignment::target(input_block_sizes, input_block_strides,
1052 *materialized_input),
1053 input_block.expr());
1055 input_buffer = *materialized_input;
1061 typedef internal::TensorBlockIO<ScalarNoConst, Index, 2 * NumDims, Layout>
1064 typename TensorBlockIO::Src src(bcast_input_strides, input_buffer);
1065 typename TensorBlockIO::Dst dst(bcast_block_sizes, bcast_block_strides,
1066 materialized_output + offset);
1068 return TensorBlockIO::Copy(dst, src);
1072 const Device EIGEN_DEVICE_REF m_device;
1073 const typename internal::remove_reference<Broadcast>::type m_broadcast;
1074 Dimensions m_dimensions;
1075 array<Index, NumDims> m_outputStrides;
1076 array<Index, NumDims> m_inputStrides;
1077 TensorEvaluator<ArgType, Device> m_impl;
The tensor base class.
Definition TensorForwardDeclarations.h:56
Definition TensorBroadcasting.h:63
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The tensor evaluator class.
Definition TensorEvaluator.h:27