Skip to content

Commit

Permalink
[BugFix] Fix lib tests (#2218)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 8, 2024
1 parent 1bd3814 commit c67ad59
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2812,14 +2812,14 @@ def _minari_selected_datasets():
torch.manual_seed(0)

keys = list(minari.list_remote_datasets())
indices = torch.randperm(len(keys))[:10]
indices = torch.randperm(len(keys))[:20]
keys = [keys[idx] for idx in indices]
keys = [
key
for key in keys
if "=0.4" in minari.list_remote_datasets()[key]["minari_version"]
]
assert len(keys) > 5
assert len(keys) > 5, keys
_MINARI_DATASETS += keys


Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1696,7 +1696,7 @@ def _padded_indices(self, shapes, arange) -> torch.Tensor:
return pad

@implement_for("torch", None, "2.4")
def _padded_indices(self, shapes, arange) -> torch.Tensor:
def _padded_indices(self, shapes, arange) -> torch.Tensor: # noqa: F811
arange = arange.flip(0).split(shapes.flip(0).squeeze().unbind())
return (
torch.nn.utils.rnn.pad_sequence(arange, batch_first=True, padding_value=-1)
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class Storage:

ndim = 1
max_size: int
_default_checkpointer: StorageCheckpointerBase
_default_checkpointer: StorageCheckpointerBase = StorageCheckpointerBase

def __init__(
self, max_size: int, checkpointer: StorageCheckpointerBase | None = None
Expand Down

0 comments on commit c67ad59

Please sign in to comment.