From 9b95196b8e75be28303e3e47d263c7be2015657c Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 10 Feb 2024 06:49:34 +0000 Subject: [PATCH] amend --- docs/source/reference/collectors.rst | 2 + torchrl/data/replay_buffers/samplers.py | 2 +- .../sphinx-tutorials/getting-started-0.py | 6 +- .../sphinx-tutorials/getting-started-1.py | 22 +-- tutorials/sphinx-tutorials/rb_tutorial.py | 131 +++++++++++++----- 5 files changed, 116 insertions(+), 47 deletions(-) diff --git a/docs/source/reference/collectors.rst b/docs/source/reference/collectors.rst index aa8de179f20..43146060edb 100644 --- a/docs/source/reference/collectors.rst +++ b/docs/source/reference/collectors.rst @@ -3,6 +3,8 @@ torchrl.collectors package ========================== +.. _data_collectors: + Data collectors are somewhat equivalent to pytorch dataloaders, except that (1) they collect data over non-static data sources and (2) the data is collected using a model (likely a version of the model that is being trained). diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 96d73375ea9..2a169cbd332 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -718,7 +718,7 @@ def __init__( if end_key is None: end_key = ("next", "done") if traj_key is None: - traj_key = "run" + traj_key = "episode" self.end_key = end_key self.traj_key = traj_key diff --git a/tutorials/sphinx-tutorials/getting-started-0.py b/tutorials/sphinx-tutorials/getting-started-0.py index b601c866871..cce340fe961 100644 --- a/tutorials/sphinx-tutorials/getting-started-0.py +++ b/tutorials/sphinx-tutorials/getting-started-0.py @@ -114,7 +114,7 @@ # our action. # # We call this format TED, for -# :ref:`TorchRL Episode Data format `. It is +# :ref:`TorchRL Episode Data format `. It is # the ubiquitous way of representing data in the library, both dynamically like # here, or statically with offline datasets. # @@ -180,7 +180,7 @@ # In this section, we'll examine a simple transform, the # :class:`~torchrl.envs.transforms.StepCounter` transform. # The complete list of transforms can be found -# :ref:`here `. +# :ref:`here `. # # The transform is integrated with the environment through a # :class:`~torchrl.envs.TransformedEnv`: @@ -214,7 +214,7 @@ # # - The :meth:`~torchrl.envs.EnvBase.step_and_maybe_reset` method that packs # together :meth:`~torchrl.envs.EnvBase.step`, -# :func:`~torchrl.envs.step_mdp` and +# :func:`~torchrl.envs.utils.step_mdp` and # :meth:`~torchrl.envs.EnvBase.reset`. # - Some environments like :class:`~torchrl.envs.GymEnv` support rendering # through the ``from_pixels`` argument. Check the class docstrings to know diff --git a/tutorials/sphinx-tutorials/getting-started-1.py b/tutorials/sphinx-tutorials/getting-started-1.py index 228583bc9a7..7c001e887d8 100644 --- a/tutorials/sphinx-tutorials/getting-started-1.py +++ b/tutorials/sphinx-tutorials/getting-started-1.py @@ -65,13 +65,13 @@ # # To simplify the incorporation of :class:`~torch.nn.Module`s into your # codebase, TorchRL offers a range of specialized wrappers designed to be -# used as actors, including :class:`~torchrl.modules.Actor`, -# # :class:`~torchrl.modules.ProbabilisticActor`, -# # :class:`~torchrl.modules.ActorValueOperator` or -# # :class:`~torchrl.modules.ActorCriticOperator`. -# For example, :class:`~torchrl.modules.Actor` provides default values for the -# ``in_keys`` and ``out_keys``, making integration with many common -# environments straightforward: +# used as actors, including :class:`~torchrl.modules.tensordict_module.Actor`, +# # :class:`~torchrl.modules.tensordict_module.ProbabilisticActor`, +# # :class:`~torchrl.modules.tensordict_module.ActorValueOperator` or +# # :class:`~torchrl.modules.tensordict_module.ActorCriticOperator`. +# For example, :class:`~torchrl.modules.tensordict_module.Actor` provides +# default values for the ``in_keys`` and ``out_keys``, making integration +# with many common environments straightforward: # from torchrl.modules import Actor @@ -82,7 +82,7 @@ ################################### # The list of available specialized TensorDictModules is available in the -# :ref:`API reference `. +# :ref:`API reference `. # # Networks # -------- @@ -126,8 +126,8 @@ # will split this output on two chunks, a mean and a standard deviation of # size ``[1]``; # - A :class:`~torchrl.modules.tensordict_module.ProbabilisticActor` that will -# read those parameters, create a distribution with them and populate our -# tensordict with samples and log-probabilities. +# read those parameters as ``in_keys``, create a distribution with them and +# populate our tensordict with samples and log-probabilities. # from tensordict.nn.distributions import NormalParamExtractor @@ -140,7 +140,7 @@ td_module = TensorDictModule(module, in_keys=["observation"], out_keys=["loc", "scale"]) policy = ProbabilisticActor( td_module, - in_keys=["observation"], + in_keys=["loc", "scale"], out_keys=["action"], distribution_class=Normal, return_log_prob=True, diff --git a/tutorials/sphinx-tutorials/rb_tutorial.py b/tutorials/sphinx-tutorials/rb_tutorial.py index 24cd35ff424..1ed8513f3c8 100644 --- a/tutorials/sphinx-tutorials/rb_tutorial.py +++ b/tutorials/sphinx-tutorials/rb_tutorial.py @@ -33,21 +33,22 @@ # # In this tutorial, you will learn: # -# - How to build a :ref:`Replay Buffer (RB) ` and use it with +# - How to build a :ref:`Replay Buffer (RB) ` and use it with # any datatype; -# - How to customize the :ref:`buffer's storage `; -# - How to use :ref:`RBs with TensorDict `; -# - How to :ref:`sample from or iterate over a replay buffer `, +# - How to customize the :ref:`buffer's storage `; +# - How to use :ref:`RBs with TensorDict `; +# - How to :ref:`sample from or iterate over a replay buffer `, # and how to define the sampling strategy; -# - How to use prioritized replay buffers; -# - How to transform data coming in and out from the buffer; -# - How to store trajectories in the buffer. +# - How to use :ref:`prioritized replay buffers `; +# - How to :ref:`transform data ` coming in and out from +# the buffer; +# - How to store :ref:`trajectories <_tuto_rb_traj>` in the buffer. # # # Basics: building a vanilla replay buffer # ---------------------------------------- # -# .. _vanilla_rb: +# .. _tuto_rb_vanilla: # # TorchRL's replay buffers are designed to prioritize modularity, # composability, efficiency, and simplicity. For instance, creating a basic @@ -112,7 +113,7 @@ # Customizing the storage # ----------------------- # -# .. _buffer_storage: +# .. _tuto_rb_storage: # # We see that the buffer has been capped to the first 1000 elements that we # passed to it. @@ -227,7 +228,7 @@ # Integration with TensorDict # --------------------------- # -# .. _td_rb: +# .. _tuto_rb_td: # # The tensor location follows the same structure as the TensorDict that # contains them: this makes it easy to save and load buffers during training. @@ -361,7 +362,7 @@ def assert0(x): # Sampling and iterating over buffers # ----------------------------------- # -# .. _sampling_rb: +# .. _tuto_rb_sampling: # # Replay Buffers support multiple sampling strategies: # @@ -382,21 +383,21 @@ def assert0(x): # Fixed batch-size # ~~~~~~~~~~~~~~~~ # -# If the batch-size is passed during construction, it should be omited when +# If the batch-size is passed during construction, it should be omitted when # sampling: data = MyData( images=torch.randint( 255, - (10, 64, 64, 3), + (200, 64, 64, 3), ), - labels=torch.randint(100, (10,)), - batch_size=[10], + labels=torch.randint(100, (200,)), + batch_size=[200], ) buffer_lazymemmap = ReplayBuffer(storage=LazyMemmapStorage(size), batch_size=128) -buffer_lazymemmap.add(data) -buffer_lazymemmap.sample() # will produces 128 identical samples +buffer_lazymemmap.extend(data) +buffer_lazymemmap.sample() ###################################################################### @@ -404,19 +405,20 @@ def assert0(x): # # To enable multithreaded sampling, just pass a positive integer to the # ``prefetch`` keyword argument during construction. This should speed up -# sampling considerably: +# sampling considerably whenever sampling is time consuming (e.g., when +# using prioritized samplers): buffer_lazymemmap = ReplayBuffer( storage=LazyMemmapStorage(size), batch_size=128, prefetch=10 ) # creates a queue of 10 elements to be prefetched in the background -buffer_lazymemmap.add(data) +buffer_lazymemmap.extend(data) print(buffer_lazymemmap.sample()) ###################################################################### -# Fixed batch-size, iterating over the buffer -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Iterating over the buffer with a fixed batch-size +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # We can also iterate over the buffer like we would do with a regular # dataloader, as long as the batch-size is predefined: @@ -462,7 +464,7 @@ def assert0(x): # ~~~~~~~~~~~~~~~~~~ # # In contrast to what we have seen earlier, the ``batch_size`` keyword -# argument can be omitted and passed directly to the `sample` method: +# argument can be omitted and passed directly to the ``sample`` method: buffer_lazymemmap = ReplayBuffer( @@ -476,7 +478,10 @@ def assert0(x): # Prioritized Replay buffers # -------------------------- # -# TorchRL also provides an interface for prioritized replay buffers. +# .. _tuto_rb_prb: +# +# TorchRL also provides an interface for +# `prioritized replay buffers `_. # This buffer class samples data according to a priority signal that is passed # through the data. # @@ -510,8 +515,8 @@ def assert0(x): # buffer, the priority is set to a default value of 1. Once the priority has # been computed (usually through the loss), it must be updated in the buffer. # -# This is done via the `update_priority` method, which requires the indices -# as well as the priority. +# This is done via the :meth:`~torchrl.data.ReplayBuffer.update_priority` +# method, which requires the indices as well as the priority. # We assign an artificially high priority to the second sample in the dataset # to observe its effect on sampling: # @@ -533,6 +538,7 @@ def assert0(x): ###################################################################### # We see that using a prioritized replay buffer requires a series of extra # steps in the training loop compared with a regular buffer: +# # - After collecting data and extending the buffer, the priority of the # items must be updated; # - After computing the loss and getting a "priority signal" from it, we must @@ -545,10 +551,10 @@ def assert0(x): # that the appropriate methods are called at the appropriate place, if and # only if a prioritized buffer is being used. # -# Let us see how we can improve this with TensorDict. We saw that the -# :class:`~torchrl.data.TensorDictReplayBuffer` returns data augmented with -# their relative storage indices. One feature we did not mention is that -# this class also ensures that the priority +# Let us see how we can improve this with :class:`~tensordict.TensorDict`. +# We saw that the :class:`~torchrl.data.TensorDictReplayBuffer` returns data +# augmented with their relative storage indices. One feature we did not mention +# is that this class also ensures that the priority # signal is automatically parsed to the prioritized sampler if present during # extension. # @@ -616,6 +622,8 @@ def assert0(x): # Using transforms # ---------------- # +# .. _tuto_rb_transform: +# # The data stored in a replay buffer may not be ready to be presented to a # loss module. # In some cases, the data produced by a collector can be too heavy to be @@ -639,7 +647,7 @@ def assert0(x): from torchrl.collectors import RandomPolicy, SyncDataCollector -from torchrl.envs import Compose, GrayScale, Resize, ToTensorImage, TransformedEnv +from torchrl.envs.transforms import Compose, GrayScale, Resize, ToTensorImage, TransformedEnv from torchrl.envs.libs.gym import GymEnv env = TransformedEnv( @@ -664,7 +672,7 @@ def assert0(x): # To do this, we will append a transform to the collector to select the keys # we want to see appearing: -from torchrl.envs import ExcludeTransform +from torchrl.envs.transforms import ExcludeTransform collector = SyncDataCollector( env, @@ -719,7 +727,7 @@ def assert0(x): # A more complex examples: using CatFrames # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -# The :class:`~torchrl.envs.CatFrames` transform unfolds the observations +# The :class:`~torchrl.envs.transforms.CatFrames` transform unfolds the observations # through time, creating a n-back memory of past events that allow the model # to take the past events into account (in the case of POMDPs or with # recurrent policies such as Decision Transformers). Storing these concatenated @@ -786,6 +794,55 @@ def assert0(x): assert (data.exclude("collector") == s.squeeze(0).exclude("index", "collector")).all() +###################################################################### +# Storing trajectories +# -------------------- +# +# .. _tuto_rb_traj: +# +# In many cases, it is desirable to access trajectories from the buffer rather +# than simple transitions. TorchRL offers multiple ways of achieving this. +# +# The preferred way is currently to store trajectories along the first +# dimension of the buffer and use a :class`~torchrl.data.SliceSampler` to +# sample these batches of data. This class only needs a couple of information +# about your data structure to do its job (not that as of now it is only +# compatible with tensordict-structured data): the number of slices or their +# length and some information about where the separation between the +# episodes can be found (e.g. :ref:`recall that ` with a +# :ref:`DataCollector `, the trajectory id is stored in +# ``("collector", "traj_ids")``). In this simple example, we construct a data +# with 4 consecutive short trajectories and sample 4 slices out of it, each of +# length 2 (since the batch size is 8, and 8 items // 4 slices = 2 time steps). +# We mark the steps as well. + +from torchrl.data import SliceSampler + +rb = TensorDictReplayBuffer( + storage=LazyMemmapStorage(size), + sampler=SliceSampler(traj_key="episode", num_slices=4), + batch_size=8, +) +episode = torch.zeros(10, dtype=torch.int) +episode[:3] = 1 +episode[3:5] = 2 +episode[5:7] = 3 +episode[7:] = 4 +steps = torch.cat([torch.arange(3), torch.arange(2), torch.arange(2), torch.arange(3)]) +data = TensorDict( + { + "episode": episode, + "obs": torch.randn((3, 4, 5)).expand(10, 3, 4, 5), + "act": torch.randn((20,)).expand(10, 20), + "other": torch.randn((20, 50)).expand(10, 20, 50), + "steps": steps, + }, [10] +) +rb.extend(data) +sample = rb.sample() +print("episode are grouped", sample["episode"]) +print("steps are successive", sample["steps"]) + ###################################################################### # Conclusion # ---------- @@ -799,3 +856,13 @@ def assert0(x): # - Choose the best storage type for your problem (list, memory or disk-based); # - Minimize the memory footprint of your buffer. # +# Next steps +# ---------- +# +# - Check the data API reference to learn about offline datasets in TorchRL, +# which are based on our Replay Buffer API; +# - Check other samplers such as +# :class:`~torchrl.data.SamplerWithoutReplacement`, +# :class:`~torchrl.data.PrioritizedSliceSampler` and +# :class:`~torchrl.data.SliceSamplerWithoutReplacement`, or other writers +# such as :class:`~torchrl.data.TensorDictMaxValueWriter`.