From 4d52d5ffe085aff01d6eee6d3e901aa25d4c7561 Mon Sep 17 00:00:00 2001 From: Remi Date: Wed, 7 Feb 2024 21:17:43 +0100 Subject: [PATCH 01/10] [Feature] Add PrioritizedSliceSampler (#1875) --- docs/source/reference/data.rst | 1 + test/test_rb.py | 149 ++++++++++++-- torchrl/data/replay_buffers/__init__.py | 1 + torchrl/data/replay_buffers/samplers.py | 246 +++++++++++++++++++++++- 4 files changed, 379 insertions(+), 18 deletions(-) diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 6ed32ebe921..8ab6401b314 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -134,6 +134,7 @@ using the following components: Sampler PrioritizedSampler + PrioritizedSliceSampler RandomSampler SamplerWithoutReplacement SliceSampler diff --git a/test/test_rb.py b/test/test_rb.py index 548e4ba9726..697981909b5 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -40,6 +40,7 @@ from torchrl.data.replay_buffers import samplers, writers from torchrl.data.replay_buffers.samplers import ( PrioritizedSampler, + PrioritizedSliceSampler, RandomSampler, SamplerEnsemble, SamplerWithoutReplacement, @@ -1834,13 +1835,14 @@ def test_sampler_without_rep_state_dict(self, backend): assert (s.exclude("index") == 0).all() @pytest.mark.parametrize( - "batch_size,num_slices,slice_len", + "batch_size,num_slices,slice_len,prioritized", [ - [100, 20, None], - [120, 30, None], - [100, None, 5], - [120, None, 4], - [101, None, 101], + [100, 20, None, True], + [100, 20, None, False], + [120, 30, None, False], + [100, None, 5, False], + [120, None, 4, False], + [101, None, 101, False], ], ) @pytest.mark.parametrize("episode_key", ["episode", ("some", "episode")]) @@ -1853,6 +1855,7 @@ def test_slice_sampler( batch_size, num_slices, slice_len, + prioritized, episode_key, done_key, match_episode, @@ -1897,19 +1900,34 @@ def test_slice_sampler( else: strict_length = True - sampler = SliceSampler( - num_slices=num_slices, - traj_key=episode_key, - end_key=done_key, - slice_len=slice_len, - strict_length=strict_length, - ) + if prioritized: + num_steps = data.shape[0] + sampler = PrioritizedSliceSampler( + max_capacity=num_steps, + alpha=0.7, + beta=0.9, + num_slices=num_slices, + traj_key=episode_key, + end_key=done_key, + slice_len=slice_len, + strict_length=strict_length, + ) + index = torch.arange(0, num_steps, 1) + sampler.extend(index) + else: + sampler = SliceSampler( + num_slices=num_slices, + traj_key=episode_key, + end_key=done_key, + slice_len=slice_len, + strict_length=strict_length, + ) if slice_len is not None: num_slices = batch_size // slice_len trajs_unique_id = set() too_short = False count_unique = set() - for _ in range(10): + for _ in range(30): index, info = sampler.sample(storage, batch_size=batch_size) if _data_prefix: samples = storage._storage["_data"][index] @@ -1918,6 +1936,7 @@ def test_slice_sampler( if strict_length: # check that trajs are ok samples = samples.view(num_slices, -1) + assert samples["another_episode"].unique( dim=1 ).squeeze().shape == torch.Size([num_slices]) @@ -1936,6 +1955,7 @@ def test_slice_sampler( raise AssertionError( f"Not all items can be sampled: {set(range(100))-count_unique} are missing" ) + if strict_length: assert not too_short else: @@ -2071,6 +2091,107 @@ def test_slice_sampler_without_replacement( assert truncated.view(num_slices, -1)[:, -1].all() +def test_prioritized_slice_sampler_doc_example(): + sampler = PrioritizedSliceSampler(max_capacity=9, num_slices=3, alpha=0.7, beta=0.9) + rb = TensorDictReplayBuffer( + storage=LazyMemmapStorage(9), sampler=sampler, batch_size=6 + ) + data = TensorDict( + { + "observation": torch.randn(9, 16), + "action": torch.randn(9, 1), + "episode": torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2], dtype=torch.long), + "steps": torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2], dtype=torch.long), + ("next", "observation"): torch.randn(9, 16), + ("next", "reward"): torch.randn(9, 1), + ("next", "done"): torch.tensor( + [0, 0, 1, 0, 0, 1, 0, 0, 1], dtype=torch.bool + ).unsqueeze(1), + }, + batch_size=[9], + ) + rb.extend(data) + sample, info = rb.sample(return_info=True) + # print("episode", sample["episode"].tolist()) + # print("steps", sample["steps"].tolist()) + # print("weight", info["_weight"].tolist()) + + priority = torch.tensor([0, 3, 3, 0, 0, 0, 1, 1, 1]) + rb.update_priority(torch.arange(0, 9, 1), priority=priority) + sample, info = rb.sample(return_info=True) + # print("episode", sample["episode"].tolist()) + # print("steps", sample["steps"].tolist()) + # print("weight", info["_weight"].tolist()) + + +@pytest.mark.parametrize("device", get_default_devices()) +def test_prioritized_slice_sampler_episodes(device): + num_slices = 10 + batch_size = 20 + + episode = torch.zeros(100, dtype=torch.int, device=device) + episode[:30] = 1 + episode[30:55] = 2 + episode[55:70] = 3 + episode[70:] = 4 + steps = torch.cat( + [torch.arange(30), torch.arange(25), torch.arange(15), torch.arange(30)], 0 + ) + done = torch.zeros(100, 1, dtype=torch.bool) + done[torch.tensor([29, 54, 69])] = 1 + + data = TensorDict( + { + "observation": torch.randn(100, 16), + "action": torch.randn(100, 4), + "episode": episode, + "steps": steps, + ("next", "observation"): torch.randn(100, 16), + ("next", "reward"): torch.randn(100, 1), + ("next", "done"): done, + }, + batch_size=[100], + device=device, + ) + + num_steps = data.shape[0] + sampler = PrioritizedSliceSampler( + max_capacity=num_steps, + alpha=0.7, + beta=0.9, + num_slices=num_slices, + ) + + rb = TensorDictReplayBuffer( + storage=LazyMemmapStorage(100), + sampler=sampler, + batch_size=batch_size, + ) + rb.extend(data) + + episodes = [] + for _ in range(10): + sample = rb.sample() + episodes.append(sample["episode"]) + assert {1, 2, 3, 4} == set( + torch.cat(episodes).cpu().tolist() + ), "all episodes are expected to be sampled at least once" + + index = torch.arange(0, num_steps, 1) + new_priorities = torch.cat( + [torch.ones(30), torch.zeros(25), torch.ones(15), torch.zeros(30)], 0 + ) + sampler.update_priority(index, new_priorities) + + episodes = [] + for _ in range(10): + sample = rb.sample() + episodes.append(sample["episode"]) + assert {1, 3} == set( + torch.cat(episodes).cpu().tolist() + ), "after priority update, only episode 1 and 3 are expected to be sampled" + + class TestEnsemble: def _make_data(self, data_type): if data_type is torch.Tensor: diff --git a/torchrl/data/replay_buffers/__init__.py b/torchrl/data/replay_buffers/__init__.py index 77e7501de0c..c9aadcb992b 100644 --- a/torchrl/data/replay_buffers/__init__.py +++ b/torchrl/data/replay_buffers/__init__.py @@ -13,6 +13,7 @@ ) from .samplers import ( PrioritizedSampler, + PrioritizedSliceSampler, RandomSampler, Sampler, SamplerEnsemble, diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 21245e37acd..96d73375ea9 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -596,7 +596,7 @@ class SliceSampler(Sampler): allowed to appear in the batch. Be mindful that this can result in effective `batch_size` shorter than the one asked for! Trajectories can be split using - :func:`torchrl.collectors.split_trajectories`. Defaults to ``True``. + :func:`~torchrl.collectors.split_trajectories`. Defaults to ``True``. .. note:: To recover the trajectory splits in the storage, :class:`~torchrl.data.replay_buffers.samplers.SliceSampler` will first @@ -633,7 +633,7 @@ class SliceSampler(Sampler): >>> print("episodes", sample.get("episode").unique()) episodes tensor([1, 2, 3, 4], dtype=torch.int32) - :class:`torchrl.data.replay_buffers.SliceSampler` is default-compatible with + :class:`~torchrl.data.replay_buffers.SliceSampler` is default-compatible with most of TorchRL's datasets: Examples: @@ -1012,7 +1012,7 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement): allowed to appear in the batch. Be mindful that this can result in effective `batch_size` shorter than the one asked for! Trajectories can be split using - :func:`torchrl.collectors.split_trajectories`. Defaults to ``True``. + :func:`~torchrl.collectors.split_trajectories`. Defaults to ``True``. shuffle (bool, optional): if ``False``, the order of the trajectories is not shuffled. Defaults to ``True``. @@ -1053,7 +1053,7 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement): >>> print("sample:", sample) >>> print("trajectories in sample", sample.get("episode").unique()) - :class:`torchrl.data.replay_buffers.SliceSamplerWithoutReplacement` is default-compatible with + :class:`~torchrl.data.replay_buffers.SliceSamplerWithoutReplacement` is default-compatible with most of TorchRL's datasets, and allows users to consume datasets in a dataloader-like fashion: Examples: @@ -1129,6 +1129,244 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: return SamplerWithoutReplacement.load_state_dict(self, state_dict) +class PrioritizedSliceSampler(SliceSampler, PrioritizedSampler): + """Samples slices of data along the first dimension, given start and stop signals, using prioritized sampling. + + This class samples sub-trajectories with replacement following a priority weighting presented in "Schaul, T.; Quan, J.; Antonoglou, I.; and Silver, D. 2015. + Prioritized experience replay." + (https://arxiv.org/abs/1511.05952) + + For more info see :class:`~torchrl.data.replay_buffers.samplers.SliceSampler` and :class:`~torchrl.data.replay_buffers.samplers.PrioritizedSampler`. + + Args: + alpha (float): exponent α determines how much prioritization is used, + with α = 0 corresponding to the uniform case. + beta (float): importance sampling negative exponent. + eps (float, optional): delta added to the priorities to ensure that the buffer + does not contain null priorities. Defaults to 1e-8. + reduction (str, optional): the reduction method for multidimensional + tensordicts (i.e., stored trajectory). Can be one of "max", "min", + "median" or "mean". + + Keyword Args: + num_slices (int): the number of slices to be sampled. The batch-size + must be greater or equal to the ``num_slices`` argument. Exclusive + with ``slice_len``. + slice_len (int): the length of the slices to be sampled. The batch-size + must be greater or equal to the ``slice_len`` argument and divisible + by it. Exclusive with ``num_slices``. + end_key (NestedKey, optional): the key indicating the end of a + trajectory (or episode). Defaults to ``("next", "done")``. + traj_key (NestedKey, optional): the key indicating the trajectories. + Defaults to ``"episode"`` (commonly used across datasets in TorchRL). + ends (torch.Tensor, optional): a 1d boolean tensor containing the end of run signals. + To be used whenever the ``end_key`` or ``traj_key`` is expensive to get, + or when this signal is readily available. Must be used with ``cache_values=True`` + and cannot be used in conjunction with ``end_key`` or ``traj_key``. + trajectories (torch.Tensor, optional): a 1d integer tensor containing the run ids. + To be used whenever the ``end_key`` or ``traj_key`` is expensive to get, + or when this signal is readily available. Must be used with ``cache_values=True`` + and cannot be used in conjunction with ``end_key`` or ``traj_key``. + cache_values (bool, optional): to be used with static datasets. + Will cache the start and end signal of the trajectory. + truncated_key (NestedKey, optional): If not ``None``, this argument + indicates where a truncated signal should be written in the output + data. This is used to indicate to value estimators where the provided + trajectory breaks. Defaults to ``("next", "truncated")``. + This feature only works with :class:`~torchrl.data.replay_buffers.TensorDictReplayBuffer` + instances (otherwise the truncated key is returned in the info dictionary + returned by the :meth:`~torchrl.data.replay_buffers.ReplayBuffer.sample` method). + strict_length (bool, optional): if ``False``, trajectories of length + shorter than `slice_len` (or `batch_size // num_slices`) will be + allowed to appear in the batch. + Be mindful that this can result in effective `batch_size` shorter + than the one asked for! Trajectories can be split using + :func:`~torchrl.collectors.split_trajectories`. Defaults to ``True``. + + Examples: + >>> import torch + >>> from torchrl.data.replay_buffers import TensorDictReplayBuffer, LazyMemmapStorage, PrioritizedSliceSampler + >>> from tensordict import TensorDict + >>> sampler = PrioritizedSliceSampler(max_capacity=9, num_slices=3, alpha=0.7, beta=0.9) + >>> rb = TensorDictReplayBuffer(storage=LazyMemmapStorage(9), sampler=sampler, batch_size=6) + >>> data = TensorDict( + ... { + ... "observation": torch.randn(9,16), + ... "action": torch.randn(9, 1), + ... "episode": torch.tensor([0,0,0,1,1,1,2,2,2], dtype=torch.long), + ... "steps": torch.tensor([0,1,2,0,1,2,0,1,2], dtype=torch.long), + ... ("next", "observation"): torch.randn(9,16), + ... ("next", "reward"): torch.randn(9,1), + ... ("next", "done"): torch.tensor([0,0,1,0,0,1,0,0,1], dtype=torch.bool).unsqueeze(1), + ... }, + ... batch_size=[9], + ... ) + >>> rb.extend(data) + >>> sample, info = rb.sample(return_info=True) + >>> print("episode", sample["episode"].tolist()) + episode [2, 2, 2, 2, 1, 1] + >>> print("steps", sample["steps"].tolist()) + steps [1, 2, 0, 1, 1, 2] + >>> print("weight", info["_weight"].tolist()) + weight [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + >>> priority = torch.tensor([0,3,3,0,0,0,1,1,1]) + >>> rb.update_priority(torch.arange(0,9,1), priority=priority) + >>> sample, info = rb.sample(return_info=True) + >>> print("episode", sample["episode"].tolist()) + episode [2, 2, 2, 2, 2, 2] + >>> print("steps", sample["steps"].tolist()) + steps [1, 2, 0, 1, 0, 1] + >>> print("weight", info["_weight"].tolist()) + weight [9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06, 9.120110917137936e-06] + """ + + def __init__( + self, + max_capacity: int, + alpha: float, + beta: float, + eps: float = 1e-8, + dtype: torch.dtype = torch.float, + reduction: str = "max", + *, + num_slices: int = None, + slice_len: int = None, + end_key: NestedKey | None = None, + traj_key: NestedKey | None = None, + ends: torch.Tensor | None = None, + trajectories: torch.Tensor | None = None, + cache_values: bool = False, + truncated_key: NestedKey | None = ("next", "truncated"), + strict_length: bool = True, + ) -> object: + SliceSampler.__init__( + self, + num_slices=num_slices, + slice_len=slice_len, + end_key=end_key, + traj_key=traj_key, + cache_values=cache_values, + truncated_key=truncated_key, + strict_length=strict_length, + ends=ends, + trajectories=trajectories, + ) + PrioritizedSampler.__init__( + self, + max_capacity=max_capacity, + alpha=alpha, + beta=beta, + eps=eps, + dtype=dtype, + reduction=reduction, + ) + + def __getstate__(self): + state = SliceSampler.__getstate__(self) + state.update(PrioritizedSampler.__getstate__(self)) + + def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]: + # Sample `batch_size` indices representing the start of a slice. + # The sampling is based on a weight vector. + start_idx, stop_idx, lengths = self._get_stop_and_length(storage) + seq_length, num_slices = self._adjusted_batch_size(batch_size) + + num_trajs = lengths.shape[0] + traj_idx = torch.arange(0, num_trajs, 1, device=lengths.device) + + if (lengths < seq_length).any(): + if self.strict_length: + raise RuntimeError( + "Some stored trajectories have a length shorter than the slice that was asked for. " + "Create the sampler with `strict_length=False` to allow shorter trajectories to appear " + "in you batch." + ) + # make seq_length a tensor with values clamped by lengths + seq_length = lengths[traj_idx].clamp_max(seq_length) + + # build a list of index that we dont want to sample: all the steps at a `seq_length` distance of + # the end the trajectory, with the end of trajectory (`stop_idx`) included + if isinstance(seq_length, int): + subtractive_idx = torch.arange( + 0, seq_length - 1, 1, device=stop_idx.device, dtype=stop_idx.dtype + ) + preceding_stop_idx = ( + stop_idx[..., None] - subtractive_idx[None, ...] + ).view(-1) + else: + raise NotImplementedError("seq_length as a list is not supported for now") + + # force to not sample index at the end of a trajectory + self._sum_tree[preceding_stop_idx] = 0.0 + # and no need to update self._min_tree + + starts, info = PrioritizedSampler.sample( + self, storage=storage, batch_size=batch_size // seq_length + ) + # TODO: update PrioritizedSampler.sample to return torch tensors + starts = torch.as_tensor(starts, device=lengths.device) + info["_weight"] = torch.as_tensor(info["_weight"], device=lengths.device) + + # extends starting indices of each slice with sequence_length to get indices of all steps + index = self._tensor_slices_from_startend(seq_length, starts) + # repeat the weight of each slice to match the number of steps + info["_weight"] = torch.repeat_interleave(info["_weight"], seq_length) + + # sanity check + if index.shape[0] != batch_size: + raise ValueError( + f"Number of indices is expected to match the batch size ({index.shape[0]} != {batch_size})." + ) + + if self.truncated_key is not None: + # following logics borrowed from SliceSampler + truncated_key = self.truncated_key + done_key = _replace_last(truncated_key, "done") + terminated_key = _replace_last(truncated_key, "terminated") + + truncated = torch.zeros( + (*index.shape, 1), dtype=torch.bool, device=index.device + ) + if isinstance(seq_length, int): + truncated.view(num_slices, -1)[:, -1] = 1 + else: + truncated[seq_length.cumsum(0) - 1] = 1 + traj_terminated = stop_idx[traj_idx] == start_idx[traj_idx] + seq_length - 1 + terminated = torch.zeros_like(truncated) + if traj_terminated.any(): + if isinstance(seq_length, int): + truncated.view(num_slices, -1)[traj_terminated] = 1 + else: + truncated[(seq_length.cumsum(0) - 1)[traj_terminated]] = 1 + truncated = truncated & ~terminated + done = terminated | truncated + + info.update( + { + truncated_key: truncated, + done_key: done, + terminated_key: terminated, + } + ) + return index.to(torch.long), info + + def _empty(self): + # no op for SliceSampler + PrioritizedSampler._empty(self) + + def dumps(self, path): + # no op for SliceSampler + PrioritizedSampler.dumps(self, path) + + def loads(self, path): + # no op for SliceSampler + return PrioritizedSampler.loads(self, path) + + def state_dict(self): + # no op for SliceSampler + return PrioritizedSampler.state_dict(self) + + class SamplerEnsemble(Sampler): """An ensemble of samplers. From 601867f9fcc6930619a2d0efabd9b28649203af9 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 10 Feb 2024 10:33:04 +0000 Subject: [PATCH 02/10] [Doc] Getting started tutos (#1886) --- README.md | 5 + docs/source/index.rst | 17 + docs/source/reference/collectors.rst | 2 + docs/source/reference/data.rst | 2 + docs/source/reference/envs.rst | 3 + docs/source/reference/modules.rst | 4 + docs/source/reference/objectives.rst | 2 + docs/source/reference/trainers.rst | 2 + torchrl/data/replay_buffers/samplers.py | 2 +- torchrl/envs/common.py | 4 +- torchrl/envs/transforms/transforms.py | 4 +- torchrl/modules/tensordict_module/actors.py | 52 +-- torchrl/objectives/dqn.py | 5 +- torchrl/objectives/utils.py | 3 +- torchrl/record/loggers/csv.py | 4 +- tutorials/sphinx-tutorials/README.rst | 2 + tutorials/sphinx-tutorials/coding_ddpg.py | 2 + tutorials/sphinx-tutorials/coding_dqn.py | 8 +- tutorials/sphinx-tutorials/coding_ppo.py | 2 + tutorials/sphinx-tutorials/dqn_with_rnn.py | 2 + .../sphinx-tutorials/getting-started-0.py | 245 ++++++++++++++ .../sphinx-tutorials/getting-started-1.py | 309 ++++++++++++++++++ .../sphinx-tutorials/getting-started-2.py | 173 ++++++++++ .../sphinx-tutorials/getting-started-3.py | 180 ++++++++++ .../sphinx-tutorials/getting-started-4.py | 104 ++++++ .../sphinx-tutorials/getting-started-5.py | 183 +++++++++++ tutorials/sphinx-tutorials/pendulum.py | 2 + tutorials/sphinx-tutorials/rb_tutorial.py | 240 ++++++++++---- tutorials/sphinx-tutorials/torchrl_demo.py | 11 +- tutorials/sphinx-tutorials/torchrl_envs.py | 13 +- 30 files changed, 1478 insertions(+), 109 deletions(-) create mode 100644 tutorials/sphinx-tutorials/getting-started-0.py create mode 100644 tutorials/sphinx-tutorials/getting-started-1.py create mode 100644 tutorials/sphinx-tutorials/getting-started-2.py create mode 100644 tutorials/sphinx-tutorials/getting-started-3.py create mode 100644 tutorials/sphinx-tutorials/getting-started-4.py create mode 100644 tutorials/sphinx-tutorials/getting-started-5.py diff --git a/README.md b/README.md index 2e1d08a0757..6adbc2decfe 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,11 @@ On the low-level end, torchrl comes with a set of highly re-usable functionals f TorchRL aims at (1) a high modularity and (2) good runtime performance. Read the [full paper](https://arxiv.org/abs/2306.00577) for a more curated description of the library. +## Getting started + +Check our [Getting Started tutorials](https://pytorch.org/rl/index.html#getting-started) for quickly ramp up with the basic +features of the library! + ## Documentation and knowledge base The TorchRL documentation can be found [here](https://pytorch.org/rl). diff --git a/docs/source/index.rst b/docs/source/index.rst index 49bcde82488..ab1cee681db 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -62,6 +62,23 @@ or via a ``git clone`` if you're willing to contribute to the library: $ cd ../rl $ python setup.py develop +Getting started +=============== + +A series of quick tutorials to get ramped up with the basic features of the +library. If you're in a hurry, you can start by +:ref:`the last item of the series ` +and navigate to the previous ones whenever you want to learn more! + +.. toctree:: + :maxdepth: 1 + + tutorials/getting-started-0 + tutorials/getting-started-1 + tutorials/getting-started-2 + tutorials/getting-started-3 + tutorials/getting-started-4 + tutorials/getting-started-5 Tutorials ========= diff --git a/docs/source/reference/collectors.rst b/docs/source/reference/collectors.rst index aa8de179f20..982b8664862 100644 --- a/docs/source/reference/collectors.rst +++ b/docs/source/reference/collectors.rst @@ -3,6 +3,8 @@ torchrl.collectors package ========================== +.. _ref_collectors: + Data collectors are somewhat equivalent to pytorch dataloaders, except that (1) they collect data over non-static data sources and (2) the data is collected using a model (likely a version of the model that is being trained). diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index 8ab6401b314..d426a112b72 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -3,6 +3,8 @@ torchrl.data package ==================== +.. _ref_data: + Replay Buffers -------------- diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index cce34e14b14..4dbb5a5da57 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -475,6 +475,9 @@ single agent standards. Transforms ---------- + +.. _transforms: + .. currentmodule:: torchrl.envs.transforms In most cases, the raw output of an environment must be treated before being passed to another object (such as a diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index d859140bb70..bcd234a7ff9 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -3,9 +3,13 @@ torchrl.modules package ======================= +.. _ref_modules: + TensorDict modules: Actors, exploration, value models and generative models --------------------------------------------------------------------------- +.. _tdmodules: + TorchRL offers a series of module wrappers aimed at making it easy to build RL models from the ground up. These wrappers are exclusively based on :class:`tensordict.nn.TensorDictModule` and :class:`tensordict.nn.TensorDictSequential`. diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index 1aec88f2d11..c2f43d8e9b6 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -3,6 +3,8 @@ torchrl.objectives package ========================== +.. _ref_objectives: + TorchRL provides a series of losses to use in your training scripts. The aim is to have losses that are easily reusable/swappable and that have a simple signature. diff --git a/docs/source/reference/trainers.rst b/docs/source/reference/trainers.rst index eb857f15a0f..04d4386c631 100644 --- a/docs/source/reference/trainers.rst +++ b/docs/source/reference/trainers.rst @@ -218,6 +218,8 @@ Utils Loggers ------- +.. _ref_loggers: + .. currentmodule:: torchrl.record.loggers .. autosummary:: diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 96d73375ea9..2a169cbd332 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -718,7 +718,7 @@ def __init__( if end_key is None: end_key = ("next", "done") if traj_key is None: - traj_key = "run" + traj_key = "episode" self.end_key = end_key self.traj_key = traj_key diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 746cc60f142..d928fd87500 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -2055,8 +2055,8 @@ def reset( tensordict_reset = self._reset(tensordict, **kwargs) # We assume that this is done properly - # if tensordict_reset.device != self.device: - # tensordict_reset = tensordict_reset.to(self.device, non_blocking=True) + # if reset.device != self.device: + # reset = reset.to(self.device, non_blocking=True) if tensordict_reset is tensordict: raise RuntimeError( "EnvBase._reset should return outplace changes to the input " diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index b71c6fcffc3..40df963ec5e 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -791,7 +791,7 @@ def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs): return tensordict_reset def _reset_proc_data(self, tensordict, tensordict_reset): - # self._complete_done(self.full_done_spec, tensordict_reset) + # self._complete_done(self.full_done_spec, reset) self._reset_check_done(tensordict, tensordict_reset) if tensordict is not None: tensordict_reset = _update_during_reset( @@ -802,7 +802,7 @@ def _reset_proc_data(self, tensordict, tensordict_reset): # # doesn't do anything special # mt_mode = self.transform.missing_tolerance # self.set_missing_tolerance(True) - # tensordict_reset = self.transform._call(tensordict_reset) + # reset = self.transform._call(reset) # self.set_missing_tolerance(mt_mode) return tensordict_reset diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index b7a044cae7d..8d9855283f5 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -33,12 +33,13 @@ class Actor(SafeModule): """General class for deterministic actors in RL. - The Actor class comes with default values for the out_keys (["action"]) - and if the spec is provided but not as a CompositeSpec object, it will be - automatically translated into :obj:`spec = CompositeSpec(action=spec)` + The Actor class comes with default values for the out_keys (``["action"]``) + and if the spec is provided but not as a + :class:`~torchrl.data.CompositeSpec` object, it will be + automatically translated into ``spec = CompositeSpec(action=spec)``. Args: - module (nn.Module): a :class:`torch.nn.Module` used to map the input to + module (nn.Module): a :class:`~torch.nn.Module` used to map the input to the output parameter space. in_keys (iterable of str, optional): keys to be read from input tensordict and passed to the module. If it @@ -47,9 +48,11 @@ class Actor(SafeModule): Defaults to ``["observation"]``. out_keys (iterable of str): keys to be written to the input tensordict. The length of out_keys must match the - number of tensors returned by the embedded module. Using "_" as a + number of tensors returned by the embedded module. Using ``"_"`` as a key avoid writing tensor to output. Defaults to ``["action"]``. + + Keyword Args: spec (TensorSpec, optional): Keyword-only argument. Specs of the output tensor. If the module outputs multiple output tensors, @@ -59,7 +62,7 @@ class Actor(SafeModule): input spec. Out-of-domain sampling can occur because of exploration policies or numerical under/overflow issues. If this value is out of bounds, it is projected back onto the - desired space using the :obj:`TensorSpec.project` + desired space using the :meth:`~torchrl.data.TensorSpec.project` method. Default is ``False``. Examples: @@ -148,17 +151,23 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential): issues. If this value is out of bounds, it is projected back onto the desired space using the :obj:`TensorSpec.project` method. Default is ``False``. - default_interaction_type=InteractionType.RANDOM (str, optional): keyword-only argument. + default_interaction_type (str, optional): keyword-only argument. Default method to be used to retrieve - the output value. Should be one of: 'mode', 'median', 'mean' or 'random' - (in which case the value is sampled randomly from the distribution). Default - is 'mode'. - Note: When a sample is drawn, the :obj:`ProbabilisticTDModule` instance will - first look for the interaction mode dictated by the `interaction_typ()` - global function. If this returns `None` (its default value), then the - `default_interaction_type` of the `ProbabilisticTDModule` instance will be - used. Note that DataCollector instances will use `set_interaction_type` to - :class:`tensordict.nn.InteractionType.RANDOM` by default. + the output value. Should be one of: 'InteractionType.MODE', + 'InteractionType.MEDIAN', 'InteractionType.MEAN' or + 'InteractionType.RANDOM' (in which case the value is sampled + randomly from the distribution). Defaults to is 'InteractionType.RANDOM'. + + .. note:: When a sample is drawn, the :class:`ProbabilisticActor` instance will + first look for the interaction mode dictated by the + :func:`~tensordict.nn.probabilistic.interaction_type` + global function. If this returns `None` (its default value), then the + `default_interaction_type` of the `ProbabilisticTDModule` + instance will be used. Note that + :class:`~torchrl.collectors.collectors.DataCollectorBase` + instances will use `set_interaction_type` to + :class:`tensordict.nn.InteractionType.RANDOM` by default. + distribution_class (Type, optional): keyword-only argument. A :class:`torch.distributions.Distribution` class to be used for sampling. @@ -197,9 +206,7 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential): ... in_keys=["loc", "scale"], ... distribution_class=TanhNormal, ... ) - >>> params = TensorDict.from_module(td_module) - >>> with params.to_module(td_module): - ... td = td_module(td) + >>> td = td_module(td) >>> td TensorDict( fields={ @@ -315,7 +322,8 @@ class ValueOperator(TensorDictModule): The length of out_keys must match the number of tensors returned by the embedded module. Using "_" as a key avoid writing tensor to output. - Defaults to ``["action"]``. + Defaults to ``["state_value"]`` or + ``["state_action_value"]`` if ``"action"`` is part of the ``in_keys``. Examples: >>> import torch @@ -334,9 +342,7 @@ class ValueOperator(TensorDictModule): >>> td_module = ValueOperator( ... in_keys=["observation", "action"], module=module ... ) - >>> params = TensorDict.from_module(td_module) - >>> with params.to_module(td_module): - ... td = td_module(td) + >>> td = td_module(td) >>> print(td) TensorDict( fields={ diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index 37fd1cbdaea..2298c262368 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -213,7 +213,10 @@ def __init__( try: action_space = value_network.action_space except AttributeError: - raise ValueError(self.ACTION_SPEC_ERROR) + raise ValueError( + "The action space could not be retrieved from the value_network. " + "Make sure it is available to the DQN loss module." + ) if action_space is None: warnings.warn( "action_space was not specified. DQNLoss will default to 'one-hot'." diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 43dfa65c0c4..b234af6a804 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -300,8 +300,7 @@ def __init__( ): if eps is None and tau is None: raise RuntimeError( - "Neither eps nor tau was provided. " "This behaviour is deprecated.", - category=DeprecationWarning, + "Neither eps nor tau was provided. This behaviour is deprecated.", ) eps = 0.999 if (eps is None) ^ (tau is None): diff --git a/torchrl/record/loggers/csv.py b/torchrl/record/loggers/csv.py index 256d0a2e840..6bcd3f50c86 100644 --- a/torchrl/record/loggers/csv.py +++ b/torchrl/record/loggers/csv.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import os from collections import defaultdict from pathlib import Path @@ -126,7 +128,7 @@ class CSVLogger(Logger): def __init__( self, exp_name: str, - log_dir: Optional[str] = None, + log_dir: str | None = None, video_format: str = "pt", video_fps: int = 30, ) -> None: diff --git a/tutorials/sphinx-tutorials/README.rst b/tutorials/sphinx-tutorials/README.rst index a7e41cccf45..7995a1fbb2e 100644 --- a/tutorials/sphinx-tutorials/README.rst +++ b/tutorials/sphinx-tutorials/README.rst @@ -1,2 +1,4 @@ README Tutos ============ + +Check the tutorials on torchrl documentation: https://pytorch.org/rl diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index 5f8bf2c0830..252b4fd2146 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -4,6 +4,8 @@ ====================================== **Author**: `Vincent Moens `_ +.. _coding_ddpg: + """ ############################################################################## diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index f85f6bf1e14..eb476dfcc15 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -4,6 +4,8 @@ ============================== **Author**: `Vincent Moens `_ +.. _coding_dqn: + """ ############################################################################## @@ -404,9 +406,9 @@ def get_replay_buffer(buffer_size, n_optim, batch_size): # environment executed in parallel in each collector (controlled by the # ``num_workers`` hyperparameter). # -# When building the collector, we can choose on which device we want the -# environment and policy to execute the operations through the ``device`` -# keyword argument. The ``storing_devices`` argument will modify the +# Collector's devices are fully parametrizable through the ``device`` (general), +# ``policy_device``, ``env_device`` and ``storing_device`` arguments. +# The ``storing_device`` argument will modify the # location of the data being collected: if the batches that we are gathering # have a considerable size, we may want to store them on a different location # than the device where the computation is happening. For asynchronous data diff --git a/tutorials/sphinx-tutorials/coding_ppo.py b/tutorials/sphinx-tutorials/coding_ppo.py index be82bbd3bd8..6f31a0aed1a 100644 --- a/tutorials/sphinx-tutorials/coding_ppo.py +++ b/tutorials/sphinx-tutorials/coding_ppo.py @@ -4,6 +4,8 @@ ================================================== **Author**: `Vincent Moens `_ +.. _coding_ppo: + This tutorial demonstrates how to use PyTorch and :py:mod:`torchrl` to train a parametric policy network to solve the Inverted Pendulum task from the `OpenAI-Gym/Farama-Gymnasium control library `__. diff --git a/tutorials/sphinx-tutorials/dqn_with_rnn.py b/tutorials/sphinx-tutorials/dqn_with_rnn.py index b71a112c91a..a2b2b12b562 100644 --- a/tutorials/sphinx-tutorials/dqn_with_rnn.py +++ b/tutorials/sphinx-tutorials/dqn_with_rnn.py @@ -6,6 +6,8 @@ **Author**: `Vincent Moens `_ +.. _RNN_tuto: + .. grid:: 2 .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn diff --git a/tutorials/sphinx-tutorials/getting-started-0.py b/tutorials/sphinx-tutorials/getting-started-0.py new file mode 100644 index 00000000000..e81b10ec381 --- /dev/null +++ b/tutorials/sphinx-tutorials/getting-started-0.py @@ -0,0 +1,245 @@ +# -*- coding: utf-8 -*- +""" + +Get started with Environments, TED and transforms +================================================= + +**Author**: `Vincent Moens `_ + +.. _gs_env_ted: + +""" + +################################ +# Welcome to the getting started tutorials! +# +# Below is the list of the topics we will be covering. +# +# - :ref:`Environments, TED and transforms `; +# - :ref:`TorchRL's modules `; +# - :ref:`Losses and optimization `; +# - :ref:`Data collection and storage `; +# - :ref:`TorchRL's logging API `. +# +# If you are in a hurry, you can jump straight away to the last tutorial, +# :ref:`Your onw first training loop `, from where you can +# backtrack every other "Getting Started" tutorial if things are not clear or +# if you want to learn more about a specific topic! +# +# Environments in RL +# ------------------ +# +# The standard RL (Reinforcement Learning) training loop involves a model, +# also known as a policy, which is trained to accomplish a task within a +# specific environment. Often, this environment is a simulator that accepts +# actions as input and produces an observation along with some metadata as +# output. +# +# In this document, we will explore the environment API of TorchRL: we will +# learn how to create an environment, interact with it, and understand the +# data format it uses. +# +# Creating an environment +# ----------------------- +# +# In essence, TorchRL does not directly provide environments, but instead +# offers wrappers for other libraries that encapsulate the simulators. The +# :mod:`~torchrl.envs` module can be viewed as a provider for a generic +# environment API, as well as a central hub for simulation backends like +# `gym `_ (:class:`~torchrl.envs.GymEnv`), +# `Brax `_ (:class:`~torchrl.envs.BraxEnv`) +# or `DeepMind Control Suite `_ +# (:class:`~torchrl.envs.DMControlEnv`). +# +# Creating your environment is typically as straightforward as the underlying +# backend API allows. Here's an example using gym: + +from torchrl.envs import GymEnv + +env = GymEnv("Pendulum-v1") + +################################ +# +# Running an environment +# ---------------------- +# +# Environments in TorchRL have two crucial methods: +# :meth:`~torchrl.envs.EnvBase.reset`, which initiates +# an episode, and :meth:`~torchrl.envs.EnvBase.step`, which executes an +# action selected by the actor. +# In TorchRL, environment methods read and write +# :class:`~tensordict.TensorDict` instances. +# Essentially, :class:`~tensordict.TensorDict` is a generic key-based data +# carrier for tensors. +# The benefit of using TensorDict over plain tensors is that it enables us to +# handle simple and complex data structures interchangeably. As our function +# signatures are very generic, it eliminates the challenge of accommodating +# different data formats. In simpler terms, after this brief tutorial, +# you will be capable of operating on both simple and highly complex +# environments, as their user-facing API is identical and simple! +# +# Let's put the environment into action and see what a tensordict instance +# looks like: + +reset = env.reset() +print(reset) + +################################ +# Now let's take a random action in the action space. First, sample the action: +reset_with_action = env.rand_action(reset) +print(reset_with_action) + +################################ +# This tensordict has the same structure as the one obtained from +# :meth:`~torchrl.envs.EnvBase` with an additional ``"action"`` entry. +# You can access the action easily, like you would do with a regular +# dictionary: +# + +print(reset_with_action["action"]) + +################################ +# We now need to pass this action tp the environment. +# We'll be passing the entire tensordict to the ``step`` method, since there +# might be more than one tensor to be read in more advanced cases like +# Multi-Agent RL or stateless environments: + +stepped_data = env.step(reset_with_action) +print(stepped_data) + +################################ +# Again, this new tensordict is identical to the previous one except for the +# fact that it has a ``"next"`` entry (itself a tensordict!) containing the +# observation, reward and done state resulting from +# our action. +# +# We call this format TED, for +# :ref:`TorchRL Episode Data format `. It is +# the ubiquitous way of representing data in the library, both dynamically like +# here, or statically with offline datasets. +# +# The last bit of information you need to run a rollout in the environment is +# how to bring that ``"next"`` entry at the root to perform the next step. +# TorchRL provides a dedicated :func:`~torchrl.envs.utils.step_mdp` function +# that does just that: it filters out the information you won't need and +# delivers a data structure corresponding to your observation after a step in +# the Markov Decision Process, or MDP. + +from torchrl.envs import step_mdp + +data = step_mdp(stepped_data) +print(data) + +################################ +# Environment rollouts +# -------------------- +# +# .. _gs_env_ted_rollout: +# +# Writing down those three steps (computing an action, making a step, +# moving in the MDP) can be a bit tedious and repetitive. Fortunately, +# TorchRL provides a nice :meth:`~torchrl.envs.EnvBase.rollout` function that +# allows you to run them in a closed loop at will: +# + +rollout = env.rollout(max_steps=10) +print(rollout) + +################################ +# This data looks pretty much like the ``stepped_data`` above with the +# exception of its batch-size, which now equates the number of steps we +# provided through the ``max_steps`` argument. The magic of tensordict +# doesn't end there: if you're interested in a single transition of this +# environment, you can index the tensordict like you would index a tensor: + +transition = rollout[3] +print(transition) + +################################ +# :class:`~tensordict.TensorDict` will automatically check if the index you +# provided is a key (in which case we index along the key-dimension) or a +# spatial index like here. +# +# Executed as such (without a policy), the ``rollout`` method may seem rather +# useless: it just runs random actions. If a policy is available, it can +# be passed to the method and used to collect data. +# +# Nevertheless, it can useful to run a naive, policyless rollout at first to +# check what is to be expected from an environment at a glance. +# +# To appreciate the versatility of TorchRL's API, consider the fact that the +# rollout method is universally applicable. It functions across **all** use +# cases, whether you're working with a single environment like this one, +# multiple copies across various processes, a multi-agent environment, or even +# a stateless version of it! +# +# +# Transforming an environment +# --------------------------- +# +# Most of the time, you'll want to modify the output of the environment to +# better suit your requirements. For example, you might want to monitor the +# number of steps executed since the last reset, resize images, or stack +# consecutive observations together. +# +# In this section, we'll examine a simple transform, the +# :class:`~torchrl.envs.transforms.StepCounter` transform. +# The complete list of transforms can be found +# :ref:`here `. +# +# The transform is integrated with the environment through a +# :class:`~torchrl.envs.transforms.TransformedEnv`: +# + +from torchrl.envs import StepCounter, TransformedEnv + +transformed_env = TransformedEnv(env, StepCounter(max_steps=10)) +rollout = transformed_env.rollout(max_steps=100) +print(rollout) + +################################ +# As you can see, our environment now has one more entry, ``"step_count"`` that +# tracks the number of steps since the last reset. +# Given that we passed the optional +# argument ``max_steps=10`` to the transform constructor, we also truncated the +# trajectory after 10 steps (not completing a full rollout of 100 steps like +# we asked with the ``rollout`` call). We can see that the trajectory was +# truncated by looking at the truncated entry: + +print(rollout["next", "truncated"]) + +################################ +# +# This is all for this short introduction to TorchRL's environment API! +# +# Next steps +# ---------- +# +# To explore further what TorchRL's environments can do, go and check: +# +# - The :meth:`~torchrl.envs.EnvBase.step_and_maybe_reset` method that packs +# together :meth:`~torchrl.envs.EnvBase.step`, +# :func:`~torchrl.envs.utils.step_mdp` and +# :meth:`~torchrl.envs.EnvBase.reset`. +# - Some environments like :class:`~torchrl.envs.GymEnv` support rendering +# through the ``from_pixels`` argument. Check the class docstrings to know +# more! +# - The batched environments, in particular :class:`~torchrl.envs.ParallelEnv` +# which allows you to run multiple copies of one same (or different!) +# environments on multiple processes. +# - Design your own environment with the +# :ref:`Pendulum tutorial ` and learn about specs and +# stateless environments. +# - See the more in-depth tutorial about environments +# :ref:`in the dedicated tutorial `; +# - Check the +# :ref:`multi-agent environment API ` +# if you're interested in MARL; +# - TorchRL has many tools to interact with the Gym API such as +# a way to register TorchRL envs in the Gym register through +# :meth:`~torchrl.envs.EnvBase.register_gym`, an API to read +# the info dictionaries through +# :meth:`~torchrl.envs.EnvBase.set_info_dict_reader` or a way +# to control the gym backend thanks to +# :func:`~torchrl.envs.set_gym_backend`. +# diff --git a/tutorials/sphinx-tutorials/getting-started-1.py b/tutorials/sphinx-tutorials/getting-started-1.py new file mode 100644 index 00000000000..136deeb5cd9 --- /dev/null +++ b/tutorials/sphinx-tutorials/getting-started-1.py @@ -0,0 +1,309 @@ +# -*- coding: utf-8 -*- +""" +Get started with TorchRL's modules +================================== + +**Author**: `Vincent Moens `_ + +.. _gs_modules: + +""" +################################### +# Reinforcement Learning is designed to create policies that can effectively +# tackle specific tasks. Policies can take various forms, from a differentiable +# map transitioning from the observation space to the action space, to a more +# ad-hoc method like an argmax over a list of values computed for each possible +# action. Policies can be deterministic or stochastic, and may incorporate +# complex elements such as Recurrent Neural Networks (RNNs) or transformers. +# +# Accommodating all these scenarios can be quite intricate. In this succinct +# tutorial, we will delve into the core functionality of TorchRL in terms of +# policy construction. We will primarily focus on stochastic and Q-Value +# policies in two common scenarios: using a Multi-Layer Perceptron (MLP) or +# a Convolutional Neural Network (CNN) as backbones. +# +# TensorDictModules +# ----------------- +# +# Similar to how environments interact with instances of +# :class:`~tensordict.TensorDict`, the modules used to represent policies and +# value functions also do the same. The core idea is simple: encapsulate a +# standard :class:`~torch.nn.Module` (or any other function) within a class +# that knows which entries need to be read and passed to the module, and then +# records the results with the assigned entries. To illustrate this, we will +# use the simplest policy possible: a deterministic map from the observation +# space to the action space. For maximum generality, we will use a +# :class:`~torch.nn.LazyLinear` module with the Pendulum environment we +# instantiated in the previous tutorial. +# + +import torch + +from tensordict.nn import TensorDictModule +from torchrl.envs import GymEnv + +env = GymEnv("Pendulum-v1") +module = torch.nn.LazyLinear(out_features=env.action_spec.shape[-1]) +policy = TensorDictModule( + module, + in_keys=["observation"], + out_keys=["action"], +) + +################################### +# This is all that's required to execute our policy! The use of a lazy module +# allows us to bypass the need to fetch the shape of the observation space, as +# the module will automatically determine it. This policy is now ready to be +# run in the environment: + +rollout = env.rollout(max_steps=10, policy=policy) +print(rollout) + +################################### +# Specialized wrappers +# -------------------- +# +# To simplify the incorporation of :class:`~torch.nn.Module`s into your +# codebase, TorchRL offers a range of specialized wrappers designed to be +# used as actors, including :class:`~torchrl.modules.tensordict_module.Actor`, +# # :class:`~torchrl.modules.tensordict_module.ProbabilisticActor`, +# # :class:`~torchrl.modules.tensordict_module.ActorValueOperator` or +# # :class:`~torchrl.modules.tensordict_module.ActorCriticOperator`. +# For example, :class:`~torchrl.modules.tensordict_module.Actor` provides +# default values for the ``in_keys`` and ``out_keys``, making integration +# with many common environments straightforward: +# + +from torchrl.modules import Actor + +policy = Actor(module) +rollout = env.rollout(max_steps=10, policy=policy) +print(rollout) + +################################### +# The list of available specialized TensorDictModules is available in the +# :ref:`API reference `. +# +# Networks +# -------- +# +# TorchRL also provides regular modules that can be used without recurring to +# tensordict features. The two most common networks you will encounter are +# the :class:`~torchrl.modules.MLP` and the :class:`~torchrl.modules.ConvNet` +# (CNN) modules. We can substitute our policy module with one of these: +# + +from torchrl.modules import MLP + +module = MLP( + out_features=env.action_spec.shape[-1], + num_cells=[32, 64], + activation_class=torch.nn.Tanh, +) +policy = Actor(module) +rollout = env.rollout(max_steps=10, policy=policy) + +################################### +# TorchRL also supports RNN-based policies. Since this is a more technical +# topic, it is treated in :ref:`a separate tutorial `. +# +# Probabilistic policies +# ---------------------- +# +# Policy-optimization algorithms like +# `PPO `_ require the policy to be +# stochastic: unlike in the examples above, the module now encodes a map from +# the observation space to a parameter space encoding a distribution over the +# possible actions. TorchRL facilitates the design of such modules by grouping +# under a single class the various operations such as building the distribution +# from the parameters, sampling from that distribution and retrieving the +# log-probability. Here, we'll be building an actor that relies on a regular +# normal distribution using three components: +# +# - An :class:`~torchrl.modules.MLP` backbone reading observations of size +# ``[3]`` and outputting a single tensor of size ``[2]``; +# - A :class:`~tensordict.nn.distributions.NormalParamExtractor` module that +# will split this output on two chunks, a mean and a standard deviation of +# size ``[1]``; +# - A :class:`~torchrl.modules.tensordict_module.ProbabilisticActor` that will +# read those parameters as ``in_keys``, create a distribution with them and +# populate our tensordict with samples and log-probabilities. +# + +from tensordict.nn.distributions import NormalParamExtractor +from torch.distributions import Normal +from torchrl.modules import ProbabilisticActor + +backbone = MLP(in_features=3, out_features=2) +extractor = NormalParamExtractor() +module = torch.nn.Sequential(backbone, extractor) +td_module = TensorDictModule(module, in_keys=["observation"], out_keys=["loc", "scale"]) +policy = ProbabilisticActor( + td_module, + in_keys=["loc", "scale"], + out_keys=["action"], + distribution_class=Normal, + return_log_prob=True, +) + +rollout = env.rollout(max_steps=10, policy=policy) +print(rollout) + +################################### +# There are a few things to note about this rollout: +# +# - Since we asked for it during the construction of the actor, the +# log-probability of the actions given the distribution at that time is +# also written. This is necessary for algorithms like PPO. +# - The parameters of the distribution are returned within the output +# tensordict too under the ``"loc"`` and ``"scale"`` entries. +# +# You can control the sampling of the action to use the expected value or +# other properties of the distribution instead of using random samples if +# your application requires it. This can be controlled via the +# :func:`~torchrl.envs.utils.set_exploration_type` function: + +from torchrl.envs.utils import ExplorationType, set_exploration_type + +with set_exploration_type(ExplorationType.MEAN): + # takes the mean as action + rollout = env.rollout(max_steps=10, policy=policy) +with set_exploration_type(ExplorationType.RANDOM): + # Samples actions according to the dist + rollout = env.rollout(max_steps=10, policy=policy) + +################################### +# Check the ``default_interaction_type`` keyword argument in +# the docstrings to know more. +# +# Exploration +# ----------- +# +# Stochastic policies like this somewhat naturally trade off exploration and +# exploitation, but deterministic policies won't. Fortunately, TorchRL can +# also palliate to this with its exploration modules. +# We will take the example of the :class:`~torchrl.modules.EGreedyModule` +# exploration module (check also +# :class:`~torchrl.modules.AdditiveGaussianWrapper` and +# :class:`~torchrl.modules.OrnsteinUhlenbeckProcessWrapper`). +# To see this module in action, let's revert to a deterministic policy: + +from tensordict.nn import TensorDictSequential +from torchrl.modules import EGreedyModule + +policy = Actor(MLP(3, 1, num_cells=[32, 64])) + +################################### +# Our :math:`\epsilon`-greedy exploration module will usually be customized +# with a number of annealing frames and an initial value for the +# :math:`\epsilon` parameter. A value of :math:`\epsilon = 1` means that every +# action taken is random, while :math:`\epsilon=0` means that there is no +# exploration at all. To anneal (i.e., decrease) the exploration factor, a call +# to :meth:`~torchrl.modules.EGreedyModule.step` is required (see the last +# :ref:`tutorial ` for an example). +# +exploration_module = EGreedyModule( + spec=env.action_spec, annealing_num_steps=1000, eps_init=0.5 +) + +################################### +# To build our explorative policy, we only had to concatenate the +# deterministic policy module with the exploration module within a +# :class:`~tensordict.nn.TensorDictSequential` module (which is the analogous +# to :class:`~torch.nn.Sequential` in the tensordict realm). + +exploration_policy = TensorDictSequential(policy, exploration_module) + +with set_exploration_type(ExplorationType.MEAN): + # Turns off exploration + rollout = env.rollout(max_steps=10, policy=exploration_policy) +with set_exploration_type(ExplorationType.RANDOM): + # Turns on exploration + rollout = env.rollout(max_steps=10, policy=exploration_policy) + +################################### +# Because it must be able to sample random actions in the action space, the +# :class:`~torchrl.modules.EGreedyModule` must be equipped with the +# ``action_space`` from the environment to know what strategy to use to +# sample actions randomly. +# +# Q-Value actors +# -------------- +# +# In some settings, the policy isn't a standalone module but is constructed on +# top of another module. This is the case with **Q-Value actors**. In short, these +# actors require an estimate of the action value (most of the time discrete) +# and will greedily pick up the action with the highest value. In some +# settings (finite discrete action space and finite discrete state space), +# one can just store a 2D table of state-action pairs and pick up the +# action with the highest value. The innovation brought by +# `DQN `_ was to scale this up to continuous +# state spaces by utilizing a neural network to encode for the ``Q(s, a)`` +# value map. Let's consider another environment with a discrete action space +# for a clearer understanding: + +env = GymEnv("CartPole-v1") +print(env.action_spec) + +################################### +# We build a value network that produces one value per action when it reads a +# state from the environment: + +num_actions = 2 +value_net = TensorDictModule( + MLP(out_features=num_actions, num_cells=[32, 32]), + in_keys=["observation"], + out_keys=["action_value"], +) + +################################### +# We can easily build our Q-Value actor by adding a +# :class:`~torchrl.modules.tensordict_module.QValueModule` after our value +# network: + +from torchrl.modules import QValueModule + +policy = TensorDictSequential( + value_net, # writes action values in our tensordict + QValueModule( + action_space=env.action_spec + ), # Reads the "action_value" entry by default +) + +################################### +# Let's check it out! We run the policy for a couple of steps and look at the +# output. We should find an ``"action_value"`` as well as a +# ``"chosen_action_value"`` entries in the rollout that we obtain: +# + +rollout = env.rollout(max_steps=3, policy=policy) +print(rollout) + +################################### +# Because it relies on the ``argmax`` operator, this policy is deterministic. +# During data collection, we will need to explore the environment. For that, +# we are using the :class:`~torchrl.modules.tensordict_module.EGreedyModule` +# once again: + +policy_explore = TensorDictSequential(policy, EGreedyModule(env.action_spec)) + +with set_exploration_type(ExplorationType.RANDOM): + rollout_explore = env.rollout(max_steps=3, policy=policy_explore) + +################################### +# This is it for our short tutorial on building a policy with TorchRL! +# +# There are many more things you can do with the library. A good place to start +# is to look at the :ref:`API reference for modules `. +# +# Next steps: +# +# - Check how to use compound distributions with +# :class:`~tensordict.nn.distributions.CompositeDistribution` when the +# action is composite (e.g., a discrete and a continuous action are +# required by the env); +# - Have a look at how you can use an RNN within the policy (a +# :ref:`tutorial `); +# - Compare this to the usage of transformers with the Decision Transformers +# examples (see the ``example`` directory on GitHub). +# diff --git a/tutorials/sphinx-tutorials/getting-started-2.py b/tutorials/sphinx-tutorials/getting-started-2.py new file mode 100644 index 00000000000..1d903e67c01 --- /dev/null +++ b/tutorials/sphinx-tutorials/getting-started-2.py @@ -0,0 +1,173 @@ +# -*- coding: utf-8 -*- +""" +Getting started with model optimization +======================================= + +**Author**: `Vincent Moens `_ + +.. _gs_optim: + +""" + +################################### +# In TorchRL, we try to treat optimization as it is custom to do in PyTorch, +# using dedicated loss modules which are designed with the sole purpose of +# optimizing the model. This approach efficiently decouples the execution of +# the policy from its training and allows us to design training loops that are +# similar to what can be found in traditional supervised learning examples. +# +# The typical training loop therefore looks like this: +# +# >>> for i in range(n_collections): +# ... data = get_next_batch(env, policy) +# ... for j in range(n_optim): +# ... loss = loss_fn(data) +# ... loss.backward() +# ... optim.step() +# +# In this concise tutorial, you will receive a brief overview of the loss modules. Due to the typically +# straightforward nature of the API for basic usage, this tutorial will be kept brief. +# +# RL objective functions +# ---------------------- +# +# In RL, innovation typically involves the exploration of novel methods +# for optimizing a policy (i.e., new algorithms), rather than focusing +# on new architectures, as seen in other domains. Within TorchRL, +# these algorithms are encapsulated within loss modules. A loss +# module orchestrates the various components of your algorithm and +# yields a set of loss values that can be backpropagated +# through to train the corresponding components. +# +# In this tutorial, we will take a popular +# off-policy algorithm as an example, +# `DDPG `_. +# +# To build a loss module, the only thing one needs is a set of networks +# defined as :class:`~tensordict.nn.TensorDictModule`s. Most of the time, one +# of these modules will be the policy. Other auxiliary networks such as +# Q-Value networks or critics of some kind may be needed as well. Let's see +# what this looks like in practice: DDPG requires a deterministic +# map from the observation space to the action space as well as a value +# network that predicts the value of a state-action pair. The DDPG loss will +# attempt to find the policy parameters that output actions that maximize the +# value for a given state. +# +# To build the loss, we need both the actor and value networks. +# If they are built according to DDPG's expectations, it is all +# we need to get a trainable loss module: + +from torchrl.envs import GymEnv + +env = GymEnv("Pendulum-v1") + +from torchrl.modules import Actor, MLP, ValueOperator +from torchrl.objectives import DDPGLoss + +n_obs = env.observation_spec["observation"].shape[-1] +n_act = env.action_spec.shape[-1] +actor = Actor(MLP(in_features=n_obs, out_features=n_act, num_cells=[32, 32])) +value_net = ValueOperator( + MLP(in_features=n_obs + n_act, out_features=1, num_cells=[32, 32]), + in_keys=["observation", "action"], +) + +ddpg_loss = DDPGLoss(actor_network=actor, value_network=value_net) + +################################### +# And that is it! Our loss module can now be run with data coming from the +# environment (we omit exploration, storage and other features to focus on +# the loss functionality): +# + +rollout = env.rollout(max_steps=100, policy=actor) +loss_vals = ddpg_loss(rollout) +print(loss_vals) + +################################### +# LossModule's output +# ------------------- +# +# As you can see, the value we received from the loss isn't a single scalar +# but a dictionary containing multiple losses. +# +# The reason is simple: because more than one network may be trained at a time, +# and since some users may wish to separate the optimization of each module +# in distinct steps, TorchRL's objectives will return dictionaries containing +# the various loss components. +# +# This format also allows us to pass metadata along with the loss values. In +# general, we make sure that only the loss values are differentiable such that +# you can simply sum over the values of the dictionary to obtain the total +# loss. If you want to make sure you're fully in control of what is happening, +# you can sum over only the entries which keys start with the ``"loss_"`` prefix: +# + +total_loss = 0 +for key, val in loss_vals.items(): + if key.startswith("loss_"): + total_loss += val + +################################### +# Training a LossModule +# --------------------- +# +# Given all this, training the modules is not so different from what would be +# done in any other training loop. Because it wraps the modules, +# the easiest way to get the list of trainable parameters is to query +# the :meth:`~torchrl.objectives.LossModule.parameters` method. +# +# We'll need an optimizer (or one optimizer +# per module if that is your choice). +# + +from torch.optim import Adam + +optim = Adam(ddpg_loss.parameters()) +total_loss.backward() + +################################### +# The following items will typically be +# found in your training loop: + +optim.step() +optim.zero_grad() + +################################### +# Further considerations: Target parameters +# ----------------------------------------- +# +# Another important aspect to consider is the presence of target parameters +# in off-policy algorithms like DDPG. Target parameters typically represent +# a delayed or smoothed version of the parameters over time, and they play +# a crucial role in value estimation during policy training. Utilizing target +# parameters for policy training often proves to be significantly more +# efficient compared to using the current configuration of value network +# parameters. Generally, managing target parameters is handled by the loss +# module, relieving users of direct concern. However, it remains the user's +# responsibility to update these values as necessary based on specific +# requirements. TorchRL offers a couple of updaters, namely +# :class:`~torchrl.objectives.HardUpdate` and +# :class:`~torchrl.objectives.SoftUpdate`, +# which can be easily instantiated without requiring in-depth +# knowledge of the underlying mechanisms of the loss module. +# +from torchrl.objectives import SoftUpdate + +updater = SoftUpdate(ddpg_loss, eps=0.99) + +################################### +# In your training loop, you will need to update the target parameters at each +# optimization step or each collection step: + +updater.step() + +################################### +# This is all you need to know about loss modules to get started! +# +# To further explore the topic, have a look at: +# +# - The :ref:`loss module reference page `; +# - The :ref:`Coding a DDPG loss tutorial `; +# - Losses in action in :ref:`PPO ` or :ref:`DQN `. +# diff --git a/tutorials/sphinx-tutorials/getting-started-3.py b/tutorials/sphinx-tutorials/getting-started-3.py new file mode 100644 index 00000000000..97934ef424d --- /dev/null +++ b/tutorials/sphinx-tutorials/getting-started-3.py @@ -0,0 +1,180 @@ +# -*- coding: utf-8 -*- +""" +Get started with data collection and storage +============================================ + +**Author**: `Vincent Moens `_ + +.. _gs_storage: + +""" + +################################# +# +# There is no learning without data. In supervised learning, users are +# accustomed to using :class:`~torch.utils.data.DataLoader` and the like +# to integrate data in their training loop. +# Dataloaders are iterable objects that provide you with the data that you will +# be using to train your model. +# +# TorchRL approaches the problem of dataloading in a similar manner, although +# it is surprisingly unique in the ecosystem of RL libraries. TorchRL's +# dataloaders are referred to as ``DataCollectors``. Most of the time, +# data collection does not stop at the collection of raw data, +# as the data needs to be stored temporarily in a buffer +# (or equivalent structure for on-policy algorithms) before being consumed +# by the :ref:`loss module `. This tutorial will explore +# these two classes. +# +# Data collectors +# --------------- +# +# .. _gs_storage_collector: +# +# +# The primary data collector discussed here is the +# :class:`~torchrl.collectors.SyncDataCollector`, which is the focus of this +# documentation. At a fundamental level, a collector is a straightforward +# class responsible for executing your policy within the environment, +# resetting the environment when necessary, and providing batches of a +# predefined size. Unlike the :meth:`~torchrl.envs.EnvBase.rollout` method +# demonstrated in :ref:`the env tutorial `, collectors do not +# reset between consecutive batches of data. Consequently, two successive +# batches of data may contain elements from the same trajectory. +# +# The basic arguments you need to pass to your collector are the size of the +# batches you want to collect (``frames_per_batch``), the length (possibly +# infinite) of the iterator, the policy and the environment. For simplicity, +# we will use a dummy, random policy in this example. + +import torch + +torch.manual_seed(0) + +from torchrl.collectors import RandomPolicy, SyncDataCollector +from torchrl.envs import GymEnv + +env = GymEnv("CartPole-v1") +env.set_seed(0) + +policy = RandomPolicy(env.action_spec) +collector = SyncDataCollector(env, policy, frames_per_batch=200, total_frames=-1) + +################################# +# We now expect that our collector will deliver batches of size ``200`` no +# matter what happens during collection. In other words, we may have multiple +# trajectories in this batch! The ``total_frames`` indicates how long the +# collector should be. A value of ``-1`` will produce a never +# ending collector. +# +# Let's iterate over the collector to get a sense +# of what this data looks like: + +for data in collector: + print(data) + break + +################################# +# As you can see, our data is augmented with some collector-specific metadata +# grouped in a ``"collector"`` sub-tensordict that we did not see during +# :ref:`environment rollouts `. This is useful to keep track of +# the trajectory ids. In the following list, each item marks the trajectory +# number the corresponding transition belongs to: + +print(data["collector", "traj_ids"]) + +################################# +# Data collectors are very useful when it comes to coding state-of-the-art +# algorithms, as performance is usually measured by the capability of a +# specific technique to solve a problem in a given number of interactions with +# the environment (the ``total_frames`` argument in the collector). +# For this reason, most training loops in our examples look like this: +# +# >>> for data in collector: +# ... # your algorithm here +# +# +# Replay Buffers +# -------------- +# +# .. _gs_storage_rb: +# +# Now that we have explored how to collect data, we would like to know how to +# store it. In RL, the typical setting is that the data is collected, stored +# temporarily and cleared after a little while given some heuristic: +# first-in first-out or other. A typical pseudo-code would look like this: +# +# >>> for data in collector: +# ... storage.store(data) +# ... for i in range(n_optim): +# ... sample = storage.sample() +# ... loss_val = loss_fn(sample) +# ... loss_val.backward() +# ... optim.step() # etc +# +# The parent class that stores the data in TorchRL +# is referred to as :class:`~torchrl.data.ReplayBuffer`. TorchRL's replay +# buffers are composable: you can edit the storage type, their sampling +# technique, the writing heuristic or the transforms applied to them. We will +# leave the fancy stuff for a dedicated in-depth tutorial. The generic replay +# buffer only needs to know what storage it has to use. In general, we +# recommend a :class:`~torchrl.data.TensorStorage` subclass, which will work +# fine in most cases. We'll be using +# :class:`~torchrl.data.replay_buffers.LazyMemmapStorage` +# in this tutorial, which enjoys two nice properties: first, being "lazy", +# you don't need to explicitly tell it what your data looks like in advance. +# Second, it uses :class:`~tensordict.MemoryMappedTensor` as a backend to save +# your data on disk in an efficient way. The only thing you need to know is +# how big you want your buffer to be. + +from torchrl.data.replay_buffers import LazyMemmapStorage, ReplayBuffer + +buffer = ReplayBuffer(storage=LazyMemmapStorage(max_size=1000)) + +################################# +# Populating the buffer can be done via the +# :meth:`~torchrl.data.ReplayBuffer.add` (single element) or +# :meth:`~torchrl.data.ReplayBuffer.extend` (multiple elements) methods. Using +# the data we just collected, we initialize and populate the buffer in one go: + +indices = buffer.extend(data) + +################################# +# We can check that the buffer now has the same number of elements than what +# we got from the collector: + +assert len(buffer) == collector.frames_per_batch + +################################# +# The only thing left to know is how to gather data from the buffer. +# Naturally, this relies on the :meth:`~torchrl.data.ReplayBuffer.sample` +# method. Because we did not specify that sampling had to be done without +# repetitions, it is not guaranteed that the samples gathered from our buffer +# will be unique: + +sample = buffer.sample(batch_size=30) +print(sample) + +################################# +# Again, our sample looks exactly the same as the data we gathered from the +# collector! +# +# Next steps +# ---------- +# +# - You can have look at other multirpocessed +# collectors such as :class:`~torchrl.collectors.collectors.MultiSyncDataCollector` or +# :class:`~torchrl.collectors.collectors.MultiaSyncDataCollector`. +# - TorchRL also offers distributed collectors if you have multiple nodes to +# use for inference. Check them out in the +# :ref:`API reference `. +# - Check the dedicated :ref:`Replay Buffer tutorial ` to know +# more about the options you have when building a buffer, or the +# :ref:`API reference ` which covers all the features in +# details. Replay buffers have countless features such as multithreaded +# sampling, prioritized experience replay, and many more... +# - We left out the capacity of replay buffers to be iterated over for +# simplicity. Try it out for yourself: build a buffer and indicate its +# batch-size in the constructor, then try to iterate over it. This is +# equivalent to calling ``rb.sample()`` within a loop! +# diff --git a/tutorials/sphinx-tutorials/getting-started-4.py b/tutorials/sphinx-tutorials/getting-started-4.py new file mode 100644 index 00000000000..bff30d79851 --- /dev/null +++ b/tutorials/sphinx-tutorials/getting-started-4.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +""" +Get started with logging +======================== + +**Author**: `Vincent Moens `_ + +.. _gs_logging: + +""" + +##################################### +# The final chapter of this series before we orchestrate everything in a +# training script is to learn about logging. +# +# Loggers +# ------- +# +# Logging is crucial for reporting your results to the outside world and for +# you to check that your algorithm is learning properly. TorchRL has several +# loggers that interface with custom backends such as +# wandb (:class:`~torchrl.record.loggers.wandb.WandbLogger`), +# tensorboard (:class:`~torchrl.record.loggers.tensorboard.TensorBoardLogger`) or a lightweight and +# portable CSV logger (:class:`~torchrl.record.loggers.csv.CSVLogger`) that you can use +# pretty much everywhere. +# +# Loggers are located in the ``torchrl.record`` module and the various classes +# can be found in the :ref:`API reference `. +# +# We tried to keep the loggers APIs as similar as we could, given the +# differences in the underlying backends. While execution of the loggers will +# mostly be interchangeable, their instantiation can differ. +# +# Usually, building a logger requires +# at least an experiment name and possibly a logging directory and other +# hyperapameters. +# + +from torchrl.record import CSVLogger + +logger = CSVLogger(exp_name="my_exp") + +##################################### +# Once the logger is instantiated, the only thing left to do is call the +# logging methods! For example, :meth:`~torchrl.record.CSVLogger.log_scalar` +# is used in several places across the training examples to log values such as +# reward, loss value or time elapsed for executing a piece of code. + +logger.log_scalar("my_scalar", 0.4) + +##################################### +# Recording videos +# ---------------- +# +# Finally, it can come in handy to record videos of a simulator. Some +# environments (e.g., Atari games) are already rendered as images whereas +# others require you to create them as such. Fortunately, in most common cases, +# rendering and recording videos isn't too difficult. +# +# Let's first see how we can create a Gym environment that outputs images +# alongside its observations. :class:`~torchrl.envs.GymEnv` accept two keywords +# for this purpose: ``from_pixels=True`` will make the env ``step`` function +# write a ``"pixels"`` entry containing the images corresponding to your +# observations, and the ``pixels_only=False`` will indicate that you want the +# observations to be returned as well. +# + +from torchrl.envs import GymEnv + +env = GymEnv("CartPole-v1", from_pixels=True, pixels_only=False) + +print(env.rollout(max_steps=3)) + +from torchrl.envs import TransformedEnv + +##################################### +# We now have built an environment that renders images with its observations. +# To record videos, we will need to combine that environment with a recorder +# and the logger (the logger providing the backend to save the video). +# This will happen within a transformed environment, like the one we saw in +# the :ref:`first tutorial `. + +from torchrl.record import VideoRecorder + +recorder = VideoRecorder(logger, tag="my_video") +record_env = TransformedEnv(env, recorder) + +##################################### +# When running this environment, all the ``"pixels"`` entries will be saved in +# a local buffer and dumped in a video on demand (it is important that you +# call this method when appropriate): + +rollout = record_env.rollout(max_steps=3) +# Uncomment this line to save the video on disk: +# recorder.dump() + +##################################### +# In this specific case, the video format can be chosen when instantiating +# the CSVLogger. +# +# This is all we wanted to cover in the getting started tutorial. +# You should now be ready to code your +# :ref:`first training loop with TorchRL `! +# diff --git a/tutorials/sphinx-tutorials/getting-started-5.py b/tutorials/sphinx-tutorials/getting-started-5.py new file mode 100644 index 00000000000..8413d0c9130 --- /dev/null +++ b/tutorials/sphinx-tutorials/getting-started-5.py @@ -0,0 +1,183 @@ +# -*- coding: utf-8 -*- +""" +Get started with your onw first training loop +============================================= + +**Author**: `Vincent Moens `_ + +.. _gs_first_training: + +""" + +################################# +# Time to wrap up everything we've learned so far in this Getting Started +# series! +# +# In this tutorial, we will be writing the most basic training loop there is +# using only components we have presented in the previous lessons. +# +# We'll be using DQN with a CartPole environment as a prototypical example. +# +# We will be voluntarily keeping the verbosity to its minimum, only linking +# each section to the related tutorial. +# +# Building the environment +# ------------------------ +# +# We'll be using a gym environment with a :class:`~torchrl.envs.transforms.StepCounter` +# transform. If you need a refresher, check our these features are presented in +# :ref:`the environment tutorial `. +# + +import torch + +torch.manual_seed(0) + +import time + +from torchrl.envs import GymEnv, StepCounter, TransformedEnv + +env = TransformedEnv(GymEnv("CartPole-v1"), StepCounter()) +env.set_seed(0) + +from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq + +################################# +# Designing a policy +# ------------------ +# +# The next step is to build our policy. We'll be making a regular, deterministic +# version to be used within the :ref:`loss module ` and during +# :ref:`evaluation `, and one augmented by an exploration module +# for :ref:`inference `. + +from torchrl.modules import EGreedyModule, MLP, QValueModule + +value_mlp = MLP(out_features=env.action_spec.shape[-1], num_cells=[64, 64]) +value_net = Mod(value_mlp, in_keys=["observation"], out_keys=["action_value"]) +policy = Seq(value_net, QValueModule(env.action_spec)) +exploration_module = EGreedyModule( + env.action_spec, annealing_num_steps=100_000, eps_init=0.5 +) +policy_explore = Seq(policy, exploration_module) + + +################################# +# Data Collector and replay buffer +# -------------------------------- +# +# Here comes the data part: we need a +# :ref:`data collector ` to easily get batches of data +# and a :ref:`replay buffer ` to store that data for training. +# + +from torchrl.collectors import SyncDataCollector +from torchrl.data import LazyTensorStorage, ReplayBuffer + +init_rand_steps = 5000 +frames_per_batch = 100 +optim_steps = 10 +collector = SyncDataCollector( + env, + policy, + frames_per_batch=frames_per_batch, + total_frames=-1, + init_random_frames=init_rand_steps, +) +rb = ReplayBuffer(storage=LazyTensorStorage(100_000)) + +from torch.optim import Adam + +################################# +# Loss module and optimizer +# ------------------------- +# +# We build our loss as indicated in the :ref:`dedicated tutorial `, with +# its optimizer and target parameter updater: + +from torchrl.objectives import DQNLoss, SoftUpdate + +loss = DQNLoss(value_network=policy, action_space=env.action_spec, delay_value=True) +optim = Adam(loss.parameters(), lr=0.02) +updater = SoftUpdate(loss, eps=0.99) + +################################# +# Logger +# ------ +# +# We'll be using a CSV logger to log our results, and save rendered videos. +# + +from torchrl._utils import logger as torchrl_logger +from torchrl.record import CSVLogger, VideoRecorder + +path = "./training_loop" +logger = CSVLogger(exp_name="dqn", log_dir=path, video_format="mp4") +video_recorder = VideoRecorder(logger, tag="video") +record_env = TransformedEnv( + GymEnv("CartPole-v1", from_pixels=True, pixels_only=False), video_recorder +) + +################################# +# Training loop +# ------------- +# +# Instead of fixing a specific number of iterations to run, we will keep on +# training the network until it reaches a certain performance (arbitrarily +# defined as 200 steps in the environment -- with CartPole, success is defined +# as having longer trajectories). +# + +total_count = 0 +total_episodes = 0 +t0 = time.time() +for i, data in enumerate(collector): + # Write data in replay buffer + rb.extend(data) + max_length = rb[:]["next", "step_count"].max() + if len(rb) > init_rand_steps: + # Optim loop (we do several optim steps + # per batch collected for efficiency) + for _ in range(optim_steps): + sample = rb.sample(128) + loss_vals = loss(sample) + loss_vals["loss"].backward() + optim.step() + optim.zero_grad() + # Update exploration factor + exploration_module.step(data.numel()) + # Update target params + updater.step() + if i % 10: + torchrl_logger.info(f"Max num steps: {max_length}, rb length {len(rb)}") + total_count += data.numel() + total_episodes += data["next", "done"].sum() + if max_length > 200: + break + +t1 = time.time() + +torchrl_logger.info( + f"solved after {total_count} steps, {total_episodes} episodes and in {t1-t0}s." +) + +################################# +# Rendering +# --------- +# +# Finally, we run the environment for as many steps as we can and save the +# video locally (notice that we are not exploring). + +record_env.rollout(max_steps=1000, policy=policy) +video_recorder.dump() + +################################# +# +# This is what your rendered CartPole video will look like after a full +# training loop: +# +# .. figure:: /_static/img/cartpole.gif +# +# This concludes our series of "Getting started with TorchRL" tutorials! +# Feel free to share feedback about it on GitHub. +# diff --git a/tutorials/sphinx-tutorials/pendulum.py b/tutorials/sphinx-tutorials/pendulum.py index a67976566d5..8e7817978e4 100644 --- a/tutorials/sphinx-tutorials/pendulum.py +++ b/tutorials/sphinx-tutorials/pendulum.py @@ -6,6 +6,8 @@ **Author**: `Vincent Moens `_ +.. _pendulum_tuto: + Creating an environment (a simulator or an interface to a physical control system) is an integrative part of reinforcement learning and control engineering. diff --git a/tutorials/sphinx-tutorials/rb_tutorial.py b/tutorials/sphinx-tutorials/rb_tutorial.py index 3d37ce3de83..2c5cd95e780 100644 --- a/tutorials/sphinx-tutorials/rb_tutorial.py +++ b/tutorials/sphinx-tutorials/rb_tutorial.py @@ -5,6 +5,8 @@ **Author**: `Vincent Moens `_ +.. _rb_tuto: + """ ###################################################################### # Replay buffers are a central piece of any RL or control algorithm. @@ -30,17 +32,24 @@ # # # In this tutorial, you will learn: -# - How to build a Replay Buffer (RB) and use it with any datatype; -# - How to use RBs with TensorDict; -# - How to sample from or iterate over a replay buffer, and how to define the sampling strategy; -# - How to use prioritized replay buffers; -# - How to transform data coming in and out from the buffer; -# - How to store trajectories in the buffer. +# +# - How to build a :ref:`Replay Buffer (RB) ` and use it with +# any datatype; +# - How to customize the :ref:`buffer's storage `; +# - How to use :ref:`RBs with TensorDict `; +# - How to :ref:`sample from or iterate over a replay buffer `, +# and how to define the sampling strategy; +# - How to use :ref:`prioritized replay buffers `; +# - How to :ref:`transform data ` coming in and out from +# the buffer; +# - How to store :ref:`trajectories ` in the buffer. # # # Basics: building a vanilla replay buffer # ---------------------------------------- # +# .. _tuto_rb_vanilla: +# # TorchRL's replay buffers are designed to prioritize modularity, # composability, efficiency, and simplicity. For instance, creating a basic # replay buffer is a straightforward process, as shown in the following @@ -77,7 +86,7 @@ ###################################################################### # By default, this replay buffer will have a size of 1000. Let's check this -# by populating our buffer using the :meth:`torchrl.data.ReplayBuffer.extend` +# by populating our buffer using the :meth:`~torchrl.data.ReplayBuffer.extend` # method: # @@ -87,24 +96,24 @@ print("length after adding elements:", len(buffer)) -import torch -from tensordict import TensorDict - ###################################################################### -# We have used the :meth:`torchrl.data.ReplayBuffer.extend` method which is +# We have used the :meth:`~torchrl.data.ReplayBuffer.extend` method which is # designed to add multiple items all at once. If the object that is passed # to ``extend`` has more than one dimension, its first dimension is # considered to be the one to be split in separate elements in the buffer. +# # This essentially means that when adding multidimensional tensors or # tensordicts to the buffer, the buffer will only look at the first dimension # when counting the elements it holds in memory. # If the object passed it not iterable, an exception will be thrown. # -# To add items one at a time, the :meth:`torchrl.data.ReplayBuffer.add` method +# To add items one at a time, the :meth:`~torchrl.data.ReplayBuffer.add` method # should be used instead. # # Customizing the storage -# ~~~~~~~~~~~~~~~~~~~~~~~ +# ----------------------- +# +# .. _tuto_rb_storage: # # We see that the buffer has been capped to the first 1000 elements that we # passed to it. @@ -112,25 +121,27 @@ # # TorchRL proposes three types of storages: # -# - The :class:`torchrl.dataListStorage` stores elements independently in a +# - The :class:`~torchrl.data.ListStorage` stores elements independently in a # list. It supports any data type, but this flexibility comes at the cost # of efficiency; -# - The :class:`torchrl.dataLazyTensorStorage` stores tensors or -# :class:`tensordidct.TensorDict` (or :class:`torchrl.data.tensorclass`) +# - The :class:`~torchrl.data.LazyTensorStorage` stores tensors data +# structures contiguously. +# It works naturally with :class:`~tensordidct.TensorDict` +# (or :class:`~torchrl.data.tensorclass`) # objects. The storage is contiguous on a per-tensor basis, meaning that # sampling will be more efficient than when using a list, but the # implicit restriction is that any data passed to it must have the same -# basic properties as the -# first batch of data that was used to instantiate the buffer. +# basic properties (such as shape and dtype) as the first batch of data that +# was used to instantiate the buffer. # Passing data that does not match this requirement will either raise an # exception or lead to some undefined behaviours. -# - The :class:`torchrl.dataLazyMemmapStorage` works as the -# :class:`torchrl.data.LazyTensorStorage` in that it is lazy (ie. it +# - The :class:`~torchrl.data.LazyMemmapStorage` works as the +# :class:`~torchrl.data.LazyTensorStorage` in that it is lazy (i.e., it # expects the first batch of data to be instantiated), and it requires data # that match in shape and dtype for each batch stored. What makes this -# storage unique is that it points to disk files, meaning that it can -# support very large datasets while still accessing data in a contiguous -# manner. +# storage unique is that it points to disk files (or uses the filesystem +# storage), meaning that it can support very large datasets while still +# accessing data in a contiguous manner. # # Let us see how we can use each of these storages: @@ -149,9 +160,9 @@ ###################################################################### # Because it is the one with the lowest amount of assumption, the -# :class:`torchrl.data.ListStorage` is the default storage in TorchRL. +# :class:`~torchrl.data.ListStorage` is the default storage in TorchRL. # -# A :class:`torchrl.data.LazyTensorStorage` can store data contiguously. +# A :class:`~torchrl.data.LazyTensorStorage` can store data contiguously. # This should be the preferred option when dealing with complicated but # unchanging data structures of medium size: @@ -161,6 +172,10 @@ # Let us create a batch of data of size ``torch.Size([3])` with 2 tensors # stored in it: # + +import torch +from tensordict import TensorDict + data = TensorDict( { "a": torch.arange(12).view(3, 4), @@ -171,7 +186,7 @@ print(data) ###################################################################### -# The first call to :meth:`torchrl.data.ReplayBuffer.extend` will +# The first call to :meth:`~torchrl.data.ReplayBuffer.extend` will # instantiate the storage. The first dimension of the data is unbound into # separate datapoints: @@ -186,7 +201,7 @@ print("samples", sample["a"], sample["b", "c"]) ###################################################################### -# A :class:`torchrl.data.LazyMemmapStorage` is created in the same manner: +# A :class:`~torchrl.data.LazyMemmapStorage` is created in the same manner: # buffer_lazymemmap = ReplayBuffer(storage=LazyMemmapStorage(size)) @@ -213,16 +228,20 @@ # Integration with TensorDict # --------------------------- # +# .. _tuto_rb_td: +# # The tensor location follows the same structure as the TensorDict that # contains them: this makes it easy to save and load buffers during training. # -# To use :class:`tensordict.TensorDict` as a data carrier at its fullest -# potential, the :class:`torchrl.data.TensorDictReplayBuffer` class should +# To use :class:`~tensordict.TensorDict` as a data carrier at its fullest +# potential, the :class:`~torchrl.data.TensorDictReplayBuffer` class can # be used. # One of its key benefits is its ability to handle the organization of sampled # data, along with any additional information that may be required # (such as sample indices). -# It can be built in the same manner as a standard :class:`torchrl.data.ReplayBuffer` and can +# +# It can be built in the same manner as a standard +# :class:`~torchrl.data.ReplayBuffer` and can # generally be used interchangeably. # @@ -250,7 +269,7 @@ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # The ReplayBuffer class and associated subclasses also work natively with -# :class:`tensordict.tensorclass` classes, which can conviniently be used to +# :class:`~tensordict.tensorclass` classes, which can conveniently be used to # encode datasets in a more explicit manner: from tensordict import tensorclass @@ -284,31 +303,28 @@ class MyData: ###################################################################### # As expected. the data has the proper class and shape! # -# Integration with PyTree -# ~~~~~~~~~~~~~~~~~~~~~~~ +# Integration with other tensor structures (PyTrees) +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # TorchRL's replay buffers also work with any pytree data structure. # A PyTree is a nested structure of arbitrary depth made of dicts, lists and/or # tuples where the leaves are tensors. # This means that one can store in contiguous memory any such tree structure! # Various storages can be used: -# :class:`~torchrl.data.replay_buffers.TensorStorage`, :class:`~torchrl.data.replay_buffers.LazyMemmapStorage` -# or :class:`~torchrl.data.replay_buffers.LazyTensorStorage` all accept this kind of data. +# :class:`~torchrl.data.replay_buffers.TensorStorage`, +# :class:`~torchrl.data.replay_buffers.LazyMemmapStorage` +# or :class:`~torchrl.data.replay_buffers.LazyTensorStorage` all accept this +# kind of data. # -# Here is a bried demonstration of what this feature looks like: +# Here is a brief demonstration of what this feature looks like: # from torch.utils._pytree import tree_map -# With pytrees, any callable can be used as a transform: -def transform(x): - # Zeros all the data in the pytree - return tree_map(lambda y: y * 0, x) - - +###################################################################### # Let's build our replay buffer on disk: -rb = ReplayBuffer(storage=LazyMemmapStorage(size), transform=transform) +rb = ReplayBuffer(storage=LazyMemmapStorage(size)) data = { "a": torch.randn(3), "b": {"c": (torch.zeros(2), [torch.ones(1)])}, @@ -320,6 +336,20 @@ def transform(x): sample = rb.sample(10) +###################################################################### +# With pytrees, any callable can be used as a transform: + + +def transform(x): + # Zeros all the data in the pytree + return tree_map(lambda y: y * 0, x) + + +rb.append_transform(transform) +sample = rb.sample(batch_size=12) + + +###################################################################### # let's check that our transform did its job: def assert0(x): assert (x == 0).all() @@ -328,9 +358,12 @@ def assert0(x): tree_map(assert0, sample) +###################################################################### # Sampling and iterating over buffers # ----------------------------------- # +# .. _tuto_rb_sampling: +# # Replay Buffers support multiple sampling strategies: # # - If the batch-size is fixed and can be defined at construction time, it can @@ -338,7 +371,7 @@ def assert0(x): # - With a fixed batch-size, the replay buffer can be iterated over to gather # samples; # - If the batch-size is dynamic, it can be passed to the -# :class:`torchrl.data.ReplayBuffer.sample` method +# :class:`~torchrl.data.ReplayBuffer.sample` method # on-the-fly. # # Sampling can be done using multithreading, but this is incompatible with the @@ -349,21 +382,22 @@ def assert0(x): # # Fixed batch-size # ~~~~~~~~~~~~~~~~ -# If the batch-size is passed during construction, it should be omited when +# +# If the batch-size is passed during construction, it should be omitted when # sampling: data = MyData( images=torch.randint( 255, - (10, 64, 64, 3), + (200, 64, 64, 3), ), - labels=torch.randint(100, (10,)), - batch_size=[10], + labels=torch.randint(100, (200,)), + batch_size=[200], ) buffer_lazymemmap = ReplayBuffer(storage=LazyMemmapStorage(size), batch_size=128) -buffer_lazymemmap.add(data) -buffer_lazymemmap.sample() # will produces 128 identical samples +buffer_lazymemmap.extend(data) +buffer_lazymemmap.sample() ###################################################################### @@ -371,19 +405,20 @@ def assert0(x): # # To enable multithreaded sampling, just pass a positive integer to the # ``prefetch`` keyword argument during construction. This should speed up -# sampling considerably: +# sampling considerably whenever sampling is time consuming (e.g., when +# using prioritized samplers): buffer_lazymemmap = ReplayBuffer( storage=LazyMemmapStorage(size), batch_size=128, prefetch=10 ) # creates a queue of 10 elements to be prefetched in the background -buffer_lazymemmap.add(data) +buffer_lazymemmap.extend(data) print(buffer_lazymemmap.sample()) ###################################################################### -# Fixed batch-size, iterating over the buffer -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Iterating over the buffer with a fixed batch-size +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # We can also iterate over the buffer like we would do with a regular # dataloader, as long as the batch-size is predefined: @@ -398,7 +433,8 @@ def assert0(x): ###################################################################### # Due to the fact that our sampling technique is entirely random and does not # prevent replacement, the iterator in question is infinite. However, we can -# make use of the :class:`torchrl.data.replay_buffers.SamplerWithoutReplacement` +# make use of the +# :class:`~torchrl.data.replay_buffers.SamplerWithoutReplacement` # instead, which will transform our buffer into a finite iterator: # @@ -428,7 +464,7 @@ def assert0(x): # ~~~~~~~~~~~~~~~~~~ # # In contrast to what we have seen earlier, the ``batch_size`` keyword -# argument can be omitted and passed directly to the `sample` method: +# argument can be omitted and passed directly to the ``sample`` method: buffer_lazymemmap = ReplayBuffer( @@ -442,7 +478,10 @@ def assert0(x): # Prioritized Replay buffers # -------------------------- # -# TorchRL also provides an interface for prioritized replay buffers. +# .. _tuto_rb_prb: +# +# TorchRL also provides an interface for +# `prioritized replay buffers `_. # This buffer class samples data according to a priority signal that is passed # through the data. # @@ -476,8 +515,8 @@ def assert0(x): # buffer, the priority is set to a default value of 1. Once the priority has # been computed (usually through the loss), it must be updated in the buffer. # -# This is done via the `update_priority` method, which requires the indices -# as well as the priority. +# This is done via the :meth:`~torchrl.data.ReplayBuffer.update_priority` +# method, which requires the indices as well as the priority. # We assign an artificially high priority to the second sample in the dataset # to observe its effect on sampling: # @@ -499,6 +538,7 @@ def assert0(x): ###################################################################### # We see that using a prioritized replay buffer requires a series of extra # steps in the training loop compared with a regular buffer: +# # - After collecting data and extending the buffer, the priority of the # items must be updated; # - After computing the loss and getting a "priority signal" from it, we must @@ -511,10 +551,10 @@ def assert0(x): # that the appropriate methods are called at the appropriate place, if and # only if a prioritized buffer is being used. # -# Let us see how we can improve this with TensorDict. We saw that the -# :class:`torchrl.data.TensorDictReplayBuffer` returns data augmented with -# their relative storage indices. One feature we did not mention is that -# this class also ensures that the priority +# Let us see how we can improve this with :class:`~tensordict.TensorDict`. +# We saw that the :class:`~torchrl.data.TensorDictReplayBuffer` returns data +# augmented with their relative storage indices. One feature we did not mention +# is that this class also ensures that the priority # signal is automatically parsed to the prioritized sampler if present during # extension. # @@ -582,6 +622,8 @@ def assert0(x): # Using transforms # ---------------- # +# .. _tuto_rb_transform: +# # The data stored in a replay buffer may not be ready to be presented to a # loss module. # In some cases, the data produced by a collector can be too heavy to be @@ -605,8 +647,14 @@ def assert0(x): from torchrl.collectors import RandomPolicy, SyncDataCollector -from torchrl.envs import Compose, GrayScale, Resize, ToTensorImage, TransformedEnv from torchrl.envs.libs.gym import GymEnv +from torchrl.envs.transforms import ( + Compose, + GrayScale, + Resize, + ToTensorImage, + TransformedEnv, +) env = TransformedEnv( GymEnv("CartPole-v1", from_pixels=True), @@ -630,7 +678,7 @@ def assert0(x): # To do this, we will append a transform to the collector to select the keys # we want to see appearing: -from torchrl.envs import ExcludeTransform +from torchrl.envs.transforms import ExcludeTransform collector = SyncDataCollector( env, @@ -685,7 +733,7 @@ def assert0(x): # A more complex examples: using CatFrames # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -# The :class:`torchrl.envs.CatFrames` transform unfolds the observations +# The :class:`~torchrl.envs.transforms.CatFrames` transform unfolds the observations # through time, creating a n-back memory of past events that allow the model # to take the past events into account (in the case of POMDPs or with # recurrent policies such as Decision Transformers). Storing these concatenated @@ -752,6 +800,56 @@ def assert0(x): assert (data.exclude("collector") == s.squeeze(0).exclude("index", "collector")).all() +###################################################################### +# Storing trajectories +# -------------------- +# +# .. _tuto_rb_traj: +# +# In many cases, it is desirable to access trajectories from the buffer rather +# than simple transitions. TorchRL offers multiple ways of achieving this. +# +# The preferred way is currently to store trajectories along the first +# dimension of the buffer and use a :class:`~torchrl.data.SliceSampler` to +# sample these batches of data. This class only needs a couple of information +# about your data structure to do its job (not that as of now it is only +# compatible with tensordict-structured data): the number of slices or their +# length and some information about where the separation between the +# episodes can be found (e.g. :ref:`recall that ` with a +# :ref:`DataCollector `, the trajectory id is stored in +# ``("collector", "traj_ids")``). In this simple example, we construct a data +# with 4 consecutive short trajectories and sample 4 slices out of it, each of +# length 2 (since the batch size is 8, and 8 items // 4 slices = 2 time steps). +# We mark the steps as well. + +from torchrl.data import SliceSampler + +rb = TensorDictReplayBuffer( + storage=LazyMemmapStorage(size), + sampler=SliceSampler(traj_key="episode", num_slices=4), + batch_size=8, +) +episode = torch.zeros(10, dtype=torch.int) +episode[:3] = 1 +episode[3:5] = 2 +episode[5:7] = 3 +episode[7:] = 4 +steps = torch.cat([torch.arange(3), torch.arange(2), torch.arange(2), torch.arange(3)]) +data = TensorDict( + { + "episode": episode, + "obs": torch.randn((3, 4, 5)).expand(10, 3, 4, 5), + "act": torch.randn((20,)).expand(10, 20), + "other": torch.randn((20, 50)).expand(10, 20, 50), + "steps": steps, + }, + [10], +) +rb.extend(data) +sample = rb.sample() +print("episode are grouped", sample["episode"]) +print("steps are successive", sample["steps"]) + ###################################################################### # Conclusion # ---------- @@ -765,3 +863,13 @@ def assert0(x): # - Choose the best storage type for your problem (list, memory or disk-based); # - Minimize the memory footprint of your buffer. # +# Next steps +# ---------- +# +# - Check the data API reference to learn about offline datasets in TorchRL, +# which are based on our Replay Buffer API; +# - Check other samplers such as +# :class:`~torchrl.data.SamplerWithoutReplacement`, +# :class:`~torchrl.data.PrioritizedSliceSampler` and +# :class:`~torchrl.data.SliceSamplerWithoutReplacement`, or other writers +# such as :class:`~torchrl.data.TensorDictMaxValueWriter`. diff --git a/tutorials/sphinx-tutorials/torchrl_demo.py b/tutorials/sphinx-tutorials/torchrl_demo.py index 5e00442fe36..ce3f0bb4b98 100644 --- a/tutorials/sphinx-tutorials/torchrl_demo.py +++ b/tutorials/sphinx-tutorials/torchrl_demo.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- """ Introduction to TorchRL -============================ +======================= This demo was presented at ICML 2022 on the industry demo day. """ ############################################################################## @@ -746,8 +746,7 @@ total_frames=240, max_frames_per_traj=-1, # envs are terminating, we don't need to stop them early frames_per_batch=60, # we want 60 frames at a time (we have 3 envs per sub-collector) - storing_devices=devices, # len must match len of env created - devices=devices, + device=devices, ) ############################################################################### @@ -769,8 +768,7 @@ total_frames=240, max_frames_per_traj=-1, # envs are terminating, we don't need to stop them early frames_per_batch=60, # we want 60 frames at a time (we have 3 envs per sub-collector) - storing_devices=devices, # len must match len of env created - devices=devices, + device=devices, ) for i, d in enumerate(collector): @@ -805,7 +803,8 @@ def forward(self, obs, action): value_module, in_keys=["observation", "action"], out_keys=["state_action_value"] ) -loss_fn = DDPGLoss(actor, value, gamma=0.99) +loss_fn = DDPGLoss(actor, value) +loss_fn.make_value_estimator(loss_fn.default_value_estimator, gamma=0.99) ############################################################################### diff --git a/tutorials/sphinx-tutorials/torchrl_envs.py b/tutorials/sphinx-tutorials/torchrl_envs.py index 56896637a87..4c792d44b80 100644 --- a/tutorials/sphinx-tutorials/torchrl_envs.py +++ b/tutorials/sphinx-tutorials/torchrl_envs.py @@ -1,9 +1,15 @@ # -*- coding: utf-8 -*- """ TorchRL envs -============================ +============ + +**Author**: `Vincent Moens `_ + +.. _envs_tuto: + """ ############################################################################## +# # Environments play a crucial role in RL settings, often somewhat similar to # datasets in supervised and unsupervised settings. The RL community has # become quite familiar with OpenAI gym API which offers a flexible way of @@ -19,7 +25,10 @@ # To run this part of the tutorial, you will need to have a recent version of # the gym library installed, as well as the atari suite. You can get this # installed by installing the following packages: -# $ pip install gym atari-py ale-py gym[accept-rom-license] pygame +# +# .. code-block:: +# $ pip install gym atari-py ale-py gym[accept-rom-license] pygame +# # To unify all frameworks, torchrl environments are built inside the # ``__init__`` method with a private method called ``_build_env`` that # will pass the arguments and keyword arguments to the root library builder. From 89213f9c153ba0c61ab4dbe029c88dec89ca3c66 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 10 Feb 2024 15:39:28 +0000 Subject: [PATCH 03/10] Update getting-started-5.py (#1894) --- tutorials/sphinx-tutorials/getting-started-0.py | 8 ++++++++ tutorials/sphinx-tutorials/getting-started-1.py | 8 ++++++++ tutorials/sphinx-tutorials/getting-started-2.py | 8 ++++++++ tutorials/sphinx-tutorials/getting-started-3.py | 8 ++++++++ tutorials/sphinx-tutorials/getting-started-4.py | 8 ++++++++ tutorials/sphinx-tutorials/getting-started-5.py | 17 ++++++++++++++--- 6 files changed, 54 insertions(+), 3 deletions(-) diff --git a/tutorials/sphinx-tutorials/getting-started-0.py b/tutorials/sphinx-tutorials/getting-started-0.py index e81b10ec381..3d479a2a67f 100644 --- a/tutorials/sphinx-tutorials/getting-started-0.py +++ b/tutorials/sphinx-tutorials/getting-started-0.py @@ -8,6 +8,14 @@ .. _gs_env_ted: +.. note:: To run this tutorial in a notebook, add an installation cell + at the beginning containing: + + .. code-block:: + + !pip install tensordict + !pip install torchrl + """ ################################ diff --git a/tutorials/sphinx-tutorials/getting-started-1.py b/tutorials/sphinx-tutorials/getting-started-1.py index 136deeb5cd9..75ccf7cf8e7 100644 --- a/tutorials/sphinx-tutorials/getting-started-1.py +++ b/tutorials/sphinx-tutorials/getting-started-1.py @@ -7,6 +7,14 @@ .. _gs_modules: +.. note:: To run this tutorial in a notebook, add an installation cell + at the beginning containing: + + .. code-block:: + + !pip install tensordict + !pip install torchrl + """ ################################### # Reinforcement Learning is designed to create policies that can effectively diff --git a/tutorials/sphinx-tutorials/getting-started-2.py b/tutorials/sphinx-tutorials/getting-started-2.py index 1d903e67c01..0a16071bed2 100644 --- a/tutorials/sphinx-tutorials/getting-started-2.py +++ b/tutorials/sphinx-tutorials/getting-started-2.py @@ -7,6 +7,14 @@ .. _gs_optim: +.. note:: To run this tutorial in a notebook, add an installation cell + at the beginning containing: + + .. code-block:: + + !pip install tensordict + !pip install torchrl + """ ################################### diff --git a/tutorials/sphinx-tutorials/getting-started-3.py b/tutorials/sphinx-tutorials/getting-started-3.py index 97934ef424d..829b22cf061 100644 --- a/tutorials/sphinx-tutorials/getting-started-3.py +++ b/tutorials/sphinx-tutorials/getting-started-3.py @@ -7,6 +7,14 @@ .. _gs_storage: +.. note:: To run this tutorial in a notebook, add an installation cell + at the beginning containing: + + .. code-block:: + + !pip install tensordict + !pip install torchrl + """ ################################# diff --git a/tutorials/sphinx-tutorials/getting-started-4.py b/tutorials/sphinx-tutorials/getting-started-4.py index bff30d79851..a7c6462375d 100644 --- a/tutorials/sphinx-tutorials/getting-started-4.py +++ b/tutorials/sphinx-tutorials/getting-started-4.py @@ -7,6 +7,14 @@ .. _gs_logging: +.. note:: To run this tutorial in a notebook, add an installation cell + at the beginning containing: + + .. code-block:: + + !pip install tensordict + !pip install torchrl + """ ##################################### diff --git a/tutorials/sphinx-tutorials/getting-started-5.py b/tutorials/sphinx-tutorials/getting-started-5.py index 8413d0c9130..447a78ae200 100644 --- a/tutorials/sphinx-tutorials/getting-started-5.py +++ b/tutorials/sphinx-tutorials/getting-started-5.py @@ -7,6 +7,14 @@ .. _gs_first_training: +.. note:: To run this tutorial in a notebook, add an installation cell + at the beginning containing: + + .. code-block:: + + !pip install tensordict + !pip install torchrl + """ ################################# @@ -46,9 +54,12 @@ # Designing a policy # ------------------ # -# The next step is to build our policy. We'll be making a regular, deterministic -# version to be used within the :ref:`loss module ` and during -# :ref:`evaluation `, and one augmented by an exploration module +# The next step is to build our policy. +# We'll be making a regular, deterministic +# version of the actor to be used within the +# :ref:`loss module ` and during +# :ref:`evaluation `. +# Next, we will augment it with an exploration module # for :ref:`inference `. from torchrl.modules import EGreedyModule, MLP, QValueModule From 2cfd9b6c8d831043949ee6d2a5122791542d8723 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 10 Feb 2024 21:01:35 +0000 Subject: [PATCH 04/10] [BugFix] Solve recursion issue in losses hook (#1897) --- torchrl/objectives/common.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 5d620b56227..b22c735bfac 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -97,7 +97,6 @@ def tensor_keys(self) -> _AcceptedKeys: return self._tensor_keys def __new__(cls, *args, **kwargs): - cls.forward = set_exploration_type(ExplorationType.MODE)(cls.forward) self = super().__new__(cls) return self @@ -110,7 +109,16 @@ def __init__(self): self.value_type = self.default_value_estimator self._tensor_keys = self._AcceptedKeys() self.register_forward_pre_hook(_updater_check_forward_prehook) - # self.register_forward_pre_hook(_parameters_to_tensordict) + expl_mode = set_exploration_type(ExplorationType.MODE) + + def _pre_hook(*args, expl_mode=expl_mode, **kwargs): + expl_mode.__enter__() + + def _post_hook(*args, expl_mode=expl_mode, **kwargs): + expl_mode.__exit__(exc_type=None, exc_value=None, traceback=None) + + self.register_forward_pre_hook(_pre_hook) + self.register_forward_hook(_post_hook) def _set_deprecated_ctor_keys(self, **kwargs) -> None: for key, value in kwargs.items(): From 1bd5ec640a4a37cbb95c402df91b199c9a8b1736 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 11 Feb 2024 17:04:42 +0000 Subject: [PATCH 05/10] [BugFix] Fix exploration in losses (#1898) --- test/test_cost.py | 17 +++++++++++++++++ test/test_exploration.py | 1 + torchrl/objectives/common.py | 19 ++++++++----------- 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index dae1fa5f70c..064a38ced60 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -25,6 +25,7 @@ TensorDictSequential, TensorDictSequential as Seq, ) +from torchrl.envs.utils import exploration_type, ExplorationType, set_exploration_type from torchrl.modules.models import QMixer @@ -12391,6 +12392,22 @@ def __init__(self): assert p.device == dest +def test_loss_exploration(): + class DummyLoss(LossModule): + def forward(self, td): + assert exploration_type() == InteractionType.MODE + with set_exploration_type(ExplorationType.RANDOM): + assert exploration_type() == ExplorationType.RANDOM + assert exploration_type() == ExplorationType.MODE + return td + + loss_fn = DummyLoss() + with set_exploration_type(ExplorationType.RANDOM): + assert exploration_type() == ExplorationType.RANDOM + loss_fn(None) + assert exploration_type() == ExplorationType.RANDOM + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_exploration.py b/test/test_exploration.py index d0735a53ae8..e6493bd1804 100644 --- a/test/test_exploration.py +++ b/test/test_exploration.py @@ -54,6 +54,7 @@ class TestEGreedy: @pytest.mark.parametrize("eps_init", [0.0, 0.5, 1]) @pytest.mark.parametrize("module", [True, False]) + @set_exploration_type(InteractionType.RANDOM) def test_egreedy(self, eps_init, module): torch.manual_seed(0) spec = BoundedTensorSpec(1, 1, torch.Size([4])) diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index b22c735bfac..cc11ac8b29e 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -5,6 +5,7 @@ from __future__ import annotations +import abc import warnings from copy import deepcopy from dataclasses import dataclass @@ -31,7 +32,13 @@ def _updater_check_forward_prehook(module, *args, **kwargs): ) -class LossModule(TensorDictModuleBase): +class _LossMeta(abc.ABCMeta): + def __init__(cls, name, bases, attr_dict): + super().__init__(name, bases, attr_dict) + cls.forward = set_exploration_type(ExplorationType.MODE)(cls.forward) + + +class LossModule(TensorDictModuleBase, metaclass=_LossMeta): """A parent class for RL losses. LossModule inherits from nn.Module. It is designed to read an input @@ -109,16 +116,6 @@ def __init__(self): self.value_type = self.default_value_estimator self._tensor_keys = self._AcceptedKeys() self.register_forward_pre_hook(_updater_check_forward_prehook) - expl_mode = set_exploration_type(ExplorationType.MODE) - - def _pre_hook(*args, expl_mode=expl_mode, **kwargs): - expl_mode.__enter__() - - def _post_hook(*args, expl_mode=expl_mode, **kwargs): - expl_mode.__exit__(exc_type=None, exc_value=None, traceback=None) - - self.register_forward_pre_hook(_pre_hook) - self.register_forward_hook(_post_hook) def _set_deprecated_ctor_keys(self, **kwargs) -> None: for key, value in kwargs.items(): From 1647fa475cc0d329d00718f40c4872d408cdf0a7 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 12 Feb 2024 09:25:17 +0000 Subject: [PATCH 06/10] [BugFix] Fix flaky rb tests (#1901) --- test/test_rb.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/test/test_rb.py b/test/test_rb.py index 697981909b5..1f5c4e90bd7 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -672,6 +672,8 @@ def test_storage_state_dict(self, storage_in, storage_out, init_out, backend): def test_storage_dumps_loads( self, device_data, storage_type, data_type, isinit, tmpdir ): + torch.manual_seed(0) + dir_rb = tmpdir / "rb" dir_save = tmpdir / "save" dir_rb.mkdir() @@ -716,15 +718,18 @@ class TC: ) else: raise NotImplementedError + if storage_type in (LazyMemmapStorage,): storage = storage_type(max_size=10, scratch_dir=dir_rb) else: storage = storage_type(max_size=10) + # We cast the device to CPU as CUDA isn't automatically cast to CPU when using range() index if data_type == "pytree": storage.set(range(3), tree_map(lambda x: x.cpu(), data)) else: storage.set(range(3), data.cpu()) + storage.dumps(dir_save) # check we can dump twice storage.dumps(dir_save) @@ -732,9 +737,11 @@ class TC: storage_recover = storage_type(max_size=10) if isinit: if data_type == "pytree": - storage_recover.set(range(3), tree_map(lambda x: x.cpu().zero_(), data)) + storage_recover.set( + range(3), tree_map(lambda x: x.cpu().clone().zero_(), data) + ) else: - storage_recover.set(range(3), data.cpu().zero_()) + storage_recover.set(range(3), data.cpu().clone().zero_()) if data_type in ("tensor", "pytree") and not isinit: with pytest.raises( From 6f6c896e4fcc405a98db979060c36a8f355398c8 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 12 Feb 2024 09:43:18 +0000 Subject: [PATCH 07/10] [BugFix] Adaptable non-blocking for mps and non cuda device in batched-envs (#1900) --- torchrl/collectors/collectors.py | 8 ++-- torchrl/envs/batched_envs.py | 64 ++++++++++++++++++++------------ 2 files changed, 44 insertions(+), 28 deletions(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 98040d9640e..7ae016a3f69 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -577,7 +577,7 @@ def __init__( reset_when_done: bool = True, interruptor=None, ): - from torchrl.envs.batched_envs import _BatchedEnv + from torchrl.envs.batched_envs import BatchedEnvBase self.closed = True @@ -591,7 +591,7 @@ def __init__( else: env = create_env_fn if create_env_kwargs: - if not isinstance(env, _BatchedEnv): + if not isinstance(env, BatchedEnvBase): raise RuntimeError( "kwargs were passed to SyncDataCollector but they can't be set " f"on environment of type {type(create_env_fn)}." @@ -1201,11 +1201,11 @@ def state_dict(self) -> OrderedDict: `"env_state_dict"`. """ - from torchrl.envs.batched_envs import _BatchedEnv + from torchrl.envs.batched_envs import BatchedEnvBase if isinstance(self.env, TransformedEnv): env_state_dict = self.env.transform.state_dict() - elif isinstance(self.env, _BatchedEnv): + elif isinstance(self.env, BatchedEnvBase): env_state_dict = self.env.state_dict() else: env_state_dict = OrderedDict() diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 2a955af1261..67802f01620 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -48,7 +48,7 @@ def _check_start(fun): - def decorated_fun(self: _BatchedEnv, *args, **kwargs): + def decorated_fun(self: BatchedEnvBase, *args, **kwargs): if self.is_closed: self._create_td() self._start_workers() @@ -121,7 +121,7 @@ def __call__(cls, *args, **kwargs): return super().__call__(*args, **kwargs) -class _BatchedEnv(EnvBase): +class BatchedEnvBase(EnvBase): """Batched environments allow the user to query an arbitrary method / attribute of the environment running remotely. Those queries will return a list of length equal to the number of workers containing the @@ -169,6 +169,9 @@ class _BatchedEnv(EnvBase): serial_for_single (bool, optional): if ``True``, creating a parallel environment with a single worker will return a :class:`~SerialEnv` instead. This option has no effect with :class:`~SerialEnv`. Defaults to ``False``. + non_blocking (bool, optional): if ``True``, device moves will be done using the + ``non_blocking=True`` option. Defaults to ``True`` for batched environments + on cuda devices, and ``False`` otherwise. Examples: >>> from torchrl.envs import GymEnv, ParallelEnv, SerialEnv, EnvCreator @@ -179,8 +182,8 @@ class _BatchedEnv(EnvBase): >>> env = ParallelEnv(2, [ ... lambda: DMControlEnv("humanoid", "stand"), ... lambda: DMControlEnv("humanoid", "walk")]) # Creates two independent copies of Humanoid, one that walks one that stands - >>> r = env.rollout(10) # executes 10 random steps in the environment - >>> r[0] # data for Humanoid stand + >>> rollout = env.rollout(10) # executes 10 random steps in the environment + >>> rollout[0] # data for Humanoid stand TensorDict( fields={ action: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False), @@ -211,7 +214,7 @@ class _BatchedEnv(EnvBase): batch_size=torch.Size([10]), device=cpu, is_shared=False) - >>> r[1] # data for Humanoid walk + >>> rollout[1] # data for Humanoid walk TensorDict( fields={ action: Tensor(shape=torch.Size([10, 21]), device=cpu, dtype=torch.float64, is_shared=False), @@ -242,6 +245,7 @@ class _BatchedEnv(EnvBase): batch_size=torch.Size([10]), device=cpu, is_shared=False) + >>> # serial_for_single to avoid creating parallel envs if not necessary >>> env = ParallelEnv(1, make_env, serial_for_single=True) >>> assert isinstance(env, SerialEnv) # serial_for_single allows you to avoid creating parallel envs when not necessary """ @@ -270,6 +274,7 @@ def __init__( num_threads: int = None, num_sub_threads: int = 1, serial_for_single: bool = False, + non_blocking: bool = False, ): super().__init__(device=device) self.serial_for_single = serial_for_single @@ -327,6 +332,15 @@ def __init__( # self._prepare_dummy_env(create_env_fn, create_env_kwargs) self._properties_set = False self._get_metadata(create_env_fn, create_env_kwargs) + self._non_blocking = non_blocking + + @property + def non_blocking(self): + nb = self._non_blocking + if nb is None: + nb = self.device is not None and self.device.type == "cuda" + self._non_blocking = nb + return nb def _get_metadata( self, create_env_fn: List[Callable], create_env_kwargs: List[Dict] @@ -654,6 +668,7 @@ def start(self) -> None: self._start_workers() def to(self, device: DEVICE_TYPING): + self._non_blocking = None device = torch.device(device) if device == self.device: return self @@ -675,10 +690,10 @@ def to(self, device: DEVICE_TYPING): return self -class SerialEnv(_BatchedEnv): +class SerialEnv(BatchedEnvBase): """Creates a series of environments in the same process.""" - __doc__ += _BatchedEnv.__doc__ + __doc__ += BatchedEnvBase.__doc__ _share_memory = False @@ -769,7 +784,9 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: else: env_device = _env.device if env_device != self.device and env_device is not None: - tensordict_ = tensordict_.to(env_device, non_blocking=True) + tensordict_ = tensordict_.to( + env_device, non_blocking=self.non_blocking + ) else: tensordict_ = tensordict_.clone(False) else: @@ -798,7 +815,7 @@ def select_and_clone(name, tensor): if device is None: out = out.clear_device_() else: - out = out.to(device, non_blocking=True) + out = out.to(device, non_blocking=self.non_blocking) return out def _reset_proc_data(self, tensordict, tensordict_reset): @@ -819,7 +836,9 @@ def _step( # There may be unexpected keys, such as "_reset", that we should comfortably ignore here. env_device = self._envs[i].device if env_device != self.device and env_device is not None: - data_in = tensordict_in[i].to(env_device, non_blocking=True) + data_in = tensordict_in[i].to( + env_device, non_blocking=self.non_blocking + ) else: data_in = tensordict_in[i] out_td = self._envs[i]._step(data_in) @@ -839,7 +858,7 @@ def select_and_clone(name, tensor): if device is None: out = out.clear_device_() elif out.device != device: - out = out.to(device, non_blocking=True) + out = out.to(device, non_blocking=self.non_blocking) return out def __getattr__(self, attr: str) -> Any: @@ -885,14 +904,14 @@ def to(self, device: DEVICE_TYPING): return self -class ParallelEnv(_BatchedEnv, metaclass=_PEnvMeta): +class ParallelEnv(BatchedEnvBase, metaclass=_PEnvMeta): """Creates one environment per process. TensorDicts are passed via shared memory or memory map. """ - __doc__ += _BatchedEnv.__doc__ + __doc__ += BatchedEnvBase.__doc__ __doc__ += """ .. warning:: @@ -1167,14 +1186,14 @@ def step_and_maybe_reset( tensordict_ = tensordict_.clone() elif device is not None: next_td = next_td._fast_apply( - lambda x: x.to(device, non_blocking=True) + lambda x: x.to(device, non_blocking=self.non_blocking) if x.device != device else x.clone(), device=device, filter_empty=True, ) tensordict_ = tensordict_._fast_apply( - lambda x: x.to(device, non_blocking=True) + lambda x: x.to(device, non_blocking=self.non_blocking) if x.device != device else x.clone(), device=device, @@ -1239,7 +1258,7 @@ def select_and_clone(name, tensor): if device is None: out.clear_device_() else: - out = out.to(device, non_blocking=True) + out = out.to(device, non_blocking=self.non_blocking) return out @_check_start @@ -1325,7 +1344,7 @@ def select_and_clone(name, tensor): if device is None: out.clear_device_() else: - out = out.to(device, non_blocking=True) + out = out.to(device, non_blocking=self.non_blocking) return out @_check_start @@ -1644,12 +1663,9 @@ def look_for_cuda(tensor, has_cuda=has_cuda): child_pipe.send(("_".join([cmd, "done"]), None)) -def _update_cuda(t_dest, t_source): - if t_source is None: - return - t_dest.copy_(t_source.pin_memory(), non_blocking=True) - return - - def _filter_empty(tensordict): return tensordict.select(*tensordict.keys(True, True)) + + +# Create an alias for possible imports +_BatchedEnv = BatchedEnvBase From 69d44f5cf4bf84eab0f21b0eea98112651f7f9a1 Mon Sep 17 00:00:00 2001 From: Albert Bou Date: Mon, 12 Feb 2024 11:37:49 +0100 Subject: [PATCH 08/10] [Feature] Replace RewardClipping with SignTransform in Atari examples (#1870) --- examples/a2c/utils_atari.py | 4 ++-- examples/dqn/utils_atari.py | 4 ++-- examples/impala/utils.py | 4 ++-- examples/ppo/utils_atari.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/a2c/utils_atari.py b/examples/a2c/utils_atari.py index 89a51f7e64b..0ddcd79123e 100644 --- a/examples/a2c/utils_atari.py +++ b/examples/a2c/utils_atari.py @@ -20,8 +20,8 @@ NoopResetEnv, ParallelEnv, Resize, - RewardClipping, RewardSum, + SignTransform, StepCounter, ToTensorImage, TransformedEnv, @@ -73,7 +73,7 @@ def make_parallel_env(env_name, num_envs, device, is_test=False): env.append_transform(RewardSum()) env.append_transform(StepCounter(max_steps=4500)) if not is_test: - env.append_transform(RewardClipping(-1, 1)) + env.append_transform(SignTransform(in_keys=["reward"])) env.append_transform(DoubleToFloat()) env.append_transform(VecNorm(in_keys=["pixels"])) return env diff --git a/examples/dqn/utils_atari.py b/examples/dqn/utils_atari.py index 24b6509147c..b9805659e63 100644 --- a/examples/dqn/utils_atari.py +++ b/examples/dqn/utils_atari.py @@ -14,8 +14,8 @@ GymEnv, NoopResetEnv, Resize, - RewardClipping, RewardSum, + SignTransform, StepCounter, ToTensorImage, TransformedEnv, @@ -42,7 +42,7 @@ def make_env(env_name, frame_skip, device, is_test=False): env.append_transform(NoopResetEnv(noops=30, random=True)) if not is_test: env.append_transform(EndOfLifeTransform()) - env.append_transform(RewardClipping(-1, 1)) + env.append_transform(SignTransform(in_keys=["reward"])) env.append_transform(ToTensorImage()) env.append_transform(GrayScale()) env.append_transform(Resize(84, 84)) diff --git a/examples/impala/utils.py b/examples/impala/utils.py index 2983f8a0193..b365dca3867 100644 --- a/examples/impala/utils.py +++ b/examples/impala/utils.py @@ -16,8 +16,8 @@ GymEnv, NoopResetEnv, Resize, - RewardClipping, RewardSum, + SignTransform, StepCounter, ToTensorImage, TransformedEnv, @@ -46,7 +46,7 @@ def make_env(env_name, device, is_test=False): env.append_transform(NoopResetEnv(noops=30, random=True)) if not is_test: env.append_transform(EndOfLifeTransform()) - env.append_transform(RewardClipping(-1, 1)) + env.append_transform(SignTransform(in_keys=["reward"])) env.append_transform(ToTensorImage(from_int=False)) env.append_transform(GrayScale()) env.append_transform(Resize(84, 84)) diff --git a/examples/ppo/utils_atari.py b/examples/ppo/utils_atari.py index eaef640ebb0..5cb838cac47 100644 --- a/examples/ppo/utils_atari.py +++ b/examples/ppo/utils_atari.py @@ -19,8 +19,8 @@ NoopResetEnv, ParallelEnv, Resize, - RewardClipping, RewardSum, + SignTransform, StepCounter, ToTensorImage, TransformedEnv, @@ -71,7 +71,7 @@ def make_parallel_env(env_name, num_envs, device, is_test=False): env.append_transform(RewardSum()) env.append_transform(StepCounter(max_steps=4500)) if not is_test: - env.append_transform(RewardClipping(-1, 1)) + env.append_transform(SignTransform(in_keys=["reward"])) env.append_transform(DoubleToFloat()) env.append_transform(VecNorm(in_keys=["pixels"])) return env From 2f9e1ae05413ff8106d5fcac65245660d39c2d33 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 12 Feb 2024 14:59:32 +0000 Subject: [PATCH 09/10] [Minor] Remove warnings in test_cost (#1902) --- test/conftest.py | 7 +- test/test_cost.py | 147 ++++++++++++++++++--- torchrl/objectives/a2c.py | 25 ++-- torchrl/objectives/decision_transformer.py | 1 - torchrl/objectives/reinforce.py | 24 ++-- torchrl/objectives/sac.py | 4 +- 6 files changed, 165 insertions(+), 43 deletions(-) diff --git a/test/conftest.py b/test/conftest.py index 5ce980a4080..2dcd369003a 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -53,7 +53,7 @@ def fin(): request.addfinalizer(fin) -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(autouse=True) def set_warnings() -> None: warnings.filterwarnings( "ignore", @@ -65,6 +65,11 @@ def set_warnings() -> None: category=UserWarning, message=r"Couldn't cast the policy onto the desired device on remote process", ) + warnings.filterwarnings( + "ignore", + category=UserWarning, + message=r"Skipping device Apple Paravirtual device", + ) warnings.filterwarnings( "ignore", category=DeprecationWarning, diff --git a/test/test_cost.py b/test/test_cost.py index 064a38ced60..3c07f5f79f4 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -156,6 +156,10 @@ ) +# Capture all warnings +pytestmark = pytest.mark.filterwarnings("error") + + class _check_td_steady: def __init__(self, td): self.td_clone = td.clone() @@ -501,6 +505,11 @@ def test_dqn(self, delay_value, double_dqn, device, action_spec_type, td_est): else contextlib.nullcontext() ), _check_td_steady(td): loss = loss_fn(td) + + if delay_value: + # remove warning + SoftUpdate(loss_fn, eps=0.5) + assert loss_fn.tensor_keys.priority in td.keys() sum([item for _, item in loss.items()]).backward() @@ -562,6 +571,10 @@ def test_dqn_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9): loss_ms = loss_fn(ms_td) assert loss_fn.tensor_keys.priority in ms_td.keys() + if delay_value: + # remove warning + SoftUpdate(loss_fn, eps=0.5) + with torch.no_grad(): loss = loss_fn(td) if n == 0: @@ -601,7 +614,7 @@ def test_dqn_tensordict_keys(self, td_est): torch.manual_seed(self.seed) action_spec_type = "one_hot" actor = self._create_mock_actor(action_spec_type=action_spec_type) - loss_fn = DQNLoss(actor) + loss_fn = DQNLoss(actor, delay_value=True) default_keys = { "advantage": "advantage", @@ -617,7 +630,7 @@ def test_dqn_tensordict_keys(self, td_est): self.tensordict_keys_test(loss_fn, default_keys=default_keys) - loss_fn = DQNLoss(actor) + loss_fn = DQNLoss(actor, delay_value=True) key_mapping = { "advantage": ("advantage", "advantage_2"), "value_target": ("value_target", ("value_target", "nested")), @@ -630,7 +643,7 @@ def test_dqn_tensordict_keys(self, td_est): actor = self._create_mock_actor( action_spec_type=action_spec_type, action_value_key="chosen_action_value_2" ) - loss_fn = DQNLoss(actor) + loss_fn = DQNLoss(actor, delay_value=True) key_mapping = { "value": ("value", "chosen_action_value_2"), } @@ -657,11 +670,14 @@ def test_dqn_tensordict_run(self, action_spec_type, td_est): action_value_key=tensor_keys["action_value"], ) - loss_fn = DQNLoss(actor, loss_function="l2") + loss_fn = DQNLoss(actor, loss_function="l2", delay_value=True) loss_fn.set_keys(**tensor_keys) if td_est is not None: loss_fn.make_value_estimator(td_est) + + SoftUpdate(loss_fn, eps=0.5) + with _check_td_steady(td): _ = loss_fn(td) assert loss_fn.tensor_keys.priority in td.keys() @@ -707,6 +723,10 @@ def test_distributional_dqn( sum([item for _, item in loss.items()]).backward() assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 + if delay_value: + # remove warning + SoftUpdate(loss_fn, eps=0.5) + # Check param update effect on targets target_value = loss_fn.target_value_network_params.clone() for p in loss_fn.parameters(): @@ -744,7 +764,7 @@ def test_dqn_notensordict( module=module, in_keys=[observation_key], ) - dqn_loss = DQNLoss(actor) + dqn_loss = DQNLoss(actor, delay_value=True) dqn_loss.set_keys(reward=reward_key, done=done_key, terminated=terminated_key) # define data observation = torch.randn(n_obs) @@ -762,6 +782,8 @@ def test_dqn_notensordict( "action": action, } td = TensorDict(kwargs, []).unflatten_keys("_") + # Disable warning + SoftUpdate(dqn_loss, eps=0.5) loss_val = dqn_loss(**kwargs) loss_val_td = dqn_loss(td) torch.testing.assert_close(loss_val_td.get("loss"), loss_val) @@ -775,7 +797,7 @@ def test_distributional_dqn_tensordict_keys(self): action_spec_type=action_spec_type, atoms=atoms ) - loss_fn = DistributionalDQNLoss(actor, gamma=gamma) + loss_fn = DistributionalDQNLoss(actor, gamma=gamma, delay_value=True) default_keys = { "priority": "td_error", @@ -810,11 +832,14 @@ def test_distributional_dqn_tensordict_run(self, action_spec_type, td_est): action_key=tensor_keys["action"], action_value_key=tensor_keys["action_value"], ) - loss_fn = DistributionalDQNLoss(actor, gamma=0.9) + loss_fn = DistributionalDQNLoss(actor, gamma=0.9, delay_value=True) loss_fn.set_keys(**tensor_keys) loss_fn.make_value_estimator(td_est) + # remove warnings + SoftUpdate(loss_fn, eps=0.5) + with _check_td_steady(td): _ = loss_fn(td) assert loss_fn.tensor_keys.priority in td.keys() @@ -984,6 +1009,10 @@ def test_qmixer(self, delay_value, device, action_spec_type, td_est): sum([item for _, item in loss.items()]).backward() assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 + if delay_value: + # remove warning + SoftUpdate(loss_fn, eps=0.5) + # Check param update effect on targets target_value = loss_fn.target_local_value_network_params.clone() for p in loss_fn.parameters(): @@ -1051,6 +1080,11 @@ def test_qmix_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9) else contextlib.nullcontext() ), _check_td_steady(ms_td): loss_ms = loss_fn(ms_td) + + if delay_value: + # remove warning + SoftUpdate(loss_fn, eps=0.5) + assert loss_fn.tensor_keys.priority in ms_td.keys() with torch.no_grad(): @@ -1105,7 +1139,7 @@ def test_qmix_tensordict_keys(self, td_est): action_spec_type = "one_hot" actor = self._create_mock_actor(action_spec_type=action_spec_type) mixer = self._create_mock_mixer() - loss_fn = QMixerLoss(actor, mixer) + loss_fn = QMixerLoss(actor, mixer, delay_value=True) default_keys = { "advantage": "advantage", @@ -1122,7 +1156,7 @@ def test_qmix_tensordict_keys(self, td_est): self.tensordict_keys_test(loss_fn, default_keys=default_keys) - loss_fn = QMixerLoss(actor, mixer) + loss_fn = QMixerLoss(actor, mixer, delay_value=True) key_mapping = { "advantage": ("advantage", "advantage_2"), "value_target": ("value_target", ("value_target", "nested")), @@ -1138,7 +1172,7 @@ def test_qmix_tensordict_keys(self, td_est): mixer = self._create_mock_mixer( global_chosen_action_value_key=("some", "nested") ) - loss_fn = QMixerLoss(actor, mixer) + loss_fn = QMixerLoss(actor, mixer, delay_value=True) key_mapping = { "global_value": ("value", ("some", "nested")), } @@ -1173,9 +1207,9 @@ def test_qmix_tensordict_run(self, action_spec_type, td_est): action_value_key=tensor_keys["action_value"], ) - loss_fn = QMixerLoss(actor, mixer, loss_function="l2") + loss_fn = QMixerLoss(actor, mixer, loss_function="l2", delay_value=True) loss_fn.set_keys(**tensor_keys) - + SoftUpdate(loss_fn, eps=0.5) if td_est is not None: loss_fn.make_value_estimator(td_est) with _check_td_steady(td): @@ -1231,7 +1265,9 @@ def test_mixer_keys( ) td = actor(td) - loss = QMixerLoss(actor, mixer) + loss = QMixerLoss(actor, mixer, delay_value=True) + + SoftUpdate(loss, eps=0.5) # Wthout etting the keys if mixer_local_chosen_action_value_key != ("agents", "chosen_action_value"): @@ -1245,7 +1281,10 @@ def test_mixer_keys( else: loss(td) - loss = QMixerLoss(actor, mixer) + loss = QMixerLoss(actor, mixer, delay_value=True) + + SoftUpdate(loss, eps=0.5) + # When setting the key loss.set_keys(global_value=mixer_global_chosen_action_value_key) if mixer_local_chosen_action_value_key != ("agents", "chosen_action_value"): @@ -1466,6 +1505,10 @@ def test_ddpg(self, delay_actor, delay_value, device, td_est): ): loss = loss_fn(td) + if delay_value: + # remove warning + SoftUpdate(loss_fn, eps=0.5) + assert all( (p.grad is None) or (p.grad == 0).all() for p in loss_fn.value_network_params.values(True, True) @@ -1582,6 +1625,9 @@ def test_ddpg_separate_losses( with pytest.warns(UserWarning, match="No target network updater has been"): loss = loss_fn(td) + # remove warning + SoftUpdate(loss_fn, eps=0.5) + assert all( (p.grad is None) or (p.grad == 0).all() for p in loss_fn.value_network_params.values(True, True) @@ -1702,6 +1748,11 @@ def test_ddpg_batcher(self, n, delay_actor, delay_value, device, gamma=0.9): else contextlib.nullcontext() ), _check_td_steady(ms_td): loss_ms = loss_fn(ms_td) + + if delay_value: + # remove warning + SoftUpdate(loss_fn, eps=0.5) + with torch.no_grad(): loss = loss_fn(td) if n == 0: @@ -2304,10 +2355,14 @@ def test_td3_batcher( loss_ms = loss_fn(ms_td) assert loss_fn.tensor_keys.priority in ms_td.keys() + if delay_qvalue or delay_actor: + SoftUpdate(loss_fn, eps=0.5) + with torch.no_grad(): torch.manual_seed(0) # log-prob is computed with a random action np.random.seed(0) loss = loss_fn(td) + if n == 0: assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) _loss = sum([item for _, item in loss.items()]) @@ -3291,6 +3346,9 @@ def test_sac_notensordict( loss_val = loss(**kwargs) torch.manual_seed(self.seed) + + SoftUpdate(loss, eps=0.5) + loss_val_td = loss(td) if version == 1: @@ -3538,6 +3596,7 @@ def test_discrete_sac( target_entropy_weight=target_entropy_weight, target_entropy=target_entropy, loss_function="l2", + action_space="one-hot", **kwargs, ) if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): @@ -3648,6 +3707,7 @@ def test_discrete_sac_state_dict( target_entropy_weight=target_entropy_weight, target_entropy=target_entropy, loss_function="l2", + action_space="one-hot", **kwargs, ) sd = loss_fn.state_dict() @@ -3659,6 +3719,7 @@ def test_discrete_sac_state_dict( target_entropy_weight=target_entropy_weight, target_entropy=target_entropy, loss_function="l2", + action_space="one-hot", **kwargs, ) loss_fn2.load_state_dict(sd) @@ -3696,6 +3757,7 @@ def test_discrete_sac_batcher( loss_function="l2", target_entropy_weight=target_entropy_weight, target_entropy=target_entropy, + action_space="one-hot", **kwargs, ) @@ -3712,6 +3774,8 @@ def test_discrete_sac_batcher( loss_ms = loss_fn(ms_td) assert loss_fn.tensor_keys.priority in ms_td.keys() + SoftUpdate(loss_fn, eps=0.5) + with torch.no_grad(): torch.manual_seed(0) # log-prob is computed with a random action np.random.seed(0) @@ -3800,6 +3864,7 @@ def test_discrete_sac_tensordict_keys(self, td_est): qvalue_network=qvalue, num_actions=actor.spec["action"].space.n, loss_function="l2", + action_space="one-hot", ) default_keys = { @@ -3822,6 +3887,7 @@ def test_discrete_sac_tensordict_keys(self, td_est): qvalue_network=qvalue, num_actions=actor.spec["action"].space.n, loss_function="l2", + action_space="one-hot", ) key_mapping = { @@ -3860,6 +3926,7 @@ def test_discrete_sac_notensordict( actor_network=actor, qvalue_network=qvalue, num_actions=actor.spec[action_key].space.n, + action_space="one-hot", ) loss.set_keys( action=action_key, @@ -4370,6 +4437,8 @@ def test_redq_deprecated_separate_losses(self, separate_losses): ): loss = loss_fn(td) + SoftUpdate(loss_fn, eps=0.5) + # check that losses are independent for k in loss.keys(): if not k.startswith("loss"): @@ -5408,6 +5477,9 @@ def test_dcql(self, delay_value, device, action_spec_type, td_est): loss = loss_fn(td) assert loss_fn.tensor_keys.priority in td.keys(True) + if delay_value: + SoftUpdate(loss_fn, eps=0.5) + sum([item for key, item in loss.items() if key.startswith("loss")]).backward() assert torch.nn.utils.clip_grad.clip_grad_norm_(actor.parameters(), 1.0) > 0.0 @@ -5467,6 +5539,9 @@ def test_dcql_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9) loss_ms = loss_fn(ms_td) assert loss_fn.tensor_keys.priority in ms_td.keys() + if delay_value: + SoftUpdate(loss_fn, eps=0.5) + with torch.no_grad(): loss = loss_fn(td) if n == 0: @@ -5510,7 +5585,7 @@ def test_dcql_tensordict_keys(self, td_est): torch.manual_seed(self.seed) action_spec_type = "one_hot" actor = self._create_mock_actor(action_spec_type=action_spec_type) - loss_fn = DQNLoss(actor) + loss_fn = DQNLoss(actor, delay_value=True) default_keys = { "value_target": "value_target", @@ -5566,6 +5641,8 @@ def test_dcql_tensordict_run(self, action_spec_type, td_est): loss_fn = DiscreteCQLLoss(actor, loss_function="l2") loss_fn.set_keys(**tensor_keys) + SoftUpdate(loss_fn, eps=0.5) + if td_est is not None: loss_fn.make_value_estimator(td_est) with _check_td_steady(td): @@ -5590,6 +5667,9 @@ def test_dcql_notensordict( in_keys=[observation_key], ) loss = DiscreteCQLLoss(actor) + + SoftUpdate(loss, eps=0.5) + loss.set_keys(reward=reward_key, done=done_key, terminated=terminated_key) # define data observation = torch.randn(n_obs) @@ -8783,6 +8863,9 @@ def test_iql_batcher( loss_ms = loss_fn(ms_td) assert loss_fn.tensor_keys.priority in ms_td.keys() + # Remove warnings + SoftUpdate(loss_fn, eps=0.5) + with torch.no_grad(): torch.manual_seed(0) # log-prob is computed with a random action np.random.seed(0) @@ -9206,6 +9289,7 @@ def test_discrete_iql( temperature=temperature, expectile=expectile, loss_function="l2", + action_space="one-hot", ) if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): @@ -9328,6 +9412,7 @@ def test_discrete_iql_state_dict( temperature=temperature, expectile=expectile, loss_function="l2", + action_space="one-hot", ) sd = loss_fn.state_dict() loss_fn2 = DiscreteIQLLoss( @@ -9338,6 +9423,7 @@ def test_discrete_iql_state_dict( temperature=temperature, expectile=expectile, loss_function="l2", + action_space="one-hot", ) loss_fn2.load_state_dict(sd) @@ -9351,6 +9437,7 @@ def test_discrete_iql_separate_losses(self, separate_losses): value_network=value, loss_function="l2", separate_losses=separate_losses, + action_space="one-hot", ) with pytest.warns(UserWarning, match="No target network updater has been"): loss = loss_fn(td) @@ -9529,6 +9616,7 @@ def test_discrete_iql_batcher( temperature=temperature, expectile=expectile, loss_function="l2", + action_space="one-hot", ) ms = MultiStep(gamma=gamma, n_steps=n).to(device) @@ -9544,6 +9632,8 @@ def test_discrete_iql_batcher( loss_ms = loss_fn(ms_td) assert loss_fn.tensor_keys.priority in ms_td.keys() + SoftUpdate(loss_fn, eps=0.5) + with torch.no_grad(): torch.manual_seed(0) # log-prob is computed with a random action np.random.seed(0) @@ -9615,6 +9705,7 @@ def test_discrete_iql_tensordict_keys(self, td_est): qvalue_network=qvalue, value_network=value, loss_function="l2", + action_space="one-hot", ) default_keys = { @@ -9640,6 +9731,7 @@ def test_discrete_iql_tensordict_keys(self, td_est): qvalue_network=qvalue, value_network=value, loss_function="l2", + action_space="one-hot", ) key_mapping = { @@ -9675,7 +9767,10 @@ def test_discrete_iql_notensordict( value = self._create_mock_value(observation_key=observation_key) loss = DiscreteIQLLoss( - actor_network=actor, qvalue_network=qvalue, value_network=value + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + action_space="one-hot", ) loss.set_keys( action=action_key, @@ -9744,6 +9839,10 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: out_keys=["action"], ) loss = MyLoss(actor_module) + + if create_target_params: + SoftUpdate(loss, eps=0.5) + if cast is not None: loss.to(cast) for name in ("weight", "bias"): @@ -9873,11 +9972,13 @@ def __init__(self, delay_module=True): self.convert_to_functional( module1, "module1", create_target_params=delay_module ) + module2 = torch.nn.BatchNorm2d(10).eval() self.module2 = module2 - iterator_params = self.target_module1_params.values( - include_nested=True, leaves_only=True - ) + tparam = self._modules.get("target_module1_params", None) + if tparam is None: + tparam = self._modules.get("module1_params").data + iterator_params = tparam.values(include_nested=True, leaves_only=True) for target in iterator_params: if target.dtype is not torch.int64: target.data.normal_() @@ -12285,10 +12386,14 @@ def test_single_call(self, has_target, value_key, single_call, detach_next=True) def test_instantiate_with_different_keys(): - loss_1 = DQNLoss(value_network=nn.Linear(3, 3), action_space="one_hot") + loss_1 = DQNLoss( + value_network=nn.Linear(3, 3), action_space="one_hot", delay_value=True + ) loss_1.set_keys(reward="a") assert loss_1.tensor_keys.reward == "a" - loss_2 = DQNLoss(value_network=nn.Linear(3, 3), action_space="one_hot") + loss_2 = DQNLoss( + value_network=nn.Linear(3, 3), action_space="one_hot", delay_value=True + ) loss_2.set_keys(reward="b") assert loss_1.tensor_keys.reward == "a" diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 6edcda5c800..8fcbd5a6699 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -255,7 +255,8 @@ def __init__( if functional: self.convert_to_functional( - actor_network, "actor_network", funs_to_decorate=["forward", "get_dist"] + actor_network, + "actor_network", ) else: self.actor_network = actor_network @@ -350,7 +351,7 @@ def in_keys(self): *[("next", key) for key in self.actor_network.in_keys], ] if self.critic_coef: - keys.extend(self.critic.in_keys) + keys.extend(self.critic_network.in_keys) return list(set(keys)) @property @@ -414,11 +415,11 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: # TODO: if the advantage is gathered by forward, this introduces an # overhead that we could easily reduce. target_return = tensordict.get(self.tensor_keys.value_target) - tensordict_select = tensordict.select(*self.critic.in_keys) + tensordict_select = tensordict.select(*self.critic_network.in_keys) with self.critic_network_params.to_module( - self.critic + self.critic_network ) if self.functional else contextlib.nullcontext(): - state_value = self.critic( + state_value = self.critic_network( tensordict_select, ).get(self.tensor_keys.value) loss_value = distance_loss( @@ -477,13 +478,19 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams if hasattr(self, "gamma"): hp["gamma"] = self.gamma if value_type == ValueEstimators.TD1: - self._value_estimator = TD1Estimator(value_network=self.critic, **hp) + self._value_estimator = TD1Estimator( + value_network=self.critic_network, **hp + ) elif value_type == ValueEstimators.TD0: - self._value_estimator = TD0Estimator(value_network=self.critic, **hp) + self._value_estimator = TD0Estimator( + value_network=self.critic_network, **hp + ) elif value_type == ValueEstimators.GAE: - self._value_estimator = GAE(value_network=self.critic, **hp) + self._value_estimator = GAE(value_network=self.critic_network, **hp) elif value_type == ValueEstimators.TDLambda: - self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp) + self._value_estimator = TDLambdaEstimator( + value_network=self.critic_network, **hp + ) elif value_type == ValueEstimators.VTrace: # VTrace currently does not support functional call on the actor if self.functional: diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index a24aa4a1271..954bd0b9a42 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -275,7 +275,6 @@ def __init__( actor_network, "actor_network", create_target_params=False, - funs_to_decorate=["forward"], ) self.loss_function = loss_function diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 4613810d0d3..9738b922c5d 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -351,7 +351,7 @@ def _set_in_keys(self): ("next", self.tensor_keys.terminated), *self.actor_network.in_keys, *[("next", key) for key in self.actor_network.in_keys], - *self.critic.in_keys, + *self.critic_network.in_keys, ] self._in_keys = list(set(keys)) @@ -398,11 +398,13 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: try: target_return = tensordict.get(self.tensor_keys.value_target) - tensordict_select = tensordict.select(*self.critic.in_keys) + tensordict_select = tensordict.select(*self.critic_network.in_keys) with self.critic_network_params.to_module( - self.critic + self.critic_network ) if self.functional else contextlib.nullcontext(): - state_value = self.critic(tensordict_select).get(self.tensor_keys.value) + state_value = self.critic_network(tensordict_select).get( + self.tensor_keys.value + ) loss_value = distance_loss( target_return, state_value, @@ -427,13 +429,19 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams hp["gamma"] = self.gamma hp.update(hyperparams) if value_type == ValueEstimators.TD1: - self._value_estimator = TD1Estimator(value_network=self.critic, **hp) + self._value_estimator = TD1Estimator( + value_network=self.critic_network, **hp + ) elif value_type == ValueEstimators.TD0: - self._value_estimator = TD0Estimator(value_network=self.critic, **hp) + self._value_estimator = TD0Estimator( + value_network=self.critic_network, **hp + ) elif value_type == ValueEstimators.GAE: - self._value_estimator = GAE(value_network=self.critic, **hp) + self._value_estimator = GAE(value_network=self.critic_network, **hp) elif value_type == ValueEstimators.TDLambda: - self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp) + self._value_estimator = TDLambdaEstimator( + value_network=self.critic_network, **hp + ) elif value_type == ValueEstimators.VTrace: # VTrace currently does not support functional call on the actor if self.functional: diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 053da9e53d2..5b722fd05f3 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -292,7 +292,6 @@ def __init__( actor_network, "actor_network", create_target_params=self.delay_actor, - funs_to_decorate=["forward", "get_dist"], ) if separate_losses: # we want to make sure there are no duplicates in the params: the @@ -980,7 +979,6 @@ def __init__( actor_network, "actor_network", create_target_params=self.delay_actor, - funs_to_decorate=["forward", "get_dist"], ) if separate_losses: # we want to make sure there are no duplicates in the params: the @@ -1036,7 +1034,7 @@ def __init__( if action_space is None: warnings.warn( "action_space was not specified. DiscreteSACLoss will default to 'one-hot'." - "This behaviour will be deprecated soon and a space will have to be passed." + "This behaviour will be deprecated soon and a space will have to be passed. " "Check the DiscreteSACLoss documentation to see how to pass the action space. " ) action_space = "one-hot" From 899af07fc10538af528e30e2caa8a67c18bf8164 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 12 Feb 2024 20:14:24 +0000 Subject: [PATCH 10/10] [BugFix] Make KL-controllers independent of the model (#1903) --- docs/source/reference/data.rst | 3 ++ examples/rlhf/train_rlhf.py | 4 +-- torchrl/data/rlhf/utils.py | 59 +++++++++++++++++++++++----------- 3 files changed, 44 insertions(+), 22 deletions(-) diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index d426a112b72..47ffb64753b 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -702,6 +702,9 @@ efficient sampling. TokenizedDatasetLoader create_infinite_iterator get_dataloader + ConstantKLController + AdaptiveKLController + Utils ----- diff --git a/examples/rlhf/train_rlhf.py b/examples/rlhf/train_rlhf.py index a921e58bad6..94d9234db2a 100644 --- a/examples/rlhf/train_rlhf.py +++ b/examples/rlhf/train_rlhf.py @@ -100,9 +100,7 @@ def main(cfg): # using a Gym-like API (querying steps etc) introduces some # extra code that we can spare. # - kl_scheduler = AdaptiveKLController( - model, init_kl_coef=0.1, target=6, horizon=10000 - ) + kl_scheduler = AdaptiveKLController(init_kl_coef=0.1, target=6, horizon=10000) rollout_from_model = RolloutFromModel( model, ref_model, diff --git a/torchrl/data/rlhf/utils.py b/torchrl/data/rlhf/utils.py index 311b2584aa5..a4ccbfd8a1b 100644 --- a/torchrl/data/rlhf/utils.py +++ b/torchrl/data/rlhf/utils.py @@ -7,13 +7,13 @@ import abc import collections import importlib -from typing import Sequence, Tuple +from typing import List, Tuple import numpy as np import torch from tensordict import TensorDict -from torch import Tensor +from torch import nn, Tensor from torch.nn import functional as F from torchrl.data.rlhf.prompt import PromptData @@ -30,8 +30,8 @@ class KLControllerBase(abc.ABC): """ @abc.abstractmethod - def update(self, kl_values: float): - pass + def update(self, kl_values: List[float]) -> float: + ... class ConstantKLController(KLControllerBase): @@ -40,30 +40,39 @@ class ConstantKLController(KLControllerBase): This controller maintains a fixed coefficient no matter what values it is updated with. - Arguments: - model: wrapped model that needs to be controlled. Must have attribute 'kl_coef' + Keyword Arguments: kl_coef (float): The coefficient to multiply KL with when calculating the reward. + model (nn.Module, optional): wrapped model that needs to be controlled. + Must have an attribute ``"kl_coef"``. If provided, the ``"kl_coef"`` will + be updated in-place. """ - def __init__(self, model, kl_coef): + def __init__( + self, + *, + kl_coef: float = None, + model: nn.Module | None = None, + ): self.model = model - if not hasattr(model, "kl_coef"): + if model is not None and not hasattr(model, "kl_coef"): raise AttributeError( "Model input to ConstantKLController doesn't have attribute 'kl_coef'" ) self.coef = kl_coef - self.model.kl_coef = self.coef + if model is not None: + self.model.kl_coef = self.coef - def update(self, kl_values: Sequence[float] = None): - self.model.kl_coef = self.coef + def update(self, kl_values: List[float] = None) -> float: + if self.model is not None: + self.model.kl_coef = self.coef + return self.coef class AdaptiveKLController(KLControllerBase): """Adaptive KL Controller as described in Ziegler et al. "Fine-Tuning Language Models from Human Preferences". - Arguments: - model: wrapped model that needs to be controlled. Must have attribute 'kl_coef' + Keyword Arguments: init_kl_coef (float): The starting value of the coefficient. target (float): The target KL value. When the observed KL is smaller, the coefficient is decreased, thereby relaxing the KL penalty in the training @@ -72,19 +81,30 @@ class AdaptiveKLController(KLControllerBase): increased, thereby pulling the model back towards the reference model. horizon (int): Scaling factor to control how aggressively we update the coefficient. + model (nn.Module, optional): wrapped model that needs to be controlled. + Must have an attribute ``"kl_coef"``. If provided, the ``"kl_coef"`` will + be updated in-place. Reference: Section 2.2 https://arxiv.org/pdf/1909.08593.pdf#page=2 Source: https://github.com/openai/lm-human-preferences/blob/master/lm_human_preferences/train_policy.py """ - def __init__(self, model, init_kl_coef: float, target: float, horizon: int): + def __init__( + self, + *, + init_kl_coef: float, + target: float, + horizon: int, + model: nn.Module | None = None, + ): self.model = model self.coef = init_kl_coef self.target = target self.horizon = horizon - self.model.kl_coef = self.coef + if model is not None: + self.model.kl_coef = self.coef - def update(self, kl_values: Sequence[float]): + def update(self, kl_values: List[float]): """Update ``self.coef`` adaptively. Arguments: @@ -104,6 +124,9 @@ def update(self, kl_values: Sequence[float]): proportional_error = np.clip(kl_value / self.target - 1, -0.2, 0.2) # ϵₜ mult = 1 + proportional_error * n_steps / self.horizon self.coef *= mult # βₜ₊₁ + if self.model is not None: + self.model.kl_coef = self.coef + return self.coef class RolloutFromModel: @@ -233,8 +256,6 @@ def create_rollout_td(self, batch, generated, log_probs, log_ratio): log_ratio (torch.Tensor): The log ratio of the probabilities of the generated tokens according to the generative model and the reference model. Can be obtained by calling the ``generate`` method. - kl_coef (float, optional): Coefficient with which to multiply the KL term before subtracting - from the reward. Defaults to 0.1. Returns: A :class:`~tensordict.TensorDict` with the following keys: @@ -514,7 +535,7 @@ def generate(self, batch: PromptData, generation_config=None): def step_scheduler(self): # recover true kl - self.kl_scheduler.update(self._kl_queue) + self.kl_coef = self.kl_scheduler.update(self._kl_queue) if isinstance(self._kl_queue, (list, collections.deque)): # remove all values while len(self._kl_queue):