From 6092be439b598e02273cf7f17d863f2d5331dd6a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 6 Dec 2024 21:20:27 +0000 Subject: [PATCH] [Quality,BE] Better doc for step_mdp ghstack-source-id: 966e5edb2ea0fb462ae1265e244942dd92fab244 Pull Request resolved: https://github.com/pytorch/rl/pull/2639 --- torchrl/_utils.py | 1 + torchrl/envs/utils.py | 53 +++++++++++++++++++++++-------------------- 2 files changed, 29 insertions(+), 25 deletions(-) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index d37aebb862f..c81ffcc962b 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -829,6 +829,7 @@ def _can_be_pickled(obj): def _make_ordinal_device(device: torch.device): if device is None: return device + device = torch.device(device) if device.type == "cuda" and device.index is None: return torch.device("cuda", index=torch.cuda.current_device()) if device.type == "mps" and device.index is None: diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 423b71e316e..9cba14c9690 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -14,7 +14,7 @@ import re import warnings from enum import Enum -from typing import Any, Dict, List, Union +from typing import Any, Dict, List import torch @@ -339,48 +339,47 @@ def step_mdp( exclude_reward: bool = True, exclude_done: bool = False, exclude_action: bool = True, - reward_keys: Union[NestedKey, List[NestedKey]] = "reward", - done_keys: Union[NestedKey, List[NestedKey]] = "done", - action_keys: Union[NestedKey, List[NestedKey]] = "action", + reward_keys: NestedKey | List[NestedKey] = "reward", + done_keys: NestedKey | List[NestedKey] = "done", + action_keys: NestedKey | List[NestedKey] = "action", ) -> TensorDictBase: """Creates a new tensordict that reflects a step in time of the input tensordict. Given a tensordict retrieved after a step, returns the :obj:`"next"` indexed-tensordict. - The arguments allow for a precise control over what should be kept and what + The arguments allow for precise control over what should be kept and what should be copied from the ``"next"`` entry. The default behavior is: - move the observation entries, reward and done states to the root, exclude - the current action and keep all extra keys (non-action, non-done, non-reward). + move the observation entries, reward, and done states to the root, exclude + the current action, and keep all extra keys (non-action, non-done, non-reward). Args: - tensordict (TensorDictBase): tensordict with keys to be renamed - next_tensordict (TensorDictBase, optional): destination tensordict - keep_other (bool, optional): if ``True``, all keys that do not start with :obj:`'next_'` will be kept. + tensordict (TensorDictBase): The tensordict with keys to be renamed. + next_tensordict (TensorDictBase, optional): The destination tensordict. If `None`, a new tensordict is created. + keep_other (bool, optional): If ``True``, all keys that do not start with :obj:`'next_'` will be kept. Default is ``True``. - exclude_reward (bool, optional): if ``True``, the :obj:`"reward"` key will be discarded + exclude_reward (bool, optional): If ``True``, the :obj:`"reward"` key will be discarded from the resulting tensordict. If ``False``, it will be copied (and replaced) - from the ``"next"`` entry (if present). - Default is ``True``. - exclude_done (bool, optional): if ``True``, the :obj:`"done"` key will be discarded + from the ``"next"`` entry (if present). Default is ``True``. + exclude_done (bool, optional): If ``True``, the :obj:`"done"` key will be discarded from the resulting tensordict. If ``False``, it will be copied (and replaced) - from the ``"next"`` entry (if present). - Default is ``False``. - exclude_action (bool, optional): if ``True``, the :obj:`"action"` key will + from the ``"next"`` entry (if present). Default is ``False``. + exclude_action (bool, optional): If ``True``, the :obj:`"action"` key will be discarded from the resulting tensordict. If ``False``, it will be kept in the root tensordict (since it should not be present in - the ``"next"`` entry). - Default is ``True``. - reward_keys (NestedKey or list of NestedKey, optional): the keys where the reward is written. Defaults + the ``"next"`` entry). Default is ``True``. + reward_keys (NestedKey or list of NestedKey, optional): The keys where the reward is written. Defaults to "reward". - done_keys (NestedKey or list of NestedKey, optional): the keys where the done is written. Defaults + done_keys (NestedKey or list of NestedKey, optional): The keys where the done is written. Defaults to "done". - action_keys (NestedKey or list of NestedKey, optional): the keys where the action is written. Defaults + action_keys (NestedKey or list of NestedKey, optional): The keys where the action is written. Defaults to "action". Returns: - A new tensordict (or next_tensordict) containing the tensors of the t+1 step. + TensorDictBase: A new tensordict (or `next_tensordict` if provided) containing the tensors of the t+1 step. + + .. seealso:: :meth:`EnvBase.step_mdp` is the class-based version of this free function. It will attempt to cache the + key values to reduce the overhead of making a step in the MDP. Examples: - This funtion allows for this kind of loop to be used: >>> from tensordict import TensorDict >>> import torch >>> td = TensorDict({ @@ -783,7 +782,9 @@ def check_env_specs( from torchrl.envs.common import _has_dynamic_specs if _has_dynamic_specs(env.specs): - for real, fake in zip(real_tensordict.unbind(-1), fake_tensordict.unbind(-1)): + for real, fake in zip( + real_tensordict_select.unbind(-1), fake_tensordict_select.unbind(-1) + ): fake = fake.apply(lambda x, y: x.expand_as(y), real) if (torch.zeros_like(real) != torch.zeros_like(fake)).any(): raise AssertionError(zeroing_err_msg) @@ -1367,6 +1368,8 @@ def _update_during_reset( reset_keys: List[NestedKey], ): """Updates the input tensordict with the reset data, based on the reset keys.""" + if not reset_keys: + return tensordict.update(tensordict_reset) roots = set() for reset_key in reset_keys: # get the node of the reset key