10#ifndef EIGEN_VISITOR_H
11#define EIGEN_VISITOR_H
14#include "./InternalHeaderCheck.h"
20template <
typename Visitor,
typename Derived,
int UnrollCount,
21 bool Vectorize = (Derived::PacketAccess && functor_traits<Visitor>::PacketAccess),
bool LinearAccess =
false,
22 bool ShortCircuitEvaluation =
false>
25template <
typename Visitor,
bool ShortCircuitEvaluation = false>
26struct short_circuit_eval_impl {
28 static constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
bool run(
const Visitor&) {
return false; }
30template <
typename Visitor>
31struct short_circuit_eval_impl<Visitor, true> {
33 static constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
bool run(
const Visitor& visitor) {
return visitor.done(); }
37template <
typename Visitor,
typename Derived,
int UnrollCount,
bool Vectorize,
bool ShortCircuitEvaluation>
38struct visitor_impl<Visitor, Derived, UnrollCount, Vectorize, false, ShortCircuitEvaluation> {
40 using Scalar =
typename Derived::Scalar;
41 using Packet =
typename packet_traits<Scalar>::type;
42 static constexpr bool RowMajor = Derived::IsRowMajor;
43 static constexpr int RowsAtCompileTime = Derived::RowsAtCompileTime;
44 static constexpr int ColsAtCompileTime = Derived::ColsAtCompileTime;
45 static constexpr int PacketSize = packet_traits<Scalar>::size;
47 static constexpr bool CanVectorize(
int K) {
48 constexpr int InnerSizeAtCompileTime =
RowMajor ? ColsAtCompileTime : RowsAtCompileTime;
49 if (InnerSizeAtCompileTime < PacketSize)
return false;
50 return Vectorize && (InnerSizeAtCompileTime - (K % InnerSizeAtCompileTime) >= PacketSize);
53 template <
int K = 0,
bool Empty = (K == UnrollCount), std::enable_if_t<Empty,
bool> = true>
54 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(
const Derived&, Visitor&) {}
56 template <
int K = 0,
bool Empty = (K == UnrollCount),
bool Initialize = (K == 0),
bool DoVectorOp = CanVectorize(K),
57 std::enable_if_t<!Empty && Initialize && !DoVectorOp, bool> =
true>
58 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(
const Derived& mat, Visitor& visitor) {
59 visitor.init(mat.coeff(0, 0), 0, 0);
63 template <
int K = 0,
bool Empty = (K == UnrollCount),
bool Initialize = (K == 0),
bool DoVectorOp = CanVectorize(K),
64 std::enable_if_t<!Empty && !Initialize && !DoVectorOp, bool> =
true>
65 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(
const Derived& mat, Visitor& visitor) {
66 static constexpr int R =
RowMajor ? (K / ColsAtCompileTime) : (K % RowsAtCompileTime);
67 static constexpr int C =
RowMajor ? (K % ColsAtCompileTime) : (K / RowsAtCompileTime);
68 visitor(mat.coeff(R, C), R, C);
69 run<K + 1>(mat, visitor);
72 template <
int K = 0,
bool Empty = (K == UnrollCount),
bool Initialize = (K == 0),
bool DoVectorOp = CanVectorize(K),
73 std::enable_if_t<!Empty && Initialize && DoVectorOp, bool> =
true>
74 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(
const Derived& mat, Visitor& visitor) {
75 Packet P = mat.template packet<Packet>(0, 0);
76 visitor.initpacket(P, 0, 0);
77 run<PacketSize>(mat, visitor);
80 template <
int K = 0,
bool Empty = (K == UnrollCount),
bool Initialize = (K == 0),
bool DoVectorOp = CanVectorize(K),
81 std::enable_if_t<!Empty && !Initialize && DoVectorOp, bool> =
true>
82 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(
const Derived& mat, Visitor& visitor) {
83 static constexpr int R =
RowMajor ? (K / ColsAtCompileTime) : (K % RowsAtCompileTime);
84 static constexpr int C =
RowMajor ? (K % ColsAtCompileTime) : (K / RowsAtCompileTime);
85 Packet P = mat.template packet<Packet>(R, C);
86 visitor.packet(P, R, C);
87 run<K + PacketSize>(mat, visitor);
92template <
typename Visitor,
typename Derived,
int UnrollCount,
bool Vectorize,
bool ShortCircuitEvaluation>
93struct visitor_impl<Visitor, Derived, UnrollCount, Vectorize, true, ShortCircuitEvaluation> {
95 using Scalar =
typename Derived::Scalar;
96 using Packet =
typename packet_traits<Scalar>::type;
97 static constexpr int PacketSize = packet_traits<Scalar>::size;
99 static constexpr bool CanVectorize(
int K) {
return Vectorize && ((UnrollCount - K) >= PacketSize); }
102 template <
int K = 0,
bool Empty = (K == UnrollCount), std::enable_if_t<Empty,
bool> = true>
103 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(
const Derived&, Visitor&) {}
106 template <
int K = 0,
bool Empty = (K == UnrollCount),
bool Initialize = (K == 0),
bool DoVectorOp = CanVectorize(K),
107 std::enable_if_t<!Empty && Initialize && !DoVectorOp, bool> =
true>
108 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(
const Derived& mat, Visitor& visitor) {
109 visitor.init(mat.coeff(0), 0);
110 run<1>(mat, visitor);
114 template <
int K = 0,
bool Empty = (K == UnrollCount),
bool Initialize = (K == 0),
bool DoVectorOp = CanVectorize(K),
115 std::enable_if_t<!Empty && !Initialize && !DoVectorOp, bool> =
true>
116 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(
const Derived& mat, Visitor& visitor) {
117 visitor(mat.coeff(K), K);
118 run<K + 1>(mat, visitor);
122 template <
int K = 0,
bool Empty = (K == UnrollCount),
bool Initialize = (K == 0),
bool DoVectorOp = CanVectorize(K),
123 std::enable_if_t<!Empty && Initialize && DoVectorOp, bool> =
true>
124 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(
const Derived& mat, Visitor& visitor) {
125 Packet P = mat.template packet<Packet>(0);
126 visitor.initpacket(P, 0);
127 run<PacketSize>(mat, visitor);
131 template <
int K = 0,
bool Empty = (K == UnrollCount),
bool Initialize = (K == 0),
bool DoVectorOp = CanVectorize(K),
132 std::enable_if_t<!Empty && !Initialize && DoVectorOp, bool> =
true>
133 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(
const Derived& mat, Visitor& visitor) {
134 Packet P = mat.template packet<Packet>(K);
135 visitor.packet(P, K);
136 run<K + PacketSize>(mat, visitor);
141template <
typename Visitor,
typename Derived,
bool ShortCircuitEvaluation>
142struct visitor_impl<Visitor, Derived,
Dynamic, false, false, ShortCircuitEvaluation> {
143 using short_circuit = short_circuit_eval_impl<Visitor, ShortCircuitEvaluation>;
144 static constexpr bool RowMajor = Derived::IsRowMajor;
146 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(
const Derived& mat, Visitor& visitor) {
149 if (innerSize == 0 || outerSize == 0)
return;
151 visitor.init(mat.coeff(0, 0), 0, 0);
152 if (short_circuit::run(visitor))
return;
153 for (
Index i = 1; i < innerSize; ++i) {
156 visitor(mat.coeff(r, c), r, c);
157 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor))
return;
160 for (
Index j = 1; j < outerSize; j++) {
161 for (
Index i = 0; i < innerSize; ++i) {
164 visitor(mat.coeff(r, c), r, c);
165 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor))
return;
172template <
typename Visitor,
typename Derived,
bool ShortCircuitEvaluation>
173struct visitor_impl<Visitor, Derived,
Dynamic, true, false, ShortCircuitEvaluation> {
174 using Scalar =
typename Derived::Scalar;
175 using Packet =
typename packet_traits<Scalar>::type;
176 static constexpr int PacketSize = packet_traits<Scalar>::size;
177 using short_circuit = short_circuit_eval_impl<Visitor, ShortCircuitEvaluation>;
178 static constexpr bool RowMajor = Derived::IsRowMajor;
180 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(
const Derived& mat, Visitor& visitor) {
183 if (innerSize == 0 || outerSize == 0)
return;
186 if (innerSize < PacketSize) {
187 visitor.init(mat.coeff(0, 0), 0, 0);
190 Packet p = mat.template packet<Packet>(0, 0);
191 visitor.initpacket(p, 0, 0);
194 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor))
return;
195 for (; i + PacketSize - 1 < innerSize; i += PacketSize) {
198 Packet p = mat.template packet<Packet>(r, c);
199 visitor.packet(p, r, c);
200 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor))
return;
202 for (; i < innerSize; ++i) {
205 visitor(mat.coeff(r, c), r, c);
206 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor))
return;
209 for (
Index j = 1; j < outerSize; j++) {
211 for (; i + PacketSize - 1 < innerSize; i += PacketSize) {
214 Packet p = mat.template packet<Packet>(r, c);
215 visitor.packet(p, r, c);
216 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor))
return;
218 for (; i < innerSize; ++i) {
221 visitor(mat.coeff(r, c), r, c);
222 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor))
return;
229template <
typename Visitor,
typename Derived,
bool ShortCircuitEvaluation>
230struct visitor_impl<Visitor, Derived,
Dynamic, false, true, ShortCircuitEvaluation> {
231 using short_circuit = short_circuit_eval_impl<Visitor, ShortCircuitEvaluation>;
233 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(
const Derived& mat, Visitor& visitor) {
234 const Index size = mat.size();
235 if (size == 0)
return;
236 visitor.init(mat.coeff(0), 0);
237 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor))
return;
238 for (
Index k = 1; k < size; k++) {
239 visitor(mat.coeff(k), k);
240 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor))
return;
246template <
typename Visitor,
typename Derived,
bool ShortCircuitEvaluation>
247struct visitor_impl<Visitor, Derived,
Dynamic, true, true, ShortCircuitEvaluation> {
248 using Scalar =
typename Derived::Scalar;
249 using Packet =
typename packet_traits<Scalar>::type;
250 static constexpr int PacketSize = packet_traits<Scalar>::size;
251 using short_circuit = short_circuit_eval_impl<Visitor, ShortCircuitEvaluation>;
253 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(
const Derived& mat, Visitor& visitor) {
254 const Index size = mat.size();
255 if (size == 0)
return;
257 if (size < PacketSize) {
258 visitor.init(mat.coeff(0), 0);
261 Packet p = mat.template packet<Packet>(k);
262 visitor.initpacket(p, k);
265 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor))
return;
266 for (; k + PacketSize - 1 < size; k += PacketSize) {
267 Packet p = mat.template packet<Packet>(k);
268 visitor.packet(p, k);
269 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor))
return;
271 for (; k < size; k++) {
272 visitor(mat.coeff(k), k);
273 if EIGEN_PREDICT_FALSE (short_circuit::run(visitor))
return;
279template <
typename XprType>
280class visitor_evaluator {
282 typedef evaluator<XprType> Evaluator;
283 typedef typename XprType::Scalar Scalar;
284 using Packet =
typename packet_traits<Scalar>::type;
285 typedef std::remove_const_t<typename XprType::CoeffReturnType> CoeffReturnType;
287 static constexpr bool PacketAccess =
static_cast<bool>(Evaluator::Flags &
PacketAccessBit);
288 static constexpr bool LinearAccess =
static_cast<bool>(Evaluator::Flags &
LinearAccessBit);
289 static constexpr bool IsRowMajor =
static_cast<bool>(XprType::IsRowMajor);
290 static constexpr int RowsAtCompileTime = XprType::RowsAtCompileTime;
291 static constexpr int ColsAtCompileTime = XprType::ColsAtCompileTime;
292 static constexpr int XprAlignment = Evaluator::Alignment;
293 static constexpr int CoeffReadCost = Evaluator::CoeffReadCost;
295 EIGEN_DEVICE_FUNC
explicit visitor_evaluator(
const XprType& xpr) : m_evaluator(xpr), m_xpr(xpr) {}
297 EIGEN_DEVICE_FUNC
constexpr Index rows() const noexcept {
return m_xpr.rows(); }
298 EIGEN_DEVICE_FUNC
constexpr Index cols() const noexcept {
return m_xpr.cols(); }
299 EIGEN_DEVICE_FUNC
constexpr Index size() const noexcept {
return m_xpr.size(); }
301 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(
Index row,
Index col)
const {
302 return m_evaluator.coeff(row, col);
304 template <
typename Packet,
int Alignment = Unaligned>
305 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(
Index row,
Index col)
const {
306 return m_evaluator.template packet<Alignment, Packet>(row, col);
309 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(
Index index)
const {
return m_evaluator.coeff(index); }
310 template <
typename Packet,
int Alignment = XprAlignment>
311 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(
Index index)
const {
312 return m_evaluator.template packet<Alignment, Packet>(index);
316 Evaluator m_evaluator;
317 const XprType& m_xpr;
320template <
typename Derived,
typename Visitor,
bool ShortCircuitEvaulation>
322 using Evaluator = visitor_evaluator<Derived>;
329 static constexpr int InnerSizeAtCompileTime = IsRowMajor ? ColsAtCompileTime : RowsAtCompileTime;
330 static constexpr int OuterSizeAtCompileTime = IsRowMajor ? RowsAtCompileTime : ColsAtCompileTime;
332 static constexpr bool LinearAccess =
333 Evaluator::LinearAccess &&
static_cast<bool>(functor_traits<Visitor>::LinearAccess);
334 static constexpr bool Vectorize = Evaluator::PacketAccess &&
static_cast<bool>(functor_traits<Visitor>::PacketAccess);
336 static constexpr int PacketSize = packet_traits<Scalar>::size;
337 static constexpr int VectorOps =
338 Vectorize ? (LinearAccess ? (SizeAtCompileTime / PacketSize)
339 : (OuterSizeAtCompileTime * (InnerSizeAtCompileTime / PacketSize)))
341 static constexpr int ScalarOps = SizeAtCompileTime - (VectorOps * PacketSize);
343 static constexpr int TotalOps = VectorOps + ScalarOps;
345 static constexpr int UnrollCost = int(Evaluator::CoeffReadCost) + int(functor_traits<Visitor>::Cost);
346 static constexpr bool Unroll = (SizeAtCompileTime !=
Dynamic) && ((TotalOps * UnrollCost) <= EIGEN_UNROLLING_LIMIT);
347 static constexpr int UnrollCount = Unroll ? int(SizeAtCompileTime) :
Dynamic;
349 using impl = visitor_impl<Visitor, Evaluator, UnrollCount, Vectorize, LinearAccess, ShortCircuitEvaulation>;
351 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void run(
const DenseBase<Derived>& mat, Visitor& visitor) {
352 Evaluator evaluator(mat.derived());
353 impl::run(evaluator, visitor);
378template <
typename Derived>
379template <
typename Visitor>
381 using impl = internal::visit_impl<Derived, Visitor,
false>;
382 impl::run(derived(), visitor);
387template <
typename Scalar>
389 using result_type = bool;
390 using Packet =
typename packet_traits<Scalar>::type;
392 EIGEN_DEVICE_FUNC
inline void init(
const Scalar& value, Index) { res = (value != Scalar(0)); }
393 EIGEN_DEVICE_FUNC
inline bool all_predux(
const Packet& p)
const {
return !predux_any(pcmp_eq(p, pzero(p))); }
394 EIGEN_DEVICE_FUNC
inline void initpacket(
const Packet& p,
Index,
Index) { res = all_predux(p); }
395 EIGEN_DEVICE_FUNC
inline void initpacket(
const Packet& p,
Index) { res = all_predux(p); }
396 EIGEN_DEVICE_FUNC
inline void operator()(
const Scalar& value,
Index,
Index) { res = res && (value != Scalar(0)); }
397 EIGEN_DEVICE_FUNC
inline void operator()(
const Scalar& value,
Index) { res = res && (value != Scalar(0)); }
398 EIGEN_DEVICE_FUNC
inline void packet(
const Packet& p,
Index,
Index) { res = res && all_predux(p); }
399 EIGEN_DEVICE_FUNC
inline void packet(
const Packet& p,
Index) { res = res && all_predux(p); }
400 EIGEN_DEVICE_FUNC
inline bool done()
const {
return !res; }
403template <
typename Scalar>
404struct functor_traits<all_visitor<Scalar>> {
405 enum { Cost = NumTraits<Scalar>::ReadCost, LinearAccess =
true, PacketAccess = packet_traits<Scalar>::HasCmp };
408template <
typename Scalar>
410 using result_type = bool;
411 using Packet =
typename packet_traits<Scalar>::type;
412 EIGEN_DEVICE_FUNC
inline void init(
const Scalar& value,
Index,
Index) { res = (value != Scalar(0)); }
413 EIGEN_DEVICE_FUNC
inline void init(
const Scalar& value,
Index) { res = (value != Scalar(0)); }
414 EIGEN_DEVICE_FUNC
inline bool any_predux(
const Packet& p)
const {
415 return predux_any(pandnot(ptrue(p), pcmp_eq(p, pzero(p))));
417 EIGEN_DEVICE_FUNC
inline void initpacket(
const Packet& p,
Index,
Index) { res = any_predux(p); }
418 EIGEN_DEVICE_FUNC
inline void initpacket(
const Packet& p,
Index) { res = any_predux(p); }
419 EIGEN_DEVICE_FUNC
inline void operator()(
const Scalar& value,
Index,
Index) { res = res || (value != Scalar(0)); }
420 EIGEN_DEVICE_FUNC
inline void operator()(
const Scalar& value,
Index) { res = res || (value != Scalar(0)); }
421 EIGEN_DEVICE_FUNC
inline void packet(
const Packet& p,
Index,
Index) { res = res || any_predux(p); }
422 EIGEN_DEVICE_FUNC
inline void packet(
const Packet& p,
Index) { res = res || any_predux(p); }
423 EIGEN_DEVICE_FUNC
inline bool done()
const {
return res; }
426template <
typename Scalar>
427struct functor_traits<any_visitor<Scalar>> {
428 enum { Cost = NumTraits<Scalar>::ReadCost, LinearAccess =
true, PacketAccess = packet_traits<Scalar>::HasCmp };
431template <
typename Scalar>
432struct count_visitor {
433 using result_type =
Index;
434 using Packet =
typename packet_traits<Scalar>::type;
435 EIGEN_DEVICE_FUNC
inline void init(
const Scalar& value,
Index,
Index) { res = value != Scalar(0) ? 1 : 0; }
436 EIGEN_DEVICE_FUNC
inline void init(
const Scalar& value,
Index) { res = value != Scalar(0) ? 1 : 0; }
437 EIGEN_DEVICE_FUNC
inline Index count_redux(
const Packet& p)
const {
438 const Packet cst_one = pset1<Packet>(Scalar(1));
439 Packet true_vals = pandnot(cst_one, pcmp_eq(p, pzero(p)));
440 Scalar num_true = predux(true_vals);
441 return static_cast<Index>(num_true);
443 EIGEN_DEVICE_FUNC
inline void initpacket(
const Packet& p,
Index,
Index) { res = count_redux(p); }
444 EIGEN_DEVICE_FUNC
inline void initpacket(
const Packet& p,
Index) { res = count_redux(p); }
445 EIGEN_DEVICE_FUNC
inline void operator()(
const Scalar& value,
Index,
Index) {
446 if (value != Scalar(0)) res++;
448 EIGEN_DEVICE_FUNC
inline void operator()(
const Scalar& value,
Index) {
449 if (value != Scalar(0)) res++;
451 EIGEN_DEVICE_FUNC
inline void packet(
const Packet& p,
Index,
Index) { res += count_redux(p); }
452 EIGEN_DEVICE_FUNC
inline void packet(
const Packet& p,
Index) { res += count_redux(p); }
456template <
typename Scalar>
457struct functor_traits<count_visitor<Scalar>> {
459 Cost = NumTraits<Scalar>::AddCost,
462 PacketAccess = packet_traits<Scalar>::HasCmp && packet_traits<Scalar>::HasAdd && !is_same<Scalar, bool>::value
466template <typename Derived, bool AlwaysTrue = NumTraits<typename traits<Derived>::Scalar>::IsInteger>
467struct all_finite_impl {
468 static EIGEN_DEVICE_FUNC
inline bool run(
const Derived& ) {
return true; }
470#if !defined(__FINITE_MATH_ONLY__) || !(__FINITE_MATH_ONLY__)
471template <
typename Derived>
472struct all_finite_impl<Derived, false> {
473 static EIGEN_DEVICE_FUNC
inline bool run(
const Derived& derived) {
return derived.array().isFiniteTyped().all(); }
486template <
typename Derived>
488 using Visitor = internal::all_visitor<Scalar>;
489 using impl = internal::visit_impl<Derived, Visitor,
true>;
491 impl::run(derived(), visitor);
499template <
typename Derived>
501 using Visitor = internal::any_visitor<Scalar>;
502 using impl = internal::visit_impl<Derived, Visitor,
true>;
504 impl::run(derived(), visitor);
512template <
typename Derived>
514 using Visitor = internal::count_visitor<Scalar>;
515 using impl = internal::visit_impl<Derived, Visitor,
false>;
517 impl::run(derived(), visitor);
521template <
typename Derived>
522EIGEN_DEVICE_FUNC
inline bool DenseBase<Derived>::hasNaN()
const {
523 return derived().cwiseTypedNotEqual(derived()).any();
530template <
typename Derived>
532 return internal::all_finite_impl<Derived>::run(derived());
void visit(Visitor &func) const
Definition Visitor.h:380
internal::traits< Derived >::Scalar Scalar
Definition DenseBase.h:62
@ SizeAtCompileTime
Definition DenseBase.h:108
@ IsRowMajor
Definition DenseBase.h:166
@ ColsAtCompileTime
Definition DenseBase.h:102
@ RowsAtCompileTime
Definition DenseBase.h:96
Index count() const
Definition Visitor.h:513
bool any() const
Definition Visitor.h:500
bool all() const
Definition Visitor.h:487
bool allFinite() const
Definition Visitor.h:531
@ RowMajor
Definition Constants.h:320
const unsigned int PacketAccessBit
Definition Constants.h:97
const unsigned int LinearAccessBit
Definition Constants.h:133
Namespace containing all symbols from the Eigen library.
Definition B01_Experimental.dox:1
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:82
const int Dynamic
Definition Constants.h:25