10#ifndef EIGEN_PARTIALREDUX_H
11#define EIGEN_PARTIALREDUX_H
14#include "./InternalHeaderCheck.h"
43template <
typename Func,
typename Evaluator>
44struct packetwise_redux_traits {
46 OuterSize = int(Evaluator::IsRowMajor) ? Evaluator::RowsAtCompileTime : Evaluator::ColsAtCompileTime,
48 : OuterSize * Evaluator::CoeffReadCost + (OuterSize - 1) * functor_traits<Func>::Cost,
49 Unrolling = Cost <= EIGEN_UNROLLING_LIMIT ? CompleteUnrolling : NoUnrolling
54template <
typename PacketType,
typename Func>
55EIGEN_DEVICE_FUNC PacketType packetwise_redux_empty_value(
const Func&) {
56 const typename unpacket_traits<PacketType>::type zero(0);
57 return pset1<PacketType>(zero);
61template <
typename PacketType,
typename Scalar>
62EIGEN_DEVICE_FUNC PacketType packetwise_redux_empty_value(
const scalar_product_op<Scalar, Scalar>&) {
63 return pset1<PacketType>(Scalar(1));
67template <typename Func, typename Evaluator, int Unrolling = packetwise_redux_traits<Func, Evaluator>::Unrolling>
68struct packetwise_redux_impl;
71template <
typename Func,
typename Evaluator>
72struct packetwise_redux_impl<Func, Evaluator, CompleteUnrolling> {
73 typedef redux_novec_unroller<Func, Evaluator, 0, Evaluator::SizeAtCompileTime> Base;
74 typedef typename Evaluator::Scalar Scalar;
76 template <
typename PacketType>
77 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE PacketType run(
const Evaluator& eval,
const Func& func, Index ) {
78 return redux_vec_unroller<Func, Evaluator, 0,
79 packetwise_redux_traits<Func, Evaluator>::OuterSize>::template run<PacketType>(eval,
88template <
typename Func,
typename Evaluator, Index Start>
89struct redux_vec_unroller<Func, Evaluator, Start, 0> {
90 template <
typename PacketType>
91 EIGEN_DEVICE_FUNC
static EIGEN_STRONG_INLINE PacketType run(
const Evaluator&,
const Func& f) {
92 return packetwise_redux_empty_value<PacketType>(f);
97template <
typename Func,
typename Evaluator>
98struct packetwise_redux_impl<Func, Evaluator, NoUnrolling> {
99 typedef typename Evaluator::Scalar Scalar;
100 typedef typename redux_traits<Func, Evaluator>::PacketType PacketScalar;
102 template <
typename PacketType>
103 EIGEN_DEVICE_FUNC
static PacketType run(
const Evaluator& eval,
const Func& func, Index size) {
104 if (size == 0)
return packetwise_redux_empty_value<PacketType>(func);
106 const Index size4 = 1 + numext::round_down(size - 1, 4);
107 PacketType p = eval.template packetByOuterInner<Unaligned, PacketType>(0, 0);
111 for (Index i = 1; i < size4; i += 4)
113 p, func.packetOp(func.packetOp(eval.template packetByOuterInner<Unaligned, PacketType>(i + 0, 0),
114 eval.template packetByOuterInner<Unaligned, PacketType>(i + 1, 0)),
115 func.packetOp(eval.template packetByOuterInner<Unaligned, PacketType>(i + 2, 0),
116 eval.template packetByOuterInner<Unaligned, PacketType>(i + 3, 0))));
117 for (Index i = size4; i < size; ++i)
118 p = func.packetOp(p, eval.template packetByOuterInner<Unaligned, PacketType>(i, 0));
123template <
typename Func,
typename Evaluator>
124struct packetwise_segment_redux_impl {
125 typedef typename Evaluator::Scalar Scalar;
126 typedef typename redux_traits<Func, Evaluator>::PacketType PacketScalar;
128 template <
typename PacketType>
129 EIGEN_DEVICE_FUNC
static PacketType run(
const Evaluator& eval,
const Func& func, Index size, Index begin,
131 if (size == 0)
return packetwise_redux_empty_value<PacketType>(func);
133 PacketType p = eval.template packetSegmentByOuterInner<Unaligned, PacketType>(0, 0, begin, count);
134 for (Index i = 1; i < size; ++i)
135 p = func.packetOp(p, eval.template packetSegmentByOuterInner<Unaligned, PacketType>(i, 0, begin, count));
140template <
typename ArgType,
typename MemberOp,
int Direction>
141struct evaluator<PartialReduxExpr<ArgType, MemberOp, Direction> >
142 : evaluator_base<PartialReduxExpr<ArgType, MemberOp, Direction> > {
143 typedef PartialReduxExpr<ArgType, MemberOp, Direction> XprType;
144 typedef typename internal::nested_eval<ArgType, 1>::type ArgTypeNested;
145 typedef add_const_on_value_type_t<ArgTypeNested> ConstArgTypeNested;
146 typedef internal::remove_all_t<ArgTypeNested> ArgTypeNestedCleaned;
147 typedef typename ArgType::Scalar InputScalar;
148 typedef typename XprType::Scalar Scalar;
150 TraversalSize = Direction == int(Vertical) ? int(ArgType::RowsAtCompileTime) : int(ArgType::ColsAtCompileTime)
152 typedef typename MemberOp::template Cost<int(TraversalSize)> CostOpType;
157 : int(TraversalSize) * int(evaluator<ArgType>::CoeffReadCost) + int(CostOpType::value),
159 ArgFlags_ = evaluator<ArgType>::Flags,
161 Vectorizable_ = bool(
int(ArgFlags_) & PacketAccessBit) && bool(MemberOp::Vectorizable) &&
162 (Direction == int(Vertical) ? bool(ArgFlags_ & RowMajorBit) : (ArgFlags_ &
RowMajorBit) == 0) &&
163 (TraversalSize != 0),
165 Flags = (traits<XprType>::Flags &
RowMajorBit) | (evaluator<ArgType>::Flags & (HereditaryBits & (~
RowMajorBit))) |
171 EIGEN_DEVICE_FUNC
explicit evaluator(
const XprType xpr) : m_arg(xpr.nestedExpression()), m_functor(xpr.functor()) {
172 EIGEN_INTERNAL_CHECK_COST_VALUE(TraversalSize == Dynamic ? HugeCost
173 : (TraversalSize == 0 ? 1 :
int(CostOpType::value)));
174 EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
177 typedef typename XprType::CoeffReturnType CoeffReturnType;
179 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Scalar coeff(Index i, Index j)
const {
180 return coeff(Direction == Vertical ? j : i);
183 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Scalar coeff(Index index)
const {
184 return m_functor(m_arg.template subVector<
DirectionType(Direction)>(index));
187 template <
int LoadMode,
typename PacketType>
188 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packet(Index i, Index j)
const {
189 return packet<LoadMode, PacketType>(Direction == Vertical ? j : i);
192 template <
int LoadMode,
typename PacketType>
193 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC PacketType packet(Index idx)
const {
194 static constexpr int PacketSize = internal::unpacket_traits<PacketType>::size;
195 static constexpr int PanelRows = Direction ==
Vertical ? ArgType::RowsAtCompileTime : PacketSize;
196 static constexpr int PanelCols = Direction ==
Vertical ? PacketSize : ArgType::ColsAtCompileTime;
197 using PanelType = Block<
const ArgTypeNestedCleaned, PanelRows, PanelCols,
true >;
198 using PanelEvaluator =
typename internal::redux_evaluator<PanelType>;
199 using BinaryOp =
typename MemberOp::BinaryOp;
200 using Impl = internal::packetwise_redux_impl<BinaryOp, PanelEvaluator>;
206 if (PacketSize == 1)
return internal::pset1<PacketType>(coeff(idx));
210 Index numRows = Direction ==
Vertical ? m_arg.rows() : PacketSize;
211 Index numCols = Direction ==
Vertical ? PacketSize : m_arg.cols();
213 PanelType panel(m_arg, startRow, startCol, numRows, numCols);
214 PanelEvaluator panel_eval(panel);
215 PacketType p = Impl::template run<PacketType>(panel_eval, m_functor.binaryFunc(), m_arg.outerSize());
219 template <
int LoadMode,
typename PacketType>
220 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packetSegment(Index i, Index j, Index begin, Index count)
const {
221 return packetSegment<LoadMode, PacketType>(Direction == Vertical ? j : i, begin, count);
224 template <
int LoadMode,
typename PacketType>
225 EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC PacketType packetSegment(Index idx, Index begin, Index count)
const {
226 static constexpr int PanelRows = Direction ==
Vertical ? ArgType::RowsAtCompileTime :
Dynamic;
227 static constexpr int PanelCols = Direction ==
Vertical ?
Dynamic : ArgType::ColsAtCompileTime;
228 using PanelType = Block<
const ArgTypeNestedCleaned, PanelRows, PanelCols,
true >;
229 using PanelEvaluator =
typename internal::redux_evaluator<PanelType>;
230 using BinaryOp =
typename MemberOp::BinaryOp;
231 using Impl = internal::packetwise_segment_redux_impl<BinaryOp, PanelEvaluator>;
235 Index numRows = Direction ==
Vertical ? m_arg.rows() : begin + count;
236 Index numCols = Direction ==
Vertical ? begin + count : m_arg.cols();
238 PanelType panel(m_arg, startRow, startCol, numRows, numCols);
239 PanelEvaluator panel_eval(panel);
240 PacketType p = Impl::template run<PacketType>(panel_eval, m_functor.binaryFunc(), m_arg.outerSize(), begin, count);
245 ConstArgTypeNested m_arg;
246 const MemberOp m_functor;
DirectionType
Definition Constants.h:263
@ Vertical
Definition Constants.h:266
const unsigned int PacketAccessBit
Definition Constants.h:97
const unsigned int LinearAccessBit
Definition Constants.h:133
const unsigned int RowMajorBit
Definition Constants.h:70
Namespace containing all symbols from the Eigen library.
Definition B01_Experimental.dox:1
const int HugeCost
Definition Constants.h:48
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:82
const int Dynamic
Definition Constants.h:25