Eigen  5.0.1-dev+60122df6
 
Loading...
Searching...
No Matches
GeneralMatrixVector.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2008-2016 Gael Guennebaud <gael.guennebaud@inria.fr>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_GENERAL_MATRIX_VECTOR_H
11#define EIGEN_GENERAL_MATRIX_VECTOR_H
12
13// IWYU pragma: private
14#include "../InternalHeaderCheck.h"
15
16namespace Eigen {
17
18namespace internal {
19
20enum GEMVPacketSizeType { GEMVPacketFull = 0, GEMVPacketHalf, GEMVPacketQuarter };
21
22template <int N, typename T1, typename T2, typename T3>
23struct gemv_packet_cond {
24 typedef T3 type;
25};
26
27template <typename T1, typename T2, typename T3>
28struct gemv_packet_cond<GEMVPacketFull, T1, T2, T3> {
29 typedef T1 type;
30};
31
32template <typename T1, typename T2, typename T3>
33struct gemv_packet_cond<GEMVPacketHalf, T1, T2, T3> {
34 typedef T2 type;
35};
36
37template <typename LhsScalar, typename RhsScalar, int PacketSize_ = GEMVPacketFull>
38class gemv_traits {
39 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
40
41#define PACKET_DECL_COND_POSTFIX(postfix, name, packet_size) \
42 typedef typename gemv_packet_cond< \
43 packet_size, typename packet_traits<name##Scalar>::type, typename packet_traits<name##Scalar>::half, \
44 typename unpacket_traits<typename packet_traits<name##Scalar>::half>::half>::type name##Packet##postfix
45
46 PACKET_DECL_COND_POSTFIX(_, Lhs, PacketSize_);
47 PACKET_DECL_COND_POSTFIX(_, Rhs, PacketSize_);
48 PACKET_DECL_COND_POSTFIX(_, Res, PacketSize_);
49#undef PACKET_DECL_COND_POSTFIX
50
51 public:
52 enum {
53 Vectorizable = unpacket_traits<LhsPacket_>::vectorizable && unpacket_traits<RhsPacket_>::vectorizable &&
54 int(unpacket_traits<LhsPacket_>::size) == int(unpacket_traits<RhsPacket_>::size),
55 LhsPacketSize = Vectorizable ? unpacket_traits<LhsPacket_>::size : 1,
56 RhsPacketSize = Vectorizable ? unpacket_traits<RhsPacket_>::size : 1,
57 ResPacketSize = Vectorizable ? unpacket_traits<ResPacket_>::size : 1
58 };
59
60 typedef std::conditional_t<Vectorizable, LhsPacket_, LhsScalar> LhsPacket;
61 typedef std::conditional_t<Vectorizable, RhsPacket_, RhsScalar> RhsPacket;
62 typedef std::conditional_t<Vectorizable, ResPacket_, ResScalar> ResPacket;
63};
64
65/* Optimized col-major matrix * vector product:
66 * This algorithm processes the matrix per vertical panels,
67 * which are then processed horizontally per chunk of 8*PacketSize x 1 vertical segments.
68 *
69 * Mixing type logic: C += alpha * A * B
70 * | A | B |alpha| comments
71 * |real |cplx |cplx | no vectorization
72 * |real |cplx |real | alpha is converted to a cplx when calling the run function, no vectorization
73 * |cplx |real |cplx | invalid, the caller has to do tmp: = A * B; C += alpha*tmp
74 * |cplx |real |real | optimal case, vectorization possible via real-cplx mul
75 *
76 * The same reasoning apply for the transposed case.
77 */
78template <typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar,
79 typename RhsMapper, bool ConjugateRhs, int Version>
80struct general_matrix_vector_product<Index, LhsScalar, LhsMapper, ColMajor, ConjugateLhs, RhsScalar, RhsMapper,
81 ConjugateRhs, Version> {
82 typedef gemv_traits<LhsScalar, RhsScalar> Traits;
83 typedef gemv_traits<LhsScalar, RhsScalar, GEMVPacketHalf> HalfTraits;
84 typedef gemv_traits<LhsScalar, RhsScalar, GEMVPacketQuarter> QuarterTraits;
85
86 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
87
88 typedef typename Traits::LhsPacket LhsPacket;
89 typedef typename Traits::RhsPacket RhsPacket;
90 typedef typename Traits::ResPacket ResPacket;
91
92 typedef typename HalfTraits::LhsPacket LhsPacketHalf;
93 typedef typename HalfTraits::RhsPacket RhsPacketHalf;
94 typedef typename HalfTraits::ResPacket ResPacketHalf;
95
96 typedef typename QuarterTraits::LhsPacket LhsPacketQuarter;
97 typedef typename QuarterTraits::RhsPacket RhsPacketQuarter;
98 typedef typename QuarterTraits::ResPacket ResPacketQuarter;
99
100 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run(Index rows, Index cols, const LhsMapper& lhs,
101 const RhsMapper& rhs, ResScalar* res, Index resIncr,
102 RhsScalar alpha);
103};
104
105template <typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar,
106 typename RhsMapper, bool ConjugateRhs, int Version>
107EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void
108general_matrix_vector_product<Index, LhsScalar, LhsMapper, ColMajor, ConjugateLhs, RhsScalar, RhsMapper, ConjugateRhs,
109 Version>::run(Index rows, Index cols, const LhsMapper& alhs, const RhsMapper& rhs,
110 ResScalar* res, Index resIncr, RhsScalar alpha) {
111 EIGEN_UNUSED_VARIABLE(resIncr);
112 eigen_internal_assert(resIncr == 1);
113
114 // The following copy tells the compiler that lhs's attributes are not modified outside this function
115 // This helps GCC to generate proper code.
116 LhsMapper lhs(alhs);
117
118 conj_helper<LhsScalar, RhsScalar, ConjugateLhs, ConjugateRhs> cj;
119 conj_helper<LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs> pcj;
120 conj_helper<LhsPacketHalf, RhsPacketHalf, ConjugateLhs, ConjugateRhs> pcj_half;
121 conj_helper<LhsPacketQuarter, RhsPacketQuarter, ConjugateLhs, ConjugateRhs> pcj_quarter;
122
123 const Index lhsStride = lhs.stride();
124 // TODO: for padded aligned inputs, we could enable aligned reads
125 enum {
126 LhsAlignment = Unaligned,
127 ResPacketSize = Traits::ResPacketSize,
128 ResPacketSizeHalf = HalfTraits::ResPacketSize,
129 ResPacketSizeQuarter = QuarterTraits::ResPacketSize,
130 LhsPacketSize = Traits::LhsPacketSize,
131 HasHalf = (int)ResPacketSizeHalf < (int)ResPacketSize,
132 HasQuarter = (int)ResPacketSizeQuarter < (int)ResPacketSizeHalf
133 };
134
135 const Index n8 = rows - 8 * ResPacketSize + 1;
136 const Index n4 = rows - 4 * ResPacketSize + 1;
137 const Index n3 = rows - 3 * ResPacketSize + 1;
138 const Index n2 = rows - 2 * ResPacketSize + 1;
139 const Index n1 = rows - 1 * ResPacketSize + 1;
140 const Index n_half = rows - 1 * ResPacketSizeHalf + 1;
141 const Index n_quarter = rows - 1 * ResPacketSizeQuarter + 1;
142
143 // TODO: improve the following heuristic:
144 const Index block_cols = cols < 128 ? cols : (lhsStride * sizeof(LhsScalar) < 32000 ? 16 : 4);
145 ResPacket palpha = pset1<ResPacket>(alpha);
146 ResPacketHalf palpha_half = pset1<ResPacketHalf>(alpha);
147 ResPacketQuarter palpha_quarter = pset1<ResPacketQuarter>(alpha);
148
149 for (Index j2 = 0; j2 < cols; j2 += block_cols) {
150 Index jend = numext::mini(j2 + block_cols, cols);
151 Index i = 0;
152 for (; i < n8; i += ResPacketSize * 8) {
153 ResPacket c0 = pset1<ResPacket>(ResScalar(0)), c1 = pset1<ResPacket>(ResScalar(0)),
154 c2 = pset1<ResPacket>(ResScalar(0)), c3 = pset1<ResPacket>(ResScalar(0)),
155 c4 = pset1<ResPacket>(ResScalar(0)), c5 = pset1<ResPacket>(ResScalar(0)),
156 c6 = pset1<ResPacket>(ResScalar(0)), c7 = pset1<ResPacket>(ResScalar(0));
157
158 for (Index j = j2; j < jend; j += 1) {
159 RhsPacket b0 = pset1<RhsPacket>(rhs(j, 0));
160 c0 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 0, j), b0, c0);
161 c1 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 1, j), b0, c1);
162 c2 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 2, j), b0, c2);
163 c3 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 3, j), b0, c3);
164 c4 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 4, j), b0, c4);
165 c5 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 5, j), b0, c5);
166 c6 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 6, j), b0, c6);
167 c7 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 7, j), b0, c7);
168 }
169 pstoreu(res + i + ResPacketSize * 0, pmadd(c0, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 0)));
170 pstoreu(res + i + ResPacketSize * 1, pmadd(c1, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 1)));
171 pstoreu(res + i + ResPacketSize * 2, pmadd(c2, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 2)));
172 pstoreu(res + i + ResPacketSize * 3, pmadd(c3, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 3)));
173 pstoreu(res + i + ResPacketSize * 4, pmadd(c4, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 4)));
174 pstoreu(res + i + ResPacketSize * 5, pmadd(c5, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 5)));
175 pstoreu(res + i + ResPacketSize * 6, pmadd(c6, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 6)));
176 pstoreu(res + i + ResPacketSize * 7, pmadd(c7, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 7)));
177 }
178 if (i < n4) {
179 ResPacket c0 = pset1<ResPacket>(ResScalar(0)), c1 = pset1<ResPacket>(ResScalar(0)),
180 c2 = pset1<ResPacket>(ResScalar(0)), c3 = pset1<ResPacket>(ResScalar(0));
181
182 for (Index j = j2; j < jend; j += 1) {
183 RhsPacket b0 = pset1<RhsPacket>(rhs(j, 0));
184 c0 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 0, j), b0, c0);
185 c1 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 1, j), b0, c1);
186 c2 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 2, j), b0, c2);
187 c3 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 3, j), b0, c3);
188 }
189 pstoreu(res + i + ResPacketSize * 0, pmadd(c0, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 0)));
190 pstoreu(res + i + ResPacketSize * 1, pmadd(c1, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 1)));
191 pstoreu(res + i + ResPacketSize * 2, pmadd(c2, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 2)));
192 pstoreu(res + i + ResPacketSize * 3, pmadd(c3, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 3)));
193
194 i += ResPacketSize * 4;
195 }
196 if (i < n3) {
197 ResPacket c0 = pset1<ResPacket>(ResScalar(0)), c1 = pset1<ResPacket>(ResScalar(0)),
198 c2 = pset1<ResPacket>(ResScalar(0));
199
200 for (Index j = j2; j < jend; j += 1) {
201 RhsPacket b0 = pset1<RhsPacket>(rhs(j, 0));
202 c0 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 0, j), b0, c0);
203 c1 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 1, j), b0, c1);
204 c2 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 2, j), b0, c2);
205 }
206 pstoreu(res + i + ResPacketSize * 0, pmadd(c0, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 0)));
207 pstoreu(res + i + ResPacketSize * 1, pmadd(c1, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 1)));
208 pstoreu(res + i + ResPacketSize * 2, pmadd(c2, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 2)));
209
210 i += ResPacketSize * 3;
211 }
212 if (i < n2) {
213 ResPacket c0 = pset1<ResPacket>(ResScalar(0)), c1 = pset1<ResPacket>(ResScalar(0));
214
215 for (Index j = j2; j < jend; j += 1) {
216 RhsPacket b0 = pset1<RhsPacket>(rhs(j, 0));
217 c0 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 0, j), b0, c0);
218 c1 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + LhsPacketSize * 1, j), b0, c1);
219 }
220 pstoreu(res + i + ResPacketSize * 0, pmadd(c0, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 0)));
221 pstoreu(res + i + ResPacketSize * 1, pmadd(c1, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 1)));
222 i += ResPacketSize * 2;
223 }
224 if (i < n1) {
225 ResPacket c0 = pset1<ResPacket>(ResScalar(0));
226 for (Index j = j2; j < jend; j += 1) {
227 RhsPacket b0 = pset1<RhsPacket>(rhs(j, 0));
228 c0 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 0, j), b0, c0);
229 }
230 pstoreu(res + i + ResPacketSize * 0, pmadd(c0, palpha, ploadu<ResPacket>(res + i + ResPacketSize * 0)));
231 i += ResPacketSize;
232 }
233 if (HasHalf && i < n_half) {
234 ResPacketHalf c0 = pset1<ResPacketHalf>(ResScalar(0));
235 for (Index j = j2; j < jend; j += 1) {
236 RhsPacketHalf b0 = pset1<RhsPacketHalf>(rhs(j, 0));
237 c0 = pcj_half.pmadd(lhs.template load<LhsPacketHalf, LhsAlignment>(i + 0, j), b0, c0);
238 }
239 pstoreu(res + i + ResPacketSizeHalf * 0,
240 pmadd(c0, palpha_half, ploadu<ResPacketHalf>(res + i + ResPacketSizeHalf * 0)));
241 i += ResPacketSizeHalf;
242 }
243 if (HasQuarter && i < n_quarter) {
244 ResPacketQuarter c0 = pset1<ResPacketQuarter>(ResScalar(0));
245 for (Index j = j2; j < jend; j += 1) {
246 RhsPacketQuarter b0 = pset1<RhsPacketQuarter>(rhs(j, 0));
247 c0 = pcj_quarter.pmadd(lhs.template load<LhsPacketQuarter, LhsAlignment>(i + 0, j), b0, c0);
248 }
249 pstoreu(res + i + ResPacketSizeQuarter * 0,
250 pmadd(c0, palpha_quarter, ploadu<ResPacketQuarter>(res + i + ResPacketSizeQuarter * 0)));
251 i += ResPacketSizeQuarter;
252 }
253 for (; i < rows; ++i) {
254 ResScalar c0(0);
255 for (Index j = j2; j < jend; j += 1) c0 += cj.pmul(lhs(i, j), rhs(j, 0));
256 res[i] += alpha * c0;
257 }
258 }
259}
260
261/* Optimized row-major matrix * vector product:
262 * This algorithm processes 4 rows at once that allows to both reduce
263 * the number of load/stores of the result by a factor 4 and to reduce
264 * the instruction dependency. Moreover, we know that all bands have the
265 * same alignment pattern.
266 *
267 * Mixing type logic:
268 * - alpha is always a complex (or converted to a complex)
269 * - no vectorization
270 */
271template <typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar,
272 typename RhsMapper, bool ConjugateRhs, int Version>
273struct general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLhs, RhsScalar, RhsMapper,
274 ConjugateRhs, Version> {
275 typedef gemv_traits<LhsScalar, RhsScalar> Traits;
276 typedef gemv_traits<LhsScalar, RhsScalar, GEMVPacketHalf> HalfTraits;
277 typedef gemv_traits<LhsScalar, RhsScalar, GEMVPacketQuarter> QuarterTraits;
278
279 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
280
281 typedef typename Traits::LhsPacket LhsPacket;
282 typedef typename Traits::RhsPacket RhsPacket;
283 typedef typename Traits::ResPacket ResPacket;
284
285 typedef typename HalfTraits::LhsPacket LhsPacketHalf;
286 typedef typename HalfTraits::RhsPacket RhsPacketHalf;
287 typedef typename HalfTraits::ResPacket ResPacketHalf;
288
289 typedef typename QuarterTraits::LhsPacket LhsPacketQuarter;
290 typedef typename QuarterTraits::RhsPacket RhsPacketQuarter;
291 typedef typename QuarterTraits::ResPacket ResPacketQuarter;
292
293 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run(Index rows, Index cols, const LhsMapper& lhs,
294 const RhsMapper& rhs, ResScalar* res, Index resIncr,
295 ResScalar alpha);
296};
297
298template <typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar,
299 typename RhsMapper, bool ConjugateRhs, int Version>
300EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void
301general_matrix_vector_product<Index, LhsScalar, LhsMapper, RowMajor, ConjugateLhs, RhsScalar, RhsMapper, ConjugateRhs,
302 Version>::run(Index rows, Index cols, const LhsMapper& alhs, const RhsMapper& rhs,
303 ResScalar* res, Index resIncr, ResScalar alpha) {
304 // The following copy tells the compiler that lhs's attributes are not modified outside this function
305 // This helps GCC to generate proper code.
306 LhsMapper lhs(alhs);
307
308 eigen_internal_assert(rhs.stride() == 1);
309 conj_helper<LhsScalar, RhsScalar, ConjugateLhs, ConjugateRhs> cj;
310 conj_helper<LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs> pcj;
311 conj_helper<LhsPacketHalf, RhsPacketHalf, ConjugateLhs, ConjugateRhs> pcj_half;
312 conj_helper<LhsPacketQuarter, RhsPacketQuarter, ConjugateLhs, ConjugateRhs> pcj_quarter;
313
314 // TODO: fine tune the following heuristic. The rationale is that if the matrix is very large,
315 // processing 8 rows at once might be counter productive wrt cache.
316 const Index n8 = lhs.stride() * sizeof(LhsScalar) > 32000 ? 0 : rows - 7;
317 const Index n4 = rows - 3;
318 const Index n2 = rows - 1;
319
320 // TODO: for padded aligned inputs, we could enable aligned reads
321 enum {
322 LhsAlignment = Unaligned,
323 ResPacketSize = Traits::ResPacketSize,
324 ResPacketSizeHalf = HalfTraits::ResPacketSize,
325 ResPacketSizeQuarter = QuarterTraits::ResPacketSize,
326 LhsPacketSize = Traits::LhsPacketSize,
327 LhsPacketSizeHalf = HalfTraits::LhsPacketSize,
328 LhsPacketSizeQuarter = QuarterTraits::LhsPacketSize,
329 HasHalf = (int)ResPacketSizeHalf < (int)ResPacketSize,
330 HasQuarter = (int)ResPacketSizeQuarter < (int)ResPacketSizeHalf
331 };
332
333 using UnsignedIndex = typename make_unsigned<Index>::type;
334 const Index fullColBlockEnd = LhsPacketSize * (UnsignedIndex(cols) / LhsPacketSize);
335 const Index halfColBlockEnd = LhsPacketSizeHalf * (UnsignedIndex(cols) / LhsPacketSizeHalf);
336 const Index quarterColBlockEnd = LhsPacketSizeQuarter * (UnsignedIndex(cols) / LhsPacketSizeQuarter);
337
338 Index i = 0;
339 for (; i < n8; i += 8) {
340 ResPacket c0 = pset1<ResPacket>(ResScalar(0)), c1 = pset1<ResPacket>(ResScalar(0)),
341 c2 = pset1<ResPacket>(ResScalar(0)), c3 = pset1<ResPacket>(ResScalar(0)),
342 c4 = pset1<ResPacket>(ResScalar(0)), c5 = pset1<ResPacket>(ResScalar(0)),
343 c6 = pset1<ResPacket>(ResScalar(0)), c7 = pset1<ResPacket>(ResScalar(0));
344
345 for (Index j = 0; j < fullColBlockEnd; j += LhsPacketSize) {
346 RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(j, 0);
347
348 c0 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 0, j), b0, c0);
349 c1 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 1, j), b0, c1);
350 c2 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 2, j), b0, c2);
351 c3 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 3, j), b0, c3);
352 c4 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 4, j), b0, c4);
353 c5 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 5, j), b0, c5);
354 c6 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 6, j), b0, c6);
355 c7 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 7, j), b0, c7);
356 }
357 ResScalar cc0 = predux(c0);
358 ResScalar cc1 = predux(c1);
359 ResScalar cc2 = predux(c2);
360 ResScalar cc3 = predux(c3);
361 ResScalar cc4 = predux(c4);
362 ResScalar cc5 = predux(c5);
363 ResScalar cc6 = predux(c6);
364 ResScalar cc7 = predux(c7);
365
366 for (Index j = fullColBlockEnd; j < cols; ++j) {
367 RhsScalar b0 = rhs(j, 0);
368
369 cc0 += cj.pmul(lhs(i + 0, j), b0);
370 cc1 += cj.pmul(lhs(i + 1, j), b0);
371 cc2 += cj.pmul(lhs(i + 2, j), b0);
372 cc3 += cj.pmul(lhs(i + 3, j), b0);
373 cc4 += cj.pmul(lhs(i + 4, j), b0);
374 cc5 += cj.pmul(lhs(i + 5, j), b0);
375 cc6 += cj.pmul(lhs(i + 6, j), b0);
376 cc7 += cj.pmul(lhs(i + 7, j), b0);
377 }
378 res[(i + 0) * resIncr] += alpha * cc0;
379 res[(i + 1) * resIncr] += alpha * cc1;
380 res[(i + 2) * resIncr] += alpha * cc2;
381 res[(i + 3) * resIncr] += alpha * cc3;
382 res[(i + 4) * resIncr] += alpha * cc4;
383 res[(i + 5) * resIncr] += alpha * cc5;
384 res[(i + 6) * resIncr] += alpha * cc6;
385 res[(i + 7) * resIncr] += alpha * cc7;
386 }
387 for (; i < n4; i += 4) {
388 ResPacket c0 = pset1<ResPacket>(ResScalar(0)), c1 = pset1<ResPacket>(ResScalar(0)),
389 c2 = pset1<ResPacket>(ResScalar(0)), c3 = pset1<ResPacket>(ResScalar(0));
390
391 for (Index j = 0; j < fullColBlockEnd; j += LhsPacketSize) {
392 RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(j, 0);
393
394 c0 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 0, j), b0, c0);
395 c1 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 1, j), b0, c1);
396 c2 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 2, j), b0, c2);
397 c3 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 3, j), b0, c3);
398 }
399 ResScalar cc0 = predux(c0);
400 ResScalar cc1 = predux(c1);
401 ResScalar cc2 = predux(c2);
402 ResScalar cc3 = predux(c3);
403
404 for (Index j = fullColBlockEnd; j < cols; ++j) {
405 RhsScalar b0 = rhs(j, 0);
406
407 cc0 += cj.pmul(lhs(i + 0, j), b0);
408 cc1 += cj.pmul(lhs(i + 1, j), b0);
409 cc2 += cj.pmul(lhs(i + 2, j), b0);
410 cc3 += cj.pmul(lhs(i + 3, j), b0);
411 }
412 res[(i + 0) * resIncr] += alpha * cc0;
413 res[(i + 1) * resIncr] += alpha * cc1;
414 res[(i + 2) * resIncr] += alpha * cc2;
415 res[(i + 3) * resIncr] += alpha * cc3;
416 }
417 for (; i < n2; i += 2) {
418 ResPacket c0 = pset1<ResPacket>(ResScalar(0)), c1 = pset1<ResPacket>(ResScalar(0));
419
420 for (Index j = 0; j < fullColBlockEnd; j += LhsPacketSize) {
421 RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(j, 0);
422
423 c0 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 0, j), b0, c0);
424 c1 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 1, j), b0, c1);
425 }
426 ResScalar cc0 = predux(c0);
427 ResScalar cc1 = predux(c1);
428
429 for (Index j = fullColBlockEnd; j < cols; ++j) {
430 RhsScalar b0 = rhs(j, 0);
431
432 cc0 += cj.pmul(lhs(i + 0, j), b0);
433 cc1 += cj.pmul(lhs(i + 1, j), b0);
434 }
435 res[(i + 0) * resIncr] += alpha * cc0;
436 res[(i + 1) * resIncr] += alpha * cc1;
437 }
438 for (; i < rows; ++i) {
439 ResPacket c0 = pset1<ResPacket>(ResScalar(0));
440 ResPacketHalf c0_h = pset1<ResPacketHalf>(ResScalar(0));
441 ResPacketQuarter c0_q = pset1<ResPacketQuarter>(ResScalar(0));
442
443 for (Index j = 0; j < fullColBlockEnd; j += LhsPacketSize) {
444 RhsPacket b0 = rhs.template load<RhsPacket, Unaligned>(j, 0);
445 c0 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i, j), b0, c0);
446 }
447 ResScalar cc0 = predux(c0);
448 if (HasHalf) {
449 for (Index j = fullColBlockEnd; j < halfColBlockEnd; j += LhsPacketSizeHalf) {
450 RhsPacketHalf b0 = rhs.template load<RhsPacketHalf, Unaligned>(j, 0);
451 c0_h = pcj_half.pmadd(lhs.template load<LhsPacketHalf, LhsAlignment>(i, j), b0, c0_h);
452 }
453 cc0 += predux(c0_h);
454 }
455 if (HasQuarter) {
456 for (Index j = halfColBlockEnd; j < quarterColBlockEnd; j += LhsPacketSizeQuarter) {
457 RhsPacketQuarter b0 = rhs.template load<RhsPacketQuarter, Unaligned>(j, 0);
458 c0_q = pcj_quarter.pmadd(lhs.template load<LhsPacketQuarter, LhsAlignment>(i, j), b0, c0_q);
459 }
460 cc0 += predux(c0_q);
461 }
462 for (Index j = quarterColBlockEnd; j < cols; ++j) {
463 cc0 += cj.pmul(lhs(i, j), rhs(j, 0));
464 }
465 res[i * resIncr] += alpha * cc0;
466 }
467}
468
469} // end namespace internal
470
471} // end namespace Eigen
472
473#endif // EIGEN_GENERAL_MATRIX_VECTOR_H
@ Unaligned
Definition Constants.h:235
@ ColMajor
Definition Constants.h:318
@ RowMajor
Definition Constants.h:320
Namespace containing all symbols from the Eigen library.
Definition B01_Experimental.dox:1
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition Meta.h:82