-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Doc] Minor fixes to the docs and type hints #2548
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -516,7 +516,7 @@ def append_transform( | |
self, | ||
transform: "Transform" # noqa: F821 | ||
| Callable[[TensorDictBase], TensorDictBase], | ||
) -> None: | ||
) -> EnvBase: | ||
"""Returns a transformed environment where the callable/transform passed is applied. | ||
|
||
Args: | ||
|
@@ -1482,7 +1482,8 @@ def full_state_spec(self, spec: Composite) -> None: | |
|
||
# Single-env specs can be used to remove the batch size from the spec | ||
@property | ||
def batch_dims(self): | ||
def batch_dims(self) -> int: | ||
"""Number of batch dimensions of the env.""" | ||
return len(self.batch_size) | ||
|
||
def _make_single_env_spec(self, spec: TensorSpec) -> TensorSpec: | ||
|
@@ -2444,11 +2445,11 @@ def rollout( | |
set_truncated: bool = False, | ||
out=None, | ||
trust_policy: bool = False, | ||
): | ||
) -> TensorDictBase: | ||
"""Executes a rollout in the environment. | ||
|
||
The function will stop as soon as one of the contained environments | ||
returns done=True. | ||
The function will return as soon as any of the contained environments | ||
reaches any of the done states. | ||
|
||
Args: | ||
max_steps (int): maximum number of steps to be executed. The actual number of steps can be smaller if | ||
|
@@ -2464,14 +2465,16 @@ def rollout( | |
the call to ``rollout``. | ||
|
||
Keyword Args: | ||
auto_reset (bool, optional): if ``True``, resets automatically the environment | ||
if it is in a done state when the rollout is initiated. | ||
Default is ``True``. | ||
Comment on lines
-2467
to
-2469
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO this docstring was wrong and caused some confusion on our side. I hope that I was able to get the correct meaning of this argument. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That looks great! |
||
auto_reset (bool, optional): if ``True``, the contained environments will be reset before starting the | ||
rollout. If ``False``, then the rollout will continue from a previous state, which requires the | ||
``tensordict`` argument to be passed with the previous rollout. Default is ``True``. | ||
auto_cast_to_device (bool, optional): if ``True``, the device of the tensordict is automatically cast to the | ||
policy device before the policy is used. Default is ``False``. | ||
break_when_any_done (bool): breaks if any of the done state is True. If False, a reset() is | ||
called on the sub-envs that are done. Default is True. | ||
break_when_all_done (bool): TODO | ||
Comment on lines
-2472
to
-2474
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Took the liberty of adding a docstring here and making some things a bit clearer. Open to suggestions for improvements! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
break_when_any_done (bool): if ``True``, break when any of the contained environments reaches any of the | ||
done states. If ``False``, then the done environments are reset automatically. Default is ``True``. | ||
break_when_all_done (bool, optional): if ``True``, break if all of the contained environments reach any | ||
of the done states. If ``False``, break if at least one environment reaches any of the done states. | ||
Default is ``False``. | ||
return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is True. | ||
tensordict (TensorDict, optional): if ``auto_reset`` is False, an initial | ||
tensordict must be provided. Rollout will check if this tensordict has done flags and reset the | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should actually return
TransformedEnv
.However, I'm guessing we don't want to import it at the top to prevent loading a lot of modules that won't be accessed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Either
EnvBase
or "torchrl.envs.TransformedEnv" but we can't import itEnvBase
could be the best compromise!