From 924e0628a806e9b1e893f7e92dea903d1ca55d5e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 25 Nov 2024 16:18:01 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- test/mocking_classes.py | 6 ++++-- torchrl/envs/common.py | 7 ++++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 8d4c5fe961e..d78e2f27184 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1277,8 +1277,10 @@ def __init__( max_steps = torch.tensor(5) if start_val is None: start_val = torch.zeros((), dtype=torch.int32) - if not max_steps.shape == self.batch_size: - raise RuntimeError("batch_size and max_steps shape must match.") + if max_steps.shape != self.batch_size: + raise RuntimeError( + f"batch_size and max_steps shape must match. Got self.batch_size={self.batch_size} and max_steps.shape={max_steps.shape}." + ) self.max_steps = max_steps diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 860b2c122ff..d5a062bc11e 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -372,9 +372,14 @@ def __init__( self.__dict__.setdefault("_batch_size", None) self.__dict__.setdefault("_device", None) - if device is not None: + if batch_size is not None: # we want an error to be raised if we pass batch_size but # it's already been set + batch_size = self.batch_size = torch.Size(batch_size) + else: + batch_size = torch.Size(()) + + if device is not None: device = self.__dict__["_device"] = _make_ordinal_device( torch.device(device) )