Skip to content

Commit

Permalink
[Feature] CatFrames.make_rb_transform_and_sampler
Browse files Browse the repository at this point in the history
ghstack-source-id: 11488a7c1d8ed1003148ff907d30195d153997f4
Pull Request resolved: #2643
  • Loading branch information
vmoens committed Dec 11, 2024
1 parent b840a77 commit bb76133
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 3 deletions.
108 changes: 108 additions & 0 deletions examples/replay-buffers/catframes-in-buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torchrl.data import LazyTensorStorage, ReplayBuffer
from torchrl.envs import (
CatFrames,
Compose,
DMControlEnv,
StepCounter,
ToTensorImage,
TransformedEnv,
UnsqueezeTransform,
)

# Number of frames to stack together
frame_stack = 4
# Dimension along which the stack should occur
stack_dim = -4
# Max size of the buffer
max_size = 100_000
# Batch size of the replay buffer
training_batch_size = 32

seed = 123


def main():
env = TransformedEnv(
DMControlEnv(
env_name="cartpole",
task_name="balance",
device="cpu",
from_pixels=True,
pixels_only=True,
),
Compose(
ToTensorImage(
from_int=True,
dtype=torch.float32,
in_keys=["pixels"],
out_keys=["pixels_trsf"],
shape_tolerant=True,
),
UnsqueezeTransform(
dim=stack_dim, in_keys=["pixels_trsf"], out_keys=["pixels_trsf"]
),
CatFrames(
N=frame_stack,
dim=stack_dim,
in_keys=["pixels_trsf"],
out_keys=["pixels_trsf"],
),
StepCounter(),
),
)
env.set_seed(seed)

sampler = None
transform = None

def collect_transform_sampler(module):
nonlocal sampler
nonlocal transform
if isinstance(module, CatFrames):
transform, sampler = module.make_rb_transform_and_sampler(
batch_size=training_batch_size,
traj_key=("collector", "traj_ids"),
strict_length=True,
)

env.apply(collect_transform_sampler)

rb_transforms = Compose(
ToTensorImage(
from_int=True,
dtype=torch.float32,
in_keys=["pixels", ("next", "pixels")],
out_keys=["pixels_trsf", ("next", "pixels_trsf")],
shape_tolerant=True,
), # C W' H' -> C W' H' (unchanged due to shape_tolerant)
UnsqueezeTransform(
dim=stack_dim,
in_keys=["pixels_trsf", ("next", "pixels_trsf")],
out_keys=["pixels_trsf", ("next", "pixels_trsf")],
), # 1 C W' H'
transform,
)

rb = ReplayBuffer(
storage=LazyTensorStorage(max_size=max_size, device="cpu"),
sampler=sampler,
batch_size=training_batch_size,
transform=rb_transforms,
)

data = env.rollout(1000, break_when_any_done=False)
rb.extend(data)

# This is where things do not work as expected right now.
training_batch = rb.sample()
print(training_batch)


