Skip to content

Commit

Permalink
[Quality,BE] Better doc for step_mdp
Browse files Browse the repository at this point in the history
ghstack-source-id: c56b9ba7742142d84dd4998f4f8acf6226c7f74e
Pull Request resolved: #2639
  • Loading branch information
vmoens committed Dec 12, 2024
1 parent d30b1b6 commit 00bf79a
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 25 deletions.
1 change: 1 addition & 0 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,7 @@ def _can_be_pickled(obj):
def _make_ordinal_device(device: torch.device):
if device is None:
return device
device = torch.device(device)
if device.type == "cuda" and device.index is None:
return torch.device("cuda", index=torch.cuda.current_device())
if device.type == "mps" and device.index is None:
Expand Down
53 changes: 28 additions & 25 deletions torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import re
import warnings
from enum import Enum
from typing import Any, Dict, List, Union
from typing import Any, Dict, List

import torch

Expand Down Expand Up @@ -339,48 +339,47 @@ def step_mdp(
exclude_reward: bool = True,
exclude_done: bool = False,
exclude_action: bool = True,
reward_keys: Union[NestedKey, List[NestedKey]] = "reward",
done_keys: Union[NestedKey, List[NestedKey]] = "done",
action_keys: Union[NestedKey, List[NestedKey]] = "action",
reward_keys: NestedKey | List[NestedKey] = "reward",
done_keys: NestedKey | List[NestedKey] = "done",
action_keys: NestedKey | List[NestedKey] = "action",
) -> TensorDictBase:
"""Creates a new tensordict that reflects a step in time of the input tensordict.
Given a tensordict retrieved after a step, returns the :obj:`"next"` indexed-tensordict.
The arguments allow for a precise control over what should be kept and what
The arguments allow for precise control over what should be kept and what
should be copied from the ``"next"`` entry. The default behavior is:
move the observation entries, reward and done states to the root, exclude
the current action and keep all extra keys (non-action, non-done, non-reward).
move the observation entries, reward, and done states to the root, exclude
the current action, and keep all extra keys (non-action, non-done, non-reward).
Args:
tensordict (TensorDictBase): tensordict with keys to be renamed
next_tensordict (TensorDictBase, optional): destination tensordict
keep_other (bool, optional): if ``True``, all keys that do not start with :obj:`'next_'` will be kept.
tensordict (TensorDictBase): The tensordict with keys to be renamed.
next_tensordict (TensorDictBase, optional): The destination tensordict. If `None`, a new tensordict is created.
keep_other (bool, optional): If ``True``, all keys that do not start with :obj:`'next_'` will be kept.
Default is ``True``.
exclude_reward (bool, optional): if ``True``, the :obj:`"reward"` key will be discarded
exclude_reward (bool, optional): If ``True``, the :obj:`"reward"` key will be discarded
from the resulting tensordict. If ``False``, it will be copied (and replaced)
from the ``"next"`` entry (if present).
Default is ``True``.
exclude_done (bool, optional): if ``True``, the :obj:`"done"` key will be discarded
from the ``"next"`` entry (if present). Default is ``True``.
exclude_done (bool, optional): If ``True``, the :obj:`"done"` key will be discarded
from the resulting tensordict. If ``False``, it will be copied (and replaced)
from the ``"next"`` entry (if present).
Default is ``False``.
exclude_action (bool, optional): if ``True``, the :obj:`"action"` key will
from the ``"next"`` entry (if present). Default is ``False``.
exclude_action (bool, optional): If ``True``, the :obj:`"action"` key will
be discarded from the resulting tensordict. If ``False``, it will
be kept in the root tensordict (since it should not be present in
the ``"next"`` entry).
Default is ``True``.
reward_keys (NestedKey or list of NestedKey, optional): the keys where the reward is written. Defaults
the ``"next"`` entry). Default is ``True``.
reward_keys (NestedKey or list of NestedKey, optional): The keys where the reward is written. Defaults
to "reward".
done_keys (NestedKey or list of NestedKey, optional): the keys where the done is written. Defaults
done_keys (NestedKey or list of NestedKey, optional): The keys where the done is written. Defaults
to "done".
action_keys (NestedKey or list of NestedKey, optional): the keys where the action is written. Defaults
action_keys (NestedKey or list of NestedKey, optional): The keys where the action is written. Defaults
to "action".
Returns:
A new tensordict (or next_tensordict) containing the tensors of the t+1 step.
TensorDictBase: A new tensordict (or `next_tensordict` if provided) containing the tensors of the t+1 step.
.. seealso:: :meth:`EnvBase.step_mdp` is the class-based version of this free function. It will attempt to cache the
key values to reduce the overhead of making a step in the MDP.
Examples:
This funtion allows for this kind of loop to be used:
>>> from tensordict import TensorDict
>>> import torch
>>> td = TensorDict({
Expand Down Expand Up @@ -783,7 +782,9 @@ def check_env_specs(
from torchrl.envs.common import _has_dynamic_specs

if _has_dynamic_specs(env.specs):
for real, fake in zip(real_tensordict.unbind(-1), fake_tensordict.unbind(-1)):
for real, fake in zip(
real_tensordict_select.unbind(-1), fake_tensordict_select.unbind(-1)
):
fake = fake.apply(lambda x, y: x.expand_as(y), real)
if (torch.zeros_like(real) != torch.zeros_like(fake)).any():
raise AssertionError(zeroing_err_msg)
Expand Down Expand Up @@ -1367,6 +1368,8 @@ def _update_during_reset(
reset_keys: List[NestedKey],
):
"""Updates the input tensordict with the reset data, based on the reset keys."""
if not reset_keys:
return tensordict.update(tensordict_reset)
roots = set()
for reset_key in reset_keys:
# get the node of the reset key
Expand Down

0 comments on commit 00bf79a

Please sign in to comment.