From 58c384713a303c0d41a9f0ba224ac3539eb4e8e1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 11 Nov 2024 12:58:10 +0000 Subject: [PATCH] [Feature] single__spec ghstack-source-id: 27e247ea1775e455999a114dd6d95fac748376c4 Pull Request resolved: https://github.com/pytorch/rl/pull/2549 --- test/test_env.py | 16 ++++++++++ torchrl/envs/common.py | 71 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+) diff --git a/test/test_env.py b/test/test_env.py index 1f95a55c2c7..14633e6b8d9 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -3510,6 +3510,22 @@ def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_devi assert (td[3].get("next") != 0).any() +def test_single_env_spec(): + env = NestedCountingEnv(batch_size=[3, 1, 7]) + assert not env.single_full_action_spec.shape + assert not env.single_full_done_spec.shape + assert not env.single_input_spec.shape + assert not env.single_full_observation_spec.shape + assert not env.single_output_spec.shape + assert not env.single_full_reward_spec.shape + + assert env.single_action_spec.shape + assert env.single_reward_spec.shape + + assert env.output_spec.is_in(env.single_output_spec.zeros(env.shape)) + assert env.input_spec.is_in(env.single_input_spec.zeros(env.shape)) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 5ac0e0dca17..22aec1cbb0d 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -1480,6 +1480,77 @@ def full_state_spec(self) -> Composite: def full_state_spec(self, spec: Composite) -> None: self.state_spec = spec + # Single-env specs can be used to remove the batch size from the spec + @property + def batch_dims(self): + return len(self.batch_size) + + def _make_single_env_spec(self, spec: TensorSpec) -> TensorSpec: + if not self.batch_dims: + return spec + idx = tuple(0 for _ in range(self.batch_dims)) + return spec[idx] + + @property + def single_full_action_spec(self) -> Composite: + """Returns the action spec of the env as if it had no batch dimensions.""" + return self._make_single_env_spec(self.full_action_spec) + + @property + def single_action_spec(self) -> TensorSpec: + """Returns the action spec of the env as if it had no batch dimensions.""" + return self._make_single_env_spec(self.action_spec) + + @property + def single_full_observation_spec(self) -> Composite: + """Returns the observation spec of the env as if it had no batch dimensions.""" + return self._make_single_env_spec(self.full_action_spec) + + @property + def single_observation_spec(self) -> Composite: + """Returns the observation spec of the env as if it had no batch dimensions.""" + return self._make_single_env_spec(self.observation_spec) + + @property + def single_full_reward_spec(self) -> Composite: + """Returns the reward spec of the env as if it had no batch dimensions.""" + return self._make_single_env_spec(self.full_action_spec) + + @property + def single_reward_spec(self) -> TensorSpec: + """Returns the reward spec of the env as if it had no batch dimensions.""" + return self._make_single_env_spec(self.reward_spec) + + @property + def single_full_done_spec(self) -> Composite: + """Returns the done spec of the env as if it had no batch dimensions.""" + return self._make_single_env_spec(self.full_action_spec) + + @property + def single_done_spec(self) -> TensorSpec: + """Returns the done spec of the env as if it had no batch dimensions.""" + return self._make_single_env_spec(self.done_spec) + + @property + def single_output_spec(self) -> Composite: + """Returns the output spec of the env as if it had no batch dimensions.""" + return self._make_single_env_spec(self.output_spec) + + @property + def single_input_spec(self) -> Composite: + """Returns the input spec of the env as if it had no batch dimensions.""" + return self._make_single_env_spec(self.input_spec) + + @property + def single_full_state_spec(self) -> Composite: + """Returns the state spec of the env as if it had no batch dimensions.""" + return self._make_single_env_spec(self.full_state_spec) + + @property + def single_state_spec(self) -> TensorSpec: + """Returns the state spec of the env as if it had no batch dimensions.""" + return self._make_single_env_spec(self.state_spec) + def step(self, tensordict: TensorDictBase) -> TensorDictBase: """Makes a step in the environment.