Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 10, 2024
1 parent cdd51f0 commit 9b95196
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 47 deletions.
2 changes: 2 additions & 0 deletions docs/source/reference/collectors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
torchrl.collectors package
==========================

.. _data_collectors:

Data collectors are somewhat equivalent to pytorch dataloaders, except that (1) they
collect data over non-static data sources and (2) the data is collected using a model
(likely a version of the model that is being trained).
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ def __init__(
if end_key is None:
end_key = ("next", "done")
if traj_key is None:
traj_key = "run"
traj_key = "episode"
self.end_key = end_key
self.traj_key = traj_key

Expand Down
6 changes: 3 additions & 3 deletions tutorials/sphinx-tutorials/getting-started-0.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@
# our action.
#
# We call this format TED, for
# :ref:`TorchRL Episode Data format <reference/data:TED-format>`. It is
# :ref:`TorchRL Episode Data format <TED-format>`. It is
# the ubiquitous way of representing data in the library, both dynamically like
# here, or statically with offline datasets.
#
Expand Down Expand Up @@ -180,7 +180,7 @@
# In this section, we'll examine a simple transform, the
# :class:`~torchrl.envs.transforms.StepCounter` transform.
# The complete list of transforms can be found
# :ref:`here <reference/envs:transforms>`.
# :ref:`here <transforms>`.
#
# The transform is integrated with the environment through a
# :class:`~torchrl.envs.TransformedEnv`:
Expand Down Expand Up @@ -214,7 +214,7 @@
#
# - The :meth:`~torchrl.envs.EnvBase.step_and_maybe_reset` method that packs
# together :meth:`~torchrl.envs.EnvBase.step`,
# :func:`~torchrl.envs.step_mdp` and
# :func:`~torchrl.envs.utils.step_mdp` and
# :meth:`~torchrl.envs.EnvBase.reset`.
# - Some environments like :class:`~torchrl.envs.GymEnv` support rendering
# through the ``from_pixels`` argument. Check the class docstrings to know
Expand Down
22 changes: 11 additions & 11 deletions tutorials/sphinx-tutorials/getting-started-1.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@
#
# To simplify the incorporation of :class:`~torch.nn.Module`s into your
# codebase, TorchRL offers a range of specialized wrappers designed to be
# used as actors, including :class:`~torchrl.modules.Actor`,
# # :class:`~torchrl.modules.ProbabilisticActor`,
# # :class:`~torchrl.modules.ActorValueOperator` or
# # :class:`~torchrl.modules.ActorCriticOperator`.
# For example, :class:`~torchrl.modules.Actor` provides default values for the
# ``in_keys`` and ``out_keys``, making integration with many common
# environments straightforward:
# used as actors, including :class:`~torchrl.modules.tensordict_module.Actor`,
# # :class:`~torchrl.modules.tensordict_module.ProbabilisticActor`,
# # :class:`~torchrl.modules.tensordict_module.ActorValueOperator` or
# # :class:`~torchrl.modules.tensordict_module.ActorCriticOperator`.
# For example, :class:`~torchrl.modules.tensordict_module.Actor` provides
# default values for the ``in_keys`` and ``out_keys``, making integration
# with many common environments straightforward:
#

from torchrl.modules import Actor
Expand All @@ -82,7 +82,7 @@

###################################
# The list of available specialized TensorDictModules is available in the
# :ref:`API reference <reference/modules:tdmodules>`.
# :ref:`API reference <tdmodules>`.
#
# Networks
# --------
Expand Down Expand Up @@ -126,8 +126,8 @@
# will split this output on two chunks, a mean and a standard deviation of
# size ``[1]``;
# - A :class:`~torchrl.modules.tensordict_module.ProbabilisticActor` that will
# read those parameters, create a distribution with them and populate our
# tensordict with samples and log-probabilities.
# read those parameters as ``in_keys``, create a distribution with them and
# populate our tensordict with samples and log-probabilities.
#

from tensordict.nn.distributions import NormalParamExtractor
Expand All @@ -140,7 +140,7 @@
td_module = TensorDictModule(module, in_keys=["observation"], out_keys=["loc", "scale"])
policy = ProbabilisticActor(
td_module,
in_keys=["observation"],
in_keys=["loc", "scale"],
out_keys=["action"],
distribution_class=Normal,
return_log_prob=True,
Expand Down
131 changes: 99 additions & 32 deletions tutorials/sphinx-tutorials/rb_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,22 @@
#
# In this tutorial, you will learn:
#
# - How to build a :ref:`Replay Buffer (RB) <vanilla_rb>` and use it with
# - How to build a :ref:`Replay Buffer (RB) <tuto_rb_vanilla>` and use it with
# any datatype;
# - How to customize the :ref:`buffer's storage <buffer_storage>`;
# - How to use :ref:`RBs with TensorDict <td_rb>`;
# - How to :ref:`sample from or iterate over a replay buffer <sampling_rb>`,
# - How to customize the :ref:`buffer's storage <tuto_rb_storage>`;
# - How to use :ref:`RBs with TensorDict <tuto_rb_td>`;
# - How to :ref:`sample from or iterate over a replay buffer <tuto_rb_sampling>`,
# and how to define the sampling strategy;
# - How to use prioritized replay buffers;
# - How to transform data coming in and out from the buffer;
# - How to store trajectories in the buffer.
# - How to use :ref:`prioritized replay buffers <tuto_rb_prb>`;
# - How to :ref:`transform data <tuto_rb_transform>` coming in and out from
# the buffer;
# - How to store :ref:`trajectories <_tuto_rb_traj>` in the buffer.
#
#
# Basics: building a vanilla replay buffer
# ----------------------------------------
#
# .. _vanilla_rb:
# .. _tuto_rb_vanilla:
#
# TorchRL's replay buffers are designed to prioritize modularity,
# composability, efficiency, and simplicity. For instance, creating a basic
Expand Down Expand Up @@ -112,7 +113,7 @@
# Customizing the storage
# -----------------------
#
# .. _buffer_storage:
# .. _tuto_rb_storage:
#
# We see that the buffer has been capped to the first 1000 elements that we
# passed to it.
Expand Down Expand Up @@ -227,7 +228,7 @@
# Integration with TensorDict
# ---------------------------
#
# .. _td_rb:
# .. _tuto_rb_td:
#
# The tensor location follows the same structure as the TensorDict that
# contains them: this makes it easy to save and load buffers during training.
Expand Down Expand Up @@ -361,7 +362,7 @@ def assert0(x):
# Sampling and iterating over buffers
# -----------------------------------
#
# .. _sampling_rb:
# .. _tuto_rb_sampling:
#
# Replay Buffers support multiple sampling strategies:
#
Expand All @@ -382,41 +383,42 @@ def assert0(x):
# Fixed batch-size
# ~~~~~~~~~~~~~~~~
#
# If the batch-size is passed during construction, it should be omited when
# If the batch-size is passed during construction, it should be omitted when
# sampling:

data = MyData(
images=torch.randint(
255,
(10, 64, 64, 3),
(200, 64, 64, 3),
),
labels=torch.randint(100, (10,)),
batch_size=[10],
labels=torch.randint(100, (200,)),
batch_size=[200],
)

buffer_lazymemmap = ReplayBuffer(storage=LazyMemmapStorage(size), batch_size=128)
buffer_lazymemmap.add(data)
buffer_lazymemmap.sample() # will produces 128 identical samples
buffer_lazymemmap.extend(data)
buffer_lazymemmap.sample()


######################################################################
# This batch of data has the size that we wanted it to have (128).
#
# To enable multithreaded sampling, just pass a positive integer to the
# ``prefetch`` keyword argument during construction. This should speed up
# sampling considerably:
# sampling considerably whenever sampling is time consuming (e.g., when
# using prioritized samplers):


buffer_lazymemmap = ReplayBuffer(
storage=LazyMemmapStorage(size), batch_size=128, prefetch=10
) # creates a queue of 10 elements to be prefetched in the background
buffer_lazymemmap.add(data)
buffer_lazymemmap.extend(data)
print(buffer_lazymemmap.sample())


######################################################################
# Fixed batch-size, iterating over the buffer
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Iterating over the buffer with a fixed batch-size
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# We can also iterate over the buffer like we would do with a regular
# dataloader, as long as the batch-size is predefined:
Expand Down Expand Up @@ -462,7 +464,7 @@ def assert0(x):
# ~~~~~~~~~~~~~~~~~~
#
# In contrast to what we have seen earlier, the ``batch_size`` keyword
# argument can be omitted and passed directly to the `sample` method:
# argument can be omitted and passed directly to the ``sample`` method:


buffer_lazymemmap = ReplayBuffer(
Expand All @@ -476,7 +478,10 @@ def assert0(x):
# Prioritized Replay buffers
# --------------------------
#
# TorchRL also provides an interface for prioritized replay buffers.
# .. _tuto_rb_prb:
#
# TorchRL also provides an interface for
# `prioritized replay buffers <https://arxiv.org/abs/1511.05952>`_.
# This buffer class samples data according to a priority signal that is passed
# through the data.
#
Expand Down Expand Up @@ -510,8 +515,8 @@ def assert0(x):
# buffer, the priority is set to a default value of 1. Once the priority has
# been computed (usually through the loss), it must be updated in the buffer.
#
# This is done via the `update_priority` method, which requires the indices
# as well as the priority.
# This is done via the :meth:`~torchrl.data.ReplayBuffer.update_priority`
# method, which requires the indices as well as the priority.
# We assign an artificially high priority to the second sample in the dataset
# to observe its effect on sampling:
#
Expand All @@ -533,6 +538,7 @@ def assert0(x):
######################################################################
# We see that using a prioritized replay buffer requires a series of extra
# steps in the training loop compared with a regular buffer:
#
# - After collecting data and extending the buffer, the priority of the
# items must be updated;
# - After computing the loss and getting a "priority signal" from it, we must
Expand All @@ -545,10 +551,10 @@ def assert0(x):
# that the appropriate methods are called at the appropriate place, if and
# only if a prioritized buffer is being used.
#
# Let us see how we can improve this with TensorDict. We saw that the
# :class:`~torchrl.data.TensorDictReplayBuffer` returns data augmented with
# their relative storage indices. One feature we did not mention is that
# this class also ensures that the priority
# Let us see how we can improve this with :class:`~tensordict.TensorDict`.
# We saw that the :class:`~torchrl.data.TensorDictReplayBuffer` returns data
# augmented with their relative storage indices. One feature we did not mention
# is that this class also ensures that the priority
# signal is automatically parsed to the prioritized sampler if present during
# extension.
#
Expand Down Expand Up @@ -616,6 +622,8 @@ def assert0(x):
# Using transforms
# ----------------
#
# .. _tuto_rb_transform:
#
# The data stored in a replay buffer may not be ready to be presented to a
# loss module.
# In some cases, the data produced by a collector can be too heavy to be
Expand All @@ -639,7 +647,7 @@ def assert0(x):


from torchrl.collectors import RandomPolicy, SyncDataCollector
from torchrl.envs import Compose, GrayScale, Resize, ToTensorImage, TransformedEnv
from torchrl.envs.transforms import Compose, GrayScale, Resize, ToTensorImage, TransformedEnv
from torchrl.envs.libs.gym import GymEnv

env = TransformedEnv(
Expand All @@ -664,7 +672,7 @@ def assert0(x):
# To do this, we will append a transform to the collector to select the keys
# we want to see appearing:

from torchrl.envs import ExcludeTransform
from torchrl.envs.transforms import ExcludeTransform

collector = SyncDataCollector(
env,
Expand Down Expand Up @@ -719,7 +727,7 @@ def assert0(x):
# A more complex examples: using CatFrames
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# The :class:`~torchrl.envs.CatFrames` transform unfolds the observations
# The :class:`~torchrl.envs.transforms.CatFrames` transform unfolds the observations
# through time, creating a n-back memory of past events that allow the model
# to take the past events into account (in the case of POMDPs or with
# recurrent policies such as Decision Transformers). Storing these concatenated
Expand Down Expand Up @@ -786,6 +794,55 @@ def assert0(x):

assert (data.exclude("collector") == s.squeeze(0).exclude("index", "collector")).all()

######################################################################
# Storing trajectories
# --------------------
#
# .. _tuto_rb_traj:
#
# In many cases, it is desirable to access trajectories from the buffer rather
# than simple transitions. TorchRL offers multiple ways of achieving this.
#
# The preferred way is currently to store trajectories along the first
# dimension of the buffer and use a :class`~torchrl.data.SliceSampler` to
# sample these batches of data. This class only needs a couple of information
# about your data structure to do its job (not that as of now it is only
# compatible with tensordict-structured data): the number of slices or their
# length and some information about where the separation between the
# episodes can be found (e.g. :ref:`recall that <gs_storage_collector>` with a
# :ref:`DataCollector <data_collectors>`, the trajectory id is stored in
# ``("collector", "traj_ids")``). In this simple example, we construct a data
# with 4 consecutive short trajectories and sample 4 slices out of it, each of
# length 2 (since the batch size is 8, and 8 items // 4 slices = 2 time steps).
# We mark the steps as well.

from torchrl.data import SliceSampler

rb = TensorDictReplayBuffer(
storage=LazyMemmapStorage(size),
sampler=SliceSampler(traj_key="episode", num_slices=4),
batch_size=8,
)
episode = torch.zeros(10, dtype=torch.int)
episode[:3] = 1
episode[3:5] = 2
episode[5:7] = 3
episode[7:] = 4
steps = torch.cat([torch.arange(3), torch.arange(2), torch.arange(2), torch.arange(3)])
data = TensorDict(
{
"episode": episode,
"obs": torch.randn((3, 4, 5)).expand(10, 3, 4, 5),
"act": torch.randn((20,)).expand(10, 20),
"other": torch.randn((20, 50)).expand(10, 20, 50),
"steps": steps,
}, [10]
)
rb.extend(data)
sample = rb.sample()
print("episode are grouped", sample["episode"])
print("steps are successive", sample["steps"])

######################################################################
# Conclusion
# ----------
Expand All @@ -799,3 +856,13 @@ def assert0(x):
# - Choose the best storage type for your problem (list, memory or disk-based);
# - Minimize the memory footprint of your buffer.
#
# Next steps
# ----------
#
# - Check the data API reference to learn about offline datasets in TorchRL,
# which are based on our Replay Buffer API;
# - Check other samplers such as
# :class:`~torchrl.data.SamplerWithoutReplacement`,
# :class:`~torchrl.data.PrioritizedSliceSampler` and
# :class:`~torchrl.data.SliceSamplerWithoutReplacement`, or other writers
# such as :class:`~torchrl.data.TensorDictMaxValueWriter`.

0 comments on commit 9b95196

Please sign in to comment.