Skip to content

Commit

Permalink
Fix the serialization of scalar valued tensors (#431)
Browse files Browse the repository at this point in the history
  • Loading branch information
enrico-stauss authored Dec 3, 2024
1 parent 7df78d3 commit fa5b13c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/litdata/streaming/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def deserialize(self, data: bytes) -> torch.Tensor:
return torch.reshape(tensor, shape)

def can_serialize(self, item: torch.Tensor) -> bool:
return isinstance(item, torch.Tensor) and type(item) == torch.Tensor and len(item.shape) > 1
return isinstance(item, torch.Tensor) and type(item) == torch.Tensor and len(item.shape) != 1


class NoHeaderTensorSerializer(Serializer):
Expand Down
17 changes: 17 additions & 0 deletions tests/streaming/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,14 @@ def test_deserialize_empty_tensor():
assert torch.equal(t, new_t)


def test_deserialize_scalar_tensor():
serializer = TensorSerializer()
t = torch.tensor(0)
data, _ = serializer.serialize(t)
new_t = serializer.deserialize(data)
assert torch.equal(t, new_t)


def test_deserialize_empty_no_header_tensor():
serializer = NoHeaderTensorSerializer()
t = torch.ones((0,)).int()
Expand All @@ -271,6 +279,15 @@ def test_deserialize_empty_no_header_tensor():
assert torch.equal(t, new_t)


def test_can_serialize_tensor():
serializer = TensorSerializer()
# Check that the TensorSerializer can serialize scalar valued tensors as well as higher order (>1) Tensors
assert serializer.can_serialize(torch.tensor(0))
assert serializer.can_serialize(torch.tensor([[0, 0]]))
# Check that it does not serialize Tensors of order 1, those are treated by the dedicated NoHeaderTensorSerializer
assert not serializer.can_serialize(torch.tensor([0, 0]))


@pytest.mark.skipif(not _TIFFFILE_AVAILABLE, reason="Requires: ['tifffile']")
def test_tiff_serializer():
serializer = TIFFSerializer()
Expand Down

0 comments on commit fa5b13c

Please sign in to comment.