$darkmode
Eigen-unsupported  5.0.1-dev
TensorIO.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_IO_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_IO_H
12 
13 // IWYU pragma: private
14 #include "./InternalHeaderCheck.h"
15 
16 namespace Eigen {
17 
18 struct TensorIOFormat;
19 
20 namespace internal {
21 template <typename Tensor, std::size_t rank, typename Format, typename EnableIf = void>
22 struct TensorPrinter;
23 }
24 
25 template <typename Derived_>
26 struct TensorIOFormatBase {
27  using Derived = Derived_;
28  TensorIOFormatBase(const std::vector<std::string>& separator, const std::vector<std::string>& prefix,
29  const std::vector<std::string>& suffix, int precision = StreamPrecision, int flags = 0,
30  const std::string& tenPrefix = "", const std::string& tenSuffix = "", const char fill = ' ')
31  : tenPrefix(tenPrefix),
32  tenSuffix(tenSuffix),
33  prefix(prefix),
34  suffix(suffix),
35  separator(separator),
36  fill(fill),
37  precision(precision),
38  flags(flags) {
39  init_spacer();
40  }
41 
42  void init_spacer() {
43  if ((flags & DontAlignCols)) return;
44  spacer.resize(prefix.size());
45  spacer[0] = "";
46  int i = int(tenPrefix.length()) - 1;
47  while (i >= 0 && tenPrefix[i] != '\n') {
48  spacer[0] += ' ';
49  i--;
50  }
51 
52  for (std::size_t k = 1; k < prefix.size(); k++) {
53  int j = int(prefix[k].length()) - 1;
54  while (j >= 0 && prefix[k][j] != '\n') {
55  spacer[k] += ' ';
56  j--;
57  }
58  }
59  }
60 
61  std::string tenPrefix;
62  std::string tenSuffix;
63  std::vector<std::string> prefix;
64  std::vector<std::string> suffix;
65  std::vector<std::string> separator;
66  char fill;
67  int precision;
68  int flags;
69  std::vector<std::string> spacer{};
70 };
71 
72 struct TensorIOFormatNumpy : public TensorIOFormatBase<TensorIOFormatNumpy> {
73  using Base = TensorIOFormatBase<TensorIOFormatNumpy>;
74  TensorIOFormatNumpy()
75  : Base(/*separator=*/{" ", "\n"}, /*prefix=*/{"", "["}, /*suffix=*/{"", "]"}, /*precision=*/StreamPrecision,
76  /*flags=*/0, /*tenPrefix=*/"[", /*tenSuffix=*/"]") {}
77 };
78 
79 struct TensorIOFormatNative : public TensorIOFormatBase<TensorIOFormatNative> {
80  using Base = TensorIOFormatBase<TensorIOFormatNative>;
81  TensorIOFormatNative()
82  : Base(/*separator=*/{", ", ",\n", "\n"}, /*prefix=*/{"", "{"}, /*suffix=*/{"", "}"},
83  /*precision=*/StreamPrecision, /*flags=*/0, /*tenPrefix=*/"{", /*tenSuffix=*/"}") {}
84 };
85 
86 struct TensorIOFormatPlain : public TensorIOFormatBase<TensorIOFormatPlain> {
87  using Base = TensorIOFormatBase<TensorIOFormatPlain>;
88  TensorIOFormatPlain()
89  : Base(/*separator=*/{" ", "\n", "\n", ""}, /*prefix=*/{""}, /*suffix=*/{""}, /*precision=*/StreamPrecision,
90  /*flags=*/0, /*tenPrefix=*/"", /*tenSuffix=*/"") {}
91 };
92 
93 struct TensorIOFormatLegacy : public TensorIOFormatBase<TensorIOFormatLegacy> {
94  using Base = TensorIOFormatBase<TensorIOFormatLegacy>;
95  TensorIOFormatLegacy()
96  : Base(/*separator=*/{", ", "\n"}, /*prefix=*/{"", "["}, /*suffix=*/{"", "]"}, /*precision=*/StreamPrecision,
97  /*flags=*/0, /*tenPrefix=*/"", /*tenSuffix=*/"") {}
98 };
99 
100 struct TensorIOFormat : public TensorIOFormatBase<TensorIOFormat> {
101  using Base = TensorIOFormatBase<TensorIOFormat>;
102  TensorIOFormat(const std::vector<std::string>& separator, const std::vector<std::string>& prefix,
103  const std::vector<std::string>& suffix, int precision = StreamPrecision, int flags = 0,
104  const std::string& tenPrefix = "", const std::string& tenSuffix = "", const char fill = ' ')
105  : Base(separator, prefix, suffix, precision, flags, tenPrefix, tenSuffix, fill) {}
106 
107  static inline const TensorIOFormatNumpy Numpy() { return TensorIOFormatNumpy{}; }
108 
109  static inline const TensorIOFormatPlain Plain() { return TensorIOFormatPlain{}; }
110 
111  static inline const TensorIOFormatNative Native() { return TensorIOFormatNative{}; }
112 
113  static inline const TensorIOFormatLegacy Legacy() { return TensorIOFormatLegacy{}; }
114 };
115 
116 template <typename T, int Layout, int rank, typename Format>
117 class TensorWithFormat;
118 // specialize for Layout=ColMajor, Layout=RowMajor and rank=0.
119 template <typename T, int rank, typename Format>
120 class TensorWithFormat<T, RowMajor, rank, Format> {
121  public:
122  TensorWithFormat(const T& tensor, const Format& format) : t_tensor(tensor), t_format(format) {}
123 
124  friend std::ostream& operator<<(std::ostream& os, const TensorWithFormat<T, RowMajor, rank, Format>& wf) {
125  // Evaluate the expression if needed
126  typedef TensorEvaluator<const TensorForcedEvalOp<const T>, DefaultDevice> Evaluator;
127  TensorForcedEvalOp<const T> eval = wf.t_tensor.eval();
128  Evaluator tensor(eval, DefaultDevice());
129  tensor.evalSubExprsIfNeeded(NULL);
130  internal::TensorPrinter<Evaluator, rank, Format>::run(os, tensor, wf.t_format);
131  // Cleanup.
132  tensor.cleanup();
133  return os;
134  }
135 
136  protected:
137  T t_tensor;
138  Format t_format;
139 };
140 
141 template <typename T, int rank, typename Format>
142 class TensorWithFormat<T, ColMajor, rank, Format> {
143  public:
144  TensorWithFormat(const T& tensor, const Format& format) : t_tensor(tensor), t_format(format) {}
145 
146  friend std::ostream& operator<<(std::ostream& os, const TensorWithFormat<T, ColMajor, rank, Format>& wf) {
147  // Switch to RowMajor storage and print afterwards
148  typedef typename T::Index IndexType;
149  std::array<IndexType, rank> shuffle;
150  std::array<IndexType, rank> id;
151  std::iota(id.begin(), id.end(), IndexType(0));
152  std::copy(id.begin(), id.end(), shuffle.rbegin());
153  auto tensor_row_major = wf.t_tensor.swap_layout().shuffle(shuffle);
154 
155  // Evaluate the expression if needed
156  typedef TensorEvaluator<const TensorForcedEvalOp<const decltype(tensor_row_major)>, DefaultDevice> Evaluator;
157  TensorForcedEvalOp<const decltype(tensor_row_major)> eval = tensor_row_major.eval();
158  Evaluator tensor(eval, DefaultDevice());
159  tensor.evalSubExprsIfNeeded(NULL);
160  internal::TensorPrinter<Evaluator, rank, Format>::run(os, tensor, wf.t_format);
161  // Cleanup.
162  tensor.cleanup();
163  return os;
164  }
165 
166  protected:
167  T t_tensor;
168  Format t_format;
169 };
170 
171 template <typename T, typename Format>
172 class TensorWithFormat<T, ColMajor, 0, Format> {
173  public:
174  TensorWithFormat(const T& tensor, const Format& format) : t_tensor(tensor), t_format(format) {}
175 
176  friend std::ostream& operator<<(std::ostream& os, const TensorWithFormat<T, ColMajor, 0, Format>& wf) {
177  // Evaluate the expression if needed
178  typedef TensorEvaluator<const TensorForcedEvalOp<const T>, DefaultDevice> Evaluator;
179  TensorForcedEvalOp<const T> eval = wf.t_tensor.eval();
180  Evaluator tensor(eval, DefaultDevice());
181  tensor.evalSubExprsIfNeeded(NULL);
182  internal::TensorPrinter<Evaluator, 0, Format>::run(os, tensor, wf.t_format);
183  // Cleanup.
184  tensor.cleanup();
185  return os;
186  }
187 
188  protected:
189  T t_tensor;
190  Format t_format;
191 };
192 
193 namespace internal {
194 
195 // Default scalar printer.
196 template <typename Scalar, typename Format, typename EnableIf = void>
197 struct ScalarPrinter {
198  static void run(std::ostream& stream, const Scalar& scalar, const Format&) { stream << scalar; }
199 };
200 
201 template <typename Scalar>
202 struct ScalarPrinter<Scalar, TensorIOFormatNumpy, std::enable_if_t<NumTraits<Scalar>::IsComplex>> {
203  static void run(std::ostream& stream, const Scalar& scalar, const TensorIOFormatNumpy&) {
204  stream << numext::real(scalar) << "+" << numext::imag(scalar) << "j";
205  }
206 };
207 
208 template <typename Scalar>
209 struct ScalarPrinter<Scalar, TensorIOFormatNative, std::enable_if_t<NumTraits<Scalar>::IsComplex>> {
210  static void run(std::ostream& stream, const Scalar& scalar, const TensorIOFormatNative&) {
211  stream << "{" << numext::real(scalar) << ", " << numext::imag(scalar) << "}";
212  }
213 };
214 
215 template <typename Tensor, std::size_t rank, typename Format, typename EnableIf>
216 struct TensorPrinter {
217  using Scalar = std::remove_const_t<typename Tensor::Scalar>;
218 
219  static void run(std::ostream& s, const Tensor& tensor, const Format& fmt) {
220  typedef typename Tensor::Index IndexType;
221 
222  eigen_assert(Tensor::Layout == RowMajor);
223  typedef std::conditional_t<is_same<Scalar, char>::value || is_same<Scalar, unsigned char>::value ||
224  is_same<Scalar, numext::int8_t>::value || is_same<Scalar, numext::uint8_t>::value,
225  int,
226  std::conditional_t<is_same<Scalar, std::complex<char>>::value ||
227  is_same<Scalar, std::complex<unsigned char>>::value ||
228  is_same<Scalar, std::complex<numext::int8_t>>::value ||
229  is_same<Scalar, std::complex<numext::uint8_t>>::value,
230  std::complex<int>, const Scalar&>>
231  PrintType;
232 
233  const IndexType total_size = array_prod(tensor.dimensions());
234 
235  std::streamsize explicit_precision;
236  if (fmt.precision == StreamPrecision) {
237  explicit_precision = 0;
238  } else if (fmt.precision == FullPrecision) {
239  if (NumTraits<Scalar>::IsInteger) {
240  explicit_precision = 0;
241  } else {
242  explicit_precision = significant_decimals_impl<Scalar>::run();
243  }
244  } else {
245  explicit_precision = fmt.precision;
246  }
247 
248  std::streamsize old_precision = 0;
249  if (explicit_precision) old_precision = s.precision(explicit_precision);
250 
251  IndexType width = 0;
252  bool align_cols = !(fmt.flags & DontAlignCols);
253  if (align_cols) {
254  // compute the largest width
255  for (IndexType i = 0; i < total_size; i++) {
256  std::stringstream sstr;
257  sstr.copyfmt(s);
258  ScalarPrinter<Scalar, Format>::run(sstr, static_cast<PrintType>(tensor.data()[i]), fmt);
259  width = std::max<IndexType>(width, IndexType(sstr.str().length()));
260  }
261  }
262  s << fmt.tenPrefix;
263  for (IndexType i = 0; i < total_size; i++) {
264  std::array<bool, rank> is_at_end{};
265  std::array<bool, rank> is_at_begin{};
266 
267  // is the ith element the end of an coeff (always true), of a row, of a matrix, ...?
268  for (std::size_t k = 0; k < rank; k++) {
269  if ((i + 1) % (std::accumulate(tensor.dimensions().rbegin(), tensor.dimensions().rbegin() + k, 1,
270  std::multiplies<IndexType>())) ==
271  0) {
272  is_at_end[k] = true;
273  }
274  }
275 
276  // is the ith element the begin of an coeff (always true), of a row, of a matrix, ...?
277  for (std::size_t k = 0; k < rank; k++) {
278  if (i % (std::accumulate(tensor.dimensions().rbegin(), tensor.dimensions().rbegin() + k, 1,
279  std::multiplies<IndexType>())) ==
280  0) {
281  is_at_begin[k] = true;
282  }
283  }
284 
285  // do we have a line break?
286  bool is_at_begin_after_newline = false;
287  for (std::size_t k = 0; k < rank; k++) {
288  if (is_at_begin[k]) {
289  std::size_t separator_index = (k < fmt.separator.size()) ? k : fmt.separator.size() - 1;
290  if (fmt.separator[separator_index].find('\n') != std::string::npos) {
291  is_at_begin_after_newline = true;
292  }
293  }
294  }
295 
296  bool is_at_end_before_newline = false;
297  for (std::size_t k = 0; k < rank; k++) {
298  if (is_at_end[k]) {
299  std::size_t separator_index = (k < fmt.separator.size()) ? k : fmt.separator.size() - 1;
300  if (fmt.separator[separator_index].find('\n') != std::string::npos) {
301  is_at_end_before_newline = true;
302  }
303  }
304  }
305 
306  std::stringstream suffix, prefix, separator;
307  for (std::size_t k = 0; k < rank; k++) {
308  std::size_t suffix_index = (k < fmt.suffix.size()) ? k : fmt.suffix.size() - 1;
309  if (is_at_end[k]) {
310  suffix << fmt.suffix[suffix_index];
311  }
312  }
313  for (std::size_t k = 0; k < rank; k++) {
314  std::size_t separator_index = (k < fmt.separator.size()) ? k : fmt.separator.size() - 1;
315  if (is_at_end[k] &&
316  (!is_at_end_before_newline || fmt.separator[separator_index].find('\n') != std::string::npos)) {
317  separator << fmt.separator[separator_index];
318  }
319  }
320  for (std::size_t k = 0; k < rank; k++) {
321  std::size_t spacer_index = (k < fmt.spacer.size()) ? k : fmt.spacer.size() - 1;
322  if (i != 0 && is_at_begin_after_newline && (!is_at_begin[k] || k == 0)) {
323  prefix << fmt.spacer[spacer_index];
324  }
325  }
326  for (int k = rank - 1; k >= 0; k--) {
327  std::size_t prefix_index = (static_cast<std::size_t>(k) < fmt.prefix.size()) ? k : fmt.prefix.size() - 1;
328  if (is_at_begin[k]) {
329  prefix << fmt.prefix[prefix_index];
330  }
331  }
332 
333  s << prefix.str();
334  // So we don't mess around with formatting, output scalar to a string stream, and adjust the width/fill manually.
335  std::stringstream sstr;
336  sstr.copyfmt(s);
337  ScalarPrinter<Scalar, Format>::run(sstr, static_cast<PrintType>(tensor.data()[i]), fmt);
338  std::string scalar_str = sstr.str();
339  IndexType scalar_width = scalar_str.length();
340  if (width && scalar_width < width) {
341  std::string filler;
342  for (IndexType j = scalar_width; j < width; ++j) {
343  filler.push_back(fmt.fill);
344  }
345  s << filler;
346  }
347  s << scalar_str;
348  s << suffix.str();
349  if (i < total_size - 1) {
350  s << separator.str();
351  }
352  }
353  s << fmt.tenSuffix;
354  if (explicit_precision) s.precision(old_precision);
355  }
356 };
357 
358 template <typename Tensor, std::size_t rank>
359 struct TensorPrinter<Tensor, rank, TensorIOFormatLegacy, std::enable_if_t<rank != 0>> {
360  using Format = TensorIOFormatLegacy;
361  using Scalar = std::remove_const_t<typename Tensor::Scalar>;
362 
363  static void run(std::ostream& s, const Tensor& tensor, const Format&) {
364  typedef typename Tensor::Index IndexType;
365  // backwards compatibility case: print tensor after reshaping to matrix of size dim(0) x
366  // (dim(1)*dim(2)*...*dim(rank-1)).
367  const IndexType total_size = internal::array_prod(tensor.dimensions());
368  if (total_size > 0) {
369  const IndexType first_dim = Eigen::internal::array_get<0>(tensor.dimensions());
370  Map<const Array<Scalar, Dynamic, Dynamic, Tensor::Layout>> matrix(tensor.data(), first_dim,
371  total_size / first_dim);
372  s << matrix;
373  return;
374  }
375  }
376 };
377 
378 template <typename Tensor, typename Format>
379 struct TensorPrinter<Tensor, 0, Format> {
380  static void run(std::ostream& s, const Tensor& tensor, const Format& fmt) {
381  using Scalar = std::remove_const_t<typename Tensor::Scalar>;
382 
383  std::streamsize explicit_precision;
384  if (fmt.precision == StreamPrecision) {
385  explicit_precision = 0;
386  } else if (fmt.precision == FullPrecision) {
387  if (NumTraits<Scalar>::IsInteger) {
388  explicit_precision = 0;
389  } else {
390  explicit_precision = significant_decimals_impl<Scalar>::run();
391  }
392  } else {
393  explicit_precision = fmt.precision;
394  }
395 
396  std::streamsize old_precision = 0;
397  if (explicit_precision) old_precision = s.precision(explicit_precision);
398  s << fmt.tenPrefix;
399  ScalarPrinter<Scalar, Format>::run(s, tensor.coeff(0), fmt);
400  s << fmt.tenSuffix;
401  if (explicit_precision) s.precision(old_precision);
402  }
403 };
404 
405 } // end namespace internal
406 template <typename T>
407 std::ostream& operator<<(std::ostream& s, const TensorBase<T, ReadOnlyAccessors>& t) {
408  s << t.format(TensorIOFormat::Plain());
409  return s;
410 }
411 } // end namespace Eigen
412 
413 #endif // EIGEN_CXX11_TENSOR_TENSOR_IO_H
static constexpr lastp1_t end
Namespace containing all symbols from the Eigen library.
Definition: AutoDiffScalar.h:629