From 21eeca42ca715e6b5b80560713e0f280cb825002 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 20 Dec 2024 10:26:42 +0000 Subject: [PATCH] [BugFix] Avoid KeyError in slice sampler (for compile) ghstack-source-id: 6e2a3036f0e50d365387cced50a761b97a47317d Pull Request resolved: https://github.com/pytorch/rl/pull/2670 --- torchrl/data/replay_buffers/samplers.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index bbdf2387683..2ad0550ed06 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -1485,13 +1485,13 @@ def _get_index( truncated[seq_length.cumsum(0) - 1] = 1 index = index.to(torch.long).unbind(-1) st_index = storage[index] - try: - done = st_index[done_key] | truncated - except KeyError: + done = st_index.get(done_key, default=None) + if done is None: done = truncated.clone() - try: - terminated = st_index[terminated_key] - except KeyError: + else: + done = done | truncated + terminated = st_index.get(terminated_key, default=None) + if terminated is None: terminated = torch.zeros_like(truncated) return index, { truncated_key: truncated,