diff --git a/torchrl/_utils.py b/torchrl/_utils.py index d37aebb862f..c81ffcc962b 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -829,6 +829,7 @@ def _can_be_pickled(obj): def _make_ordinal_device(device: torch.device): if device is None: return device + device = torch.device(device) if device.type == "cuda" and device.index is None: return torch.device("cuda", index=torch.cuda.current_device()) if device.type == "mps" and device.index is None: diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index d2ec66475ab..f7403e6a69e 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -14,7 +14,7 @@ import re import warnings from enum import Enum -from typing import Any, Dict, List, Union +from typing import Any, Dict, List import torch @@ -339,48 +339,47 @@ def step_mdp( exclude_reward: bool = True, exclude_done: bool = False, exclude_action: bool = True, - reward_keys: Union[NestedKey, List[NestedKey]] = "reward", - done_keys: Union[NestedKey, List[NestedKey]] = "done", - action_keys: Union[NestedKey, List[NestedKey]] = "action", + reward_keys: NestedKey | List[NestedKey] = "reward", + done_keys: NestedKey | List[NestedKey] = "done", + action_keys: NestedKey | List[NestedKey] = "action", ) -> TensorDictBase: """Creates a new tensordict that reflects a step in time of the input tensordict. Given a tensordict retrieved after a step, returns the :obj:`"next"` indexed-tensordict. - The arguments allow for a precise control over what should be kept and what + The arguments allow for precise control over what should be kept and what should be copied from the ``"next"`` entry. The default behavior is: - move the observation entries, reward and done states to the root, exclude - the current action and keep all extra keys (non-action, non-done, non-reward). + move the observation entries, reward, and done states to the root, exclude + the current action, and keep all extra keys (non-action, non-done, non-reward). Args: - tensordict (TensorDictBase): tensordict with keys to be renamed - next_tensordict (TensorDictBase, optional): destination tensordict - keep_other (bool, optional): if ``True``, all keys that do not start with :obj:`'next_'` will be kept. + tensordict (TensorDictBase): The tensordict with keys to be renamed. + next_tensordict (TensorDictBase, optional): The destination tensordict. If `None`, a new tensordict is created. + keep_other (bool, optional): If ``True``, all keys that do not start with :obj:`'next_'` will be kept. Default is ``True``. - exclude_reward (bool, optional): if ``True``, the :obj:`"reward"` key will be discarded + exclude_reward (bool, optional): If ``True``, the :obj:`"reward"` key will be discarded from the resulting tensordict. If ``False``, it will be copied (and replaced) - from the ``"next"`` entry (if present). - Default is ``True``. - exclude_done (bool, optional): if ``True``, the :obj:`"done"` key will be discarded + from the ``"next"`` entry (if present). Default is ``True``. + exclude_done (bool, optional): If ``True``, the :obj:`"done"` key will be discarded from the resulting tensordict. If ``False``, it will be copied (and replaced) - from the ``"next"`` entry (if present). - Default is ``False``. - exclude_action (bool, optional): if ``True``, the :obj:`"action"` key will + from the ``"next"`` entry (if present). Default is ``False``. + exclude_action (bool, optional): If ``True``, the :obj:`"action"` key will be discarded from the resulting tensordict. If ``False``, it will be kept in the root tensordict (since it should not be present in - the ``"next"`` entry). - Default is ``True``. - reward_keys (NestedKey or list of NestedKey, optional): the keys where the reward is written. Defaults + the ``"next"`` entry). Default is ``True``. + reward_keys (NestedKey or list of NestedKey, optional): The keys where the reward is written. Defaults to "reward". - done_keys (NestedKey or list of NestedKey, optional): the keys where the done is written. Defaults + done_keys (NestedKey or list of NestedKey, optional): The keys where the done is written. Defaults to "done". - action_keys (NestedKey or list of NestedKey, optional): the keys where the action is written. Defaults + action_keys (NestedKey or list of NestedKey, optional): The keys where the action is written. Defaults to "action". Returns: - A new tensordict (or next_tensordict) containing the tensors of the t+1 step. + TensorDictBase: A new tensordict (or `next_tensordict` if provided) containing the tensors of the t+1 step. + + .. seealso:: :meth:`EnvBase.step_mdp` is the class-based version of this free function. It will attempt to cache the + key values to reduce the overhead of making a step in the MDP. Examples: - This funtion allows for this kind of loop to be used: >>> from tensordict import TensorDict >>> import torch >>> td = TensorDict({ @@ -784,8 +783,8 @@ def check_env_specs( if _has_dynamic_specs(env.specs): for real, fake in zip( - real_tensordict.filter_non_tensor_data().unbind(-1), - fake_tensordict.filter_non_tensor_data().unbind(-1), + real_tensordict_select.filter_non_tensor_data().unbind(-1), + fake_tensordict_select.filter_non_tensor_data().unbind(-1), ): fake = fake.apply(lambda x, y: x.expand_as(y), real) if (torch.zeros_like(real) != torch.zeros_like(fake)).any():