diff --git a/examples/replay-buffers/catframes-in-buffer.py b/examples/replay-buffers/catframes-in-buffer.py new file mode 100644 index 00000000000..916fc63bc50 --- /dev/null +++ b/examples/replay-buffers/catframes-in-buffer.py @@ -0,0 +1,99 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torchrl.data import LazyTensorStorage, ReplayBuffer +from torchrl.envs import ( + CatFrames, + Compose, + DMControlEnv, + StepCounter, + ToTensorImage, + TransformedEnv, + UnsqueezeTransform, +) + +# Number of frames to stack together +frame_stack = 4 +# Dimension along which the stack should occur +stack_dim = -4 +# Max size of the buffer +max_size = 100_000 +# Batch size of the replay buffer +training_batch_size = 32 + +seed = 123 + + +def main(): + catframes = CatFrames( + N=frame_stack, + dim=stack_dim, + in_keys=["pixels_trsf"], + out_keys=["pixels_trsf"], + ) + env = TransformedEnv( + DMControlEnv( + env_name="cartpole", + task_name="balance", + device="cpu", + from_pixels=True, + pixels_only=True, + ), + Compose( + ToTensorImage( + from_int=True, + dtype=torch.float32, + in_keys=["pixels"], + out_keys=["pixels_trsf"], + shape_tolerant=True, + ), + UnsqueezeTransform( + dim=stack_dim, in_keys=["pixels_trsf"], out_keys=["pixels_trsf"] + ), + catframes, + StepCounter(), + ), + ) + env.set_seed(seed) + + transform, sampler = catframes.make_rb_transform_and_sampler( + batch_size=training_batch_size, + traj_key=("collector", "traj_ids"), + strict_length=True, + ) + + rb_transforms = Compose( + ToTensorImage( + from_int=True, + dtype=torch.float32, + in_keys=["pixels", ("next", "pixels")], + out_keys=["pixels_trsf", ("next", "pixels_trsf")], + shape_tolerant=True, + ), # C W' H' -> C W' H' (unchanged due to shape_tolerant) + UnsqueezeTransform( + dim=stack_dim, + in_keys=["pixels_trsf", ("next", "pixels_trsf")], + out_keys=["pixels_trsf", ("next", "pixels_trsf")], + ), # 1 C W' H' + transform, + ) + + rb = ReplayBuffer( + storage=LazyTensorStorage(max_size=max_size, device="cpu"), + sampler=sampler, + batch_size=training_batch_size, + transform=rb_transforms, + ) + + data = env.rollout(1000, break_when_any_done=False) + rb.extend(data) + + training_batch = rb.sample() + print(training_batch) + + +if __name__ == "__main__": + main() diff --git a/test/test_transforms.py b/test/test_transforms.py index d90c00b6a19..cc3ca40b059 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -933,6 +933,29 @@ def test_transform_rb(self, dim, N, padding, rbclass): assert (tdsample["out_" + key1] == td["out_" + key1]).all() assert (tdsample["next", "out_" + key1] == td["next", "out_" + key1]).all() + def test_transform_rb_maker(self): + env = CountingEnv(max_steps=10) + catframes = CatFrames( + in_keys=["observation"], out_keys=["observation_stack"], dim=-1, N=4 + ) + env.append_transform(catframes) + policy = lambda td: td.update(env.full_action_spec.zeros() + 1) + rollout = env.rollout(150, policy, break_when_any_done=False) + transform, sampler = catframes.make_rb_transform_and_sampler(batch_size=32) + rb = ReplayBuffer( + sampler=sampler, storage=LazyTensorStorage(150), transform=transform + ) + rb.extend(rollout) + sample = rb.sample(32) + assert "observation_stack" not in rb._storage._storage + assert sample.shape == (32,) + assert sample["observation_stack"].shape == (32, 4) + assert sample["next", "observation_stack"].shape == (32, 4) + assert ( + sample["observation_stack"] + == sample["observation_stack"][:, :1] + torch.arange(4) + ).all() + @pytest.mark.parametrize("dim", [-1]) @pytest.mark.parametrize("N", [3, 4]) @pytest.mark.parametrize("padding", ["same", "constant"]) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index b97b585aa3f..bbdf2387683 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -968,6 +968,9 @@ class SliceSampler(Sampler): """ + # We use this whenever we need to sample N times too many transitions to then select only a 1/N fraction of them + _batch_size_multiplier: int | None = 1 + def __init__( self, *, @@ -1295,6 +1298,8 @@ def _adjusted_batch_size(self, batch_size): return seq_length, num_slices def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]: + if self._batch_size_multiplier is not None: + batch_size = batch_size * self._batch_size_multiplier # pick up as many trajs as we need start_idx, stop_idx, lengths = self._get_stop_and_length(storage) # we have to make sure that the number of dims of the storage @@ -1747,6 +1752,8 @@ def _storage_len(self, storage): def sample( self, storage: Storage, batch_size: int ) -> Tuple[Tuple[torch.Tensor, ...], dict]: + if self._batch_size_multiplier is not None: + batch_size = batch_size * self._batch_size_multiplier start_idx, stop_idx, lengths = self._get_stop_and_length(storage) # we have to make sure that the number of dims of the storage # is the same as the stop/start signals since we will diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 0bab5868ded..f3329d085df 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -2825,9 +2825,9 @@ def _reset( class CatFrames(ObservationTransform): """Concatenates successive observation frames into a single tensor. - This can, for instance, account for movement/velocity of the observed - feature. Proposed in "Playing Atari with Deep Reinforcement Learning" ( - https://arxiv.org/pdf/1312.5602.pdf). + This transform is useful for creating a sense of movement or velocity in the observed features. + It can also be used with models that require access to past observations such as transformers and the like. + It was first proposed in "Playing Atari with Deep Reinforcement Learning" (https://arxiv.org/pdf/1312.5602.pdf). When used within a transformed environment, :class:`CatFrames` is a stateful class, and it can be reset to its native state by @@ -2915,6 +2915,14 @@ class CatFrames(ObservationTransform): such as those found in MARL settings, are currently not supported. If this feature is needed, please raise an issue on TorchRL repo. + .. note:: Storing stacks of frames in the replay buffer can significantly increase memory consumption (by N times). + To mitigate this, you can store trajectories directly in the replay buffer and apply :class:`CatFrames` at sampling time. + This approach involves sampling slices of the stored trajectories and then applying the frame stacking transform. + For convenience, :class:`CatFrames` provides a :meth:`~.make_rb_transform_and_sampler` method that creates: + + - A modified version of the transform suitable for use in replay buffers + - A corresponding :class:`SliceSampler` to use with the buffer + """ inplace = False @@ -2964,6 +2972,75 @@ def __init__( self.reset_key = reset_key self.done_key = done_key + def make_rb_transform_and_sampler( + self, batch_size: int, **sampler_kwargs + ) -> Tuple[Transform, "torchrl.data.replay_buffers.SliceSampler"]: # noqa: F821 + """Creates a transform and sampler to be used with a replay buffer when storing frame-stacked data. + + This method helps reduce redundancy in stored data by avoiding the need to + store the entire stack of frames in the buffer. Instead, it creates a + transform that stacks frames on-the-fly during sampling, and a sampler that + ensures the correct sequence length is maintained. + + Args: + batch_size (int): The batch size to use for the sampler. + **sampler_kwargs: Additional keyword arguments to pass to the + :class:`~torchrl.data.replay_buffers.SliceSampler` constructor. + + Returns: + A tuple containing: + - transform (Transform): A transform that stacks frames on-the-fly during sampling. + - sampler (SliceSampler): A sampler that ensures the correct sequence length is maintained. + + Example: + >>> env = TransformedEnv(...) + >>> catframes = CatFrames(N=4, ...) + >>> transform, sampler = catframes.make_rb_transform_and_sampler(batch_size=32) + >>> rb = ReplayBuffer(..., sampler=sampler, transform=transform) + + .. note:: When working with images, it's recommended to use distinct ``in_keys`` and ``out_keys`` in the preceding + :class:`~torchrl.envs.ToTensorImage` transform. This ensures that the tensors stored in the buffer are separate + from their processed counterparts, which we don't want to store. + For non-image data, consider inserting a :class:`~torchrl.envs.RenameTransform` before :class:`CatFrames` to create + a copy of the data that will be stored in the buffer. + + .. note:: When adding the transform to the replay buffer, one should pay attention to also pass the transforms + that precede CatFrames, such as :class:`~torchrl.envs.ToTensorImage` or :class:`~torchrl.envs.UnsqueezeTransform` + in such a way that the :class:`~torchrl.envs.CatFrames` transforms sees data formatted as it was during data + collection. + + .. note:: For a more complete example, refer to torchrl's github repo `examples` folder: + https://github.com/pytorch/rl/tree/main/examples/replay-buffers/catframes-in-buffer.py + + """ + from torchrl.data.replay_buffers import SliceSampler + + in_keys = self.in_keys + in_keys = in_keys + [unravel_key(("next", key)) for key in in_keys] + out_keys = self.out_keys + out_keys = out_keys + [unravel_key(("next", key)) for key in out_keys] + catframes = type(self)( + N=self.N, + in_keys=in_keys, + out_keys=out_keys, + dim=self.dim, + padding=self.padding, + padding_value=self.padding_value, + as_inverse=False, + reset_key=self.reset_key, + done_key=self.done_key, + ) + sampler = SliceSampler(slice_len=self.N, **sampler_kwargs) + sampler._batch_size_multiplier = self.N + transform = Compose( + lambda td: td.reshape(-1, self.N), + catframes, + lambda td: td[:, -1], + # We only store "pixels" to the replay buffer to save memory + ExcludeTransform(*out_keys, inverse=True), + ) + return transform, sampler + @property def done_key(self): done_key = self.__dict__.get("_done_key", None)