diff --git a/test/test_rb.py b/test/test_rb.py index 4568c654cc0..5a849224117 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -914,7 +914,7 @@ def test_replay_buffer_trajectories(stack, reduction, datatype): rb.extend(traj_td) if datatype == "tc": sampled_td, info = rb.sample(return_info=True) - index= info["index"] + index = info["index"] else: sampled_td = rb.sample() if datatype == "tc": diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 5a1eccb4f45..d507e02b24b 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -986,8 +986,16 @@ def _set_index_in_td(self, tensordict, index): return if _is_int(index): index = torch.as_tensor(index, device=tensordict.device) - elif index.shape[0] != tensordict.shape: - index = index.unflatten(0, tensordict.shape) + elif index.ndim == 2 and index.shape[:1] != tensordict.shape[:1]: + for dim in range(2, tensordict.ndim + 1): + if index.shape[:1].numel() == tensordict.shape[:dim].numel(): + # if index has 2 dims and is in a non-zero format + index = index.unflatten(0, tensordict.shape[:dim]) + break + else: + raise RuntimeError( + f"could not find how to reshape index with shape {index.shape} to fit in tensordict with shape {tensordict.shape}" + ) tensordict.set("index", index) return tensordict.set("index", expand_as_right(index, tensordict))