Eigen-unsupported  5.0.1-dev+7c7d8473
 
Loading...
Searching...
No Matches
TensorContractionMapper.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H
11#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H
12
13// IWYU pragma: private
14#include "./InternalHeaderCheck.h"
15
16namespace Eigen {
17
18namespace internal {
19
20enum { Rhs = 0, Lhs = 1 };
21
22/*
23 * Implementation of the Eigen blas_data_mapper class for tensors.
24 */
27template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_ = MakePointer>
28struct CoeffLoader;
29
30template <typename Scalar, typename Index, int side, typename Tensor, typename nocontract_t, typename contract_t,
31 int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
32 template <class> class MakePointer_ = MakePointer>
33class BaseTensorContractionMapper;
34
35template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_>
36struct CoeffLoader {
37 enum { DirectOffsets = false };
38
39 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffLoader(const Tensor& tensor) : m_tensor(tensor) {}
40
41 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index) {
42 eigen_assert(false && "unsupported");
43 }
44
45 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const typename MakePointer_<const typename Tensor::Scalar>::Type data() const {
46 eigen_assert(false && "unsupported");
47 return NULL;
48 }
49
50 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const {
51 return m_tensor.coeff(index);
52 }
53
54 template <int LoadMode>
55 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename Tensor::PacketReturnType packet(typename Tensor::Index index) const {
56 return m_tensor.template packet<LoadMode>(index);
57 }
58
59 private:
60 const Tensor m_tensor;
61};
62
63template <typename Tensor, template <class> class MakePointer_>
64struct CoeffLoader<Tensor, true, MakePointer_> {
65 enum { DirectOffsets = true };
66
67 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffLoader(const Tensor& tensor) : m_data(tensor.data()) {}
68
69 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index offset) { m_data += offset; }
70
71 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const typename MakePointer_<const typename Tensor::Scalar>::Type data() const {
72 return m_data;
73 }
74
75 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const {
76 return loadConstant(m_data + index);
77 }
78
79 template <int LoadMode>
80 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename Tensor::PacketReturnType packet(typename Tensor::Index index) const {
81 return internal::ploadt_ro<typename Tensor::PacketReturnType, LoadMode>(m_data + index);
82 }
83
84 private:
85 typedef typename Tensor::Scalar Scalar;
86
87 typename MakePointer_<const Scalar>::Type m_data;
88};
89
90template <typename Scalar, typename Index, int side, typename Tensor, typename nocontract_t, typename contract_t,
91 int packet_size, bool inner_dim_contiguous, int Alignment, template <class> class MakePointer_ = MakePointer>
92class SimpleTensorContractionMapper {
93 public:
94 EIGEN_DEVICE_FUNC SimpleTensorContractionMapper(const Tensor& tensor, const nocontract_t& nocontract_strides,
95 const nocontract_t& ij_strides, const contract_t& contract_strides,
96 const contract_t& k_strides)
97 : m_tensor(tensor),
98 m_nocontract_strides(nocontract_strides),
99 m_ij_strides(ij_strides),
100 m_contract_strides(contract_strides),
101 m_k_strides(k_strides) {}
102
103 enum { DirectOffsets = CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_>::DirectOffsets };
104
105 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index offset) {
106 m_tensor.offsetBuffer(offset);
107 }
108
109 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void prefetch(Index /*i*/) {}
110
111 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(Index row) const {
112 // column major assumption
113 return operator()(row, 0);
114 }
115
116 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(Index row, Index col) const {
117 return m_tensor.coeff(computeIndex(row, col));
118 }
119
120#ifdef EIGEN_MULTIDIMENSIONAL_SUBSCRIPT
121 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator[](Index row, Index col) const { return operator()(row, col); }
122#endif
123
124 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index computeIndex(Index row, Index col) const {
125 const bool left = (side == Lhs);
126 EIGEN_UNUSED_VARIABLE(left); // annoying bug in g++8.1: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=85963
127 Index nocontract_val = left ? row : col;
128 Index linidx = 0;
129 EIGEN_UNROLL_LOOP
130 for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) {
131 const Index idx = nocontract_val / m_ij_strides[i];
132 linidx += idx * m_nocontract_strides[i];
133 nocontract_val -= idx * m_ij_strides[i];
134 }
135 if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) {
136 if (side == Lhs && inner_dim_contiguous) {
137 eigen_assert(m_nocontract_strides[0] == 1);
138 linidx += nocontract_val;
139 } else {
140 linidx += nocontract_val * m_nocontract_strides[0];
141 }
142 }
143
144 Index contract_val = left ? col : row;
145 if (array_size<contract_t>::value > 0) {
146 EIGEN_UNROLL_LOOP
147 for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
148 const Index idx = contract_val / m_k_strides[i];
149 linidx += idx * m_contract_strides[i];
150 contract_val -= idx * m_k_strides[i];
151 }
152
153 if (side == Rhs && inner_dim_contiguous) {
154 eigen_assert(m_contract_strides[0] == 1);
155 linidx += contract_val;
156 } else {
157 linidx += contract_val * m_contract_strides[0];
158 }
159 }
160
161 return linidx;
162 }
163
164 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE IndexPair<Index> computeIndexPair(Index row, Index col,
165 const Index distance) const {
166 const bool left = (side == Lhs);
167 EIGEN_UNUSED_VARIABLE(left); // annoying bug in g++8.1: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=85963
168 Index nocontract_val[2] = {left ? row : col, left ? row + distance : col};
169 Index linidx[2] = {0, 0};
170 if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) {
171 EIGEN_UNROLL_LOOP
172 for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) {
173 const Index idx0 = nocontract_val[0] / m_ij_strides[i];
174 const Index idx1 = nocontract_val[1] / m_ij_strides[i];
175 linidx[0] += idx0 * m_nocontract_strides[i];
176 linidx[1] += idx1 * m_nocontract_strides[i];
177 nocontract_val[0] -= idx0 * m_ij_strides[i];
178 nocontract_val[1] -= idx1 * m_ij_strides[i];
179 }
180 if (side == Lhs && inner_dim_contiguous) {
181 eigen_assert(m_nocontract_strides[0] == 1);
182 linidx[0] += nocontract_val[0];
183 linidx[1] += nocontract_val[1];
184 } else {
185 linidx[0] += nocontract_val[0] * m_nocontract_strides[0];
186 linidx[1] += nocontract_val[1] * m_nocontract_strides[0];
187 }
188 }
189
190 Index contract_val[2] = {left ? col : row, left ? col : row + distance};
191 if (array_size<contract_t>::value > 0) {
192 EIGEN_UNROLL_LOOP
193 for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
194 const Index idx0 = contract_val[0] / m_k_strides[i];
195 const Index idx1 = contract_val[1] / m_k_strides[i];
196 linidx[0] += idx0 * m_contract_strides[i];
197 linidx[1] += idx1 * m_contract_strides[i];
198 contract_val[0] -= idx0 * m_k_strides[i];
199 contract_val[1] -= idx1 * m_k_strides[i];
200 }
201
202 if (side == Rhs && inner_dim_contiguous) {
203 eigen_assert(m_contract_strides[0] == 1);
204 linidx[0] += contract_val[0];
205 linidx[1] += contract_val[1];
206 } else {
207 linidx[0] += contract_val[0] * m_contract_strides[0];
208 linidx[1] += contract_val[1] * m_contract_strides[0];
209 }
210 }
211 return IndexPair<Index>(linidx[0], linidx[1]);
212 }
213
214 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index firstAligned(Index size) const {
215 // Only claim alignment when we can compute the actual stride (ie when we're
216 // dealing with the lhs with inner_dim_contiguous. This is because the
217 // matrix-vector product relies on the stride when dealing with aligned inputs.
218 return (Alignment == Aligned) && (side == Lhs) && inner_dim_contiguous ? 0 : size;
219 }
220 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index stride() const {
221 return ((side == Lhs) && inner_dim_contiguous && array_size<contract_t>::value > 0) ? m_contract_strides[0] : 1;
222 }
223
224 const CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_>& tensor() const { return m_tensor; }
225
226 const nocontract_t& nocontract_strides() const { return m_nocontract_strides; }
227 const nocontract_t& ij_strides() const { return m_ij_strides; }
228 const contract_t& contract_strides() const { return m_contract_strides; }
229 const contract_t& k_strides() const { return m_k_strides; }
230
231 protected:
232 CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_> m_tensor;
233 const nocontract_t m_nocontract_strides;
234 const nocontract_t m_ij_strides;
235 const contract_t m_contract_strides;
236 const contract_t m_k_strides;
237};
238
239template <typename Scalar, typename Index, int side, typename Tensor, typename nocontract_t, typename contract_t,
240 int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
241 template <class> class MakePointer_>
242class BaseTensorContractionMapper
243 : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size,
244 inner_dim_contiguous, Alignment, MakePointer_> {
245 public:
246 typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size,
247 inner_dim_contiguous, Alignment, MakePointer_>
248 ParentMapper;
249
250 EIGEN_DEVICE_FUNC BaseTensorContractionMapper(const Tensor& tensor, const nocontract_t& nocontract_strides,
251 const nocontract_t& ij_strides, const contract_t& contract_strides,
252 const contract_t& k_strides)
253 : ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) {}
254
255 template <typename PacketT, int AlignmentType>
256 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
257 std::enable_if_t<internal::unpacket_traits<PacketT>::size == packet_size, PacketT>
258 load(Index i, Index j) const {
259 // whole method makes column major assumption
260
261 // don't need to add offsets for now (because operator handles that)
262 // current code assumes packet size must be a multiple of 2
263 EIGEN_STATIC_ASSERT(packet_size % 2 == 0, YOU_MADE_A_PROGRAMMING_MISTAKE);
264
265 if (Tensor::PacketAccess && inner_dim_contiguous && !inner_dim_reordered) {
266 const Index index = this->computeIndex(i, j);
267 eigen_assert(this->computeIndex(i + packet_size - 1, j) == index + packet_size - 1);
268 return this->m_tensor.template packet<AlignmentType>(index);
269 }
270
271 const IndexPair<Index> indexPair = this->computeIndexPair(i, j, packet_size - 1);
272 const Index first = indexPair.first;
273 const Index lastIdx = indexPair.second;
274
275 // We can always do optimized packet reads from left hand side right now, because
276 // the vertical matrix dimension on the left hand side is never contracting.
277 // On the right hand side we need to check if the contracting dimensions may have
278 // been shuffled first.
279 if (Tensor::PacketAccess && (side == Lhs || internal::array_size<contract_t>::value <= 1 || !inner_dim_reordered) &&
280 (lastIdx - first) == (packet_size - 1)) {
281 return this->m_tensor.template packet<AlignmentType>(first);
282 }
283
284 EIGEN_ALIGN_MAX Scalar data[packet_size];
285
286 data[0] = this->m_tensor.coeff(first);
287 EIGEN_UNROLL_LOOP
288 for (Index k = 1; k < packet_size - 1; k += 2) {
289 const IndexPair<Index> internal_pair = this->computeIndexPair(i + k, j, 1);
290 data[k] = this->m_tensor.coeff(internal_pair.first);
291 data[k + 1] = this->m_tensor.coeff(internal_pair.second);
292 }
293 data[packet_size - 1] = this->m_tensor.coeff(lastIdx);
294
295 return pload<PacketT>(data);
296 }
297
298 template <typename PacketT, int AlignmentType>
299 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
300 std::enable_if_t<internal::unpacket_traits<PacketT>::size != packet_size, PacketT>
301 load(Index i, Index j) const {
302 const Index requested_packet_size = internal::unpacket_traits<PacketT>::size;
303 EIGEN_ALIGN_MAX Scalar data[requested_packet_size];
304
305 const IndexPair<Index> indexPair = this->computeIndexPair(i, j, requested_packet_size - 1);
306 const Index first = indexPair.first;
307 const Index lastIdx = indexPair.second;
308
309 data[0] = this->m_tensor.coeff(first);
310 for (Index k = 1; k < requested_packet_size - 1; k += 2) {
311 const IndexPair<Index> internal_pair = this->computeIndexPair(i + k, j, 1);
312 data[k] = this->m_tensor.coeff(internal_pair.first);
313 data[k + 1] = this->m_tensor.coeff(internal_pair.second);
314 }
315 data[requested_packet_size - 1] = this->m_tensor.coeff(lastIdx);
316
317 return pload<PacketT>(data);
318 }
319
320 template <typename PacketT, int AlignmentType>
321 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketT loadPacket(Index i, Index j) const {
322 return this->load<PacketT, AlignmentType>(i, j);
323 }
324};
325
326template <typename Scalar, typename Index, int side, typename Tensor, typename nocontract_t, typename contract_t,
327 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, template <class> class MakePointer_>
328class BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous,
329 inner_dim_reordered, Alignment, MakePointer_>
330 : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1,
331 inner_dim_contiguous, Alignment, MakePointer_> {
332 public:
333 typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous,
334 Alignment, MakePointer_>
335 ParentMapper;
336
337 EIGEN_DEVICE_FUNC BaseTensorContractionMapper(const Tensor& tensor, const nocontract_t& nocontract_strides,
338 const nocontract_t& ij_strides, const contract_t& contract_strides,
339 const contract_t& k_strides)
340 : ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) {}
341
342 template <typename PacketT, int>
343 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketT loadPacket(Index i, Index j) const {
344 EIGEN_ALIGN_MAX Scalar data[1];
345 data[0] = this->m_tensor.coeff(this->computeIndex(i, j));
346 return pload<PacketT>(data);
347 }
348 template <typename PacketT, int>
349 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketT load(Index i, Index j) const {
350 EIGEN_ALIGN_MAX Scalar data[1];
351 data[0] = this->m_tensor.coeff(this->computeIndex(i, j));
352 return pload<PacketT>(data);
353 }
354};
355
356template <typename Scalar, typename Index, int side, typename Tensor, typename nocontract_t, typename contract_t,
357 int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
358 template <class> class MakePointer_ = MakePointer>
359class TensorContractionSubMapper {
360 public:
361 typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size,
362 inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_>
363 ParentMapper;
364 typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size,
365 inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_>
366 Self;
367 typedef Self LinearMapper;
368 typedef Self SubMapper;
369
370 enum {
371 // We can use direct offsets iff the parent mapper supports then and we can compute the strides.
372 // TODO: we should also enable direct offsets for the Rhs case.
373 UseDirectOffsets =
374 ParentMapper::DirectOffsets && (side == Lhs) && inner_dim_contiguous && (array_size<contract_t>::value > 0)
375 };
376
377 EIGEN_DEVICE_FUNC TensorContractionSubMapper(const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
378 : m_base_mapper(base_mapper), m_vert_offset(vert_offset), m_horiz_offset(horiz_offset) {
379 // Bake the offsets into the buffer used by the base mapper whenever possible. This avoids the need to recompute
380 // this offset every time we attempt to access a coefficient.
381 if (UseDirectOffsets) {
382 Index stride = m_base_mapper.stride();
383 m_base_mapper.offsetBuffer(vert_offset + horiz_offset * stride);
384 }
385 }
386
387 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
388 if (UseDirectOffsets) {
389 return m_base_mapper(i, 0);
390 }
391 return m_base_mapper(i + m_vert_offset, m_horiz_offset);
392 }
393 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, Index j) const {
394 if (UseDirectOffsets) {
395 return m_base_mapper(i, j);
396 }
397 return m_base_mapper(i + m_vert_offset, j + m_horiz_offset);
398 }
399
400 template <typename PacketT>
401 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT loadPacket(Index i) const {
402 if (UseDirectOffsets) {
403 return m_base_mapper.template loadPacket<PacketT, Alignment>(i, 0);
404 }
405 return m_base_mapper.template loadPacket<PacketT, Alignment>(i + m_vert_offset, m_horiz_offset);
406 }
407
408 template <typename PacketT>
409 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT loadPacket(Index i, Index j) const {
410 if (UseDirectOffsets) {
411 return m_base_mapper.template loadPacket<PacketT, Alignment>(i, j);
412 }
413 return m_base_mapper.template loadPacket<PacketT, Alignment>(i + m_vert_offset, j + m_horiz_offset);
414 }
415
416 template <typename PacketT>
417 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT loadPacketPartial(Index i, Index j, Index, Index = 0) const {
418 if (UseDirectOffsets) {
419 return m_base_mapper.template loadPacket<PacketT, Alignment>(i, j);
420 }
421 return m_base_mapper.template loadPacket<PacketT, Alignment>(i + m_vert_offset, j + m_horiz_offset);
422 }
423
424 template <typename PacketT, int AlignmentType>
425 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT loadPacket(Index i, Index j) const {
426 if (UseDirectOffsets) {
427 return m_base_mapper.template load<PacketT, AlignmentType>(i, j);
428 }
429 return m_base_mapper.template loadPacket<PacketT, AlignmentType>(i + m_vert_offset, j + m_horiz_offset);
430 }
431
432 template <typename PacketT>
433 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, const PacketT& p) const {
434 if (UseDirectOffsets) {
435 m_base_mapper.storePacket(i, 0, p);
436 }
437 m_base_mapper.storePacket(i + m_vert_offset, m_horiz_offset, p);
438 }
439
440 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
441 if (UseDirectOffsets) {
442 return LinearMapper(m_base_mapper, i, j);
443 }
444 return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
445 }
446
447 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubMapper getSubMapper(Index i, Index j) const {
448 if (UseDirectOffsets) {
449 return SubMapper(m_base_mapper, i, j);
450 }
451 return SubMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
452 }
453
454 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const Index stride() const { return m_base_mapper.stride(); }
455
456 template <typename PacketT, int AlignmentType>
457 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i) const {
458 EIGEN_STATIC_ASSERT((internal::is_same<PacketT, PacketT>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
459 const int ActualAlignment = (AlignmentType == Aligned) && (Alignment == Aligned) ? Aligned : Unaligned;
460 if (UseDirectOffsets) {
461 return m_base_mapper.template loadPacket<PacketT, ActualAlignment>(i, 0);
462 }
463 return m_base_mapper.template loadPacket<PacketT, ActualAlignment>(i + m_vert_offset, m_horiz_offset);
464 }
465
466 template <typename PacketT>
467 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool aligned(Index) const {
468 return false;
469 }
470
471 const ParentMapper& base_mapper() const { return m_base_mapper; }
472 Index vert_offset() const { return m_vert_offset; }
473 Index horiz_offset() const { return m_horiz_offset; }
474
475 private:
476 ParentMapper m_base_mapper;
477 const Index m_vert_offset;
478 const Index m_horiz_offset;
479};
480
481template <typename Scalar_, typename Index, int side, typename Tensor, typename nocontract_t, typename contract_t,
482 int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
483 template <class> class MakePointer_ = MakePointer>
484class TensorContractionInputMapper
485 : public BaseTensorContractionMapper<Scalar_, Index, side, Tensor, nocontract_t, contract_t, packet_size,
486 inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> {
487 public:
488 typedef Scalar_ Scalar;
489 typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size,
490 inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_>
491 Base;
492 typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size,
493 inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_>
494 SubMapper;
495 typedef SubMapper VectorMapper;
496 typedef SubMapper LinearMapper;
497
498 EIGEN_DEVICE_FUNC TensorContractionInputMapper(const Tensor& tensor, const nocontract_t& nocontract_strides,
499 const nocontract_t& ij_strides, const contract_t& contract_strides,
500 const contract_t& k_strides)
501 : Base(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) {}
502
503 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
504 return SubMapper(*this, i, j);
505 }
506
507 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
508 return LinearMapper(*this, i, j);
509 }
510
511 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
512 return VectorMapper(*this, i, j);
513 }
514
515 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_>& get_tensor() const {
516 return Base::m_tensor;
517 }
518};
519
520template <typename T>
521struct TensorContractionInputMapperTrait;
522
523template <typename Scalar_, typename Index_, int side_, typename Tensor_, typename nocontract_t_, typename contract_t_,
524 int packet_size_, bool inner_dim_contiguous_, bool inner_dim_reordered_, int Alignment_,
525 template <class> class MakePointer_>
526struct TensorContractionInputMapperTrait<
527 TensorContractionInputMapper<Scalar_, Index_, side_, Tensor_, nocontract_t_, contract_t_, packet_size_,
528 inner_dim_contiguous_, inner_dim_reordered_, Alignment_, MakePointer_> > {
529 typedef Tensor_ XprType;
530 static const bool inner_dim_contiguous = inner_dim_contiguous_;
531 static const bool inner_dim_reordered = inner_dim_reordered_;
532};
533
534} // end namespace internal
535} // end namespace Eigen
536
537#endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H
The tensor class.
Definition Tensor.h:68
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
Definition TensorContractionMapper.h:36