10#ifndef EIGEN_INNER_PRODUCT_EVAL_H
11#define EIGEN_INNER_PRODUCT_EVAL_H
14#include "./InternalHeaderCheck.h"
21template <typename Scalar, int Size, typename Packet = typename packet_traits<Scalar>::type,
23 (unpacket_traits<Packet>::size <= Size) || is_same<Packet,
typename unpacket_traits<Packet>::half>::value>
24struct find_inner_product_packet_helper;
26template <
typename Scalar,
int Size,
typename Packet>
27struct find_inner_product_packet_helper<Scalar, Size, Packet, false> {
28 using type =
typename find_inner_product_packet_helper<Scalar, Size, typename unpacket_traits<Packet>::half>::type;
31template <
typename Scalar,
int Size,
typename Packet>
32struct find_inner_product_packet_helper<Scalar, Size, Packet, true> {
36template <
typename Scalar,
int Size>
37struct find_inner_product_packet : find_inner_product_packet_helper<Scalar, Size> {};
39template <
typename Scalar>
40struct find_inner_product_packet<Scalar,
Dynamic> {
41 using type =
typename packet_traits<Scalar>::type;
44template <
typename Lhs,
typename Rhs>
45struct inner_product_assert {
46 EIGEN_STATIC_ASSERT_VECTOR_ONLY(Lhs)
47 EIGEN_STATIC_ASSERT_VECTOR_ONLY(Rhs)
48 EIGEN_STATIC_ASSERT_SAME_VECTOR_SIZE(Lhs, Rhs)
50 static EIGEN_DEVICE_FUNC
void run(
const Lhs& lhs,
const Rhs& rhs) {
51 eigen_assert((lhs.size() == rhs.size()) &&
"Inner product: lhs and rhs vectors must have same size");
54 static EIGEN_DEVICE_FUNC
void run(
const Lhs&,
const Rhs&) {}
58template <
typename Func,
typename Lhs,
typename Rhs>
59struct inner_product_evaluator {
60 static constexpr int LhsFlags = evaluator<Lhs>::Flags;
61 static constexpr int RhsFlags = evaluator<Rhs>::Flags;
62 static constexpr int SizeAtCompileTime = size_prefer_fixed(Lhs::SizeAtCompileTime, Rhs::SizeAtCompileTime);
63 static constexpr int MaxSizeAtCompileTime =
64 min_size_prefer_fixed(Lhs::MaxSizeAtCompileTime, Rhs::MaxSizeAtCompileTime);
65 static constexpr int LhsAlignment = evaluator<Lhs>::Alignment;
66 static constexpr int RhsAlignment = evaluator<Rhs>::Alignment;
68 using Scalar =
typename Func::result_type;
69 using Packet =
typename find_inner_product_packet<Scalar, SizeAtCompileTime>::type;
71 static constexpr bool Vectorize =
73 ((MaxSizeAtCompileTime ==
Dynamic) || (unpacket_traits<Packet>::size <= MaxSizeAtCompileTime));
75 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
explicit inner_product_evaluator(
const Lhs& lhs,
const Rhs& rhs,
77 : m_func(func), m_lhs(lhs), m_rhs(rhs), m_size(lhs.size()) {
78 inner_product_assert<Lhs, Rhs>::run(lhs, rhs);
81 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Index size()
const {
return m_size.value(); }
83 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(
Index index)
const {
84 return m_func.coeff(m_lhs.coeff(index), m_rhs.coeff(index));
87 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(
const Scalar& value,
Index index)
const {
88 return m_func.coeff(value, m_lhs.coeff(index), m_rhs.coeff(index));
91 template <
typename PacketType,
int LhsMode = LhsAlignment,
int RhsMode = RhsAlignment>
92 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packet(
Index index)
const {
93 return m_func.packet(m_lhs.template packet<LhsMode, PacketType>(index),
94 m_rhs.template packet<RhsMode, PacketType>(index));
97 template <
typename PacketType,
int LhsMode = LhsAlignment,
int RhsMode = RhsAlignment>
98 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packet(
const PacketType& value,
Index index)
const {
99 return m_func.packet(value, m_lhs.template packet<LhsMode, PacketType>(index),
100 m_rhs.template packet<RhsMode, PacketType>(index));
104 const evaluator<Lhs> m_lhs;
105 const evaluator<Rhs> m_rhs;
106 const variable_if_dynamic<Index, SizeAtCompileTime> m_size;
109template <
typename Evaluator,
bool Vectorize = Evaluator::Vectorize>
110struct inner_product_impl;
113template <
typename Evaluator>
114struct inner_product_impl<Evaluator, false> {
115 using Scalar =
typename Evaluator::Scalar;
116 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(
const Evaluator& eval) {
117 const Index size = eval.size();
118 if (size == 0)
return Scalar(0);
120 Scalar result = eval.coeff(0);
121 for (
Index k = 1; k < size; k++) {
122 result = eval.coeff(result, k);
130template <
typename Evaluator>
131struct inner_product_impl<Evaluator, true> {
132 using UnsignedIndex = std::make_unsigned_t<Index>;
133 using Scalar =
typename Evaluator::Scalar;
134 using Packet =
typename Evaluator::Packet;
135 static constexpr int PacketSize = unpacket_traits<Packet>::size;
136 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(
const Evaluator& eval) {
137 const UnsignedIndex size =
static_cast<UnsignedIndex
>(eval.size());
138 if (size < PacketSize)
return inner_product_impl<Evaluator, false>::run(eval);
140 const UnsignedIndex packetEnd = numext::round_down(size, PacketSize);
141 const UnsignedIndex quadEnd = numext::round_down(size, 4 * PacketSize);
142 const UnsignedIndex numPackets = size / PacketSize;
143 const UnsignedIndex numRemPackets = (packetEnd - quadEnd) / PacketSize;
145 Packet presult0, presult1, presult2, presult3;
147 presult0 = eval.template packet<Packet>(0 * PacketSize);
148 if (numPackets >= 2) presult1 = eval.template packet<Packet>(1 * PacketSize);
149 if (numPackets >= 3) presult2 = eval.template packet<Packet>(2 * PacketSize);
150 if (numPackets >= 4) {
151 presult3 = eval.template packet<Packet>(3 * PacketSize);
153 for (UnsignedIndex k = 4 * PacketSize; k < quadEnd; k += 4 * PacketSize) {
154 presult0 = eval.packet(presult0, k + 0 * PacketSize);
155 presult1 = eval.packet(presult1, k + 1 * PacketSize);
156 presult2 = eval.packet(presult2, k + 2 * PacketSize);
157 presult3 = eval.packet(presult3, k + 3 * PacketSize);
160 if (numRemPackets >= 1) presult0 = eval.packet(presult0, quadEnd + 0 * PacketSize);
161 if (numRemPackets >= 2) presult1 = eval.packet(presult1, quadEnd + 1 * PacketSize);
162 if (numRemPackets == 3) presult2 = eval.packet(presult2, quadEnd + 2 * PacketSize);
164 presult2 = padd(presult2, presult3);
167 if (numPackets >= 3) presult1 = padd(presult1, presult2);
168 if (numPackets >= 2) presult0 = padd(presult0, presult1);
170 Scalar result = predux(presult0);
171 for (UnsignedIndex k = packetEnd; k < size; k++) {
172 result = eval.coeff(result, k);
179template <
typename Scalar,
bool Conj>
180struct conditional_conj;
182template <
typename Scalar>
183struct conditional_conj<Scalar, true> {
184 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(
const Scalar& a) {
return numext::conj(a); }
185 template <
typename Packet>
186 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(
const Packet& a) {
191template <
typename Scalar>
192struct conditional_conj<Scalar, false> {
193 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(
const Scalar& a) {
return a; }
194 template <
typename Packet>
195 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(
const Packet& a) {
200template <
typename LhsScalar,
typename RhsScalar,
bool Conj>
201struct scalar_inner_product_op {
202 using result_type =
typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType;
203 using conj_helper = conditional_conj<LhsScalar, Conj>;
204 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type coeff(
const LhsScalar& a,
const RhsScalar& b)
const {
205 return (conj_helper::coeff(a) * b);
207 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type coeff(
const result_type& accum,
const LhsScalar& a,
208 const RhsScalar& b)
const {
209 return (conj_helper::coeff(a) * b) + accum;
211 static constexpr bool PacketAccess =
false;
216template <
typename Scalar,
bool Conj>
217struct scalar_inner_product_op<
219 typename std::enable_if<internal::is_same<typename ScalarBinaryOpTraits<Scalar, Scalar>::ReturnType, Scalar>::value,
222 using result_type = Scalar;
223 using conj_helper = conditional_conj<Scalar, Conj>;
224 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(
const Scalar& a,
const Scalar& b)
const {
225 return pmul(conj_helper::coeff(a), b);
227 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(
const Scalar& accum,
const Scalar& a,
const Scalar& b)
const {
228 return pmadd(conj_helper::coeff(a), b, accum);
230 template <
typename Packet>
231 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(
const Packet& a,
const Packet& b)
const {
232 return pmul(conj_helper::packet(a), b);
234 template <
typename Packet>
235 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(
const Packet& accum,
const Packet& a,
const Packet& b)
const {
236 return pmadd(conj_helper::packet(a), b, accum);
238 static constexpr bool PacketAccess = packet_traits<Scalar>::HasMul && packet_traits<Scalar>::HasAdd;
241template <
typename Lhs,
typename Rhs,
bool Conj>
242struct default_inner_product_impl {
243 using LhsScalar =
typename traits<Lhs>::Scalar;
244 using RhsScalar =
typename traits<Rhs>::Scalar;
245 using Op = scalar_inner_product_op<LhsScalar, RhsScalar, Conj>;
246 using Evaluator = inner_product_evaluator<Op, Lhs, Rhs>;
247 using result_type =
typename Evaluator::Scalar;
248 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type run(
const MatrixBase<Lhs>& a,
const MatrixBase<Rhs>& b) {
249 Evaluator eval(a.derived(), b.derived(), Op());
250 return inner_product_impl<Evaluator>::run(eval);
254template <
typename Lhs,
typename Rhs>
255struct dot_impl : default_inner_product_impl<Lhs, Rhs, true> {};
const unsigned int PacketAccessBit
Definition Constants.h:97
Namespace containing all symbols from the Eigen library.
Definition B01_Experimental.dox:1
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:82
const int Dynamic
Definition Constants.h:25