Skip to content

[Quality,BE] Better doc for step_mdp #2639

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,7 @@ def _can_be_pickled(obj):
def _make_ordinal_device(device: torch.device):
if device is None:
return device
device = torch.device(device)
if device.type == "cuda" and device.index is None:
return torch.device("cuda", index=torch.cuda.current_device())
if device.type == "mps" and device.index is None:
Expand Down
51 changes: 25 additions & 26 deletions torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import re
import warnings
from enum import Enum
from typing import Any, Dict, List, Union
from typing import Any, Dict, List

import torch

Expand Down Expand Up @@ -339,48 +339,47 @@ def step_mdp(
exclude_reward: bool = True,
exclude_done: bool = False,
exclude_action: bool = True,
reward_keys: Union[NestedKey, List[NestedKey]] = "reward",
done_keys: Union[NestedKey, List[NestedKey]] = "done",
action_keys: Union[NestedKey, List[NestedKey]] = "action",
reward_keys: NestedKey | List[NestedKey] = "reward",
done_keys: NestedKey | List[NestedKey] = "done",
action_keys: NestedKey | List[NestedKey] = "action",
) -> TensorDictBase:
"""Creates a new tensordict that reflects a step in time of the input tensordict.

Given a tensordict retrieved after a step, returns the :obj:`"next"` indexed-tensordict.
The arguments allow for a precise control over what should be kept and what
The arguments allow for precise control over what should be kept and what
should be copied from the ``"next"`` entry. The default behavior is:
move the observation entries, reward and done states to the root, exclude
the current action and keep all extra keys (non-action, non-done, non-reward).
move the observation entries, reward, and done states to the root, exclude
the current action, and keep all extra keys (non-action, non-done, non-reward).

Args:
tensordict (TensorDictBase): tensordict with keys to be renamed
next_tensordict (TensorDictBase, optional): destination tensordict
keep_other (bool, optional): if ``True``, all keys that do not start with :obj:`'next_'` will be kept.
tensordict (TensorDictBase): The tensordict with keys to be renamed.
next_tensordict (TensorDictBase, optional): The destination tensordict. If `None`, a new tensordict is created.
keep_other (bool, optional): If ``True``, all keys that do not start with :obj:`'next_'` will be kept.
Default is ``True``.
exclude_reward (bool, optional): if ``True``, the :obj:`"reward"` key will be discarded
exclude_reward (bool, optional): If ``True``, the :obj:`"reward"` key will be discarded
from the resulting tensordict. If ``False``, it will be copied (and replaced)
from the ``"next"`` entry (if present).
Default is ``True``.
exclude_done (bool, optional): if ``True``, the :obj:`"done"` key will be discarded
from the ``"next"`` entry (if present). Default is ``True``.
exclude_done (bool, optional): If ``True``, the :obj:`"done"` key will be discarded
from the resulting tensordict. If ``False``, it will be copied (and replaced)
from the ``"next"`` entry (if present).
Default is ``False``.
exclude_action (bool, optional): if ``True``, the :obj:`"action"` key will
from the ``"next"`` entry (if present). Default is ``False``.
exclude_action (bool, optional): If ``True``, the :obj:`"action"` key will
be discarded from the resulting tensordict. If ``False``, it will
be kept in the root tensordict (since it should not be present in
the ``"next"`` entry).
Default is ``True``.
reward_keys (NestedKey or list of NestedKey, optional): the keys where the reward is written. Defaults
the ``"next"`` entry). Default is ``True``.
reward_keys (NestedKey or list of NestedKey, optional): The keys where the reward is written. Defaults
to "reward".
done_keys (NestedKey or list of NestedKey, optional): the keys where the done is written. Defaults
done_keys (NestedKey or list of NestedKey, optional): The keys where the done is written. Defaults
to "done".
action_keys (NestedKey or list of NestedKey, optional): the keys where the action is written. Defaults
action_keys (NestedKey or list of NestedKey, optional): The keys where the action is written. Defaults
to "action".

Returns:
A new tensordict (or next_tensordict) containing the tensors of the t+1 step.
TensorDictBase: A new tensordict (or `next_tensordict` if provided) containing the tensors of the t+1 step.

.. seealso:: :meth:`EnvBase.step_mdp` is the class-based version of this free function. It will attempt to cache the
key values to reduce the overhead of making a step in the MDP.

Examples:
This funtion allows for this kind of loop to be used:
>>> from tensordict import TensorDict
>>> import torch
>>> td = TensorDict({
Expand Down Expand Up @@ -784,8 +783,8 @@ def check_env_specs(

if _has_dynamic_specs(env.specs):
for real, fake in zip(
real_tensordict.filter_non_tensor_data().unbind(-1),
fake_tensordict.filter_non_tensor_data().unbind(-1),
real_tensordict_select.filter_non_tensor_data().unbind(-1),
fake_tensordict_select.filter_non_tensor_data().unbind(-1),
):
fake = fake.apply(lambda x, y: x.expand_as(y), real)
if (torch.zeros_like(real) != torch.zeros_like(fake)).any():
Expand Down
Loading