diff --git a/test/mocking_classes.py b/test/mocking_classes.py index bb902f879b1..3c30286c419 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1996,3 +1996,17 @@ def _step( def _set_seed(self, seed: Optional[int]): ... + + +class EnvThatDoesNothing(EnvBase): + def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: + return TensorDict(batch_size=self.batch_size, device=self.device) + + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + return TensorDict(batch_size=self.batch_size, device=self.device) + + def _set_seed(self, seed): + ... diff --git a/test/test_env.py b/test/test_env.py index 81708b0b9a6..b48b1a1cf8f 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -44,6 +44,7 @@ DiscreteActionConvMockEnvNumpy, DiscreteActionVecMockEnv, DummyModelBasedEnvBase, + EnvThatDoesNothing, EnvWithDynamicSpec, EnvWithMetadata, HeterogeneousCountingEnv, @@ -81,6 +82,7 @@ DiscreteActionConvMockEnvNumpy, DiscreteActionVecMockEnv, DummyModelBasedEnvBase, + EnvThatDoesNothing, EnvWithDynamicSpec, EnvWithMetadata, HeterogeneousCountingEnv, @@ -3554,6 +3556,34 @@ def test_auto_spec(): env.check_env_specs(tensordict=td.copy()) +def test_env_that_does_nothing(): + env = EnvThatDoesNothing() + env.check_env_specs() + r = env.rollout(3) + r.exclude( + "done", "terminated", ("next", "done"), ("next", "terminated"), inplace=True + ) + assert r.is_empty() + p_env = SerialEnv(2, EnvThatDoesNothing) + p_env.check_env_specs() + r = p_env.rollout(3) + r.exclude( + "done", "terminated", ("next", "done"), ("next", "terminated"), inplace=True + ) + assert r.is_empty() + p_env = ParallelEnv(2, EnvThatDoesNothing) + try: + p_env.check_env_specs() + r = p_env.rollout(3) + r.exclude( + "done", "terminated", ("next", "done"), ("next", "terminated"), inplace=True + ) + assert r.is_empty() + finally: + p_env.close() + del p_env + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index d5a062bc11e..bafe88b639a 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -2434,8 +2434,12 @@ def _register_gym( # noqa: F811 apply_api_compatibility=apply_api_compatibility, ) - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - raise NotImplementedError("EnvBase.forward is not implemented") + def forward(self, *args, **kwargs): + raise NotImplementedError( + "EnvBase.forward is not implemented. If you ended here during a call to `ParallelEnv(...)`, please use " + "a constructor such as `ParallelEnv(num_env, lambda env=env: env)` instead. " + "Batched envs require constructors because environment instances may not always be serializable." + ) @abc.abstractmethod def _step( diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 7454bce99b3..209349878ec 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -287,6 +287,8 @@ def __call__(self, tensordict): if self.validate(tensordict): if self.keep_other: out = self._exclude(self.exclude_from_root, tensordict, out=None) + if out is None: + out = tensordict.empty() else: out = next_td.empty() self._grab_and_place(