Skip to content

Commit

Permalink
[Test] More comprehensive tests for auto_spec
Browse files Browse the repository at this point in the history
ghstack-source-id: 802c9ad76af873924be5414bfdb141f9ae73567f
Pull Request resolved: #2640
  • Loading branch information
vmoens committed Dec 6, 2024
1 parent 6092be4 commit 8c76e88
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
12 changes: 9 additions & 3 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from typing import Dict, List, Optional

import torch
Expand Down Expand Up @@ -1929,14 +1931,18 @@ def __init__(self):
tensor=Unbounded(3),
non_tensor=NonTensor(shape=()),
)
self._saved_obs_spec = self.observation_spec.clone()
self.state_spec = Composite(
non_tensor=NonTensor(shape=()),
)
self._saved_state_spec = self.state_spec.clone()
self.reward_spec = Unbounded(1)
self._saved_full_reward_spec = self.full_reward_spec.clone()
self.action_spec = Unbounded(1)
self._saved_full_action_spec = self.full_action_spec.clone()

def _reset(self, tensordict):
data = self.observation_spec.zero()
data = self._saved_obs_spec.zero()
data.set_non_tensor("non_tensor", 0)
data.update(self.full_done_spec.zero())
return data
Expand All @@ -1945,10 +1951,10 @@ def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:
data = self.observation_spec.zero()
data = self._saved_obs_spec.zero()
data.set_non_tensor("non_tensor", tensordict["non_tensor"] + 1)
data.update(self.full_done_spec.zero())
data.update(self.full_reward_spec.zero())
data.update(self._saved_full_reward_spec.zero())
return data

def _set_seed(self, seed: Optional[int]):
Expand Down
11 changes: 8 additions & 3 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3528,8 +3528,13 @@ def test_single_env_spec():
assert env.input_spec.is_in(env.input_spec_unbatched.zeros(env.shape))


def test_auto_spec():
env = CountingEnv()
@pytest.mark.parametrize("env_type", [CountingEnv, EnvWithMetadata])
def test_auto_spec(env_type):
if env_type is EnvWithMetadata:
obs_vals = ["tensor", "non_tensor"]
else:
obs_vals = "observation"
env = env_type()
td = env.reset()

policy = lambda td, action_spec=env.full_action_spec.clone(): td.update(
Expand All @@ -3552,7 +3557,7 @@ def test_auto_spec():
shape=env.full_state_spec.shape, device=env.full_state_spec.device
)
env._action_keys = ["action"]
env.auto_specs_(policy, tensordict=td.copy())
env.auto_specs_(policy, tensordict=td.copy(), observation_key=obs_vals)
env.check_env_specs(tensordict=td.copy())


Expand Down

0 comments on commit 8c76e88

Please sign in to comment.