Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Dec 19, 2024
1 parent c0f7803 commit 6a7ad64
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1485,13 +1485,13 @@ def _get_index(
truncated[seq_length.cumsum(0) - 1] = 1
index = index.to(torch.long).unbind(-1)
st_index = storage[index]
try:
done = st_index[done_key] | truncated
except KeyError:
done = st_index.get(done_key, default=None)
if done is None:
done = truncated.clone()
try:
terminated = st_index[terminated_key]
except KeyError:
else:
done = done | truncated
terminated = st_index.get(terminated_key, default=None)
if terminated is None:
terminated = torch.zeros_like(truncated)
return index, {
truncated_key: truncated,
Expand Down

0 comments on commit 6a7ad64

Please sign in to comment.