Skip to content

Commit

Permalink
[Feature, Test] Adding tests for envs that have no specs
Browse files Browse the repository at this point in the history
ghstack-source-id: 2968cc344b75042d3bb845bfda01e12bf00f4af7
Pull Request resolved: #2621
  • Loading branch information
vmoens committed Dec 2, 2024
1 parent 57ad9df commit 5816d75
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 2 deletions.
14 changes: 14 additions & 0 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1996,3 +1996,17 @@ def _step(

def _set_seed(self, seed: Optional[int]):
...


class EnvThatDoesNothing(EnvBase):
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
return TensorDict(batch_size=self.batch_size, device=self.device)

def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:
return TensorDict(batch_size=self.batch_size, device=self.device)

def _set_seed(self, seed):
...
26 changes: 26 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
DiscreteActionConvMockEnvNumpy,
DiscreteActionVecMockEnv,
DummyModelBasedEnvBase,
EnvThatDoesNothing,
EnvWithDynamicSpec,
EnvWithMetadata,
HeterogeneousCountingEnv,
Expand Down Expand Up @@ -81,6 +82,7 @@
DiscreteActionConvMockEnvNumpy,
DiscreteActionVecMockEnv,
DummyModelBasedEnvBase,
EnvThatDoesNothing,
EnvWithDynamicSpec,
EnvWithMetadata,
HeterogeneousCountingEnv,
Expand Down Expand Up @@ -3554,6 +3556,30 @@ def test_auto_spec():
env.check_env_specs(tensordict=td.copy())


def test_env_that_does_nothing():
env = EnvThatDoesNothing()
env.check_env_specs()
r = env.rollout(3)
r_done = r.separates("done", "terminated", ("next", "done"), ("next", "terminated"))
assert r.is_empty()
p_env = SerialEnv(2, EnvThatDoesNothing)
p_env.check_env_specs()
r = p_env.rollout(3)
r_done = r.separates("done", "terminated", ("next", "done"), ("next", "terminated"))
assert r.is_empty()
p_env = ParallelEnv(2, EnvThatDoesNothing)
try:
p_env.check_env_specs()
r = p_env.rollout(3)
r_done = r.separates(
"done", "terminated", ("next", "done"), ("next", "terminated")
)
assert r.is_empty()
finally:
p_env.close()
del p_env


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
8 changes: 6 additions & 2 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2434,8 +2434,12 @@ def _register_gym( # noqa: F811
apply_api_compatibility=apply_api_compatibility,
)

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
raise NotImplementedError("EnvBase.forward is not implemented")
def forward(self, *args, **kwargs):
raise NotImplementedError(
"EnvBase.forward is not implemented. If you ended here during a call to `ParallelEnv(...)`, please use "
"a constructor such as `ParallelEnv(num_env, lambda env=env: env)` instead. "
"Batched envs require constructors because environment instances may not always be serializable."
)

@abc.abstractmethod
def _step(
Expand Down
2 changes: 2 additions & 0 deletions torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,8 @@ def __call__(self, tensordict):
if self.validate(tensordict):
if self.keep_other:
out = self._exclude(self.exclude_from_root, tensordict, out=None)
if out is None:
out = tensordict.empty()
else:
out = next_td.empty()
self._grab_and_place(
Expand Down

0 comments on commit 5816d75

Please sign in to comment.