diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index bbdf2387683..2ad0550ed06 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -1485,13 +1485,13 @@ def _get_index( truncated[seq_length.cumsum(0) - 1] = 1 index = index.to(torch.long).unbind(-1) st_index = storage[index] - try: - done = st_index[done_key] | truncated - except KeyError: + done = st_index.get(done_key, default=None) + if done is None: done = truncated.clone() - try: - terminated = st_index[terminated_key] - except KeyError: + else: + done = done | truncated + terminated = st_index.get(terminated_key, default=None) + if terminated is None: terminated = torch.zeros_like(truncated) return index, { truncated_key: truncated,