Eigen-unsupported  5.0.1-dev+284dcc12
 
Loading...
Searching...
No Matches
TensorIO.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_CXX11_TENSOR_TENSOR_IO_H
11#define EIGEN_CXX11_TENSOR_TENSOR_IO_H
12
13// IWYU pragma: private
14#include "./InternalHeaderCheck.h"
15
16namespace Eigen {
17
18struct TensorIOFormat;
19
20namespace internal {
21template <typename Tensor, std::size_t rank, typename Format, typename EnableIf = void>
22struct TensorPrinter;
23}
24
25template <typename Derived_>
26struct TensorIOFormatBase {
27 using Derived = Derived_;
28 TensorIOFormatBase(const std::vector<std::string>& separator, const std::vector<std::string>& prefix,
29 const std::vector<std::string>& suffix, int precision = StreamPrecision, int flags = 0,
30 const std::string& tenPrefix = "", const std::string& tenSuffix = "", const char fill = ' ')
31 : tenPrefix(tenPrefix),
32 tenSuffix(tenSuffix),
33 prefix(prefix),
34 suffix(suffix),
35 separator(separator),
36 fill(fill),
37 precision(precision),
38 flags(flags) {
39 init_spacer();
40 }
41
42 void init_spacer() {
43 if ((flags & DontAlignCols)) return;
44 spacer.resize(prefix.size());
45 spacer[0] = "";
46 int i = int(tenPrefix.length()) - 1;
47 while (i >= 0 && tenPrefix[i] != '\n') {
48 spacer[0] += ' ';
49 i--;
50 }
51
52 for (std::size_t k = 1; k < prefix.size(); k++) {
53 int j = int(prefix[k].length()) - 1;
54 while (j >= 0 && prefix[k][j] != '\n') {
55 spacer[k] += ' ';
56 j--;
57 }
58 }
59 }
60
61 std::string tenPrefix;
62 std::string tenSuffix;
63 std::vector<std::string> prefix;
64 std::vector<std::string> suffix;
65 std::vector<std::string> separator;
66 char fill;
67 int precision;
68 int flags;
69 std::vector<std::string> spacer{};
70};
71
72struct TensorIOFormatNumpy : public TensorIOFormatBase<TensorIOFormatNumpy> {
73 using Base = TensorIOFormatBase<TensorIOFormatNumpy>;
74 TensorIOFormatNumpy()
75 : Base(/*separator=*/{" ", "\n"}, /*prefix=*/{"", "["}, /*suffix=*/{"", "]"}, /*precision=*/StreamPrecision,
76 /*flags=*/0, /*tenPrefix=*/"[", /*tenSuffix=*/"]") {}
77};
78
79struct TensorIOFormatNative : public TensorIOFormatBase<TensorIOFormatNative> {
80 using Base = TensorIOFormatBase<TensorIOFormatNative>;
81 TensorIOFormatNative()
82 : Base(/*separator=*/{", ", ",\n", "\n"}, /*prefix=*/{"", "{"}, /*suffix=*/{"", "}"},
83 /*precision=*/StreamPrecision, /*flags=*/0, /*tenPrefix=*/"{", /*tenSuffix=*/"}") {}
84};
85
86struct TensorIOFormatPlain : public TensorIOFormatBase<TensorIOFormatPlain> {
87 using Base = TensorIOFormatBase<TensorIOFormatPlain>;
88 TensorIOFormatPlain()
89 : Base(/*separator=*/{" ", "\n", "\n", ""}, /*prefix=*/{""}, /*suffix=*/{""}, /*precision=*/StreamPrecision,
90 /*flags=*/0, /*tenPrefix=*/"", /*tenSuffix=*/"") {}
91};
92
93struct TensorIOFormatLegacy : public TensorIOFormatBase<TensorIOFormatLegacy> {
94 using Base = TensorIOFormatBase<TensorIOFormatLegacy>;
95 TensorIOFormatLegacy()
96 : Base(/*separator=*/{", ", "\n"}, /*prefix=*/{"", "["}, /*suffix=*/{"", "]"}, /*precision=*/StreamPrecision,
97 /*flags=*/0, /*tenPrefix=*/"", /*tenSuffix=*/"") {}
98};
99
100struct TensorIOFormat : public TensorIOFormatBase<TensorIOFormat> {
101 using Base = TensorIOFormatBase<TensorIOFormat>;
102 TensorIOFormat(const std::vector<std::string>& separator, const std::vector<std::string>& prefix,
103 const std::vector<std::string>& suffix, int precision = StreamPrecision, int flags = 0,
104 const std::string& tenPrefix = "", const std::string& tenSuffix = "", const char fill = ' ')
105 : Base(separator, prefix, suffix, precision, flags, tenPrefix, tenSuffix, fill) {}
106
107 static inline const TensorIOFormatNumpy Numpy() { return TensorIOFormatNumpy{}; }
108
109 static inline const TensorIOFormatPlain Plain() { return TensorIOFormatPlain{}; }
110
111 static inline const TensorIOFormatNative Native() { return TensorIOFormatNative{}; }
112
113 static inline const TensorIOFormatLegacy Legacy() { return TensorIOFormatLegacy{}; }
114};
115
116template <typename T, int Layout, int rank, typename Format>
117class TensorWithFormat;
118// specialize for Layout=ColMajor, Layout=RowMajor and rank=0.
119template <typename T, int rank, typename Format>
120class TensorWithFormat<T, RowMajor, rank, Format> {
121 public:
122 TensorWithFormat(const T& tensor, const Format& format) : t_tensor(tensor), t_format(format) {}
123
124 friend std::ostream& operator<<(std::ostream& os, const TensorWithFormat<T, RowMajor, rank, Format>& wf) {
125 // Evaluate the expression if needed
126 typedef TensorEvaluator<const TensorForcedEvalOp<const T>, DefaultDevice> Evaluator;
127 TensorForcedEvalOp<const T> eval = wf.t_tensor.eval();
128 Evaluator tensor(eval, DefaultDevice());
129 tensor.evalSubExprsIfNeeded(NULL);
130 internal::TensorPrinter<Evaluator, rank, Format>::run(os, tensor, wf.t_format);
131 // Cleanup.
132 tensor.cleanup();
133 return os;
134 }
135
136 protected:
137 T t_tensor;
138 Format t_format;
139};
140
141template <typename T, int rank, typename Format>
142class TensorWithFormat<T, ColMajor, rank, Format> {
143 public:
144 TensorWithFormat(const T& tensor, const Format& format) : t_tensor(tensor), t_format(format) {}
145
146 friend std::ostream& operator<<(std::ostream& os, const TensorWithFormat<T, ColMajor, rank, Format>& wf) {
147 // Switch to RowMajor storage and print afterwards
148 typedef typename T::Index IndexType;
149 std::array<IndexType, rank> shuffle;
150 std::array<IndexType, rank> id;
151 std::iota(id.begin(), id.end(), IndexType(0));
152 std::copy(id.begin(), id.end(), shuffle.rbegin());
153 auto tensor_row_major = wf.t_tensor.swap_layout().shuffle(shuffle);
154
155 // Evaluate the expression if needed
156 typedef TensorEvaluator<const TensorForcedEvalOp<const decltype(tensor_row_major)>, DefaultDevice> Evaluator;
157 TensorForcedEvalOp<const decltype(tensor_row_major)> eval = tensor_row_major.eval();
158 Evaluator tensor(eval, DefaultDevice());
159 tensor.evalSubExprsIfNeeded(NULL);
160 internal::TensorPrinter<Evaluator, rank, Format>::run(os, tensor, wf.t_format);
161 // Cleanup.
162 tensor.cleanup();
163 return os;
164 }
165
166 protected:
167 T t_tensor;
168 Format t_format;
169};
170
171template <typename T, typename Format>
172class TensorWithFormat<T, ColMajor, 0, Format> {
173 public:
174 TensorWithFormat(const T& tensor, const Format& format) : t_tensor(tensor), t_format(format) {}
175
176 friend std::ostream& operator<<(std::ostream& os, const TensorWithFormat<T, ColMajor, 0, Format>& wf) {
177 // Evaluate the expression if needed
178 typedef TensorEvaluator<const TensorForcedEvalOp<const T>, DefaultDevice> Evaluator;
179 TensorForcedEvalOp<const T> eval = wf.t_tensor.eval();
180 Evaluator tensor(eval, DefaultDevice());
181 tensor.evalSubExprsIfNeeded(NULL);
182 internal::TensorPrinter<Evaluator, 0, Format>::run(os, tensor, wf.t_format);
183 // Cleanup.
184 tensor.cleanup();
185 return os;
186 }
187
188 protected:
189 T t_tensor;
190 Format t_format;
191};
192
193namespace internal {
194
195// Default scalar printer.
196template <typename Scalar, typename Format, typename EnableIf = void>
197struct ScalarPrinter {
198 static void run(std::ostream& stream, const Scalar& scalar, const Format&) { stream << scalar; }
199};
200
201template <typename Scalar>
202struct ScalarPrinter<Scalar, TensorIOFormatNumpy, std::enable_if_t<NumTraits<Scalar>::IsComplex>> {
203 static void run(std::ostream& stream, const Scalar& scalar, const TensorIOFormatNumpy&) {
204 stream << numext::real(scalar) << "+" << numext::imag(scalar) << "j";
205 }
206};
207
208template <typename Scalar>
209struct ScalarPrinter<Scalar, TensorIOFormatNative, std::enable_if_t<NumTraits<Scalar>::IsComplex>> {
210 static void run(std::ostream& stream, const Scalar& scalar, const TensorIOFormatNative&) {
211 stream << "{" << numext::real(scalar) << ", " << numext::imag(scalar) << "}";
212 }
213};
214
215template <typename Tensor, std::size_t rank, typename Format, typename EnableIf>
216struct TensorPrinter {
217 using Scalar = std::remove_const_t<typename Tensor::Scalar>;
218
219 static void run(std::ostream& s, const Tensor& tensor, const Format& fmt) {
220 typedef typename Tensor::Index IndexType;
221
222 eigen_assert(Tensor::Layout == RowMajor);
223 typedef std::conditional_t<is_same<Scalar, char>::value || is_same<Scalar, unsigned char>::value ||
224 is_same<Scalar, numext::int8_t>::value || is_same<Scalar, numext::uint8_t>::value,
225 int,
226 std::conditional_t<is_same<Scalar, std::complex<char>>::value ||
227 is_same<Scalar, std::complex<unsigned char>>::value ||
228 is_same<Scalar, std::complex<numext::int8_t>>::value ||
229 is_same<Scalar, std::complex<numext::uint8_t>>::value,
230 std::complex<int>, const Scalar&>>
231 PrintType;
232
233 const IndexType total_size = array_prod(tensor.dimensions());
234
235 std::streamsize explicit_precision;
236 if (fmt.precision == StreamPrecision) {
237 explicit_precision = 0;
238 } else if (fmt.precision == FullPrecision) {
239 if (NumTraits<Scalar>::IsInteger) {
240 explicit_precision = 0;
241 } else {
242 explicit_precision = significant_decimals_impl<Scalar>::run();
243 }
244 } else {
245 explicit_precision = fmt.precision;
246 }
247
248 std::streamsize old_precision = 0;
249 if (explicit_precision) old_precision = s.precision(explicit_precision);
250
251 IndexType width = 0;
252 bool align_cols = !(fmt.flags & DontAlignCols);
253 if (align_cols) {
254 // compute the largest width
255 for (IndexType i = 0; i < total_size; i++) {
256 std::stringstream sstr;
257 sstr.copyfmt(s);
258 ScalarPrinter<Scalar, Format>::run(sstr, static_cast<PrintType>(tensor.data()[i]), fmt);
259 width = std::max<IndexType>(width, IndexType(sstr.str().length()));
260 }
261 }
262 s << fmt.tenPrefix;
263 for (IndexType i = 0; i < total_size; i++) {
264 std::array<bool, rank> is_at_end{};
265 std::array<bool, rank> is_at_begin{};
266
267 // is the ith element the end of an coeff (always true), of a row, of a matrix, ...?
268 for (std::size_t k = 0; k < rank; k++) {
269 if ((i + 1) % (std::accumulate(tensor.dimensions().rbegin(), tensor.dimensions().rbegin() + k, 1,
270 std::multiplies<IndexType>())) ==
271 0) {
272 is_at_end[k] = true;
273 }
274 }
275
276 // is the ith element the begin of an coeff (always true), of a row, of a matrix, ...?
277 for (std::size_t k = 0; k < rank; k++) {
278 if (i % (std::accumulate(tensor.dimensions().rbegin(), tensor.dimensions().rbegin() + k, 1,
279 std::multiplies<IndexType>())) ==
280 0) {
281 is_at_begin[k] = true;
282 }
283 }
284
285 // do we have a line break?
286 bool is_at_begin_after_newline = false;
287 for (std::size_t k = 0; k < rank; k++) {
288 if (is_at_begin[k]) {
289 std::size_t separator_index = (k < fmt.separator.size()) ? k : fmt.separator.size() - 1;
290 if (fmt.separator[separator_index].find('\n') != std::string::npos) {
291 is_at_begin_after_newline = true;
292 }
293 }
294 }
295
296 bool is_at_end_before_newline = false;
297 for (std::size_t k = 0; k < rank; k++) {
298 if (is_at_end[k]) {
299 std::size_t separator_index = (k < fmt.separator.size()) ? k : fmt.separator.size() - 1;
300 if (fmt.separator[separator_index].find('\n') != std::string::npos) {
301 is_at_end_before_newline = true;
302 }
303 }
304 }
305
306 std::stringstream suffix, prefix, separator;
307 for (std::size_t k = 0; k < rank; k++) {
308 std::size_t suffix_index = (k < fmt.suffix.size()) ? k : fmt.suffix.size() - 1;
309 if (is_at_end[k]) {
310 suffix << fmt.suffix[suffix_index];
311 }
312 }
313 for (std::size_t k = 0; k < rank; k++) {
314 std::size_t separator_index = (k < fmt.separator.size()) ? k : fmt.separator.size() - 1;
315 if (is_at_end[k] &&
316 (!is_at_end_before_newline || fmt.separator[separator_index].find('\n') != std::string::npos)) {
317 separator << fmt.separator[separator_index];
318 }
319 }
320 for (std::size_t k = 0; k < rank; k++) {
321 std::size_t spacer_index = (k < fmt.spacer.size()) ? k : fmt.spacer.size() - 1;
322 if (i != 0 && is_at_begin_after_newline && (!is_at_begin[k] || k == 0)) {
323 prefix << fmt.spacer[spacer_index];
324 }
325 }
326 for (int k = rank - 1; k >= 0; k--) {
327 std::size_t prefix_index = (static_cast<std::size_t>(k) < fmt.prefix.size()) ? k : fmt.prefix.size() - 1;
328 if (is_at_begin[k]) {
329 prefix << fmt.prefix[prefix_index];
330 }
331 }
332
333 s << prefix.str();
334 // So we don't mess around with formatting, output scalar to a string stream, and adjust the width/fill manually.
335 std::stringstream sstr;
336 sstr.copyfmt(s);
337 ScalarPrinter<Scalar, Format>::run(sstr, static_cast<PrintType>(tensor.data()[i]), fmt);
338 std::string scalar_str = sstr.str();
339 IndexType scalar_width = scalar_str.length();
340 if (width && scalar_width < width) {
341 std::string filler;
342 for (IndexType j = scalar_width; j < width; ++j) {
343 filler.push_back(fmt.fill);
344 }
345 s << filler;
346 }
347 s << scalar_str;
348 s << suffix.str();
349 if (i < total_size - 1) {
350 s << separator.str();
351 }
352 }
353 s << fmt.tenSuffix;
354 if (explicit_precision) s.precision(old_precision);
355 }
356};
357
358template <typename Tensor, std::size_t rank>
359struct TensorPrinter<Tensor, rank, TensorIOFormatLegacy, std::enable_if_t<rank != 0>> {
360 using Format = TensorIOFormatLegacy;
361 using Scalar = std::remove_const_t<typename Tensor::Scalar>;
362
363 static void run(std::ostream& s, const Tensor& tensor, const Format&) {
364 typedef typename Tensor::Index IndexType;
365 // backwards compatibility case: print tensor after reshaping to matrix of size dim(0) x
366 // (dim(1)*dim(2)*...*dim(rank-1)).
367 const IndexType total_size = internal::array_prod(tensor.dimensions());
368 if (total_size > 0) {
369 const IndexType first_dim = Eigen::internal::array_get<0>(tensor.dimensions());
370 Map<const Array<Scalar, Dynamic, Dynamic, Tensor::Layout>> matrix(tensor.data(), first_dim,
371 total_size / first_dim);
372 s << matrix;
373 return;
374 }
375 }
376};
377
378template <typename Tensor, typename Format>
379struct TensorPrinter<Tensor, 0, Format> {
380 static void run(std::ostream& s, const Tensor& tensor, const Format& fmt) {
381 using Scalar = std::remove_const_t<typename Tensor::Scalar>;
382
383 std::streamsize explicit_precision;
384 if (fmt.precision == StreamPrecision) {
385 explicit_precision = 0;
386 } else if (fmt.precision == FullPrecision) {
387 if (NumTraits<Scalar>::IsInteger) {
388 explicit_precision = 0;
389 } else {
390 explicit_precision = significant_decimals_impl<Scalar>::run();
391 }
392 } else {
393 explicit_precision = fmt.precision;
394 }
395
396 std::streamsize old_precision = 0;
397 if (explicit_precision) old_precision = s.precision(explicit_precision);
398 s << fmt.tenPrefix;
399 ScalarPrinter<Scalar, Format>::run(s, tensor.coeff(0), fmt);
400 s << fmt.tenSuffix;
401 if (explicit_precision) s.precision(old_precision);
402 }
403};
404
405} // end namespace internal
406template <typename T>
407std::ostream& operator<<(std::ostream& s, const TensorBase<T, ReadOnlyAccessors>& t) {
408 s << t.format(TensorIOFormat::Plain());
409 return s;
410}
411} // end namespace Eigen
412
413#endif // EIGEN_CXX11_TENSOR_TENSOR_IO_H
The tensor base class.
Definition TensorForwardDeclarations.h:68
Namespace containing all symbols from the Eigen library.