diff --git a/tests/utility/test_tensor_dict_utilities.py b/tests/utility/test_tensor_dict_utilities.py index 344cf3a8..dfd033c0 100644 --- a/tests/utility/test_tensor_dict_utilities.py +++ b/tests/utility/test_tensor_dict_utilities.py @@ -206,6 +206,20 @@ def test_dataproto_pad_unpad(): assert (unpadd_data.non_tensor_batch['labels'] == labels).all() assert unpadd_data.meta_info == {'info': 'test_info'} + padded_data, pad_size = pad_dataproto_to_divisor(data, size_divisor=7) + assert pad_size == 4 + + expected_obs = torch.tensor([[1, 2], [3, 4], [5, 6], [1, 2], [3, 4], [5, 6], [1, 2]]) + expected_labels = ['a', 'b', 'c', 'a', 'b', 'c', 'a'] + assert torch.all(torch.eq(padded_data.batch['obs'], expected_obs)) + assert (padded_data.non_tensor_batch['labels'] == expected_labels).all() + assert padded_data.meta_info == {'info': 'test_info'} + + unpadd_data = unpad_dataproto(padded_data, pad_size=pad_size) + assert torch.all(torch.eq(unpadd_data.batch['obs'], obs)) + assert (unpadd_data.non_tensor_batch['labels'] == labels).all() + assert unpadd_data.meta_info == {'info': 'test_info'} + def test_dataproto_fold_unfold(): from verl.protocol import fold_batch_dim, unfold_batch_dim, DataProto diff --git a/verl/protocol.py b/verl/protocol.py index 80626242..7f434465 100644 --- a/verl/protocol.py +++ b/verl/protocol.py @@ -51,7 +51,13 @@ def pad_dataproto_to_divisor(data: 'DataProto', size_divisor: int): assert isinstance(data, DataProto), 'data must be a DataProto' if len(data) % size_divisor != 0: pad_size = size_divisor - len(data) % size_divisor - data_padded = DataProto.concat([data, data[:pad_size]]) + padding_protos = [] + remaining_pad = pad_size + while remaining_pad > 0: + take_size = min(remaining_pad, len(data)) + padding_protos.append(data[:take_size]) + remaining_pad -= take_size + data_padded = DataProto.concat([data] + padding_protos) else: pad_size = 0 data_padded = data