diff --git a/heat/core/printing.py b/heat/core/printing.py index e06db65c5..5f9f95218 100644 --- a/heat/core/printing.py +++ b/heat/core/printing.py @@ -303,6 +303,9 @@ def _tensor_str(dndarray, indent: int) -> str: # to do so, we slice up the torch data and forward it to torch internal printing mechanism summarize = elements > get_printoptions()["threshold"] torch_data = _torch_data(dndarray, summarize) + if not dndarray.is_distributed(): + # let torch handle formatting on non-distributed data + # formatter gets too slow for even moderately large tensors + return torch._tensor_str._tensor_str(torch_data, indent) formatter = torch._tensor_str._Formatter(torch_data) - return torch._tensor_str._tensor_str_with_formatter(torch_data, indent, summarize, formatter)