11#ifndef EIGEN_CXX11_TENSOR_TENSOR_TRACE_H
12#define EIGEN_CXX11_TENSOR_TENSOR_TRACE_H
17template<
typename Dims,
typename XprType>
18struct traits<TensorTraceOp<Dims, XprType> > :
public traits<XprType>
20 typedef typename XprType::Scalar Scalar;
21 typedef traits<XprType> XprTraits;
22 typedef typename XprTraits::StorageKind StorageKind;
23 typedef typename XprTraits::Index
Index;
24 typedef typename XprType::Nested Nested;
25 typedef typename remove_reference<Nested>::type _Nested;
26 static const int NumDimensions = XprTraits::NumDimensions - array_size<Dims>::value;
27 static const int Layout = XprTraits::Layout;
30template<
typename Dims,
typename XprType>
31struct eval<TensorTraceOp<Dims, XprType>, Eigen::Dense>
33 typedef const TensorTraceOp<Dims, XprType>& type;
36template<
typename Dims,
typename XprType>
37struct nested<TensorTraceOp<Dims, XprType>, 1, typename eval<TensorTraceOp<Dims, XprType> >::type>
39 typedef TensorTraceOp<Dims, XprType> type;
49template <
typename Dims,
typename XprType>
50class TensorTraceOp :
public TensorBase<TensorTraceOp<Dims, XprType> > {
52 typedef typename Eigen::internal::traits<TensorTraceOp>::Scalar Scalar;
54 typedef typename XprType::CoeffReturnType CoeffReturnType;
55 typedef typename Eigen::internal::nested<TensorTraceOp>::type Nested;
56 typedef typename Eigen::internal::traits<TensorTraceOp>::StorageKind StorageKind;
57 typedef typename Eigen::internal::traits<TensorTraceOp>::Index Index;
59 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorTraceOp(
const XprType& expr,
const Dims& dims)
60 : m_xpr(expr), m_dims(dims) {
63 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
64 const Dims& dims()
const {
return m_dims; }
66 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
67 const typename internal::remove_all<typename XprType::Nested>::type& expression()
const {
return m_xpr; }
70 typename XprType::Nested m_xpr;
76template<
typename Dims,
typename ArgType,
typename Device>
80 static const int NumInputDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
81 static const int NumReducedDims = internal::array_size<Dims>::value;
82 static const int NumOutputDims = NumInputDims - NumReducedDims;
83 typedef typename XprType::Index
Index;
84 typedef DSizes<Index, NumOutputDims>
Dimensions;
85 typedef typename XprType::Scalar
Scalar;
87 typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
88 static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
89 typedef StorageMemory<CoeffReturnType, Device> Storage;
90 typedef typename Storage::Type EvaluatorPointerType;
94 PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
96 PreferBlockAccess = TensorEvaluator<ArgType, Device>::PreferBlockAccess,
97 Layout = TensorEvaluator<ArgType, Device>::Layout,
103 typedef internal::TensorBlockNotImplemented TensorBlock;
106 EIGEN_STRONG_INLINE TensorEvaluator(
const XprType& op,
const Device& device)
107 : m_impl(op.expression(), device), m_traceDim(1), m_device(device)
110 EIGEN_STATIC_ASSERT((NumOutputDims >= 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
111 EIGEN_STATIC_ASSERT((NumReducedDims >= 2) || ((NumReducedDims == 0) && (NumInputDims == 0)), YOU_MADE_A_PROGRAMMING_MISTAKE);
113 for (
int i = 0; i < NumInputDims; ++i) {
114 m_reduced[i] =
false;
117 const Dims& op_dims = op.dims();
118 for (
int i = 0; i < NumReducedDims; ++i) {
119 eigen_assert(op_dims[i] >= 0);
120 eigen_assert(op_dims[i] < NumInputDims);
121 m_reduced[op_dims[i]] =
true;
125 int num_distinct_reduce_dims = 0;
126 for (
int i = 0; i < NumInputDims; ++i) {
128 ++num_distinct_reduce_dims;
132 EIGEN_ONLY_USED_FOR_DEBUG(num_distinct_reduce_dims);
133 eigen_assert(num_distinct_reduce_dims == NumReducedDims);
136 const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();
138 int output_index = 0;
139 int reduced_index = 0;
140 for (
int i = 0; i < NumInputDims; ++i) {
142 m_reducedDims[reduced_index] = input_dims[i];
143 if (reduced_index > 0) {
145 eigen_assert(m_reducedDims[0] == m_reducedDims[reduced_index]);
150 m_dimensions[output_index] = input_dims[i];
155 if (NumReducedDims != 0) {
156 m_traceDim = m_reducedDims[0];
160 if (NumOutputDims > 0) {
161 if (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) {
162 m_outputStrides[0] = 1;
163 for (
int i = 1; i < NumOutputDims; ++i) {
164 m_outputStrides[i] = m_outputStrides[i - 1] * m_dimensions[i - 1];
168 m_outputStrides.back() = 1;
169 for (
int i = NumOutputDims - 2; i >= 0; --i) {
170 m_outputStrides[i] = m_outputStrides[i + 1] * m_dimensions[i + 1];
176 if (NumInputDims > 0) {
177 array<Index, NumInputDims> input_strides;
178 if (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) {
179 input_strides[0] = 1;
180 for (
int i = 1; i < NumInputDims; ++i) {
181 input_strides[i] = input_strides[i - 1] * input_dims[i - 1];
185 input_strides.back() = 1;
186 for (
int i = NumInputDims - 2; i >= 0; --i) {
187 input_strides[i] = input_strides[i + 1] * input_dims[i + 1];
193 for (
int i = 0; i < NumInputDims; ++i) {
195 m_reducedStrides[reduced_index] = input_strides[i];
199 m_preservedStrides[output_index] = input_strides[i];
206 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Dimensions& dimensions()
const {
210 EIGEN_STRONG_INLINE
bool evalSubExprsIfNeeded(EvaluatorPointerType ) {
211 m_impl.evalSubExprsIfNeeded(NULL);
215 EIGEN_STRONG_INLINE
void cleanup() {
219 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index)
const
222 CoeffReturnType result = internal::cast<int, CoeffReturnType>(0);
223 Index index_stride = 0;
224 for (
int i = 0; i < NumReducedDims; ++i) {
225 index_stride += m_reducedStrides[i];
230 if (NumOutputDims != 0)
231 cur_index = firstInput(index);
232 for (Index i = 0; i < m_traceDim; ++i) {
233 result += m_impl.coeff(cur_index);
234 cur_index += index_stride;
240 template<
int LoadMode>
241 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index)
const {
243 EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE);
244 eigen_assert(index + PacketSize - 1 < dimensions().TotalSize());
246 EIGEN_ALIGN_MAX
typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
247 for (
int i = 0; i < PacketSize; ++i) {
248 values[i] = coeff(index + i);
250 PacketReturnType result = internal::ploadt<PacketReturnType, LoadMode>(values);
256 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void bind(cl::sycl::handler &cgh)
const {
263 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index firstInput(Index index)
const {
264 Index startInput = 0;
265 if (
static_cast<int>(Layout) ==
static_cast<int>(
ColMajor)) {
266 for (
int i = NumOutputDims - 1; i > 0; --i) {
267 const Index idx = index / m_outputStrides[i];
268 startInput += idx * m_preservedStrides[i];
269 index -= idx * m_outputStrides[i];
271 startInput += index * m_preservedStrides[0];
274 for (
int i = 0; i < NumOutputDims - 1; ++i) {
275 const Index idx = index / m_outputStrides[i];
276 startInput += idx * m_preservedStrides[i];
277 index -= idx * m_outputStrides[i];
279 startInput += index * m_preservedStrides[NumOutputDims - 1];
284 Dimensions m_dimensions;
285 TensorEvaluator<ArgType, Device> m_impl;
288 const Device EIGEN_DEVICE_REF m_device;
289 array<bool, NumInputDims> m_reduced;
290 array<Index, NumReducedDims> m_reducedDims;
291 array<Index, NumOutputDims> m_outputStrides;
292 array<Index, NumReducedDims> m_reducedStrides;
293 array<Index, NumOutputDims> m_preservedStrides;
The tensor base class.
Definition TensorForwardDeclarations.h:56
Tensor Trace class.
Definition TensorTrace.h:50
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The tensor evaluator class.
Definition TensorEvaluator.h:27