Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 25, 2024
1 parent 1a4c9f6 commit 924e062
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
6 changes: 4 additions & 2 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1277,8 +1277,10 @@ def __init__(
max_steps = torch.tensor(5)
if start_val is None:
start_val = torch.zeros((), dtype=torch.int32)
if not max_steps.shape == self.batch_size:
raise RuntimeError("batch_size and max_steps shape must match.")
if max_steps.shape != self.batch_size:
raise RuntimeError(
f"batch_size and max_steps shape must match. Got self.batch_size={self.batch_size} and max_steps.shape={max_steps.shape}."
)

self.max_steps = max_steps

Expand Down
7 changes: 6 additions & 1 deletion torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,9 +372,14 @@ def __init__(
self.__dict__.setdefault("_batch_size", None)
self.__dict__.setdefault("_device", None)

if device is not None:
if batch_size is not None:
# we want an error to be raised if we pass batch_size but
# it's already been set
batch_size = self.batch_size = torch.Size(batch_size)
else:
batch_size = torch.Size(())

if device is not None:
device = self.__dict__["_device"] = _make_ordinal_device(
torch.device(device)
)
Expand Down

0 comments on commit 924e062

Please sign in to comment.