Skip to content

Commit

Permalink
[Performance] Faster SliceSampler._tensor_slices_from_startend
Browse files Browse the repository at this point in the history
  • Loading branch information
kurtamohler committed Sep 9, 2024
1 parent 57f0580 commit fec4f40
Showing 1 changed file with 21 additions and 2 deletions.
23 changes: 21 additions & 2 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,9 +1076,28 @@ def _tensor_slices_from_startend(self, seq_length, start, storage_length):
# seq_length is a 1d tensor indicating the desired length of each sequence

if isinstance(seq_length, int):
result = torch.cat(
[self._start_to_end(_start, length=seq_length) for _start in start]
arange = torch.arange(seq_length, device=start.device, dtype=start.dtype)
ndims = start.shape[-1] - 1 if (start.ndim - 1) else 0
if ndims:
arange_reshaped = torch.empty(
arange.shape + torch.Size([ndims + 1]),
device=start.device,
dtype=start.dtype,
)
arange_reshaped[..., 0] = arange
arange_reshaped[..., 1:] = 0
else:
arange_reshaped = arange.unsqueeze(-1)
arange_expanded = arange_reshaped.expand(
torch.Size([start.shape[0]]) + arange_reshaped.shape
)
if start.shape != arange_expanded.shape:
n_missing_dims = arange_expanded.dim() - start.dim()
start_expanded = start[
(slice(None),) + (None,) * n_missing_dims
].expand_as(arange_expanded)
result = (start_expanded + arange_expanded).flatten(0, 1)

else:
# when padding is needed
result = torch.cat(
Expand Down

0 comments on commit fec4f40

Please sign in to comment.