10#ifndef EIGEN_CXX11_TENSOR_TENSOR_IO_H
11#define EIGEN_CXX11_TENSOR_TENSOR_IO_H
14#include "./InternalHeaderCheck.h"
21template <
typename Tensor, std::
size_t rank,
typename Format,
typename EnableIf =
void>
25template <
typename Derived_>
26struct TensorIOFormatBase {
27 using Derived = Derived_;
28 TensorIOFormatBase(
const std::vector<std::string>& separator,
const std::vector<std::string>& prefix,
29 const std::vector<std::string>& suffix,
int precision = StreamPrecision,
int flags = 0,
30 const std::string& tenPrefix =
"",
const std::string& tenSuffix =
"",
const char fill =
' ')
31 : tenPrefix(tenPrefix),
43 if ((flags & DontAlignCols))
return;
44 spacer.resize(prefix.size());
46 int i = int(tenPrefix.length()) - 1;
47 while (i >= 0 && tenPrefix[i] !=
'\n') {
52 for (std::size_t k = 1; k < prefix.size(); k++) {
53 int j = int(prefix[k].length()) - 1;
54 while (j >= 0 && prefix[k][j] !=
'\n') {
61 std::string tenPrefix;
62 std::string tenSuffix;
63 std::vector<std::string> prefix;
64 std::vector<std::string> suffix;
65 std::vector<std::string> separator;
69 std::vector<std::string> spacer{};
72struct TensorIOFormatNumpy :
public TensorIOFormatBase<TensorIOFormatNumpy> {
73 using Base = TensorIOFormatBase<TensorIOFormatNumpy>;
75 : Base({
" ",
"\n"}, {
"",
"["}, {
"",
"]"}, StreamPrecision,
79struct TensorIOFormatNative :
public TensorIOFormatBase<TensorIOFormatNative> {
80 using Base = TensorIOFormatBase<TensorIOFormatNative>;
81 TensorIOFormatNative()
82 : Base({
", ",
",\n",
"\n"}, {
"",
"{"}, {
"",
"}"},
83 StreamPrecision, 0,
"{",
"}") {}
86struct TensorIOFormatPlain :
public TensorIOFormatBase<TensorIOFormatPlain> {
87 using Base = TensorIOFormatBase<TensorIOFormatPlain>;
89 : Base({
" ",
"\n",
"\n",
""}, {
""}, {
""}, StreamPrecision,
93struct TensorIOFormatLegacy :
public TensorIOFormatBase<TensorIOFormatLegacy> {
94 using Base = TensorIOFormatBase<TensorIOFormatLegacy>;
95 TensorIOFormatLegacy()
96 : Base({
", ",
"\n"}, {
"",
"["}, {
"",
"]"}, StreamPrecision,
100struct TensorIOFormat :
public TensorIOFormatBase<TensorIOFormat> {
101 using Base = TensorIOFormatBase<TensorIOFormat>;
102 TensorIOFormat(
const std::vector<std::string>& separator,
const std::vector<std::string>& prefix,
103 const std::vector<std::string>& suffix,
int precision = StreamPrecision,
int flags = 0,
104 const std::string& tenPrefix =
"",
const std::string& tenSuffix =
"",
const char fill =
' ')
105 : Base(separator, prefix, suffix, precision, flags, tenPrefix, tenSuffix, fill) {}
107 static inline const TensorIOFormatNumpy Numpy() {
return TensorIOFormatNumpy{}; }
109 static inline const TensorIOFormatPlain Plain() {
return TensorIOFormatPlain{}; }
111 static inline const TensorIOFormatNative Native() {
return TensorIOFormatNative{}; }
113 static inline const TensorIOFormatLegacy Legacy() {
return TensorIOFormatLegacy{}; }
116template <
typename T,
int Layout,
int rank,
typename Format>
117class TensorWithFormat;
119template <
typename T,
int rank,
typename Format>
120class TensorWithFormat<T,
RowMajor, rank, Format> {
122 TensorWithFormat(
const T& tensor,
const Format& format) : t_tensor(tensor), t_format(format) {}
124 friend std::ostream& operator<<(std::ostream& os,
const TensorWithFormat<T, RowMajor, rank, Format>& wf) {
126 typedef TensorEvaluator<const TensorForcedEvalOp<const T>, DefaultDevice> Evaluator;
127 TensorForcedEvalOp<const T> eval = wf.t_tensor.eval();
128 Evaluator tensor(eval, DefaultDevice());
129 tensor.evalSubExprsIfNeeded(NULL);
130 internal::TensorPrinter<Evaluator, rank, Format>::run(os, tensor, wf.t_format);
141template <
typename T,
int rank,
typename Format>
142class TensorWithFormat<T,
ColMajor, rank, Format> {
144 TensorWithFormat(
const T& tensor,
const Format& format) : t_tensor(tensor), t_format(format) {}
146 friend std::ostream& operator<<(std::ostream& os,
const TensorWithFormat<T, ColMajor, rank, Format>& wf) {
148 typedef typename T::Index IndexType;
149 std::array<IndexType, rank> shuffle;
150 std::array<IndexType, rank> id;
151 std::iota(
id.begin(),
id.end(), IndexType(0));
152 std::copy(
id.begin(),
id.end(), shuffle.rbegin());
153 auto tensor_row_major = wf.t_tensor.swap_layout().shuffle(shuffle);
156 typedef TensorEvaluator<
const TensorForcedEvalOp<
const decltype(tensor_row_major)>, DefaultDevice> Evaluator;
157 TensorForcedEvalOp<
const decltype(tensor_row_major)> eval = tensor_row_major.eval();
158 Evaluator tensor(eval, DefaultDevice());
159 tensor.evalSubExprsIfNeeded(NULL);
160 internal::TensorPrinter<Evaluator, rank, Format>::run(os, tensor, wf.t_format);
171template <
typename T,
typename Format>
172class TensorWithFormat<T,
ColMajor, 0, Format> {
174 TensorWithFormat(
const T& tensor,
const Format& format) : t_tensor(tensor), t_format(format) {}
176 friend std::ostream& operator<<(std::ostream& os,
const TensorWithFormat<T, ColMajor, 0, Format>& wf) {
178 typedef TensorEvaluator<const TensorForcedEvalOp<const T>, DefaultDevice> Evaluator;
179 TensorForcedEvalOp<const T> eval = wf.t_tensor.eval();
180 Evaluator tensor(eval, DefaultDevice());
181 tensor.evalSubExprsIfNeeded(NULL);
182 internal::TensorPrinter<Evaluator, 0, Format>::run(os, tensor, wf.t_format);
196template <
typename Scalar,
typename Format,
typename EnableIf =
void>
197struct ScalarPrinter {
198 static void run(std::ostream& stream,
const Scalar& scalar,
const Format&) { stream << scalar; }
201template <
typename Scalar>
202struct ScalarPrinter<Scalar, TensorIOFormatNumpy, std::enable_if_t<NumTraits<Scalar>::IsComplex>> {
203 static void run(std::ostream& stream,
const Scalar& scalar,
const TensorIOFormatNumpy&) {
204 stream << numext::real(scalar) <<
"+" << numext::imag(scalar) <<
"j";
208template <
typename Scalar>
209struct ScalarPrinter<Scalar, TensorIOFormatNative, std::enable_if_t<NumTraits<Scalar>::IsComplex>> {
210 static void run(std::ostream& stream,
const Scalar& scalar,
const TensorIOFormatNative&) {
211 stream <<
"{" << numext::real(scalar) <<
", " << numext::imag(scalar) <<
"}";
215template <
typename Tensor, std::
size_t rank,
typename Format,
typename EnableIf>
216struct TensorPrinter {
217 using Scalar = std::remove_const_t<typename Tensor::Scalar>;
219 static void run(std::ostream& s,
const Tensor& tensor,
const Format& fmt) {
220 typedef typename Tensor::Index IndexType;
222 eigen_assert(Tensor::Layout ==
RowMajor);
223 typedef std::conditional_t<is_same<Scalar, char>::value || is_same<Scalar, unsigned char>::value ||
224 is_same<Scalar, numext::int8_t>::value || is_same<Scalar, numext::uint8_t>::value,
226 std::conditional_t<is_same<Scalar, std::complex<char>>::value ||
227 is_same<Scalar, std::complex<unsigned char>>::value ||
228 is_same<Scalar, std::complex<numext::int8_t>>::value ||
229 is_same<Scalar, std::complex<numext::uint8_t>>::value,
230 std::complex<int>,
const Scalar&>>
233 const IndexType total_size = array_prod(tensor.dimensions());
235 std::streamsize explicit_precision;
236 if (fmt.precision == StreamPrecision) {
237 explicit_precision = 0;
238 }
else if (fmt.precision == FullPrecision) {
239 if (NumTraits<Scalar>::IsInteger) {
240 explicit_precision = 0;
242 explicit_precision = significant_decimals_impl<Scalar>::run();
245 explicit_precision = fmt.precision;
248 std::streamsize old_precision = 0;
249 if (explicit_precision) old_precision = s.precision(explicit_precision);
252 bool align_cols = !(fmt.flags & DontAlignCols);
255 for (IndexType i = 0; i < total_size; i++) {
256 std::stringstream sstr;
258 ScalarPrinter<Scalar, Format>::run(sstr,
static_cast<PrintType
>(tensor.data()[i]), fmt);
259 width = std::max<IndexType>(width, IndexType(sstr.str().length()));
263 for (IndexType i = 0; i < total_size; i++) {
264 std::array<bool, rank> is_at_end{};
265 std::array<bool, rank> is_at_begin{};
268 for (std::size_t k = 0; k < rank; k++) {
269 if ((i + 1) % (std::accumulate(tensor.dimensions().rbegin(), tensor.dimensions().rbegin() + k, 1,
270 std::multiplies<IndexType>())) ==
277 for (std::size_t k = 0; k < rank; k++) {
278 if (i % (std::accumulate(tensor.dimensions().rbegin(), tensor.dimensions().rbegin() + k, 1,
279 std::multiplies<IndexType>())) ==
281 is_at_begin[k] =
true;
286 bool is_at_begin_after_newline =
false;
287 for (std::size_t k = 0; k < rank; k++) {
288 if (is_at_begin[k]) {
289 std::size_t separator_index = (k < fmt.separator.size()) ? k : fmt.separator.size() - 1;
290 if (fmt.separator[separator_index].find(
'\n') != std::string::npos) {
291 is_at_begin_after_newline =
true;
296 bool is_at_end_before_newline =
false;
297 for (std::size_t k = 0; k < rank; k++) {
299 std::size_t separator_index = (k < fmt.separator.size()) ? k : fmt.separator.size() - 1;
300 if (fmt.separator[separator_index].find(
'\n') != std::string::npos) {
301 is_at_end_before_newline =
true;
306 std::stringstream suffix, prefix, separator;
307 for (std::size_t k = 0; k < rank; k++) {
308 std::size_t suffix_index = (k < fmt.suffix.size()) ? k : fmt.suffix.size() - 1;
310 suffix << fmt.suffix[suffix_index];
313 for (std::size_t k = 0; k < rank; k++) {
314 std::size_t separator_index = (k < fmt.separator.size()) ? k : fmt.separator.size() - 1;
316 (!is_at_end_before_newline || fmt.separator[separator_index].find(
'\n') != std::string::npos)) {
317 separator << fmt.separator[separator_index];
320 for (std::size_t k = 0; k < rank; k++) {
321 std::size_t spacer_index = (k < fmt.spacer.size()) ? k : fmt.spacer.size() - 1;
322 if (i != 0 && is_at_begin_after_newline && (!is_at_begin[k] || k == 0)) {
323 prefix << fmt.spacer[spacer_index];
326 for (
int k = rank - 1; k >= 0; k--) {
327 std::size_t prefix_index = (
static_cast<std::size_t
>(k) < fmt.prefix.size()) ? k : fmt.prefix.size() - 1;
328 if (is_at_begin[k]) {
329 prefix << fmt.prefix[prefix_index];
335 std::stringstream sstr;
337 ScalarPrinter<Scalar, Format>::run(sstr,
static_cast<PrintType
>(tensor.data()[i]), fmt);
338 std::string scalar_str = sstr.str();
339 IndexType scalar_width = scalar_str.length();
340 if (width && scalar_width < width) {
342 for (IndexType j = scalar_width; j < width; ++j) {
343 filler.push_back(fmt.fill);
349 if (i < total_size - 1) {
350 s << separator.str();
354 if (explicit_precision) s.precision(old_precision);
358template <
typename Tensor, std::
size_t rank>
359struct TensorPrinter<Tensor, rank, TensorIOFormatLegacy, std::enable_if_t<rank != 0>> {
360 using Format = TensorIOFormatLegacy;
361 using Scalar = std::remove_const_t<typename Tensor::Scalar>;
363 static void run(std::ostream& s,
const Tensor& tensor,
const Format&) {
364 typedef typename Tensor::Index IndexType;
367 const IndexType total_size = internal::array_prod(tensor.dimensions());
368 if (total_size > 0) {
369 const IndexType first_dim = Eigen::internal::array_get<0>(tensor.dimensions());
370 Map<const Array<Scalar, Dynamic, Dynamic, Tensor::Layout>> matrix(tensor.data(), first_dim,
371 total_size / first_dim);
378template <
typename Tensor,
typename Format>
379struct TensorPrinter<Tensor, 0, Format> {
380 static void run(std::ostream& s,
const Tensor& tensor,
const Format& fmt) {
381 using Scalar = std::remove_const_t<typename Tensor::Scalar>;
383 std::streamsize explicit_precision;
384 if (fmt.precision == StreamPrecision) {
385 explicit_precision = 0;
386 }
else if (fmt.precision == FullPrecision) {
387 if (NumTraits<Scalar>::IsInteger) {
388 explicit_precision = 0;
390 explicit_precision = significant_decimals_impl<Scalar>::run();
393 explicit_precision = fmt.precision;
396 std::streamsize old_precision = 0;
397 if (explicit_precision) old_precision = s.precision(explicit_precision);
399 ScalarPrinter<Scalar, Format>::run(s, tensor.coeff(0), fmt);
401 if (explicit_precision) s.precision(old_precision);
408 s << t.format(TensorIOFormat::Plain());
The tensor base class.
Definition TensorForwardDeclarations.h:68
Namespace containing all symbols from the Eigen library.