Skip to content

Commit

Permalink
[dataproto] fix: add assertion for uneven chunk (#115)
Browse files Browse the repository at this point in the history
- forbid uneven chunk for DataProto
  • Loading branch information
vermouth1992 authored Jan 18, 2025
1 parent 5a94e14 commit 1ec5eb5
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 2 deletions.
23 changes: 23 additions & 0 deletions tests/utility/test_tensor_dict_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def test_chunk_concat():
labels = ['a', 'b', 'c', 'd', 'e', 'f']
data = DataProto.from_dict(tensors={'obs': obs}, non_tensors={'labels': labels}, meta_info={'name': 'abdce'})

with pytest.raises(AssertionError):
data.chunk(5)

data_split = data.chunk(2)
assert len(data_split) == 2
assert torch.all(torch.eq(data_split[0].batch['obs'], torch.tensor([1, 2, 3])))
Expand Down Expand Up @@ -237,3 +240,23 @@ def test_torch_save_data_proto():

import os
os.remove('test_data.pt')


def test_len():
obs = torch.tensor([[1, 2], [3, 4], [5, 6]])
labels = np.array(['a', 'b', 'c'], dtype=object)
data = DataProto.from_dict(tensors={'obs': obs}, non_tensors={'labels': labels}, meta_info={'info': 'test_info'})

assert len(data) == 3

data = DataProto(batch=None, non_tensor_batch={'labels': labels}, meta_info={'info': 'test_info'})

assert len(data) == 3

data = DataProto(batch=None, non_tensor_batch={}, meta_info={'info': 'test_info'})

assert len(data) == 0

data = DataProto(batch=None, non_tensor_batch=None, meta_info={'info': 'test_info'})

assert len(data) == 0
17 changes: 15 additions & 2 deletions verl/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,13 @@ def __post_init__(self):
self.check_consistency()

def __len__(self):
return self.batch.batch_size[0]
if self.batch is not None:
return self.batch.batch_size[0]
elif self.non_tensor_batch is not None and len(self.non_tensor_batch) > 0:
random_key = list(self.non_tensor_batch.keys())[0]
return self.non_tensor_batch[random_key].shape[0]
else:
return 0

def __getitem__(self, item):
tensor_data = self.batch[item]
Expand Down Expand Up @@ -240,7 +246,11 @@ def check_consistency(self):
if self.batch is not None:
assert len(self.batch.batch_size) == 1, 'only support num_batch_dims=1'

if len(self.non_tensor_batch) != 0:
if self.non_tensor_batch is not None:
for key, val in self.non_tensor_batch.items():
assert isinstance(val, np.ndarray)

if self.batch is not None and len(self.non_tensor_batch) != 0:
# TODO: we can actually lift this restriction if needed
assert len(self.batch.batch_size) == 1, 'only support num_batch_dims=1 when non_tensor_batch is not empty.'

Expand Down Expand Up @@ -478,6 +488,9 @@ def chunk(self, chunks: int) -> List['DataProto']:
Returns:
List[DataProto]: a list of DataProto after splitting
"""
assert len(
self) % chunks == 0, f'only support equal chunk. Got size of DataProto {len(self)} and chunk {chunks}.'

if self.batch is not None:
batch_lst = self.batch.chunk(chunks=chunks, dim=0)
else:
Expand Down

0 comments on commit 1ec5eb5

Please sign in to comment.