Skip to content

Commit

Permalink
feat: improve error message when refolding with missing dims
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Sep 16, 2024
1 parent 4b1721c commit 8755cbf
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions foldedtensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,9 +402,16 @@ def refold(self, *dims: Union[Sequence[Union[int, str]], int, str]):
"sequence or each arguments to be ints or strings"
)
dims = dims[0]
dims = tuple(
dim if isinstance(dim, int) else self.full_names.index(dim) for dim in dims
)
try:
dims = tuple(
dim if isinstance(dim, int) else self.full_names.index(dim)
for dim in dims
)
except ValueError:
raise Exception(
f"Folded tensor with available dimensions {self.full_names} "
f"could not be refolded with dimensions {list(dims)}"
)

if dims == self.data_dims:
return self
Expand Down

0 comments on commit 8755cbf

Please sign in to comment.