Skip to content

Commit 2c717c8

Browse files
committed
amend
1 parent 2fbb08b commit 2c717c8

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

test/test_rb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -914,7 +914,7 @@ def test_replay_buffer_trajectories(stack, reduction, datatype):
914914
rb.extend(traj_td)
915915
if datatype == "tc":
916916
sampled_td, info = rb.sample(return_info=True)
917-
index= info["index"]
917+
index = info["index"]
918918
else:
919919
sampled_td = rb.sample()
920920
if datatype == "tc":

torchrl/data/replay_buffers/replay_buffers.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -986,8 +986,16 @@ def _set_index_in_td(self, tensordict, index):
986986
return
987987
if _is_int(index):
988988
index = torch.as_tensor(index, device=tensordict.device)
989-
elif index.shape[0] != tensordict.shape:
990-
index = index.unflatten(0, tensordict.shape)
989+
elif index.ndim == 2 and index.shape[:1] != tensordict.shape[:1]:
990+
for dim in range(2, tensordict.ndim + 1):
991+
if index.shape[:1].numel() == tensordict.shape[:dim].numel():
992+
# if index has 2 dims and is in a non-zero format
993+
index = index.unflatten(0, tensordict.shape[:dim])
994+
break
995+
else:
996+
raise RuntimeError(
997+
f"could not find how to reshape index with shape {index.shape} to fit in tensordict with shape {tensordict.shape}"
998+
)
991999
tensordict.set("index", index)
9921000
return
9931001
tensordict.set("index", expand_as_right(index, tensordict))

0 commit comments

Comments
 (0)