if __name__ == "__main__":
main()
7 changes: 7 additions & 0 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,9 @@ class SliceSampler(Sampler):
"""

# We use this whenever we need to sample N times too many transitions to then select only a 1/N fraction of them
_batch_size_multiplier: int | None

def __init__(
self,
*,
Expand Down Expand Up @@ -1295,6 +1298,8 @@ def _adjusted_batch_size(self, batch_size):
return seq_length, num_slices

def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]:
if self._batch_size_multiplier is not None:
batch_size = batch_size * self._batch_size_multiplier
# pick up as many trajs as we need
start_idx, stop_idx, lengths = self._get_stop_and_length(storage)
# we have to make sure that the number of dims of the storage
Expand Down Expand Up @@ -1747,6 +1752,8 @@ def _storage_len(self, storage):
def sample(
self, storage: Storage, batch_size: int
) -> Tuple[Tuple[torch.Tensor, ...], dict]:
if self._batch_size_multiplier is not None:
batch_size = batch_size * self._batch_size_multiplier
start_idx, stop_idx, lengths = self._get_stop_and_length(storage)
# we have to make sure that the number of dims of the storage
# is the same as the stop/start signals since we will
Expand Down
64 changes: 61 additions & 3 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2825,9 +2825,9 @@ def _reset(
class CatFrames(ObservationTransform):
"""Concatenates successive observation frames into a single tensor.
This can, for instance, account for movement/velocity of the observed
feature. Proposed in "Playing Atari with Deep Reinforcement Learning" (
https://arxiv.org/pdf/1312.5602.pdf).
This transform is useful for creating a sense of movement or velocity in the observed features.
It can also be used with models that require access to past observations such as transformers and the like.
It was first proposed in "Playing Atari with Deep Reinforcement Learning" (https://arxiv.org/pdf/1312.5602.pdf).
When used within a transformed environment,
:class:`CatFrames` is a stateful class, and it can be reset to its native state by
Expand Down Expand Up @@ -2915,6 +2915,14 @@ class CatFrames(ObservationTransform):
such as those found in MARL settings, are currently not supported.
If this feature is needed, please raise an issue on TorchRL repo.
.. note:: Storing stacks of frames in the replay buffer can significantly increase memory consumption (by N times).
To mitigate this, you can store trajectories directly in the replay buffer and apply :class:`CatFrames` at sampling time.
This approach involves sampling slices of the stored trajectories and then applying the frame stacking transform.
For convenience, :class:`CatFrames` provides a :meth:`~.make_rb_transform_and_sampler` method that creates:
- A modified version of the transform suitable for use in replay buffers
- A corresponding :class:`SliceSampler` to use with the buffer
"""

inplace = False
Expand Down Expand Up @@ -2964,6 +2972,56 @@ def __init__(
self.reset_key = reset_key
self.done_key = done_key

def make_rb_transform_and_sampler(
self, batch_size: int, **sampler_kwargs
) -> Tuple[Transform, "torchrl.data.replay_buffers.SliceSampler"]: # noqa: F821
"""Creates a transform and sampler to be used with a replay buffer when storing frame-stacked data.
This method helps reduce redundancy in stored data by avoiding the need to
store the entire stack of frames in the buffer. Instead, it creates a
transform that stacks frames on-the-fly during sampling, and a sampler that
ensures the correct sequence length is maintained.
Args:
batch_size (int): The batch size to use for the sampler.
**sampler_kwargs: Additional keyword arguments to pass to the
:class:`~torchrl.data.replay_buffers.SliceSampler` constructor.
Returns:
A tuple containing:
- transform (Transform): A transform that stacks frames on-the-fly during sampling.
- sampler (SliceSampler): A sampler that ensures the correct sequence length is maintained.
Example:
>>> env = TransformedEnv(...)
>>> catframes = CatFrames(N=4, ...)
>>> transform, sampler = catframes.make_rb_transform_and_sampler(batch_size=32)
>>> rb = ReplayBuffer(..., sampler=sampler, transform=transform)
.. note:: When working with images, it's recommended to use distinct ``in_keys`` and ``out_keys`` in the preceding
:class:`~torchrl.envs.ToTensorImage` transform. This ensures that the tensors stored in the buffer are separate
from their processed counterparts, which we don't want to store.
For non-image data, consider inserting a :class:`~torchrl.envs.RenameTransform` before :class:`CatFrames` to create
a copy of the data that will be stored in the buffer.
.. note:: For a more complete example, refer to torchrl's github repo `examples` folder:
https://github.com/pytorch/rl/tree/main/examples/replay-buffers/catframes-in-buffer.py
"""
from torchrl.data.replay_buffers import SliceSampler

catframes = self.clone()
sampler = SliceSampler(slice_len=self.N, **sampler_kwargs)
sampler._batch_size_multiplier = self.N
transform = Compose(
lambda td: td.reshape(-1, self.N),
catframes,
lambda td: td[:, -1],
# We only store "pixels" to the replay buffer to save memory
ExcludeTransform(*self.in_keys, inverse=True),
)
return transform, sampler

@property
def done_key(self):
done_key = self.__dict__.get("_done_key", None)
Expand Down

0 comments on commit bb76133

Please sign in to comment.