Skip to content

Commit

Permalink
[BugFix] Avoid KeyError in slice sampler (for compile)
Browse files Browse the repository at this point in the history
ghstack-source-id: 6e2a3036f0e50d365387cced50a761b97a47317d
Pull Request resolved: #2670
  • Loading branch information
vmoens committed Dec 20, 2024
1 parent f4709c1 commit 21eeca4
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1485,13 +1485,13 @@ def _get_index(
truncated[seq_length.cumsum(0) - 1] = 1
index = index.to(torch.long).unbind(-1)
st_index = storage[index]
try:
done = st_index[done_key] | truncated
except KeyError:
done = st_index.get(done_key, default=None)
if done is None:
done = truncated.clone()
try:
terminated = st_index[terminated_key]
except KeyError:
else:
done = done | truncated
terminated = st_index.get(terminated_key, default=None)
if terminated is None:
terminated = torch.zeros_like(truncated)
return index, {
truncated_key: truncated,
Expand Down

0 comments on commit 21eeca4

Please sign in to comment.