Loading...
Searching...
No Matches
TensorContractionMapper.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H
11#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H
12
13namespace Eigen {
14
15namespace internal {
16
17enum {
18 Rhs = 0,
19 Lhs = 1
20};
21
22/*
23 * Implementation of the Eigen blas_data_mapper class for tensors.
24 */
25
26template <typename Tensor, bool HasRawAccess> struct CoeffLoader {
27 enum {
28 DirectOffsets = false
29 };
30
31 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffLoader(const Tensor& tensor) : m_tensor(tensor) { }
32
33 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index) {
34 eigen_assert(false && "unsupported");
35 }
36
37 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return m_tensor.coeff(index); }
38
39 template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
40 typename Tensor::PacketReturnType packet(typename Tensor::Index index) const
41 {
42 return m_tensor.template packet<LoadMode>(index);
43 }
44
45
46 private:
47 const Tensor m_tensor;
48};
49
50template <typename Tensor> struct CoeffLoader<Tensor, true> {
51 enum {
52 DirectOffsets = true
53 };
54
55 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffLoader(const Tensor& tensor) : m_data(tensor.data()) {}
56
57 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index offset) {
58 m_data += offset;
59 }
60
61 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return loadConstant(m_data+index); }
62
63 template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
64 typename Tensor::PacketReturnType packet(typename Tensor::Index index) const
65 {
66 return internal::ploadt_ro<typename Tensor::PacketReturnType, LoadMode>(m_data + index);
67 }
68 private:
69 typedef typename Tensor::Scalar Scalar;
70 const Scalar* m_data;
71};
72
73template<typename Scalar, typename Index, int side,
74 typename Tensor,
75 typename nocontract_t, typename contract_t,
76 int packet_size, bool inner_dim_contiguous, int Alignment>
77class SimpleTensorContractionMapper {
78 public:
79 EIGEN_DEVICE_FUNC
80 SimpleTensorContractionMapper(const Tensor& tensor,
81 const nocontract_t& nocontract_strides,
82 const nocontract_t& ij_strides,
83 const contract_t& contract_strides,
84 const contract_t& k_strides) :
85 m_tensor(tensor),
86 m_nocontract_strides(nocontract_strides),
87 m_ij_strides(ij_strides),
88 m_contract_strides(contract_strides),
89 m_k_strides(k_strides) { }
90
91 enum {
92 DirectOffsets = CoeffLoader<Tensor, Tensor::RawAccess>::DirectOffsets
93 };
94
95 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index offset) {
96 m_tensor.offsetBuffer(offset);
97 }
98
99 EIGEN_DEVICE_FUNC
100 EIGEN_STRONG_INLINE void prefetch(Index /*i*/) { }
101
102 EIGEN_DEVICE_FUNC
103 EIGEN_STRONG_INLINE Scalar operator()(Index row) const {
104 // column major assumption
105 return operator()(row, 0);
106 }
107
108 EIGEN_DEVICE_FUNC
109 EIGEN_STRONG_INLINE Scalar operator()(Index row, Index col) const {
110 return m_tensor.coeff(computeIndex(row, col));
111 }
112
113 EIGEN_DEVICE_FUNC
114 EIGEN_STRONG_INLINE Index computeIndex(Index row, Index col) const {
115 const bool left = (side == Lhs);
116 EIGEN_UNUSED_VARIABLE(left); // annoying bug in g++8.1: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=85963
117 Index nocontract_val = left ? row : col;
118 Index linidx = 0;
119 for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) {
120 const Index idx = nocontract_val / m_ij_strides[i];
121 linidx += idx * m_nocontract_strides[i];
122 nocontract_val -= idx * m_ij_strides[i];
123 }
124 if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) {
125 if (side == Lhs && inner_dim_contiguous) {
126 eigen_assert(m_nocontract_strides[0] == 1);
127 linidx += nocontract_val;
128 } else {
129 linidx += nocontract_val * m_nocontract_strides[0];
130 }
131 }
132
133 Index contract_val = left ? col : row;
134 if(array_size<contract_t>::value > 0) {
135 for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
136 const Index idx = contract_val / m_k_strides[i];
137 linidx += idx * m_contract_strides[i];
138 contract_val -= idx * m_k_strides[i];
139 }
140
141 if (side == Rhs && inner_dim_contiguous) {
142 eigen_assert(m_contract_strides[0] == 1);
143 linidx += contract_val;
144 } else {
145 linidx += contract_val * m_contract_strides[0];
146 }
147 }
148
149 return linidx;
150 }
151
152 EIGEN_DEVICE_FUNC
153 EIGEN_STRONG_INLINE IndexPair<Index> computeIndexPair(Index row, Index col, const Index distance) const {
154 const bool left = (side == Lhs);
155 EIGEN_UNUSED_VARIABLE(left); // annoying bug in g++8.1: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=85963
156 Index nocontract_val[2] = {left ? row : col, left ? row + distance : col};
157 Index linidx[2] = {0, 0};
158 if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) {
159 for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) {
160 const Index idx0 = nocontract_val[0] / m_ij_strides[i];
161 const Index idx1 = nocontract_val[1] / m_ij_strides[i];
162 linidx[0] += idx0 * m_nocontract_strides[i];
163 linidx[1] += idx1 * m_nocontract_strides[i];
164 nocontract_val[0] -= idx0 * m_ij_strides[i];
165 nocontract_val[1] -= idx1 * m_ij_strides[i];
166 }
167 if (side == Lhs && inner_dim_contiguous) {
168 eigen_assert(m_nocontract_strides[0] == 1);
169 linidx[0] += nocontract_val[0];
170 linidx[1] += nocontract_val[1];
171 } else {
172 linidx[0] += nocontract_val[0] * m_nocontract_strides[0];
173 linidx[1] += nocontract_val[1] * m_nocontract_strides[0];
174 }
175 }
176
177 Index contract_val[2] = {left ? col : row, left ? col : row + distance};
178 if (array_size<contract_t>::value> 0) {
179 for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
180 const Index idx0 = contract_val[0] / m_k_strides[i];
181 const Index idx1 = contract_val[1] / m_k_strides[i];
182 linidx[0] += idx0 * m_contract_strides[i];
183 linidx[1] += idx1 * m_contract_strides[i];
184 contract_val[0] -= idx0 * m_k_strides[i];
185 contract_val[1] -= idx1 * m_k_strides[i];
186 }
187
188 if (side == Rhs && inner_dim_contiguous) {
189 eigen_assert(m_contract_strides[0] == 1);
190 linidx[0] += contract_val[0];
191 linidx[1] += contract_val[1];
192 } else {
193 linidx[0] += contract_val[0] * m_contract_strides[0];
194 linidx[1] += contract_val[1] * m_contract_strides[0];
195 }
196 }
197 return IndexPair<Index>(linidx[0], linidx[1]);
198 }
199
200 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index firstAligned(Index size) const {
201 // Only claim alignment when we can compute the actual stride (ie when we're
202 // dealing with the lhs with inner_dim_contiguous. This is because the
203 // matrix-vector product relies on the stride when dealing with aligned inputs.
204 return (Alignment == Aligned) && (side == Lhs) && inner_dim_contiguous ? 0 : size;
205 }
206 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index stride() const {
207 return ((side == Lhs) && inner_dim_contiguous && array_size<contract_t>::value > 0) ? m_contract_strides[0] : 1;
208 }
209
210 protected:
211 CoeffLoader<Tensor, Tensor::RawAccess> m_tensor;
212 const nocontract_t m_nocontract_strides;
213 const nocontract_t m_ij_strides;
214 const contract_t m_contract_strides;
215 const contract_t m_k_strides;
216};
217
218
219template<typename Scalar, typename Index, int side,
220 typename Tensor,
221 typename nocontract_t, typename contract_t,
222 int packet_size, bool inner_dim_contiguous,
223 bool inner_dim_reordered, int Alignment>
224class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, Alignment>
225{
226 public:
227 typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, Alignment> ParentMapper;
228
229 EIGEN_DEVICE_FUNC
230 BaseTensorContractionMapper(const Tensor& tensor,
231 const nocontract_t& nocontract_strides,
232 const nocontract_t& ij_strides,
233 const contract_t& contract_strides,
234 const contract_t& k_strides) :
235 ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
236
237 typedef typename Tensor::PacketReturnType Packet;
238 typedef typename unpacket_traits<Packet>::half HalfPacket;
239
240 template <int AlignmentType>
241 EIGEN_DEVICE_FUNC
242 EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const {
243 // whole method makes column major assumption
244
245 // don't need to add offsets for now (because operator handles that)
246 // current code assumes packet size must be a multiple of 2
247 EIGEN_STATIC_ASSERT(packet_size % 2 == 0, YOU_MADE_A_PROGRAMMING_MISTAKE);
248
249 if (Tensor::PacketAccess && inner_dim_contiguous && !inner_dim_reordered) {
250 const Index index = this->computeIndex(i, j);
251 eigen_assert(this->computeIndex(i+packet_size-1, j) == index + packet_size-1);
252 return this->m_tensor.template packet<AlignmentType>(index);
253 }
254
255 const IndexPair<Index> indexPair = this->computeIndexPair(i, j, packet_size - 1);
256 const Index first = indexPair.first;
257 const Index last = indexPair.second;
258
259 // We can always do optimized packet reads from left hand side right now, because
260 // the vertical matrix dimension on the left hand side is never contracting.
261 // On the right hand side we need to check if the contracting dimensions may have
262 // been shuffled first.
263 if (Tensor::PacketAccess &&
264 (side == Lhs || internal::array_size<contract_t>::value <= 1 || !inner_dim_reordered) &&
265 (last - first) == (packet_size - 1)) {
266
267 return this->m_tensor.template packet<AlignmentType>(first);
268 }
269
270 EIGEN_ALIGN_MAX Scalar data[packet_size];
271
272 data[0] = this->m_tensor.coeff(first);
273 for (Index k = 1; k < packet_size - 1; k += 2) {
274 const IndexPair<Index> internal_pair = this->computeIndexPair(i + k, j, 1);
275 data[k] = this->m_tensor.coeff(internal_pair.first);
276 data[k + 1] = this->m_tensor.coeff(internal_pair.second);
277 }
278 data[packet_size - 1] = this->m_tensor.coeff(last);
279
280 return pload<Packet>(data);
281 }
282
283 template <int AlignmentType>
284 EIGEN_DEVICE_FUNC
285 EIGEN_STRONG_INLINE HalfPacket loadHalfPacket(Index i, Index j) const {
286 // whole method makes column major assumption
287
288 // don't need to add offsets for now (because operator handles that)
289 const Index half_packet_size = unpacket_traits<HalfPacket>::size;
290 if (half_packet_size == packet_size) {
291 return loadPacket<AlignmentType>(i, j);
292 }
293 EIGEN_ALIGN_MAX Scalar data[half_packet_size];
294 for (Index k = 0; k < half_packet_size; k++) {
295 data[k] = operator()(i + k, j);
296 }
297 return pload<HalfPacket>(data);
298 }
299};
300
301
302template<typename Scalar, typename Index, int side,
303 typename Tensor,
304 typename nocontract_t, typename contract_t,
305 bool inner_dim_contiguous,
306 bool inner_dim_reordered, int Alignment>
307class BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment> : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment>
308{
309 public:
310 typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment> ParentMapper;
311
312 EIGEN_DEVICE_FUNC
313 BaseTensorContractionMapper(const Tensor& tensor,
314 const nocontract_t& nocontract_strides,
315 const nocontract_t& ij_strides,
316 const contract_t& contract_strides,
317 const contract_t& k_strides) :
318 ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
319
320 typedef typename Tensor::PacketReturnType Packet;
321 template <int> EIGEN_DEVICE_FUNC
322 EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const {
323 EIGEN_ALIGN_MAX Scalar data[1];
324 data[0] = this->m_tensor.coeff(this->computeIndex(i, j));
325 return pload<typename Tensor::PacketReturnType>(data);
326 }
327 template <int> EIGEN_DEVICE_FUNC
328 EIGEN_STRONG_INLINE Packet loadHalfPacket(Index i, Index j) const {
329 return loadPacket(i, j);
330 }
331};
332
333
334template<typename Scalar, typename Index, int side,
335 typename Tensor,
336 typename nocontract_t, typename contract_t,
337 int packet_size,
338 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
339class TensorContractionSubMapper {
340 public:
341 typedef typename Tensor::PacketReturnType Packet;
342 typedef typename unpacket_traits<Packet>::half HalfPacket;
343
344 typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> ParentMapper;
345 typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Self;
346 typedef Self LinearMapper;
347
348 enum {
349 // We can use direct offsets iff the parent mapper supports then and we can compute the strides.
350 // TODO: we should also enable direct offsets for the Rhs case.
351 UseDirectOffsets = ParentMapper::DirectOffsets && (side == Lhs) && inner_dim_contiguous && (array_size<contract_t>::value > 0)
352 };
353
354 EIGEN_DEVICE_FUNC TensorContractionSubMapper(const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
355 : m_base_mapper(base_mapper), m_vert_offset(vert_offset), m_horiz_offset(horiz_offset) {
356 // Bake the offsets into the buffer used by the base mapper whenever possible. This avoids the need to recompute
357 // this offset every time we attempt to access a coefficient.
358 if (UseDirectOffsets) {
359 Index stride = m_base_mapper.stride();
360 m_base_mapper.offsetBuffer(vert_offset + horiz_offset * stride);
361 }
362 }
363
364 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
365 if (UseDirectOffsets) {
366 return m_base_mapper(i, 0);
367 }
368 return m_base_mapper(i + m_vert_offset, m_horiz_offset);
369 }
370 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, Index j) const {
371 if (UseDirectOffsets) {
372 return m_base_mapper(i, j);
373 }
374 return m_base_mapper(i + m_vert_offset, j + m_horiz_offset);
375 }
376
377 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
378 if (UseDirectOffsets) {
379 return m_base_mapper.template loadPacket<Alignment>(i, 0);
380 }
381 return m_base_mapper.template loadPacket<Alignment>(i + m_vert_offset, m_horiz_offset);
382 }
383 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const {
384 if (UseDirectOffsets) {
385 return m_base_mapper.template loadPacket<Alignment>(i, j);
386 }
387 return m_base_mapper.template loadPacket<Alignment>(i + m_vert_offset, j + m_horiz_offset);
388 }
389
390 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const {
391 if (UseDirectOffsets) {
392 return m_base_mapper.template loadHalfPacket<Alignment>(i, 0);
393 }
394 return m_base_mapper.template loadHalfPacket<Alignment>(i + m_vert_offset, m_horiz_offset);
395 }
396
397 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, Packet p) const {
398 if (UseDirectOffsets) {
399 m_base_mapper.storePacket(i, 0, p);
400 }
401 m_base_mapper.storePacket(i + m_vert_offset, m_horiz_offset, p);
402 }
403
404 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
405 if (UseDirectOffsets) {
406 return LinearMapper(m_base_mapper, i, j);
407 }
408 return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
409 }
410
411 template <typename PacketT, int AlignmentType>
412 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i) const {
413 EIGEN_STATIC_ASSERT((internal::is_same<PacketT, Packet>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
414 const int ActualAlignment = (AlignmentType == Aligned) && (Alignment == Aligned) ? Aligned : Unaligned;
415 if (UseDirectOffsets) {
416 return m_base_mapper.template loadPacket<ActualAlignment>(i, 0);
417 }
418 return m_base_mapper.template loadPacket<ActualAlignment>(i + m_vert_offset, m_horiz_offset);
419 }
420
421 template <typename Packet>
422 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool aligned(Index) const {
423 return false;
424 }
425
426 private:
427 ParentMapper m_base_mapper;
428 const Index m_vert_offset;
429 const Index m_horiz_offset;
430};
431
432
433template<typename Scalar_, typename Index, int side,
434 typename Tensor,
435 typename nocontract_t, typename contract_t,
436 int packet_size,
437 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
438class TensorContractionInputMapper
439 : public BaseTensorContractionMapper<Scalar_, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> {
440
441 public:
442 typedef Scalar_ Scalar;
443 typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Base;
444 typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
445 typedef SubMapper VectorMapper;
446
447 EIGEN_DEVICE_FUNC TensorContractionInputMapper(const Tensor& tensor,
448 const nocontract_t& nocontract_strides,
449 const nocontract_t& ij_strides,
450 const contract_t& contract_strides,
451 const contract_t& k_strides)
452 : Base(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
453
454 EIGEN_DEVICE_FUNC
455 EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
456 return SubMapper(*this, i, j);
457 }
458
459 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
460 return VectorMapper(*this, i, j);
461 }
462};
463
464
465
466} // end namespace internal
467} // end namespace Eigen
468
469#endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index