Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] CatFrames.make_rb_transform_and_sampler #2643

Merged
merged 2 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions examples/replay-buffers/catframes-in-buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torchrl.data import LazyTensorStorage, ReplayBuffer
from torchrl.envs import (
CatFrames,
Compose,
DMControlEnv,
StepCounter,
ToTensorImage,
TransformedEnv,
UnsqueezeTransform,
)

# Number of frames to stack together
frame_stack = 4
# Dimension along which the stack should occur
stack_dim = -4
# Max size of the buffer
max_size = 100_000
# Batch size of the replay buffer
training_batch_size = 32

seed = 123


def main():
catframes = CatFrames(
N=frame_stack,
dim=stack_dim,
in_keys=["pixels_trsf"],
out_keys=["pixels_trsf"],
)
env = TransformedEnv(
DMControlEnv(
env_name="cartpole",
task_name="balance",
device="cpu",
from_pixels=True,
pixels_only=True,
),
Compose(
ToTensorImage(
from_int=True,
dtype=torch.float32,
in_keys=["pixels"],
out_keys=["pixels_trsf"],
shape_tolerant=True,
),
UnsqueezeTransform(
dim=stack_dim, in_keys=["pixels_trsf"], out_keys=["pixels_trsf"]
),
catframes,
StepCounter(),
),
)
env.set_seed(seed)
vmoens marked this conversation as resolved.
Show resolved Hide resolved

transform, sampler = catframes.make_rb_transform_and_sampler(
batch_size=training_batch_size,
traj_key=("collector", "traj_ids"),
strict_length=True,
)

rb_transforms = Compose(
ToTensorImage(
from_int=True,
dtype=torch.float32,
in_keys=["pixels", ("next", "pixels")],
out_keys=["pixels_trsf", ("next", "pixels_trsf")],
shape_tolerant=True,
), # C W' H' -> C W' H' (unchanged due to shape_tolerant)
UnsqueezeTransform(
dim=stack_dim,
in_keys=["pixels_trsf", ("next", "pixels_trsf")],
out_keys=["pixels_trsf", ("next", "pixels_trsf")],
), # 1 C W' H'
transform,
)

rb = ReplayBuffer(
storage=LazyTensorStorage(max_size=max_size, device="cpu"),
sampler=sampler,
batch_size=training_batch_size,
transform=rb_transforms,
)

data = env.rollout(1000, break_when_any_done=False)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should perform 2k steps to be guaranteed to have 2 trajs (max is 1000 for DMC).
This will allow the addition of assertions that could be used in the doc to explain what is going on further (eg: As you can see, we check that the stacked frames for the frame at index 0 in the sampled batch is actually ...)
OR
Keep using 1000 in the example but for the unit tests I would make sure to have a test that ensures that CatFrames did not simply use the previous frame when that previous frame did not belong to the same traj.

rb.extend(data)

training_batch = rb.sample()
print(training_batch)


if __name__ == "__main__":
main()
23 changes: 23 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,29 @@ def test_transform_rb(self, dim, N, padding, rbclass):
assert (tdsample["out_" + key1] == td["out_" + key1]).all()
assert (tdsample["next", "out_" + key1] == td["next", "out_" + key1]).all()

def test_transform_rb_maker(self):
env = CountingEnv(max_steps=10)
catframes = CatFrames(
in_keys=["observation"], out_keys=["observation_stack"], dim=-1, N=4
)
env.append_transform(catframes)
policy = lambda td: td.update(env.full_action_spec.zeros() + 1)
rollout = env.rollout(150, policy, break_when_any_done=False)
transform, sampler = catframes.make_rb_transform_and_sampler(batch_size=32)
rb = ReplayBuffer(
sampler=sampler, storage=LazyTensorStorage(150), transform=transform
)
rb.extend(rollout)
sample = rb.sample(32)
assert "observation_stack" not in rb._storage._storage
assert sample.shape == (32,)
assert sample["observation_stack"].shape == (32, 4)
assert sample["next", "observation_stack"].shape == (32, 4)
assert (
sample["observation_stack"]
== sample["observation_stack"][:, :1] + torch.arange(4)
).all()

@pytest.mark.parametrize("dim", [-1])
@pytest.mark.parametrize("N", [3, 4])
@pytest.mark.parametrize("padding", ["same", "constant"])
Expand Down
7 changes: 7 additions & 0 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,9 @@ class SliceSampler(Sampler):

"""

# We use this whenever we need to sample N times too many transitions to then select only a 1/N fraction of them
vmoens marked this conversation as resolved.
Show resolved Hide resolved
_batch_size_multiplier: int | None = 1

def __init__(
self,
*,
Expand Down Expand Up @@ -1295,6 +1298,8 @@ def _adjusted_batch_size(self, batch_size):
return seq_length, num_slices

def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]:
if self._batch_size_multiplier is not None:
vmoens marked this conversation as resolved.
Show resolved Hide resolved
batch_size = batch_size * self._batch_size_multiplier
# pick up as many trajs as we need
start_idx, stop_idx, lengths = self._get_stop_and_length(storage)
# we have to make sure that the number of dims of the storage
Expand Down Expand Up @@ -1747,6 +1752,8 @@ def _storage_len(self, storage):
def sample(
self, storage: Storage, batch_size: int
) -> Tuple[Tuple[torch.Tensor, ...], dict]:
if self._batch_size_multiplier is not None:
batch_size = batch_size * self._batch_size_multiplier
start_idx, stop_idx, lengths = self._get_stop_and_length(storage)
# we have to make sure that the number of dims of the storage
# is the same as the stop/start signals since we will
Expand Down
83 changes: 80 additions & 3 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2825,9 +2825,9 @@ def _reset(
class CatFrames(ObservationTransform):
"""Concatenates successive observation frames into a single tensor.

