Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 15, 2024
1 parent 2fbb08b commit 2c717c8
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
2 changes: 1 addition & 1 deletion test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,7 @@ def test_replay_buffer_trajectories(stack, reduction, datatype):
rb.extend(traj_td)
if datatype == "tc":
sampled_td, info = rb.sample(return_info=True)
index= info["index"]
index = info["index"]
else:
sampled_td = rb.sample()
if datatype == "tc":
Expand Down
12 changes: 10 additions & 2 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,8 +986,16 @@ def _set_index_in_td(self, tensordict, index):
return
if _is_int(index):
index = torch.as_tensor(index, device=tensordict.device)
elif index.shape[0] != tensordict.shape:
index = index.unflatten(0, tensordict.shape)
elif index.ndim == 2 and index.shape[:1] != tensordict.shape[:1]:
for dim in range(2, tensordict.ndim + 1):
if index.shape[:1].numel() == tensordict.shape[:dim].numel():
# if index has 2 dims and is in a non-zero format
index = index.unflatten(0, tensordict.shape[:dim])
break
else:
raise RuntimeError(
f"could not find how to reshape index with shape {index.shape} to fit in tensordict with shape {tensordict.shape}"
)
tensordict.set("index", index)
return
tensordict.set("index", expand_as_right(index, tensordict))
Expand Down

0 comments on commit 2c717c8

Please sign in to comment.