Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Doc] Minor fixes to the docs and type hints #2548

Merged
merged 1 commit into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sota-implementations/ppo/ppo_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

"""
This script reproduces the Proximal Policy Optimization (PPO) Algorithm
results from Schulman et al. 2017 for the on Atari Environments.
results from Schulman et al. 2017 for the Atari Environments.
"""
import hydra
from torchrl._utils import logger as torchrl_logger
Expand Down
25 changes: 14 additions & 11 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def append_transform(
self,
transform: "Transform" # noqa: F821
| Callable[[TensorDictBase], TensorDictBase],
) -> None:
) -> EnvBase:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should actually return TransformedEnv.

However, I'm guessing we don't want to import it at the top to prevent loading a lot of modules that won't be accessed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either EnvBase or "torchrl.envs.TransformedEnv" but we can't import it
EnvBase could be the best compromise!

"""Returns a transformed environment where the callable/transform passed is applied.

Args:
Expand Down Expand Up @@ -1482,7 +1482,8 @@ def full_state_spec(self, spec: Composite) -> None:

# Single-env specs can be used to remove the batch size from the spec
@property
def batch_dims(self):
def batch_dims(self) -> int:
"""Number of batch dimensions of the env."""
return len(self.batch_size)

def _make_single_env_spec(self, spec: TensorSpec) -> TensorSpec:
Expand Down Expand Up @@ -2444,11 +2445,11 @@ def rollout(
set_truncated: bool = False,
out=None,
trust_policy: bool = False,
):
) -> TensorDictBase:
"""Executes a rollout in the environment.

The function will stop as soon as one of the contained environments
returns done=True.
The function will return as soon as any of the contained environments
reaches any of the done states.

Args:
max_steps (int): maximum number of steps to be executed. The actual number of steps can be smaller if
Expand All @@ -2464,14 +2465,16 @@ def rollout(
the call to ``rollout``.

Keyword Args:
auto_reset (bool, optional): if ``True``, resets automatically the environment
if it is in a done state when the rollout is initiated.
Default is ``True``.
Comment on lines -2467 to -2469
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO this docstring was wrong and caused some confusion on our side. I hope that I was able to get the correct meaning of this argument.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That looks great!

auto_reset (bool, optional): if ``True``, the contained environments will be reset before starting the
rollout. If ``False``, then the rollout will continue from a previous state, which requires the
``tensordict`` argument to be passed with the previous rollout. Default is ``True``.
auto_cast_to_device (bool, optional): if ``True``, the device of the tensordict is automatically cast to the
policy device before the policy is used. Default is ``False``.
break_when_any_done (bool): breaks if any of the done state is True. If False, a reset() is
called on the sub-envs that are done. Default is True.
break_when_all_done (bool): TODO
Comment on lines -2472 to -2474
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took the liberty of adding a docstring here and making some things a bit clearer. Open to suggestions for improvements!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha yeah much better!
200w

break_when_any_done (bool): if ``True``, break when any of the contained environments reaches any of the
done states. If ``False``, then the done environments are reset automatically. Default is ``True``.
break_when_all_done (bool, optional): if ``True``, break if all of the contained environments reach any
of the done states. If ``False``, break if at least one environment reaches any of the done states.
Default is ``False``.
return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is True.
tensordict (TensorDict, optional): if ``auto_reset`` is False, an initial
tensordict must be provided. Rollout will check if this tensordict has done flags and reset the
Expand Down