diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 21245e37acd..44a68d045cc 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -750,24 +750,33 @@ def _find_start_stop_traj(*, trajectory=None, end=None): dim=0, value=1, ) - if end.ndim != 1: + ndim = end.ndim + if ndim == 0: raise RuntimeError( - f"Expected the end-of-trajectory signal to be 1-dimensional. Got a {end.ndim} tensor instead." + f"Expected the end-of-trajectory signal to be at least 1-dimensional." ) - stop_idx = end.view(-1).nonzero().view(-1) - start_idx = torch.cat([torch.zeros_like(stop_idx[:1]), stop_idx[:-1] + 1]) - lengths = stop_idx - start_idx + 1 + stop_idx = end.nonzero() + beginnings = torch.cat([torch.ones_like(end[:1]), end[:-1]], 0) + start_idx = beginnings.nonzero() + sort = start_idx[:, 1].sort(-1)[1] + start_idx = start_idx[sort] + print("stop_idx", stop_idx) + print("start_idx", start_idx) + lengths = stop_idx[:, 0] - start_idx[:, 0] + 1 + print("lengths", lengths) return start_idx, stop_idx, lengths def _tensor_slices_from_startend(self, seq_length, start): if isinstance(seq_length, int): - return ( - torch.arange( + print('seq_length', seq_length) + print('start', start) + arange = torch.arange( seq_length, device=start.device, dtype=start.dtype - ).unsqueeze(0) - + start.unsqueeze(1) - ).view(-1) + ) + arange = torch.stack([arange, torch.zeros_like(arange)], -1) + return arange + start else: + raise NotImplementedError # when padding is needed return torch.cat( [ @@ -883,6 +892,7 @@ def _sample_slices( else: num_slices = traj_idx.shape[0] + print('lengths', lengths, 'seq_length', seq_length) if (lengths < seq_length).any(): if self.strict_length: raise RuntimeError( @@ -901,8 +911,9 @@ def _sample_slices( .floor() .to(start_idx.dtype) ) - starts = start_idx[traj_idx] + relative_starts + starts = torch.stack([start_idx[traj_idx, 0] + relative_starts, start_idx[traj_idx, 1]], -1) index = self._tensor_slices_from_startend(seq_length, starts) + print("index", index) if self.truncated_key is not None: truncated_key = self.truncated_key done_key = _replace_last(truncated_key, "done")