diff --git a/examples/replay-buffers/catframes-in-buffer.py b/examples/replay-buffers/catframes-in-buffer.py new file mode 100644 index 00000000000..9f8892f7989 --- /dev/null +++ b/examples/replay-buffers/catframes-in-buffer.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torchrl.data import LazyTensorStorage, ReplayBuffer +from torchrl.envs import ( + CatFrames, + Compose, + DMControlEnv, + StepCounter, + ToTensorImage, + TransformedEnv, + UnsqueezeTransform, +) + +# Number of frames to stack together +frame_stack = 4 +# Dimension along which the stack should occur +stack_dim = -4 +# Max size of the buffer +max_size = 100_000 +# Batch size of the replay buffer +training_batch_size = 32 + +seed = 123 + + +def main(): + env = TransformedEnv( + DMControlEnv( + env_name="cartpole", + task_name="balance", + device="cpu", + from_pixels=True, + pixels_only=True, + ), + Compose( + ToTensorImage( + from_int=True, + dtype=torch.float32, + in_keys=["pixels"], + out_keys=["pixels_trsf"], + shape_tolerant=True, + ), + UnsqueezeTransform( + dim=stack_dim, in_keys=["pixels_trsf"], out_keys=["pixels_trsf"] + ), + CatFrames( + N=frame_stack, + dim=stack_dim, + in_keys=["pixels_trsf"], + out_keys=["pixels_trsf"], + ), + StepCounter(), + ), + ) + env.set_seed(seed) + + sampler = None + transform = None + + def collect_transform_sampler(module): + nonlocal sampler + nonlocal transform + if isinstance(module, CatFrames): + transform, sampler = module.make_rb_transform_and_sampler( + batch_size=training_batch_size, + traj_key=("collector", "traj_ids"), + strict_length=True, + ) + + env.apply(collect_transform_sampler) + + rb_transforms = Compose( + ToTensorImage( + from_int=True, + dtype=torch.float32, + in_keys=["pixels", ("next", "pixels")], + out_keys=["pixels_trsf", ("next", "pixels_trsf")], + shape_tolerant=True, + ), # C W' H' -> C W' H' (unchanged due to shape_tolerant) + UnsqueezeTransform( + dim=stack_dim, + in_keys=["pixels_trsf", ("next", "pixels_trsf")], + out_keys=["pixels_trsf", ("next", "pixels_trsf")], + ), # 1 C W' H' + transform, + ) + + rb = ReplayBuffer( + storage=LazyTensorStorage(max_size=max_size, device="cpu"), + sampler=sampler, + batch_size=training_batch_size, + transform=rb_transforms, + ) + + data = env.rollout(1000, break_when_any_done=False) + rb.extend(data) + + # This is where things do not work as expected right now. + training_batch = rb.sample() + print(training_batch) + + +if __name__ == "__main__": + main() diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index b97b585aa3f..625ad9b209e 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -968,6 +968,9 @@ class SliceSampler(Sampler): """ + # We use this whenever we need to sample N times too many transitions to then select only a 1/N fraction of them + _batch_size_multiplier: int | None + def __init__( self, *, @@ -1295,6 +1298,8 @@ def _adjusted_batch_size(self, batch_size): return seq_length, num_slices def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]: + if self._batch_size_multiplier is not None: + batch_size = batch_size * self._batch_size_multiplier # pick up as many trajs as we need start_idx, stop_idx, lengths = self._get_stop_and_length(storage) # we have to make sure that the number of dims of the storage @@ -1747,6 +1752,8 @@ def _storage_len(self, storage): def sample( self, storage: Storage, batch_size: int ) -> Tuple[Tuple[torch.Tensor, ...], dict]: + if self._batch_size_multiplier is not None: + batch_size = batch_size * self._batch_size_multiplier start_idx, stop_idx, lengths = self._get_stop_and_length(storage) # we have to make sure that the number of dims of the storage # is the same as the stop/start signals since we will diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 0bab5868ded..7ada415ccd4 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -2825,9 +2825,9 @@ def _reset( class CatFrames(ObservationTransform): """Concatenates successive observation frames into a single tensor. - This can, for instance, account for movement/velocity of the observed - feature. Proposed in "Playing Atari with Deep Reinforcement Learning" ( - https://arxiv.org/pdf/1312.5602.pdf). + This transform is useful for creating a sense of movement or velocity in the observed features. + It can also be used with models that require access to past observations such as transformers and the like. + It was first proposed in "Playing Atari with Deep Reinforcement Learning" (https://arxiv.org/pdf/1312.5602.pdf). When used within a transformed environment, :class:`CatFrames` is a stateful class, and it can be reset to its native state by @@ -2915,6 +2915,14 @@ class CatFrames(ObservationTransform): such as those found in MARL settings, are currently not supported. If this feature is needed, please raise an issue on TorchRL repo. + .. note:: Storing stacks of frames in the replay buffer can significantly increase memory consumption (by N times). + To mitigate this, you can store trajectories directly in the replay buffer and apply :class:`CatFrames` at sampling time. + This approach involves sampling slices of the stored trajectories and then applying the frame stacking transform. + For convenience, :class:`CatFrames` provides a :meth:`~.make_rb_transform_and_sampler` method that creates: + + - A modified version of the transform suitable for use in replay buffers + - A corresponding :class:`SliceSampler` to use with the buffer + """ inplace = False @@ -2964,6 +2972,56 @@ def __init__( self.reset_key = reset_key self.done_key = done_key + def make_rb_transform_and_sampler( + self, batch_size: int, **sampler_kwargs + ) -> Tuple[Transform, "torchrl.data.replay_buffers.SliceSampler"]: # noqa: F821 + """Creates a transform and sampler to be used with a replay buffer when storing frame-stacked data. + + This method helps reduce redundancy in stored data by avoiding the need to + store the entire stack of frames in the buffer. Instead, it creates a + transform that stacks frames on-the-fly during sampling, and a sampler that + ensures the correct sequence length is maintained. + + Args: + batch_size (int): The batch size to use for the sampler. + **sampler_kwargs: Additional keyword arguments to pass to the + :class:`~torchrl.data.replay_buffers.SliceSampler` constructor. + + Returns: + A tuple containing: + - transform (Transform): A transform that stacks frames on-the-fly during sampling. + - sampler (SliceSampler): A sampler that ensures the correct sequence length is maintained. + + Example: + >>> env = TransformedEnv(...) + >>> catframes = CatFrames(N=4, ...) + >>> transform, sampler = catframes.make_rb_transform_and_sampler(batch_size=32) + >>> rb = ReplayBuffer(..., sampler=sampler, transform=transform) + + .. note:: When working with images, it's recommended to use distinct ``in_keys`` and ``out_keys`` in the preceding + :class:`~torchrl.envs.ToTensorImage` transform. This ensures that the tensors stored in the buffer are separate + from their processed counterparts, which we don't want to store. + For non-image data, consider inserting a :class:`~torchrl.envs.RenameTransform` before :class:`CatFrames` to create + a copy of the data that will be stored in the buffer. + + .. note:: For a more complete example, refer to torchrl's github repo `examples` folder: + https://github.com/pytorch/rl/tree/main/examples/replay-buffers/catframes-in-buffer.py + + """ + from torchrl.data.replay_buffers import SliceSampler + + catframes = self.clone() + sampler = SliceSampler(slice_len=self.N, **sampler_kwargs) + sampler._batch_size_multiplier = self.N + transform = Compose( + lambda td: td.reshape(-1, self.N), + catframes, + lambda td: td[:, -1], + # We only store "pixels" to the replay buffer to save memory + ExcludeTransform(*self.in_keys, inverse=True), + ) + return transform, sampler + @property def done_key(self): done_key = self.__dict__.get("_done_key", None)