Skip to content

Commit

Permalink
[Feature, Test] Adding tests for envs that have no specs
Browse files Browse the repository at this point in the history
ghstack-source-id: 4c75691baa1e70f417e518df15c4208cff189950
Pull Request resolved: #2621
  • Loading branch information
vmoens committed Dec 2, 2024
1 parent 830f2f2 commit c72583f
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 2 deletions.
14 changes: 14 additions & 0 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1996,3 +1996,17 @@ def _step(

def _set_seed(self, seed: Optional[int]):
...


class EnvThatDoesNothing(EnvBase):
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
return TensorDict(batch_size=self.batch_size, device=self.device)

def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:
return TensorDict(batch_size=self.batch_size, device=self.device)

def _set_seed(self, seed):
...
30 changes: 30 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
DiscreteActionConvMockEnvNumpy,
DiscreteActionVecMockEnv,
DummyModelBasedEnvBase,
EnvThatDoesNothing,
EnvWithDynamicSpec,
EnvWithMetadata,
HeterogeneousCountingEnv,
Expand Down Expand Up @@ -81,6 +82,7 @@
DiscreteActionConvMockEnvNumpy,
DiscreteActionVecMockEnv,
DummyModelBasedEnvBase,
EnvThatDoesNothing,
EnvWithDynamicSpec,
EnvWithMetadata,
HeterogeneousCountingEnv,
Expand Down Expand Up @@ -3554,6 +3556,34 @@ def test_auto_spec():
env.check_env_specs(tensordict=td.copy())


def test_env_that_does_nothing():
env = EnvThatDoesNothing()
env.check_env_specs()
r = env.rollout(3)
r.exclude(
"done", "terminated", ("next", "done"), ("next", "terminated"), inplace=True
)
assert r.is_empty()
p_env = SerialEnv(2, EnvThatDoesNothing)
p_env.check_env_specs()
r = p_env.rollout(3)
r.exclude(
"done", "terminated", ("next", "done"), ("next", "terminated"), inplace=True
)
assert r.is_empty()
p_env = ParallelEnv(2, EnvThatDoesNothing)
try:
p_env.check_env_specs()
r = p_env.rollout(3)
r.exclude(
"done", "terminated", ("next", "done"), ("next", "terminated"), inplace=True
)
assert r.is_empty()
finally:
p_env.close()
del p_env


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
8 changes: 6 additions & 2 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2434,8 +2434,12 @@ def _register_gym( # noqa: F811
apply_api_compatibility=apply_api_compatibility,
)

def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
raise NotImplementedError("EnvBase.forward is not implemented")
def forward(self, *args, **kwargs):
raise NotImplementedError(
"EnvBase.forward is not implemented. If you ended here during a call to `ParallelEnv(...)`, please use "
"a constructor such as `ParallelEnv(num_env, lambda env=env: env)` instead. "
"Batched envs require constructors because environment instances may not always be serializable."
)

@abc.abstractmethod
def _step(
Expand Down
2 changes: 2 additions & 0 deletions torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,8 @@ def __call__(self, tensordict):
if self.validate(tensordict):
if self.keep_other:
out = self._exclude(self.exclude_from_root, tensordict, out=None)
if out is None:
out = tensordict.empty()
else:
out = next_td.empty()
self._grab_and_place(
Expand Down

1 comment on commit c72583f

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: c72583f Previous: 830f2f2 Ratio
benchmarks/test_replaybuffer_benchmark.py::test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] 37.10445721004974 iter/sec (stddev: 0.1615408179199506) 250.33690089770667 iter/sec (stddev: 0.0005382685575183127) 6.75

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.