Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 7, 2024
1 parent 91941ea commit 5e06cd4
Showing 1 changed file with 22 additions and 11 deletions.
33 changes: 22 additions & 11 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,24 +750,33 @@ def _find_start_stop_traj(*, trajectory=None, end=None):
dim=0,
value=1,
)
if end.ndim != 1:
ndim = end.ndim
if ndim == 0:
raise RuntimeError(
f"Expected the end-of-trajectory signal to be 1-dimensional. Got a {end.ndim} tensor instead."
f"Expected the end-of-trajectory signal to be at least 1-dimensional."
)
stop_idx = end.view(-1).nonzero().view(-1)
start_idx = torch.cat([torch.zeros_like(stop_idx[:1]), stop_idx[:-1] + 1])
lengths = stop_idx - start_idx + 1
stop_idx = end.nonzero()
beginnings = torch.cat([torch.ones_like(end[:1]), end[:-1]], 0)
start_idx = beginnings.nonzero()
sort = start_idx[:, 1].sort(-1)[1]
start_idx = start_idx[sort]
print("stop_idx", stop_idx)
print("start_idx", start_idx)
lengths = stop_idx[:, 0] - start_idx[:, 0] + 1
print("lengths", lengths)
return start_idx, stop_idx, lengths

def _tensor_slices_from_startend(self, seq_length, start):
if isinstance(seq_length, int):
return (
torch.arange(
print('seq_length', seq_length)
print('start', start)
arange = torch.arange(
seq_length, device=start.device, dtype=start.dtype
).unsqueeze(0)
+ start.unsqueeze(1)
).view(-1)
)
arange = torch.stack([arange, torch.zeros_like(arange)], -1)
return arange + start
else:
raise NotImplementedError
# when padding is needed
return torch.cat(
[
Expand Down Expand Up @@ -883,6 +892,7 @@ def _sample_slices(
else:
num_slices = traj_idx.shape[0]

print('lengths', lengths, 'seq_length', seq_length)
if (lengths < seq_length).any():
if self.strict_length:
raise RuntimeError(
Expand All @@ -901,8 +911,9 @@ def _sample_slices(
.floor()
.to(start_idx.dtype)
)
starts = start_idx[traj_idx] + relative_starts
starts = torch.stack([start_idx[traj_idx, 0] + relative_starts, start_idx[traj_idx, 1]], -1)
index = self._tensor_slices_from_startend(seq_length, starts)
print("index", index)
if self.truncated_key is not None:
truncated_key = self.truncated_key
done_key = _replace_last(truncated_key, "done")
Expand Down

0 comments on commit 5e06cd4

Please sign in to comment.