Skip to content

Commit

Permalink
[dataproto] update repeat and unpad/pad
Browse files Browse the repository at this point in the history
  • Loading branch information
PeterSH6 committed Dec 7, 2024
1 parent 3c729fd commit 7425b36
Showing 1 changed file with 111 additions and 0 deletions.
111 changes: 111 additions & 0 deletions verl/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,32 @@
pass


def pad_dataproto_to_divisor(data: 'DataProto', size_divisor: int):
"""Pad a DataProto to size divisible by size_divisor
Args:
size_divisor (int): size divisor
Returns:
data: (DataProto): the padded DataProto
pad_size (int)
"""
assert isinstance(data, DataProto), 'data must be a DataProto'
if len(data) % size_divisor != 0:
pad_size = size_divisor - len(data) % size_divisor
data_padded = DataProto.concat([data, data[:pad_size]])
else:
pad_size = 0
data_padded = data
return data_padded, pad_size


def unpad_dataproto(data: 'DataProto', pad_size):
if pad_size != 0:
data = data[:-pad_size]
return data


def union_tensor_dict(tensor_dict1: TensorDict, tensor_dict2: TensorDict) -> TensorDict:
"""Union two tensordicts."""
assert tensor_dict1.batch_size == tensor_dict2.batch_size, \
Expand Down Expand Up @@ -403,6 +429,47 @@ def chunk(self, chunks: int) -> List['DataProto']:

return output

def unfold_column_chunks(self, n_split, split_keys=None):
"""Split along the second dim into `n_split`, unfold it to the first dim (batch dim)
Useful in passing grouped tensors that doesn't want to be shuffled in dataset.
keys not in split_keys are repeated to match the shape
"""
if split_keys is None:
split_keys = list(self.batch.keys())

if self.batch is not None:
unfolded_batch = {}
for key in self.batch.keys():
if key in split_keys:
shape = list(self.batch[key].shape)
shape[0] = self.batch[key].shape[0] * n_split
shape[1] = self.batch[key].shape[1] // n_split
unfolded_batch[key] = self.batch[key].reshape(*shape)
else:
unfolded_batch[key] = torch.repeat_interleave(self.batch[key], n_split, dim=0)
else:
unfolded_batch = None

unfolded_batch = TensorDict(
source=unfolded_batch,
batch_size=(self.batch.batch_size[0] * n_split,),
)
repeated_non_tensor_batch = {}
for key, val in self.non_tensor_batch.items():
if key in split_keys:
shape = list(val.shape)
shape[0] = val.shape[0] * n_split
shape[1] = val.shape[1] // n_split
repeated_non_tensor_batch[key] = val.reshape(*shape)
else:
repeated_non_tensor_batch[key] = np.repeat(val, n_split, axis=0)

return DataProto(
batch=unfolded_batch,
non_tensor_batch=repeated_non_tensor_batch,
meta_info=self.meta_info,
)

@staticmethod
def concat(data: List['DataProto']) -> 'DataProto':
"""Concat a list of DataProto. The batch is concatenated among dim=0.
Expand Down Expand Up @@ -436,6 +503,50 @@ def reorder(self, indices):
self.batch = self.batch[indices]
self.non_tensor_batch = {key: val[indices_np] for key, val in self.non_tensor_batch.items()}

def repeat(self, repeat_times=2, interleave=True):
"""
Repeat the batch data a specified number of times.
Args:
repeat_times (int): Number of times to repeat the data.
interleave (bool): Whether to interleave the repeated data.
Returns:
DataProto: A new DataProto with repeated data.
"""
if self.batch is not None:
if interleave:
# Interleave the data
repeated_tensors = {
key: tensor.repeat_interleave(repeat_times, dim=0) for key, tensor in self.batch.items()
}
else:
# Stack the data
repeated_tensors = {
key: tensor.unsqueeze(0).expand(repeat_times, *tensor.shape).reshape(-1, *tensor.shape[1:])
for key, tensor in self.batch.items()
}

repeated_batch = TensorDict(
source=repeated_tensors,
batch_size=(self.batch.batch_size[0] * repeat_times,),
)
else:
repeated_batch = None

repeated_non_tensor_batch = {}
for key, val in self.non_tensor_batch.items():
if interleave:
repeated_non_tensor_batch[key] = np.repeat(val, repeat_times, axis=0)
else:
repeated_non_tensor_batch[key] = np.tile(val, (repeat_times,) + (1,) * (val.ndim - 1))

return DataProto(
batch=repeated_batch,
non_tensor_batch=repeated_non_tensor_batch,
meta_info=self.meta_info,
)


import ray

Expand Down

0 comments on commit 7425b36

Please sign in to comment.