Eigen-unsupported  5.0.1-dev+284dcc12
 
Loading...
Searching...
No Matches
TensorDimensions.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CXX11_TENSOR_TENSOR_DIMENSIONS_H
11#define EIGEN_CXX11_TENSOR_TENSOR_DIMENSIONS_H
12
13// IWYU pragma: private
14#include "./InternalHeaderCheck.h"
15
16namespace Eigen {
17
18// Boilerplate code
19namespace internal {
20
21template <std::ptrdiff_t n, typename Dimension>
22struct dget {
23 static const std::ptrdiff_t value = get<n, Dimension>::value;
24};
25
26template <typename Index, std::ptrdiff_t NumIndices, std::ptrdiff_t n, bool RowMajor>
27struct fixed_size_tensor_index_linearization_helper {
28 template <typename Dimensions>
29 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Index run(array<Index, NumIndices> const& indices,
30 const Dimensions& dimensions) {
31 return array_get < RowMajor ? n - 1
32 : (NumIndices - n) > (indices) + dget < RowMajor ? n - 1
33 : (NumIndices - n),
34 Dimensions > ::value * fixed_size_tensor_index_linearization_helper<Index, NumIndices, n - 1, RowMajor>::run(
35 indices, dimensions);
36 }
37};
38
39template <typename Index, std::ptrdiff_t NumIndices, bool RowMajor>
40struct fixed_size_tensor_index_linearization_helper<Index, NumIndices, 0, RowMajor> {
41 template <typename Dimensions>
42 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Index run(array<Index, NumIndices> const&, const Dimensions&) {
43 return 0;
44 }
45};
46
47template <typename Index, std::ptrdiff_t n>
48struct fixed_size_tensor_index_extraction_helper {
49 template <typename Dimensions>
50 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Index run(const Index index, const Dimensions& dimensions) {
51 const Index mult = (index == n - 1) ? 1 : 0;
52 return array_get<n - 1>(dimensions) * mult +
53 fixed_size_tensor_index_extraction_helper<Index, n - 1>::run(index, dimensions);
54 }
55};
56
57template <typename Index>
58struct fixed_size_tensor_index_extraction_helper<Index, 0> {
59 template <typename Dimensions>
60 EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Index run(const Index, const Dimensions&) {
61 return 0;
62 }
63};
64
65} // end namespace internal
66
79template <typename std::ptrdiff_t... Indices>
80struct Sizes {
81 typedef internal::numeric_list<std::ptrdiff_t, Indices...> Base;
82 const Base t = Base();
83 static const std::ptrdiff_t total_size = internal::arg_prod(Indices...);
84 static const ptrdiff_t count = Base::count;
85
86 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t rank() const { return Base::count; }
87
88 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t TotalSize() { return internal::arg_prod(Indices...); }
89
90 EIGEN_DEVICE_FUNC Sizes() {}
91 template <typename DenseIndex>
92 explicit EIGEN_DEVICE_FUNC Sizes(const array<DenseIndex, Base::count>& /*indices*/) {
93 // todo: add assertion
94 }
95 template <typename... DenseIndex>
96 EIGEN_DEVICE_FUNC Sizes(DenseIndex...) {}
97 explicit EIGEN_DEVICE_FUNC Sizes(std::initializer_list<std::ptrdiff_t> /*l*/) {
98 // todo: add assertion
99 }
100
101 template <typename T>
102 Sizes& operator=(const T& /*other*/) {
103 // add assertion failure if the size of other is different
104 return *this;
105 }
106
107 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t operator[](const std::ptrdiff_t index) const {
108 return internal::fixed_size_tensor_index_extraction_helper<std::ptrdiff_t, Base::count>::run(index, t);
109 }
110
111 template <typename DenseIndex>
112 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ptrdiff_t IndexOfColMajor(const array<DenseIndex, Base::count>& indices) const {
113 return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count, false>::run(
114 indices, t);
115 }
116 template <typename DenseIndex>
117 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ptrdiff_t IndexOfRowMajor(const array<DenseIndex, Base::count>& indices) const {
118 return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count, true>::run(
119 indices, t);
120 }
121};
122
123namespace internal {
124template <typename std::ptrdiff_t... Indices>
125EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t array_prod(const Sizes<Indices...>&) {
126 return Sizes<Indices...>::total_size;
127}
128} // namespace internal
129
130// Boilerplate
131namespace internal {
132template <typename Index, std::ptrdiff_t NumIndices, std::ptrdiff_t n, bool RowMajor>
133struct tensor_index_linearization_helper {
134 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index run(array<Index, NumIndices> const& indices,
135 array<Index, NumIndices> const& dimensions) {
136 return array_get < RowMajor ? n
137 : (NumIndices - n - 1) > (indices) + array_get < RowMajor
138 ? n
139 : (NumIndices - n - 1) >
140 (dimensions)*tensor_index_linearization_helper<Index, NumIndices, n - 1, RowMajor>::run(
141 indices, dimensions);
142 }
143};
144
145template <typename Index, std::ptrdiff_t NumIndices, bool RowMajor>
146struct tensor_index_linearization_helper<Index, NumIndices, 0, RowMajor> {
147 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index run(array<Index, NumIndices> const& indices,
148 array<Index, NumIndices> const&) {
149 return array_get < RowMajor ? 0 : NumIndices - 1 > (indices);
150 }
151};
152} // end namespace internal
153
165template <typename DenseIndex, int NumDims>
166struct DSizes : array<DenseIndex, NumDims> {
167 typedef array<DenseIndex, NumDims> Base;
168 static const int count = NumDims;
169
170 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index rank() const { return NumDims; }
171
172 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DenseIndex TotalSize() const {
173 return (NumDims == 0) ? 1 : internal::array_prod(*static_cast<const Base*>(this));
174 }
175
176 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DSizes() {
177 for (int i = 0; i < NumDims; ++i) {
178 (*this)[i] = 0;
179 }
180 }
181 EIGEN_DEVICE_FUNC explicit DSizes(const array<DenseIndex, NumDims>& a) : Base(a) {}
182
183 EIGEN_DEVICE_FUNC explicit DSizes(const DenseIndex i0) {
184 eigen_assert(NumDims == 1);
185 (*this)[0] = i0;
186 }
187
188 EIGEN_DEVICE_FUNC DSizes(const DimensionList<DenseIndex, NumDims>& a) {
189 for (int i = 0; i < NumDims; ++i) {
190 (*this)[i] = a[i];
191 }
192 }
193
194 // Enable DSizes index type promotion only if we are promoting to the
195 // larger type, e.g. allow to promote dimensions of type int to long.
196 template <typename OtherIndex>
197 EIGEN_DEVICE_FUNC explicit DSizes(
198 const array<OtherIndex, NumDims>& other,
199 // Default template parameters require c++11.
200 std::enable_if_t<
201 internal::is_same<DenseIndex, typename internal::promote_index_type<DenseIndex, OtherIndex>::type>::value,
202 void*> = 0) {
203 for (int i = 0; i < NumDims; ++i) {
204 (*this)[i] = static_cast<DenseIndex>(other[i]);
205 }
206 }
207
208 template <typename FirstType, typename... OtherTypes>
209 EIGEN_DEVICE_FUNC explicit DSizes(const Eigen::IndexList<FirstType, OtherTypes...>& dimensions) {
210 for (int i = 0; i < dimensions.count; ++i) {
211 (*this)[i] = dimensions[i];
212 }
213 }
214
215 template <typename std::ptrdiff_t... Indices>
216 EIGEN_DEVICE_FUNC DSizes(const Sizes<Indices...>& a) {
217 for (int i = 0; i < NumDims; ++i) {
218 (*this)[i] = a[i];
219 }
220 }
221
222 template <typename... IndexTypes>
223 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit DSizes(DenseIndex firstDimension, DenseIndex secondDimension,
224 IndexTypes... otherDimensions)
225 : Base({{firstDimension, secondDimension, otherDimensions...}}) {
226 EIGEN_STATIC_ASSERT(sizeof...(otherDimensions) + 2 == NumDims, YOU_MADE_A_PROGRAMMING_MISTAKE)
227 }
228
229 EIGEN_DEVICE_FUNC DSizes& operator=(const array<DenseIndex, NumDims>& other) {
230 *static_cast<Base*>(this) = other;
231 return *this;
232 }
233
234 // A constexpr would be so much better here
235 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DenseIndex IndexOfColMajor(const array<DenseIndex, NumDims>& indices) const {
236 return internal::tensor_index_linearization_helper<DenseIndex, NumDims, NumDims - 1, false>::run(
237 indices, *static_cast<const Base*>(this));
238 }
239 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DenseIndex IndexOfRowMajor(const array<DenseIndex, NumDims>& indices) const {
240 return internal::tensor_index_linearization_helper<DenseIndex, NumDims, NumDims - 1, true>::run(
241 indices, *static_cast<const Base*>(this));
242 }
243};
244
245template <typename IndexType, int NumDims>
246std::ostream& operator<<(std::ostream& os, const DSizes<IndexType, NumDims>& dims) {
247 os << "[";
248 for (int i = 0; i < NumDims; ++i) {
249 if (i > 0) os << ", ";
250 os << dims[i];
251 }
252 os << "]";
253 return os;
254}
255
256// Boilerplate
257namespace internal {
258template <typename Index, std::ptrdiff_t NumIndices, std::ptrdiff_t n, bool RowMajor>
259struct tensor_vsize_index_linearization_helper {
260 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index run(array<Index, NumIndices> const& indices,
261 std::vector<DenseIndex> const& dimensions) {
262 return array_get < RowMajor ? n
263 : (NumIndices - n - 1) > (indices) + array_get < RowMajor
264 ? n
265 : (NumIndices - n - 1) >
266 (dimensions)*tensor_vsize_index_linearization_helper<Index, NumIndices, n - 1, RowMajor>::run(
267 indices, dimensions);
268 }
269};
270
271template <typename Index, std::ptrdiff_t NumIndices, bool RowMajor>
272struct tensor_vsize_index_linearization_helper<Index, NumIndices, 0, RowMajor> {
273 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index run(array<Index, NumIndices> const& indices,
274 std::vector<DenseIndex> const&) {
275 return array_get < RowMajor ? 0 : NumIndices - 1 > (indices);
276 }
277};
278} // end namespace internal
279
280namespace internal {
281
282template <typename DenseIndex, int NumDims>
283struct array_size<const DSizes<DenseIndex, NumDims> > {
284 static const ptrdiff_t value = NumDims;
285};
286template <typename DenseIndex, int NumDims>
287struct array_size<DSizes<DenseIndex, NumDims> > {
288 static const ptrdiff_t value = NumDims;
289};
290template <typename std::ptrdiff_t... Indices>
291struct array_size<const Sizes<Indices...> > {
292 static const std::ptrdiff_t value = Sizes<Indices...>::count;
293};
294template <typename std::ptrdiff_t... Indices>
295struct array_size<Sizes<Indices...> > {
296 static const std::ptrdiff_t value = Sizes<Indices...>::count;
297};
298template <std::ptrdiff_t n, typename std::ptrdiff_t... Indices>
299EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t array_get(const Sizes<Indices...>&) {
300 return get<n, internal::numeric_list<std::ptrdiff_t, Indices...> >::value;
301}
302template <std::ptrdiff_t n>
303EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t array_get(const Sizes<>&) {
304 eigen_assert(false && "should never be called");
305 return -1;
306}
307
308template <typename Dims1, typename Dims2, ptrdiff_t n, ptrdiff_t m>
309struct sizes_match_below_dim {
310 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool run(Dims1&, Dims2&) { return false; }
311};
312template <typename Dims1, typename Dims2, ptrdiff_t n>
313struct sizes_match_below_dim<Dims1, Dims2, n, n> {
314 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool run(Dims1& dims1, Dims2& dims2) {
315 return numext::equal_strict(array_get<n - 1>(dims1), array_get<n - 1>(dims2)) &&
316 sizes_match_below_dim<Dims1, Dims2, n - 1, n - 1>::run(dims1, dims2);
317 }
318};
319template <typename Dims1, typename Dims2>
320struct sizes_match_below_dim<Dims1, Dims2, 0, 0> {
321 static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool run(Dims1&, Dims2&) { return true; }
322};
323
324} // end namespace internal
325
326template <typename Dims1, typename Dims2>
327EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool dimensions_match(Dims1 dims1, Dims2 dims2) {
328 return internal::sizes_match_below_dim<Dims1, Dims2, internal::array_size<Dims1>::value,
329 internal::array_size<Dims2>::value>::run(dims1, dims2);
330}
331
332} // end namespace Eigen
333
334#endif // EIGEN_CXX11_TENSOR_TENSOR_DIMENSIONS_H
Namespace containing all symbols from the Eigen library.
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index