Skip to content

Commit

Permalink
[Feature] EnvBase.auto_specs_
Browse files Browse the repository at this point in the history
ghstack-source-id: 9cb10eeb50fa5e4108ceaddad83e717316c77cb3
Pull Request resolved: #2601
  • Loading branch information
vmoens committed Nov 25, 2024
1 parent c8676f4 commit f572281
Show file tree
Hide file tree
Showing 6 changed files with 257 additions and 68 deletions.
8 changes: 5 additions & 3 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,11 +1038,13 @@ def _step(
tensordict: TensorDictBase,
) -> TensorDictBase:
action = tensordict.get(self.action_key)
try:
device = self.full_action_spec[self.action_key].device
except KeyError:
device = self.device
self.count += action.to(
dtype=torch.int,
device=self.full_action_spec[self.action_key].device
if self.device is None
else self.device,
device=device if self.device is None else self.device,
)
tensordict = TensorDict(
source={
Expand Down
28 changes: 28 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3526,6 +3526,34 @@ def test_single_env_spec():
assert env.input_spec.is_in(env.input_spec_unbatched.zeros(env.shape))


def test_auto_spec():
env = CountingEnv()
td = env.reset()

policy = lambda td, action_spec=env.full_action_spec.clone(): td.update(
action_spec.rand()
)

env.full_observation_spec = Composite(
shape=env.full_observation_spec.shape, device=env.full_observation_spec.device
)
env.full_action_spec = Composite(
shape=env.full_action_spec.shape, device=env.full_action_spec.device
)
env.full_reward_spec = Composite(
shape=env.full_reward_spec.shape, device=env.full_reward_spec.device
)
env.full_done_spec = Composite(
shape=env.full_done_spec.shape, device=env.full_done_spec.device
)
env.full_state_spec = Composite(
shape=env.full_state_spec.shape, device=env.full_state_spec.device
)
env._action_keys = ["action"]
env.auto_specs_(policy, tensordict=td.copy())
env.check_env_specs(tensordict=td.copy())


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
9 changes: 9 additions & 0 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,15 @@ def test_getitem(self, shape, is_complete, device, dtype):
with pytest.raises(KeyError):
_ = ts["UNK"]

def test_setitem_newshape(self, shape, is_complete, device, dtype):
ts = self._composite_spec(shape, is_complete, device, dtype)
new_spec = ts.clone()
new_spec.shape = torch.Size(())
new_spec.clear_device_()
ts["new_spec"] = new_spec
assert ts["new_spec"].shape == ts.shape
assert ts["new_spec"].device == ts.device

def test_setitem_forbidden_keys(self, shape, is_complete, device, dtype):
ts = self._composite_spec(shape, is_complete, device, dtype)
for key in {"shape", "device", "dtype", "space"}:
Expand Down
25 changes: 20 additions & 5 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4372,11 +4372,20 @@ def set(self, name, spec):
if spec is not None:
shape = spec.shape
if shape[: self.ndim] != self.shape:
raise ValueError(
"The shape of the spec and the Composite mismatch: the first "
f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and "
f"Composite.shape={self.shape}."
)
if (
isinstance(spec, Composite)
and spec.ndim < self.ndim
and self.shape[: spec.ndim] == spec.shape
):
# Try to set the composite shape
spec = spec.clone()
spec.shape = self.shape
else:
raise ValueError(
"The shape of the spec and the Composite mismatch: the first "
f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and "
f"Composite.shape={self.shape}."
)
self._specs[name] = spec

def __init__(
Expand Down Expand Up @@ -4448,6 +4457,8 @@ def clear_device_(self):
"""Clears the device of the Composite."""
self._device = None
for spec in self._specs.values():
if spec is None:
continue
spec.clear_device_()
return self

Expand Down Expand Up @@ -4530,6 +4541,10 @@ def __setitem__(self, key, value):
and value.device != self.device
):
if isinstance(value, Composite) and value.device is None:
# We make a clone not to mess up the spec that was provided.
# in set() we do the same for shape - these two ops should be grouped.
# we don't care about the overhead of cloning twice though because in theory
# we don't set specs often.
value = value.clone().to(self.device)
else:
raise RuntimeError(
Expand Down
Loading

0 comments on commit f572281

Please sign in to comment.