Eigen-unsupported  5.0.1-dev+284dcc12
 
Loading...
Searching...
No Matches
TensorContractionMapper.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H
11#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H
12
13// IWYU pragma: private
14#include "./InternalHeaderCheck.h"
15
16namespace Eigen {
17
18namespace internal {
19
20enum { Rhs = 0, Lhs = 1 };
21
22/*
23 * Implementation of the Eigen blas_data_mapper class for tensors.
24 */
27template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_ = MakePointer>
28struct CoeffLoader;
29
30template <typename Scalar, typename Index, int side, typename Tensor, typename nocontract_t, typename contract_t,
31 int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
32 template <class> class MakePointer_ = MakePointer>
33class BaseTensorContractionMapper;
34
35template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_>
36struct CoeffLoader {
37 enum { DirectOffsets = false };
38
39 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffLoader(const Tensor& tensor) : m_tensor(tensor) {}
40
41 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index) {
42 eigen_assert(false && "unsupported");
43 }
44
45 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const typename MakePointer_<const typename Tensor::Scalar>::Type data() const {
46 eigen_assert(false && "unsupported");
47 return NULL;
48 }
49
50 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const {
51 return m_tensor.coeff(index);
52 }
53
54 template <int LoadMode>
55 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename Tensor::PacketReturnType packet(typename Tensor::Index index) const {
56 return m_tensor.template packet<LoadMode>(index);
57 }
58
59 private:
60 const Tensor m_tensor;
61};
62
63template <typename Tensor, template <class> class MakePointer_>
64struct CoeffLoader<Tensor, true, MakePointer_> {
65 enum { DirectOffsets = true };
66
67 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffLoader(const Tensor& tensor) : m_data(tensor.data()) {}
68
69 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index offset) { m_data += offset; }
70
71 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const typename MakePointer_<const typename Tensor::Scalar>::Type data() const {
72 return m_data;
73 }
74
75 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const {
76 return loadConstant(m_data + index);
77 }
78
79 template <int LoadMode>
80 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename Tensor::PacketReturnType packet(typename Tensor::Index index) const {
81 return internal::ploadt_ro<typename Tensor::PacketReturnType, LoadMode>(m_data + index);
82 }
83
84 private:
85 typedef typename Tensor::Scalar Scalar;
86
87 typename MakePointer_<const Scalar>::Type m_data;
88};
89
90template <typename Scalar, typename Index, int side, typename Tensor, typename nocontract_t, typename contract_t,
91 int packet_size, bool inner_dim_contiguous, int Alignment, template <class> class MakePointer_ = MakePointer>
92class SimpleTensorContractionMapper {
93 public:
94 EIGEN_DEVICE_FUNC SimpleTensorContractionMapper(const Tensor& tensor, const nocontract_t& nocontract_strides,
95 const nocontract_t& ij_strides, const contract_t& contract_strides,
96 const contract_t& k_strides)
97 : m_tensor(tensor),
98 m_nocontract_strides(nocontract_strides),
99 m_ij_strides(ij_strides),
100 m_contract_strides(contract_strides),
101 m_k_strides(k_strides) {}
102
103 enum { DirectOffsets = CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_>::DirectOffsets };
104
105 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index offset) {
106 m_tensor.offsetBuffer(offset);
107 }
108
109 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void prefetch(Index /*i*/) {}
110
111 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(Index row) const {
112 // column major assumption
113 return operator()(row, 0);
114 }
115
116 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(Index row, Index col) const {
117 return m_tensor.coeff(computeIndex(row, col));
118 }
119
120 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index computeIndex(Index row, Index col) const {
121 const bool left = (side == Lhs);
122 EIGEN_UNUSED_VARIABLE(left); // annoying bug in g++8.1: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=85963
123 Index nocontract_val = left ? row : col;
124 Index linidx = 0;
125 EIGEN_UNROLL_LOOP
126 for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) {
127 const Index idx = nocontract_val / m_ij_strides[i];
128 linidx += idx * m_nocontract_strides[i];
129 nocontract_val -= idx * m_ij_strides[i];
130 }
131 if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) {
132 if (side == Lhs && inner_dim_contiguous) {
133 eigen_assert(m_nocontract_strides[0] == 1);
134 linidx += nocontract_val;
135 } else {
136 linidx += nocontract_val * m_nocontract_strides[0];
137 }
138 }
139
140 Index contract_val = left ? col : row;
141 if (array_size<contract_t>::value > 0) {
142 EIGEN_UNROLL_LOOP
143 for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
144 const Index idx = contract_val / m_k_strides[i];
145 linidx += idx * m_contract_strides[i];
146 contract_val -= idx * m_k_strides[i];
147 }
148
149 if (side == Rhs && inner_dim_contiguous) {
150 eigen_assert(m_contract_strides[0] == 1);
151 linidx += contract_val;
152 } else {
153 linidx += contract_val * m_contract_strides[0];
154 }
155 }
156
157 return linidx;
158 }
159
160 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE IndexPair<Index> computeIndexPair(Index row, Index col,
161 const Index distance) const {
162 const bool left = (side == Lhs);
163 EIGEN_UNUSED_VARIABLE(left); // annoying bug in g++8.1: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=85963
164 Index nocontract_val[2] = {left ? row : col, left ? row + distance : col};
165 Index linidx[2] = {0, 0};
166 if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) {
167 EIGEN_UNROLL_LOOP
168 for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) {
169 const Index idx0 = nocontract_val[0] / m_ij_strides[i];
170 const Index idx1 = nocontract_val[1] / m_ij_strides[i];
171 linidx[0] += idx0 * m_nocontract_strides[i];
172 linidx[1] += idx1 * m_nocontract_strides[i];
173 nocontract_val[0] -= idx0 * m_ij_strides[i];
174 nocontract_val[1] -= idx1 * m_ij_strides[i];
175 }
176 if (side == Lhs && inner_dim_contiguous) {
177 eigen_assert(m_nocontract_strides[0] == 1);
178 linidx[0] += nocontract_val[0];
179 linidx[1] += nocontract_val[1];
180 } else {
181 linidx[0] += nocontract_val[0] * m_nocontract_strides[0];
182 linidx[1] += nocontract_val[1] * m_nocontract_strides[0];
183 }
184 }
185
186 Index contract_val[2] = {left ? col : row, left ? col : row + distance};
187 if (array_size<contract_t>::value > 0) {
188 EIGEN_UNROLL_LOOP
189 for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
190 const Index idx0 = contract_val[0] / m_k_strides[i];
191 const Index idx1 = contract_val[1] / m_k_strides[i];
192 linidx[0] += idx0 * m_contract_strides[i];
193 linidx[1] += idx1 * m_contract_strides[i];
194 contract_val[0] -= idx0 * m_k_strides[i];
195 contract_val[1] -= idx1 * m_k_strides[i];
196 }
197
198 if (side == Rhs && inner_dim_contiguous) {
199 eigen_assert(m_contract_strides[0] == 1);
200 linidx[0] += contract_val[0];
201 linidx[1] += contract_val[1];
202 } else {
203 linidx[0] += contract_val[0] * m_contract_strides[0];
204 linidx[1] += contract_val[1] * m_contract_strides[0];
205 }
206 }
207 return IndexPair<Index>(linidx[0], linidx[1]);
208 }
209
210 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index firstAligned(Index size) const {
211 // Only claim alignment when we can compute the actual stride (ie when we're
212 // dealing with the lhs with inner_dim_contiguous. This is because the
213 // matrix-vector product relies on the stride when dealing with aligned inputs.
214 return (Alignment == Aligned) && (side == Lhs) && inner_dim_contiguous ? 0 : size;
215 }
216 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index stride() const {
217 return ((side == Lhs) && inner_dim_contiguous && array_size<contract_t>::value > 0) ? m_contract_strides[0] : 1;
218 }
219
220 const CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_>& tensor() const { return m_tensor; }
221
222 const nocontract_t& nocontract_strides() const { return m_nocontract_strides; }
223 const nocontract_t& ij_strides() const { return m_ij_strides; }
224 const contract_t& contract_strides() const { return m_contract_strides; }
225 const contract_t& k_strides() const { return m_k_strides; }
226
227 protected:
228 CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_> m_tensor;
229 const nocontract_t m_nocontract_strides;
230 const nocontract_t m_ij_strides;
231 const contract_t m_contract_strides;
232 const contract_t m_k_strides;
233};
234
235template <typename Scalar, typename Index, int side, typename Tensor, typename nocontract_t, typename contract_t,
236 int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
237 template <class> class MakePointer_>
238class BaseTensorContractionMapper
239 : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size,
240 inner_dim_contiguous, Alignment, MakePointer_> {
241 public:
242 typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size,
243 inner_dim_contiguous, Alignment, MakePointer_>
244 ParentMapper;
245
246 EIGEN_DEVICE_FUNC BaseTensorContractionMapper(const Tensor& tensor, const nocontract_t& nocontract_strides,
247 const nocontract_t& ij_strides, const contract_t& contract_strides,
248 const contract_t& k_strides)
249 : ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) {}
250
251 template <typename PacketT, int AlignmentType>
252 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
253 std::enable_if_t<internal::unpacket_traits<PacketT>::size == packet_size, PacketT>
254 load(Index i, Index j) const {
255 // whole method makes column major assumption
256
257 // don't need to add offsets for now (because operator handles that)
258 // current code assumes packet size must be a multiple of 2
259 EIGEN_STATIC_ASSERT(packet_size % 2 == 0, YOU_MADE_A_PROGRAMMING_MISTAKE);
260
261 if (Tensor::PacketAccess && inner_dim_contiguous && !inner_dim_reordered) {
262 const Index index = this->computeIndex(i, j);
263 eigen_assert(this->computeIndex(i + packet_size - 1, j) == index + packet_size - 1);
264 return this->m_tensor.template packet<AlignmentType>(index);
265 }
266
267 const IndexPair<Index> indexPair = this->computeIndexPair(i, j, packet_size - 1);
268 const Index first = indexPair.first;
269 const Index lastIdx = indexPair.second;
270
271 // We can always do optimized packet reads from left hand side right now, because
272 // the vertical matrix dimension on the left hand side is never contracting.
273 // On the right hand side we need to check if the contracting dimensions may have
274 // been shuffled first.
275 if (Tensor::PacketAccess && (side == Lhs || internal::array_size<contract_t>::value <= 1 || !inner_dim_reordered) &&
276 (lastIdx - first) == (packet_size - 1)) {
277 return this->m_tensor.template packet<AlignmentType>(first);
278 }
279
280 EIGEN_ALIGN_MAX Scalar data[packet_size];
281
282 data[0] = this->m_tensor.coeff(first);
283 EIGEN_UNROLL_LOOP
284 for (Index k = 1; k < packet_size - 1; k += 2) {
285 const IndexPair<Index> internal_pair = this->computeIndexPair(i + k, j, 1);
286 data[k] = this->m_tensor.coeff(internal_pair.first);
287 data[k + 1] = this->m_tensor.coeff(internal_pair.second);
288 }
289 data[packet_size - 1] = this->m_tensor.coeff(lastIdx);
290
291 return pload<PacketT>(data);
292 }
293
294 template <typename PacketT, int AlignmentType>
295 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
296 std::enable_if_t<internal::unpacket_traits<PacketT>::size != packet_size, PacketT>
297 load(Index i, Index j) const {
298 const Index requested_packet_size = internal::unpacket_traits<PacketT>::size;
299 EIGEN_ALIGN_MAX Scalar data[requested_packet_size];
300
301 const IndexPair<Index> indexPair = this->computeIndexPair(i, j, requested_packet_size - 1);
302 const Index first = indexPair.first;
303 const Index lastIdx = indexPair.second;
304
305 data[0] = this->m_tensor.coeff(first);
306 for (Index k = 1; k < requested_packet_size - 1; k += 2) {
307 const IndexPair<Index> internal_pair = this->computeIndexPair(i + k, j, 1);
308 data[k] = this->m_tensor.coeff(internal_pair.first);
309 data[k + 1] = this->m_tensor.coeff(internal_pair.second);
310 }
311 data[requested_packet_size - 1] = this->m_tensor.coeff(lastIdx);
312
313 return pload<PacketT>(data);
314 }
315
316 template <typename PacketT, int AlignmentType>
317 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketT loadPacket(Index i, Index j) const {
318 return this->load<PacketT, AlignmentType>(i, j);
319 }
320};
321
322template <typename Scalar, typename Index, int side, typename Tensor, typename nocontract_t, typename contract_t,
323 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, template <class> class MakePointer_>
324class BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous,
325 inner_dim_reordered, Alignment, MakePointer_>
326 : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1,
327 inner_dim_contiguous, Alignment, MakePointer_> {
328 public:
329 typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous,
330 Alignment, MakePointer_>
331 ParentMapper;
332
333 EIGEN_DEVICE_FUNC BaseTensorContractionMapper(const Tensor& tensor, const nocontract_t& nocontract_strides,
334 const nocontract_t& ij_strides, const contract_t& contract_strides,
335 const contract_t& k_strides)
336 : ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) {}
337
338 template <typename PacketT, int>
339 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketT loadPacket(Index i, Index j) const {
340 EIGEN_ALIGN_MAX Scalar data[1];
341 data[0] = this->m_tensor.coeff(this->computeIndex(i, j));
342 return pload<PacketT>(data);
343 }
344 template <typename PacketT, int>
345 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketT load(Index i, Index j) const {
346 EIGEN_ALIGN_MAX Scalar data[1];
347 data[0] = this->m_tensor.coeff(this->computeIndex(i, j));
348 return pload<PacketT>(data);
349 }
350};
351
352template <typename Scalar, typename Index, int side, typename Tensor, typename nocontract_t, typename contract_t,
353 int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
354 template <class> class MakePointer_ = MakePointer>
355class TensorContractionSubMapper {
356 public:
357 typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size,
358 inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_>
359 ParentMapper;
360 typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size,
361 inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_>
362 Self;
363 typedef Self LinearMapper;
364 typedef Self SubMapper;
365
366 enum {
367 // We can use direct offsets iff the parent mapper supports then and we can compute the strides.
368 // TODO: we should also enable direct offsets for the Rhs case.
369 UseDirectOffsets =
370 ParentMapper::DirectOffsets && (side == Lhs) && inner_dim_contiguous && (array_size<contract_t>::value > 0)
371 };
372
373 EIGEN_DEVICE_FUNC TensorContractionSubMapper(const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
374 : m_base_mapper(base_mapper), m_vert_offset(vert_offset), m_horiz_offset(horiz_offset) {
375 // Bake the offsets into the buffer used by the base mapper whenever possible. This avoids the need to recompute
376 // this offset every time we attempt to access a coefficient.
377 if (UseDirectOffsets) {
378 Index stride = m_base_mapper.stride();
379 m_base_mapper.offsetBuffer(vert_offset + horiz_offset * stride);
380 }
381 }
382
383 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
384 if (UseDirectOffsets) {
385 return m_base_mapper(i, 0);
386 }
387 return m_base_mapper(i + m_vert_offset, m_horiz_offset);
388 }
389 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, Index j) const {
390 if (UseDirectOffsets) {
391 return m_base_mapper(i, j);
392 }
393 return m_base_mapper(i + m_vert_offset, j + m_horiz_offset);
394 }
395
396 template <typename PacketT>
397 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT loadPacket(Index i) const {
398 if (UseDirectOffsets) {
399 return m_base_mapper.template loadPacket<PacketT, Alignment>(i, 0);
400 }
401 return m_base_mapper.template loadPacket<PacketT, Alignment>(i + m_vert_offset, m_horiz_offset);
402 }
403
404 template <typename PacketT>
405 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT loadPacket(Index i, Index j) const {
406 if (UseDirectOffsets) {
407 return m_base_mapper.template loadPacket<PacketT, Alignment>(i, j);
408 }
409 return m_base_mapper.template loadPacket<PacketT, Alignment>(i + m_vert_offset, j + m_horiz_offset);
410 }
411
412 template <typename PacketT>
413 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT loadPacketPartial(Index i, Index j, Index, Index = 0) const {
414 if (UseDirectOffsets) {
415 return m_base_mapper.template loadPacket<PacketT, Alignment>(i, j);
416 }
417 return m_base_mapper.template loadPacket<PacketT, Alignment>(i + m_vert_offset, j + m_horiz_offset);
418 }
419
420 template <typename PacketT, int AlignmentType>
421 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT loadPacket(Index i, Index j) const {
422 if (UseDirectOffsets) {
423 return m_base_mapper.template load<PacketT, AlignmentType>(i, j);
424 }
425 return m_base_mapper.template loadPacket<PacketT, AlignmentType>(i + m_vert_offset, j + m_horiz_offset);
426 }
427
428 template <typename PacketT>
429 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, const PacketT& p) const {
430 if (UseDirectOffsets) {
431 m_base_mapper.storePacket(i, 0, p);
432 }
433 m_base_mapper.storePacket(i + m_vert_offset, m_horiz_offset, p);
434 }
435
436 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
437 if (UseDirectOffsets) {
438 return LinearMapper(m_base_mapper, i, j);
439 }
440 return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
441 }
442
443 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubMapper getSubMapper(Index i, Index j) const {
444 if (UseDirectOffsets) {
445 return SubMapper(m_base_mapper, i, j);
446 }
447 return SubMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
448 }
449
450 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const Index stride() const { return m_base_mapper.stride(); }
451
452 template <typename PacketT, int AlignmentType>
453 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i) const {
454 EIGEN_STATIC_ASSERT((internal::is_same<PacketT, PacketT>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
455 const int ActualAlignment = (AlignmentType == Aligned) && (Alignment == Aligned) ? Aligned : Unaligned;
456 if (UseDirectOffsets) {
457 return m_base_mapper.template loadPacket<PacketT, ActualAlignment>(i, 0);
458 }
459 return m_base_mapper.template loadPacket<PacketT, ActualAlignment>(i + m_vert_offset, m_horiz_offset);
460 }
461
462 template <typename PacketT>
463 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool aligned(Index) const {
464 return false;
465 }
466
467 const ParentMapper& base_mapper() const { return m_base_mapper; }
468 Index vert_offset() const { return m_vert_offset; }
469 Index horiz_offset() const { return m_horiz_offset; }
470
471 private:
472 ParentMapper m_base_mapper;
473 const Index m_vert_offset;
474 const Index m_horiz_offset;
475};
476
477template <typename Scalar_, typename Index, int side, typename Tensor, typename nocontract_t, typename contract_t,
478 int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
479 template <class> class MakePointer_ = MakePointer>
480class TensorContractionInputMapper
481 : public BaseTensorContractionMapper<Scalar_, Index, side, Tensor, nocontract_t, contract_t, packet_size,
482 inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> {
483 public:
484 typedef Scalar_ Scalar;
485 typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size,
486 inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_>
487 Base;
488 typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size,
489 inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_>
490 SubMapper;
491 typedef SubMapper VectorMapper;
492 typedef SubMapper LinearMapper;
493
494 EIGEN_DEVICE_FUNC TensorContractionInputMapper(const Tensor& tensor, const nocontract_t& nocontract_strides,
495 const nocontract_t& ij_strides, const contract_t& contract_strides,
496 const contract_t& k_strides)
497 : Base(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) {}
498
499 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
500 return SubMapper(*this, i, j);
501 }
502
503 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
504 return LinearMapper(*this, i, j);
505 }
506
507 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
508 return VectorMapper(*this, i, j);
509 }
510
511 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_>& get_tensor() const {
512 return Base::m_tensor;
513 }
514};
515
516template <typename T>
517struct TensorContractionInputMapperTrait;
518
519template <typename Scalar_, typename Index_, int side_, typename Tensor_, typename nocontract_t_, typename contract_t_,
520 int packet_size_, bool inner_dim_contiguous_, bool inner_dim_reordered_, int Alignment_,
521 template <class> class MakePointer_>
522struct TensorContractionInputMapperTrait<
523 TensorContractionInputMapper<Scalar_, Index_, side_, Tensor_, nocontract_t_, contract_t_, packet_size_,
524 inner_dim_contiguous_, inner_dim_reordered_, Alignment_, MakePointer_> > {
525 typedef Tensor_ XprType;
526 static const bool inner_dim_contiguous = inner_dim_contiguous_;
527 static const bool inner_dim_reordered = inner_dim_reordered_;
528};
529
530} // end namespace internal
531} // end namespace Eigen
532
533#endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H
The tensor class.
Definition Tensor.h:68
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
Definition TensorContractionMapper.h:36