diff --git a/README.md b/README.md index 2e1d08a0757..6adbc2decfe 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,11 @@ On the low-level end, torchrl comes with a set of highly re-usable functionals f TorchRL aims at (1) a high modularity and (2) good runtime performance. Read the [full paper](https://arxiv.org/abs/2306.00577) for a more curated description of the library. +## Getting started + +Check our [Getting Started tutorials](https://pytorch.org/rl/index.html#getting-started) for quickly ramp up with the basic +features of the library! + ## Documentation and knowledge base The TorchRL documentation can be found [here](https://pytorch.org/rl). diff --git a/docs/source/index.rst b/docs/source/index.rst index 49bcde82488..ab1cee681db 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -62,6 +62,23 @@ or via a ``git clone`` if you're willing to contribute to the library: $ cd ../rl $ python setup.py develop +Getting started +=============== + +A series of quick tutorials to get ramped up with the basic features of the +library. If you're in a hurry, you can start by +:ref:`the last item of the series ` +and navigate to the previous ones whenever you want to learn more! + +.. toctree:: + :maxdepth: 1 + + tutorials/getting-started-0 + tutorials/getting-started-1 + tutorials/getting-started-2 + tutorials/getting-started-3 + tutorials/getting-started-4 + tutorials/getting-started-5 Tutorials ========= diff --git a/docs/source/reference/collectors.rst b/docs/source/reference/collectors.rst index aa8de179f20..982b8664862 100644 --- a/docs/source/reference/collectors.rst +++ b/docs/source/reference/collectors.rst @@ -3,6 +3,8 @@ torchrl.collectors package ========================== +.. _ref_collectors: + Data collectors are somewhat equivalent to pytorch dataloaders, except that (1) they collect data over non-static data sources and (2) the data is collected using a model (likely a version of the model that is being trained). diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 6ed32ebe921..80b58f96989 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -3,6 +3,8 @@ torchrl.data package ==================== +.. _ref_data: + Replay Buffers -------------- diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index cce34e14b14..4dbb5a5da57 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -475,6 +475,9 @@ single agent standards. Transforms ---------- + +.. _transforms: + .. currentmodule:: torchrl.envs.transforms In most cases, the raw output of an environment must be treated before being passed to another object (such as a diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index d859140bb70..bcd234a7ff9 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -3,9 +3,13 @@ torchrl.modules package ======================= +.. _ref_modules: + TensorDict modules: Actors, exploration, value models and generative models --------------------------------------------------------------------------- +.. _tdmodules: + TorchRL offers a series of module wrappers aimed at making it easy to build RL models from the ground up. These wrappers are exclusively based on :class:`tensordict.nn.TensorDictModule` and :class:`tensordict.nn.TensorDictSequential`. diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index 1aec88f2d11..c2f43d8e9b6 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -3,6 +3,8 @@ torchrl.objectives package ========================== +.. _ref_objectives: + TorchRL provides a series of losses to use in your training scripts. The aim is to have losses that are easily reusable/swappable and that have a simple signature. diff --git a/docs/source/reference/trainers.rst b/docs/source/reference/trainers.rst index eb857f15a0f..04d4386c631 100644 --- a/docs/source/reference/trainers.rst +++ b/docs/source/reference/trainers.rst @@ -218,6 +218,8 @@ Utils Loggers ------- +.. _ref_loggers: + .. currentmodule:: torchrl.record.loggers .. autosummary:: diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 21245e37acd..0022fe41569 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -718,7 +718,7 @@ def __init__( if end_key is None: end_key = ("next", "done") if traj_key is None: - traj_key = "run" + traj_key = "episode" self.end_key = end_key self.traj_key = traj_key diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 61cd211b6ae..aaba2981047 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -2055,8 +2055,8 @@ def reset( tensordict_reset = self._reset(tensordict, **kwargs) # We assume that this is done properly - # if tensordict_reset.device != self.device: - # tensordict_reset = tensordict_reset.to(self.device, non_blocking=True) + # if reset.device != self.device: + # reset = reset.to(self.device, non_blocking=True) if tensordict_reset is tensordict: raise RuntimeError( "EnvBase._reset should return outplace changes to the input " diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index efa59e25c26..5c62f37db9f 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -791,7 +791,7 @@ def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs): return tensordict_reset def _reset_proc_data(self, tensordict, tensordict_reset): - # self._complete_done(self.full_done_spec, tensordict_reset) + # self._complete_done(self.full_done_spec, reset) self._reset_check_done(tensordict, tensordict_reset) if tensordict is not None: tensordict_reset = _update_during_reset( @@ -802,7 +802,7 @@ def _reset_proc_data(self, tensordict, tensordict_reset): # # doesn't do anything special # mt_mode = self.transform.missing_tolerance # self.set_missing_tolerance(True) - # tensordict_reset = self.transform._call(tensordict_reset) + # reset = self.transform._call(reset) # self.set_missing_tolerance(mt_mode) return tensordict_reset diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index b7a044cae7d..8d9855283f5 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -33,12 +33,13 @@ class Actor(SafeModule): """General class for deterministic actors in RL. - The Actor class comes with default values for the out_keys (["action"]) - and if the spec is provided but not as a CompositeSpec object, it will be - automatically translated into :obj:`spec = CompositeSpec(action=spec)` + The Actor class comes with default values for the out_keys (``["action"]``) + and if the spec is provided but not as a + :class:`~torchrl.data.CompositeSpec` object, it will be + automatically translated into ``spec = CompositeSpec(action=spec)``. Args: - module (nn.Module): a :class:`torch.nn.Module` used to map the input to + module (nn.Module): a :class:`~torch.nn.Module` used to map the input to the output parameter space. in_keys (iterable of str, optional): keys to be read from input tensordict and passed to the module. If it @@ -47,9 +48,11 @@ class Actor(SafeModule): Defaults to ``["observation"]``. out_keys (iterable of str): keys to be written to the input tensordict. The length of out_keys must match the - number of tensors returned by the embedded module. Using "_" as a + number of tensors returned by the embedded module. Using ``"_"`` as a key avoid writing tensor to output. Defaults to ``["action"]``. + + Keyword Args: spec (TensorSpec, optional): Keyword-only argument. Specs of the output tensor. If the module outputs multiple output tensors, @@ -59,7 +62,7 @@ class Actor(SafeModule): input spec. Out-of-domain sampling can occur because of exploration policies or numerical under/overflow issues. If this value is out of bounds, it is projected back onto the - desired space using the :obj:`TensorSpec.project` + desired space using the :meth:`~torchrl.data.TensorSpec.project` method. Default is ``False``. Examples: @@ -148,17 +151,23 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential): issues. If this value is out of bounds, it is projected back onto the desired space using the :obj:`TensorSpec.project` method. Default is ``False``. - default_interaction_type=InteractionType.RANDOM (str, optional): keyword-only argument. + default_interaction_type (str, optional): keyword-only argument. Default method to be used to retrieve - the output value. Should be one of: 'mode', 'median', 'mean' or 'random' - (in which case the value is sampled randomly from the distribution). Default - is 'mode'. - Note: When a sample is drawn, the :obj:`ProbabilisticTDModule` instance will - first look for the interaction mode dictated by the `interaction_typ()` - global function. If this returns `None` (its default value), then the - `default_interaction_type` of the `ProbabilisticTDModule` instance will be - used. Note that DataCollector instances will use `set_interaction_type` to - :class:`tensordict.nn.InteractionType.RANDOM` by default. + the output value. Should be one of: 'InteractionType.MODE', + 'InteractionType.MEDIAN', 'InteractionType.MEAN' or + 'InteractionType.RANDOM' (in which case the value is sampled + randomly from the distribution). Defaults to is 'InteractionType.RANDOM'. + + .. note:: When a sample is drawn, the :class:`ProbabilisticActor` instance will + first look for the interaction mode dictated by the + :func:`~tensordict.nn.probabilistic.interaction_type` + global function. If this returns `None` (its default value), then the + `default_interaction_type` of the `ProbabilisticTDModule` + instance will be used. Note that + :class:`~torchrl.collectors.collectors.DataCollectorBase` + instances will use `set_interaction_type` to + :class:`tensordict.nn.InteractionType.RANDOM` by default. + distribution_class (Type, optional): keyword-only argument. A :class:`torch.distributions.Distribution` class to be used for sampling. @@ -197,9 +206,7 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential): ... in_keys=["loc", "scale"], ... distribution_class=TanhNormal, ... ) - >>> params = TensorDict.from_module(td_module) - >>> with params.to_module(td_module): - ... td = td_module(td) + >>> td = td_module(td) >>> td TensorDict( fields={ @@ -315,7 +322,8 @@ class ValueOperator(TensorDictModule): The length of out_keys must match the number of tensors returned by the embedded module. Using "_" as a key avoid writing tensor to output. - Defaults to ``["action"]``. + Defaults to ``["state_value"]`` or + ``["state_action_value"]`` if ``"action"`` is part of the ``in_keys``. Examples: >>> import torch @@ -334,9 +342,7 @@ class ValueOperator(TensorDictModule): >>> td_module = ValueOperator( ... in_keys=["observation", "action"], module=module ... ) - >>> params = TensorDict.from_module(td_module) - >>> with params.to_module(td_module): - ... td = td_module(td) + >>> td = td_module(td) >>> print(td) TensorDict( fields={ diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index 37fd1cbdaea..2298c262368 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -213,7 +213,10 @@ def __init__( try: action_space = value_network.action_space except AttributeError: - raise ValueError(self.ACTION_SPEC_ERROR) + raise ValueError( + "The action space could not be retrieved from the value_network. " + "Make sure it is available to the DQN loss module." + ) if action_space is None: warnings.warn( "action_space was not specified. DQNLoss will default to 'one-hot'." diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 43dfa65c0c4..b234af6a804 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -300,8 +300,7 @@ def __init__( ): if eps is None and tau is None: raise RuntimeError( - "Neither eps nor tau was provided. " "This behaviour is deprecated.", - category=DeprecationWarning, + "Neither eps nor tau was provided. This behaviour is deprecated.", ) eps = 0.999 if (eps is None) ^ (tau is None): diff --git a/torchrl/record/loggers/csv.py b/torchrl/record/loggers/csv.py index 256d0a2e840..6bcd3f50c86 100644 --- a/torchrl/record/loggers/csv.py +++ b/torchrl/record/loggers/csv.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import os from collections import defaultdict from pathlib import Path @@ -126,7 +128,7 @@ class CSVLogger(Logger): def __init__( self, exp_name: str, - log_dir: Optional[str] = None, + log_dir: str | None = None, video_format: str = "pt", video_fps: int = 30, ) -> None: diff --git a/tutorials/sphinx-tutorials/README.rst b/tutorials/sphinx-tutorials/README.rst index a7e41cccf45..7995a1fbb2e 100644 --- a/tutorials/sphinx-tutorials/README.rst +++ b/tutorials/sphinx-tutorials/README.rst @@ -1,2 +1,4 @@ README Tutos ============ + +Check the tutorials on torchrl documentation: https://pytorch.org/rl diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index 5f8bf2c0830..252b4fd2146 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -4,6 +4,8 @@ ====================================== **Author**: `Vincent Moens `_ +.. _coding_ddpg: + """ ############################################################################## diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index f85f6bf1e14..eb476dfcc15 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -4,6 +4,8 @@ ============================== **Author**: `Vincent Moens `_ +.. _coding_dqn: + """ ############################################################################## @@ -404,9 +406,9 @@ def get_replay_buffer(buffer_size, n_optim, batch_size): # environment executed in parallel in each collector (controlled by the # ``num_workers`` hyperparameter). # -# When building the collector, we can choose on which device we want the -# environment and policy to execute the operations through the ``device`` -# keyword argument. The ``storing_devices`` argument will modify the +# Collector's devices are fully parametrizable through the ``device`` (general), +# ``policy_device``, ``env_device`` and ``storing_device`` arguments. +# The ``storing_device`` argument will modify the # location of the data being collected: if the batches that we are gathering # have a considerable size, we may want to store them on a different location # than the device where the computation is happening. For asynchronous data diff --git a/tutorials/sphinx-tutorials/coding_ppo.py b/tutorials/sphinx-tutorials/coding_ppo.py index be82bbd3bd8..6f31a0aed1a 100644 --- a/tutorials/sphinx-tutorials/coding_ppo.py +++ b/tutorials/sphinx-tutorials/coding_ppo.py @@ -4,6 +4,8 @@ ================================================== **Author**: `Vincent Moens `_ +.. _coding_ppo: + This tutorial demonstrates how to use PyTorch and :py:mod:`torchrl` to train a parametric policy network to solve the Inverted Pendulum task from the `OpenAI-Gym/Farama-Gymnasium control library `__. diff --git a/tutorials/sphinx-tutorials/dqn_with_rnn.py b/tutorials/sphinx-tutorials/dqn_with_rnn.py index b71a112c91a..a2b2b12b562 100644 --- a/tutorials/sphinx-tutorials/dqn_with_rnn.py +++ b/tutorials/sphinx-tutorials/dqn_with_rnn.py @@ -6,6 +6,8 @@ **Author**: `Vincent Moens `_ +.. _RNN_tuto: + .. grid:: 2 .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn diff --git a/tutorials/sphinx-tutorials/getting-started-0.py b/tutorials/sphinx-tutorials/getting-started-0.py new file mode 100644 index 00000000000..e81b10ec381 --- /dev/null +++ b/tutorials/sphinx-tutorials/getting-started-0.py @@ -0,0 +1,245 @@ +# -*- coding: utf-8 -*- +""" + +Get started with Environments, TED and transforms +================================================= + +**Author**: `Vincent Moens `_ + +.. _gs_env_ted: + +""" + +################################ +# Welcome to the getting started tutorials! +# +# Below is the list of the topics we will be covering. +# +# - :ref:`Environments, TED and transforms `; +# - :ref:`TorchRL's modules `; +# - :ref:`Losses and optimization `; +# - :ref:`Data collection and storage `; +# - :ref:`TorchRL's logging API `. +# +# If you are in a hurry, you can jump straight away to the last tutorial, +# :ref:`Your onw first training loop `, from where you can +# backtrack every other "Getting Started" tutorial if things are not clear or +# if you want to learn more about a specific topic! +# +# Environments in RL +# ------------------ +# +# The standard RL (Reinforcement Learning) training loop involves a model, +# also known as a policy, which is trained to accomplish a task within a +# specific environment. Often, this environment is a simulator that accepts +# actions as input and produces an observation along with some metadata as +# output. +# +# In this document, we will explore the environment API of TorchRL: we will +# learn how to create an environment, interact with it, and understand the +# data format it uses. +# +# Creating an environment +# ----------------------- +# +# In essence, TorchRL does not directly provide environments, but instead +# offers wrappers for other libraries that encapsulate the simulators. The +# :mod:`~torchrl.envs` module can be viewed as a provider for a generic +# environment API, as well as a central hub for simulation backends like +# `gym `_ (:class:`~torchrl.envs.GymEnv`), +# `Brax `_ (:class:`~torchrl.envs.BraxEnv`) +# or `DeepMind Control Suite `_ +# (:class:`~torchrl.envs.DMControlEnv`). +# +# Creating your environment is typically as straightforward as the underlying +# backend API allows. Here's an example using gym: + +from torchrl.envs import GymEnv + +env = GymEnv("Pendulum-v1") + +################################ +# +# Running an environment +# ---------------------- +# +# Environments in TorchRL have two crucial methods: +# :meth:`~torchrl.envs.EnvBase.reset`, which initiates +# an episode, and :meth:`~torchrl.envs.EnvBase.step`, which executes an +# action selected by the actor. +# In TorchRL, environment methods read and write +# :class:`~tensordict.TensorDict` instances. +# Essentially, :class:`~tensordict.TensorDict` is a generic key-based data +# carrier for tensors. +# The benefit of using TensorDict over plain tensors is that it enables us to +# handle simple and complex data structures interchangeably. As our function +# signatures are very generic, it eliminates the challenge of accommodating +# different data formats. In simpler terms, after this brief tutorial, +# you will be capable of operating on both simple and highly complex +# environments, as their user-facing API is identical and simple! +# +# Let's put the environment into action and see what a tensordict instance +# looks like: + +reset = env.reset() +print(reset) + +################################ +# Now let's take a random action in the action space. First, sample the action: +reset_with_action = env.rand_action(reset) +print(reset_with_action) + +################################ +# This tensordict has the same structure as the one obtained from +# :meth:`~torchrl.envs.EnvBase` with an additional ``"action"`` entry. +# You can access the action easily, like you would do with a regular +# dictionary: +# + +print(reset_with_action["action"]) + +################################ +# We now need to pass this action tp the environment. +# We'll be passing the entire tensordict to the ``step`` method, since there +# might be more than one tensor to be read in more advanced cases like +# Multi-Agent RL or stateless environments: + +stepped_data = env.step(reset_with_action) +print(stepped_data) + +################################ +# Again, this new tensordict is identical to the previous one except for the +# fact that it has a ``"next"`` entry (itself a tensordict!) containing the +# observation, reward and done state resulting from +# our action. +# +# We call this format TED, for +# :ref:`TorchRL Episode Data format `. It is +# the ubiquitous way of representing data in the library, both dynamically like +# here, or statically with offline datasets. +# +# The last bit of information you need to run a rollout in the environment is +# how to bring that ``"next"`` entry at the root to perform the next step. +# TorchRL provides a dedicated :func:`~torchrl.envs.utils.step_mdp` function +# that does just that: it filters out the information you won't need and +# delivers a data structure corresponding to your observation after a step in +# the Markov Decision Process, or MDP. + +from torchrl.envs import step_mdp + +data = step_mdp(stepped_data) +print(data) + +################################ +# Environment rollouts +# -------------------- +# +# .. _gs_env_ted_rollout: +# +# Writing down those three steps (computing an action, making a step, +# moving in the MDP) can be a bit tedious and repetitive. Fortunately, +# TorchRL provides a nice :meth:`~torchrl.envs.EnvBase.rollout` function that +# allows you to run them in a closed loop at will: +# + +rollout = env.rollout(max_steps=10) +print(rollout) + +################################ +# This data looks pretty much like the ``stepped_data`` above with the +# exception of its batch-size, which now equates the number of steps we +# provided through the ``max_steps`` argument. The magic of tensordict +# doesn't end there: if you're interested in a single transition of this +# environment, you can index the tensordict like you would index a tensor: + +transition = rollout[3] +print(transition) + +################################ +# :class:`~tensordict.TensorDict` will automatically check if the index you +# provided is a key (in which case we index along the key-dimension) or a +# spatial index like here. +# +# Executed as such (without a policy), the ``rollout`` method may seem rather +# useless: it just runs random actions. If a policy is available, it can +# be passed to the method and used to collect data. +# +# Nevertheless, it can useful to run a naive, policyless rollout at first to +# check what is to be expected from an environment at a glance. +# +# To appreciate the versatility of TorchRL's API, consider the fact that the +# rollout method is universally applicable. It functions across **all** use +# cases, whether you're working with a single environment like this one, +# multiple copies across various processes, a multi-agent environment, or even +# a stateless version of it! +# +# +# Transforming an environment +# --------------------------- +# +# Most of the time, you'll want to modify the output of the environment to +# better suit your requirements. For example, you might want to monitor the +# number of steps executed since the last reset, resize images, or stack +# consecutive observations together. +# +# In this section, we'll examine a simple transform, the +# :class:`~torchrl.envs.transforms.StepCounter` transform. +# The complete list of transforms can be found +# :ref:`here `. +# +# The transform is integrated with the environment through a +# :class:`~torchrl.envs.transforms.TransformedEnv`: +# + +from torchrl.envs import StepCounter, TransformedEnv + +transformed_env = TransformedEnv(env, StepCounter(max_steps=10)) +rollout = transformed_env.rollout(max_steps=100) +print(rollout) + +################################ +# As you can see, our environment now has one more entry, ``"step_count"`` that +# tracks the number of steps since the last reset. +# Given that we passed the optional +# argument ``max_steps=10`` to the transform constructor, we also truncated the +# trajectory after 10 steps (not completing a full rollout of 100 steps like +# we asked with the ``rollout`` call). We can see that the trajectory was +# truncated by looking at the truncated entry: + +print(rollout["next", "truncated"]) + +################################ +# +# This is all for this short introduction to TorchRL's environment API! +# +# Next steps +# ---------- +# +# To explore further what TorchRL's environments can do, go and check: +# +# - The :meth:`~torchrl.envs.EnvBase.step_and_maybe_reset` method that packs +# together :meth:`~torchrl.envs.EnvBase.step`, +# :func:`~torchrl.envs.utils.step_mdp` and +# :meth:`~torchrl.envs.EnvBase.reset`. +# - Some environments like :class:`~torchrl.envs.GymEnv` support rendering +# through the ``from_pixels`` argument. Check the class docstrings to know +# more! +# - The batched environments, in particular :class:`~torchrl.envs.ParallelEnv` +# which allows you to run multiple copies of one same (or different!) +# environments on multiple processes. +# - Design your own environment with the +# :ref:`Pendulum tutorial ` and learn about specs and +# stateless environments. +# - See the more in-depth tutorial about environments +# :ref:`in the dedicated tutorial `; +# - Check the +# :ref:`multi-agent environment API ` +# if you're interested in MARL; +# - TorchRL has many tools to interact with the Gym API such as +# a way to register TorchRL envs in the Gym register through +# :meth:`~torchrl.envs.EnvBase.register_gym`, an API to read +# the info dictionaries through +# :meth:`~torchrl.envs.EnvBase.set_info_dict_reader` or a way +# to control the gym backend thanks to +# :func:`~torchrl.envs.set_gym_backend`. +# diff --git a/tutorials/sphinx-tutorials/getting-started-1.py b/tutorials/sphinx-tutorials/getting-started-1.py new file mode 100644 index 00000000000..136deeb5cd9 --- /dev/null +++ b/tutorials/sphinx-tutorials/getting-started-1.py @@ -0,0 +1,309 @@ +# -*- coding: utf-8 -*- +""" +Get started with TorchRL's modules +================================== + +**Author**: `Vincent Moens `_ + +.. _gs_modules: + +""" +################################### +# Reinforcement Learning is designed to create policies that can effectively +# tackle specific tasks. Policies can take various forms, from a differentiable +# map transitioning from the observation space to the action space, to a more +# ad-hoc method like an argmax over a list of values computed for each possible +# action. Policies can be deterministic or stochastic, and may incorporate +# complex elements such as Recurrent Neural Networks (RNNs) or transformers. +# +# Accommodating all these scenarios can be quite intricate. In this succinct +# tutorial, we will delve into the core functionality of TorchRL in terms of +# policy construction. We will primarily focus on stochastic and Q-Value +# policies in two common scenarios: using a Multi-Layer Perceptron (MLP) or +# a Convolutional Neural Network (CNN) as backbones. +# +# TensorDictModules +# ----------------- +# +# Similar to how environments interact with instances of +# :class:`~tensordict.TensorDict`, the modules used to represent policies and +# value functions also do the same. The core idea is simple: encapsulate a +# standard :class:`~torch.nn.Module` (or any other function) within a class +# that knows which entries need to be read and passed to the module, and then +# records the results with the assigned entries. To illustrate this, we will +# use the simplest policy possible: a deterministic map from the observation +# space to the action space. For maximum generality, we will use a +# :class:`~torch.nn.LazyLinear` module with the Pendulum environment we +# instantiated in the previous tutorial. +# + +import torch + +from tensordict.nn import TensorDictModule +from torchrl.envs import GymEnv + +env = GymEnv("Pendulum-v1") +module = torch.nn.LazyLinear(out_features=env.action_spec.shape[-1]) +policy = TensorDictModule( + module, + in_keys=["observation"], + out_keys=["action"], +) + +################################### +# This is all that's required to execute our policy! The use of a lazy module +# allows us to bypass the need to fetch the shape of the observation space, as +# the module will automatically determine it. This policy is now ready to be +# run in the environment: + +rollout = env.rollout(max_steps=10, policy=policy) +print(rollout) + +################################### +# Specialized wrappers +# -------------------- +# +# To simplify the incorporation of :class:`~torch.nn.Module`s into your +# codebase, TorchRL offers a range of specialized wrappers designed to be +# used as actors, including :class:`~torchrl.modules.tensordict_module.Actor`, +# # :class:`~torchrl.modules.tensordict_module.ProbabilisticActor`, +# # :class:`~torchrl.modules.tensordict_module.ActorValueOperator` or +# # :class:`~torchrl.modules.tensordict_module.ActorCriticOperator`. +# For example, :class:`~torchrl.modules.tensordict_module.Actor` provides +# default values for the ``in_keys`` and ``out_keys``, making integration +# with many common environments straightforward: +# + +from torchrl.modules import Actor + +policy = Actor(module) +rollout = env.rollout(max_steps=10, policy=policy) +print(rollout) + +################################### +# The list of available specialized TensorDictModules is available in the +# :ref:`API reference `. +# +# Networks +# -------- +# +# TorchRL also provides regular modules that can be used without recurring to +# tensordict features. The two most common networks you will encounter are +# the :class:`~torchrl.modules.MLP` and the :class:`~torchrl.modules.ConvNet` +# (CNN) modules. We can substitute our policy module with one of these: +# + +from torchrl.modules import MLP + +module = MLP( + out_features=env.action_spec.shape[-1], + num_cells=[32, 64], + activation_class=torch.nn.Tanh, +) +policy = Actor(module) +rollout = env.rollout(max_steps=10, policy=policy) + +################################### +# TorchRL also supports RNN-based policies. Since this is a more technical +# topic, it is treated in :ref:`a separate tutorial `. +# +# Probabilistic policies +# ---------------------- +# +# Policy-optimization algorithms like +# `PPO `_ require the policy to be +# stochastic: unlike in the examples above, the module now encodes a map from +# the observation space to a parameter space encoding a distribution over the +# possible actions. TorchRL facilitates the design of such modules by grouping +# under a single class the various operations such as building the distribution +# from the parameters, sampling from that distribution and retrieving the +# log-probability. Here, we'll be building an actor that relies on a regular +# normal distribution using three components: +# +# - An :class:`~torchrl.modules.MLP` backbone reading observations of size +# ``[3]`` and outputting a single tensor of size ``[2]``; +# - A :class:`~tensordict.nn.distributions.NormalParamExtractor` module that +# will split this output on two chunks, a mean and a standard deviation of +# size ``[1]``; +# - A :class:`~torchrl.modules.tensordict_module.ProbabilisticActor` that will +# read those parameters as ``in_keys``, create a distribution with them and +# populate our tensordict with samples and log-probabilities. +# + +from tensordict.nn.distributions import NormalParamExtractor +from torch.distributions import Normal +from torchrl.modules import ProbabilisticActor + +backbone = MLP(in_features=3, out_features=2) +extractor = NormalParamExtractor() +module = torch.nn.Sequential(backbone, extractor) +td_module = TensorDictModule(module, in_keys=["observation"], out_keys=["loc", "scale"]) +policy = ProbabilisticActor( + td_module, + in_keys=["loc", "scale"], + out_keys=["action"], + distribution_class=Normal, + return_log_prob=True, +) + +rollout = env.rollout(max_steps=10, policy=policy) +print(rollout) + +################################### +# There are a few things to note about this rollout: +# +# - Since we asked for it during the construction of the actor, the +# log-probability of the actions given the distribution at that time is +# also written. This is necessary for algorithms like PPO. +# - The parameters of the distribution are returned within the output +# tensordict too under the ``"loc"`` and ``"scale"`` entries. +# +# You can control the sampling of the action to use the expected value or +# other properties of the distribution instead of using random samples if +# your application requires it. This can be controlled via the +# :func:`~torchrl.envs.utils.set_exploration_type` function: + +from torchrl.envs.utils import ExplorationType, set_exploration_type + +with set_exploration_type(ExplorationType.MEAN): + # takes the mean as action + rollout = env.rollout(max_steps=10, policy=policy) +with set_exploration_type(ExplorationType.RANDOM): + # Samples actions according to the dist + rollout = env.rollout(max_steps=10, policy=policy) + +################################### +# Check the ``default_interaction_type`` keyword argument in +# the docstrings to know more. +# +# Exploration +# ----------- +# +# Stochastic policies like this somewhat naturally trade off exploration and +# exploitation, but deterministic policies won't. Fortunately, TorchRL can +# also palliate to this with its exploration modules. +# We will take the example of the :class:`~torchrl.modules.EGreedyModule` +# exploration module (check also +# :class:`~torchrl.modules.AdditiveGaussianWrapper` and +# :class:`~torchrl.modules.OrnsteinUhlenbeckProcessWrapper`). +# To see this module in action, let's revert to a deterministic policy: + +from tensordict.nn import TensorDictSequential +from torchrl.modules import EGreedyModule + +policy = Actor(MLP(3, 1, num_cells=[32, 64])) + +################################### +# Our :math:`\epsilon`-greedy exploration module will usually be customized +# with a number of annealing frames and an initial value for the +# :math:`\epsilon` parameter. A value of :math:`\epsilon = 1` means that every +# action taken is random, while :math:`\epsilon=0` means that there is no +# exploration at all. To anneal (i.e., decrease) the exploration factor, a call +# to :meth:`~torchrl.modules.EGreedyModule.step` is required (see the last +# :ref:`tutorial ` for an example). +# +exploration_module = EGreedyModule( + spec=env.action_spec, annealing_num_steps=1000, eps_init=0.5 +) + +################################### +# To build our explorative policy, we only had to concatenate the +# deterministic policy module with the exploration module within a +# :class:`~tensordict.nn.TensorDictSequential` module (which is the analogous +# to :class:`~torch.nn.Sequential` in the tensordict realm). + +exploration_policy = TensorDictSequential(policy, exploration_module) + +with set_exploration_type(ExplorationType.MEAN): + # Turns off exploration + rollout = env.rollout(max_steps=10, policy=exploration_policy) +with set_exploration_type(ExplorationType.RANDOM): + # Turns on exploration + rollout = env.rollout(max_steps=10, policy=exploration_policy) + +################################### +# Because it must be able to sample random actions in the action space, the +# :class:`~torchrl.modules.EGreedyModule` must be equipped with the +# ``action_space`` from the environment to know what strategy to use to +# sample actions randomly. +# +# Q-Value actors +# -------------- +# +# In some settings, the policy isn't a standalone module but is constructed on +# top of another module. This is the case with **Q-Value actors**. In short, these +# actors require an estimate of the action value (most of the time discrete) +# and will greedily pick up the action with the highest value. In some +# settings (finite discrete action space and finite discrete state space), +# one can just store a 2D table of state-action pairs and pick up the +# action with the highest value. The innovation brought by +# `DQN `_ was to scale this up to continuous +# state spaces by utilizing a neural network to encode for the ``Q(s, a)`` +# value map. Let's consider another environment with a discrete action space +# for a clearer understanding: + +env = GymEnv("CartPole-v1") +print(env.action_spec) + +################################### +# We build a value network that produces one value per action when it reads a +# state from the environment: + +num_actions = 2 +value_net = TensorDictModule( + MLP(out_features=num_actions, num_cells=[32, 32]), + in_keys=["observation"], + out_keys=["action_value"], +) + +################################### +# We can easily build our Q-Value actor by adding a +# :class:`~torchrl.modules.tensordict_module.QValueModule` after our value +# network: + +from torchrl.modules import QValueModule + +policy = TensorDictSequential( + value_net, # writes action values in our tensordict + QValueModule( + action_space=env.action_spec + ), # Reads the "action_value" entry by default +) + +################################### +# Let's check it out! We run the policy for a couple of steps and look at the +# output. We should find an ``"action_value"`` as well as a +# ``"chosen_action_value"`` entries in the rollout that we obtain: +# + +rollout = env.rollout(max_steps=3, policy=policy) +print(rollout) + +################################### +# Because it relies on the ``argmax`` operator, this policy is deterministic. +# During data collection, we will need to explore the environment. For that, +# we are using the :class:`~torchrl.modules.tensordict_module.EGreedyModule` +# once again: + +policy_explore = TensorDictSequential(policy, EGreedyModule(env.action_spec)) + +with set_exploration_type(ExplorationType.RANDOM): + rollout_explore = env.rollout(max_steps=3, policy=policy_explore) + +################################### +# This is it for our short tutorial on building a policy with TorchRL! +# +# There are many more things you can do with the library. A good place to start +# is to look at the :ref:`API reference for modules `. +# +# Next steps: +# +# - Check how to use compound distributions with +# :class:`~tensordict.nn.distributions.CompositeDistribution` when the +# action is composite (e.g., a discrete and a continuous action are +# required by the env); +# - Have a look at how you can use an RNN within the policy (a +# :ref:`tutorial `); +# - Compare this to the usage of transformers with the Decision Transformers +# examples (see the ``example`` directory on GitHub). +# diff --git a/tutorials/sphinx-tutorials/getting-started-2.py b/tutorials/sphinx-tutorials/getting-started-2.py new file mode 100644 index 00000000000..1d903e67c01 --- /dev/null +++ b/tutorials/sphinx-tutorials/getting-started-2.py @@ -0,0 +1,173 @@ +# -*- coding: utf-8 -*- +""" +Getting started with model optimization +======================================= + +**Author**: `Vincent Moens `_ + +.. _gs_optim: + +""" + +################################### +# In TorchRL, we try to treat optimization as it is custom to do in PyTorch, +# using dedicated loss modules which are designed with the sole purpose of +# optimizing the model. This approach efficiently decouples the execution of +# the policy from its training and allows us to design training loops that are +# similar to what can be found in traditional supervised learning examples. +# +# The typical training loop therefore looks like this: +# +# >>> for i in range(n_collections): +# ... data = get_next_batch(env, policy) +# ... for j in range(n_optim): +# ... loss = loss_fn(data) +# ... loss.backward() +# ... optim.step() +# +# In this concise tutorial, you will receive a brief overview of the loss modules. Due to the typically +# straightforward nature of the API for basic usage, this tutorial will be kept brief. +# +# RL objective functions +# ---------------------- +# +# In RL, innovation typically involves the exploration of novel methods +# for optimizing a policy (i.e., new algorithms), rather than focusing +# on new architectures, as seen in other domains. Within TorchRL, +# these algorithms are encapsulated within loss modules. A loss +# module orchestrates the various components of your algorithm and +# yields a set of loss values that can be backpropagated +# through to train the corresponding components. +# +# In this tutorial, we will take a popular +# off-policy algorithm as an example, +# `DDPG `_. +# +# To build a loss module, the only thing one needs is a set of networks +# defined as :class:`~tensordict.nn.TensorDictModule`s. Most of the time, one +# of these modules will be the policy. Other auxiliary networks such as +# Q-Value networks or critics of some kind may be needed as well. Let's see +# what this looks like in practice: DDPG requires a deterministic +# map from the observation space to the action space as well as a value +# network that predicts the value of a state-action pair. The DDPG loss will +# attempt to find the policy parameters that output actions that maximize the +# value for a given state. +# +# To build the loss, we need both the actor and value networks. +# If they are built according to DDPG's expectations, it is all +# we need to get a trainable loss module: + +from torchrl.envs import GymEnv + +env = GymEnv("Pendulum-v1") + +from torchrl.modules import Actor, MLP, ValueOperator +from torchrl.objectives import DDPGLoss + +n_obs = env.observation_spec["observation"].shape[-1] +n_act = env.action_spec.shape[-1] +actor = Actor(MLP(in_features=n_obs, out_features=n_act, num_cells=[32, 32])) +value_net = ValueOperator( + MLP(in_features=n_obs + n_act, out_features=1, num_cells=[32, 32]), + in_keys=["observation", "action"], +) + +ddpg_loss = DDPGLoss(actor_network=actor, value_network=value_net) + +################################### +# And that is it! Our loss module can now be run with data coming from the +# environment (we omit exploration, storage and other features to focus on +# the loss functionality): +# + +rollout = env.rollout(max_steps=100, policy=actor) +loss_vals = ddpg_loss(rollout) +print(loss_vals) + +################################### +# LossModule's output +# ------------------- +# +# As you can see, the value we received from the loss isn't a single scalar +# but a dictionary containing multiple losses. +# +# The reason is simple: because more than one network may be trained at a time, +# and since some users may wish to separate the optimization of each module +# in distinct steps, TorchRL's objectives will return dictionaries containing +# the various loss components. +# +# This format also allows us to pass metadata along with the loss values. In +# general, we make sure that only the loss values are differentiable such that +# you can simply sum over the values of the dictionary to obtain the total +# loss. If you want to make sure you're fully in control of what is happening, +# you can sum over only the entries which keys start with the ``"loss_"`` prefix: +# + +total_loss = 0 +for key, val in loss_vals.items(): + if key.startswith("loss_"): + total_loss += val + +################################### +# Training a LossModule +# --------------------- +# +# Given all this, training the modules is not so different from what would be +# done in any other training loop. Because it wraps the modules, +# the easiest way to get the list of trainable parameters is to query +# the :meth:`~torchrl.objectives.LossModule.parameters` method. +# +# We'll need an optimizer (or one optimizer +# per module if that is your choice). +# + +from torch.optim import Adam + +optim = Adam(ddpg_loss.parameters()) +total_loss.backward() + +################################### +# The following items will typically be +# found in your training loop: + +optim.step() +optim.zero_grad() + +################################### +# Further considerations: Target parameters +# ----------------------------------------- +# +# Another important aspect to consider is the presence of target parameters +# in off-policy algorithms like DDPG. Target parameters typically represent +# a delayed or smoothed version of the parameters over time, and they play +# a crucial role in value estimation during policy training. Utilizing target +# parameters for policy training often proves to be significantly more +# efficient compared to using the current configuration of value network +# parameters. Generally, managing target parameters is handled by the loss +# module, relieving users of direct concern. However, it remains the user's +# responsibility to update these values as necessary based on specific +# requirements. TorchRL offers a couple of updaters, namely +# :class:`~torchrl.objectives.HardUpdate` and +# :class:`~torchrl.objectives.SoftUpdate`, +# which can be easily instantiated without requiring in-depth +# knowledge of the underlying mechanisms of the loss module. +# +from torchrl.objectives import SoftUpdate + +updater = SoftUpdate(ddpg_loss, eps=0.99) + +################################### +# In your training loop, you will need to update the target parameters at each +# optimization step or each collection step: + +updater.step() + +################################### +# This is all you need to know about loss modules to get started! +# +# To further explore the topic, have a look at: +# +# - The :ref:`loss module reference page `; +# - The :ref:`Coding a DDPG loss tutorial `; +# - Losses in action in :ref:`PPO ` or :ref:`DQN `. +# diff --git a/tutorials/sphinx-tutorials/getting-started-3.py b/tutorials/sphinx-tutorials/getting-started-3.py new file mode 100644 index 00000000000..97934ef424d --- /dev/null +++ b/tutorials/sphinx-tutorials/getting-started-3.py @@ -0,0 +1,180 @@ +# -*- coding: utf-8 -*- +""" +Get started with data collection and storage +============================================ + +**Author**: `Vincent Moens `_ + +.. _gs_storage: + +""" + +################################# +# +# There is no learning without data. In supervised learning, users are +# accustomed to using :class:`~torch.utils.data.DataLoader` and the like +# to integrate data in their training loop. +# Dataloaders are iterable objects that provide you with the data that you will +# be using to train your model. +# +# TorchRL approaches the problem of dataloading in a similar manner, although +# it is surprisingly unique in the ecosystem of RL libraries. TorchRL's +# dataloaders are referred to as ``DataCollectors``. Most of the time, +# data collection does not stop at the collection of raw data, +# as the data needs to be stored temporarily in a buffer +# (or equivalent structure for on-policy algorithms) before being consumed +# by the :ref:`loss module `. This tutorial will explore +# these two classes. +# +# Data collectors +# --------------- +# +# .. _gs_storage_collector: +# +# +# The primary data collector discussed here is the +# :class:`~torchrl.collectors.SyncDataCollector`, which is the focus of this +# documentation. At a fundamental level, a collector is a straightforward +# class responsible for executing your policy within the environment, +# resetting the environment when necessary, and providing batches of a +# predefined size. Unlike the :meth:`~torchrl.envs.EnvBase.rollout` method +# demonstrated in :ref:`the env tutorial `, collectors do not +# reset between consecutive batches of data. Consequently, two successive +# batches of data may contain elements from the same trajectory. +# +# The basic arguments you need to pass to your collector are the size of the +# batches you want to collect (``frames_per_batch``), the length (possibly +# infinite) of the iterator, the policy and the environment. For simplicity, +# we will use a dummy, random policy in this example. + +import torch + +torch.manual_seed(0) + +from torchrl.collectors import RandomPolicy, SyncDataCollector +from torchrl.envs import GymEnv + +env = GymEnv("CartPole-v1") +env.set_seed(0) + +policy = RandomPolicy(env.action_spec) +collector = SyncDataCollector(env, policy, frames_per_batch=200, total_frames=-1) + +################################# +# We now expect that our collector will deliver batches of size ``200`` no +# matter what happens during collection. In other words, we may have multiple +# trajectories in this batch! The ``total_frames`` indicates how long the +# collector should be. A value of ``-1`` will produce a never +# ending collector. +# +# Let's iterate over the collector to get a sense +# of what this data looks like: + +for data in collector: + print(data) + break + +################################# +# As you can see, our data is augmented with some collector-specific metadata +# grouped in a ``"collector"`` sub-tensordict that we did not see during +# :ref:`environment rollouts `. This is useful to keep track of +# the trajectory ids. In the following list, each item marks the trajectory +# number the corresponding transition belongs to: + +print(data["collector", "traj_ids"]) + +################################# +# Data collectors are very useful when it comes to coding state-of-the-art +# algorithms, as performance is usually measured by the capability of a +# specific technique to solve a problem in a given number of interactions with +# the environment (the ``total_frames`` argument in the collector). +# For this reason, most training loops in our examples look like this: +# +# >>> for data in collector: +# ... # your algorithm here +# +# +# Replay Buffers +# -------------- +# +# .. _gs_storage_rb: +# +# Now that we have explored how to collect data, we would like to know how to +# store it. In RL, the typical setting is that the data is collected, stored +# temporarily and cleared after a little while given some heuristic: +# first-in first-out or other. A typical pseudo-code would look like this: +# +# >>> for data in collector: +# ... storage.store(data) +# ... for i in range(n_optim): +# ... sample = storage.sample() +# ... loss_val = loss_fn(sample) +# ... loss_val.backward() +# ... optim.step() # etc +# +# The parent class that stores the data in TorchRL +# is referred to as :class:`~torchrl.data.ReplayBuffer`. TorchRL's replay +# buffers are composable: you can edit the storage type, their sampling +# technique, the writing heuristic or the transforms applied to them. We will +# leave the fancy stuff for a dedicated in-depth tutorial. The generic replay +# buffer only needs to know what storage it has to use. In general, we +# recommend a :class:`~torchrl.data.TensorStorage` subclass, which will work +# fine in most cases. We'll be using +# :class:`~torchrl.data.replay_buffers.LazyMemmapStorage` +# in this tutorial, which enjoys two nice properties: first, being "lazy", +# you don't need to explicitly tell it what your data looks like in advance. +# Second, it uses :class:`~tensordict.MemoryMappedTensor` as a backend to save +# your data on disk in an efficient way. The only thing you need to know is +# how big you want your buffer to be. + +from torchrl.data.replay_buffers import LazyMemmapStorage, ReplayBuffer + +buffer = ReplayBuffer(storage=LazyMemmapStorage(max_size=1000)) + +################################# +# Populating the buffer can be done via the +# :meth:`~torchrl.data.ReplayBuffer.add` (single element) or +# :meth:`~torchrl.data.ReplayBuffer.extend` (multiple elements) methods. Using +# the data we just collected, we initialize and populate the buffer in one go: + +indices = buffer.extend(data) + +################################# +# We can check that the buffer now has the same number of elements than what +# we got from the collector: + +assert len(buffer) == collector.frames_per_batch + +################################# +# The only thing left to know is how to gather data from the buffer. +# Naturally, this relies on the :meth:`~torchrl.data.ReplayBuffer.sample` +# method. Because we did not specify that sampling had to be done without +# repetitions, it is not guaranteed that the samples gathered from our buffer +# will be unique: + +sample = buffer.sample(batch_size=30) +print(sample) + +################################# +# Again, our sample looks exactly the same as the data we gathered from the +# collector! +# +# Next steps +# ---------- +# +# - You can have look at other multirpocessed +# collectors such as :class:`~torchrl.collectors.collectors.MultiSyncDataCollector` or +# :class:`~torchrl.collectors.collectors.MultiaSyncDataCollector`. +# - TorchRL also offers distributed collectors if you have multiple nodes to +# use for inference. Check them out in the +# :ref:`API reference `. +# - Check the dedicated :ref:`Replay Buffer tutorial ` to know +# more about the options you have when building a buffer, or the +# :ref:`API reference ` which covers all the features in +# details. Replay buffers have countless features such as multithreaded +# sampling, prioritized experience replay, and many more... +# - We left out the capacity of replay buffers to be iterated over for +# simplicity. Try it out for yourself: build a buffer and indicate its +# batch-size in the constructor, then try to iterate over it. This is +# equivalent to calling ``rb.sample()`` within a loop! +# diff --git a/tutorials/sphinx-tutorials/getting-started-4.py b/tutorials/sphinx-tutorials/getting-started-4.py new file mode 100644 index 00000000000..bff30d79851 --- /dev/null +++ b/tutorials/sphinx-tutorials/getting-started-4.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +""" +Get started with logging +======================== + +**Author**: `Vincent Moens `_ + +.. _gs_logging: + +""" + +##################################### +# The final chapter of this series before we orchestrate everything in a +# training script is to learn about logging. +# +# Loggers +# ------- +# +# Logging is crucial for reporting your results to the outside world and for +# you to check that your algorithm is learning properly. TorchRL has several +# loggers that interface with custom backends such as +# wandb (:class:`~torchrl.record.loggers.wandb.WandbLogger`), +# tensorboard (:class:`~torchrl.record.loggers.tensorboard.TensorBoardLogger`) or a lightweight and +# portable CSV logger (:class:`~torchrl.record.loggers.csv.CSVLogger`) that you can use +# pretty much everywhere. +# +# Loggers are located in the ``torchrl.record`` module and the various classes +# can be found in the :ref:`API reference `. +# +# We tried to keep the loggers APIs as similar as we could, given the +# differences in the underlying backends. While execution of the loggers will +# mostly be interchangeable, their instantiation can differ. +# +# Usually, building a logger requires +# at least an experiment name and possibly a logging directory and other +# hyperapameters. +# + +from torchrl.record import CSVLogger + +logger = CSVLogger(exp_name="my_exp") + +##################################### +# Once the logger is instantiated, the only thing left to do is call the +# logging methods! For example, :meth:`~torchrl.record.CSVLogger.log_scalar` +# is used in several places across the training examples to log values such as +# reward, loss value or time elapsed for executing a piece of code. + +logger.log_scalar("my_scalar", 0.4) + +##################################### +# Recording videos +# ---------------- +# +# Finally, it can come in handy to record videos of a simulator. Some +# environments (e.g., Atari games) are already rendered as images whereas +# others require you to create them as such. Fortunately, in most common cases, +# rendering and recording videos isn't too difficult. +# +# Let's first see how we can create a Gym environment that outputs images +# alongside its observations. :class:`~torchrl.envs.GymEnv` accept two keywords +# for this purpose: ``from_pixels=True`` will make the env ``step`` function +# write a ``"pixels"`` entry containing the images corresponding to your +# observations, and the ``pixels_only=False`` will indicate that you want the +# observations to be returned as well. +# + +from torchrl.envs import GymEnv + +env = GymEnv("CartPole-v1", from_pixels=True, pixels_only=False) + +print(env.rollout(max_steps=3)) + +from torchrl.envs import TransformedEnv + +##################################### +# We now have built an environment that renders images with its observations. +# To record videos, we will need to combine that environment with a recorder +# and the logger (the logger providing the backend to save the video). +# This will happen within a transformed environment, like the one we saw in +# the :ref:`first tutorial `. + +from torchrl.record import VideoRecorder + +recorder = VideoRecorder(logger, tag="my_video") +record_env = TransformedEnv(env, recorder) + +##################################### +# When running this environment, all the ``"pixels"`` entries will be saved in +# a local buffer and dumped in a video on demand (it is important that you +# call this method when appropriate): + +rollout = record_env.rollout(max_steps=3) +# Uncomment this line to save the video on disk: +# recorder.dump() + +##################################### +# In this specific case, the video format can be chosen when instantiating +# the CSVLogger. +# +# This is all we wanted to cover in the getting started tutorial. +# You should now be ready to code your +# :ref:`first training loop with TorchRL `! +# diff --git a/tutorials/sphinx-tutorials/getting-started-5.py b/tutorials/sphinx-tutorials/getting-started-5.py new file mode 100644 index 00000000000..8413d0c9130 --- /dev/null +++ b/tutorials/sphinx-tutorials/getting-started-5.py @@ -0,0 +1,183 @@ +# -*- coding: utf-8 -*- +""" +Get started with your onw first training loop +============================================= + +**Author**: `Vincent Moens `_ + +.. _gs_first_training: + +""" + +################################# +# Time to wrap up everything we've learned so far in this Getting Started +# series! +# +# In this tutorial, we will be writing the most basic training loop there is +# using only components we have presented in the previous lessons. +# +# We'll be using DQN with a CartPole environment as a prototypical example. +# +# We will be voluntarily keeping the verbosity to its minimum, only linking +# each section to the related tutorial. +# +# Building the environment +# ------------------------ +# +# We'll be using a gym environment with a :class:`~torchrl.envs.transforms.StepCounter` +# transform. If you need a refresher, check our these features are presented in +# :ref:`the environment tutorial `. +# + +import torch + +torch.manual_seed(0) + +import time + +from torchrl.envs import GymEnv, StepCounter, TransformedEnv + +env = TransformedEnv(GymEnv("CartPole-v1"), StepCounter()) +env.set_seed(0) + +from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq + +################################# +# Designing a policy +# ------------------ +# +# The next step is to build our policy. We'll be making a regular, deterministic +# version to be used within the :ref:`loss module ` and during +# :ref:`evaluation `, and one augmented by an exploration module +# for :ref:`inference `. + +from torchrl.modules import EGreedyModule, MLP, QValueModule + +value_mlp = MLP(out_features=env.action_spec.shape[-1], num_cells=[64, 64]) +value_net = Mod(value_mlp, in_keys=["observation"], out_keys=["action_value"]) +policy = Seq(value_net, QValueModule(env.action_spec)) +exploration_module = EGreedyModule( + env.action_spec, annealing_num_steps=100_000, eps_init=0.5 +) +policy_explore = Seq(policy, exploration_module) + + +################################# +# Data Collector and replay buffer +# -------------------------------- +# +# Here comes the data part: we need a +# :ref:`data collector ` to easily get batches of data +# and a :ref:`replay buffer ` to store that data for training. +# + +from torchrl.collectors import SyncDataCollector +from torchrl.data import LazyTensorStorage, ReplayBuffer + +init_rand_steps = 5000 +frames_per_batch = 100 +optim_steps = 10 +collector = SyncDataCollector( + env, + policy, + frames_per_batch=frames_per_batch, + total_frames=-1, + init_random_frames=init_rand_steps, +) +rb = ReplayBuffer(storage=LazyTensorStorage(100_000)) + +from torch.optim import Adam + +################################# +# Loss module and optimizer +# ------------------------- +# +# We build our loss as indicated in the :ref:`dedicated tutorial `, with +# its optimizer and target parameter updater: + +from torchrl.objectives import DQNLoss, SoftUpdate + +loss = DQNLoss(value_network=policy, action_space=env.action_spec, delay_value=True) +optim = Adam(loss.parameters(), lr=0.02) +updater = SoftUpdate(loss, eps=0.99) + +################################# +# Logger +# ------ +# +# We'll be using a CSV logger to log our results, and save rendered videos. +# + +from torchrl._utils import logger as torchrl_logger +from torchrl.record import CSVLogger, VideoRecorder + +path = "./training_loop" +logger = CSVLogger(exp_name="dqn", log_dir=path, video_format="mp4") +video_recorder = VideoRecorder(logger, tag="video") +record_env = TransformedEnv( + GymEnv("CartPole-v1", from_pixels=True, pixels_only=False), video_recorder +) + +################################# +# Training loop +# ------------- +# +# Instead of fixing a specific number of iterations to run, we will keep on +# training the network until it reaches a certain performance (arbitrarily +# defined as 200 steps in the environment -- with CartPole, success is defined +# as having longer trajectories). +# + +total_count = 0 +total_episodes = 0 +t0 = time.time() +for i, data in enumerate(collector): + # Write data in replay buffer + rb.extend(data) + max_length = rb[:]["next", "step_count"].max() + if len(rb) > init_rand_steps: + # Optim loop (we do several optim steps + # per batch collected for efficiency) + for _ in range(optim_steps): + sample = rb.sample(128) + loss_vals = loss(sample) + loss_vals["loss"].backward() + optim.step() + optim.zero_grad() + # Update exploration factor + exploration_module.step(data.numel()) + # Update target params + updater.step() + if i % 10: + torchrl_logger.info(f"Max num steps: {max_length}, rb length {len(rb)}") + total_count += data.numel() + total_episodes += data["next", "done"].sum() + if max_length > 200: + break + +t1 = time.time() + +torchrl_logger.info( + f"solved after {total_count} steps, {total_episodes} episodes and in {t1-t0}s." +) + +################################# +# Rendering +# --------- +# +# Finally, we run the environment for as many steps as we can and save the +# video locally (notice that we are not exploring). + +record_env.rollout(max_steps=1000, policy=policy) +video_recorder.dump() + +################################# +# +# This is what your rendered CartPole video will look like after a full +# training loop: +# +# .. figure:: /_static/img/cartpole.gif +# +# This concludes our series of "Getting started with TorchRL" tutorials! +# Feel free to share feedback about it on GitHub. +# diff --git a/tutorials/sphinx-tutorials/pendulum.py b/tutorials/sphinx-tutorials/pendulum.py index a67976566d5..8e7817978e4 100644 --- a/tutorials/sphinx-tutorials/pendulum.py +++ b/tutorials/sphinx-tutorials/pendulum.py @@ -6,6 +6,8 @@ **Author**: `Vincent Moens `_ +.. _pendulum_tuto: + Creating an environment (a simulator or an interface to a physical control system) is an integrative part of reinforcement learning and control engineering. diff --git a/tutorials/sphinx-tutorials/rb_tutorial.py b/tutorials/sphinx-tutorials/rb_tutorial.py index 3d37ce3de83..2c5cd95e780 100644 --- a/tutorials/sphinx-tutorials/rb_tutorial.py +++ b/tutorials/sphinx-tutorials/rb_tutorial.py @@ -5,6 +5,8 @@ **Author**: `Vincent Moens `_ +.. _rb_tuto: + """ ###################################################################### # Replay buffers are a central piece of any RL or control algorithm. @@ -30,17 +32,24 @@ # # # In this tutorial, you will learn: -# - How to build a Replay Buffer (RB) and use it with any datatype; -# - How to use RBs with TensorDict; -# - How to sample from or iterate over a replay buffer, and how to define the sampling strategy; -# - How to use prioritized replay buffers; -# - How to transform data coming in and out from the buffer; -# - How to store trajectories in the buffer. +# +# - How to build a :ref:`Replay Buffer (RB) ` and use it with +# any datatype; +# - How to customize the :ref:`buffer's storage `; +# - How to use :ref:`RBs with TensorDict `; +# - How to :ref:`sample from or iterate over a replay buffer `, +# and how to define the sampling strategy; +# - How to use :ref:`prioritized replay buffers `; +# - How to :ref:`transform data ` coming in and out from +# the buffer; +# - How to store :ref:`trajectories ` in the buffer. # # # Basics: building a vanilla replay buffer # ---------------------------------------- # +# .. _tuto_rb_vanilla: +# # TorchRL's replay buffers are designed to prioritize modularity, # composability, efficiency, and simplicity. For instance, creating a basic # replay buffer is a straightforward process, as shown in the following @@ -77,7 +86,7 @@ ###################################################################### # By default, this replay buffer will have a size of 1000. Let's check this -# by populating our buffer using the :meth:`torchrl.data.ReplayBuffer.extend` +# by populating our buffer using the :meth:`~torchrl.data.ReplayBuffer.extend` # method: # @@ -87,24 +96,24 @@ print("length after adding elements:", len(buffer)) -import torch -from tensordict import TensorDict - ###################################################################### -# We have used the :meth:`torchrl.data.ReplayBuffer.extend` method which is +# We have used the :meth:`~torchrl.data.ReplayBuffer.extend` method which is # designed to add multiple items all at once. If the object that is passed # to ``extend`` has more than one dimension, its first dimension is # considered to be the one to be split in separate elements in the buffer. +# # This essentially means that when adding multidimensional tensors or # tensordicts to the buffer, the buffer will only look at the first dimension # when counting the elements it holds in memory. # If the object passed it not iterable, an exception will be thrown. # -# To add items one at a time, the :meth:`torchrl.data.ReplayBuffer.add` method +# To add items one at a time, the :meth:`~torchrl.data.ReplayBuffer.add` method # should be used instead. # # Customizing the storage -# ~~~~~~~~~~~~~~~~~~~~~~~ +# ----------------------- +# +# .. _tuto_rb_storage: # # We see that the buffer has been capped to the first 1000 elements that we # passed to it. @@ -112,25 +121,27 @@ # # TorchRL proposes three types of storages: # -# - The :class:`torchrl.dataListStorage` stores elements independently in a +# - The :class:`~torchrl.data.ListStorage` stores elements independently in a # list. It supports any data type, but this flexibility comes at the cost # of efficiency; -# - The :class:`torchrl.dataLazyTensorStorage` stores tensors or -# :class:`tensordidct.TensorDict` (or :class:`torchrl.data.tensorclass`) +# - The :class:`~torchrl.data.LazyTensorStorage` stores tensors data +# structures contiguously. +# It works naturally with :class:`~tensordidct.TensorDict` +# (or :class:`~torchrl.data.tensorclass`) # objects. The storage is contiguous on a per-tensor basis, meaning that # sampling will be more efficient than when using a list, but the # implicit restriction is that any data passed to it must have the same -# basic properties as the -# first batch of data that was used to instantiate the buffer. +# basic properties (such as shape and dtype) as the first batch of data that +# was used to instantiate the buffer. # Passing data that does not match this requirement will either raise an # exception or lead to some undefined behaviours. -# - The :class:`torchrl.dataLazyMemmapStorage` works as the -# :class:`torchrl.data.LazyTensorStorage` in that it is lazy (ie. it +# - The :class:`~torchrl.data.LazyMemmapStorage` works as the +# :class:`~torchrl.data.LazyTensorStorage` in that it is lazy (i.e., it # expects the first batch of data to be instantiated), and it requires data # that match in shape and dtype for each batch stored. What makes this -# storage unique is that it points to disk files, meaning that it can -# support very large datasets while still accessing data in a contiguous -# manner. +# storage unique is that it points to disk files (or uses the filesystem +# storage), meaning that it can support very large datasets while still +# accessing data in a contiguous manner. # # Let us see how we can use each of these storages: @@ -149,9 +160,9 @@ ###################################################################### # Because it is the one with the lowest amount of assumption, the -# :class:`torchrl.data.ListStorage` is the default storage in TorchRL. +# :class:`~torchrl.data.ListStorage` is the default storage in TorchRL. # -# A :class:`torchrl.data.LazyTensorStorage` can store data contiguously. +# A :class:`~torchrl.data.LazyTensorStorage` can store data contiguously. # This should be the preferred option when dealing with complicated but # unchanging data structures of medium size: @@ -161,6 +172,10 @@ # Let us create a batch of data of size ``torch.Size([3])` with 2 tensors # stored in it: # + +import torch +from tensordict import TensorDict + data = TensorDict( { "a": torch.arange(12).view(3, 4), @@ -171,7 +186,7 @@ print(data) ###################################################################### -# The first call to :meth:`torchrl.data.ReplayBuffer.extend` will +# The first call to :meth:`~torchrl.data.ReplayBuffer.extend` will # instantiate the storage. The first dimension of the data is unbound into # separate datapoints: @@ -186,7 +201,7 @@ print("samples", sample["a"], sample["b", "c"]) ###################################################################### -# A :class:`torchrl.data.LazyMemmapStorage` is created in the same manner: +# A :class:`~torchrl.data.LazyMemmapStorage` is created in the same manner: # buffer_lazymemmap = ReplayBuffer(storage=LazyMemmapStorage(size)) @@ -213,16 +228,20 @@ # Integration with TensorDict # --------------------------- # +# .. _tuto_rb_td: +# # The tensor location follows the same structure as the TensorDict that # contains them: this makes it easy to save and load buffers during training. # -# To use :class:`tensordict.TensorDict` as a data carrier at its fullest -# potential, the :class:`torchrl.data.TensorDictReplayBuffer` class should +# To use :class:`~tensordict.TensorDict` as a data carrier at its fullest +# potential, the :class:`~torchrl.data.TensorDictReplayBuffer` class can # be used. # One of its key benefits is its ability to handle the organization of sampled # data, along with any additional information that may be required # (such as sample indices). -# It can be built in the same manner as a standard :class:`torchrl.data.ReplayBuffer` and can +# +# It can be built in the same manner as a standard +# :class:`~torchrl.data.ReplayBuffer` and can # generally be used interchangeably. # @@ -250,7 +269,7 @@ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # The ReplayBuffer class and associated subclasses also work natively with -# :class:`tensordict.tensorclass` classes, which can conviniently be used to +# :class:`~tensordict.tensorclass` classes, which can conveniently be used to # encode datasets in a more explicit manner: from tensordict import tensorclass @@ -284,31 +303,28 @@ class MyData: ###################################################################### # As expected. the data has the proper class and shape! # -# Integration with PyTree -# ~~~~~~~~~~~~~~~~~~~~~~~ +# Integration with other tensor structures (PyTrees) +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # TorchRL's replay buffers also work with any pytree data structure. # A PyTree is a nested structure of arbitrary depth made of dicts, lists and/or # tuples where the leaves are tensors. # This means that one can store in contiguous memory any such tree structure! # Various storages can be used: -# :class:`~torchrl.data.replay_buffers.TensorStorage`, :class:`~torchrl.data.replay_buffers.LazyMemmapStorage` -# or :class:`~torchrl.data.replay_buffers.LazyTensorStorage` all accept this kind of data. +# :class:`~torchrl.data.replay_buffers.TensorStorage`, +# :class:`~torchrl.data.replay_buffers.LazyMemmapStorage` +# or :class:`~torchrl.data.replay_buffers.LazyTensorStorage` all accept this +# kind of data. # -# Here is a bried demonstration of what this feature looks like: +# Here is a brief demonstration of what this feature looks like: # from torch.utils._pytree import tree_map -# With pytrees, any callable can be used as a transform: -def transform(x): - # Zeros all the data in the pytree - return tree_map(lambda y: y * 0, x) - - +###################################################################### # Let's build our replay buffer on disk: -rb = ReplayBuffer(storage=LazyMemmapStorage(size), transform=transform) +rb = ReplayBuffer(storage=LazyMemmapStorage(size)) data = { "a": torch.randn(3), "b": {"c": (torch.zeros(2), [torch.ones(1)])}, @@ -320,6 +336,20 @@ def transform(x): sample = rb.sample(10) +###################################################################### +# With pytrees, any callable can be used as a transform: + + +def transform(x): + # Zeros all the data in the pytree + return tree_map(lambda y: y * 0, x) + + +rb.append_transform(transform) +sample = rb.sample(batch_size=12) + + +###################################################################### # let's check that our transform did its job: def assert0(x): assert (x == 0).all() @@ -328,9 +358,12 @@ def assert0(x): tree_map(assert0, sample) +###################################################################### # Sampling and iterating over buffers # ----------------------------------- # +# .. _tuto_rb_sampling: +# # Replay Buffers support multiple sampling strategies: # # - If the batch-size is fixed and can be defined at construction time, it can @@ -338,7 +371,7 @@ def assert0(x): # - With a fixed batch-size, the replay buffer can be iterated over to gather # samples; # - If the batch-size is dynamic, it can be passed to the -# :class:`torchrl.data.ReplayBuffer.sample` method +# :class:`~torchrl.data.ReplayBuffer.sample` method # on-the-fly. # # Sampling can be done using multithreading, but this is incompatible with the @@ -349,21 +382,22 @@ def assert0(x): # # Fixed batch-size # ~~~~~~~~~~~~~~~~ -# If the batch-size is passed during construction, it should be omited when +# +# If the batch-size is passed during construction, it should be omitted when # sampling: data = MyData( images=torch.randint( 255, - (10, 64, 64, 3), + (200, 64, 64, 3), ), - labels=torch.randint(100, (10,)), - batch_size=[10], + labels=torch.randint(100, (200,)), + batch_size=[200], ) buffer_lazymemmap = ReplayBuffer(storage=LazyMemmapStorage(size), batch_size=128) -buffer_lazymemmap.add(data) -buffer_lazymemmap.sample() # will produces 128 identical samples +buffer_lazymemmap.extend(data) +buffer_lazymemmap.sample() ###################################################################### @@ -371,19 +405,20 @@ def assert0(x): # # To enable multithreaded sampling, just pass a positive integer to the # ``prefetch`` keyword argument during construction. This should speed up -# sampling considerably: +# sampling considerably whenever sampling is time consuming (e.g., when +# using prioritized samplers): buffer_lazymemmap = ReplayBuffer( storage=LazyMemmapStorage(size), batch_size=128, prefetch=10 ) # creates a queue of 10 elements to be prefetched in the background -buffer_lazymemmap.add(data) +buffer_lazymemmap.extend(data) print(buffer_lazymemmap.sample()) ###################################################################### -# Fixed batch-size, iterating over the buffer -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Iterating over the buffer with a fixed batch-size +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # We can also iterate over the buffer like we would do with a regular # dataloader, as long as the batch-size is predefined: @@ -398,7 +433,8 @@ def assert0(x): ###################################################################### # Due to the fact that our sampling technique is entirely random and does not # prevent replacement, the iterator in question is infinite. However, we can -# make use of the :class:`torchrl.data.replay_buffers.SamplerWithoutReplacement` +# make use of the +# :class:`~torchrl.data.replay_buffers.SamplerWithoutReplacement` # instead, which will transform our buffer into a finite iterator: # @@ -428,7 +464,7 @@ def assert0(x): # ~~~~~~~~~~~~~~~~~~ # # In contrast to what we have seen earlier, the ``batch_size`` keyword -# argument can be omitted and passed directly to the `sample` method: +# argument can be omitted and passed directly to the ``sample`` method: buffer_lazymemmap = ReplayBuffer( @@ -442,7 +478,10 @@ def assert0(x): # Prioritized Replay buffers # -------------------------- # -# TorchRL also provides an interface for prioritized replay buffers. +# .. _tuto_rb_prb: +# +# TorchRL also provides an interface for +# `prioritized replay buffers `_. # This buffer class samples data according to a priority signal that is passed # through the data. # @@ -476,8 +515,8 @@ def assert0(x): # buffer, the priority is set to a default value of 1. Once the priority has # been computed (usually through the loss), it must be updated in the buffer. # -# This is done via the `update_priority` method, which requires the indices -# as well as the priority. +# This is done via the :meth:`~torchrl.data.ReplayBuffer.update_priority` +# method, which requires the indices as well as the priority. # We assign an artificially high priority to the second sample in the dataset # to observe its effect on sampling: # @@ -499,6 +538,7 @@ def assert0(x): ###################################################################### # We see that using a prioritized replay buffer requires a series of extra # steps in the training loop compared with a regular buffer: +# # - After collecting data and extending the buffer, the priority of the # items must be updated; # - After computing the loss and getting a "priority signal" from it, we must @@ -511,10 +551,10 @@ def assert0(x): # that the appropriate methods are called at the appropriate place, if and # only if a prioritized buffer is being used. # -# Let us see how we can improve this with TensorDict. We saw that the -# :class:`torchrl.data.TensorDictReplayBuffer` returns data augmented with -# their relative storage indices. One feature we did not mention is that -# this class also ensures that the priority +# Let us see how we can improve this with :class:`~tensordict.TensorDict`. +# We saw that the :class:`~torchrl.data.TensorDictReplayBuffer` returns data +# augmented with their relative storage indices. One feature we did not mention +# is that this class also ensures that the priority # signal is automatically parsed to the prioritized sampler if present during # extension. # @@ -582,6 +622,8 @@ def assert0(x): # Using transforms # ---------------- # +# .. _tuto_rb_transform: +# # The data stored in a replay buffer may not be ready to be presented to a # loss module. # In some cases, the data produced by a collector can be too heavy to be @@ -605,8 +647,14 @@ def assert0(x): from torchrl.collectors import RandomPolicy, SyncDataCollector -from torchrl.envs import Compose, GrayScale, Resize, ToTensorImage, TransformedEnv from torchrl.envs.libs.gym import GymEnv +from torchrl.envs.transforms import ( + Compose, + GrayScale, + Resize, + ToTensorImage, + TransformedEnv, +) env = TransformedEnv( GymEnv("CartPole-v1", from_pixels=True), @@ -630,7 +678,7 @@ def assert0(x): # To do this, we will append a transform to the collector to select the keys # we want to see appearing: -from torchrl.envs import ExcludeTransform +from torchrl.envs.transforms import ExcludeTransform collector = SyncDataCollector( env, @@ -685,7 +733,7 @@ def assert0(x): # A more complex examples: using CatFrames # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -# The :class:`torchrl.envs.CatFrames` transform unfolds the observations +# The :class:`~torchrl.envs.transforms.CatFrames` transform unfolds the observations # through time, creating a n-back memory of past events that allow the model # to take the past events into account (in the case of POMDPs or with # recurrent policies such as Decision Transformers). Storing these concatenated @@ -752,6 +800,56 @@ def assert0(x): assert (data.exclude("collector") == s.squeeze(0).exclude("index", "collector")).all() +###################################################################### +# Storing trajectories +# -------------------- +# +# .. _tuto_rb_traj: +# +# In many cases, it is desirable to access trajectories from the buffer rather +# than simple transitions. TorchRL offers multiple ways of achieving this. +# +# The preferred way is currently to store trajectories along the first +# dimension of the buffer and use a :class:`~torchrl.data.SliceSampler` to +# sample these batches of data. This class only needs a couple of information +# about your data structure to do its job (not that as of now it is only +# compatible with tensordict-structured data): the number of slices or their +# length and some information about where the separation between the +# episodes can be found (e.g. :ref:`recall that ` with a +# :ref:`DataCollector `, the trajectory id is stored in +# ``("collector", "traj_ids")``). In this simple example, we construct a data +# with 4 consecutive short trajectories and sample 4 slices out of it, each of +# length 2 (since the batch size is 8, and 8 items // 4 slices = 2 time steps). +# We mark the steps as well. + +from torchrl.data import SliceSampler + +rb = TensorDictReplayBuffer( + storage=LazyMemmapStorage(size), + sampler=SliceSampler(traj_key="episode", num_slices=4), + batch_size=8, +) +episode = torch.zeros(10, dtype=torch.int) +episode[:3] = 1 +episode[3:5] = 2 +episode[5:7] = 3 +episode[7:] = 4 +steps = torch.cat([torch.arange(3), torch.arange(2), torch.arange(2), torch.arange(3)]) +data = TensorDict( + { + "episode": episode, + "obs": torch.randn((3, 4, 5)).expand(10, 3, 4, 5), + "act": torch.randn((20,)).expand(10, 20), + "other": torch.randn((20, 50)).expand(10, 20, 50), + "steps": steps, + }, + [10], +) +rb.extend(data) +sample = rb.sample() +print("episode are grouped", sample["episode"]) +print("steps are successive", sample["steps"]) + ###################################################################### # Conclusion # ---------- @@ -765,3 +863,13 @@ def assert0(x): # - Choose the best storage type for your problem (list, memory or disk-based); # - Minimize the memory footprint of your buffer. # +# Next steps +# ---------- +# +# - Check the data API reference to learn about offline datasets in TorchRL, +# which are based on our Replay Buffer API; +# - Check other samplers such as +# :class:`~torchrl.data.SamplerWithoutReplacement`, +# :class:`~torchrl.data.PrioritizedSliceSampler` and +# :class:`~torchrl.data.SliceSamplerWithoutReplacement`, or other writers +# such as :class:`~torchrl.data.TensorDictMaxValueWriter`. diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index 5e00442fe36..ce3f0bb4b98 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- """ Introduction to TorchRL -============================ +======================= This demo was presented at ICML 2022 on the industry demo day. """ ############################################################################## @@ -746,8 +746,7 @@ total_frames=240, max_frames_per_traj=-1, # envs are terminating, we don't need to stop them early frames_per_batch=60, # we want 60 frames at a time (we have 3 envs per sub-collector) - storing_devices=devices, # len must match len of env created - devices=devices, + device=devices, ) ############################################################################### @@ -769,8 +768,7 @@ total_frames=240, max_frames_per_traj=-1, # envs are terminating, we don't need to stop them early frames_per_batch=60, # we want 60 frames at a time (we have 3 envs per sub-collector) - storing_devices=devices, # len must match len of env created - devices=devices, + device=devices, ) for i, d in enumerate(collector): @@ -805,7 +803,8 @@ def forward(self, obs, action): value_module, in_keys=["observation", "action"], out_keys=["state_action_value"] ) -loss_fn = DDPGLoss(actor, value, gamma=0.99) +loss_fn = DDPGLoss(actor, value) +loss_fn.make_value_estimator(loss_fn.default_value_estimator, gamma=0.99) ############################################################################### diff --git a/tutorials/sphinx-tutorials/torchrl_envs.py b/tutorials/sphinx-tutorials/torchrl_envs.py index 56896637a87..4c792d44b80 100644 --- a/tutorials/sphinx-tutorials/torchrl_envs.py +++ b/tutorials/sphinx-tutorials/torchrl_envs.py @@ -1,9 +1,15 @@ # -*- coding: utf-8 -*- """ TorchRL envs -============================ +============ + +**Author**: `Vincent Moens `_ + +.. _envs_tuto: + """ ############################################################################## +# # Environments play a crucial role in RL settings, often somewhat similar to # datasets in supervised and unsupervised settings. The RL community has # become quite familiar with OpenAI gym API which offers a flexible way of @@ -19,7 +25,10 @@ # To run this part of the tutorial, you will need to have a recent version of # the gym library installed, as well as the atari suite. You can get this # installed by installing the following packages: -# $ pip install gym atari-py ale-py gym[accept-rom-license] pygame +# +# .. code-block:: +# $ pip install gym atari-py ale-py gym[accept-rom-license] pygame +# # To unify all frameworks, torchrl environments are built inside the # ``__init__`` method with a private method called ``_build_env`` that # will pass the arguments and keyword arguments to the root library builder.