This can, for instance, account for movement/velocity of the observed
feature. Proposed in "Playing Atari with Deep Reinforcement Learning" (
https://arxiv.org/pdf/1312.5602.pdf).
This transform is useful for creating a sense of movement or velocity in the observed features.
It can also be used with models that require access to past observations such as transformers and the like.
It was first proposed in "Playing Atari with Deep Reinforcement Learning" (https://arxiv.org/pdf/1312.5602.pdf).

When used within a transformed environment,
:class:`CatFrames` is a stateful class, and it can be reset to its native state by
Expand Down Expand Up @@ -2915,6 +2915,14 @@ class CatFrames(ObservationTransform):
such as those found in MARL settings, are currently not supported.
If this feature is needed, please raise an issue on TorchRL repo.

.. note:: Storing stacks of frames in the replay buffer can significantly increase memory consumption (by N times).
To mitigate this, you can store trajectories directly in the replay buffer and apply :class:`CatFrames` at sampling time.
This approach involves sampling slices of the stored trajectories and then applying the frame stacking transform.
For convenience, :class:`CatFrames` provides a :meth:`~.make_rb_transform_and_sampler` method that creates:

- A modified version of the transform suitable for use in replay buffers
- A corresponding :class:`SliceSampler` to use with the buffer

"""

inplace = False
Expand Down Expand Up @@ -2964,6 +2972,75 @@ def __init__(
self.reset_key = reset_key
self.done_key = done_key

def make_rb_transform_and_sampler(
self, batch_size: int, **sampler_kwargs
) -> Tuple[Transform, "torchrl.data.replay_buffers.SliceSampler"]: # noqa: F821
"""Creates a transform and sampler to be used with a replay buffer when storing frame-stacked data.

This method helps reduce redundancy in stored data by avoiding the need to
store the entire stack of frames in the buffer. Instead, it creates a
transform that stacks frames on-the-fly during sampling, and a sampler that
ensures the correct sequence length is maintained.

Args:
batch_size (int): The batch size to use for the sampler.
**sampler_kwargs: Additional keyword arguments to pass to the
:class:`~torchrl.data.replay_buffers.SliceSampler` constructor.

Returns:
A tuple containing:
- transform (Transform): A transform that stacks frames on-the-fly during sampling.
- sampler (SliceSampler): A sampler that ensures the correct sequence length is maintained.

Example:
>>> env = TransformedEnv(...)
>>> catframes = CatFrames(N=4, ...)
>>> transform, sampler = catframes.make_rb_transform_and_sampler(batch_size=32)
>>> rb = ReplayBuffer(..., sampler=sampler, transform=transform)

.. note:: When working with images, it's recommended to use distinct ``in_keys`` and ``out_keys`` in the preceding
:class:`~torchrl.envs.ToTensorImage` transform. This ensures that the tensors stored in the buffer are separate
from their processed counterparts, which we don't want to store.
For non-image data, consider inserting a :class:`~torchrl.envs.RenameTransform` before :class:`CatFrames` to create
a copy of the data that will be stored in the buffer.

.. note:: When adding the transform to the replay buffer, one should pay attention to also pass the transforms
that precede CatFrames, such as :class:`~torchrl.envs.ToTensorImage` or :class:`~torchrl.envs.UnsqueezeTransform`
in such a way that the :class:`~torchrl.envs.CatFrames` transforms sees data formatted as it was during data
collection.

.. note:: For a more complete example, refer to torchrl's github repo `examples` folder:
https://github.com/pytorch/rl/tree/main/examples/replay-buffers/catframes-in-buffer.py

"""
from torchrl.data.replay_buffers import SliceSampler

in_keys = self.in_keys
in_keys = in_keys + [unravel_key(("next", key)) for key in in_keys]
out_keys = self.out_keys
out_keys = out_keys + [unravel_key(("next", key)) for key in out_keys]
catframes = type(self)(
N=self.N,
in_keys=in_keys,
out_keys=out_keys,
dim=self.dim,
padding=self.padding,
padding_value=self.padding_value,
as_inverse=False,
reset_key=self.reset_key,
done_key=self.done_key,
)
sampler = SliceSampler(slice_len=self.N, **sampler_kwargs)
vmoens marked this conversation as resolved.
Show resolved Hide resolved
sampler._batch_size_multiplier = self.N
transform = Compose(
lambda td: td.reshape(-1, self.N),
catframes,
lambda td: td[:, -1],
# We only store "pixels" to the replay buffer to save memory
ExcludeTransform(*out_keys, inverse=True),
)
return transform, sampler

@property
def done_key(self):
done_key = self.__dict__.get("_done_key", None)
Expand Down
Loading