Skip to content

Commit

Permalink
feat: improve error message when refolding with missing dims
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Sep 16, 2024
1 parent 4b1721c commit 3ab7b1d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
13 changes: 10 additions & 3 deletions foldedtensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,9 +402,16 @@ def refold(self, *dims: Union[Sequence[Union[int, str]], int, str]):
"sequence or each arguments to be ints or strings"
)
dims = dims[0]
dims = tuple(
dim if isinstance(dim, int) else self.full_names.index(dim) for dim in dims
)
try:
dims = tuple(
dim if isinstance(dim, int) else self.full_names.index(dim)
for dim in dims
)
except ValueError:
raise ValueError(
f"Folded tensor with available dimensions {self.full_names} "
f"could not be refolded with dimensions {list(dims)}"
)

if dims == self.data_dims:
return self
Expand Down
15 changes: 15 additions & 0 deletions tests/test_folded_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,3 +431,18 @@ def test_hashable_lengths():
assert tensor.lengths is embedding(tensor).lengths
assert hash(tensor.lengths) is not None
assert hash(tensor.lengths) == hash(embedding(tensor).lengths)


def test_missing_dims():
tensor = as_folded_tensor(
[
[0, 1, 2],
[3, 4],
],
full_names=("sample", "token"),
dtype=torch.long,
)
with pytest.raises(ValueError) as e:
tensor.refold("line", "token")

assert "line" in str(e.value)

0 comments on commit 3ab7b1d

Please sign in to comment.