10#ifndef EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H
11#define EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H
14#include "./InternalHeaderCheck.h"
19template <
typename Broadcast,
typename XprType>
20struct traits<TensorBroadcastingOp<Broadcast, XprType>> :
public traits<XprType> {
21 typedef typename XprType::Scalar Scalar;
22 typedef traits<XprType> XprTraits;
23 typedef typename XprTraits::StorageKind StorageKind;
24 typedef typename XprTraits::Index
Index;
25 typedef typename XprType::Nested Nested;
26 typedef std::remove_reference_t<Nested> Nested_;
27 static constexpr int NumDimensions = XprTraits::NumDimensions;
28 static constexpr int Layout = XprTraits::Layout;
29 typedef typename XprTraits::PointerType PointerType;
32 Flags = traits<XprType>::Flags & ~LvalueBit
36template <
typename Broadcast,
typename XprType>
37struct eval<TensorBroadcastingOp<Broadcast, XprType>, Eigen::Dense> {
38 typedef const TensorBroadcastingOp<Broadcast, XprType> EIGEN_DEVICE_REF type;
41template <
typename Broadcast,
typename XprType>
42struct nested<TensorBroadcastingOp<Broadcast, XprType>, 1,
43 typename eval<TensorBroadcastingOp<Broadcast, XprType>>::type> {
44 typedef TensorBroadcastingOp<Broadcast, XprType> type;
47template <
typename Dims>
48struct is_input_scalar {
49 static const bool value =
false;
52struct is_input_scalar<Sizes<>> {
53 static const bool value =
true;
55template <
typename std::ptrdiff_t... Indices>
56struct is_input_scalar<Sizes<Indices...>> {
57 static constexpr bool value = (Sizes<Indices...>::total_size == 1);
65template <
typename Broadcast,
typename XprType>
66class TensorBroadcastingOp :
public TensorBase<TensorBroadcastingOp<Broadcast, XprType>, ReadOnlyAccessors> {
68 typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Scalar Scalar;
70 typedef typename XprType::CoeffReturnType CoeffReturnType;
71 typedef typename Eigen::internal::nested<TensorBroadcastingOp>::type Nested;
72 typedef typename Eigen::internal::traits<TensorBroadcastingOp>::StorageKind StorageKind;
73 typedef typename Eigen::internal::traits<TensorBroadcastingOp>::Index Index;
75 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBroadcastingOp(
const XprType& expr,
const Broadcast& broadcast)
76 : m_xpr(expr), m_broadcast(broadcast) {}
78 EIGEN_DEVICE_FUNC
const Broadcast& broadcast()
const {
return m_broadcast; }
80 EIGEN_DEVICE_FUNC
const internal::remove_all_t<typename XprType::Nested>& expression()
const {
return m_xpr; }
83 typename XprType::Nested m_xpr;
84 const Broadcast m_broadcast;
88template <
typename Broadcast,
typename ArgType,
typename Device>
91 typedef typename XprType::Index
Index;
92 static constexpr int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
94 typedef typename XprType::Scalar
Scalar;
95 typedef typename TensorEvaluator<ArgType, Device>::Dimensions InputDimensions;
97 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
98 static constexpr int PacketSize = PacketType<CoeffReturnType, Device>::size;
102 bool isCopy, nByOne, oneByN;
105 typedef StorageMemory<CoeffReturnType, Device> Storage;
106 typedef typename Storage::Type EvaluatorPointerType;
109 IsAligned = TensorEvaluator<ArgType, Device>::IsAligned,
110 PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
112 PreferBlockAccess =
true,
115 static constexpr int Layout = TensorEvaluator<ArgType, Device>::Layout;
117 typedef std::remove_const_t<Scalar> ScalarNoConst;
121 typedef DSizes<Index, 2 * NumDims> BroadcastDimensions;
124 typedef internal::TensorBlockDescriptor<NumDims, Index> TensorBlockDesc;
125 typedef internal::TensorBlockScratchAllocator<Device> TensorBlockScratch;
127 typedef typename TensorEvaluator<const ArgType, Device>::TensorBlock ArgTensorBlock;
129 typedef typename internal::TensorMaterializedBlock<ScalarNoConst, NumDims, Layout, Index> TensorBlock;
132 EIGEN_STRONG_INLINE TensorEvaluator(
const XprType& op,
const Device& device)
137 m_broadcast(op.broadcast()),
138 m_impl(op.expression(), device) {
142 EIGEN_STATIC_ASSERT((NumDims > 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
143 const InputDimensions& input_dims = m_impl.dimensions();
145 for (
int i = 0; i < NumDims; ++i) {
146 eigen_assert(input_dims[i] > 0);
147 m_dimensions[i] = input_dims[i] * m_broadcast[i];
148 if (m_broadcast[i] != 1) {
153 if (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) {
154 m_inputStrides[0] = 1;
155 m_outputStrides[0] = 1;
156 for (
int i = 1; i < NumDims; ++i) {
157 m_inputStrides[i] = m_inputStrides[i - 1] * input_dims[i - 1];
158 m_outputStrides[i] = m_outputStrides[i - 1] * m_dimensions[i - 1];
161 m_inputStrides[NumDims - 1] = 1;
162 m_outputStrides[NumDims - 1] = 1;
163 for (
int i = NumDims - 2; i >= 0; --i) {
164 m_inputStrides[i] = m_inputStrides[i + 1] * input_dims[i + 1];
165 m_outputStrides[i] = m_outputStrides[i + 1] * m_dimensions[i + 1];
169 if (input_dims[0] == 1) {
171 for (
int i = 1; i < NumDims; ++i) {
172 if (m_broadcast[i] != 1) {
177 }
else if (input_dims[NumDims - 1] == 1) {
179 for (
int i = 0; i < NumDims - 1; ++i) {
180 if (m_broadcast[i] != 1) {
189 if (!oneByN && !nByOne) {
190 if (input_dims[0] == 1 && input_dims[NumDims - 1] == 1 && NumDims > 2) {
193 for (
int i = 1; i < NumDims - 1; ++i) {
194 if (m_broadcast[i] != 1) {
204 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Dimensions& dimensions()
const {
return m_dimensions; }
206 EIGEN_STRONG_INLINE
bool evalSubExprsIfNeeded(EvaluatorPointerType) {
207 m_impl.evalSubExprsIfNeeded(NULL);
211#ifdef EIGEN_USE_THREADS
212 template <
typename EvalSubExprsCallback>
213 EIGEN_STRONG_INLINE
void evalSubExprsIfNeededAsync(EvaluatorPointerType, EvalSubExprsCallback done) {
214 m_impl.evalSubExprsIfNeededAsync(
nullptr, [done](
bool) { done(
true); });
218 EIGEN_STRONG_INLINE
void cleanup() { m_impl.cleanup(); }
220 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffReturnType coeff(Index index)
const {
221 if (internal::is_input_scalar<internal::remove_all_t<InputDimensions>>::value) {
222 return m_impl.coeff(0);
225 if (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) {
227 return m_impl.coeff(index);
229 return coeffColMajor(index);
233 return m_impl.coeff(index);
235 return coeffRowMajor(index);
241 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index indexColMajor(Index index)
const {
242 Index inputIndex = 0;
244 for (
int i = NumDims - 1; i > 0; --i) {
245 const Index idx = index / m_outputStrides[i];
246 if (internal::index_statically_eq<Broadcast>(i, 1)) {
247 eigen_assert(idx < m_impl.dimensions()[i]);
248 inputIndex += idx * m_inputStrides[i];
250 if (internal::index_statically_eq<InputDimensions>(i, 1)) {
251 eigen_assert(idx % m_impl.dimensions()[i] == 0);
253 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
256 index -= idx * m_outputStrides[i];
258 if (internal::index_statically_eq<Broadcast>(0, 1)) {
259 eigen_assert(index < m_impl.dimensions()[0]);
262 if (internal::index_statically_eq<InputDimensions>(0, 1)) {
263 eigen_assert(index % m_impl.dimensions()[0] == 0);
265 inputIndex += (index % m_impl.dimensions()[0]);
271 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeffColMajor(Index index)
const {
272 return m_impl.coeff(indexColMajor(index));
275 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index indexRowMajor(Index index)
const {
276 Index inputIndex = 0;
278 for (
int i = 0; i < NumDims - 1; ++i) {
279 const Index idx = index / m_outputStrides[i];
280 if (internal::index_statically_eq<Broadcast>(i, 1)) {
281 eigen_assert(idx < m_impl.dimensions()[i]);
282 inputIndex += idx * m_inputStrides[i];
284 if (internal::index_statically_eq<InputDimensions>(i, 1)) {
285 eigen_assert(idx % m_impl.dimensions()[i] == 0);
287 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
290 index -= idx * m_outputStrides[i];
292 if (internal::index_statically_eq<Broadcast>(NumDims - 1, 1)) {
293 eigen_assert(index < m_impl.dimensions()[NumDims - 1]);
296 if (internal::index_statically_eq<InputDimensions>(NumDims - 1, 1)) {
297 eigen_assert(index % m_impl.dimensions()[NumDims - 1] == 0);
299 inputIndex += (index % m_impl.dimensions()[NumDims - 1]);
305 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeffRowMajor(Index index)
const {
306 return m_impl.coeff(indexRowMajor(index));
309 template <
int LoadMode>
310 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketReturnType packet(Index index)
const {
311 if (internal::is_input_scalar<internal::remove_all_t<InputDimensions>>::value) {
312 return internal::pset1<PacketReturnType>(m_impl.coeff(0));
315 if (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) {
317#ifdef EIGEN_GPU_COMPILE_PHASE
320 return m_impl.template packet<Unaligned>(index);
322 return m_impl.template packet<LoadMode>(index);
324 }
else if (oneByN && !nByOne) {
325 return packetNByOne<LoadMode>(index);
326 }
else if (!oneByN && nByOne) {
327 return packetOneByN<LoadMode>(index);
328 }
else if (oneByN && nByOne) {
329 return packetOneByNByOne<LoadMode>(index);
331 return packetColMajor<LoadMode>(index);
335#ifdef EIGEN_GPU_COMPILE_PHASE
337 return m_impl.template packet<Unaligned>(index);
339 return m_impl.template packet<LoadMode>(index);
341 }
else if (oneByN && !nByOne) {
342 return packetOneByN<LoadMode>(index);
343 }
else if (!oneByN && nByOne) {
344 return packetNByOne<LoadMode>(index);
345 }
else if (oneByN && nByOne) {
346 return packetOneByNByOne<LoadMode>(index);
348 return packetRowMajor<LoadMode>(index);
353 template <
int LoadMode>
354 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetOneByNByOne(Index index)
const {
355 eigen_assert(index + PacketSize - 1 < dimensions().TotalSize());
357 EIGEN_ALIGN_MAX std::remove_const_t<CoeffReturnType> values[PacketSize];
358 Index startDim, endDim;
359 Index inputIndex, outputOffset, batchedIndex;
361 if (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) {
362 startDim = NumDims - 1;
366 endDim = NumDims - 2;
369 batchedIndex = index % m_outputStrides[startDim];
370 inputIndex = batchedIndex / m_outputStrides[endDim];
371 outputOffset = batchedIndex % m_outputStrides[endDim];
373 if (outputOffset + PacketSize <= m_outputStrides[endDim]) {
374 values[0] = m_impl.coeff(inputIndex);
375 return internal::pload1<PacketReturnType>(values);
378 for (
int i = 0, cur = 0; i < PacketSize; ++i, ++cur) {
379 if (outputOffset + cur < m_outputStrides[endDim]) {
380 values[i] = m_impl.coeff(inputIndex);
383 inputIndex = (inputIndex == m_inputStrides[startDim] ? 0 : inputIndex);
384 values[i] = m_impl.coeff(inputIndex);
389 return internal::pload<PacketReturnType>(values);
393 template <
int LoadMode>
394 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetOneByN(Index index)
const {
399 eigen_assert(index + PacketSize - 1 < dimensions().TotalSize());
403 (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) ? m_inputStrides[NumDims - 1] : m_inputStrides[0];
404 Index inputIndex = index % M;
405 if (inputIndex + PacketSize <= M) {
406 return m_impl.template packet<Unaligned>(inputIndex);
408 EIGEN_ALIGN_MAX std::remove_const_t<CoeffReturnType> values[PacketSize];
410 for (
int i = 0; i < PacketSize; ++i) {
411 if (inputIndex > M - 1) {
414 values[i] = m_impl.coeff(inputIndex++);
416 return internal::pload<PacketReturnType>(values);
420 template <
int LoadMode>
421 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetNByOne(Index index)
const {
426 eigen_assert(index + PacketSize - 1 < dimensions().TotalSize());
429 (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) ? m_broadcast[0] : m_broadcast[NumDims - 1];
431 Index inputIndex = index / M;
432 Index outputOffset = index % M;
433 if (outputOffset + PacketSize <= M) {
434 return internal::pset1<PacketReturnType>(m_impl.coeff(inputIndex));
436 EIGEN_ALIGN_MAX std::remove_const_t<CoeffReturnType> values[PacketSize];
438 for (
int i = 0; i < PacketSize; ++i) {
439 if (outputOffset < M) {
440 values[i] = m_impl.coeff(inputIndex);
443 values[i] = m_impl.coeff(++inputIndex);
447 return internal::pload<PacketReturnType>(values);
453 template <
int LoadMode>
454 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetColMajor(Index index)
const {
455 eigen_assert(index + PacketSize - 1 < dimensions().TotalSize());
457 const Index originalIndex = index;
459 Index inputIndex = 0;
461 for (
int i = NumDims - 1; i > 0; --i) {
462 const Index idx = index / m_outputStrides[i];
463 if (internal::index_statically_eq<Broadcast>(i, 1)) {
464 eigen_assert(idx < m_impl.dimensions()[i]);
465 inputIndex += idx * m_inputStrides[i];
467 if (internal::index_statically_eq<InputDimensions>(i, 1)) {
468 eigen_assert(idx % m_impl.dimensions()[i] == 0);
470 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
473 index -= idx * m_outputStrides[i];
476 if (internal::index_statically_eq<Broadcast>(0, 1)) {
477 eigen_assert(index < m_impl.dimensions()[0]);
478 innermostLoc = index;
480 if (internal::index_statically_eq<InputDimensions>(0, 1)) {
481 eigen_assert(index % m_impl.dimensions()[0] == 0);
484 innermostLoc = index % m_impl.dimensions()[0];
487 inputIndex += innermostLoc;
491 if (innermostLoc + PacketSize <= m_impl.dimensions()[0]) {
492 return m_impl.template packet<Unaligned>(inputIndex);
494 EIGEN_ALIGN_MAX std::remove_const_t<CoeffReturnType> values[PacketSize];
495 values[0] = m_impl.coeff(inputIndex);
497 for (
int i = 1; i < PacketSize; ++i) {
498 if (innermostLoc + i < m_impl.dimensions()[0]) {
499 values[i] = m_impl.coeff(inputIndex + i);
501 values[i] = coeffColMajor(originalIndex + i);
504 PacketReturnType rslt = internal::pload<PacketReturnType>(values);
509 template <
int LoadMode>
510 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packetRowMajor(Index index)
const {
511 eigen_assert(index + PacketSize - 1 < dimensions().TotalSize());
513 const Index originalIndex = index;
515 Index inputIndex = 0;
517 for (
int i = 0; i < NumDims - 1; ++i) {
518 const Index idx = index / m_outputStrides[i];
519 if (internal::index_statically_eq<Broadcast>(i, 1)) {
520 eigen_assert(idx < m_impl.dimensions()[i]);
521 inputIndex += idx * m_inputStrides[i];
523 if (internal::index_statically_eq<InputDimensions>(i, 1)) {
524 eigen_assert(idx % m_impl.dimensions()[i] == 0);
526 inputIndex += (idx % m_impl.dimensions()[i]) * m_inputStrides[i];
529 index -= idx * m_outputStrides[i];
532 if (internal::index_statically_eq<Broadcast>(NumDims - 1, 1)) {
533 eigen_assert(index < m_impl.dimensions()[NumDims - 1]);
534 innermostLoc = index;
536 if (internal::index_statically_eq<InputDimensions>(NumDims - 1, 1)) {
537 eigen_assert(index % m_impl.dimensions()[NumDims - 1] == 0);
540 innermostLoc = index % m_impl.dimensions()[NumDims - 1];
543 inputIndex += innermostLoc;
547 if (innermostLoc + PacketSize <= m_impl.dimensions()[NumDims - 1]) {
548 return m_impl.template packet<Unaligned>(inputIndex);
550 EIGEN_ALIGN_MAX std::remove_const_t<CoeffReturnType> values[PacketSize];
551 values[0] = m_impl.coeff(inputIndex);
553 for (
int i = 1; i < PacketSize; ++i) {
554 if (innermostLoc + i < m_impl.dimensions()[NumDims - 1]) {
555 values[i] = m_impl.coeff(inputIndex + i);
557 values[i] = coeffRowMajor(originalIndex + i);
560 PacketReturnType rslt = internal::pload<PacketReturnType>(values);
565 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(
bool vectorized)
const {
566 double compute_cost = TensorOpCost::AddCost<Index>();
567 if (!isCopy && NumDims > 0) {
569 for (
int i = NumDims - 1; i > 0; --i) {
570 compute_cost += TensorOpCost::DivCost<Index>();
571 if (internal::index_statically_eq<Broadcast>(i, 1)) {
572 compute_cost += TensorOpCost::MulCost<Index>() + TensorOpCost::AddCost<Index>();
574 if (!internal::index_statically_eq<InputDimensions>(i, 1)) {
576 TensorOpCost::MulCost<Index>() + TensorOpCost::ModCost<Index>() + TensorOpCost::AddCost<Index>();
579 compute_cost += TensorOpCost::MulCost<Index>() + TensorOpCost::AddCost<Index>();
582 return m_impl.costPerCoeff(vectorized) + TensorOpCost(0, 0, compute_cost, vectorized, PacketSize);
585 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE internal::TensorBlockResourceRequirements getResourceRequirements()
const {
588 const size_t target_size = m_device.firstLevelCacheSize();
589 return internal::TensorBlockResourceRequirements::merge(
590 m_impl.getResourceRequirements(), internal::TensorBlockResourceRequirements::skewed<Scalar>(target_size));
593 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlock block(TensorBlockDesc& desc, TensorBlockScratch& scratch,
594 bool =
false)
const {
595 BlockBroadcastingParams params = blockBroadcastingParams(desc);
597 if (params.inner_dim_size == 0 || params.bcast_dim_size == 0) {
602 const typename TensorBlock::Storage block_storage = TensorBlock::prepareStorage(desc, scratch);
603 ScalarNoConst* materialized_output = block_storage.data();
606 size_t materialized_input_size = 0;
607 ScalarNoConst* materialized_input = NULL;
612 array<BlockBroadcastingIteratorState, NumDims> it;
615 for (
int i = params.inner_dim_count + 1; i < NumDims; ++i) {
616 const Index dim = IsColMajor ? i : NumDims - 1 - i;
617 it[idx].size = params.output_dims[dim];
619 it[idx].output_stride = m_outputStrides[dim];
620 it[idx].output_span = it[idx].output_stride * (it[idx].size - 1);
625 Index output_offset = 0;
629 const Index output_size = NumDims == 0 ? 1 : params.output_dims.TotalSize();
631 for (Index num_output_coeffs = 0; num_output_coeffs < output_size;) {
632 ScalarNoConst* bcast_output = materialized_output + num_output_coeffs;
633 Index bcast_offset = desc.offset() + output_offset;
636 num_output_coeffs += BroadcastBlockAlongBcastDim(params, bcast_offset, scratch, bcast_output, &materialized_input,
637 &materialized_input_size);
640 for (
int j = 0; j < idx; ++j) {
641 if (++it[j].count < it[j].size) {
642 output_offset += it[j].output_stride;
646 output_offset -= it[j].output_span;
650 return block_storage.AsTensorMaterializedBlock();
653 EIGEN_DEVICE_FUNC EvaluatorPointerType data()
const {
return NULL; }
655 const TensorEvaluator<ArgType, Device>& impl()
const {
return m_impl; }
657 Broadcast functor()
const {
return m_broadcast; }
660 static constexpr bool IsColMajor =
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor);
679 struct BlockBroadcastingParams {
680 Dimensions input_dims;
681 Dimensions output_dims;
682 Dimensions output_strides;
686 Index bcast_dim_size;
687 Index inner_dim_size;
691 Dimensions input_block_sizes;
692 Dimensions input_block_strides;
695 BroadcastDimensions bcast_block_sizes;
696 BroadcastDimensions bcast_block_strides;
697 BroadcastDimensions bcast_input_strides;
700 struct BlockBroadcastingIteratorState {
707 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlockBroadcastingParams blockBroadcastingParams(TensorBlockDesc& desc)
const {
708 BlockBroadcastingParams params;
710 params.input_dims = Dimensions(m_impl.dimensions());
713 params.output_dims = desc.dimensions();
714 params.output_strides = internal::strides<Layout>(params.output_dims);
718 params.bcast_dim = 0;
719 params.bcast_dim_size = 1;
720 params.inner_dim_size = 1;
724 params.inner_dim_count = 0;
726 for (
int i = 0; i < NumDims; ++i) {
727 const int dim = IsColMajor ? i : NumDims - i - 1;
729 if (params.output_dims[dim] == m_dimensions[dim]) {
730 params.inner_dim_size *= params.output_dims[dim];
731 ++params.inner_dim_count;
736 eigen_assert(params.output_dims[dim] < m_dimensions[dim]);
737 params.bcast_dim = dim;
738 params.bcast_dim_size = params.output_dims[dim];
743 for (
int i = 0; i < params.inner_dim_count; ++i) {
744 const int dim = IsColMajor ? i : NumDims - i - 1;
745 params.input_block_sizes[dim] = params.input_dims[dim];
747 for (
int i = params.inner_dim_count; i < NumDims; ++i) {
748 const int dim = IsColMajor ? i : NumDims - i - 1;
749 params.input_block_sizes[dim] = 1;
751 params.input_block_strides = internal::strides<Layout>(params.input_block_sizes);
771 for (
int i = 0; i < params.inner_dim_count; ++i) {
772 const int dim = IsColMajor ? i : NumDims - i - 1;
774 const int copy_dim = IsColMajor ? 2 * i : 2 * NumDims - 2 * i - 1;
775 const int broadcast_dim = IsColMajor ? copy_dim + 1 : copy_dim - 1;
777 params.bcast_block_sizes[copy_dim] = params.input_dims[dim];
778 params.bcast_block_sizes[broadcast_dim] = m_broadcast[dim];
779 params.bcast_block_strides[copy_dim] = params.output_strides[dim];
780 params.bcast_block_strides[broadcast_dim] = params.output_strides[dim] * params.input_dims[dim];
781 params.bcast_input_strides[copy_dim] = params.input_block_strides[dim];
782 params.bcast_input_strides[broadcast_dim] = 0;
785 for (
int i = 2 * params.inner_dim_count; i < 2 * NumDims; ++i) {
786 const int dim = IsColMajor ? i : 2 * NumDims - i - 1;
787 params.bcast_block_sizes[dim] = 1;
788 params.bcast_block_strides[dim] = 0;
789 params.bcast_input_strides[dim] = 0;
795 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlock emptyBlock()
const {
796 DSizes<Index, NumDims> dimensions;
797 for (
int i = 0; i < NumDims; ++i) dimensions[i] = 0;
798 return TensorBlock(internal::TensorBlockKind::kView, NULL, dimensions);
801 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index BroadcastBlockAlongBcastDim(
802 BlockBroadcastingParams params, Index bcast_offset, TensorBlockScratch& scratch,
803 ScalarNoConst* materialized_output, ScalarNoConst** materialized_input,
size_t* materialized_input_size)
const {
804 if (params.bcast_dim_size == 1) {
806 return BroadcastBlock(params.input_block_sizes, params.input_block_strides, params.bcast_block_sizes,
807 params.bcast_block_strides, params.bcast_input_strides, bcast_offset, 0, scratch,
808 materialized_output, materialized_input, materialized_input_size);
810 }
else if (params.input_dims[params.bcast_dim] == 1) {
812 const int broadcast_bcast_dim =
813 IsColMajor ? 2 * params.inner_dim_count + 1 : 2 * NumDims - 2 * params.inner_dim_count - 2;
815 params.bcast_block_sizes[broadcast_bcast_dim] = params.bcast_dim_size;
816 params.bcast_input_strides[broadcast_bcast_dim] = 0;
817 params.bcast_block_strides[broadcast_bcast_dim] = params.output_strides[params.bcast_dim];
819 return BroadcastBlock(params.input_block_sizes, params.input_block_strides, params.bcast_block_sizes,
820 params.bcast_block_strides, params.bcast_input_strides, bcast_offset, 0, scratch,
821 materialized_output, materialized_input, materialized_input_size);
826 Index num_output_coeffs = 0;
848 const Index bcast_dim_left_index = bcast_offset / m_outputStrides[params.bcast_dim];
851 const Index input_bcast_dim_size = params.input_dims[params.bcast_dim];
855 const Index first_multiple =
856 numext::div_ceil<Index>(bcast_dim_left_index, input_bcast_dim_size) * input_bcast_dim_size;
858 if (first_multiple <= bcast_dim_left_index + params.bcast_dim_size) {
860 const Index last_multiple =
861 (bcast_dim_left_index + params.bcast_dim_size) / input_bcast_dim_size * input_bcast_dim_size;
862 const int copy_bcast_dim =
863 IsColMajor ? 2 * params.inner_dim_count : 2 * NumDims - 2 * params.inner_dim_count - 1;
864 const int broadcast_bcast_dim =
865 IsColMajor ? 2 * params.inner_dim_count + 1 : 2 * NumDims - 2 * params.inner_dim_count - 2;
867 if (first_multiple > bcast_dim_left_index) {
868 const Index head_size = first_multiple - bcast_dim_left_index;
869 params.input_block_sizes[params.bcast_dim] = head_size;
870 params.bcast_block_sizes[copy_bcast_dim] = head_size;
871 params.bcast_input_strides[copy_bcast_dim] = params.input_block_strides[params.bcast_dim];
872 params.bcast_block_strides[copy_bcast_dim] = params.output_strides[params.bcast_dim];
873 params.bcast_block_sizes[broadcast_bcast_dim] = 1;
874 params.bcast_input_strides[broadcast_bcast_dim] = 0;
875 params.bcast_block_strides[broadcast_bcast_dim] =
876 params.output_strides[params.bcast_dim] * params.input_dims[params.bcast_dim];
879 BroadcastBlock(params.input_block_sizes, params.input_block_strides, params.bcast_block_sizes,
880 params.bcast_block_strides, params.bcast_input_strides, bcast_offset, 0, scratch,
881 materialized_output, materialized_input, materialized_input_size);
883 if (first_multiple < last_multiple) {
884 params.input_block_sizes[params.bcast_dim] = input_bcast_dim_size;
885 params.bcast_block_sizes[copy_bcast_dim] = input_bcast_dim_size;
886 params.bcast_input_strides[copy_bcast_dim] = params.input_block_strides[params.bcast_dim];
887 params.bcast_block_strides[copy_bcast_dim] = params.output_strides[params.bcast_dim];
888 params.bcast_block_sizes[broadcast_bcast_dim] = (last_multiple - first_multiple) / input_bcast_dim_size;
889 params.bcast_input_strides[broadcast_bcast_dim] = 0;
890 params.bcast_block_strides[broadcast_bcast_dim] =
891 params.output_strides[params.bcast_dim] * params.input_dims[params.bcast_dim];
892 const Index offset = (first_multiple - bcast_dim_left_index) * m_outputStrides[params.bcast_dim];
895 BroadcastBlock(params.input_block_sizes, params.input_block_strides, params.bcast_block_sizes,
896 params.bcast_block_strides, params.bcast_input_strides, bcast_offset, offset, scratch,
897 materialized_output, materialized_input, materialized_input_size);
899 if (last_multiple < bcast_dim_left_index + params.bcast_dim_size) {
900 const Index tail_size = bcast_dim_left_index + params.bcast_dim_size - last_multiple;
901 params.input_block_sizes[params.bcast_dim] = tail_size;
902 params.bcast_block_sizes[copy_bcast_dim] = tail_size;
903 params.bcast_input_strides[copy_bcast_dim] = params.input_block_strides[params.bcast_dim];
904 params.bcast_block_strides[copy_bcast_dim] = params.output_strides[params.bcast_dim];
905 params.bcast_block_sizes[broadcast_bcast_dim] = 1;
906 params.bcast_input_strides[broadcast_bcast_dim] = 0;
907 params.bcast_block_strides[broadcast_bcast_dim] =
908 params.output_strides[params.bcast_dim] * params.input_dims[params.bcast_dim];
909 const Index offset = (last_multiple - bcast_dim_left_index) * m_outputStrides[params.bcast_dim];
912 BroadcastBlock(params.input_block_sizes, params.input_block_strides, params.bcast_block_sizes,
913 params.bcast_block_strides, params.bcast_input_strides, bcast_offset, offset, scratch,
914 materialized_output, materialized_input, materialized_input_size);
918 const int copy_bcast_dim =
919 IsColMajor ? 2 * params.inner_dim_count : 2 * NumDims - 2 * params.inner_dim_count - 1;
920 params.input_block_sizes[params.bcast_dim] = params.bcast_dim_size;
921 params.bcast_block_sizes[copy_bcast_dim] = params.bcast_dim_size;
922 params.bcast_input_strides[copy_bcast_dim] = params.input_block_strides[params.bcast_dim];
923 params.bcast_block_strides[copy_bcast_dim] = params.output_strides[params.bcast_dim];
926 BroadcastBlock(params.input_block_sizes, params.input_block_strides, params.bcast_block_sizes,
927 params.bcast_block_strides, params.bcast_input_strides, bcast_offset, 0, scratch,
928 materialized_output, materialized_input, materialized_input_size);
931 return num_output_coeffs;
935 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index BroadcastBlock(
936 const Dimensions& input_block_sizes,
const Dimensions& input_block_strides,
937 const BroadcastDimensions& bcast_block_sizes,
const BroadcastDimensions& bcast_block_strides,
938 const BroadcastDimensions& bcast_input_strides, Index bcast_offset, Index offset, TensorBlockScratch& scratch,
939 ScalarNoConst* materialized_output, ScalarNoConst** materialized_input,
size_t* materialized_input_size)
const {
942 const Index input_offset = bcast_offset + offset;
943 TensorBlockDesc input_desc(IsColMajor ? indexColMajor(input_offset) : indexRowMajor(input_offset),
946 ArgTensorBlock input_block = m_impl.block(input_desc, scratch);
951 const ScalarNoConst* input_buffer = NULL;
953 if (input_block.data() != NULL) {
955 input_buffer = input_block.data();
962 const size_t input_total_size = input_block_sizes.TotalSize();
963 if (*materialized_input == NULL || *materialized_input_size < input_total_size) {
964 *materialized_input_size = input_total_size;
965 void* mem = scratch.allocate(*materialized_input_size *
sizeof(Scalar));
966 *materialized_input =
static_cast<ScalarNoConst*
>(mem);
969 typedef internal::TensorBlockAssignment<ScalarNoConst, NumDims, typename ArgTensorBlock::XprType, Index>
970 TensorBlockAssignment;
972 TensorBlockAssignment::Run(
973 TensorBlockAssignment::target(input_block_sizes, input_block_strides, *materialized_input),
976 input_buffer = *materialized_input;
982 typedef internal::TensorBlockIO<ScalarNoConst, Index, 2 * NumDims, Layout> TensorBlockIO;
984 typename TensorBlockIO::Src src(bcast_input_strides, input_buffer);
985 typename TensorBlockIO::Dst dst(bcast_block_sizes, bcast_block_strides, materialized_output + offset);
987 return TensorBlockIO::Copy(dst, src);
991 const Device EIGEN_DEVICE_REF m_device;
992 const std::remove_reference_t<Broadcast> m_broadcast;
993 Dimensions m_dimensions;
994 array<Index, NumDims> m_outputStrides;
995 array<Index, NumDims> m_inputStrides;
996 TensorEvaluator<ArgType, Device> m_impl;
The tensor base class.
Definition TensorForwardDeclarations.h:68
Definition TensorBroadcasting.h:66
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The tensor evaluator class.
Definition TensorEvaluator.h:30