10#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H
11#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H
14#include "./InternalHeaderCheck.h"
20enum { Rhs = 0, Lhs = 1 };
27template <
typename Tensor,
bool HasRawAccess,
template <
class>
class MakePointer_ = MakePointer>
30template <
typename Scalar,
typename Index,
int side,
typename Tensor,
typename nocontract_t,
typename contract_t,
31 int packet_size,
bool inner_dim_contiguous,
bool inner_dim_reordered,
int Alignment,
32 template <
class>
class MakePointer_ = MakePointer>
33class BaseTensorContractionMapper;
35template <
typename Tensor,
bool HasRawAccess,
template <
class>
class MakePointer_>
37 enum { DirectOffsets =
false };
39 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffLoader(
const Tensor& tensor) : m_tensor(tensor) {}
41 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
void offsetBuffer(
typename Tensor::Index) {
42 eigen_assert(
false &&
"unsupported");
45 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
const typename MakePointer_<const typename Tensor::Scalar>::Type data()
const {
46 eigen_assert(
false &&
"unsupported");
50 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
typename Tensor::Scalar coeff(
typename Tensor::Index index)
const {
51 return m_tensor.coeff(index);
54 template <
int LoadMode>
55 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
typename Tensor::PacketReturnType packet(
typename Tensor::Index index)
const {
56 return m_tensor.template packet<LoadMode>(index);
63template <
typename Tensor,
template <
class>
class MakePointer_>
65 enum { DirectOffsets =
true };
67 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffLoader(
const Tensor& tensor) : m_data(tensor.data()) {}
69 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
void offsetBuffer(
typename Tensor::Index offset) { m_data += offset; }
71 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
const typename MakePointer_<const typename Tensor::Scalar>::Type data()
const {
75 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
typename Tensor::Scalar coeff(
typename Tensor::Index index)
const {
76 return loadConstant(m_data + index);
79 template <
int LoadMode>
80 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
typename Tensor::PacketReturnType packet(
typename Tensor::Index index)
const {
81 return internal::ploadt_ro<typename Tensor::PacketReturnType, LoadMode>(m_data + index);
85 typedef typename Tensor::Scalar Scalar;
87 typename MakePointer_<const Scalar>::Type m_data;
90template <
typename Scalar,
typename Index,
int side,
typename Tensor,
typename nocontract_t,
typename contract_t,
91 int packet_size,
bool inner_dim_contiguous,
int Alignment,
template <
class>
class MakePointer_ = MakePointer>
92class SimpleTensorContractionMapper {
94 EIGEN_DEVICE_FUNC SimpleTensorContractionMapper(
const Tensor& tensor,
const nocontract_t& nocontract_strides,
95 const nocontract_t& ij_strides,
const contract_t& contract_strides,
96 const contract_t& k_strides)
98 m_nocontract_strides(nocontract_strides),
99 m_ij_strides(ij_strides),
100 m_contract_strides(contract_strides),
101 m_k_strides(k_strides) {}
103 enum { DirectOffsets = CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_>::DirectOffsets };
105 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
void offsetBuffer(
typename Tensor::Index offset) {
106 m_tensor.offsetBuffer(offset);
109 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void prefetch(
Index ) {}
111 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(
Index row)
const {
113 return operator()(row, 0);
116 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(
Index row,
Index col)
const {
117 return m_tensor.coeff(computeIndex(row, col));
120 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Index computeIndex(
Index row,
Index col)
const {
121 const bool left = (side == Lhs);
122 EIGEN_UNUSED_VARIABLE(left);
123 Index nocontract_val = left ? row : col;
126 for (
int i =
static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) {
127 const Index idx = nocontract_val / m_ij_strides[i];
128 linidx += idx * m_nocontract_strides[i];
129 nocontract_val -= idx * m_ij_strides[i];
131 if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) {
132 if (side == Lhs && inner_dim_contiguous) {
133 eigen_assert(m_nocontract_strides[0] == 1);
134 linidx += nocontract_val;
136 linidx += nocontract_val * m_nocontract_strides[0];
140 Index contract_val = left ? col : row;
141 if (array_size<contract_t>::value > 0) {
143 for (
int i =
static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
144 const Index idx = contract_val / m_k_strides[i];
145 linidx += idx * m_contract_strides[i];
146 contract_val -= idx * m_k_strides[i];
149 if (side == Rhs && inner_dim_contiguous) {
150 eigen_assert(m_contract_strides[0] == 1);
151 linidx += contract_val;
153 linidx += contract_val * m_contract_strides[0];
160 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE IndexPair<Index> computeIndexPair(
Index row,
Index col,
161 const Index distance)
const {
162 const bool left = (side == Lhs);
163 EIGEN_UNUSED_VARIABLE(left);
164 Index nocontract_val[2] = {left ? row : col, left ? row + distance : col};
165 Index linidx[2] = {0, 0};
166 if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) {
168 for (
int i =
static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) {
169 const Index idx0 = nocontract_val[0] / m_ij_strides[i];
170 const Index idx1 = nocontract_val[1] / m_ij_strides[i];
171 linidx[0] += idx0 * m_nocontract_strides[i];
172 linidx[1] += idx1 * m_nocontract_strides[i];
173 nocontract_val[0] -= idx0 * m_ij_strides[i];
174 nocontract_val[1] -= idx1 * m_ij_strides[i];
176 if (side == Lhs && inner_dim_contiguous) {
177 eigen_assert(m_nocontract_strides[0] == 1);
178 linidx[0] += nocontract_val[0];
179 linidx[1] += nocontract_val[1];
181 linidx[0] += nocontract_val[0] * m_nocontract_strides[0];
182 linidx[1] += nocontract_val[1] * m_nocontract_strides[0];
186 Index contract_val[2] = {left ? col : row, left ? col : row + distance};
187 if (array_size<contract_t>::value > 0) {
189 for (
int i =
static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
190 const Index idx0 = contract_val[0] / m_k_strides[i];
191 const Index idx1 = contract_val[1] / m_k_strides[i];
192 linidx[0] += idx0 * m_contract_strides[i];
193 linidx[1] += idx1 * m_contract_strides[i];
194 contract_val[0] -= idx0 * m_k_strides[i];
195 contract_val[1] -= idx1 * m_k_strides[i];
198 if (side == Rhs && inner_dim_contiguous) {
199 eigen_assert(m_contract_strides[0] == 1);
200 linidx[0] += contract_val[0];
201 linidx[1] += contract_val[1];
203 linidx[0] += contract_val[0] * m_contract_strides[0];
204 linidx[1] += contract_val[1] * m_contract_strides[0];
207 return IndexPair<Index>(linidx[0], linidx[1]);
210 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
Index firstAligned(
Index size)
const {
214 return (Alignment ==
Aligned) && (side == Lhs) && inner_dim_contiguous ? 0 : size;
216 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
Index stride()
const {
217 return ((side == Lhs) && inner_dim_contiguous && array_size<contract_t>::value > 0) ? m_contract_strides[0] : 1;
220 const CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_>& tensor()
const {
return m_tensor; }
222 const nocontract_t& nocontract_strides()
const {
return m_nocontract_strides; }
223 const nocontract_t& ij_strides()
const {
return m_ij_strides; }
224 const contract_t& contract_strides()
const {
return m_contract_strides; }
225 const contract_t& k_strides()
const {
return m_k_strides; }
228 CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_> m_tensor;
229 const nocontract_t m_nocontract_strides;
230 const nocontract_t m_ij_strides;
231 const contract_t m_contract_strides;
232 const contract_t m_k_strides;
235template <
typename Scalar,
typename Index,
int side,
typename Tensor,
typename nocontract_t,
typename contract_t,
236 int packet_size,
bool inner_dim_contiguous,
bool inner_dim_reordered,
int Alignment,
237 template <
class>
class MakePointer_>
238class BaseTensorContractionMapper
239 :
public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size,
240 inner_dim_contiguous, Alignment, MakePointer_> {
242 typedef SimpleTensorContractionMapper<Scalar,
Index, side, Tensor, nocontract_t, contract_t, packet_size,
243 inner_dim_contiguous, Alignment, MakePointer_>
246 EIGEN_DEVICE_FUNC BaseTensorContractionMapper(
const Tensor& tensor,
const nocontract_t& nocontract_strides,
247 const nocontract_t& ij_strides,
const contract_t& contract_strides,
248 const contract_t& k_strides)
249 : ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) {}
251 template <
typename PacketT,
int AlignmentType>
252 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
253 std::enable_if_t<internal::unpacket_traits<PacketT>::size == packet_size, PacketT>
259 EIGEN_STATIC_ASSERT(packet_size % 2 == 0, YOU_MADE_A_PROGRAMMING_MISTAKE);
261 if (Tensor::PacketAccess && inner_dim_contiguous && !inner_dim_reordered) {
262 const Index index = this->computeIndex(i, j);
263 eigen_assert(this->computeIndex(i + packet_size - 1, j) == index + packet_size - 1);
264 return this->m_tensor.template packet<AlignmentType>(index);
267 const IndexPair<Index> indexPair = this->computeIndexPair(i, j, packet_size - 1);
268 const Index first = indexPair.first;
269 const Index lastIdx = indexPair.second;
275 if (Tensor::PacketAccess && (side == Lhs || internal::array_size<contract_t>::value <= 1 || !inner_dim_reordered) &&
276 (lastIdx - first) == (packet_size - 1)) {
277 return this->m_tensor.template packet<AlignmentType>(first);
280 EIGEN_ALIGN_MAX Scalar data[packet_size];
282 data[0] = this->m_tensor.coeff(first);
284 for (
Index k = 1; k < packet_size - 1; k += 2) {
285 const IndexPair<Index> internal_pair = this->computeIndexPair(i + k, j, 1);
286 data[k] = this->m_tensor.coeff(internal_pair.first);
287 data[k + 1] = this->m_tensor.coeff(internal_pair.second);
289 data[packet_size - 1] = this->m_tensor.coeff(lastIdx);
291 return pload<PacketT>(data);
294 template <
typename PacketT,
int AlignmentType>
295 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
296 std::enable_if_t<internal::unpacket_traits<PacketT>::size != packet_size, PacketT>
298 const Index requested_packet_size = internal::unpacket_traits<PacketT>::size;
299 EIGEN_ALIGN_MAX Scalar data[requested_packet_size];
301 const IndexPair<Index> indexPair = this->computeIndexPair(i, j, requested_packet_size - 1);
302 const Index first = indexPair.first;
303 const Index lastIdx = indexPair.second;
305 data[0] = this->m_tensor.coeff(first);
306 for (
Index k = 1; k < requested_packet_size - 1; k += 2) {
307 const IndexPair<Index> internal_pair = this->computeIndexPair(i + k, j, 1);
308 data[k] = this->m_tensor.coeff(internal_pair.first);
309 data[k + 1] = this->m_tensor.coeff(internal_pair.second);
311 data[requested_packet_size - 1] = this->m_tensor.coeff(lastIdx);
313 return pload<PacketT>(data);
316 template <
typename PacketT,
int AlignmentType>
317 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketT loadPacket(
Index i,
Index j)
const {
318 return this->load<PacketT, AlignmentType>(i, j);
322template <
typename Scalar,
typename Index,
int side,
typename Tensor,
typename nocontract_t,
typename contract_t,
323 bool inner_dim_contiguous,
bool inner_dim_reordered,
int Alignment,
template <
class>
class MakePointer_>
324class BaseTensorContractionMapper<Scalar,
Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous,
325 inner_dim_reordered, Alignment, MakePointer_>
326 :
public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1,
327 inner_dim_contiguous, Alignment, MakePointer_> {
329 typedef SimpleTensorContractionMapper<Scalar,
Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous,
330 Alignment, MakePointer_>
333 EIGEN_DEVICE_FUNC BaseTensorContractionMapper(
const Tensor& tensor,
const nocontract_t& nocontract_strides,
334 const nocontract_t& ij_strides,
const contract_t& contract_strides,
335 const contract_t& k_strides)
336 : ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) {}
338 template <
typename PacketT,
int>
339 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketT loadPacket(
Index i,
Index j)
const {
340 EIGEN_ALIGN_MAX Scalar data[1];
341 data[0] = this->m_tensor.coeff(this->computeIndex(i, j));
342 return pload<PacketT>(data);
344 template <
typename PacketT,
int>
345 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketT load(
Index i,
Index j)
const {
346 EIGEN_ALIGN_MAX Scalar data[1];
347 data[0] = this->m_tensor.coeff(this->computeIndex(i, j));
348 return pload<PacketT>(data);
352template <
typename Scalar,
typename Index,
int side,
typename Tensor,
typename nocontract_t,
typename contract_t,
353 int packet_size,
bool inner_dim_contiguous,
bool inner_dim_reordered,
int Alignment,
354 template <
class>
class MakePointer_ = MakePointer>
355class TensorContractionSubMapper {
357 typedef BaseTensorContractionMapper<Scalar,
Index, side, Tensor, nocontract_t, contract_t, packet_size,
358 inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_>
360 typedef TensorContractionSubMapper<Scalar,
Index, side, Tensor, nocontract_t, contract_t, packet_size,
361 inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_>
363 typedef Self LinearMapper;
364 typedef Self SubMapper;
370 ParentMapper::DirectOffsets && (side == Lhs) && inner_dim_contiguous && (array_size<contract_t>::value > 0)
373 EIGEN_DEVICE_FUNC TensorContractionSubMapper(
const ParentMapper& base_mapper,
Index vert_offset,
Index horiz_offset)
374 : m_base_mapper(base_mapper), m_vert_offset(vert_offset), m_horiz_offset(horiz_offset) {
377 if (UseDirectOffsets) {
378 Index stride = m_base_mapper.stride();
379 m_base_mapper.offsetBuffer(vert_offset + horiz_offset * stride);
383 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(
Index i)
const {
384 if (UseDirectOffsets) {
385 return m_base_mapper(i, 0);
387 return m_base_mapper(i + m_vert_offset, m_horiz_offset);
389 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(
Index i,
Index j)
const {
390 if (UseDirectOffsets) {
391 return m_base_mapper(i, j);
393 return m_base_mapper(i + m_vert_offset, j + m_horiz_offset);
396 template <
typename PacketT>
397 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT loadPacket(
Index i)
const {
398 if (UseDirectOffsets) {
399 return m_base_mapper.template loadPacket<PacketT, Alignment>(i, 0);
401 return m_base_mapper.template loadPacket<PacketT, Alignment>(i + m_vert_offset, m_horiz_offset);
404 template <
typename PacketT>
405 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT loadPacket(
Index i,
Index j)
const {
406 if (UseDirectOffsets) {
407 return m_base_mapper.template loadPacket<PacketT, Alignment>(i, j);
409 return m_base_mapper.template loadPacket<PacketT, Alignment>(i + m_vert_offset, j + m_horiz_offset);
412 template <
typename PacketT>
413 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT loadPacketPartial(
Index i,
Index j,
Index,
Index = 0)
const {
414 if (UseDirectOffsets) {
415 return m_base_mapper.template loadPacket<PacketT, Alignment>(i, j);
417 return m_base_mapper.template loadPacket<PacketT, Alignment>(i + m_vert_offset, j + m_horiz_offset);
420 template <
typename PacketT,
int AlignmentType>
421 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT loadPacket(
Index i,
Index j)
const {
422 if (UseDirectOffsets) {
423 return m_base_mapper.template load<PacketT, AlignmentType>(i, j);
425 return m_base_mapper.template loadPacket<PacketT, AlignmentType>(i + m_vert_offset, j + m_horiz_offset);
428 template <
typename PacketT>
429 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
void storePacket(
Index i,
const PacketT& p)
const {
430 if (UseDirectOffsets) {
431 m_base_mapper.storePacket(i, 0, p);
433 m_base_mapper.storePacket(i + m_vert_offset, m_horiz_offset, p);
436 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(
Index i,
Index j)
const {
437 if (UseDirectOffsets) {
438 return LinearMapper(m_base_mapper, i, j);
440 return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
443 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubMapper getSubMapper(
Index i,
Index j)
const {
444 if (UseDirectOffsets) {
445 return SubMapper(m_base_mapper, i, j);
447 return SubMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
450 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
const Index stride()
const {
return m_base_mapper.stride(); }
452 template <
typename PacketT,
int AlignmentType>
453 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(
Index i)
const {
454 EIGEN_STATIC_ASSERT((internal::is_same<PacketT, PacketT>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
456 if (UseDirectOffsets) {
457 return m_base_mapper.template loadPacket<PacketT, ActualAlignment>(i, 0);
459 return m_base_mapper.template loadPacket<PacketT, ActualAlignment>(i + m_vert_offset, m_horiz_offset);
462 template <
typename PacketT>
463 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
bool aligned(
Index)
const {
467 const ParentMapper& base_mapper()
const {
return m_base_mapper; }
468 Index vert_offset()
const {
return m_vert_offset; }
469 Index horiz_offset()
const {
return m_horiz_offset; }
472 ParentMapper m_base_mapper;
473 const Index m_vert_offset;
474 const Index m_horiz_offset;
477template <
typename Scalar_,
typename Index,
int side,
typename Tensor,
typename nocontract_t,
typename contract_t,
478 int packet_size,
bool inner_dim_contiguous,
bool inner_dim_reordered,
int Alignment,
479 template <
class>
class MakePointer_ = MakePointer>
480class TensorContractionInputMapper
481 :
public BaseTensorContractionMapper<Scalar_, Index, side, Tensor, nocontract_t, contract_t, packet_size,
482 inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> {
484 typedef Scalar_ Scalar;
485 typedef BaseTensorContractionMapper<Scalar,
Index, side, Tensor, nocontract_t, contract_t, packet_size,
486 inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_>
488 typedef TensorContractionSubMapper<Scalar,
Index, side, Tensor, nocontract_t, contract_t, packet_size,
489 inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_>
491 typedef SubMapper VectorMapper;
492 typedef SubMapper LinearMapper;
494 EIGEN_DEVICE_FUNC TensorContractionInputMapper(
const Tensor& tensor,
const nocontract_t& nocontract_strides,
495 const nocontract_t& ij_strides,
const contract_t& contract_strides,
496 const contract_t& k_strides)
497 : Base(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) {}
499 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE SubMapper getSubMapper(
Index i,
Index j)
const {
500 return SubMapper(*
this, i, j);
503 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(
Index i,
Index j)
const {
504 return LinearMapper(*
this, i, j);
507 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(
Index i,
Index j)
const {
508 return VectorMapper(*
this, i, j);
511 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
const CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_>& get_tensor()
const {
512 return Base::m_tensor;
517struct TensorContractionInputMapperTrait;
519template <
typename Scalar_,
typename Index_,
int side_,
typename Tensor_,
typename nocontract_t_,
typename contract_t_,
520 int packet_size_,
bool inner_dim_contiguous_,
bool inner_dim_reordered_,
int Alignment_,
521 template <
class>
class MakePointer_>
522struct TensorContractionInputMapperTrait<
523 TensorContractionInputMapper<Scalar_, Index_, side_, Tensor_, nocontract_t_, contract_t_, packet_size_,
524 inner_dim_contiguous_, inner_dim_reordered_, Alignment_, MakePointer_> > {
525 typedef Tensor_ XprType;
526 static const bool inner_dim_contiguous = inner_dim_contiguous_;
527 static const bool inner_dim_reordered = inner_dim_reordered_;
The tensor class.
Definition Tensor.h:68
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
Definition TensorContractionMapper.h:36