Eigen  5.0.1-dev+60122df6
 
Loading...
Searching...
No Matches
InnerProduct.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2024 Charlie Schlosser <cs.schlosser@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_INNER_PRODUCT_EVAL_H
11#define EIGEN_INNER_PRODUCT_EVAL_H
12
13// IWYU pragma: private
14#include "./InternalHeaderCheck.h"
15
16namespace Eigen {
17
18namespace internal {
19
20// recursively searches for the largest simd type that does not exceed Size, or the smallest if no such type exists
21template <typename Scalar, int Size, typename Packet = typename packet_traits<Scalar>::type,
22 bool Stop =
23 (unpacket_traits<Packet>::size <= Size) || is_same<Packet, typename unpacket_traits<Packet>::half>::value>
24struct find_inner_product_packet_helper;
25
26template <typename Scalar, int Size, typename Packet>
27struct find_inner_product_packet_helper<Scalar, Size, Packet, false> {
28 using type = typename find_inner_product_packet_helper<Scalar, Size, typename unpacket_traits<Packet>::half>::type;
29};
30
31template <typename Scalar, int Size, typename Packet>
32struct find_inner_product_packet_helper<Scalar, Size, Packet, true> {
33 using type = Packet;
34};
35
36template <typename Scalar, int Size>
37struct find_inner_product_packet : find_inner_product_packet_helper<Scalar, Size> {};
38
39template <typename Scalar>
40struct find_inner_product_packet<Scalar, Dynamic> {
41 using type = typename packet_traits<Scalar>::type;
42};
43
44template <typename Lhs, typename Rhs>
45struct inner_product_assert {
46 EIGEN_STATIC_ASSERT_VECTOR_ONLY(Lhs)
47 EIGEN_STATIC_ASSERT_VECTOR_ONLY(Rhs)
48 EIGEN_STATIC_ASSERT_SAME_VECTOR_SIZE(Lhs, Rhs)
49#ifndef EIGEN_NO_DEBUG
50 static EIGEN_DEVICE_FUNC void run(const Lhs& lhs, const Rhs& rhs) {
51 eigen_assert((lhs.size() == rhs.size()) && "Inner product: lhs and rhs vectors must have same size");
52 }
53#else
54 static EIGEN_DEVICE_FUNC void run(const Lhs&, const Rhs&) {}
55#endif
56};
57
58template <typename Func, typename Lhs, typename Rhs>
59struct inner_product_evaluator {
60 static constexpr int LhsFlags = evaluator<Lhs>::Flags;
61 static constexpr int RhsFlags = evaluator<Rhs>::Flags;
62 static constexpr int SizeAtCompileTime = size_prefer_fixed(Lhs::SizeAtCompileTime, Rhs::SizeAtCompileTime);
63 static constexpr int MaxSizeAtCompileTime =
64 min_size_prefer_fixed(Lhs::MaxSizeAtCompileTime, Rhs::MaxSizeAtCompileTime);
65 static constexpr int LhsAlignment = evaluator<Lhs>::Alignment;
66 static constexpr int RhsAlignment = evaluator<Rhs>::Alignment;
67
68 using Scalar = typename Func::result_type;
69 using Packet = typename find_inner_product_packet<Scalar, SizeAtCompileTime>::type;
70
71 static constexpr bool Vectorize =
72 bool(LhsFlags & RhsFlags & PacketAccessBit) && Func::PacketAccess &&
73 ((MaxSizeAtCompileTime == Dynamic) || (unpacket_traits<Packet>::size <= MaxSizeAtCompileTime));
74
75 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit inner_product_evaluator(const Lhs& lhs, const Rhs& rhs,
76 Func func = Func())
77 : m_func(func), m_lhs(lhs), m_rhs(rhs), m_size(lhs.size()) {
78 inner_product_assert<Lhs, Rhs>::run(lhs, rhs);
79 }
80
81 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index size() const { return m_size.value(); }
82
83 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(Index index) const {
84 return m_func.coeff(m_lhs.coeff(index), m_rhs.coeff(index));
85 }
86
87 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(const Scalar& value, Index index) const {
88 return m_func.coeff(value, m_lhs.coeff(index), m_rhs.coeff(index));
89 }
90
91 template <typename PacketType, int LhsMode = LhsAlignment, int RhsMode = RhsAlignment>
92 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packet(Index index) const {
93 return m_func.packet(m_lhs.template packet<LhsMode, PacketType>(index),
94 m_rhs.template packet<RhsMode, PacketType>(index));
95 }
96
97 template <typename PacketType, int LhsMode = LhsAlignment, int RhsMode = RhsAlignment>
98 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packet(const PacketType& value, Index index) const {
99 return m_func.packet(value, m_lhs.template packet<LhsMode, PacketType>(index),
100 m_rhs.template packet<RhsMode, PacketType>(index));
101 }
102
103 const Func m_func;
104 const evaluator<Lhs> m_lhs;
105 const evaluator<Rhs> m_rhs;
106 const variable_if_dynamic<Index, SizeAtCompileTime> m_size;
107};
108
109template <typename Evaluator, bool Vectorize = Evaluator::Vectorize>
110struct inner_product_impl;
111
112// scalar loop
113template <typename Evaluator>
114struct inner_product_impl<Evaluator, false> {
115 using Scalar = typename Evaluator::Scalar;
116 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Evaluator& eval) {
117 const Index size = eval.size();
118 if (size == 0) return Scalar(0);
119
120 Scalar result = eval.coeff(0);
121 for (Index k = 1; k < size; k++) {
122 result = eval.coeff(result, k);
123 }
124
125 return result;
126 }
127};
128
129// vector loop
130template <typename Evaluator>
131struct inner_product_impl<Evaluator, true> {
132 using UnsignedIndex = std::make_unsigned_t<Index>;
133 using Scalar = typename Evaluator::Scalar;
134 using Packet = typename Evaluator::Packet;
135 static constexpr int PacketSize = unpacket_traits<Packet>::size;
136 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Evaluator& eval) {
137 const UnsignedIndex size = static_cast<UnsignedIndex>(eval.size());
138 if (size < PacketSize) return inner_product_impl<Evaluator, false>::run(eval);
139
140 const UnsignedIndex packetEnd = numext::round_down(size, PacketSize);
141 const UnsignedIndex quadEnd = numext::round_down(size, 4 * PacketSize);
142 const UnsignedIndex numPackets = size / PacketSize;
143 const UnsignedIndex numRemPackets = (packetEnd - quadEnd) / PacketSize;
144
145 Packet presult0, presult1, presult2, presult3;
146
147 presult0 = eval.template packet<Packet>(0 * PacketSize);
148 if (numPackets >= 2) presult1 = eval.template packet<Packet>(1 * PacketSize);
149 if (numPackets >= 3) presult2 = eval.template packet<Packet>(2 * PacketSize);
150 if (numPackets >= 4) {
151 presult3 = eval.template packet<Packet>(3 * PacketSize);
152
153 for (UnsignedIndex k = 4 * PacketSize; k < quadEnd; k += 4 * PacketSize) {
154 presult0 = eval.packet(presult0, k + 0 * PacketSize);
155 presult1 = eval.packet(presult1, k + 1 * PacketSize);
156 presult2 = eval.packet(presult2, k + 2 * PacketSize);
157 presult3 = eval.packet(presult3, k + 3 * PacketSize);
158 }
159
160 if (numRemPackets >= 1) presult0 = eval.packet(presult0, quadEnd + 0 * PacketSize);
161 if (numRemPackets >= 2) presult1 = eval.packet(presult1, quadEnd + 1 * PacketSize);
162 if (numRemPackets == 3) presult2 = eval.packet(presult2, quadEnd + 2 * PacketSize);
163
164 presult2 = padd(presult2, presult3);
165 }
166
167 if (numPackets >= 3) presult1 = padd(presult1, presult2);
168 if (numPackets >= 2) presult0 = padd(presult0, presult1);
169
170 Scalar result = predux(presult0);
171 for (UnsignedIndex k = packetEnd; k < size; k++) {
172 result = eval.coeff(result, k);
173 }
174
175 return result;
176 }
177};
178
179template <typename Scalar, bool Conj>
180struct conditional_conj;
181
182template <typename Scalar>
183struct conditional_conj<Scalar, true> {
184 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(const Scalar& a) { return numext::conj(a); }
185 template <typename Packet>
186 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(const Packet& a) {
187 return pconj(a);
188 }
189};
190
191template <typename Scalar>
192struct conditional_conj<Scalar, false> {
193 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(const Scalar& a) { return a; }
194 template <typename Packet>
195 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(const Packet& a) {
196 return a;
197 }
198};
199
200template <typename LhsScalar, typename RhsScalar, bool Conj>
201struct scalar_inner_product_op {
202 using result_type = typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType;
203 using conj_helper = conditional_conj<LhsScalar, Conj>;
204 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type coeff(const LhsScalar& a, const RhsScalar& b) const {
205 return (conj_helper::coeff(a) * b);
206 }
207 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type coeff(const result_type& accum, const LhsScalar& a,
208 const RhsScalar& b) const {
209 return (conj_helper::coeff(a) * b) + accum;
210 }
211 static constexpr bool PacketAccess = false;
212};
213
214// Partial specialization for packet access if and only if
215// LhsScalar == RhsScalar == ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType.
216template <typename Scalar, bool Conj>
217struct scalar_inner_product_op<
218 Scalar,
219 typename std::enable_if<internal::is_same<typename ScalarBinaryOpTraits<Scalar, Scalar>::ReturnType, Scalar>::value,
220 Scalar>::type,
221 Conj> {
222 using result_type = Scalar;
223 using conj_helper = conditional_conj<Scalar, Conj>;
224 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(const Scalar& a, const Scalar& b) const {
225 return pmul(conj_helper::coeff(a), b);
226 }
227 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(const Scalar& accum, const Scalar& a, const Scalar& b) const {
228 return pmadd(conj_helper::coeff(a), b, accum);
229 }
230 template <typename Packet>
231 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(const Packet& a, const Packet& b) const {
232 return pmul(conj_helper::packet(a), b);
233 }
234 template <typename Packet>
235 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packet(const Packet& accum, const Packet& a, const Packet& b) const {
236 return pmadd(conj_helper::packet(a), b, accum);
237 }
238 static constexpr bool PacketAccess = packet_traits<Scalar>::HasMul && packet_traits<Scalar>::HasAdd;
239};
240
241template <typename Lhs, typename Rhs, bool Conj>
242struct default_inner_product_impl {
243 using LhsScalar = typename traits<Lhs>::Scalar;
244 using RhsScalar = typename traits<Rhs>::Scalar;
245 using Op = scalar_inner_product_op<LhsScalar, RhsScalar, Conj>;
246 using Evaluator = inner_product_evaluator<Op, Lhs, Rhs>;
247 using result_type = typename Evaluator::Scalar;
248 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type run(const MatrixBase<Lhs>& a, const MatrixBase<Rhs>& b) {
249 Evaluator eval(a.derived(), b.derived(), Op());
250 return inner_product_impl<Evaluator>::run(eval);
251 }
252};
253
254template <typename Lhs, typename Rhs>
255struct dot_impl : default_inner_product_impl<Lhs, Rhs, true> {};
256
257} // namespace internal
258} // namespace Eigen
259
260#endif // EIGEN_INNER_PRODUCT_EVAL_H
const unsigned int PacketAccessBit
Definition Constants.h:97
Namespace containing all symbols from the Eigen library.
Definition B01_Experimental.dox:1
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:82
const int Dynamic
Definition Constants.h:25