Skip to content

Commit cfecd6f

Browse files
Vincent Moensvmoens
authored andcommitted
[Doc] Getting started tutos (#1886)
1 parent 6bd9296 commit cfecd6f

30 files changed

+1478
-109
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ On the low-level end, torchrl comes with a set of highly re-usable functionals f
3333

3434
TorchRL aims at (1) a high modularity and (2) good runtime performance. Read the [full paper](https://arxiv.org/abs/2306.00577) for a more curated description of the library.
3535

36+
## Getting started
37+
38+
Check our [Getting Started tutorials](https://pytorch.org/rl/index.html#getting-started) for quickly ramp up with the basic
39+
features of the library!
40+
3641
## Documentation and knowledge base
3742

3843
The TorchRL documentation can be found [here](https://pytorch.org/rl).

docs/source/index.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,23 @@ or via a ``git clone`` if you're willing to contribute to the library:
6262
$ cd ../rl
6363
$ python setup.py develop
6464
65+
Getting started
66+
===============
67+
68+
A series of quick tutorials to get ramped up with the basic features of the
69+
library. If you're in a hurry, you can start by
70+
:ref:`the last item of the series <gs_first_training>`
71+
and navigate to the previous ones whenever you want to learn more!
72+
73+
.. toctree::
74+
:maxdepth: 1
75+
76+
tutorials/getting-started-0
77+
tutorials/getting-started-1
78+
tutorials/getting-started-2
79+
tutorials/getting-started-3
80+
tutorials/getting-started-4
81+
tutorials/getting-started-5
6582

6683
Tutorials
6784
=========

docs/source/reference/collectors.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
torchrl.collectors package
44
==========================
55

6+
.. _ref_collectors:
7+
68
Data collectors are somewhat equivalent to pytorch dataloaders, except that (1) they
79
collect data over non-static data sources and (2) the data is collected using a model
810
(likely a version of the model that is being trained).

docs/source/reference/data.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
torchrl.data package
44
====================
55

6+
.. _ref_data:
7+
68
Replay Buffers
79
--------------
810

docs/source/reference/envs.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,9 @@ single agent standards.
475475

476476
Transforms
477477
----------
478+
479+
.. _transforms:
480+
478481
.. currentmodule:: torchrl.envs.transforms
479482

480483
In most cases, the raw output of an environment must be treated before being passed to another object (such as a

docs/source/reference/modules.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,13 @@
33
torchrl.modules package
44
=======================
55

6+
.. _ref_modules:
7+
68
TensorDict modules: Actors, exploration, value models and generative models
79
---------------------------------------------------------------------------
810

11+
.. _tdmodules:
12+
913
TorchRL offers a series of module wrappers aimed at making it easy to build
1014
RL models from the ground up. These wrappers are exclusively based on
1115
:class:`tensordict.nn.TensorDictModule` and :class:`tensordict.nn.TensorDictSequential`.

docs/source/reference/objectives.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
torchrl.objectives package
44
==========================
55

6+
.. _ref_objectives:
7+
68
TorchRL provides a series of losses to use in your training scripts.
79
The aim is to have losses that are easily reusable/swappable and that have
810
a simple signature.

docs/source/reference/trainers.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,8 @@ Utils
218218
Loggers
219219
-------
220220

221+
.. _ref_loggers:
222+
221223
.. currentmodule:: torchrl.record.loggers
222224

223225
.. autosummary::

torchrl/data/replay_buffers/samplers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,7 @@ def __init__(
718718
if end_key is None:
719719
end_key = ("next", "done")
720720
if traj_key is None:
721-
traj_key = "run"
721+
traj_key = "episode"
722722
self.end_key = end_key
723723
self.traj_key = traj_key
724724

torchrl/envs/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2055,8 +2055,8 @@ def reset(
20552055

20562056
tensordict_reset = self._reset(tensordict, **kwargs)
20572057
# We assume that this is done properly
2058-
# if tensordict_reset.device != self.device:
2059-
# tensordict_reset = tensordict_reset.to(self.device, non_blocking=True)
2058+
# if reset.device != self.device:
2059+
# reset = reset.to(self.device, non_blocking=True)
20602060
if tensordict_reset is tensordict:
20612061
raise RuntimeError(
20622062
"EnvBase._reset should return outplace changes to the input "

torchrl/envs/transforms/transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -791,7 +791,7 @@ def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs):
791791
return tensordict_reset
792792

793793
def _reset_proc_data(self, tensordict, tensordict_reset):
794-
# self._complete_done(self.full_done_spec, tensordict_reset)
794+
# self._complete_done(self.full_done_spec, reset)
795795
self._reset_check_done(tensordict, tensordict_reset)
796796
if tensordict is not None:
797797
tensordict_reset = _update_during_reset(
@@ -802,7 +802,7 @@ def _reset_proc_data(self, tensordict, tensordict_reset):
802802
# # doesn't do anything special
803803
# mt_mode = self.transform.missing_tolerance
804804
# self.set_missing_tolerance(True)
805-
# tensordict_reset = self.transform._call(tensordict_reset)
805+
# reset = self.transform._call(reset)
806806
# self.set_missing_tolerance(mt_mode)
807807
return tensordict_reset
808808

torchrl/modules/tensordict_module/actors.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,13 @@
3333
class Actor(SafeModule):
3434
"""General class for deterministic actors in RL.
3535
36-
The Actor class comes with default values for the out_keys (["action"])
37-
and if the spec is provided but not as a CompositeSpec object, it will be
38-
automatically translated into :obj:`spec = CompositeSpec(action=spec)`
36+
The Actor class comes with default values for the out_keys (``["action"]``)
37+
and if the spec is provided but not as a
38+
:class:`~torchrl.data.CompositeSpec` object, it will be
39+
automatically translated into ``spec = CompositeSpec(action=spec)``.
3940
4041
Args:
41-
module (nn.Module): a :class:`torch.nn.Module` used to map the input to
42+
module (nn.Module): a :class:`~torch.nn.Module` used to map the input to
4243
the output parameter space.
4344
in_keys (iterable of str, optional): keys to be read from input
4445
tensordict and passed to the module. If it
@@ -47,9 +48,11 @@ class Actor(SafeModule):
4748
Defaults to ``["observation"]``.
4849
out_keys (iterable of str): keys to be written to the input tensordict.
4950
The length of out_keys must match the
50-
number of tensors returned by the embedded module. Using "_" as a
51+
number of tensors returned by the embedded module. Using ``"_"`` as a
5152
key avoid writing tensor to output.
5253
Defaults to ``["action"]``.
54+
55+
Keyword Args:
5356
spec (TensorSpec, optional): Keyword-only argument.
5457
Specs of the output tensor. If the module
5558
outputs multiple output tensors,
@@ -59,7 +62,7 @@ class Actor(SafeModule):
5962
input spec. Out-of-domain sampling can
6063
occur because of exploration policies or numerical under/overflow
6164
issues. If this value is out of bounds, it is projected back onto the
62-
desired space using the :obj:`TensorSpec.project`
65+
desired space using the :meth:`~torchrl.data.TensorSpec.project`
6366
method. Default is ``False``.
6467
6568
Examples:
@@ -148,17 +151,23 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential):
148151
issues. If this value is out of bounds, it is projected back onto the
149152
desired space using the :obj:`TensorSpec.project`
150153
method. Default is ``False``.
151-
default_interaction_type=InteractionType.RANDOM (str, optional): keyword-only argument.
154+
default_interaction_type (str, optional): keyword-only argument.
152155
Default method to be used to retrieve
153-
the output value. Should be one of: 'mode', 'median', 'mean' or 'random'
154-
(in which case the value is sampled randomly from the distribution). Default
155-
is 'mode'.
156-
Note: When a sample is drawn, the :obj:`ProbabilisticTDModule` instance will
157-
first look for the interaction mode dictated by the `interaction_typ()`
158-
global function. If this returns `None` (its default value), then the
159-
`default_interaction_type` of the `ProbabilisticTDModule` instance will be
160-
used. Note that DataCollector instances will use `set_interaction_type` to
161-
:class:`tensordict.nn.InteractionType.RANDOM` by default.
156+
the output value. Should be one of: 'InteractionType.MODE',
157+
'InteractionType.MEDIAN', 'InteractionType.MEAN' or
158+
'InteractionType.RANDOM' (in which case the value is sampled
159+
randomly from the distribution). Defaults to is 'InteractionType.RANDOM'.
160+
161+
.. note:: When a sample is drawn, the :class:`ProbabilisticActor` instance will
162+
first look for the interaction mode dictated by the
163+
:func:`~tensordict.nn.probabilistic.interaction_type`
164+
global function. If this returns `None` (its default value), then the
165+
`default_interaction_type` of the `ProbabilisticTDModule`
166+
instance will be used. Note that
167+
:class:`~torchrl.collectors.collectors.DataCollectorBase`
168+
instances will use `set_interaction_type` to
169+
:class:`tensordict.nn.InteractionType.RANDOM` by default.
170+
162171
distribution_class (Type, optional): keyword-only argument.
163172
A :class:`torch.distributions.Distribution` class to
164173
be used for sampling.
@@ -197,9 +206,7 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential):
197206
... in_keys=["loc", "scale"],
198207
... distribution_class=TanhNormal,
199208
... )
200-
>>> params = TensorDict.from_module(td_module)
201-
>>> with params.to_module(td_module):
202-
... td = td_module(td)
209+
>>> td = td_module(td)
203210
>>> td
204211
TensorDict(
205212
fields={
@@ -315,7 +322,8 @@ class ValueOperator(TensorDictModule):
315322
The length of out_keys must match the
316323
number of tensors returned by the embedded module. Using "_" as a
317324
key avoid writing tensor to output.
318-
Defaults to ``["action"]``.
325+
Defaults to ``["state_value"]`` or
326+
``["state_action_value"]`` if ``"action"`` is part of the ``in_keys``.
319327
320328
Examples:
321329
>>> import torch
@@ -334,9 +342,7 @@ class ValueOperator(TensorDictModule):
334342
>>> td_module = ValueOperator(
335343
... in_keys=["observation", "action"], module=module
336344
... )
337-
>>> params = TensorDict.from_module(td_module)
338-
>>> with params.to_module(td_module):
339-
... td = td_module(td)
345+
>>> td = td_module(td)
340346
>>> print(td)
341347
TensorDict(
342348
fields={

torchrl/objectives/dqn.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,10 @@ def __init__(
213213
try:
214214
action_space = value_network.action_space
215215
except AttributeError:
216-
raise ValueError(self.ACTION_SPEC_ERROR)
216+
raise ValueError(
217+
"The action space could not be retrieved from the value_network. "
218+
"Make sure it is available to the DQN loss module."
219+
)
217220
if action_space is None:
218221
warnings.warn(
219222
"action_space was not specified. DQNLoss will default to 'one-hot'."

torchrl/objectives/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,7 @@ def __init__(
300300
):
301301
if eps is None and tau is None:
302302
raise RuntimeError(
303-
"Neither eps nor tau was provided. " "This behaviour is deprecated.",
304-
category=DeprecationWarning,
303+
"Neither eps nor tau was provided. This behaviour is deprecated.",
305304
)
306305
eps = 0.999
307306
if (eps is None) ^ (tau is None):

torchrl/record/loggers/csv.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
57
import os
68
from collections import defaultdict
79
from pathlib import Path
@@ -126,7 +128,7 @@ class CSVLogger(Logger):
126128
def __init__(
127129
self,
128130
exp_name: str,
129-
log_dir: Optional[str] = None,
131+
log_dir: str | None = None,
130132
video_format: str = "pt",
131133
video_fps: int = 30,
132134
) -> None:

tutorials/sphinx-tutorials/README.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
README Tutos
22
============
3+
4+
Check the tutorials on torchrl documentation: https://pytorch.org/rl

tutorials/sphinx-tutorials/coding_ddpg.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
======================================
55
**Author**: `Vincent Moens <https://github.com/vmoens>`_
66
7+
.. _coding_ddpg:
8+
79
"""
810

911
##############################################################################

tutorials/sphinx-tutorials/coding_dqn.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
==============================
55
**Author**: `Vincent Moens <https://github.com/vmoens>`_
66
7+
.. _coding_dqn:
8+
79
"""
810

911
##############################################################################
@@ -404,9 +406,9 @@ def get_replay_buffer(buffer_size, n_optim, batch_size):
404406
# environment executed in parallel in each collector (controlled by the
405407
# ``num_workers`` hyperparameter).
406408
#
407-
# When building the collector, we can choose on which device we want the
408-
# environment and policy to execute the operations through the ``device``
409-
# keyword argument. The ``storing_devices`` argument will modify the
409+
# Collector's devices are fully parametrizable through the ``device`` (general),
410+
# ``policy_device``, ``env_device`` and ``storing_device`` arguments.
411+
# The ``storing_device`` argument will modify the
410412
# location of the data being collected: if the batches that we are gathering
411413
# have a considerable size, we may want to store them on a different location
412414
# than the device where the computation is happening. For asynchronous data

tutorials/sphinx-tutorials/coding_ppo.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
==================================================
55
**Author**: `Vincent Moens <https://github.com/vmoens>`_
66
7+
.. _coding_ppo:
8+
79
This tutorial demonstrates how to use PyTorch and :py:mod:`torchrl` to train a parametric policy
810
network to solve the Inverted Pendulum task from the `OpenAI-Gym/Farama-Gymnasium
911
control library <https://github.com/Farama-Foundation/Gymnasium>`__.

tutorials/sphinx-tutorials/dqn_with_rnn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
77
**Author**: `Vincent Moens <https://github.com/vmoens>`_
88
9+
.. _RNN_tuto:
10+
911
.. grid:: 2
1012
1113
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn

0 commit comments

Comments
 (0)