From 7425b3699daf2e749d597a8af8c7947580826dc0 Mon Sep 17 00:00:00 2001 From: shengguangming Date: Sat, 7 Dec 2024 12:03:01 +0800 Subject: [PATCH] [dataproto] update repeat and unpad/pad --- verl/protocol.py | 111 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 111 insertions(+) diff --git a/verl/protocol.py b/verl/protocol.py index 827c44e..8c2946d 100644 --- a/verl/protocol.py +++ b/verl/protocol.py @@ -36,6 +36,32 @@ pass +def pad_dataproto_to_divisor(data: 'DataProto', size_divisor: int): + """Pad a DataProto to size divisible by size_divisor + + Args: + size_divisor (int): size divisor + + Returns: + data: (DataProto): the padded DataProto + pad_size (int) + """ + assert isinstance(data, DataProto), 'data must be a DataProto' + if len(data) % size_divisor != 0: + pad_size = size_divisor - len(data) % size_divisor + data_padded = DataProto.concat([data, data[:pad_size]]) + else: + pad_size = 0 + data_padded = data + return data_padded, pad_size + + +def unpad_dataproto(data: 'DataProto', pad_size): + if pad_size != 0: + data = data[:-pad_size] + return data + + def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict: """Union two tensordicts.""" assert tensor_dict1.batch_size == tensor_dict2.batch_size, \ @@ -403,6 +429,47 @@ def chunk(self, chunks: int) -> List['DataProto']: return output + def unfold_column_chunks(self, n_split, split_keys=None): + """Split along the second dim into `n_split`, unfold it to the first dim (batch dim) + Useful in passing grouped tensors that doesn't want to be shuffled in dataset. + keys not in split_keys are repeated to match the shape + """ + if split_keys is None: + split_keys = list(self.batch.keys()) + + if self.batch is not None: + unfolded_batch = {} + for key in self.batch.keys(): + if key in split_keys: + shape = list(self.batch[key].shape) + shape[0] = self.batch[key].shape[0] * n_split + shape[1] = self.batch[key].shape[1] // n_split + unfolded_batch[key] = self.batch[key].reshape(*shape) + else: + unfolded_batch[key] = torch.repeat_interleave(self.batch[key], n_split, dim=0) + else: + unfolded_batch = None + + unfolded_batch = TensorDict( + source=unfolded_batch, + batch_size=(self.batch.batch_size[0] * n_split,), + ) + repeated_non_tensor_batch = {} + for key, val in self.non_tensor_batch.items(): + if key in split_keys: + shape = list(val.shape) + shape[0] = val.shape[0] * n_split + shape[1] = val.shape[1] // n_split + repeated_non_tensor_batch[key] = val.reshape(*shape) + else: + repeated_non_tensor_batch[key] = np.repeat(val, n_split, axis=0) + + return DataProto( + batch=unfolded_batch, + non_tensor_batch=repeated_non_tensor_batch, + meta_info=self.meta_info, + ) + @staticmethod def concat(data: List['DataProto']) -> 'DataProto': """Concat a list of DataProto. The batch is concatenated among dim=0. @@ -436,6 +503,50 @@ def reorder(self, indices): self.batch = self.batch[indices] self.non_tensor_batch = {key: val[indices_np] for key, val in self.non_tensor_batch.items()} + def repeat(self, repeat_times=2, interleave=True): + """ + Repeat the batch data a specified number of times. + + Args: + repeat_times (int): Number of times to repeat the data. + interleave (bool): Whether to interleave the repeated data. + + Returns: + DataProto: A new DataProto with repeated data. + """ + if self.batch is not None: + if interleave: + # Interleave the data + repeated_tensors = { + key: tensor.repeat_interleave(repeat_times, dim=0) for key, tensor in self.batch.items() + } + else: + # Stack the data + repeated_tensors = { + key: tensor.unsqueeze(0).expand(repeat_times, *tensor.shape).reshape(-1, *tensor.shape[1:]) + for key, tensor in self.batch.items() + } + + repeated_batch = TensorDict( + source=repeated_tensors, + batch_size=(self.batch.batch_size[0] * repeat_times,), + ) + else: + repeated_batch = None + + repeated_non_tensor_batch = {} + for key, val in self.non_tensor_batch.items(): + if interleave: + repeated_non_tensor_batch[key] = np.repeat(val, repeat_times, axis=0) + else: + repeated_non_tensor_batch[key] = np.tile(val, (repeat_times,) + (1,) * (val.ndim - 1)) + + return DataProto( + batch=repeated_batch, + non_tensor_batch=repeated_non_tensor_batch, + meta_info=self.meta_info, + ) + import ray