From 90572ac118c9e2afdb2d0aa53a981fef0d8399a5 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 25 Nov 2024 14:24:25 +0000 Subject: [PATCH 01/27] [Doc] Better doc for SliceSampler ghstack-source-id: 7d79ef7d37c4dc2ffbdff5b422cf5da24d93c0da Pull Request resolved: https://github.com/pytorch/rl/pull/2607 --- torchrl/data/replay_buffers/samplers.py | 156 ++++++++++++++++++++++++ 1 file changed, 156 insertions(+) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 51b84029766..b97b585aa3f 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -802,6 +802,112 @@ class SliceSampler(Sampler): attempt to find the ``traj_key`` entry in the storage. If it cannot be found, the ``end_key`` will be used to reconstruct the episodes. + .. note:: When using `strict_length=False`, it is recommended to use + :func:`~torchrl.collectors.utils.split_trajectories` to split the sampled trajectories. + However, if two samples from the same episode are placed next to each other, + this may produce incorrect results. To avoid this issue, consider one of these solutions: + + - using a :class:`~torchrl.data.TensorDictReplayBuffer` instance with the slice sampler + + >>> import torch + >>> from tensordict import TensorDict + >>> from torchrl.collectors.utils import split_trajectories + >>> from torchrl.data import TensorDictReplayBuffer, ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement + >>> + >>> rb = TensorDictReplayBuffer(storage=LazyTensorStorage(max_size=1000), + ... sampler=SliceSampler( + ... slice_len=5, traj_key="episode",strict_length=False, + ... )) + ... + >>> ep_1 = TensorDict( + ... {"obs": torch.arange(100), + ... "episode": torch.zeros(100),}, + ... batch_size=[100] + ... ) + >>> ep_2 = TensorDict( + ... {"obs": torch.arange(4), + ... "episode": torch.ones(4),}, + ... batch_size=[4] + ... ) + >>> rb.extend(ep_1) + >>> rb.extend(ep_2) + >>> + >>> s = rb.sample(50) + >>> print(s) + TensorDict( + fields={ + episode: Tensor(shape=torch.Size([46]), device=cpu, dtype=torch.float32, is_shared=False), + index: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.int64, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False), + terminated: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([46, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([46]), + device=cpu, + is_shared=False), + obs: Tensor(shape=torch.Size([46]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([46]), + device=cpu, + is_shared=False) + >>> t = split_trajectories(s, done_key="truncated") + >>> print(t["obs"]) + tensor([[73, 74, 75, 76, 77], + [ 0, 1, 2, 3, 0], + [ 0, 1, 2, 3, 0], + [41, 42, 43, 44, 45], + [ 0, 1, 2, 3, 0], + [67, 68, 69, 70, 71], + [27, 28, 29, 30, 31], + [80, 81, 82, 83, 84], + [17, 18, 19, 20, 21], + [ 0, 1, 2, 3, 0]]) + >>> print(t["episode"]) + tensor([[0., 0., 0., 0., 0.], + [1., 1., 1., 1., 0.], + [1., 1., 1., 1., 0.], + [0., 0., 0., 0., 0.], + [1., 1., 1., 1., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [1., 1., 1., 1., 0.]]) + + - using a :class:`~torchrl.data.replay_buffers.samplers.SliceSamplerWithoutReplacement` + + >>> import torch + >>> from tensordict import TensorDict + >>> from torchrl.collectors.utils import split_trajectories + >>> from torchrl.data import ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement + >>> + >>> rb = ReplayBuffer(storage=LazyTensorStorage(max_size=1000), + ... sampler=SliceSamplerWithoutReplacement( + ... slice_len=5, traj_key="episode",strict_length=False + ... )) + ... + >>> ep_1 = TensorDict( + ... {"obs": torch.arange(100), + ... "episode": torch.zeros(100),}, + ... batch_size=[100] + ... ) + >>> ep_2 = TensorDict( + ... {"obs": torch.arange(4), + ... "episode": torch.ones(4),}, + ... batch_size=[4] + ... ) + >>> rb.extend(ep_1) + >>> rb.extend(ep_2) + >>> + >>> s = rb.sample(50) + >>> t = split_trajectories(s, trajectory_key="episode") + >>> print(t["obs"]) + tensor([[75, 76, 77, 78, 79], + [ 0, 1, 2, 3, 0]]) + >>> print(t["episode"]) + tensor([[0., 0., 0., 0., 0.], + [1., 1., 1., 1., 0.]]) + Examples: >>> import torch >>> from tensordict import TensorDict @@ -1427,6 +1533,10 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement): """Samples slices of data along the first dimension, given start and stop signals, without replacement. + In this context, ``without replacement`` means that the same element (NOT trajectory) will not be sampled twice + before the counter is automatically reset. Within a single sample, however, only one slice of a given trajectory + will appear (see example below). + This class is to be used with static replay buffers or in between two replay buffer extensions. Extending the replay buffer will reset the the sampler, and continuous sampling without replacement is currently not @@ -1533,6 +1643,52 @@ class SliceSamplerWithoutReplacement(SliceSampler, SamplerWithoutReplacement): tensor([ 1, 2, 7, 9, 10, 13, 15, 18, 21, 22]) tensor([ 0, 3, 4, 20, 23]) + When requesting a large total number of samples with few trajectories and small span, the batch will contain + only at most one sample of each trajectory: + + Examples: + >>> import torch + >>> from tensordict import TensorDict + >>> from torchrl.collectors.utils import split_trajectories + >>> from torchrl.data import ReplayBuffer, LazyTensorStorage, SliceSampler, SliceSamplerWithoutReplacement + >>> + >>> rb = ReplayBuffer(storage=LazyTensorStorage(max_size=1000), + ... sampler=SliceSamplerWithoutReplacement( + ... slice_len=5, traj_key="episode",strict_length=False + ... )) + ... + >>> ep_1 = TensorDict( + ... {"obs": torch.arange(100), + ... "episode": torch.zeros(100),}, + ... batch_size=[100] + ... ) + >>> ep_2 = TensorDict( + ... {"obs": torch.arange(51), + ... "episode": torch.ones(51),}, + ... batch_size=[51] + ... ) + >>> rb.extend(ep_1) + >>> rb.extend(ep_2) + >>> + >>> s = rb.sample(50) + >>> t = split_trajectories(s, trajectory_key="episode") + >>> print(t["obs"]) + tensor([[14, 15, 16, 17, 18], + [ 3, 4, 5, 6, 7]]) + >>> print(t["episode"]) + tensor([[0., 0., 0., 0., 0.], + [1., 1., 1., 1., 1.]]) + >>> + >>> s = rb.sample(50) + >>> t = split_trajectories(s, trajectory_key="episode") + >>> print(t["obs"]) + tensor([[ 4, 5, 6, 7, 8], + [26, 27, 28, 29, 30]]) + >>> print(t["episode"]) + tensor([[0., 0., 0., 0., 0.], + [1., 1., 1., 1., 1.]]) + + """ def __init__( From d537dcb6347e2370fcdaed553bf3474d653cb5a6 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 25 Nov 2024 21:42:10 +0000 Subject: [PATCH 02/27] [Feature] EnvBase.auto_specs_ ghstack-source-id: 329679238c5172d7ff13097ceaa189479d4f4145 Pull Request resolved: https://github.com/pytorch/rl/pull/2601 --- test/mocking_classes.py | 14 ++- test/test_env.py | 28 +++++ test/test_specs.py | 9 ++ torchrl/data/tensor_specs.py | 25 +++- torchrl/envs/common.py | 230 +++++++++++++++++++++++++++-------- torchrl/envs/utils.py | 30 +++-- torchrl/objectives/sac.py | 2 +- 7 files changed, 267 insertions(+), 71 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index eb517429c08..d78e2f27184 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1038,11 +1038,13 @@ def _step( tensordict: TensorDictBase, ) -> TensorDictBase: action = tensordict.get(self.action_key) + try: + device = self.full_action_spec[self.action_key].device + except KeyError: + device = self.device self.count += action.to( dtype=torch.int, - device=self.full_action_spec[self.action_key].device - if self.device is None - else self.device, + device=device if self.device is None else self.device, ) tensordict = TensorDict( source={ @@ -1275,8 +1277,10 @@ def __init__( max_steps = torch.tensor(5) if start_val is None: start_val = torch.zeros((), dtype=torch.int32) - if not max_steps.shape == self.batch_size: - raise RuntimeError("batch_size and max_steps shape must match.") + if max_steps.shape != self.batch_size: + raise RuntimeError( + f"batch_size and max_steps shape must match. Got self.batch_size={self.batch_size} and max_steps.shape={max_steps.shape}." + ) self.max_steps = max_steps diff --git a/test/test_env.py b/test/test_env.py index ab854a3b4be..81708b0b9a6 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -3526,6 +3526,34 @@ def test_single_env_spec(): assert env.input_spec.is_in(env.input_spec_unbatched.zeros(env.shape)) +def test_auto_spec(): + env = CountingEnv() + td = env.reset() + + policy = lambda td, action_spec=env.full_action_spec.clone(): td.update( + action_spec.rand() + ) + + env.full_observation_spec = Composite( + shape=env.full_observation_spec.shape, device=env.full_observation_spec.device + ) + env.full_action_spec = Composite( + shape=env.full_action_spec.shape, device=env.full_action_spec.device + ) + env.full_reward_spec = Composite( + shape=env.full_reward_spec.shape, device=env.full_reward_spec.device + ) + env.full_done_spec = Composite( + shape=env.full_done_spec.shape, device=env.full_done_spec.device + ) + env.full_state_spec = Composite( + shape=env.full_state_spec.shape, device=env.full_state_spec.device + ) + env._action_keys = ["action"] + env.auto_specs_(policy, tensordict=td.copy()) + env.check_env_specs(tensordict=td.copy()) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_specs.py b/test/test_specs.py index 39b09798ac2..3dedc6233a9 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -412,6 +412,15 @@ def test_getitem(self, shape, is_complete, device, dtype): with pytest.raises(KeyError): _ = ts["UNK"] + def test_setitem_newshape(self, shape, is_complete, device, dtype): + ts = self._composite_spec(shape, is_complete, device, dtype) + new_spec = ts.clone() + new_spec.shape = torch.Size(()) + new_spec.clear_device_() + ts["new_spec"] = new_spec + assert ts["new_spec"].shape == ts.shape + assert ts["new_spec"].device == ts.device + def test_setitem_forbidden_keys(self, shape, is_complete, device, dtype): ts = self._composite_spec(shape, is_complete, device, dtype) for key in {"shape", "device", "dtype", "space"}: diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 32e61bc3ede..ddf6ed41c99 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -4372,11 +4372,20 @@ def set(self, name, spec): if spec is not None: shape = spec.shape if shape[: self.ndim] != self.shape: - raise ValueError( - "The shape of the spec and the Composite mismatch: the first " - f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and " - f"Composite.shape={self.shape}." - ) + if ( + isinstance(spec, Composite) + and spec.ndim < self.ndim + and self.shape[: spec.ndim] == spec.shape + ): + # Try to set the composite shape + spec = spec.clone() + spec.shape = self.shape + else: + raise ValueError( + "The shape of the spec and the Composite mismatch: the first " + f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and " + f"Composite.shape={self.shape}." + ) self._specs[name] = spec def __init__( @@ -4448,6 +4457,8 @@ def clear_device_(self): """Clears the device of the Composite.""" self._device = None for spec in self._specs.values(): + if spec is None: + continue spec.clear_device_() return self @@ -4530,6 +4541,10 @@ def __setitem__(self, key, value): and value.device != self.device ): if isinstance(value, Composite) and value.device is None: + # We make a clone not to mess up the spec that was provided. + # in set() we do the same for shape - these two ops should be grouped. + # we don't care about the overhead of cloning twice though because in theory + # we don't set specs often. value = value.clone().to(self.device) else: raise RuntimeError( diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 8adf36b0019..d5a062bc11e 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -356,6 +356,9 @@ class EnvBase(nn.Module, metaclass=_EnvPostInit): .. note:: Learn more about dynamic specs and environments :ref:`here `. """ + _batch_size: torch.Size | None + _device: torch.device | None + def __init__( self, *, @@ -364,34 +367,178 @@ def __init__( run_type_checks: bool = False, allow_done_after_reset: bool = False, ): + super().__init__() + self.__dict__.setdefault("_batch_size", None) - if device is not None: - self.__dict__["_device"] = _make_ordinal_device(torch.device(device)) - output_spec = self.__dict__.get("_output_spec") - if output_spec is not None: - self.__dict__["_output_spec"] = ( - output_spec.to(self.device) - if self.device is not None - else output_spec - ) - input_spec = self.__dict__.get("_input_spec") - if input_spec is not None: - self.__dict__["_input_spec"] = ( - input_spec.to(self.device) - if self.device is not None - else input_spec - ) + self.__dict__.setdefault("_device", None) - super().__init__() - if "is_closed" not in self.__dir__(): - self.is_closed = True if batch_size is not None: # we want an error to be raised if we pass batch_size but # it's already been set - self.batch_size = torch.Size(batch_size) + batch_size = self.batch_size = torch.Size(batch_size) + else: + batch_size = torch.Size(()) + + if device is not None: + device = self.__dict__["_device"] = _make_ordinal_device( + torch.device(device) + ) + + output_spec = self.__dict__.get("_output_spec") + if output_spec is None: + output_spec = self.__dict__["_output_spec"] = Composite( + shape=batch_size, device=device + ).lock_() + elif self._output_spec.device != device and device is not None: + self.__dict__["_output_spec"] = self.__dict__["_output_spec"].to( + self.device + ) + input_spec = self.__dict__.get("_input_spec") + if input_spec is None: + input_spec = self.__dict__["_input_spec"] = Composite( + shape=batch_size, device=device + ).lock_() + elif self._input_spec.device != device and device is not None: + self.__dict__["_input_spec"] = self.__dict__["_input_spec"].to(self.device) + + output_spec.unlock_() + input_spec.unlock_() + if "full_observation_spec" not in output_spec: + output_spec["full_observation_spec"] = Composite() + if "full_done_spec" not in output_spec: + output_spec["full_done_spec"] = Composite() + if "full_reward_spec" not in output_spec: + output_spec["full_reward_spec"] = Composite() + if "full_state_spec" not in input_spec: + input_spec["full_state_spec"] = Composite() + if "full_action_spec" not in input_spec: + input_spec["full_action_spec"] = Composite() + output_spec.lock_() + input_spec.lock_() + + if "is_closed" not in self.__dir__(): + self.is_closed = True self._run_type_checks = run_type_checks self._allow_done_after_reset = allow_done_after_reset + def auto_specs_( + self, + policy: Callable[[TensorDictBase], TensorDictBase], + *, + tensordict: TensorDictBase | None = None, + action_key: NestedKey | List[NestedKey] = "action", + done_key: NestedKey | List[NestedKey] | None = None, + observation_key: NestedKey | List[NestedKey] = "observation", + reward_key: NestedKey | List[NestedKey] = "reward", + batch_size: torch.Size | None = None, + ): + """Automatically sets the specifications (specs) of the environment based on a random rollout using a given policy. + + This method performs a rollout using the provided policy to infer the input and output specifications of the environment. + It updates the environment's specs for actions, observations, rewards, and done signals based on the data collected + during the rollout. + + Args: + policy (Callable[[TensorDictBase], TensorDictBase]): + A callable policy that takes a `TensorDictBase` as input and returns a `TensorDictBase` as output. + This policy is used to perform the rollout and determine the specs. + + Keyword Args: + tensordict (TensorDictBase, optional): + An optional `TensorDictBase` instance to be used as the initial state for the rollout. + If not provided, the environment's `reset` method will be called to obtain the initial state. + action_key (NestedKey or List[NestedKey], optional): + The key(s) used to identify actions in the `TensorDictBase`. Defaults to "action". + done_key (NestedKey or List[NestedKey], optional): + The key(s) used to identify done signals in the `TensorDictBase`. Defaults to ``None``, which will + attempt to use ["done", "terminated", "truncated"] as potential keys. + observation_key (NestedKey or List[NestedKey], optional): + The key(s) used to identify observations in the `TensorDictBase`. Defaults to "observation". + reward_key (NestedKey or List[NestedKey], optional): + The key(s) used to identify rewards in the `TensorDictBase`. Defaults to "reward". + + Returns: + EnvBase: The environment instance with updated specs. + + Raises: + RuntimeError: If there are keys in the output specs that are not accounted for in the provided keys. + """ + if self.batch_locked or tensordict is None: + batch_size = self.batch_size + else: + batch_size = tensordict.batch_size + if tensordict is None: + tensordict = self.reset() + + # Input specs + tensordict = policy(tensordict) + step_0 = self.step(tensordict.copy()) + tensordict2 = step_0.get("next").copy() + step_1 = self.step(policy(tensordict2).copy()) + nexts_0: TensorDictBase = step_0.pop("next") + nexts_1: TensorDictBase = step_1.pop("next") + + input_spec_stack = {} + tensordict.apply( + partial(_tensor_to_spec, stack=input_spec_stack), + tensordict2, + named=True, + nested_keys=True, + ) + input_spec = Composite(input_spec_stack, batch_size=batch_size) + if not self.batch_locked and batch_size != self.batch_size: + while input_spec.shape: + input_spec = input_spec[0] + if isinstance(action_key, NestedKey): + action_key = [action_key] + full_action_spec = input_spec.separates(*action_key, default=None) + + # Output specs + + output_spec_stack = {} + nexts_0.apply( + partial(_tensor_to_spec, stack=output_spec_stack), + nexts_1, + named=True, + nested_keys=True, + ) + + output_spec = Composite(output_spec_stack, batch_size=batch_size) + if not self.batch_locked and batch_size != self.batch_size: + while output_spec.shape: + output_spec = output_spec[0] + + if done_key is None: + done_key = ["done", "terminated", "truncated"] + full_done_spec = output_spec.separates(*done_key, default=None) + if full_done_spec is not None: + self.full_done_spec = full_done_spec + + if isinstance(reward_key, NestedKey): + reward_key = [reward_key] + full_reward_spec = output_spec.separates(*reward_key, default=None) + + if isinstance(observation_key, NestedKey): + observation_key = [observation_key] + full_observation_spec = output_spec.separates(*observation_key, default=None) + if not output_spec.is_empty(recurse=True): + raise RuntimeError( + f"Keys {list(output_spec.keys(True, True))} are unaccounted for." + ) + + if full_action_spec is not None: + self.full_action_spec = full_action_spec + if full_done_spec is not None: + self.full_done_specs = full_done_spec + if full_observation_spec is not None: + self.full_observation_spec = full_observation_spec + if full_reward_spec is not None: + self.full_reward_spec = full_reward_spec + full_state_spec = input_spec + self.full_state_spec = full_state_spec + + return self + @wraps(check_env_specs_func) def check_env_specs(self, *args, **kwargs): return check_env_specs_func(self, *args, **kwargs) @@ -475,7 +622,7 @@ def batch_size(self) -> torch.Size: in parallel). """ - _batch_size = self.__dict__["_batch_size"] + _batch_size = self.__dict__.get("_batch_size") if _batch_size is None: _batch_size = self._batch_size = torch.Size([]) return _batch_size @@ -667,8 +814,6 @@ def action_keys(self) -> List[NestedKey]: if action_keys is not None: return action_keys keys = self.full_action_spec.keys(True, True) - if not len(keys): - raise AttributeError("Could not find action spec") keys = sorted(keys, key=_repr_by_depth) self.__dict__["_action_keys"] = keys return keys @@ -827,15 +972,7 @@ def action_spec(self, value: TensorSpec) -> None: "Please use `env.action_spec_unbatched = value` to set unbatched versions instead." ) - if isinstance(value, Composite): - for _ in value.values(True, True): # noqa: B007 - break - else: - raise RuntimeError( - "An empty Composite was passed for the action spec. " - "This is currently not permitted." - ) - else: + if not isinstance(value, Composite): value = Composite( action=value.to(device), shape=self.batch_size, device=device ) @@ -892,7 +1029,6 @@ def reward_keys(self) -> List[NestedKey]: reward_keys = self.__dict__.get("_reward_keys") if reward_keys is not None: return reward_keys - reward_keys = sorted(self.full_reward_spec.keys(True, True), key=_repr_by_depth) self.__dict__["_reward_keys"] = reward_keys return reward_keys @@ -1030,15 +1166,7 @@ def reward_spec(self, value: TensorSpec) -> None: f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size}). " "Please use `env.reward_spec_unbatched = value` to set unbatched versions instead." ) - if isinstance(value, Composite): - for _ in value.values(True, True): # noqa: B007 - break - else: - raise RuntimeError( - "An empty Composite was passed for the reward spec. " - "This is currently not permitted." - ) - else: + if not isinstance(value, Composite): value = Composite( reward=value.to(device), shape=self.batch_size, device=device ) @@ -1319,15 +1447,7 @@ def done_spec(self, value: TensorSpec) -> None: raise ValueError( f"The value of spec.shape ({value.shape}) must match the env batch size ({self.batch_size})." ) - if isinstance(value, Composite): - for _ in value.values(True, True): # noqa: B007 - break - else: - raise RuntimeError( - "An empty Composite was passed for the done spec. " - "This is currently not permitted." - ) - else: + if not isinstance(value, Composite): value = Composite( done=value.to(device), terminated=value.to(device), @@ -3445,3 +3565,11 @@ def _has_dynamic_specs(spec: Composite): any(s == -1 for s in spec.shape) for spec in spec.values(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS) ) + + +def _tensor_to_spec(name, leaf, leaf_compare=None, *, stack): + shape = leaf.shape + if leaf_compare is not None: + shape_compare = leaf_compare.shape + shape = [s0 if s0 == s1 else -1 for s0, s1 in zip(shape, shape_compare)] + stack[name] = Unbounded(shape, device=leaf.device, dtype=leaf.dtype) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index c83591acb63..7454bce99b3 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -16,8 +16,6 @@ from enum import Enum from typing import Any, Dict, List, Union -import tensordict.base - import torch from tensordict import ( @@ -29,7 +27,7 @@ TensorDictBase, unravel_key, ) -from tensordict.base import _is_leaf_nontensor +from tensordict.base import _default_is_leaf, _is_leaf_nontensor from tensordict.nn import TensorDictModule, TensorDictModuleBase from tensordict.nn.probabilistic import ( # noqa interaction_type as exploration_type, @@ -691,7 +689,11 @@ def _per_level_env_check(data0, data1, check_dtype): def check_env_specs( - env, return_contiguous=True, check_dtype=True, seed: int | None = None + env, + return_contiguous=True, + check_dtype=True, + seed: int | None = None, + tensordict: TensorDictBase | None = None, ): """Tests an environment specs against the results of short rollout. @@ -715,6 +717,7 @@ def check_env_specs( setting the rng state back to what is was isn't a feature of most environment, we leave it to the user to accomplish that. Defaults to ``None``. + tensordict (TensorDict, optional): an optional tensordict instance to use for reset. Caution: this function resets the env seed. It should be used "offline" to check that an env is adequately constructed, but it may affect the seeding @@ -732,7 +735,16 @@ def check_env_specs( ) fake_tensordict = env.fake_tensordict() - real_tensordict = env.rollout(3, return_contiguous=return_contiguous) + if not env._batch_locked and tensordict is not None: + shape = torch.broadcast_shapes(fake_tensordict.shape, tensordict.shape) + fake_tensordict = fake_tensordict.expand(shape) + tensordict = tensordict.expand(shape) + real_tensordict = env.rollout( + 3, + return_contiguous=return_contiguous, + tensordict=tensordict, + auto_reset=tensordict is None, + ) if return_contiguous: fake_tensordict = fake_tensordict.unsqueeze(real_tensordict.batch_dims - 1) @@ -743,17 +755,17 @@ def check_env_specs( ) # eliminate empty containers fake_tensordict_select = fake_tensordict.select( - *fake_tensordict.keys(True, True, is_leaf=tensordict.base._default_is_leaf) + *fake_tensordict.keys(True, True, is_leaf=_default_is_leaf) ) real_tensordict_select = real_tensordict.select( - *real_tensordict.keys(True, True, is_leaf=tensordict.base._default_is_leaf) + *real_tensordict.keys(True, True, is_leaf=_default_is_leaf) ) # check keys fake_tensordict_keys = set( - fake_tensordict.keys(True, True, is_leaf=tensordict.base._is_leaf_nontensor) + fake_tensordict.keys(True, True, is_leaf=_is_leaf_nontensor) ) real_tensordict_keys = set( - real_tensordict.keys(True, True, is_leaf=tensordict.base._is_leaf_nontensor) + real_tensordict.keys(True, True, is_leaf=_is_leaf_nontensor) ) if fake_tensordict_keys != real_tensordict_keys: raise AssertionError( diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 52efb3d312b..cd7039c323d 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -1243,7 +1243,7 @@ def _compute_target(self, tensordict) -> Tensor: # unlike in continuous SAC, we can compute the exact expectation over all discrete actions next_state_value = (next_prob * next_state_value).sum(-1).unsqueeze(-1) if next_tensordict_select is not next_tensordict: - mask = ~done.squeeze(-1) + mask = ~done next_state_value = next_state_value.new_zeros( mask.shape ).masked_scatter_(mask, next_state_value) From 90c8e40f64bb76601d93a9416fa8723cd607ffe2 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 2 Dec 2024 16:24:13 +0000 Subject: [PATCH 03/27] [BugFix] Better account of composite distributions in PPO ghstack-source-id: 3d86f99bc5b20a53e4092d786e96a5f7e83405ac Pull Request resolved: https://github.com/pytorch/rl/pull/2622 --- torchrl/objectives/ppo.py | 53 +++++++++++++++++--------- torchrl/objectives/utils.py | 5 +++ torchrl/objectives/value/advantages.py | 22 ++++++++--- 3 files changed, 58 insertions(+), 22 deletions(-) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 8c64c1ba539..eb9a916dfc1 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -18,6 +18,7 @@ TensorDictParams, ) from tensordict.nn import ( + CompositeDistribution, dispatch, ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential, @@ -33,6 +34,7 @@ _clip_value_loss, _GAMMA_LMBDA_DEPREC_ERROR, _reduce, + _sum_td_features, default_value_kwargs, distance_loss, ValueEstimators, @@ -462,9 +464,13 @@ def reset(self) -> None: def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: try: - entropy = dist.entropy() + if isinstance(dist, CompositeDistribution): + kwargs = {"aggregate_probabilities": False, "include_sum": False} + else: + kwargs = {} + entropy = dist.entropy(**kwargs) if is_tensor_collection(entropy): - entropy = entropy.get(dist.entropy_key) + entropy = _sum_td_features(entropy) except NotImplementedError: x = dist.rsample((self.samples_mc_entropy,)) log_prob = dist.log_prob(x) @@ -497,13 +503,20 @@ def _log_weight( if isinstance(action, torch.Tensor): log_prob = dist.log_prob(action) else: - maybe_log_prob = dist.log_prob(tensordict) - if not isinstance(maybe_log_prob, torch.Tensor): - # In some cases (Composite distribution with aggregate_probabilities toggled off) the returned type may not - # be a tensor - log_prob = maybe_log_prob.get(self.tensor_keys.sample_log_prob) + if isinstance(dist, CompositeDistribution): + is_composite = True + kwargs = { + "inplace": False, + "aggregate_probabilities": False, + "include_sum": False, + } else: - log_prob = maybe_log_prob + is_composite = False + kwargs = {} + log_prob = dist.log_prob(tensordict, **kwargs) + if is_composite and not isinstance(prev_log_prob, TensorDict): + log_prob = _sum_td_features(log_prob) + log_prob.view_as(prev_log_prob) log_weight = (log_prob - prev_log_prob).unsqueeze(-1) kl_approx = (prev_log_prob - log_prob).unsqueeze(-1) @@ -598,6 +611,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: advantage = (advantage - loc) / scale log_weight, dist, kl_approx = self._log_weight(tensordict) + if is_tensor_collection(log_weight): + log_weight = _sum_td_features(log_weight) + log_weight = log_weight.view(advantage.shape) neg_loss = log_weight.exp() * advantage td_out = TensorDict({"loss_objective": -neg_loss}, batch_size=[]) if self.entropy_bonus: @@ -1149,16 +1165,19 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: kl = torch.distributions.kl.kl_divergence(previous_dist, current_dist) except NotImplementedError: x = previous_dist.sample((self.samples_mc_kl,)) - previous_log_prob = previous_dist.log_prob(x) - current_log_prob = current_dist.log_prob(x) + if isinstance(previous_dist, CompositeDistribution): + kwargs = { + "aggregate_probabilities": False, + "inplace": False, + "include_sum": False, + } + else: + kwargs = {} + previous_log_prob = previous_dist.log_prob(x, **kwargs) + current_log_prob = current_dist.log_prob(x, **kwargs) if is_tensor_collection(current_log_prob): - previous_log_prob = previous_log_prob.get( - self.tensor_keys.sample_log_prob - ) - current_log_prob = current_log_prob.get( - self.tensor_keys.sample_log_prob - ) - + previous_log_prob = _sum_td_features(previous_log_prob) + current_log_prob = _sum_td_features(current_log_prob) kl = (previous_log_prob - current_log_prob).mean(0) kl = kl.unsqueeze(-1) neg_loss = neg_loss - self.beta * kl diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index 4dfed60e5a9..9c46fc98262 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -615,3 +615,8 @@ def group_optimizers(*optimizers: torch.optim.Optimizer) -> torch.optim.Optimize raise ValueError("Cannot group optimizers of different type.") params.extend(optimizer.param_groups) return cls(params) + + +def _sum_td_features(data: TensorDictBase) -> torch.Tensor: + # Sum all features and return a tensor + return data.sum(dim="feature", reduce=True) diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index fadfe932c50..bbd6a23bfdd 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -15,11 +15,14 @@ import torch from tensordict import TensorDictBase from tensordict.nn import ( + CompositeDistribution, dispatch, + ProbabilisticTensorDictModule, set_skip_existing, TensorDictModule, TensorDictModuleBase, ) +from tensordict.nn.probabilistic import interaction_type from tensordict.utils import NestedKey from torch import Tensor @@ -74,14 +77,22 @@ def new_func(self, *args, **kwargs): def _call_actor_net( - actor_net: TensorDictModuleBase, + actor_net: ProbabilisticTensorDictModule, data: TensorDictBase, params: TensorDictBase, log_prob_key: NestedKey, ): - # TODO: extend to handle time dimension (and vmap?) - log_pi = actor_net(data.select(*actor_net.in_keys, strict=False)).get(log_prob_key) - return log_pi + dist = actor_net.get_dist(data.select(*actor_net.in_keys, strict=False)) + if isinstance(dist, CompositeDistribution): + kwargs = { + "aggregate_probabilities": True, + "inplace": False, + "include_sum": False, + } + else: + kwargs = {} + s = actor_net._dist_sample(dist, interaction_type=interaction_type()) + return dist.log_prob(s, **kwargs) class ValueEstimatorBase(TensorDictModuleBase): @@ -1771,7 +1782,8 @@ def forward( data=tensordict, params=None, log_prob_key=self.tensor_keys.sample_log_prob, - ).view_as(value) + ) + log_pi = log_pi.view_as(value) # Compute the V-Trace correction done = tensordict.get(("next", self.tensor_keys.done)) From de61e4d5eefeb41cd0e69a3821ec1b8ebf34c8c8 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 2 Dec 2024 17:49:24 +0000 Subject: [PATCH 04/27] [BugFix] skip_done_states in SAC ghstack-source-id: 39d97360e3b0e45dd8c327487eac50ddafe2254d Pull Request resolved: https://github.com/pytorch/rl/pull/2613 --- test/test_cost.py | 2 + torchrl/objectives/sac.py | 151 ++++++++++++++++++++++++-------------- 2 files changed, 99 insertions(+), 54 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index c48b4a28b99..1f191e41db6 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -4493,6 +4493,7 @@ def test_sac_terminating( actor_network=actor, qvalue_network=qvalue, value_network=value, + skip_done_states=True, ) loss.set_keys( action=action_key, @@ -5204,6 +5205,7 @@ def test_discrete_sac_terminating( qvalue_network=qvalue, num_actions=actor.spec[action_key].space.n, action_space="one-hot", + skip_done_states=True, ) loss.set_keys( action=action_key, diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index cd7039c323d..dafff17011e 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -126,6 +126,10 @@ class SACLoss(LossModule): ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, ``"mean"``: the sum of the output will be divided by the number of elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``. + skip_done_states (bool, optional): whether the actor network used for value computation should only be run on + valid, non-terminating next states. If ``True``, it is assumed that the done state can be broadcast to the + shape of the data and that masking the data results in a valid data structure. Among other things, this may + not be true in MARL settings or when using RNNs. Defaults to ``False``. Examples: >>> import torch @@ -320,6 +324,7 @@ def __init__( priority_key: str = None, separate_losses: bool = False, reduction: str = None, + skip_done_states: bool = False, ) -> None: self._in_keys = None self._out_keys = None @@ -418,6 +423,7 @@ def __init__( raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) self._make_vmap() self.reduction = reduction + self.skip_done_states = skip_done_states def _make_vmap(self): self._vmap_qnetworkN0 = _vmap_func( @@ -712,36 +718,44 @@ def _compute_target_v2(self, tensordict) -> Tensor: ExplorationType.RANDOM ), self.actor_network_params.to_module(self.actor_network): next_tensordict = tensordict.get("next").copy() - # Check done state and avoid passing these to the actor - done = next_tensordict.get(self.tensor_keys.done) - if done is not None and done.any(): - next_tensordict_select = next_tensordict[~done.squeeze(-1)] - else: - next_tensordict_select = next_tensordict - next_dist = self.actor_network.get_dist(next_tensordict_select) - next_action = next_dist.rsample() - next_sample_log_prob = compute_log_prob( - next_dist, next_action, self.tensor_keys.log_prob - ) - if next_tensordict_select is not next_tensordict: - mask = ~done.squeeze(-1) - if mask.ndim < next_action.ndim: - mask = expand_right( - mask, (*mask.shape, *next_action.shape[mask.ndim :]) - ) - next_action = next_action.new_zeros(mask.shape).masked_scatter_( - mask, next_action + if self.skip_done_states: + # Check done state and avoid passing these to the actor + done = next_tensordict.get(self.tensor_keys.done) + if done is not None and done.any(): + next_tensordict_select = next_tensordict[~done.squeeze(-1)] + else: + next_tensordict_select = next_tensordict + next_dist = self.actor_network.get_dist(next_tensordict_select) + next_action = next_dist.rsample() + next_sample_log_prob = compute_log_prob( + next_dist, next_action, self.tensor_keys.log_prob ) - mask = ~done.squeeze(-1) - if mask.ndim < next_sample_log_prob.ndim: - mask = expand_right( - mask, - (*mask.shape, *next_sample_log_prob.shape[mask.ndim :]), + if next_tensordict_select is not next_tensordict: + mask = ~done.squeeze(-1) + if mask.ndim < next_action.ndim: + mask = expand_right( + mask, (*mask.shape, *next_action.shape[mask.ndim :]) + ) + next_action = next_action.new_zeros(mask.shape).masked_scatter_( + mask, next_action ) - next_sample_log_prob = next_sample_log_prob.new_zeros( - mask.shape - ).masked_scatter_(mask, next_sample_log_prob) - next_tensordict.set(self.tensor_keys.action, next_action) + mask = ~done.squeeze(-1) + if mask.ndim < next_sample_log_prob.ndim: + mask = expand_right( + mask, + (*mask.shape, *next_sample_log_prob.shape[mask.ndim :]), + ) + next_sample_log_prob = next_sample_log_prob.new_zeros( + mask.shape + ).masked_scatter_(mask, next_sample_log_prob) + next_tensordict.set(self.tensor_keys.action, next_action) + else: + next_dist = self.actor_network.get_dist(next_tensordict) + next_action = next_dist.rsample() + next_tensordict.set(self.tensor_keys.action, next_action) + next_sample_log_prob = compute_log_prob( + next_dist, next_action, self.tensor_keys.log_prob + ) # get q-values next_tensordict_expand = self._vmap_qnetworkN0( @@ -877,6 +891,10 @@ class DiscreteSACLoss(LossModule): ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, ``"mean"``: the sum of the output will be divided by the number of elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``. + skip_done_states (bool, optional): whether the actor network used for value computation should only be run on + valid, non-terminating next states. If ``True``, it is assumed that the done state can be broadcast to the + shape of the data and that masking the data results in a valid data structure. Among other things, this may + not be true in MARL settings or when using RNNs. Defaults to ``False``. Examples: >>> import torch @@ -1051,6 +1069,7 @@ def __init__( priority_key: str = None, separate_losses: bool = False, reduction: str = None, + skip_done_states: bool = False, ): if reduction is None: reduction = "mean" @@ -1133,6 +1152,7 @@ def __init__( ) self._make_vmap() self.reduction = reduction + self.skip_done_states = skip_done_states def _make_vmap(self): self._vmap_qnetworkN0 = _vmap_func( @@ -1218,35 +1238,58 @@ def _compute_target(self, tensordict) -> Tensor: with torch.no_grad(): next_tensordict = tensordict.get("next").clone(False) - done = next_tensordict.get(self.tensor_keys.done) - if done is not None and done.any(): - next_tensordict_select = next_tensordict[~done.squeeze(-1)] - else: - next_tensordict_select = next_tensordict + if self.skip_done_states: + done = next_tensordict.get(self.tensor_keys.done) + if done is not None and done.any(): + next_tensordict_select = next_tensordict[~done.squeeze(-1)] + else: + next_tensordict_select = next_tensordict - # get probs and log probs for actions computed from "next" - with self.actor_network_params.to_module(self.actor_network): - next_dist = self.actor_network.get_dist(next_tensordict_select) - next_log_prob = next_dist.logits - next_prob = next_log_prob.exp() + # get probs and log probs for actions computed from "next" + with self.actor_network_params.to_module(self.actor_network): + next_dist = self.actor_network.get_dist(next_tensordict_select) + next_log_prob = next_dist.logits + next_prob = next_log_prob.exp() - # get q-values for all actions - next_tensordict_expand = self._vmap_qnetworkN0( - next_tensordict_select, self.target_qvalue_network_params - ) - next_action_value = next_tensordict_expand.get( - self.tensor_keys.action_value - ) + # get q-values for all actions + next_tensordict_expand = self._vmap_qnetworkN0( + next_tensordict_select, self.target_qvalue_network_params + ) + next_action_value = next_tensordict_expand.get( + self.tensor_keys.action_value + ) - # like in continuous SAC, we take the minimum of the value ensemble and subtract the entropy term - next_state_value = next_action_value.min(0)[0] - self._alpha * next_log_prob - # unlike in continuous SAC, we can compute the exact expectation over all discrete actions - next_state_value = (next_prob * next_state_value).sum(-1).unsqueeze(-1) - if next_tensordict_select is not next_tensordict: - mask = ~done - next_state_value = next_state_value.new_zeros( - mask.shape - ).masked_scatter_(mask, next_state_value) + # like in continuous SAC, we take the minimum of the value ensemble and subtract the entropy term + next_state_value = ( + next_action_value.min(0)[0] - self._alpha * next_log_prob + ) + # unlike in continuous SAC, we can compute the exact expectation over all discrete actions + next_state_value = (next_prob * next_state_value).sum(-1).unsqueeze(-1) + if next_tensordict_select is not next_tensordict: + mask = ~done + next_state_value = next_state_value.new_zeros( + mask.shape + ).masked_scatter_(mask, next_state_value) + else: + # get probs and log probs for actions computed from "next" + with self.actor_network_params.to_module(self.actor_network): + next_dist = self.actor_network.get_dist(next_tensordict) + next_prob = next_dist.probs + next_log_prob = torch.log(torch.where(next_prob == 0, 1e-8, next_prob)) + + # get q-values for all actions + next_tensordict_expand = self._vmap_qnetworkN0( + next_tensordict, self.target_qvalue_network_params + ) + next_action_value = next_tensordict_expand.get( + self.tensor_keys.action_value + ) + # like in continuous SAC, we take the minimum of the value ensemble and subtract the entropy term + next_state_value = ( + next_action_value.min(0)[0] - self._alpha * next_log_prob + ) + # unlike in continuous SAC, we can compute the exact expectation over all discrete actions + next_state_value = (next_prob * next_state_value).sum(-1).unsqueeze(-1) tensordict.set( ("next", self.value_estimator.tensor_keys.value), next_state_value From 830f2f26ca91ec153f63e539c423223dddd95e21 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 2 Dec 2024 17:46:24 +0000 Subject: [PATCH 05/27] [BugFix] ActionDiscretizer scalar integration ghstack-source-id: b22102f3730914b125ef0f813f4d2f22dec0b26e Pull Request resolved: https://github.com/pytorch/rl/pull/2619 --- test/mocking_classes.py | 69 +++++++++++++++++ test/test_transforms.py | 103 +++++++++++++++++++++----- torchrl/envs/transforms/transforms.py | 89 +++++++++++++++------- 3 files changed, 213 insertions(+), 48 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index d78e2f27184..bb902f879b1 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1927,3 +1927,72 @@ def _step( def _set_seed(self, seed: Optional[int]): self.manual_seed = seed return seed + + +class EnvWithScalarAction(EnvBase): + def __init__(self, singleton: bool = False, **kwargs): + super().__init__(**kwargs) + self.singleton = singleton + self.action_spec = Bounded( + -1, + 1, + shape=( + *self.batch_size, + 1, + ) + if self.singleton + else self.batch_size, + ) + self.observation_spec = Composite( + observation=Unbounded( + shape=( + *self.batch_size, + 3, + ) + ), + shape=self.batch_size, + ) + self.done_spec = Composite( + done=Unbounded(self.batch_size + (1,), dtype=torch.bool), + terminated=Unbounded(self.batch_size + (1,), dtype=torch.bool), + truncated=Unbounded(self.batch_size + (1,), dtype=torch.bool), + shape=self.batch_size, + ) + self.reward_spec = Unbounded( + shape=( + *self.batch_size, + 1, + ) + ) + + def _reset(self, td: TensorDict): + return TensorDict( + observation=torch.randn(*self.batch_size, 3, device=self.device), + done=torch.zeros(*self.batch_size, 1, dtype=torch.bool, device=self.device), + truncated=torch.zeros( + *self.batch_size, 1, dtype=torch.bool, device=self.device + ), + terminated=torch.zeros( + *self.batch_size, 1, dtype=torch.bool, device=self.device + ), + device=self.device, + ) + + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + return TensorDict( + observation=torch.randn(*self.batch_size, 3, device=self.device), + reward=torch.zeros(1, device=self.device), + done=torch.zeros(*self.batch_size, 1, dtype=torch.bool, device=self.device), + truncated=torch.zeros( + *self.batch_size, 1, dtype=torch.bool, device=self.device + ), + terminated=torch.zeros( + *self.batch_size, 1, dtype=torch.bool, device=self.device + ), + ) + + def _set_seed(self, seed: Optional[int]): + ... diff --git a/test/test_transforms.py b/test/test_transforms.py index 8b2ada8c93a..ae428d35d97 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -41,6 +41,7 @@ CountingEnvCountPolicy, DiscreteActionConvMockEnv, DiscreteActionConvMockEnvNumpy, + EnvWithScalarAction, IncrementingEnv, MockBatchedLockedEnv, MockBatchedUnLockedEnv, @@ -66,6 +67,7 @@ CountingEnvCountPolicy, DiscreteActionConvMockEnv, DiscreteActionConvMockEnvNumpy, + EnvWithScalarAction, IncrementingEnv, MockBatchedLockedEnv, MockBatchedUnLockedEnv, @@ -11781,17 +11783,33 @@ def test_transform_inverse(self): class TestActionDiscretizer(TransformBase): @pytest.mark.parametrize("categorical", [True, False]) - def test_single_trans_env_check(self, categorical): - base_env = ContinuousActionVecMockEnv() + @pytest.mark.parametrize( + "env_cls", + [ + ContinuousActionVecMockEnv, + partial(EnvWithScalarAction, singleton=True), + partial(EnvWithScalarAction, singleton=False), + ], + ) + def test_single_trans_env_check(self, categorical, env_cls): + base_env = env_cls() env = base_env.append_transform( ActionDiscretizer(num_intervals=5, categorical=categorical) ) check_env_specs(env) @pytest.mark.parametrize("categorical", [True, False]) - def test_serial_trans_env_check(self, categorical): + @pytest.mark.parametrize( + "env_cls", + [ + ContinuousActionVecMockEnv, + partial(EnvWithScalarAction, singleton=True), + partial(EnvWithScalarAction, singleton=False), + ], + ) + def test_serial_trans_env_check(self, categorical, env_cls): def make_env(): - base_env = ContinuousActionVecMockEnv() + base_env = env_cls() return base_env.append_transform( ActionDiscretizer(num_intervals=5, categorical=categorical) ) @@ -11800,9 +11818,17 @@ def make_env(): check_env_specs(env) @pytest.mark.parametrize("categorical", [True, False]) - def test_parallel_trans_env_check(self, categorical): + @pytest.mark.parametrize( + "env_cls", + [ + ContinuousActionVecMockEnv, + partial(EnvWithScalarAction, singleton=True), + partial(EnvWithScalarAction, singleton=False), + ], + ) + def test_parallel_trans_env_check(self, categorical, env_cls): def make_env(): - base_env = ContinuousActionVecMockEnv() + base_env = env_cls() env = base_env.append_transform( ActionDiscretizer(num_intervals=5, categorical=categorical) ) @@ -11812,17 +11838,33 @@ def make_env(): check_env_specs(env) @pytest.mark.parametrize("categorical", [True, False]) - def test_trans_serial_env_check(self, categorical): - env = SerialEnv(2, ContinuousActionVecMockEnv).append_transform( + @pytest.mark.parametrize( + "env_cls", + [ + ContinuousActionVecMockEnv, + partial(EnvWithScalarAction, singleton=True), + partial(EnvWithScalarAction, singleton=False), + ], + ) + def test_trans_serial_env_check(self, categorical, env_cls): + env = SerialEnv(2, env_cls).append_transform( ActionDiscretizer(num_intervals=5, categorical=categorical) ) check_env_specs(env) @pytest.mark.parametrize("categorical", [True, False]) - def test_trans_parallel_env_check(self, categorical): - env = ParallelEnv( - 2, ContinuousActionVecMockEnv, mp_start_method=mp_ctx - ).append_transform(ActionDiscretizer(num_intervals=5, categorical=categorical)) + @pytest.mark.parametrize( + "env_cls", + [ + ContinuousActionVecMockEnv, + partial(EnvWithScalarAction, singleton=True), + partial(EnvWithScalarAction, singleton=False), + ], + ) + def test_trans_parallel_env_check(self, categorical, env_cls): + env = ParallelEnv(2, env_cls, mp_start_method=mp_ctx).append_transform( + ActionDiscretizer(num_intervals=5, categorical=categorical) + ) check_env_specs(env) def test_transform_no_env(self): @@ -11838,7 +11880,6 @@ def test_transform_compose(self): check_env_specs(env) @pytest.mark.skipif(not _has_gym, reason="gym required for this test") - @pytest.mark.parametrize("envname", ["cheetah", "pendulum"]) @pytest.mark.parametrize("interval_as_tensor", [False, True]) @pytest.mark.parametrize("categorical", [True, False]) @pytest.mark.parametrize( @@ -11851,15 +11892,37 @@ def test_transform_compose(self): ActionDiscretizer.SamplingStrategy.RANDOM, ], ) - def test_transform_env(self, envname, interval_as_tensor, categorical, sampling): + @pytest.mark.parametrize( + "env_cls", + [ + "cheetah", + "pendulum", + partial(EnvWithScalarAction, singleton=True), + partial(EnvWithScalarAction, singleton=False), + ], + ) + def test_transform_env(self, env_cls, interval_as_tensor, categorical, sampling): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - base_env = GymEnv( - HALFCHEETAH_VERSIONED() if envname == "cheetah" else PENDULUM_VERSIONED(), - device=device, - ) - if interval_as_tensor: - num_intervals = torch.arange(5, 11 if envname == "cheetah" else 6) + if env_cls == "cheetah": + base_env = GymEnv( + HALFCHEETAH_VERSIONED(), + device=device, + ) + num_intervals = torch.arange(5, 11) + elif env_cls == "pendulum": + base_env = GymEnv( + PENDULUM_VERSIONED(), + device=device, + ) + num_intervals = torch.arange(5, 6) else: + base_env = env_cls( + device=device, + ) + num_intervals = torch.arange(5, 6) + + if not interval_as_tensor: + # override num_intervals = 5 t = ActionDiscretizer( num_intervals=num_intervals, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 7bdd25591cd..7ab5a2deb72 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -8585,24 +8585,32 @@ def _indent(s): def transform_input_spec(self, input_spec): try: - action_spec = input_spec["full_action_spec", self.in_keys_inv[0]] + action_spec = self.parent.full_action_spec_unbatched[self.in_keys_inv[0]] if not isinstance(action_spec, Bounded): raise TypeError( - f"action spec type {type(action_spec)} is not supported." + f"action spec type {type(action_spec)} is not supported. The action spec type must be Bounded." ) n_act = action_spec.shape if not n_act: - n_act = 1 + n_act = () + empty_shape = True else: - n_act = n_act[-1] + n_act = (n_act[-1],) + empty_shape = False self.n_act = n_act self.dtype = action_spec.dtype - interval = (action_spec.high - action_spec.low).unsqueeze(-1) + interval = action_spec.high - action_spec.low num_intervals = self.num_intervals + if not empty_shape: + interval = interval.unsqueeze(-1) + elif isinstance(num_intervals, torch.Tensor): + num_intervals = int(num_intervals.squeeze()) + self.num_intervals = torch.as_tensor(num_intervals) + def custom_arange(nint): result = torch.arange( start=0.0, @@ -8625,11 +8633,13 @@ def custom_arange(nint): if isinstance(num_intervals, int): arange = ( - custom_arange(num_intervals).expand(n_act, num_intervals) * interval - ) - self.register_buffer( - "intervals", action_spec.low.unsqueeze(-1) + arange + custom_arange(num_intervals).expand((*n_act, num_intervals)) + * interval ) + low = action_spec.low + if not empty_shape: + low = low.unsqueeze(-1) + self.register_buffer("intervals", low + arange) else: arange = [ custom_arange(_num_intervals) * interval @@ -8644,12 +8654,6 @@ def custom_arange(nint): ) ] - cls = ( - functools.partial(MultiCategorical, remove_singleton=False) - if self.categorical - else MultiOneHot - ) - if not isinstance(num_intervals, torch.Tensor): nvec = torch.as_tensor(num_intervals, device=action_spec.device) else: @@ -8657,7 +8661,10 @@ def custom_arange(nint): if nvec.ndim > 1: raise RuntimeError(f"Cannot use num_intervals with shape {nvec.shape}") if nvec.ndim == 0 or nvec.numel() == 1: - nvec = nvec.expand(action_spec.shape[-1]) + if not empty_shape: + nvec = nvec.expand(action_spec.shape[-1]) + else: + nvec = nvec.squeeze() self.register_buffer("nvec", nvec) if self.sampling == self.SamplingStrategy.RANDOM: # compute jitters @@ -8667,7 +8674,22 @@ def custom_arange(nint): if self.categorical else (*action_spec.shape[:-1], nvec.sum()) ) - action_spec = cls(nvec=nvec, shape=shape, device=action_spec.device) + + if not empty_shape: + cls = ( + functools.partial(MultiCategorical, remove_singleton=False) + if self.categorical + else MultiOneHot + ) + action_spec = cls(nvec=nvec, shape=shape, device=action_spec.device) + + else: + cls = Categorical if self.categorical else OneHot + action_spec = cls(n=int(nvec), shape=shape, device=action_spec.device) + + batch_size = self.parent.batch_size + if batch_size: + action_spec = action_spec.expand(batch_size + action_spec.shape) input_spec["full_action_spec", self.out_keys_inv[0]] = action_spec if self.out_keys_inv[0] != self.in_keys_inv[0]: @@ -8705,6 +8727,8 @@ def _inv_call(self, tensordict): if self.categorical: action = action.unsqueeze(-1) if isinstance(intervals, torch.Tensor): + shape = action.shape[: -intervals.ndim] + intervals = intervals.expand(shape + intervals.shape) action = intervals.gather(index=action, dim=-1).squeeze(-1) else: action = torch.stack( @@ -8715,17 +8739,26 @@ def _inv_call(self, tensordict): -1, ) else: - nvec = self.nvec.tolist() - action = action.split(nvec, dim=-1) - if isinstance(intervals, torch.Tensor): - intervals = intervals.unbind(-2) - action = torch.stack( - [ - intervals[action].view(action.shape[:-1]) - for (intervals, action) in zip(intervals, action) - ], - -1, - ) + nvec = self.nvec + empty_shape = not nvec.ndim + if not empty_shape: + nvec = nvec.tolist() + if isinstance(intervals, torch.Tensor): + shape = action.shape[: (-intervals.ndim + 1)] + intervals = intervals.expand(shape + intervals.shape) + intervals = intervals.unbind(-2) + action = action.split(nvec, dim=-1) + action = torch.stack( + [ + intervals[action].view(action.shape[:-1]) + for (intervals, action) in zip(intervals, action) + ], + -1, + ) + else: + shape = action.shape[: -intervals.ndim] + intervals = intervals.expand(shape + intervals.shape) + action = intervals[action].squeeze(-1) if self.sampling == self.SamplingStrategy.RANDOM: action = action + self.jitters * torch.rand_like(self.jitters) From c72583f75ab220c7ef89e9bd2505045ea5898db4 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 2 Dec 2024 17:46:25 +0000 Subject: [PATCH 06/27] [Feature, Test] Adding tests for envs that have no specs ghstack-source-id: 4c75691baa1e70f417e518df15c4208cff189950 Pull Request resolved: https://github.com/pytorch/rl/pull/2621 --- test/mocking_classes.py | 14 ++++++++++++++ test/test_env.py | 30 ++++++++++++++++++++++++++++++ torchrl/envs/common.py | 8 ++++++-- torchrl/envs/utils.py | 2 ++ 4 files changed, 52 insertions(+), 2 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index bb902f879b1..3c30286c419 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1996,3 +1996,17 @@ def _step( def _set_seed(self, seed: Optional[int]): ... + + +class EnvThatDoesNothing(EnvBase): + def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: + return TensorDict(batch_size=self.batch_size, device=self.device) + + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + return TensorDict(batch_size=self.batch_size, device=self.device) + + def _set_seed(self, seed): + ... diff --git a/test/test_env.py b/test/test_env.py index 81708b0b9a6..b48b1a1cf8f 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -44,6 +44,7 @@ DiscreteActionConvMockEnvNumpy, DiscreteActionVecMockEnv, DummyModelBasedEnvBase, + EnvThatDoesNothing, EnvWithDynamicSpec, EnvWithMetadata, HeterogeneousCountingEnv, @@ -81,6 +82,7 @@ DiscreteActionConvMockEnvNumpy, DiscreteActionVecMockEnv, DummyModelBasedEnvBase, + EnvThatDoesNothing, EnvWithDynamicSpec, EnvWithMetadata, HeterogeneousCountingEnv, @@ -3554,6 +3556,34 @@ def test_auto_spec(): env.check_env_specs(tensordict=td.copy()) +def test_env_that_does_nothing(): + env = EnvThatDoesNothing() + env.check_env_specs() + r = env.rollout(3) + r.exclude( + "done", "terminated", ("next", "done"), ("next", "terminated"), inplace=True + ) + assert r.is_empty() + p_env = SerialEnv(2, EnvThatDoesNothing) + p_env.check_env_specs() + r = p_env.rollout(3) + r.exclude( + "done", "terminated", ("next", "done"), ("next", "terminated"), inplace=True + ) + assert r.is_empty() + p_env = ParallelEnv(2, EnvThatDoesNothing) + try: + p_env.check_env_specs() + r = p_env.rollout(3) + r.exclude( + "done", "terminated", ("next", "done"), ("next", "terminated"), inplace=True + ) + assert r.is_empty() + finally: + p_env.close() + del p_env + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index d5a062bc11e..bafe88b639a 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -2434,8 +2434,12 @@ def _register_gym( # noqa: F811 apply_api_compatibility=apply_api_compatibility, ) - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: - raise NotImplementedError("EnvBase.forward is not implemented") + def forward(self, *args, **kwargs): + raise NotImplementedError( + "EnvBase.forward is not implemented. If you ended here during a call to `ParallelEnv(...)`, please use " + "a constructor such as `ParallelEnv(num_env, lambda env=env: env)` instead. " + "Batched envs require constructors because environment instances may not always be serializable." + ) @abc.abstractmethod def _step( diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 7454bce99b3..209349878ec 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -287,6 +287,8 @@ def __call__(self, tensordict): if self.validate(tensordict): if self.keep_other: out = self._exclude(self.exclude_from_root, tensordict, out=None) + if out is None: + out = tensordict.empty() else: out = next_td.empty() self._grab_and_place( From b2e9f291ad2862e6b9d8d34e68d0e2607acc9295 Mon Sep 17 00:00:00 2001 From: Mana <57663038+0xMana-git@users.noreply.github.com> Date: Mon, 2 Dec 2024 23:15:18 -0800 Subject: [PATCH 07/27] [Doc] Fix typo in torchrl/modules/distributions/continuous.py (#2624) --- torchrl/modules/distributions/continuous.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index f32a3b0c6fa..eb9093dbcfe 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -554,7 +554,7 @@ def get_mode(self): def mean(self): raise NotImplementedError( f"{type(self).__name__} does not have a closed form formula for the average. " - "Am estimate of this value can be computed using dist.sample((N,)).mean(dim=0), " + "An estimate of this value can be computed using dist.sample((N,)).mean(dim=0), " "where N is a large number of samples." ) From 8257799353253ec481100b08592e65525a659690 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Valter=20Sch=C3=BCtz?= Date: Tue, 3 Dec 2024 15:12:10 +0100 Subject: [PATCH 08/27] [Doc] actor docstrings (#2626) Co-authored-by: Valter Schutz --- torchrl/modules/tensordict_module/actors.py | 10 +++++----- torchrl/modules/tensordict_module/probabilistic.py | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/torchrl/modules/tensordict_module/actors.py b/torchrl/modules/tensordict_module/actors.py index 888729835b5..6175bc8bf0c 100644 --- a/torchrl/modules/tensordict_module/actors.py +++ b/torchrl/modules/tensordict_module/actors.py @@ -153,14 +153,14 @@ class ProbabilisticActor(SafeProbabilisticTensorDictSequential): issues. If this value is out of bounds, it is projected back onto the desired space using the :obj:`TensorSpec.project` method. Default is ``False``. - default_interaction_type (str, optional): keyword-only argument. + default_interaction_type (tensordict.nn.InteractionType, optional): keyword-only argument. Default method to be used to retrieve - the output value. Should be one of: 'InteractionType.MODE', 'InteractionType.DETERMINISTIC', - 'InteractionType.MEDIAN', 'InteractionType.MEAN' or - 'InteractionType.RANDOM' (in which case the value is sampled + the output value. Should be one of: ``InteractionType.MODE``, ``InteractionType.DETERMINISTIC``, + ``InteractionType.MEDIAN``, ``InteractionType.MEAN`` or + ``InteractionType.RANDOM`` (in which case the value is sampled randomly from the distribution). TorchRL's ``ExplorationType`` class is a proxy to ``InteractionType``. - Defaults to is 'InteractionType.DETERMINISTIC'. + Defaults to ``InteractionType.DETERMINISTIC``. .. note:: When a sample is drawn, the :class:`ProbabilisticActor` instance will first look for the interaction mode dictated by the diff --git a/torchrl/modules/tensordict_module/probabilistic.py b/torchrl/modules/tensordict_module/probabilistic.py index 8bd5143d20f..5ea006b8d2f 100644 --- a/torchrl/modules/tensordict_module/probabilistic.py +++ b/torchrl/modules/tensordict_module/probabilistic.py @@ -68,12 +68,12 @@ class SafeProbabilisticModule(ProbabilisticTensorDictModule): returned by the input module. If the sample is out of bounds, it is projected back onto the desired space using the `TensorSpec.project` method. Default is ``False``. - default_interaction_type (str, optional): default method to be used to retrieve - the output value. Should be one of: 'mode', 'median', 'mean' or 'random' + default_interaction_type (tensordict.nn.InteractionType, optional): default method to be used to retrieve + the output value. Should be one of: ``InteractionType.MODE``, ``InteractionType.MEDIAN``, ``InteractionType.MEAN`` or ``InteractionType.RANDOM`` (in which case the value is sampled randomly from the distribution). Default - is 'mode'. + is ``InteractionType.MODE``. Note: When a sample is drawn, the :obj:`ProbabilisticTDModule` instance will - fist look for the interaction mode dictated by the `interaction_typ()` + fist look for the interaction mode dictated by the `interaction_type()` global function. If this returns `None` (its default value), then the `default_interaction_type` of the :class:`~.ProbabilisticTDModule` instance will be used. Note that DataCollector instances will use From d22266d05d7ae10f53e3b904d847d44743beba40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Valter=20Sch=C3=BCtz?= Date: Tue, 3 Dec 2024 15:12:49 +0100 Subject: [PATCH 09/27] [Doc] Update docstring for TruncatedNormal with correct parameter names (#2625) Co-authored-by: Valter Schutz --- torchrl/modules/distributions/continuous.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index eb9093dbcfe..e34f1be8ff9 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -205,8 +205,8 @@ class TruncatedNormal(D.Independent): Default is 5.0 - min (torch.Tensor or number, optional): minimum value of the distribution. Default = -1.0; - max (torch.Tensor or number, optional): maximum value of the distribution. Default = 1.0; + low (torch.Tensor or number, optional): minimum value of the distribution. Default = -1.0; + high (torch.Tensor or number, optional): maximum value of the distribution. Default = 1.0; tanh_loc (bool, optional): if ``True``, the above formula is used for the location scaling, otherwise the raw value is kept. Default is ``False``; From 607ebc52dc083290b6bcce98864881358f94fd7a Mon Sep 17 00:00:00 2001 From: Goia Rares Dan Tiago <115428237+raresdan@users.noreply.github.com> Date: Tue, 3 Dec 2024 16:35:51 +0200 Subject: [PATCH 10/27] [Refactor] Rename Recorder and LogReward (#2616) --- docs/source/reference/trainers.rst | 8 +-- sota-implementations/redq/utils.py | 10 ++-- test/test_trainer.py | 27 +++++----- torchrl/trainers/__init__.py | 2 + torchrl/trainers/helpers/trainers.py | 10 ++-- torchrl/trainers/trainers.py | 61 +++++++++++++++++++++-- tutorials/sphinx-tutorials/coding_ddpg.py | 6 +-- tutorials/sphinx-tutorials/coding_dqn.py | 8 +-- 8 files changed, 94 insertions(+), 38 deletions(-) diff --git a/docs/source/reference/trainers.rst b/docs/source/reference/trainers.rst index 11384bda0e6..8f6be633743 100644 --- a/docs/source/reference/trainers.rst +++ b/docs/source/reference/trainers.rst @@ -78,8 +78,8 @@ Hooks can be split into 3 categories: **data processing** (``"batch_process"`` a constants update), data subsampling (:class:``~torchrl.trainers.BatchSubSampler``) and such. - **Logging** hooks take a batch of data presented as a ``TensorDict`` and write in the logger - some information retrieved from that data. Examples include the ``Recorder`` hook, the reward - logger (``LogReward``) and such. Hooks should return a dictionary (or a None value) containing the + some information retrieved from that data. Examples include the ``LogValidationReward`` hook, the reward + logger (``LogScaler``) and such. Hooks should return a dictionary (or a None value) containing the data to log. The key ``"log_pbar"`` is reserved to boolean values indicating if the logged value should be displayed on the progression bar printed on the training log. @@ -174,9 +174,9 @@ Trainer and hooks BatchSubSampler ClearCudaCache CountFramesLog - LogReward + LogScaler OptimizerHook - Recorder + LogValidationReward ReplayBufferTrainer RewardNormalizer SelectKeys diff --git a/sota-implementations/redq/utils.py b/sota-implementations/redq/utils.py index 9953fcb3112..fed4922b5a7 100644 --- a/sota-implementations/redq/utils.py +++ b/sota-implementations/redq/utils.py @@ -81,8 +81,8 @@ BatchSubSampler, ClearCudaCache, CountFramesLog, - LogReward, - Recorder, + LogScalar, + LogValidationReward, ReplayBufferTrainer, RewardNormalizer, Trainer, @@ -331,7 +331,7 @@ def make_trainer( if recorder is not None: # create recorder object - recorder_obj = Recorder( + recorder_obj = LogValidationReward( record_frames=cfg.logger.record_frames, frame_skip=cfg.env.frame_skip, policy_exploration=policy_exploration, @@ -347,7 +347,7 @@ def make_trainer( # call recorder - could be removed recorder_obj(None) # create explorative recorder - could be optional - recorder_obj_explore = Recorder( + recorder_obj_explore = LogValidationReward( record_frames=cfg.logger.record_frames, frame_skip=cfg.env.frame_skip, policy_exploration=policy_exploration, @@ -369,7 +369,7 @@ def make_trainer( "post_steps", UpdateWeights(collector, update_weights_interval=1) ) - trainer.register_op("pre_steps_log", LogReward()) + trainer.register_op("pre_steps_log", LogScalar()) trainer.register_op("pre_steps_log", CountFramesLog(frame_skip=cfg.env.frame_skip)) return trainer diff --git a/test/test_trainer.py b/test/test_trainer.py index f7e4ccffdf5..caae5bbe178 100644 --- a/test/test_trainer.py +++ b/test/test_trainer.py @@ -35,14 +35,14 @@ TensorDictReplayBuffer, ) from torchrl.envs.libs.gym import _has_gym -from torchrl.trainers import Recorder, Trainer +from torchrl.trainers import LogValidationReward, Trainer from torchrl.trainers.helpers import transformed_env_constructor from torchrl.trainers.trainers import ( _has_tqdm, _has_ts, BatchSubSampler, CountFramesLog, - LogReward, + LogScalar, mask_batch, OptimizerHook, ReplayBufferTrainer, @@ -638,7 +638,7 @@ def test_log_reward(self, logname, pbar): trainer = mocking_trainer() trainer.collected_frames = 0 - log_reward = LogReward(logname, log_pbar=pbar) + log_reward = LogScalar(logname, log_pbar=pbar) trainer.register_op("pre_steps_log", log_reward) td = TensorDict({REWARD_KEY: torch.ones(3)}, [3]) trainer._pre_steps_log_hook(td) @@ -654,7 +654,7 @@ def test_log_reward_register(self, logname, pbar): trainer = mocking_trainer() trainer.collected_frames = 0 - log_reward = LogReward(logname, log_pbar=pbar) + log_reward = LogScalar(logname, log_pbar=pbar) log_reward.register(trainer) td = TensorDict({REWARD_KEY: torch.ones(3)}, [3]) trainer._pre_steps_log_hook(td) @@ -873,7 +873,7 @@ def test_recorder(self, N=8): logger=logger, )() - recorder = Recorder( + recorder = LogValidationReward( record_frames=args.record_frames, frame_skip=args.frame_skip, policy_exploration=None, @@ -919,13 +919,12 @@ def test_recorder_load(self, backend, N=8): os.environ["CKPT_BACKEND"] = backend state_dict_has_been_called = [False] load_state_dict_has_been_called = [False] - Recorder.state_dict, Recorder_state_dict = _fun_checker( - Recorder.state_dict, state_dict_has_been_called + LogValidationReward.state_dict, Recorder_state_dict = _fun_checker( + LogValidationReward.state_dict, state_dict_has_been_called + ) + (LogValidationReward.load_state_dict, Recorder_load_state_dict,) = _fun_checker( + LogValidationReward.load_state_dict, load_state_dict_has_been_called ) - ( - Recorder.load_state_dict, - Recorder_load_state_dict, - ) = _fun_checker(Recorder.load_state_dict, load_state_dict_has_been_called) args = self._get_args() @@ -948,7 +947,7 @@ def _make_recorder_and_trainer(tmpdirname): )() environment.rollout(2) - recorder = Recorder( + recorder = LogValidationReward( record_frames=args.record_frames, frame_skip=args.frame_skip, policy_exploration=None, @@ -969,8 +968,8 @@ def _make_recorder_and_trainer(tmpdirname): assert recorder2._count == 8 assert state_dict_has_been_called[0] assert load_state_dict_has_been_called[0] - Recorder.state_dict = Recorder_state_dict - Recorder.load_state_dict = Recorder_load_state_dict + LogValidationReward.state_dict = Recorder_state_dict + LogValidationReward.load_state_dict = Recorder_load_state_dict def test_updateweights(): diff --git a/torchrl/trainers/__init__.py b/torchrl/trainers/__init__.py index 364c0dec725..9d593d64f17 100644 --- a/torchrl/trainers/__init__.py +++ b/torchrl/trainers/__init__.py @@ -8,6 +8,8 @@ ClearCudaCache, CountFramesLog, LogReward, + LogScalar, + LogValidationReward, mask_batch, OptimizerHook, Recorder, diff --git a/torchrl/trainers/helpers/trainers.py b/torchrl/trainers/helpers/trainers.py index 207bcec0ffd..4819d9e07e8 100644 --- a/torchrl/trainers/helpers/trainers.py +++ b/torchrl/trainers/helpers/trainers.py @@ -25,8 +25,8 @@ BatchSubSampler, ClearCudaCache, CountFramesLog, - LogReward, - Recorder, + LogScalar, + LogValidationReward, ReplayBufferTrainer, RewardNormalizer, SelectKeys, @@ -259,7 +259,7 @@ def make_trainer( if recorder is not None: # create recorder object - recorder_obj = Recorder( + recorder_obj = LogValidationReward( record_frames=cfg.record_frames, frame_skip=cfg.frame_skip, policy_exploration=policy_exploration, @@ -275,7 +275,7 @@ def make_trainer( # call recorder - could be removed recorder_obj(None) # create explorative recorder - could be optional - recorder_obj_explore = Recorder( + recorder_obj_explore = LogValidationReward( record_frames=cfg.record_frames, frame_skip=cfg.frame_skip, policy_exploration=policy_exploration, @@ -297,7 +297,7 @@ def make_trainer( "post_steps", UpdateWeights(collector, update_weights_interval=1) ) - trainer.register_op("pre_steps_log", LogReward()) + trainer.register_op("pre_steps_log", LogScalar()) trainer.register_op("pre_steps_log", CountFramesLog(frame_skip=cfg.frame_skip)) return trainer diff --git a/torchrl/trainers/trainers.py b/torchrl/trainers/trainers.py index 7e28da45f52..83bd050ef96 100644 --- a/torchrl/trainers/trainers.py +++ b/torchrl/trainers/trainers.py @@ -822,7 +822,7 @@ def __call__(self, *args, **kwargs): torch.cuda.empty_cache() -class LogReward(TrainerHookBase): +class LogScalar(TrainerHookBase): """Reward logger hook. Args: @@ -833,7 +833,7 @@ class LogReward(TrainerHookBase): in the input batch. Defaults to ``("next", "reward")`` Examples: - >>> log_reward = LogReward(("next", "reward")) + >>> log_reward = LogScalar(("next", "reward")) >>> trainer.register_op("pre_steps_log", log_reward) """ @@ -870,6 +870,23 @@ def register(self, trainer: Trainer, name: str = "log_reward"): trainer.register_module(name, self) +class LogReward(LogScalar): + """Deprecated class. Use LogScalar instead.""" + + def __init__( + self, + logname="r_training", + log_pbar: bool = False, + reward_key: Union[str, tuple] = None, + ): + warnings.warn( + "The 'LogReward' class is deprecated and will be removed in v0.9. Please use 'LogScalar' instead.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__(logname=logname, log_pbar=log_pbar, reward_key=reward_key) + + class RewardNormalizer(TrainerHookBase): """Reward normalizer hook. @@ -1127,7 +1144,7 @@ def register(self, trainer: Trainer, name: str = "batch_subsampler"): trainer.register_module(name, self) -class Recorder(TrainerHookBase): +class LogValidationReward(TrainerHookBase): """Recorder hook for :class:`~torchrl.trainers.Trainer`. Args: @@ -1264,6 +1281,44 @@ def register(self, trainer: Trainer, name: str = "recorder"): ) +class Recorder(LogValidationReward): + """Deprecated class. Use LogValidationReward instead.""" + + def __init__( + self, + *, + record_interval: int, + record_frames: int, + frame_skip: int = 1, + policy_exploration: TensorDictModule, + environment: EnvBase = None, + exploration_type: ExplorationType = ExplorationType.RANDOM, + log_keys: Optional[List[Union[str, Tuple[str]]]] = None, + out_keys: Optional[Dict[Union[str, Tuple[str]], str]] = None, + suffix: Optional[str] = None, + log_pbar: bool = False, + recorder: EnvBase = None, + ) -> None: + warnings.warn( + "The 'Recorder' class is deprecated and will be removed in v0.9. Please use 'LogValidationReward' instead.", + DeprecationWarning, + stacklevel=2, + ) + super().__init__( + record_interval=record_interval, + record_frames=record_frames, + frame_skip=frame_skip, + policy_exploration=policy_exploration, + environment=environment, + exploration_type=exploration_type, + log_keys=log_keys, + out_keys=out_keys, + suffix=suffix, + log_pbar=log_pbar, + recorder=recorder, + ) + + class UpdateWeights(TrainerHookBase): """A collector weights update hook class. diff --git a/tutorials/sphinx-tutorials/coding_ddpg.py b/tutorials/sphinx-tutorials/coding_ddpg.py index 906d162f181..70176f9de4a 100644 --- a/tutorials/sphinx-tutorials/coding_ddpg.py +++ b/tutorials/sphinx-tutorials/coding_ddpg.py @@ -883,12 +883,12 @@ def make_ddpg_actor( # # As the training data is obtained using some exploration strategy, the true # performance of our algorithm needs to be assessed in deterministic mode. We -# do this using a dedicated class, ``Recorder``, which executes the policy in +# do this using a dedicated class, ``LogValidationReward``, which executes the policy in # the environment at a given frequency and returns some statistics obtained # from these simulations. # # The following helper function builds this object: -from torchrl.trainers import Recorder +from torchrl.trainers import LogValidationReward def make_recorder(actor_model_explore, transform_state_dict, record_interval): @@ -899,7 +899,7 @@ def make_recorder(actor_model_explore, transform_state_dict, record_interval): ) # must be instantiated to load the state dict environment.transform[2].load_state_dict(transform_state_dict) - recorder_obj = Recorder( + recorder_obj = LogValidationReward( record_frames=1000, policy_exploration=actor_model_explore, environment=environment, diff --git a/tutorials/sphinx-tutorials/coding_dqn.py b/tutorials/sphinx-tutorials/coding_dqn.py index 59188ad21f6..a10e8c1169a 100644 --- a/tutorials/sphinx-tutorials/coding_dqn.py +++ b/tutorials/sphinx-tutorials/coding_dqn.py @@ -140,8 +140,8 @@ from torchrl.objectives import DQNLoss, SoftUpdate from torchrl.record.loggers.csv import CSVLogger from torchrl.trainers import ( - LogReward, - Recorder, + LogScalar, + LogValidationReward, ReplayBufferTrainer, Trainer, UpdateWeights, @@ -666,7 +666,7 @@ def get_loss_module(actor, gamma): buffer_hook.register(trainer) weight_updater = UpdateWeights(collector, update_weights_interval=1) weight_updater.register(trainer) -recorder = Recorder( +recorder = LogValidationReward( record_interval=100, # log every 100 optimization steps record_frames=1000, # maximum number of frames in the record frame_skip=1, @@ -704,7 +704,7 @@ def get_loss_module(actor, gamma): # This will be reflected by the `total_rewards` value displayed in the # progress bar. # -log_reward = LogReward(log_pbar=True) +log_reward = LogScalar(log_pbar=True) log_reward.register(trainer) ############################################################################### From 3da76f0063aac5880312832a6a449f2adb9caf91 Mon Sep 17 00:00:00 2001 From: Oliver Slumbers <40644337+oslumbers@users.noreply.github.com> Date: Tue, 3 Dec 2024 14:42:16 +0000 Subject: [PATCH 11/27] [Feature] ActionDiscretizer custom sampling (#2609) Co-authored-by: Oliver Slumbers --- torchrl/envs/transforms/transforms.py | 46 ++++++++++++++------------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 7ab5a2deb72..980273af96c 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -8583,6 +8583,26 @@ def _indent(s): f"\n{_indent(out_action_key)},\n{_indent(sampling)},\n{_indent(categorical)})" ) + def _custom_arange(self, nint, device): + result = torch.arange( + start=0.0, + end=1.0, + step=1 / nint, + dtype=self.dtype, + device=device, + ) + result_ = result + if self.sampling in ( + self.SamplingStrategy.HIGH, + self.SamplingStrategy.MEDIAN, + ): + result_ = (1 - result).flip(0) + if self.sampling == self.SamplingStrategy.MEDIAN: + result = (result + result_) / 2 + else: + result = result_ + return result + def transform_input_spec(self, input_spec): try: action_spec = self.parent.full_action_spec_unbatched[self.in_keys_inv[0]] @@ -8611,29 +8631,11 @@ def transform_input_spec(self, input_spec): num_intervals = int(num_intervals.squeeze()) self.num_intervals = torch.as_tensor(num_intervals) - def custom_arange(nint): - result = torch.arange( - start=0.0, - end=1.0, - step=1 / nint, - dtype=self.dtype, - device=action_spec.device, - ) - result_ = result - if self.sampling in ( - self.SamplingStrategy.HIGH, - self.SamplingStrategy.MEDIAN, - ): - result_ = (1 - result).flip(0) - if self.sampling == self.SamplingStrategy.MEDIAN: - result = (result + result_) / 2 - else: - result = result_ - return result - if isinstance(num_intervals, int): arange = ( - custom_arange(num_intervals).expand((*n_act, num_intervals)) + self._custom_arange(num_intervals, action_spec.device).expand( + (*n_act, num_intervals) + ) * interval ) low = action_spec.low @@ -8642,7 +8644,7 @@ def custom_arange(nint): self.register_buffer("intervals", low + arange) else: arange = [ - custom_arange(_num_intervals) * interval + self._custom_arange(_num_intervals, action_spec.device) * interval for _num_intervals, interval in zip( num_intervals.tolist(), interval.unbind(-2) ) From aed03fda451e1abebad6f7310c974b1b372c4a61 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 3 Dec 2024 14:50:26 +0000 Subject: [PATCH 12/27] [CI] Fix dreamer run in SOTA tests ghstack-source-id: dfe3ab6fe0d29fcdcaf57f31f84d04e07e36bad3 Pull Request resolved: https://github.com/pytorch/rl/pull/2627 --- .github/unittest/linux_sota/scripts/test_sota.py | 4 ++-- sota-implementations/dreamer/dreamer.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/.github/unittest/linux_sota/scripts/test_sota.py b/.github/unittest/linux_sota/scripts/test_sota.py index d42f96d5ee1..b7af381634c 100644 --- a/.github/unittest/linux_sota/scripts/test_sota.py +++ b/.github/unittest/linux_sota/scripts/test_sota.py @@ -190,12 +190,12 @@ logger.backend= """, "dreamer": """python sota-implementations/dreamer/dreamer.py \ - collector.total_frames=200 \ + collector.total_frames=600 \ collector.init_random_frames=10 \ collector.frames_per_batch=200 \ env.n_parallel_envs=1 \ optimization.optim_steps_per_batch=1 \ - logger.video=True \ + logger.video=False \ logger.backend=csv \ replay_buffer.buffer_size=120 \ replay_buffer.batch_size=24 \ diff --git a/sota-implementations/dreamer/dreamer.py b/sota-implementations/dreamer/dreamer.py index d97066b87c5..1b9823c1dd1 100644 --- a/sota-implementations/dreamer/dreamer.py +++ b/sota-implementations/dreamer/dreamer.py @@ -321,8 +321,10 @@ def compile_rssms(module): t_collect_init = time.time() - test_env.close() - train_env.close() + if not test_env.is_closed: + test_env.close() + if not train_env.is_closed: + train_env.close() collector.shutdown() del test_env From 1ca134cc3243b28295ce9c2e8bf363814fa8ce32 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 3 Dec 2024 15:05:44 +0000 Subject: [PATCH 13/27] [BugFix] Fix MARL PPO tutorial action_spec call ghstack-source-id: 1d9058c45b28c0f0279e4243a2a0f96c622a51d8 Pull Request resolved: https://github.com/pytorch/rl/pull/2628 --- tutorials/sphinx-tutorials/multiagent_ppo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tutorials/sphinx-tutorials/multiagent_ppo.py b/tutorials/sphinx-tutorials/multiagent_ppo.py index e2ca3f6ecd8..0e6cc51adf6 100644 --- a/tutorials/sphinx-tutorials/multiagent_ppo.py +++ b/tutorials/sphinx-tutorials/multiagent_ppo.py @@ -450,8 +450,8 @@ out_keys=[env.action_key], distribution_class=TanhNormal, distribution_kwargs={ - "low": env.action_spec_unbatched[env.action_key].space.low, - "high": env.action_spec_unbatched[env.action_key].space.high, + "low": env.full_action_spec_unbatched[env.action_key].space.low, + "high": env.full_action_spec_unbatched[env.action_key].space.high, }, return_log_prob=True, log_prob_key=("agents", "sample_log_prob"), From 1cffffee92a37d16df3ddaf94fc29cc4b3292d5a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 3 Dec 2024 15:05:46 +0000 Subject: [PATCH 14/27] [BugFix] Fix export aoti_compile_and_package API change ghstack-source-id: 07a0f063f8955815157c2a3eac02c6460a82f672 Pull Request resolved: https://github.com/pytorch/rl/pull/2629 --- tutorials/sphinx-tutorials/export.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tutorials/sphinx-tutorials/export.py b/tutorials/sphinx-tutorials/export.py index 48dd8723ffc..d40ef09ff8c 100644 --- a/tutorials/sphinx-tutorials/export.py +++ b/tutorials/sphinx-tutorials/export.py @@ -343,8 +343,6 @@ with torch.no_grad(): pkg_path = aoti_compile_and_package( exported_policy, - args=(), - kwargs={"pixels": pixels}, # Specify the generated shared library path package_path=path, ) From 594462d6b0e5f2e18b417e177828ebfb4ac16235 Mon Sep 17 00:00:00 2001 From: kurtamohler Date: Wed, 4 Dec 2024 02:34:15 -0800 Subject: [PATCH 15/27] [Feature] Add `Stack` transform (#2567) --- .../scripts_unity_mlagents/run_test.sh | 1 + test/mocking_classes.py | 157 ++++++- test/test_transforms.py | 394 +++++++++++++++++- torchrl/envs/__init__.py | 1 + torchrl/envs/libs/unity_mlagents.py | 4 +- torchrl/envs/transforms/__init__.py | 1 + torchrl/envs/transforms/transforms.py | 211 ++++++++++ 7 files changed, 764 insertions(+), 5 deletions(-) diff --git a/.github/unittest/linux_libs/scripts_unity_mlagents/run_test.sh b/.github/unittest/linux_libs/scripts_unity_mlagents/run_test.sh index d5bb8695c44..05eb63c2b51 100755 --- a/.github/unittest/linux_libs/scripts_unity_mlagents/run_test.sh +++ b/.github/unittest/linux_libs/scripts_unity_mlagents/run_test.sh @@ -23,6 +23,7 @@ conda deactivate && conda activate ./env python -c "import mlagents_envs" python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestUnityMLAgents --runslow +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_transforms.py --instafail -v --durations 200 --capture no -k test_transform_env[unity] coverage combine coverage xml -i diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 3c30286c419..4e943e03cfc 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -2,7 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional +from typing import Dict, List, Optional import torch import torch.nn as nn @@ -24,7 +24,12 @@ from torchrl.data.utils import consolidate_spec from torchrl.envs.common import EnvBase from torchrl.envs.model_based.common import ModelBasedEnvBase -from torchrl.envs.utils import _terminated_or_truncated +from torchrl.envs.utils import ( + _terminated_or_truncated, + check_marl_grouping, + MarlGroupMapType, +) + spec_dict = { "bounded": Bounded, @@ -1059,6 +1064,154 @@ def _step( return tensordict +class MultiAgentCountingEnv(EnvBase): + """A multi-agent env that is done after a given number of steps. + + All agents have identical specs. + + The count is incremented by 1 on each step. + + """ + + def __init__( + self, + n_agents: int, + group_map: MarlGroupMapType + | Dict[str, List[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP, + max_steps: int = 5, + start_val: int = 0, + **kwargs, + ): + super().__init__(**kwargs) + self.max_steps = max_steps + self.start_val = start_val + self.n_agents = n_agents + self.agent_names = [f"agent_{idx}" for idx in range(n_agents)] + + if isinstance(group_map, MarlGroupMapType): + group_map = group_map.get_group_map(self.agent_names) + check_marl_grouping(group_map, self.agent_names) + + self.group_map = group_map + + observation_specs = {} + reward_specs = {} + done_specs = {} + action_specs = {} + + for group_name, agents in group_map.items(): + observation_specs[group_name] = {} + reward_specs[group_name] = {} + done_specs[group_name] = {} + action_specs[group_name] = {} + + for agent_name in agents: + observation_specs[group_name][agent_name] = Composite( + observation=Unbounded( + ( + *self.batch_size, + 3, + 4, + ), + dtype=torch.float32, + device=self.device, + ), + shape=self.batch_size, + device=self.device, + ) + reward_specs[group_name][agent_name] = Composite( + reward=Unbounded( + ( + *self.batch_size, + 1, + ), + device=self.device, + ), + shape=self.batch_size, + device=self.device, + ) + done_specs[group_name][agent_name] = Composite( + done=Categorical( + 2, + dtype=torch.bool, + shape=( + *self.batch_size, + 1, + ), + device=self.device, + ), + shape=self.batch_size, + device=self.device, + ) + action_specs[group_name][agent_name] = Composite( + action=Binary(n=1, shape=[*self.batch_size, 1], device=self.device), + shape=self.batch_size, + device=self.device, + ) + + self.observation_spec = Composite(observation_specs) + self.reward_spec = Composite(reward_specs) + self.done_spec = Composite(done_specs) + self.action_spec = Composite(action_specs) + self.register_buffer( + "count", + torch.zeros((*self.batch_size, 1), device=self.device, dtype=torch.int), + ) + + def _set_seed(self, seed: Optional[int]): + torch.manual_seed(seed) + + def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: + if tensordict is not None and "_reset" in tensordict.keys(): + _reset = tensordict.get("_reset") + self.count[_reset] = self.start_val + else: + self.count[:] = self.start_val + + source = {} + for group_name, agents in self.group_map.items(): + source[group_name] = {} + for agent_name in agents: + source[group_name][agent_name] = TensorDict( + source={ + "observation": torch.rand( + (*self.batch_size, 3, 4), device=self.device + ), + "done": self.count > self.max_steps, + "terminated": self.count > self.max_steps, + }, + batch_size=self.batch_size, + device=self.device, + ) + + tensordict = TensorDict(source, batch_size=self.batch_size, device=self.device) + return tensordict + + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + self.count += 1 + source = {} + for group_name, agents in self.group_map.items(): + source[group_name] = {} + for agent_name in agents: + source[group_name][agent_name] = TensorDict( + source={ + "observation": torch.rand( + (*self.batch_size, 3, 4), device=self.device + ), + "done": self.count > self.max_steps, + "terminated": self.count > self.max_steps, + "reward": torch.zeros_like(self.count, dtype=torch.float), + }, + batch_size=self.batch_size, + device=self.device, + ) + tensordict = TensorDict(source, batch_size=self.batch_size, device=self.device) + return tensordict + + class IncrementingEnv(CountingEnv): # Same as CountingEnv but always increments the count by 1 regardless of the action. def _step( diff --git a/test/test_transforms.py b/test/test_transforms.py index ae428d35d97..d90c00b6a19 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -45,6 +45,7 @@ IncrementingEnv, MockBatchedLockedEnv, MockBatchedUnLockedEnv, + MultiAgentCountingEnv, MultiKeyCountingEnv, MultiKeyCountingEnvPolicy, NestedCountingEnv, @@ -71,6 +72,7 @@ IncrementingEnv, MockBatchedLockedEnv, MockBatchedUnLockedEnv, + MultiAgentCountingEnv, MultiKeyCountingEnv, MultiKeyCountingEnvPolicy, NestedCountingEnv, @@ -134,6 +136,7 @@ SerialEnv, SignTransform, SqueezeTransform, + Stack, StepCounter, TargetReturn, TensorDictPrimer, @@ -141,12 +144,14 @@ ToTensorImage, TrajCounter, TransformedEnv, + UnityMLAgentsEnv, UnsqueezeTransform, VC1Transform, VIPTransform, ) from torchrl.envs.libs.dm_control import _has_dm_control from torchrl.envs.libs.gym import _has_gym, GymEnv, set_gym_backend +from torchrl.envs.libs.unity_mlagents import _has_unity_mlagents from torchrl.envs.transforms import VecNorm from torchrl.envs.transforms.r3m import _R3MNet from torchrl.envs.transforms.rlhf import KLRewardTransform @@ -159,7 +164,7 @@ ) from torchrl.envs.transforms.vc1 import _has_vc from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform -from torchrl.envs.utils import check_env_specs, step_mdp +from torchrl.envs.utils import check_env_specs, MarlGroupMapType, step_mdp from torchrl.modules import GRUModule, LSTMModule, MLP, ProbabilisticActor, TanhNormal from torchrl.modules.utils import get_primers_from_module @@ -2149,6 +2154,393 @@ def test_transform_no_env(self, device, batch): pytest.skip("TrajCounter cannot be called without env") +class TestStack(TransformBase): + def test_single_trans_env_check(self): + t = Stack( + in_keys=["observation", "observation_orig"], + out_key="observation_out", + dim=-1, + del_keys=False, + ) + env = TransformedEnv(ContinuousActionVecMockEnv(), t) + check_env_specs(env) + + def test_serial_trans_env_check(self): + def make_env(): + t = Stack( + in_keys=["observation", "observation_orig"], + out_key="observation_out", + dim=-1, + del_keys=False, + ) + return TransformedEnv(ContinuousActionVecMockEnv(), t) + + env = SerialEnv(2, make_env) + check_env_specs(env) + + def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv): + def make_env(): + t = Stack( + in_keys=["observation", "observation_orig"], + out_key="observation_out", + dim=-1, + del_keys=False, + ) + return TransformedEnv(ContinuousActionVecMockEnv(), t) + + env = maybe_fork_ParallelEnv(2, make_env) + try: + check_env_specs(env) + finally: + try: + env.close() + except RuntimeError: + pass + + def test_trans_serial_env_check(self): + t = Stack( + in_keys=["observation", "observation_orig"], + out_key="observation_out", + dim=-2, + del_keys=False, + ) + + env = TransformedEnv(SerialEnv(2, ContinuousActionVecMockEnv), t) + check_env_specs(env) + + def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): + t = Stack( + in_keys=["observation", "observation_orig"], + out_key="observation_out", + dim=-2, + del_keys=False, + ) + + env = TransformedEnv(maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv), t) + try: + check_env_specs(env) + finally: + try: + env.close() + except RuntimeError: + pass + + @pytest.mark.parametrize("del_keys", [True, False]) + def test_transform_del_keys(self, del_keys): + td_orig = TensorDict( + { + "group_0": TensorDict( + { + "agent_0": TensorDict({"obs": torch.randn(10)}), + "agent_1": TensorDict({"obs": torch.randn(10)}), + } + ), + "group_1": TensorDict( + { + "agent_2": TensorDict({"obs": torch.randn(10)}), + "agent_3": TensorDict({"obs": torch.randn(10)}), + } + ), + } + ) + t = Stack( + in_keys=[ + ("group_0", "agent_0", "obs"), + ("group_0", "agent_1", "obs"), + ("group_1", "agent_2", "obs"), + ("group_1", "agent_3", "obs"), + ], + out_key="observations", + del_keys=del_keys, + ) + td = td_orig.clone() + t(td) + keys = td.keys(include_nested=True) + if del_keys: + assert ("group_0",) not in keys + assert ("group_0", "agent_0", "obs") not in keys + assert ("group_0", "agent_1", "obs") not in keys + assert ("group_1", "agent_2", "obs") not in keys + assert ("group_1", "agent_3", "obs") not in keys + else: + assert ("group_0", "agent_0", "obs") in keys + assert ("group_0", "agent_1", "obs") in keys + assert ("group_1", "agent_2", "obs") in keys + assert ("group_1", "agent_3", "obs") in keys + + assert ("observations",) in keys + + def _test_transform_no_env_tensor(self, compose=False): + td_orig = TensorDict( + { + "key1": torch.rand(1, 3), + "key2": torch.rand(1, 3), + "key3": torch.rand(1, 3), + }, + [1], + ) + td = td_orig.clone() + t = Stack( + in_keys=[("key1",), ("key2",)], + out_key=("stacked",), + in_key_inv=("stacked",), + out_keys_inv=[("key1",), ("key2",)], + dim=-2, + ) + if compose: + t = Compose(t) + + td = t(td) + + assert ("key1",) not in td.keys() + assert ("key2",) not in td.keys() + assert ("key3",) in td.keys() + assert ("stacked",) in td.keys() + + assert td["stacked"].shape == torch.Size([1, 2, 3]) + assert (td["stacked"][:, 0] == td_orig["key1"]).all() + assert (td["stacked"][:, 1] == td_orig["key2"]).all() + + td = t.inv(td) + assert (td == td_orig).all() + + def _test_transform_no_env_tensordict(self, compose=False): + def gen_value(): + return TensorDict( + { + "a": torch.rand(3), + "b": torch.rand(2, 4), + } + ) + + td_orig = TensorDict( + { + "key1": gen_value(), + "key2": gen_value(), + "key3": gen_value(), + }, + [], + ) + td = td_orig.clone() + t = Stack( + in_keys=[("key1",), ("key2",)], + out_key=("stacked",), + in_key_inv=("stacked",), + out_keys_inv=[("key1",), ("key2",)], + dim=0, + allow_positive_dim=True, + ) + if compose: + t = Compose(t) + td = t(td) + + assert ("key1",) not in td.keys() + assert ("key2",) not in td.keys() + assert ("stacked", "a") in td.keys(include_nested=True) + assert ("stacked", "b") in td.keys(include_nested=True) + assert ("key3",) in td.keys() + + assert td["stacked", "a"].shape == torch.Size([2, 3]) + assert td["stacked", "b"].shape == torch.Size([2, 2, 4]) + assert (td["stacked"][0] == td_orig["key1"]).all() + assert (td["stacked"][1] == td_orig["key2"]).all() + assert (td["key3"] == td_orig["key3"]).all() + + td = t.inv(td) + assert (td == td_orig).all() + + @pytest.mark.parametrize("datatype", ["tensor", "tensordict"]) + def test_transform_no_env(self, datatype): + if datatype == "tensor": + self._test_transform_no_env_tensor() + + elif datatype == "tensordict": + self._test_transform_no_env_tensordict() + + else: + raise RuntimeError(f"please add a test case for datatype {datatype}") + + @pytest.mark.parametrize("datatype", ["tensor", "tensordict"]) + def test_transform_compose(self, datatype): + if datatype == "tensor": + self._test_transform_no_env_tensor(compose=True) + + elif datatype == "tensordict": + self._test_transform_no_env_tensordict(compose=True) + + else: + raise RuntimeError(f"please add a test case for datatype {datatype}") + + @pytest.mark.parametrize("envtype", ["mock", "unity"]) + def test_transform_env(self, envtype): + if envtype == "mock": + base_env = MultiAgentCountingEnv( + n_agents=5, + ) + rollout_len = 6 + t = Stack( + in_keys=[ + ("agents", "agent_0"), + ("agents", "agent_2"), + ("agents", "agent_3"), + ], + out_key="stacked_agents", + in_key_inv="stacked_agents", + out_keys_inv=[ + ("agents", "agent_0"), + ("agents", "agent_2"), + ("agents", "agent_3"), + ], + ) + + elif envtype == "unity": + if not _has_unity_mlagents: + raise pytest.skip("mlagents not installed") + base_env = UnityMLAgentsEnv( + registered_name="3DBall", + no_graphics=True, + group_map=MarlGroupMapType.ALL_IN_ONE_GROUP, + ) + rollout_len = 200 + t = Stack( + in_keys=[("agents", f"agent_{idx}") for idx in range(12)], + out_key="stacked_agents", + in_key_inv="stacked_agents", + out_keys_inv=[("agents", f"agent_{idx}") for idx in range(12)], + ) + + try: + env = TransformedEnv(base_env, t) + check_env_specs(env) + + if envtype == "mock": + base_env.set_seed(123) + td_orig = base_env.reset() + if envtype == "mock": + env.set_seed(123) + td = env.reset() + + td_keys = td.keys(include_nested=True) + + if envtype == "mock": + assert ("agents", "agent_0") not in td_keys + assert ("agents", "agent_2") not in td_keys + assert ("agents", "agent_3") not in td_keys + assert ("agents", "agent_1") in td_keys + assert ("agents", "agent_4") in td_keys + assert ("stacked_agents",) in td_keys + + assert (td["stacked_agents"][0] == td_orig["agents", "agent_0"]).all() + assert (td["stacked_agents"][1] == td_orig["agents", "agent_2"]).all() + assert (td["stacked_agents"][2] == td_orig["agents", "agent_3"]).all() + assert (td["agents", "agent_1"] == td_orig["agents", "agent_1"]).all() + assert (td["agents", "agent_4"] == td_orig["agents", "agent_4"]).all() + else: + assert ("agents",) not in td_keys + assert ("stacked_agents",) in td_keys + assert td["stacked_agents"].shape[0] == 12 + + assert ("agents",) not in env.full_action_spec.keys(include_nested=True) + assert ("stacked_agents",) in env.full_action_spec.keys( + include_nested=True + ) + + td = env.step(env.full_action_spec.rand()) + td = env.rollout(rollout_len) + + if envtype == "mock": + assert td["next", "stacked_agents", "done"].shape == torch.Size( + [6, 3, 1] + ) + assert not (td["next", "stacked_agents", "done"][:-1]).any() + assert (td["next", "stacked_agents", "done"][-1]).all() + finally: + base_env.close() + + def test_transform_model(self): + t = Stack( + in_keys=[("next", "observation"), ("observation",)], + out_key="observation_out", + dim=-2, + del_keys=True, + ) + model = nn.Sequential(t, nn.Identity()) + td = TensorDict( + {("next", "observation"): torch.randn(3), "observation": torch.randn(3)}, [] + ) + td = model(td) + assert "observation_out" in td.keys() + assert "observation" not in td.keys() + assert ("next", "observation") not in td.keys(True) + + @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) + def test_transform_rb(self, rbclass): + t = Stack( + in_keys=[("next", "observation"), "observation"], + out_key="observation_out", + dim=-2, + del_keys=True, + ) + rb = rbclass(storage=LazyTensorStorage(10)) + rb.append_transform(t) + td = TensorDict( + { + "observation": TensorDict({"stuff": torch.randn(3, 4)}, [3, 4]), + "next": TensorDict( + {"observation": TensorDict({"stuff": torch.randn(3, 4)}, [3, 4])}, + [], + ), + }, + [], + ).expand(10) + rb.extend(td) + td = rb.sample(2) + assert "observation_out" in td.keys() + assert "observation" not in td.keys() + assert ("next", "observation") not in td.keys(True) + + def test_transform_inverse(self): + td_orig = TensorDict( + { + "stacked": torch.rand(1, 2, 3), + "key3": torch.rand(1, 3), + }, + [1], + ) + td = td_orig.clone() + t = Stack( + in_keys=[("key1",), ("key2",)], + out_key=("stacked",), + in_key_inv=("stacked",), + out_keys_inv=[("key1",), ("key2",)], + dim=1, + allow_positive_dim=True, + ) + + td = t.inv(td) + + assert ("key1",) in td.keys() + assert ("key2",) in td.keys() + assert ("key3",) in td.keys() + assert ("stacked",) not in td.keys() + assert (td["key1"] == td_orig["stacked"][:, 0]).all() + assert (td["key2"] == td_orig["stacked"][:, 1]).all() + + td = t(td) + assert (td == td_orig).all() + + # Check that if `out_key` is not in the tensordict, + # then the inverse transform does nothing. + t = Stack( + in_keys=[("key1",), ("key2",)], + out_key=("sacked",), + dim=1, + allow_positive_dim=True, + ) + td = t.inv(td) + assert (td == td_orig).all() + + class TestCatTensors(TransformBase): @pytest.mark.parametrize("append", [True, False]) def test_cattensors_empty(self, append): diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 4cfb00cc307..36e4ec1a908 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -87,6 +87,7 @@ SelectTransform, SignTransform, SqueezeTransform, + Stack, StepCounter, TargetReturn, TensorDictPrimer, diff --git a/torchrl/envs/libs/unity_mlagents.py b/torchrl/envs/libs/unity_mlagents.py index 95c2460bc83..5aeabc4d0aa 100644 --- a/torchrl/envs/libs/unity_mlagents.py +++ b/torchrl/envs/libs/unity_mlagents.py @@ -363,12 +363,12 @@ def _make_td_out(self, tensordict_in, is_reset=False): # Add rewards if not is_reset: source[group_name][agent_name]["reward"] = torch.tensor( - steps.reward[steps_idx], + [steps.reward[steps_idx]], device=self.device, dtype=torch.float32, ) source[group_name][agent_name]["group_reward"] = torch.tensor( - steps.group_reward[steps_idx], + [steps.group_reward[steps_idx]], device=self.device, dtype=torch.float32, ) diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index bccbd9a4543..77f6ecc03bf 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -48,6 +48,7 @@ SelectTransform, SignTransform, SqueezeTransform, + Stack, StepCounter, TargetReturn, TensorDictPrimer, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 980273af96c..0bab5868ded 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4323,6 +4323,217 @@ def __repr__(self) -> str: ) +class Stack(Transform): + """Stacks tensors and tensordicts. + + Concatenates a sequence of tensors or tensordicts along a new dimension. + The tensordicts or tensors under ``in_keys`` must all have the same shapes. + + This transform only stacks the inputs into one output key. Stacking multiple + groups of input keys into different output keys requires multiple + transforms. + + This transform can be useful for environments that have multiple agents with + identical specs under different keys. The specs and tensordicts for the + agents can be stacked together under a shared key, in order to run MARL + algorithms that expect the tensors for observations, rewards, etc. to + contain batched data for all the agents. + + Args: + in_keys (sequence of NestedKey): keys to be stacked. + out_key (NestedKey): key of the resulting stacked entry. + in_key_inv (NestedKey, optional): key to unstack during :meth:`~.inv` + calls. Default is ``None``. + out_keys_inv (sequence of NestedKey, optional): keys of the resulting + unstacked entries after :meth:`~.inv` calls. Default is ``None``. + dim (int, optional): dimension to insert. Default is ``-1``. + allow_positive_dim (bool, optional): if ``True``, positive dimensions + are accepted. Defaults to ``False``, ie. non-negative dimensions are + not permitted. + + Keyword Args: + del_keys (bool, optional): if ``True``, the input values will be deleted + after stacking. Default is ``True``. + + Examples: + >>> import torch + >>> from tensordict import TensorDict + >>> from torchrl.envs import Stack + >>> td = TensorDict({"key1": torch.zeros(3), "key2": torch.ones(3)}, []) + >>> td + TensorDict( + fields={ + key1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), + key2: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + >>> transform = Stack(in_keys=["key1", "key2"], out_key="out", dim=-2) + >>> transform(td) + TensorDict( + fields={ + out: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + >>> td["out"] + tensor([[0., 0., 0.], + [1., 1., 1.]]) + + >>> agent_0 = TensorDict({"obs": torch.rand(4, 5), "reward": torch.zeros(1)}) + >>> agent_1 = TensorDict({"obs": torch.rand(4, 5), "reward": torch.zeros(1)}) + >>> td = TensorDict({"agent_0": agent_0, "agent_1": agent_1}) + >>> transform = Stack(in_keys=["agent_0", "agent_1"], out_key="agents") + >>> transform(td) + TensorDict( + fields={ + agents: TensorDict( + fields={ + obs: Tensor(shape=torch.Size([2, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False), + reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([2]), + device=None, + is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + """ + + invertible = True + + def __init__( + self, + in_keys: Sequence[NestedKey], + out_key: NestedKey, + in_key_inv: NestedKey | None = None, + out_keys_inv: Sequence[NestedKey] | None = None, + dim: int = -1, + allow_positive_dim: bool = False, + *, + del_keys: bool = True, + ): + if not allow_positive_dim and dim >= 0: + raise ValueError( + "dim should be negative to accommodate for envs of different " + "batch_sizes. If you need dim to be positive, set " + "allow_positive_dim=True." + ) + + if in_key_inv is None and out_keys_inv is not None: + raise ValueError("out_keys_inv was specified, but in_key_inv was not") + elif in_key_inv is not None and out_keys_inv is None: + raise ValueError("in_key_inv was specified, but out_keys_inv was not") + + super(Stack, self).__init__( + in_keys=in_keys, + out_keys=[out_key], + in_keys_inv=None if in_key_inv is None else [in_key_inv], + out_keys_inv=out_keys_inv, + ) + + for in_key in self.in_keys: + if len(in_key) == len(self.out_keys[0]): + if all(k1 == k2 for k1, k2 in zip(in_key, self.out_keys[0])): + raise ValueError(f"{self}: out_key cannot be in in_keys") + parent_keys = [] + for key in self.in_keys: + if isinstance(key, (list, tuple)): + for parent_level in range(1, len(key)): + parent_key = tuple(key[:-parent_level]) + if parent_key not in parent_keys: + parent_keys.append(parent_key) + self._maybe_del_parent_keys = sorted(parent_keys, key=len, reverse=True) + self.dim = dim + self._del_keys = del_keys + self._keys_to_exclude = None + + def _call(self, tensordict: TensorDictBase) -> TensorDictBase: + values = [] + for in_key in self.in_keys: + value = tensordict.get(in_key, default=None) + if value is not None: + values.append(value) + elif not self.missing_tolerance: + raise KeyError( + f"{self}: '{in_key}' not found in tensordict {tensordict}" + ) + + out_tensor = torch.stack(values, dim=self.dim) + tensordict.set(self.out_keys[0], out_tensor) + if self._del_keys: + tensordict.exclude(*self.in_keys, inplace=True) + for parent_key in self._maybe_del_parent_keys: + if len(tensordict[parent_key].keys()) == 0: + tensordict.exclude(parent_key, inplace=True) + return tensordict + + forward = _call + + def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: + if len(self.in_keys_inv) == 0: + return tensordict + + if self.in_keys_inv[0] not in tensordict.keys(include_nested=True): + return tensordict + values = torch.unbind(tensordict[self.in_keys_inv[0]], dim=self.dim) + for value, out_key_inv in _zip_strict(values, self.out_keys_inv): + tensordict = tensordict.set(out_key_inv, value) + return tensordict.exclude(self.in_keys_inv[0]) + + def _reset( + self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase + ) -> TensorDictBase: + with _set_missing_tolerance(self, True): + tensordict_reset = self._call(tensordict_reset) + return tensordict_reset + + def _transform_spec(self, spec: TensorSpec) -> TensorSpec: + if not isinstance(spec, Composite): + raise TypeError(f"{self}: Only specs of type Composite can be transformed") + + spec_keys = spec.keys(include_nested=True) + keys_to_stack = [key for key in spec_keys if key in self.in_keys] + specs_to_stack = [spec[key] for key in keys_to_stack] + + if len(specs_to_stack) == 0: + return spec + + stacked_specs = torch.stack(specs_to_stack, dim=self.dim) + spec.set(self.out_keys[0], stacked_specs) + + if self._del_keys: + for key in keys_to_stack: + del spec[key] + for parent_key in self._maybe_del_parent_keys: + if len(spec[parent_key]) == 0: + del spec[parent_key] + + return spec + + def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: + self._transform_spec(input_spec["full_state_spec"]) + self._transform_spec(input_spec["full_action_spec"]) + return input_spec + + def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: + return self._transform_spec(observation_spec) + + def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: + return self._transform_spec(reward_spec) + + def transform_done_spec(self, done_spec: TensorSpec) -> TensorSpec: + return self._transform_spec(done_spec) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"in_keys={self.in_keys}, " + f"out_key={self.out_keys[0]}, " + f"dim={self.dim}" + ")" + ) + + class DiscreteActionProjection(Transform): """Projects discrete actions from a high dimensional space to a low dimensional space. From e7062a1d68caccf5b8a9f8ad35aef366f98cd46f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 4 Dec 2024 11:34:04 +0000 Subject: [PATCH 16/27] [BugFix] Fix typing for python 3.9 ghstack-source-id: 663da84096214611804a726e2d38d27a6f21c958 Pull Request resolved: https://github.com/pytorch/rl/pull/2631 --- test/mocking_classes.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index 4e943e03cfc..b6f4ac7069b 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + from typing import Dict, List, Optional import torch From 2511c04a533e191d8200f75a60951385438e8e1e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 4 Dec 2024 12:52:56 +0000 Subject: [PATCH 17/27] [CI] Change doc image ghstack-source-id: eceab242294ec55135d79f29e848345a5d5d455e Pull Request resolved: https://github.com/pytorch/rl/pull/2632 --- .github/workflows/docs.yml | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 10ea80c1dcd..77abee7d4fc 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -28,13 +28,23 @@ jobs: with: repository: pytorch/rl upload-artifact: docs - runner: "linux.g5.4xlarge.nvidia.gpu" - docker-image: "nvidia/cudagl:11.4.0-base" timeout: 120 script: | set -e set -v - apt-get update && apt-get install -y -f git wget gcc g++ dialog apt-utils + # apt-get update && apt-get install -y -f git wget gcc g++ dialog apt-utils + yum makecache + # yum install -y glfw glew mesa-libGL mesa-libGL-devel mesa-libOSMesa-devel egl-utils freeglut + # Install Mesa and OpenGL Libraries: + yum install -y glfw mesa-libGL mesa-libGL-devel egl-utils freeglut mesa-libGLU mesa-libEGL + # Install DRI Drivers: + yum install -y mesa-dri-drivers + # Install Xvfb for Headless Environments: + yum install -y xorg-x11-server-Xvfb + # xhost +local:docker + # Xvfb :1 -screen 0 1024x768x24 & + # export DISPLAY=:1 + root_dir="$(pwd)" conda_dir="${root_dir}/conda" env_dir="${root_dir}/env" @@ -51,7 +61,7 @@ jobs: conda activate "${env_dir}" # 2. upgrade pip, ninja and packaging - apt-get install python3-pip unzip -y -f + # apt-get install python3-pip unzip -y -f python3 -m pip install --upgrade pip python3 -m pip install setuptools ninja packaging cmake -U From b840a772c4ed7446cbba3241f1065f18539c0149 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 10 Dec 2024 10:54:55 -0800 Subject: [PATCH 18/27] [Example] Efficient Trajectory Sampling with CompletedTrajRepertoire ghstack-source-id: 4d5c587c69230aa8f3a1b9b6fe19f52fa683d703 Pull Request resolved: https://github.com/pytorch/rl/pull/2642 --- .../replay-buffers/filter-imcomplete-trajs.py | 89 +++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 examples/replay-buffers/filter-imcomplete-trajs.py diff --git a/examples/replay-buffers/filter-imcomplete-trajs.py b/examples/replay-buffers/filter-imcomplete-trajs.py new file mode 100644 index 00000000000..271c7c00831 --- /dev/null +++ b/examples/replay-buffers/filter-imcomplete-trajs.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Efficient Trajectory Sampling with CompletedTrajRepertoire + +This example demonstrates how to design a custom transform that filters trajectories during sampling, +ensuring that only completed trajectories are present in sampled batches. This can be particularly useful +when dealing with environments where some trajectories might be corrupted or never reach a done state, +which could skew the learning process or lead to biased models. For instance, in robotics or autonomous +driving, a trajectory might be interrupted due to external factors such as hardware failures or human +intervention, resulting in incomplete or inconsistent data. By filtering out these incomplete trajectories, +we can improve the quality of the training data and increase the robustness of our models. +""" + +import torch +from tensordict import TensorDictBase +from torchrl.data import LazyTensorStorage, ReplayBuffer +from torchrl.envs import GymEnv, TrajCounter, Transform + + +class CompletedTrajectoryRepertoire(Transform): + """ + A transform that keeps track of completed trajectories and filters them out during sampling. + """ + + def __init__(self): + super().__init__() + self.completed_trajectories = set() + self.repertoire_tensor = torch.zeros((), dtype=torch.int64) + + def _update_repertoire(self, tensordict: TensorDictBase) -> None: + """Updates the repertoire of completed trajectories.""" + done = tensordict["next", "terminated"].squeeze(-1) + traj = tensordict["next", "traj_count"][done].view(-1) + if traj.numel(): + self.completed_trajectories = self.completed_trajectories.union( + traj.tolist() + ) + self.repertoire_tensor = torch.tensor( + list(self.completed_trajectories), dtype=torch.int64 + ) + + def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: + """Updates the repertoire of completed trajectories during insertion.""" + self._update_repertoire(tensordict) + return tensordict + + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + """Filters out incomplete trajectories during sampling.""" + traj = tensordict["next", "traj_count"] + traj = traj.unsqueeze(-1) + has_traj = (traj == self.repertoire_tensor).any(-1) + has_traj = has_traj.view(tensordict.shape) + return tensordict[has_traj] + + +def main(): + # Create a CartPole environment with trajectory counting + env = GymEnv("CartPole-v1").append_transform(TrajCounter()) + + # Create a replay buffer with the completed trajectory repertoire transform + buffer = ReplayBuffer( + storage=LazyTensorStorage(1_000_000), transform=CompletedTrajectoryRepertoire() + ) + + # Roll out the environment for 1000 steps + while True: + rollout = env.rollout(1000, break_when_any_done=False) + if not rollout["next", "done"][-1].item(): + break + + # Extend the replay buffer with the rollout + buffer.extend(rollout) + + # Get the last trajectory count + last_traj_count = rollout[-1]["next", "traj_count"].item() + print(f"Incomplete trajectory: {last_traj_count}") + + # Sample from the replay buffer 10 times + for _ in range(10): + sample_traj_counts = buffer.sample(32)["next", "traj_count"].unique() + print(f"Sampled trajectories: {sample_traj_counts}") + assert last_traj_count not in sample_traj_counts + + +if __name__ == "__main__": + main() From 19dfefc84ec9e8998b7ef6e97578fe186372d48f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 11 Dec 2024 09:15:52 -0800 Subject: [PATCH 19/27] [BugFix] Fix init_random_frames=0 ghstack-source-id: 38a544ea15631f9affb4c385c09e7c4df94af55d Pull Request resolved: https://github.com/pytorch/rl/pull/2645 --- test/test_collector.py | 2 +- torchrl/collectors/collectors.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_collector.py b/test/test_collector.py index 38191a46eaa..5c91cb83633 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -1345,7 +1345,7 @@ def make_env(): functools.partial(MultiSyncDataCollector, cat_results="stack"), ], ) -@pytest.mark.parametrize("init_random_frames", [50]) # 1226: faster execution +@pytest.mark.parametrize("init_random_frames", [0, 50]) # 1226: faster execution @pytest.mark.parametrize( "explicit_spec,split_trajs", [[True, True], [False, False]] ) # 1226: faster execution diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 16eb5904b84..14fbc7d5f22 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -712,10 +712,10 @@ def __init__( ) self.reset_at_each_iter = reset_at_each_iter self.init_random_frames = ( - int(init_random_frames) if init_random_frames is not None else 0 + int(init_random_frames) if init_random_frames not in (None, -1) else 0 ) if ( - init_random_frames is not None + init_random_frames not in (-1, None, 0) and init_random_frames % frames_per_batch != 0 and RL_WARNINGS ): From 57dc25a446a94c4175ad1820473196ff5c49249a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 12 Dec 2024 11:19:09 -0800 Subject: [PATCH 20/27] [Refactor] Refactor trees ghstack-source-id: 368ba4c4402b6db0bc8b0688802ce161db9776b7 Pull Request resolved: https://github.com/pytorch/rl/pull/2634 --- test/test_storage_map.py | 104 ++++- torchrl/data/map/hash.py | 3 +- torchrl/data/map/tdstorage.py | 32 +- torchrl/data/map/tree.py | 585 ++++++++++++++++++++++-- torchrl/data/map/utils.py | 6 +- torchrl/data/replay_buffers/storages.py | 4 +- 6 files changed, 678 insertions(+), 56 deletions(-) diff --git a/test/test_storage_map.py b/test/test_storage_map.py index 9ff4431fb50..db2d0bc2c49 100644 --- a/test/test_storage_map.py +++ b/test/test_storage_map.py @@ -301,6 +301,7 @@ def _state0(self) -> TensorDict: def _make_td(state: torch.Tensor, action: torch.Tensor) -> TensorDict: done = torch.zeros_like(action, dtype=torch.bool).unsqueeze(-1) reward = action.clone() + action = action + torch.arange(action.shape[-1]) / action.shape[-1] return TensorDict( { @@ -326,7 +327,7 @@ def _make_forest(self) -> MCTSForest: forest.extend(r4) return forest - def _make_forest_intersect(self) -> MCTSForest: + def _make_forest_rebranching(self) -> MCTSForest: """ ├── 0 │ ├── 16 @@ -449,7 +450,7 @@ def test_forest_check_ids(self): def test_forest_intersect(self): state0 = self._state0() - forest = self._make_forest_intersect() + forest = self._make_forest_rebranching() tree = forest.get_tree(state0) subtree = forest.get_tree(TensorDict(observation=19)) @@ -467,13 +468,110 @@ def test_forest_intersect(self): def test_forest_intersect_vertices(self): state0 = self._state0() - forest = self._make_forest_intersect() + forest = self._make_forest_rebranching() tree = forest.get_tree(state0) assert len(tree.vertices(key_type="path")) > len(tree.vertices(key_type="hash")) assert len(tree.vertices(key_type="id")) == len(tree.vertices(key_type="hash")) with pytest.raises(ValueError, match="key_type must be"): tree.vertices(key_type="another key type") + @pytest.mark.skipif(not _has_gym, reason="requires gym") + def test_simple_tree(self): + from torchrl.envs import GymEnv + + env = GymEnv("Pendulum-v1") + r = env.rollout(10) + state0 = r[0] + forest = MCTSForest() + forest.extend(r) + # forest = self._make_forest_intersect() + tree = forest.get_tree(state0, compact=False) + assert tree.max_length() == 9 + for p in tree.valid_paths(): + assert len(p) == 9 + + @pytest.mark.parametrize( + "tree_type,compact", + [ + ["simple", False], + ["forest", False], + # parent of rebranching trees are still buggy + # ["rebranching", False], + # ["rebranching", True], + ], + ) + def test_forest_parent(self, tree_type, compact): + if tree_type == "simple": + if not _has_gym: + pytest.skip("requires gym") + from torchrl.envs import GymEnv + + env = GymEnv("Pendulum-v1") + r = env.rollout(10) + state0 = r[0] + forest = MCTSForest() + forest.extend(r) + tree = forest.get_tree(state0, compact=compact) + elif tree_type == "forest": + state0 = self._state0() + forest = self._make_forest() + tree = forest.get_tree(state0, compact=compact) + else: + state0 = self._state0() + forest = self._make_forest_rebranching() + tree = forest.get_tree(state0, compact=compact) + # Check access + tree.subtree.parent + tree.subtree.subtree.parent + tree.subtree.subtree.subtree.parent + + # check present of weakref + assert tree.subtree[0]._parent is not None + assert tree.subtree[0].subtree[0]._parent is not None + + # Check content + assert_close(tree.subtree.parent, tree) + for p in tree.valid_paths(): + root = tree + for it in p: + node = root.subtree[it] + assert_close(node.parent, root) + root = node + + def test_forest_action_attr(self): + state0 = self._state0() + forest = self._make_forest() + tree = forest.get_tree(state0) + assert tree.branching_action is None + assert (tree.subtree.branching_action != tree.subtree.prev_action).any() + assert ( + tree.subtree[0].subtree.branching_action + != tree.subtree[0].subtree.prev_action + ).any() + assert tree.prev_action is None + + @pytest.mark.parametrize("intersect", [False, True]) + def test_forest_check_obs_match(self, intersect): + state0 = self._state0() + if intersect: + forest = self._make_forest_rebranching() + else: + forest = self._make_forest() + tree = forest.get_tree(state0) + for path in tree.valid_paths(): + prev_tree = tree + for p in path: + subtree = prev_tree.subtree[p] + assert ( + subtree.node_data["observation"] + == subtree.rollout[..., -1]["next", "observation"] + ).all() + assert ( + subtree.node_observation + == subtree.rollout[..., -1]["next", "observation"] + ).all() + prev_tree = subtree + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() diff --git a/torchrl/data/map/hash.py b/torchrl/data/map/hash.py index 01988dc43be..59526628dbe 100644 --- a/torchrl/data/map/hash.py +++ b/torchrl/data/map/hash.py @@ -75,7 +75,8 @@ def forward(self, features: torch.Tensor) -> torch.Tensor: class SipHash(Module): """A Module to Compute SipHash values for given tensors. - A hash function module based on SipHash implementation in python. + A hash function module based on SipHash implementation in python. Input tensors should have shape ``[batch_size, num_features]`` + and the output shape will be ``[batch_size]``. Args: as_tensor (bool, optional): if ``True``, the bytes will be turned into integers diff --git a/torchrl/data/map/tdstorage.py b/torchrl/data/map/tdstorage.py index a601f1e3261..9413033bac4 100644 --- a/torchrl/data/map/tdstorage.py +++ b/torchrl/data/map/tdstorage.py @@ -138,6 +138,10 @@ def __init__( self.collate_fn = collate_fn self.write_fn = write_fn + @property + def max_size(self): + return self.storage.max_size + @property def out_keys(self) -> List[NestedKey]: out_keys = self.__dict__.get("_out_keys_and_lazy") @@ -177,7 +181,7 @@ def from_tensordict_pair( collate_fn: Callable[[Any], Any] | None = None, write_fn: Callable[[Any, Any], Any] | None = None, consolidated: bool | None = None, - ): + ) -> TensorDictMap: """Creates a new TensorDictStorage from a pair of tensordicts (source and dest) using pre-defined rules of thumb. Args: @@ -238,7 +242,13 @@ def from_tensordict_pair( n_feat = 0 hash_module = [] for in_key in in_keys: - n_feat = source[in_key].shape[-1] + entry = source[in_key] + if entry.ndim == source.ndim: + # this is a good example of why td/tc are useful - carrying metadata + # allows us to know if there's a feature dim or not + n_feat = 0 + else: + n_feat = entry.shape[-1] if n_feat > RandomProjectionHash._N_COMPONENTS_DEFAULT: _hash_module = RandomProjectionHash() else: @@ -308,7 +318,23 @@ def __setitem__(self, item: TensorDictBase, value: TensorDictBase): if not self._has_lazy_out_keys(): # TODO: make this work with pytrees and avoid calling select if keys match value = value.select(*self.out_keys, strict=False) + item, value = self._maybe_add_batch(item, value) + index = self._to_index(item, extend=True) + if index.unique().numel() < index.numel(): + # If multiple values point to the same place in the storage, we cannot process them by batch + # There could be a better way to deal with this, using unique ids. + vals = [] + for it, val in zip(item.split(1), value.split(1)): + self[it] = val + vals.append(val) + # __setitem__ may affect the content of the input data + value.update(TensorDictBase.lazy_stack(vals)) + return if self.write_fn is not None: + # We use this block in the following context: the value written in the storage is already present, + # but it needs to be updated. + # We first check if the value is already there using `contains`. If so, we pass the new value and the + # previous one to write_fn. The values that are not present are passed alone. if len(self): modifiable = self.contains(item) if modifiable.any(): @@ -322,8 +348,6 @@ def __setitem__(self, item: TensorDictBase, value: TensorDictBase): value = self.write_fn(value) else: value = self.write_fn(value) - item, value = self._maybe_add_batch(item, value) - index = self._to_index(item, extend=True) self.storage.set(index, value) def __len__(self): diff --git a/torchrl/data/map/tree.py b/torchrl/data/map/tree.py index 645f7704ddd..513a7b94e58 100644 --- a/torchrl/data/map/tree.py +++ b/torchrl/data/map/tree.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +import weakref from collections import deque from typing import Any, Callable, Dict, List, Literal, Tuple @@ -15,10 +16,13 @@ TensorClass, TensorDict, TensorDictBase, + unravel_key, ) from torchrl.data.map.tdstorage import TensorDictMap from torchrl.data.map.utils import _plot_plotly_box, _plot_plotly_tree from torchrl.data.replay_buffers.storages import ListStorage +from torchrl.data.tensor_specs import Composite + from torchrl.envs.common import EnvBase @@ -69,7 +73,9 @@ class Tree(TensorClass["nocast"]): """ - count: int = None + count: int | torch.Tensor = None + wins: int | torch.Tensor = None + index: torch.Tensor | None = None # The hash is None if the node has more than one action associated hash: int | None = None @@ -78,12 +84,249 @@ class Tree(TensorClass["nocast"]): # rollout following the observation encoded in node, in a TorchRL (TED) format rollout: TensorDict | None = None - # The data specifying the node - node: TensorDict | None = None + # The data specifying the node (typically an observation or a set of observations) + node_data: TensorDict | None = None # Stack of subtrees. A subtree is produced when an action is taken. subtree: "Tree" = None + # weakrefs to the parent(s) of the node + _parent: weakref.ref | List[weakref.ref] | None = None + + # Specs: contains information such as action or observation keys and spaces. + # If present, they should be structured like env specs are: + # Composite(input_spec=Composite(full_state_spec=..., full_action_spec=...), + # output_spec=Composite(full_observation_spec=..., full_reward_spec=..., full_done_spec=...)) + # where every leaf component is optional. + specs: Composite | None = None + + @classmethod + def make_node( + cls, + data: TensorDictBase, + *, + device: torch.device | None = None, + batch_size: torch.Size | None = None, + specs: Composite | None = None, + ) -> Tree: + """Creates a new node given some data.""" + if "next" in data.keys(): + rollout = data + if not rollout.ndim: + rollout = rollout.unsqueeze(0) + subtree = TensorDict.lazy_stack([cls.make_node(data["next"][..., -1])]) + else: + rollout = None + subtree = None + if device is None: + device = data.device + return cls( + count=torch.zeros(()), + wins=torch.zeros(()), + node=data.exclude("action", "next"), + rollout=rollout, + subtree=subtree, + device=device, + batch_size=batch_size, + ) + + # Specs + @property + def full_observation_spec(self): + """The observation spec of the tree. + + This is an alias for `Tree.specs['output_spec', 'full_observation_spec']`.""" + return self.specs["output_spec", "full_observation_spec"] + + @property + def full_reward_spec(self): + """The reward spec of the tree. + + This is an alias for `Tree.specs['output_spec', 'full_reward_spec']`.""" + return self.specs["output_spec", "full_reward_spec"] + + @property + def full_done_spec(self): + """The done spec of the tree. + + This is an alias for `Tree.specs['output_spec', 'full_done_spec']`.""" + return self.specs["output_spec", "full_done_spec"] + + @property + def full_state_spec(self): + """The state spec of the tree. + + This is an alias for `Tree.specs['input_spec', 'full_state_spec']`.""" + return self.specs["input_spec", "full_state_spec"] + + @property + def full_action_spec(self): + """The action spec of the tree. + + This is an alias for `Tree.specs['input_spec', 'full_action_spec']`.""" + return self.specs["input_spec", "full_action_spec"] + + @property + def selected_actions(self) -> torch.Tensor | TensorDictBase | None: + """Returns a tensor containing all the selected actions branching out from this node.""" + if self.subtree is None: + return None + return self.subtree.rollout[..., 0]["action"] + + @property + def prev_action(self) -> torch.Tensor | TensorDictBase | None: + """The action undertaken just before this node's observation was generated. + + Returns: + a tensor, tensordict or None if the node has no parent. + + .. seealso:: This will be equal to :class:`~torchrl.data.Tree.branching_action` whenever the rollout data contains a single step. + + .. seealso:: :class:`All actions associated with a given node (or observation) in the tree <~torchrl.data.Tree.selected_action>`. + + """ + if self.rollout is None: + return None + return self.rollout[..., -1]["action"] + + @property + def branching_action(self) -> torch.Tensor | TensorDictBase | None: + """Returns the action that branched out to this particular node. + + Returns: + a tensor, tensordict or None if the node has no parent. + + .. seealso:: This will be equal to :class:`~torchrl.data.Tree.prev_action` whenever the rollout data contains a single step. + + .. seealso:: :class:`All actions associated with a given node (or observation) in the tree <~torchrl.data.Tree.selected_action>`. + + """ + if self.rollout is None: + return None + return self.rollout[..., 0]["action"] + + @property + def node_observation(self) -> torch.Tensor | TensorDictBase: + """Returns the observation associated with this particular node. + + This is the observation (or bag of observations) that defines the node before a branching occurs. + If the node contains a :attr:`~.rollout` attribute, the node observation is typically identical to the + observation resulting from the last action undertaken, i.e., ``node.rollout[..., -1]["next", "observation"]``. + + If more than one observation key is associated with the tree specs, a :class:`~tensordict.TensorDict` instance + is returned instead. + + For a more consistent representation, see :attr:`~.node_observations`. + + """ + # TODO: implement specs + return self.node_data["observation"] + + @property + def node_observations(self) -> torch.Tensor | TensorDictBase: + """Returns the observations associated with this particular node in a TensorDict format. + + This is the observation (or bag of observations) that defines the node before a branching occurs. + If the node contains a :attr:`~.rollout` attribute, the node observation is typically identical to the + observation resulting from the last action undertaken, i.e., ``node.rollout[..., -1]["next", "observation"]``. + + If more than one observation key is associated with the tree specs, a :class:`~tensordict.TensorDict` instance + is returned instead. + + For a more consistent representation, see :attr:`~.node_observations`. + + """ + # TODO: implement specs + return self.node_data.select("observation") + + @property + def visits(self) -> int | torch.Tensor: + """Returns the number of visits associated with this particular node. + + This is an alias for the :attr:`~.count` attribute. + + """ + return self.count + + @visits.setter + def visits(self, count): + self.count = count + + def __setattr__(self, name: str, value: Any) -> None: + if name == "subtree" and value is not None: + wr = weakref.ref(self._tensordict) + if value._parent is None: + value._parent = wr + elif isinstance(value._parent, list): + value._parent.append(wr) + else: + value._parent = [value._parent, wr] + return super().__setattr__(name, value) + + @property + def parent(self) -> Tree | None: + """The parent of the node. + + If the node has a parent and this object is still present in the python workspace, it will be returned by this + property. + + For re-branching trees, this property may return a stack of trees where every index of the stack corresponds to + a different parent. + + .. note:: the ``parent`` attribute will match in content but not in identity: the tensorclass object is recustructed + using the same tensors (i.e., tensors that point to the same memory locations). + + Returns: + A ``Tree`` containing the parent data or ``None`` if the parent data is out of scope or the node is the root. + """ + parent = self._parent + if parent is not None: + # Check that all parents match + queue = [parent] + + def maybe_flatten_list(maybe_nested_list): + if isinstance(maybe_nested_list, list): + for p in maybe_nested_list: + if isinstance(p, list): + queue.append(p) + else: + yield p() + else: + yield maybe_nested_list() + + parent_result = None + while len(queue): + local_result = None + for r in maybe_flatten_list(queue.pop()): + if local_result is None: + local_result = r + elif r is not None and r is not local_result: + if isinstance(local_result, list): + local_result.append(r) + else: + local_result = [local_result, r] + if local_result is None: + continue + # replicate logic at macro level + if parent_result is None: + parent_result = local_result + else: + if isinstance(local_result, list): + local_result = [ + r for r in local_result if r not in parent_result + ] + else: + local_result = [local_result] + if isinstance(parent_result, list): + parent_result.extend(local_result) + else: + parent_result = [parent_result, *local_result] + if isinstance(parent_result, list): + return TensorDict.lazy_stack( + [self._from_tensordict(r) for r in parent_result] + ) + return self._from_tensordict(parent_result) + @property def num_children(self) -> int: """Number of children of this node. @@ -93,9 +336,19 @@ def num_children(self) -> int: return len(self.subtree) if self.subtree is not None else 0 @property - def is_terminal(self): - """Returns True if the the tree has no children nodes.""" - return self.subtree is None + def is_terminal(self) -> bool | torch.Tensor: + """Returns True if the tree has no children nodes.""" + if self.rollout is not None: + return self.rollout[..., -1]["next", "done"].squeeze(-1) + # If there is no rollout, there is no preceding data - either this is a root or it's a floating node. + # In either case, we assume that the node is not terminal. + return False + + def fully_expanded(self, env: EnvBase) -> bool: + """Returns True if the number of children is equal to the environment cardinality.""" + cardinality = env.cardinality(self.node_data) + num_actions = self.num_children + return cardinality == num_actions def get_vertex_by_id(self, id: int) -> Tree: """Goes through the tree and returns the node corresponding the given id.""" @@ -163,9 +416,6 @@ def vertices( if h in memo and not use_path: continue memo.add(h) - r = tree.rollout - if r is not None: - r = r["next", "observation"] if use_path: result[cur_path] = tree elif use_id: @@ -206,6 +456,14 @@ def num_vertices(self, *, count_repeat: bool = False) -> int: ) def edges(self) -> List[Tuple[int, int]]: + """Retrieves a list of edges in the tree. + + Each edge is represented as a tuple of two node IDs: the parent node ID and the child node ID. + The tree is traversed using Breadth-First Search (BFS) to ensure all edges are visited. + + Returns: + A list of tuples, where each tuple contains a parent node ID and a child node ID. + """ result = [] q = deque() parent = self.node_id @@ -221,22 +479,62 @@ def edges(self) -> List[Tuple[int, int]]: return result def valid_paths(self): + """Generates all valid paths in the tree. + + A valid path is a sequence of child indices that starts at the root node and ends at a leaf node. + Each path is represented as a tuple of integers, where each integer corresponds to the index of a child node. + + Yields: + tuple: A valid path in the tree. + """ + # Initialize a queue with the current tree node and an empty path q = deque() cur_path = () q.append((self, cur_path)) + # Perform BFS traversal of the tree while len(q): + # Dequeue the next tree node and its current path tree, cur_path = q.popleft() + # Get the number of child nodes n = int(tree.num_children) + # If this is a leaf node, yield the current path if not n: yield cur_path + # Iterate over the child nodes for i in range(n): cur_path_tree = cur_path + (i,) q.append((tree.subtree[i], cur_path_tree)) def max_length(self): - return max(*(len(path) for path in self.valid_paths())) + """Returns the maximum length of all valid paths in the tree. + + The length of a path is defined as the number of nodes in the path. + If the tree is empty, returns 0. + + Returns: + int: The maximum length of all valid paths in the tree. + + """ + lengths = tuple(len(path) for path in self.valid_paths()) + if len(lengths) == 0: + return 0 + elif len(lengths) == 1: + return lengths[0] + return max(*lengths) def rollout_from_path(self, path: Tuple[int]) -> TensorDictBase | None: + """Retrieves the rollout data along a given path in the tree. + + The rollout data is concatenated along the last dimension (dim=-1) for each node in the path. + If no rollout data is found along the path, returns ``None``. + + Args: + path: A tuple of integers representing the path in the tree. + + Returns: + The concatenated rollout data along the path, or None if no data is found. + + """ r = self.rollout tree = self rollouts = [] @@ -272,8 +570,19 @@ def plot( backend: str = "plotly", figure: str = "tree", info: List[str] = None, - make_labels: Callable[[Any], Any] | None = None, + make_labels: Callable[[Any, ...], Any] | None = None, ): + """Plots a visualization of the tree using the specified backend and figure type. + + Args: + backend: The plotting backend to use. Currently only supports 'plotly'. + figure: The type of figure to plot. Can be either 'tree' or 'box'. + info: A list of additional information to include in the plot (not currently used). + make_labels: An optional function to generate custom labels for the plot. + + Raises: + NotImplementedError: If an unsupported backend or figure type is specified. + """ if backend == "plotly": if figure == "box": _plot_plotly_box(self) @@ -284,33 +593,48 @@ def plot( else: pass raise NotImplementedError( - f"Unkown plotting backend {backend} with figure {figure}." + f"Unknown plotting backend {backend} with figure {figure}." ) class MCTSForest: """A collection of MCTS trees. + .. warning:: This class is currently under active development. Expect frequent API changes. + The class is aimed at storing rollouts in a storage, and produce trees based on a given root in that dataset. Keyword Args: data_map (TensorDictMap, optional): the storage to use to store the data (observation, reward, states etc). If not provided, it is lazily - initialized using :meth:`~torchrl.data.map.tdstorage.TensorDictMap.from_tensordict_pair`. - node_map (TensorDictMap, optional): TODO - done_keys (list of NestedKey): the done keys of the environment. If not provided, + initialized using :meth:`~torchrl.data.map.tdstorage.TensorDictMap.from_tensordict_pair` + using the list of :attr:`observation_keys` and :attr:`action_keys` as ``in_keys``. + node_map (TensorDictMap, optional): a map from the observation space to the index space. + Internally, the node map is used to gather all possible branches coming out of + a given node. For example, if an observation has two associated actions and outcomes + in the data map, then the :attr:`node_map` will return a data structure containing the + two indices in the :attr:`data_map` that correspond to these two outcomes. + If not provided, it is lazily initialized using + :meth:`~torchrl.data.map.tdstorage.TensorDictMap.from_tensordict_pair` using the list of + :attr:`observation_keys` as ``in_keys`` and the :class:`~torchrl.data.QueryModule` as + ``out_keys``. + max_size (int, optional): the size of the maps. + If not provided, defaults to ``data_map.max_size`` if this can be found, then + ``node_map.max_size``. If none of these are provided, defaults to `1000`. + done_keys (list of NestedKey, optional): the done keys of the environment. If not provided, defaults to ``("done", "terminated", "truncated")``. The :meth:`~.get_keys_from_env` can be used to automatically determine the keys. - action_keys (list of NestedKey): the action keys of the environment. If not provided, + action_keys (list of NestedKey, optional): the action keys of the environment. If not provided, defaults to ``("action",)``. The :meth:`~.get_keys_from_env` can be used to automatically determine the keys. - reward_keys (list of NestedKey): the reward keys of the environment. If not provided, + reward_keys (list of NestedKey, optional): the reward keys of the environment. If not provided, defaults to ``("reward",)``. The :meth:`~.get_keys_from_env` can be used to automatically determine the keys. - observation_keys (list of NestedKey): the observation keys of the environment. If not provided, + observation_keys (list of NestedKey, optional): the observation keys of the environment. If not provided, defaults to ``("observation",)``. The :meth:`~.get_keys_from_env` can be used to automatically determine the keys. + excluded_keys (list of NestedKey, optional): a list of keys to exclude from the data storage. consolidated (bool, optional): if ``True``, the data_map storage will be consolidated on disk. Defaults to ``False``. @@ -405,10 +729,12 @@ def __init__( *, data_map: TensorDictMap | None = None, node_map: TensorDictMap | None = None, + max_size: int | None = None, done_keys: List[NestedKey] | None = None, reward_keys: List[NestedKey] = None, observation_keys: List[NestedKey] = None, action_keys: List[NestedKey] = None, + excluded_keys: List[NestedKey] = None, consolidated: bool | None = None, ): @@ -416,55 +742,125 @@ def __init__( self.node_map = node_map + if max_size is None: + if data_map is not None: + max_size = data_map.max_size + if max_size != getattr(node_map, "max_size", max_size): + raise ValueError( + f"Conflicting max_size: got data_map.max_size={data_map.max_size} and node_map.max_size={node_map.max_size}." + ) + elif node_map is not None: + max_size = node_map.max_size + else: + max_size = None + elif data_map is not None and max_size != getattr( + data_map, "max_size", max_size + ): + raise ValueError( + f"Conflicting max_size: got data_map.max_size={data_map.max_size} and max_size={max_size}." + ) + elif node_map is not None and max_size != getattr( + node_map, "max_size", max_size + ): + raise ValueError( + f"Conflicting max_size: got node_map.max_size={node_map.max_size} and max_size={max_size}." + ) + self.max_size = max_size + self.done_keys = done_keys self.action_keys = action_keys self.reward_keys = reward_keys self.observation_keys = observation_keys + self.excluded_keys = excluded_keys self.consolidated = consolidated @property - def done_keys(self): + def done_keys(self) -> List[NestedKey]: + """Done Keys. + + Returns the keys used to indicate that an episode has ended. + The default done keys are "done", "terminated", and "truncated". These keys can be + used in the environment's output to signal the end of an episode. + + Returns: + A list of strings representing the done keys. + + """ done_keys = getattr(self, "_done_keys", None) if done_keys is None: - self._done_keys = done_keys = ("done", "terminated", "truncated") + self._done_keys = done_keys = ["done", "terminated", "truncated"] return done_keys @done_keys.setter def done_keys(self, value): - self._done_keys = value + self._done_keys = _make_list_of_nestedkeys(value, "done_keys") @property - def reward_keys(self): + def reward_keys(self) -> List[NestedKey]: + """Reward Keys. + + Returns the keys used to retrieve rewards from the environment's output. + The default reward key is "reward". + + Returns: + A list of strings or tuples representing the reward keys. + + """ reward_keys = getattr(self, "_reward_keys", None) if reward_keys is None: - self._reward_keys = reward_keys = ("reward",) + self._reward_keys = reward_keys = ["reward"] return reward_keys @reward_keys.setter def reward_keys(self, value): - self._reward_keys = value + self._reward_keys = _make_list_of_nestedkeys(value, "reward_keys") @property - def action_keys(self): + def action_keys(self) -> List[NestedKey]: + """Action Keys. + + Returns the keys used to retrieve actions from the environment's input. + The default action key is "action". + + Returns: + A list of strings or tuples representing the action keys. + + """ action_keys = getattr(self, "_action_keys", None) if action_keys is None: - self._action_keys = action_keys = ("action",) + self._action_keys = action_keys = ["action"] return action_keys @action_keys.setter def action_keys(self, value): - self._action_keys = value + self._action_keys = _make_list_of_nestedkeys(value, "action_keys") @property - def observation_keys(self): + def observation_keys(self) -> List[NestedKey]: + """Observation Keys. + + Returns the keys used to retrieve observations from the environment's output. + The default observation key is "observation". + + Returns: + A list of strings or tuples representing the observation keys. + """ observation_keys = getattr(self, "_observation_keys", None) if observation_keys is None: - self._observation_keys = observation_keys = ("observation",) + self._observation_keys = observation_keys = ["observation"] return observation_keys @observation_keys.setter def observation_keys(self, value): - self._observation_keys = value + self._observation_keys = _make_list_of_nestedkeys(value, "observation_keys") + + @property + def excluded_keys(self) -> List[NestedKey] | None: + return self._excluded_keys + + @excluded_keys.setter + def excluded_keys(self, value): + self._excluded_keys = _make_list_of_nestedkeys(value, "excluded_keys") def get_keys_from_env(self, env: EnvBase): """Writes missing done, action and reward keys to the Forest given an environment. @@ -482,8 +878,21 @@ def get_keys_from_env(self, env: EnvBase): @classmethod def _write_fn_stack(cls, new, old=None): + # This function updates the old values by adding the new ones + # if and only if the new ones are not there. + # If the old value is not provided, we assume there are none and the + # `new` is just prepared. + # This involves unsqueezing the last dim (since we'll be stacking tensors + # and calling unique). + # The update involves calling cat along the last dim + unique + # which will keep only the new values that were unknown to + # the storage. + # We use this method to track all the indices that are associated with + # an observation. Every time a new index is obtained, it is stacked alongside + # the others. if old is None: - result = new.apply(lambda x: x.unsqueeze(0), filter_empty=False) + # we unsqueeze the values to stack them along dim -1 + result = new.apply(lambda x: x.unsqueeze(-1), filter_empty=False) result.set( "count", torch.ones(result.shape, dtype=torch.int, device=result.device) ) @@ -493,28 +902,44 @@ def cat(name, x, y): if name == "count": return x if y.ndim < x.ndim: - y = y.unsqueeze(0) - result = torch.cat([x, y], 0).unique(dim=0, sorted=False) + y = y.unsqueeze(-1) + result = torch.cat([x, y], -1) + # Breaks on mps + if result.device.type == "mps": + result = result.cpu() + result = result.unique(dim=-1, sorted=False) + result = result.to("mps") + else: + result = result.unique(dim=-1, sorted=False) return result result = old.named_apply(cat, new, default=None) result.set_("count", old.get("count") + 1) return result - def _make_storage(self, source, dest): + def _make_data_map(self, source, dest): try: + kwargs = {} + if self.max_size is not None: + kwargs["max_size"] = self.max_size self.data_map = TensorDictMap.from_tensordict_pair( source, dest, in_keys=[*self.observation_keys, *self.action_keys], consolidated=self.consolidated, + **kwargs, ) + if self.max_size is None: + self.max_size = self.data_map.max_size except KeyError as err: raise KeyError( "A KeyError occurred during data map creation. This could be due to the wrong setting of a key in the MCTSForest constructor. Scroll up for more info." ) from err - def _make_storage_branches(self, source, dest): + def _make_node_map(self, source, dest): + kwargs = {} + if self.max_size is not None: + kwargs["max_size"] = self.max_size self.node_map = TensorDictMap.from_tensordict_pair( source, dest, @@ -528,26 +953,59 @@ def _make_storage_branches(self, source, dest): storage_constructor=ListStorage, collate_fn=TensorDict.lazy_stack, write_fn=self._write_fn_stack, + **kwargs, ) + if self.max_size is None: + self.max_size = self.data_map.max_size - def extend(self, rollout): + def extend(self, rollout, *, return_node: bool = False): source, dest = ( rollout.exclude("next").copy(), rollout.select("next", *self.action_keys).copy(), ) + if self.excluded_keys is not None: + dest = dest.exclude(*self.excluded_keys, inplace=True) + dest.get("next").exclude(*self.excluded_keys, inplace=True) if self.data_map is None: - self._make_storage(source, dest) + self._make_data_map(source, dest) # We need to set the action somewhere to keep track of what action lead to what child # # Set the action in the 'next' # dest[1:] = source[:-1].exclude(*self.done_keys) + # Add ('observation', 'action') -> ('next, observation') self.data_map[source] = dest value = source if self.node_map is None: - self._make_storage_branches(source, dest) + self._make_node_map(source, dest) + # map ('observation',) -> ('indices',) self.node_map[source] = TensorDict.lazy_stack(value.unbind(0)) + if return_node: + return self.get_tree(rollout) + + def add(self, step, *, return_node: bool = False): + source, dest = ( + step.exclude("next").copy(), + step.select("next", *self.action_keys).copy(), + ) + + if self.data_map is None: + self._make_data_map(source, dest) + + # We need to set the action somewhere to keep track of what action lead to what child + # # Set the action in the 'next' + # dest[1:] = source[:-1].exclude(*self.done_keys) + + # Add ('observation', 'action') -> ('next, observation') + self.data_map[source] = dest + value = source + if self.node_map is None: + self._make_node_map(source, dest) + # map ('observation',) -> ('indices',) + self.node_map[source] = value + if return_node: + return self.get_tree(step) def get_child(self, root: TensorDictBase) -> TensorDictBase: return self.data_map[root] @@ -573,6 +1031,8 @@ def _make_local_tree( while index.numel() <= 1: index = index.squeeze() d = self.data_map.storage[index] + + # Rebuild rollout step steps.append(merge_tensordicts(d, root, callback_exist=lambda *x: None)) d = d["next"] if d in self.node_map: @@ -582,6 +1042,15 @@ def _make_local_tree( if not compact: break else: + # If the root is provided and not gathered from the storage, it could be that its + # device doesn't match the data_map storage device. + root = steps[-1]["next"].select(*self.node_map.in_keys) + device = getattr(self.data_map.storage, "device", None) + if root.device != device: + if device is not None: + root = root.to(self.data_map.storage.device) + else: + root.clear_device_() index = None break rollout = None @@ -592,10 +1061,12 @@ def _make_local_tree( return ( Tree( rollout=rollout, - count=node_meta["count"], - node=root, + count=torch.zeros((), dtype=torch.int32), + wins=torch.zeros(()), + node_data=root, index=index, hash=None, + # We do this to avoid raising an exception as rollout and subtree must be provided together subtree=None, ), index, @@ -618,7 +1089,7 @@ def _make_tree_iter( ): q = deque() memo = {} - tree, indices, hash = self._make_local_tree(root, index=index) + tree, indices, hash = self._make_local_tree(root, index=index, compact=compact) tree.node_id = 0 result = tree @@ -626,7 +1097,6 @@ def _make_tree_iter( counter = 1 if indices is not None: q.append((tree, indices, hash, depth)) - del tree, indices while len(q): tree, indices, hash, depth = q.popleft() @@ -638,12 +1108,29 @@ def _make_tree_iter( subtree, subtree_indices, subtree_hash = memo.get(h, (None,) * 3) if subtree is None: subtree, subtree_indices, subtree_hash = self._make_local_tree( - tree.node, index=i, compact=compact + tree.node_data, + index=i, + compact=compact, ) subtree.node_id = counter counter += 1 subtree.hash = h memo[h] = (subtree, subtree_indices, subtree_hash) + else: + # We just need to save the two (or more) rollouts + subtree_bis, _, _ = self._make_local_tree( + tree.node_data, + index=i, + compact=compact, + ) + if subtree.rollout.ndim == subtree_bis.rollout.ndim: + subtree.rollout = TensorDict.stack( + [subtree.rollout, subtree_bis.rollout] + ) + else: + subtree.rollout = TensorDict.stack( + [*subtree.rollout, subtree_bis.rollout] + ) subtrees.append(subtree) if extend and subtree_indices is not None: @@ -668,3 +1155,15 @@ def valid_paths(cls, tree: Tree): def __len__(self): return len(self.data_map) + + +def _make_list_of_nestedkeys(obj: Any, attr: str) -> List[NestedKey]: + if obj is None: + return obj + if isinstance(obj, (str, tuple)): + return [obj] + if not isinstance(obj, list): + raise ValueError( + f"{attr} must be a list of NestedKeys or a NestedKey, got {obj}." + ) + return [unravel_key(key) for key in obj] diff --git a/torchrl/data/map/utils.py b/torchrl/data/map/utils.py index 570214f1cb2..d9588d79905 100644 --- a/torchrl/data/map/utils.py +++ b/torchrl/data/map/utils.py @@ -17,13 +17,13 @@ def _plot_plotly_tree( if make_labels is None: - def make_labels(tree): + def make_labels(tree, path, *args, **kwargs): return str((tree.node_id, tree.hash)) nr_vertices = tree.num_vertices() - vertices = tree.vertices() + vertices = tree.vertices(key_type="path") - v_label = [make_labels(subtree) for subtree in vertices.values()] + v_label = [make_labels(subtree, path) for path, subtree in vertices.items()] G = Graph(nr_vertices, tree.edges()) layout = G.layout_sugiyama(range(nr_vertices)) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 665cae254f5..ae0d97b7bab 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -246,8 +246,8 @@ def set( set_cursor: bool = True, ): if not isinstance(cursor, INT_CLASSES): - if (isinstance(cursor, torch.Tensor) and cursor.numel() <= 1) or ( - isinstance(cursor, np.ndarray) and cursor.size <= 1 + if (isinstance(cursor, torch.Tensor) and cursor.ndim == 0) or ( + isinstance(cursor, np.ndarray) and cursor.ndim == 0 ): self.set(int(cursor), data, set_cursor=set_cursor) return From 30d21e5990dcbcaf8beb7c37326a89526d8cfda7 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 12 Dec 2024 13:31:36 -0800 Subject: [PATCH 21/27] [Feature] LLMHashingEnv ghstack-source-id: d1a20ecd023008683cf18cf9e694340cfdbdac8a Pull Request resolved: https://github.com/pytorch/rl/pull/2635 --- docs/source/reference/envs.rst | 2 + docs/source/reference/trainers.rst | 4 +- test/test_env.py | 25 ++++ torchrl/data/map/tree.py | 15 +- torchrl/envs/__init__.py | 2 +- torchrl/envs/common.py | 37 ++++- torchrl/envs/custom/__init__.py | 1 + torchrl/envs/custom/llm.py | 213 +++++++++++++++++++++++++++++ torchrl/envs/utils.py | 11 +- 9 files changed, 293 insertions(+), 17 deletions(-) create mode 100644 torchrl/envs/custom/llm.py diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 4519900ae8b..70fdf03c0ff 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -347,6 +347,8 @@ TorchRL offers a series of custom built-in environments. PendulumEnv TicTacToeEnv + LLMHashingEnv + Multi-agent environments ------------------------ diff --git a/docs/source/reference/trainers.rst b/docs/source/reference/trainers.rst index 8f6be633743..264534a725c 100644 --- a/docs/source/reference/trainers.rst +++ b/docs/source/reference/trainers.rst @@ -79,7 +79,7 @@ Hooks can be split into 3 categories: **data processing** (``"batch_process"`` a - **Logging** hooks take a batch of data presented as a ``TensorDict`` and write in the logger some information retrieved from that data. Examples include the ``LogValidationReward`` hook, the reward - logger (``LogScaler``) and such. Hooks should return a dictionary (or a None value) containing the + logger (``LogScalar``) and such. Hooks should return a dictionary (or a None value) containing the data to log. The key ``"log_pbar"`` is reserved to boolean values indicating if the logged value should be displayed on the progression bar printed on the training log. @@ -174,7 +174,7 @@ Trainer and hooks BatchSubSampler ClearCudaCache CountFramesLog - LogScaler + LogScalar OptimizerHook LogValidationReward ReplayBufferTrainer diff --git a/test/test_env.py b/test/test_env.py index b48b1a1cf8f..cef7a507f2a 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -8,6 +8,7 @@ import functools import gc import os.path +import random import re from collections import defaultdict from functools import partial @@ -114,6 +115,7 @@ DoubleToFloat, EnvBase, EnvCreator, + LLMHashingEnv, ParallelEnv, PendulumEnv, SerialEnv, @@ -3419,6 +3421,29 @@ def test_pendulum_env(self, device): r = env.rollout(10, tensordict=TensorDict(batch_size=[5], device=device)) assert r.shape == torch.Size((5, 10)) + def test_llm_hashing_env(self): + vocab_size = 5 + + class Tokenizer: + def __call__(self, obj): + return torch.randint(vocab_size, (len(obj.split(" ")),)).tolist() + + def decode(self, obj): + words = ["apple", "banana", "cherry", "date", "elderberry"] + return " ".join(random.choice(words) for _ in obj) + + def batch_decode(self, obj): + return [self.decode(_obj) for _obj in obj] + + def encode(self, obj): + return self(obj) + + tokenizer = Tokenizer() + env = LLMHashingEnv(tokenizer=tokenizer, vocab_size=vocab_size) + td = env.make_tensordict("some sentence") + assert isinstance(td, TensorDict) + env.check_env_specs(tensordict=td) + @pytest.mark.parametrize("device", [None, *get_default_devices()]) @pytest.mark.parametrize("env_device", [None, *get_default_devices()]) diff --git a/torchrl/data/map/tree.py b/torchrl/data/map/tree.py index 513a7b94e58..c09db75aa5b 100644 --- a/torchrl/data/map/tree.py +++ b/torchrl/data/map/tree.py @@ -135,35 +135,40 @@ def make_node( def full_observation_spec(self): """The observation spec of the tree. - This is an alias for `Tree.specs['output_spec', 'full_observation_spec']`.""" + This is an alias for `Tree.specs['output_spec', 'full_observation_spec']`. + """ return self.specs["output_spec", "full_observation_spec"] @property def full_reward_spec(self): """The reward spec of the tree. - This is an alias for `Tree.specs['output_spec', 'full_reward_spec']`.""" + This is an alias for `Tree.specs['output_spec', 'full_reward_spec']`. + """ return self.specs["output_spec", "full_reward_spec"] @property def full_done_spec(self): """The done spec of the tree. - This is an alias for `Tree.specs['output_spec', 'full_done_spec']`.""" + This is an alias for `Tree.specs['output_spec', 'full_done_spec']`. + """ return self.specs["output_spec", "full_done_spec"] @property def full_state_spec(self): """The state spec of the tree. - This is an alias for `Tree.specs['input_spec', 'full_state_spec']`.""" + This is an alias for `Tree.specs['input_spec', 'full_state_spec']`. + """ return self.specs["input_spec", "full_state_spec"] @property def full_action_spec(self): """The action spec of the tree. - This is an alias for `Tree.specs['input_spec', 'full_action_spec']`.""" + This is an alias for `Tree.specs['input_spec', 'full_action_spec']`. + """ return self.specs["input_spec", "full_action_spec"] @property diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 36e4ec1a908..f3dec221ce0 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -5,7 +5,7 @@ from .batched_envs import ParallelEnv, SerialEnv from .common import EnvBase, EnvMetaData, make_tensordict -from .custom import PendulumEnv, TicTacToeEnv +from .custom import LLMHashingEnv, PendulumEnv, TicTacToeEnv from .env_creator import env_creator, EnvCreator, get_env_metadata from .gym_like import default_info_dict_reader, GymLikeEnv from .libs import ( diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index bafe88b639a..4f6002dedd3 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -14,8 +14,14 @@ import numpy as np import torch import torch.nn as nn -from tensordict import LazyStackedTensorDict, TensorDictBase, unravel_key -from tensordict.utils import NestedKey +from tensordict import ( + is_tensor_collection, + LazyStackedTensorDict, + TensorDictBase, + unravel_key, +) +from tensordict.base import _is_leaf_nontensor +from tensordict.utils import is_non_tensor, NestedKey from torchrl._utils import ( _ends_with, _make_ordinal_device, @@ -25,7 +31,13 @@ seed_generator, ) -from torchrl.data.tensor_specs import Categorical, Composite, TensorSpec, Unbounded +from torchrl.data.tensor_specs import ( + Categorical, + Composite, + NonTensor, + TensorSpec, + Unbounded, +) from torchrl.data.utils import DEVICE_TYPING from torchrl.envs.utils import ( _make_compatible_policy, @@ -430,7 +442,6 @@ def auto_specs_( done_key: NestedKey | List[NestedKey] | None = None, observation_key: NestedKey | List[NestedKey] = "observation", reward_key: NestedKey | List[NestedKey] = "reward", - batch_size: torch.Size | None = None, ): """Automatically sets the specifications (specs) of the environment based on a random rollout using a given policy. @@ -484,6 +495,7 @@ def auto_specs_( tensordict2, named=True, nested_keys=True, + is_leaf=_is_leaf_nontensor, ) input_spec = Composite(input_spec_stack, batch_size=batch_size) if not self.batch_locked and batch_size != self.batch_size: @@ -501,6 +513,7 @@ def auto_specs_( nexts_1, named=True, nested_keys=True, + is_leaf=_is_leaf_nontensor, ) output_spec = Composite(output_spec_stack, batch_size=batch_size) @@ -523,7 +536,8 @@ def auto_specs_( full_observation_spec = output_spec.separates(*observation_key, default=None) if not output_spec.is_empty(recurse=True): raise RuntimeError( - f"Keys {list(output_spec.keys(True, True))} are unaccounted for." + f"Keys {list(output_spec.keys(True, True))} are unaccounted for. " + f"Make sure you have passed all the leaf names to the auto_specs_ method." ) if full_action_spec is not None: @@ -541,6 +555,8 @@ def auto_specs_( @wraps(check_env_specs_func) def check_env_specs(self, *args, **kwargs): + return_contiguous = kwargs.pop("return_contiguous", not self._has_dynamic_specs) + kwargs["return_contiguous"] = return_contiguous return check_env_specs_func(self, *args, **kwargs) check_env_specs.__doc__ = check_env_specs_func.__doc__ @@ -3206,7 +3222,10 @@ def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase: """ if self._simple_done: done = tensordict._get_str("done", default=None) - any_done = done.any() + if done is not None: + any_done = done.any() + else: + any_done = False if any_done: tensordict._set_str( "_reset", @@ -3572,6 +3591,12 @@ def _has_dynamic_specs(spec: Composite): def _tensor_to_spec(name, leaf, leaf_compare=None, *, stack): + if not (isinstance(leaf, torch.Tensor) or is_tensor_collection(leaf)): + stack[name] = NonTensor(shape=()) + return + elif is_non_tensor(leaf): + stack[name] = NonTensor(shape=leaf.shape) + return shape = leaf.shape if leaf_compare is not None: shape_compare = leaf_compare.shape diff --git a/torchrl/envs/custom/__init__.py b/torchrl/envs/custom/__init__.py index 8649d3d3e97..375a0e23a57 100644 --- a/torchrl/envs/custom/__init__.py +++ b/torchrl/envs/custom/__init__.py @@ -3,5 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from .llm import LLMHashingEnv from .pendulum import PendulumEnv from .tictactoeenv import TicTacToeEnv diff --git a/torchrl/envs/custom/llm.py b/torchrl/envs/custom/llm.py new file mode 100644 index 00000000000..2f456482147 --- /dev/null +++ b/torchrl/envs/custom/llm.py @@ -0,0 +1,213 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from typing import Callable, List, Union + +import torch +from tensordict import NestedKey, TensorDict, TensorDictBase +from tensordict.tensorclass import NonTensorData, NonTensorStack + +from torchrl.data import ( + Categorical as CategoricalSpec, + Composite, + NonTensor, + SipHash, + Unbounded, +) +from torchrl.envs import EnvBase +from torchrl.envs.utils import _StepMDP + + +class LLMHashingEnv(EnvBase): + """A text generation environment that uses a hashing module to identify unique observations. + + The primary goal of this environment is to identify token chains using a hashing function. + This allows the data to be stored in a :class:`~torchrl.data.MCTSForest` using nothing but hashes as node + identifiers, or easily prune repeated token chains in a data structure. + The following figure gives an overview of this workflow: + + .. figure:: /_static/img/rollout-llm.png + :alt: Data collection loop with our LLM environment. + + .. seealso:: the :ref:`Beam Search ` tutorial gives a practical example of how this env can be used. + + Args: + vocab_size (int): The size of the vocabulary. Can be omitted if the tokenizer is passed. + + Keyword Args: + hashing_module (Callable[[torch.Tensor], torch.Tensor], optional): + A hashing function that takes a tensor as input and returns a hashed tensor. + Defaults to :class:`~torchrl.data.SipHash` if not provided. + observation_key (NestedKey, optional): The key for the observation in the TensorDict. + Defaults to "observation". + text_output (bool, optional): Whether to include the text output in the observation. + Defaults to True. + tokenizer (transformers.Tokenizer | None, optional): + A tokenizer function that converts text to tensors. + Only used when `text_output` is `True`. + Must implement the following methods: `decode` and `batch_decode`. + Defaults to ``None``. + text_key (NestedKey | None, optional): The key for the text output in the TensorDict. + Defaults to "text". + + Examples: + >>> from tensordict import TensorDict + >>> from torchrl.envs import LLMHashingEnv + >>> from transformers import GPT2Tokenizer + >>> tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") + >>> x = tokenizer(["Check out TorchRL!"])["input_ids"] + >>> env = LLMHashingEnv(tokenizer=tokenizer) + >>> td = TensorDict(observation=x, batch_size=[1]) + >>> td = env.reset(td) + >>> print(td) + TensorDict( + fields={ + done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False), + hash: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False), + observation: Tensor(shape=torch.Size([1, 5]), device=cpu, dtype=torch.int64, is_shared=False), + terminated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False), + text: NonTensorStack( + ['Check out TorchRL!'], + batch_size=torch.Size([1]), + device=None)}, + batch_size=torch.Size([1]), + device=None, + is_shared=False) + + """ + + def __init__( + self, + vocab_size: int | None = None, + *, + hashing_module: Callable[[torch.Tensor], torch.Tensor] = None, + observation_key: NestedKey = "observation", + text_output: bool = True, + tokenizer: Callable[[Union[str, List[str]]], torch.Tensor] | None = None, + text_key: NestedKey | None = "text", + ): + super().__init__() + if vocab_size is None: + if tokenizer is None: + raise TypeError( + "You must provide a vocab_size integer if tokenizer is `None`." + ) + vocab_size = tokenizer.vocab_size + self._batch_locked = False + if hashing_module is None: + hashing_module = SipHash() + + self._hashing_module = hashing_module + self._tokenizer = tokenizer + self.observation_key = observation_key + observation_spec = { + observation_key: CategoricalSpec(n=vocab_size, shape=(-1,)), + "hashing": Unbounded(shape=(1,), dtype=torch.int64), + } + self.text_output = text_output + if not text_output: + text_key = None + elif text_key is None: + text_key = "text" + if text_key is not None: + observation_spec[text_key] = NonTensor(shape=()) + self.text_key = text_key + self.observation_spec = Composite(observation_spec) + self.action_spec = Composite(action=CategoricalSpec(vocab_size, shape=(1,))) + _StepMDP(self) + + def make_tensordict(self, input: str | List[str]) -> TensorDict: + """Converts a string or list of strings in a TensorDict with appropriate shape and device.""" + list_len = len(input) if isinstance(input, list) else 0 + tensordict = TensorDict( + {self.observation_key: self._tokenizer(input)}, device=self.device + ) + if list_len: + tensordict.batch_size = [list_len] + return self.reset(tensordict) + + def _reset(self, tensordict: TensorDictBase): + """Initializes the environment with a given observation. + + Args: + tensordict (TensorDictBase): A TensorDict containing the initial observation. + + Returns: + A TensorDict containing the initial observation, its hash, and other relevant information. + + """ + out = tensordict.empty() + obs = tensordict.get(self.observation_key, None) + if obs is None: + raise RuntimeError( + f"Resetting the {type(self).__name__} environment requires a prompt." + ) + if self.text_output: + if obs.ndim > 1: + text = self._tokenizer.batch_decode(obs) + text = NonTensorStack.from_list(text) + else: + text = self._tokenizer.decode(obs) + text = NonTensorData(text) + out.set(self.text_key, text) + + if obs.ndim > 1: + out.set("hashing", self._hashing_module(obs).unsqueeze(-1)) + else: + out.set("hashing", self._hashing_module(obs.unsqueeze(0)).transpose(0, -1)) + + if not self.full_done_spec.is_empty(): + out.update(self.full_done_spec.zero(tensordict.shape)) + else: + out.set("done", torch.zeros((*tensordict.batch_size, 1), dtype=torch.bool)) + out.set( + "terminated", torch.zeros((*tensordict.batch_size, 1), dtype=torch.bool) + ) + return out + + def _step(self, tensordict): + """Takes an action (i.e., the next token to generate) and returns the next observation and reward. + + Args: + tensordict: A TensorDict containing the current observation and action. + + Returns: + A TensorDict containing the next observation, its hash, and other relevant information. + """ + out = tensordict.empty() + action = tensordict.get("action") + obs = torch.cat([tensordict.get(self.observation_key), action], -1) + kwargs = {self.observation_key: obs} + + catval = torch.cat([tensordict.get("hashing"), action], -1) + if obs.ndim > 1: + new_hash = self._hashing_module(catval).unsqueeze(-1) + else: + new_hash = self._hashing_module(catval.unsqueeze(0)).transpose(0, -1) + + if self.text_output: + if obs.ndim > 1: + text = self._tokenizer.batch_decode(obs) + text = NonTensorStack.from_list(text) + else: + text = self._tokenizer.decode(obs) + text = NonTensorData(text) + kwargs[self.text_key] = text + kwargs.update( + { + "hashing": new_hash, + "done": torch.zeros((*tensordict.batch_size, 1), dtype=torch.bool), + "terminated": torch.zeros( + (*tensordict.batch_size, 1), dtype=torch.bool + ), + } + ) + return out.update(kwargs) + + def _set_seed(self, *args): + """Sets the seed for the environment's randomness. + + .. note:: This environment has no randomness, so this method does nothing. + """ + pass diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 209349878ec..d2ec66475ab 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -76,7 +76,7 @@ def __get__(self, cls, owner): class _StepMDP: - """Stateful version of step_mdp. + """Stateful version of :func:`~torchrl.envs.step_mdp`. Precomputes the list of keys to include and exclude during a call to step_mdp to reduce runtime. @@ -778,12 +778,15 @@ def check_env_specs( ) zeroing_err_msg = ( "zeroing the two tensordicts did not make them identical. " - "Check for discrepancies:\nFake=\n{fake_tensordict}\nReal=\n{real_tensordict}" + f"Check for discrepancies:\nFake=\n{fake_tensordict}\nReal=\n{real_tensordict}" ) from torchrl.envs.common import _has_dynamic_specs if _has_dynamic_specs(env.specs): - for real, fake in zip(real_tensordict.unbind(-1), fake_tensordict.unbind(-1)): + for real, fake in zip( + real_tensordict.filter_non_tensor_data().unbind(-1), + fake_tensordict.filter_non_tensor_data().unbind(-1), + ): fake = fake.apply(lambda x, y: x.expand_as(y), real) if (torch.zeros_like(real) != torch.zeros_like(fake)).any(): raise AssertionError(zeroing_err_msg) @@ -1367,6 +1370,8 @@ def _update_during_reset( reset_keys: List[NestedKey], ): """Updates the input tensordict with the reset data, based on the reset keys.""" + if not reset_keys: + return tensordict.update(tensordict_reset) roots = set() for reset_key in reset_keys: # get the node of the reset key From 4bc40a80899a580bfd09289d62fbba6473d1ed7f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 12 Dec 2024 13:31:37 -0800 Subject: [PATCH 22/27] [Feature] env.step_mdp ghstack-source-id: 145e37cd772fdd74e35e5ffe6accc5c81ad689f3 Pull Request resolved: https://github.com/pytorch/rl/pull/2636 --- torchrl/envs/common.py | 46 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 4f6002dedd3..78f89cc8a38 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -3015,6 +3015,52 @@ def add_truncated_keys(self) -> EnvBase: self.__dict__["_done_keys"] = None return self + def step_mdp(self, next_tensordict: TensorDictBase) -> TensorDictBase: + """Advances the environment state by one step using the provided `next_tensordict`. + + This method updates the environment's state by transitioning from the current + state to the next, as defined by the `next_tensordict`. The resulting tensordict + includes updated observations and any other relevant state information, with + keys managed according to the environment's specifications. + + Internally, this method utilizes a precomputed :class:`~torchrl.envs.utils._StepMDP` instance to efficiently + handle the transition of state, observation, action, reward, and done keys. The + :class:`~torchrl.envs.utils._StepMDP` class optimizes the process by precomputing the keys to include and + exclude, reducing runtime overhead during repeated calls. The :class:`~torchrl.envs.utils._StepMDP` instance + is created with `exclude_action=False`, meaning that action keys are retained in + the root tensordict. + + Args: + next_tensordict (TensorDictBase): A tensordict containing the state of the + environment at the next time step. This tensordict should include keys + for observations, actions, rewards, and done flags, as defined by the + environment's specifications. + + Returns: + TensorDictBase: A new tensordict representing the environment state after + advancing by one step. + + .. note:: The method ensures that the environment's key specifications are validated + against the provided `next_tensordict`, issuing warnings if discrepancies + are found. + + .. note:: This method is designed to work efficiently with environments that have + consistent key specifications, leveraging the `_StepMDP` class to minimize + overhead. + + Example: + >>> from torchrl.envs import GymEnv + >>> env = GymEnv("Pendulum-1") + >>> data = env.reset() + >>> for i in range(10): + ... # compute action + ... env.rand_action(data) + ... # Perform action + ... next_data = env.step(reset_data) + ... data = env.step_mdp(next_data) + """ + return self._step_mdp(next_tensordict) + @property def _step_mdp(self): step_func = self.__dict__.get("_step_mdp_value") From dd26ae79f09bcc2f00082143cb2265812d90b202 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 12 Dec 2024 13:37:40 -0800 Subject: [PATCH 23/27] [Feature] spec.cardinality ghstack-source-id: 1160900f8a81dd51dc72436e1af69c8248bff162 Pull Request resolved: https://github.com/pytorch/rl/pull/2638 --- test/test_specs.py | 79 ++++++++++++++++++++++ torchrl/data/tensor_specs.py | 124 ++++++++++++++++++++++++++++++----- torchrl/envs/common.py | 19 ++++++ 3 files changed, 207 insertions(+), 15 deletions(-) diff --git a/test/test_specs.py b/test/test_specs.py index 3dedc6233a9..5334281f0ee 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -1689,6 +1689,85 @@ def test_unboundeddiscrete( assert spec is not spec.clone() +class TestCardinality: + @pytest.mark.parametrize("shape1", [(5, 4)]) + def test_binary(self, shape1): + spec = Binary(n=4, shape=shape1, device="cpu", dtype=torch.bool) + assert spec.cardinality() == len(list(spec.enumerate())) + + @pytest.mark.parametrize("shape1", [(5,), (5, 6)]) + def test_discrete( + self, + shape1, + ): + spec = Categorical(n=4, shape=shape1, device="cpu", dtype=torch.long) + assert spec.cardinality() == len(list(spec.enumerate())) + + @pytest.mark.parametrize("shape1", [(5,), (5, 6)]) + def test_multidiscrete( + self, + shape1, + ): + if shape1 is None: + shape1 = (3,) + else: + shape1 = (*shape1, 3) + spec = MultiCategorical( + nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long + ) + assert spec.cardinality() == len(spec.enumerate()) + + @pytest.mark.parametrize("shape1", [(5,), (5, 6)]) + def test_multionehot( + self, + shape1, + ): + if shape1 is None: + shape1 = (15,) + else: + shape1 = (*shape1, 15) + spec = MultiOneHot(nvec=(4, 5, 6), shape=shape1, device="cpu", dtype=torch.long) + assert spec.cardinality() == len(list(spec.enumerate())) + + def test_non_tensor(self): + spec = NonTensor(shape=(3, 4), device="cpu") + with pytest.raises(RuntimeError, match="Cannot enumerate a NonTensorSpec."): + spec.cardinality() + + @pytest.mark.parametrize("shape1", [(5,), (5, 6)]) + def test_onehot( + self, + shape1, + ): + if shape1 is None: + shape1 = (15,) + else: + shape1 = (*shape1, 15) + spec = OneHot(n=15, shape=shape1, device="cpu", dtype=torch.long) + assert spec.cardinality() == len(list(spec.enumerate())) + + def test_composite(self): + batch_size = (5,) + spec2 = Binary(n=4, shape=(*batch_size, 4), device="cpu", dtype=torch.bool) + spec3 = Categorical(n=4, shape=batch_size, device="cpu", dtype=torch.long) + spec4 = MultiCategorical( + nvec=(4, 5, 6), shape=(*batch_size, 3), device="cpu", dtype=torch.long + ) + spec5 = MultiOneHot( + nvec=(4, 5, 6), shape=(*batch_size, 15), device="cpu", dtype=torch.long + ) + spec6 = OneHot(n=15, shape=(*batch_size, 15), device="cpu", dtype=torch.long) + spec = Composite( + spec2=spec2, + spec3=spec3, + spec4=spec4, + spec5=spec5, + spec6=spec6, + shape=batch_size, + ) + assert spec.cardinality() == len(spec.enumerate()) + + class TestUnbind: @pytest.mark.parametrize("shape1", [(5, 4)]) def test_binary(self, shape1): diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index ddf6ed41c99..c03fb40f1ac 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -41,7 +41,7 @@ unravel_key, ) from tensordict.base import NO_DEFAULT -from tensordict.utils import _getitem_batch_size, NestedKey +from tensordict.utils import _getitem_batch_size, is_non_tensor, NestedKey from torchrl._utils import _make_ordinal_device, get_binary_env_var, implement_for DEVICE_TYPING = Union[torch.device, str, int] @@ -582,6 +582,16 @@ def clear_device_(self) -> T: """ return self + @abc.abstractmethod + def cardinality(self) -> int: + """The cardinality of the spec. + + This refers to the number of possible outcomes in a spec. It is assumed that the cardinality of a composite + spec is the cartesian product of all possible outcomes. + + """ + ... + def encode( self, val: np.ndarray | torch.Tensor | TensorDictBase, @@ -1515,6 +1525,9 @@ def __init__( def n(self): return self.space.n + def cardinality(self) -> int: + return self.n + def update_mask(self, mask): """Sets a mask to prevent some of the possible outcomes when a sample is taken. @@ -2107,6 +2120,9 @@ def enumerate(self) -> Any: f"enumerate is not implemented for spec of class {type(self).__name__}." ) + def cardinality(self) -> int: + return float("inf") + def __eq__(self, other): return ( type(other) == type(self) @@ -2426,8 +2442,11 @@ def __init__( shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs ) + def cardinality(self) -> Any: + raise RuntimeError("Cannot enumerate a NonTensorSpec.") + def enumerate(self) -> Any: - raise NotImplementedError("Cannot enumerate a NonTensorSpec.") + raise RuntimeError("Cannot enumerate a NonTensorSpec.") def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensor: if isinstance(dest, torch.dtype): @@ -2466,10 +2485,10 @@ def one(self, shape=None): data=None, batch_size=(*shape, *self._safe_shape), device=self.device ) - def is_in(self, val: torch.Tensor) -> bool: + def is_in(self, val: Any) -> bool: shape = torch.broadcast_shapes(self._safe_shape, val.shape) return ( - isinstance(val, NonTensorData) + is_non_tensor(val) and val.shape == shape # We relax constrains on device as they're hard to enforce for non-tensor # tensordicts and pointless @@ -2832,6 +2851,9 @@ def __init__( ) self.update_mask(mask) + def cardinality(self) -> int: + return torch.as_tensor(self.nvec).prod() + def enumerate(self) -> torch.Tensor: nvec = self.nvec enum_disc = self.to_categorical_spec().enumerate() @@ -3220,13 +3242,20 @@ class Categorical(TensorSpec): The spec will have the shape defined by the ``shape`` argument: if a singleton dimension is desired for the training dimension, one should specify it explicitly. + Attributes: + n (int): The number of possible outcomes. + shape (torch.Size): The shape of the variable. + device (torch.device): The device of the tensors. + dtype (torch.dtype): The dtype of the tensors. + Args: - n (int): number of possible outcomes. + n (int): number of possible outcomes. If set to -1, the cardinality of the categorical spec is undefined, + and `set_provisional_n` must be called before sampling from this spec. shape: (torch.Size, optional): shape of the variable, default is "torch.Size([])". - device (str, int or torch.device, optional): device of the tensors. - dtype (str or torch.dtype, optional): dtype of the tensors. - mask (torch.Tensor or None): mask some of the possible outcomes when a - sample is taken. See :meth:`~.update_mask` for more information. + device (str, int or torch.device, optional): the device of the tensors. + dtype (str or torch.dtype, optional): the dtype of the tensors. + mask (torch.Tensor or None): A boolean mask to prevent some of the possible outcomes when a sample is taken. + See :meth:`~.update_mask` for more information. Examples: >>> categ = Categorical(3) @@ -3249,6 +3278,13 @@ class Categorical(TensorSpec): domain=discrete) >>> categ.rand() tensor([1]) + >>> categ = Categorical(-1) + >>> categ.set_provisional_n(5) + >>> categ.rand() + tensor(3) + + .. note:: When n is set to -1, calling `rand` without first setting a provisional n using `set_provisional_n` + will raise a ``RuntimeError``. """ @@ -3276,16 +3312,31 @@ def __init__( shape=shape, space=space, device=device, dtype=dtype, domain="discrete" ) self.update_mask(mask) + self._provisional_n = None def enumerate(self) -> torch.Tensor: - arange = torch.arange(self.n, dtype=self.dtype, device=self.device) + dtype = self.dtype + if dtype is torch.bool: + dtype = torch.uint8 + arange = torch.arange(self.n, dtype=dtype, device=self.device) if self.ndim: arange = arange.view(-1, *(1,) * self.ndim) return arange.expand(self.n, *self.shape) @property def n(self): - return self.space.n + n = self.space.n + if n == -1: + n = self._provisional_n + if n is None: + raise RuntimeError( + f"Undefined cardinality for {type(self)}. Please call " + f"spec.set_provisional_n(int)." + ) + return n + + def cardinality(self) -> int: + return self.n def update_mask(self, mask): """Sets a mask to prevent some of the possible outcomes when a sample is taken. @@ -3316,13 +3367,33 @@ def update_mask(self, mask): raise ValueError("Only boolean masks are accepted.") self.mask = mask + def set_provisional_n(self, n: int): + """Set the cardinality of the Categorical spec temporarily. + + This method is required to be called before sampling from the spec when n is -1. + + Args: + n (int): The cardinality of the Categorical spec. + + """ + self._provisional_n = n + def rand(self, shape: torch.Size = None) -> torch.Tensor: + if self.space.n < 0: + if self._provisional_n is None: + raise RuntimeError( + "Cannot generate random categorical samples for undefined cardinality (n=-1). " + "To sample from this class, first call Categorical.set_provisional_n(n) before calling rand()." + ) + n = self._provisional_n + else: + n = self.space.n if shape is None: shape = _size([]) if self.mask is None: return torch.randint( 0, - self.space.n, + n, _size([*shape, *self.shape]), device=self.device, dtype=self.dtype, @@ -3334,6 +3405,12 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor: else: mask_flat = mask shape_out = mask.shape[:-1] + # Check that the mask has the right size + if mask_flat.shape[-1] != n: + raise ValueError( + "The last dimension of the mask must match the number of action allowed by the " + f"Categorical spec. Got mask.shape={self.mask.shape} and n={n}." + ) out = torch.multinomial(mask_flat.float(), 1).reshape(shape_out) return out @@ -3360,6 +3437,8 @@ def is_in(self, val: torch.Tensor) -> bool: dtype_match = val.dtype == self.dtype if not dtype_match: return False + if self.space.n == -1: + return True return (0 <= val).all() and (val < self.space.n).all() shape = self.mask.shape shape = _size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]]) @@ -3607,7 +3686,7 @@ def __init__( device: Optional[DEVICE_TYPING] = None, dtype: Union[str, torch.dtype] = torch.int8, ): - if n is None and not shape: + if n is None and shape is None: raise TypeError("Must provide either n or shape.") if n is None: n = shape[-1] @@ -3813,6 +3892,9 @@ def enumerate(self) -> torch.Tensor: arange = arange.expand(arange.shape[0], *self.shape) return arange + def cardinality(self) -> int: + return self.nvec._base.prod() + def update_mask(self, mask): """Sets a mask to prevent some of the possible outcomes when a sample is taken. @@ -4373,7 +4455,7 @@ def set(self, name, spec): shape = spec.shape if shape[: self.ndim] != self.shape: if ( - isinstance(spec, Composite) + isinstance(spec, (Composite, NonTensor)) and spec.ndim < self.ndim and self.shape[: spec.ndim] == spec.shape ): @@ -4382,7 +4464,7 @@ def set(self, name, spec): spec.shape = self.shape else: raise ValueError( - "The shape of the spec and the Composite mismatch: the first " + f"The shape of the spec {type(spec).__name__} and the Composite {type(self).__name__} mismatch: the first " f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and " f"Composite.shape={self.shape}." ) @@ -4798,6 +4880,18 @@ def clone(self) -> Composite: shape=self.shape, ) + def cardinality(self) -> int: + n = None + for spec in self.values(): + if spec is None: + continue + if n is None: + n = 1 + n = n * spec.cardinality() + if n is None: + n = 0 + return n + def enumerate(self) -> TensorDictBase: # We are going to use meshgrid to create samples of all the subspecs in here # but first let's get rid of the batch size, we'll put it back later diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 78f89cc8a38..3b55fd227a7 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -561,6 +561,25 @@ def check_env_specs(self, *args, **kwargs): check_env_specs.__doc__ = check_env_specs_func.__doc__ + def cardinality(self, tensordict: TensorDictBase | None = None) -> int: + """The cardinality of the action space. + + By default, this is just a wrapper around :meth:`env.action_space.cardinality <~torchrl.data.TensorSpec.cardinality>`. + + This class is useful when the action spec is variable: + + - The number of actions can be undefined, e.g., ``Categorical(n=-1)``; + - The action cardinality may depend on the action mask; + - The shape can be dynamic, as in ``Unbound(shape=(-1))``. + + In these cases, the :meth:`~.cardinality` should be overwritten, + + Args: + tensordict (TensorDictBase, optional): a tensordict containing the data required to compute the cardinality. + + """ + return self.full_action_spec.cardinality() + @classmethod def __new__(cls, *args, _inplace_update=False, _batch_locked=True, **kwargs): # inplace update will write tensors in-place on the provided tensordict. From ef5a37d8a76f3cf920966ac01eaf8348c7278d3c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 12 Dec 2024 13:37:41 -0800 Subject: [PATCH 24/27] [Quality,BE] Better doc for step_mdp ghstack-source-id: 1f5aed6fb2e97ead9d379f9545ae742f7728c585 Pull Request resolved: https://github.com/pytorch/rl/pull/2639 --- torchrl/_utils.py | 1 + torchrl/envs/utils.py | 51 +++++++++++++++++++++---------------------- 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index d37aebb862f..c81ffcc962b 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -829,6 +829,7 @@ def _can_be_pickled(obj): def _make_ordinal_device(device: torch.device): if device is None: return device + device = torch.device(device) if device.type == "cuda" and device.index is None: return torch.device("cuda", index=torch.cuda.current_device()) if device.type == "mps" and device.index is None: diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index d2ec66475ab..f7403e6a69e 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -14,7 +14,7 @@ import re import warnings from enum import Enum -from typing import Any, Dict, List, Union +from typing import Any, Dict, List import torch @@ -339,48 +339,47 @@ def step_mdp( exclude_reward: bool = True, exclude_done: bool = False, exclude_action: bool = True, - reward_keys: Union[NestedKey, List[NestedKey]] = "reward", - done_keys: Union[NestedKey, List[NestedKey]] = "done", - action_keys: Union[NestedKey, List[NestedKey]] = "action", + reward_keys: NestedKey | List[NestedKey] = "reward", + done_keys: NestedKey | List[NestedKey] = "done", + action_keys: NestedKey | List[NestedKey] = "action", ) -> TensorDictBase: """Creates a new tensordict that reflects a step in time of the input tensordict. Given a tensordict retrieved after a step, returns the :obj:`"next"` indexed-tensordict. - The arguments allow for a precise control over what should be kept and what + The arguments allow for precise control over what should be kept and what should be copied from the ``"next"`` entry. The default behavior is: - move the observation entries, reward and done states to the root, exclude - the current action and keep all extra keys (non-action, non-done, non-reward). + move the observation entries, reward, and done states to the root, exclude + the current action, and keep all extra keys (non-action, non-done, non-reward). Args: - tensordict (TensorDictBase): tensordict with keys to be renamed - next_tensordict (TensorDictBase, optional): destination tensordict - keep_other (bool, optional): if ``True``, all keys that do not start with :obj:`'next_'` will be kept. + tensordict (TensorDictBase): The tensordict with keys to be renamed. + next_tensordict (TensorDictBase, optional): The destination tensordict. If `None`, a new tensordict is created. + keep_other (bool, optional): If ``True``, all keys that do not start with :obj:`'next_'` will be kept. Default is ``True``. - exclude_reward (bool, optional): if ``True``, the :obj:`"reward"` key will be discarded + exclude_reward (bool, optional): If ``True``, the :obj:`"reward"` key will be discarded from the resulting tensordict. If ``False``, it will be copied (and replaced) - from the ``"next"`` entry (if present). - Default is ``True``. - exclude_done (bool, optional): if ``True``, the :obj:`"done"` key will be discarded + from the ``"next"`` entry (if present). Default is ``True``. + exclude_done (bool, optional): If ``True``, the :obj:`"done"` key will be discarded from the resulting tensordict. If ``False``, it will be copied (and replaced) - from the ``"next"`` entry (if present). - Default is ``False``. - exclude_action (bool, optional): if ``True``, the :obj:`"action"` key will + from the ``"next"`` entry (if present). Default is ``False``. + exclude_action (bool, optional): If ``True``, the :obj:`"action"` key will be discarded from the resulting tensordict. If ``False``, it will be kept in the root tensordict (since it should not be present in - the ``"next"`` entry). - Default is ``True``. - reward_keys (NestedKey or list of NestedKey, optional): the keys where the reward is written. Defaults + the ``"next"`` entry). Default is ``True``. + reward_keys (NestedKey or list of NestedKey, optional): The keys where the reward is written. Defaults to "reward". - done_keys (NestedKey or list of NestedKey, optional): the keys where the done is written. Defaults + done_keys (NestedKey or list of NestedKey, optional): The keys where the done is written. Defaults to "done". - action_keys (NestedKey or list of NestedKey, optional): the keys where the action is written. Defaults + action_keys (NestedKey or list of NestedKey, optional): The keys where the action is written. Defaults to "action". Returns: - A new tensordict (or next_tensordict) containing the tensors of the t+1 step. + TensorDictBase: A new tensordict (or `next_tensordict` if provided) containing the tensors of the t+1 step. + + .. seealso:: :meth:`EnvBase.step_mdp` is the class-based version of this free function. It will attempt to cache the + key values to reduce the overhead of making a step in the MDP. Examples: - This funtion allows for this kind of loop to be used: >>> from tensordict import TensorDict >>> import torch >>> td = TensorDict({ @@ -784,8 +783,8 @@ def check_env_specs( if _has_dynamic_specs(env.specs): for real, fake in zip( - real_tensordict.filter_non_tensor_data().unbind(-1), - fake_tensordict.filter_non_tensor_data().unbind(-1), + real_tensordict_select.filter_non_tensor_data().unbind(-1), + fake_tensordict_select.filter_non_tensor_data().unbind(-1), ): fake = fake.apply(lambda x, y: x.expand_as(y), real) if (torch.zeros_like(real) != torch.zeros_like(fake)).any(): From 6c7d233a44278aed978c54205557ffd03d6f2443 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 12 Dec 2024 13:37:42 -0800 Subject: [PATCH 25/27] [Test] More comprehensive tests for auto_spec ghstack-source-id: 75352490436fd706af3d36f9b8016e80a8a3f46a Pull Request resolved: https://github.com/pytorch/rl/pull/2640 --- test/mocking_classes.py | 10 +++++++--- test/test_env.py | 11 ++++++++--- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/test/mocking_classes.py b/test/mocking_classes.py index b6f4ac7069b..6f666290376 100644 --- a/test/mocking_classes.py +++ b/test/mocking_classes.py @@ -1931,14 +1931,18 @@ def __init__(self): tensor=Unbounded(3), non_tensor=NonTensor(shape=()), ) + self._saved_obs_spec = self.observation_spec.clone() self.state_spec = Composite( non_tensor=NonTensor(shape=()), ) + self._saved_state_spec = self.state_spec.clone() self.reward_spec = Unbounded(1) + self._saved_full_reward_spec = self.full_reward_spec.clone() self.action_spec = Unbounded(1) + self._saved_full_action_spec = self.full_action_spec.clone() def _reset(self, tensordict): - data = self.observation_spec.zero() + data = self._saved_obs_spec.zero() data.set_non_tensor("non_tensor", 0) data.update(self.full_done_spec.zero()) return data @@ -1947,10 +1951,10 @@ def _step( self, tensordict: TensorDictBase, ) -> TensorDictBase: - data = self.observation_spec.zero() + data = self._saved_obs_spec.zero() data.set_non_tensor("non_tensor", tensordict["non_tensor"] + 1) data.update(self.full_done_spec.zero()) - data.update(self.full_reward_spec.zero()) + data.update(self._saved_full_reward_spec.zero()) return data def _set_seed(self, seed: Optional[int]): diff --git a/test/test_env.py b/test/test_env.py index cef7a507f2a..415c973b6fb 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -3553,8 +3553,13 @@ def test_single_env_spec(): assert env.input_spec.is_in(env.input_spec_unbatched.zeros(env.shape)) -def test_auto_spec(): - env = CountingEnv() +@pytest.mark.parametrize("env_type", [CountingEnv, EnvWithMetadata]) +def test_auto_spec(env_type): + if env_type is EnvWithMetadata: + obs_vals = ["tensor", "non_tensor"] + else: + obs_vals = "observation" + env = env_type() td = env.reset() policy = lambda td, action_spec=env.full_action_spec.clone(): td.update( @@ -3577,7 +3582,7 @@ def test_auto_spec(): shape=env.full_state_spec.shape, device=env.full_state_spec.device ) env._action_keys = ["action"] - env.auto_specs_(policy, tensordict=td.copy()) + env.auto_specs_(policy, tensordict=td.copy(), observation_key=obs_vals) env.check_env_specs(tensordict=td.copy()) From 17983d43ecbeeaf70a32a3dae9c189e4afc87068 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 12 Dec 2024 18:53:30 -0800 Subject: [PATCH 26/27] [Feature] ChessEnv ghstack-source-id: 087c3b12cd621ea11a252b34c4896133697bce1a Pull Request resolved: https://github.com/pytorch/rl/pull/2641 --- docs/source/reference/envs.rst | 1 + torchrl/envs/__init__.py | 2 +- torchrl/envs/custom/__init__.py | 1 + torchrl/envs/custom/chess.py | 197 ++++++++++++++++++++++++++++++++ 4 files changed, 200 insertions(+), 1 deletion(-) create mode 100644 torchrl/envs/custom/chess.py diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 70fdf03c0ff..9ef9d88dbe6 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -345,6 +345,7 @@ TorchRL offers a series of custom built-in environments. :toctree: generated/ :template: rl_template.rst + ChessEnv PendulumEnv TicTacToeEnv LLMHashingEnv diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index f3dec221ce0..b863ad0801c 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -5,7 +5,7 @@ from .batched_envs import ParallelEnv, SerialEnv from .common import EnvBase, EnvMetaData, make_tensordict -from .custom import LLMHashingEnv, PendulumEnv, TicTacToeEnv +from .custom import ChessEnv, LLMHashingEnv, PendulumEnv, TicTacToeEnv from .env_creator import env_creator, EnvCreator, get_env_metadata from .gym_like import default_info_dict_reader, GymLikeEnv from .libs import ( diff --git a/torchrl/envs/custom/__init__.py b/torchrl/envs/custom/__init__.py index 375a0e23a57..d2c85a7198f 100644 --- a/torchrl/envs/custom/__init__.py +++ b/torchrl/envs/custom/__init__.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from .chess import ChessEnv from .llm import LLMHashingEnv from .pendulum import PendulumEnv from .tictactoeenv import TicTacToeEnv diff --git a/torchrl/envs/custom/chess.py b/torchrl/envs/custom/chess.py new file mode 100644 index 00000000000..b745f594b33 --- /dev/null +++ b/torchrl/envs/custom/chess.py @@ -0,0 +1,197 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict, Optional + +import torch +from tensordict import TensorDict, TensorDictBase +from torchrl.data import Categorical, Composite, NonTensor, Unbounded + +from torchrl.envs import EnvBase + +from torchrl.envs.utils import _classproperty + + +class ChessEnv(EnvBase): + """A chess environment that follows the TorchRL API. + + Requires: the `chess` library. More info `here `__. + + Args: + stateful (bool): Whether to keep track of the internal state of the board. + If False, the state will be stored in the observation and passed back + to the environment on each call. Default: ``False``. + + .. note:: the action spec is a :class:`~torchrl.data.Categorical` spec with a ``-1`` shape. + Unless :meth:`~torchrl.data.Categorical.set_provisional_n` is called with the cardinality of the legal moves, + valid random actions cannot be taken. :meth:`~torchrl.envs.EnvBase.rand_action` has been adapted to account for + this behavior. + + Examples: + >>> env = ChessEnv() + >>> r = env.reset() + >>> env.rand_step(r) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1, batch_size=torch.Size([]), device=None), + hashing: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + fen: NonTensorData(data=rnbqkbnr/pppppppp/8/8/8/2N5/PPPPPPPP/R1BQKBNR b KQkq - 1 1, batch_size=torch.Size([]), device=None), + hashing: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), + reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int32, is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False), + terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), + turn: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + >>> env.rollout(1000) + TensorDict( + fields={ + action: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.int64, is_shared=False), + done: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.bool, is_shared=False), + fen: NonTensorStack( + ['rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQ..., + batch_size=torch.Size([322]), + device=None), + hashing: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.int64, is_shared=False), + next: TensorDict( + fields={ + done: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.bool, is_shared=False), + fen: NonTensorStack( + ['rnbqkbnr/pppppppp/8/8/2P5/8/PP1PPPPP/RNBQKBNR b ..., + batch_size=torch.Size([322]), + device=None), + hashing: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.int64, is_shared=False), + reward: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.int32, is_shared=False), + terminated: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.bool, is_shared=False), + turn: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([322]), + device=None, + is_shared=False), + terminated: Tensor(shape=torch.Size([322, 1]), device=cpu, dtype=torch.bool, is_shared=False), + turn: Tensor(shape=torch.Size([322]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([322]), + device=None, + is_shared=False) + + + """ + + _hash_table: Dict[int, str] = {} + + @_classproperty + def lib(cls): + try: + import chess + except ImportError: + raise ImportError( + "The `chess` library could not be found. Make sure you installed it through `pip install chess`." + ) + return chess + + def __init__(self, stateful: bool = False): + chess = self.lib + super().__init__() + self.full_observation_spec = Composite( + hashing=Unbounded(shape=(), dtype=torch.int64), + fen=NonTensor(shape=()), + turn=Categorical(n=2, dtype=torch.bool, shape=()), + ) + self.stateful = stateful + if not self.stateful: + self.full_state_spec = self.full_observation_spec.clone() + self.full_action_spec = Composite( + action=Categorical(n=-1, shape=(), dtype=torch.int64) + ) + self.full_reward_spec = Composite( + reward=Unbounded(shape=(1,), dtype=torch.int32) + ) + # done spec generated automatically + self.board = chess.Board() + if self.stateful: + self.action_spec.set_provisional_n(len(list(self.board.legal_moves))) + + def rand_action(self, tensordict: Optional[TensorDictBase] = None): + self._set_action_space(tensordict) + return super().rand_action(tensordict) + + def _reset(self, tensordict=None): + fen = None + if tensordict is not None: + fen = self._get_fen(tensordict) + dest = tensordict.empty() + else: + dest = TensorDict() + + if fen is None: + self.board.reset() + fen = self.board.fen() + else: + self.board.set_fen(fen.data) + + hashing = hash(fen) + + self._set_action_space() + turn = self.board.turn + return dest.set("fen", fen).set("hashing", hashing).set("turn", turn) + + def _set_action_space(self, tensordict: TensorDict | None = None): + if not self.stateful and tensordict is not None: + fen = self._get_fen(tensordict).data + self.board.set_fen(fen) + self.action_spec.set_provisional_n(self.board.legal_moves.count()) + + @classmethod + def _get_fen(cls, tensordict): + fen = tensordict.get("fen", None) + if fen is None: + hashing = tensordict.get("hashing", None) + if hashing is not None: + fen = cls._hash_table.get(hashing.item()) + return fen + + def _step(self, tensordict): + # action + action = tensordict.get("action") + board = self.board + if not self.stateful: + fen = self._get_fen(tensordict).data + board.set_fen(fen) + action = str(list(board.legal_moves)[action]) + # assert chess.Move.from_uci(action) in board.legal_moves + board.push_san(action) + self._set_action_space() + + # Collect data + fen = self.board.fen() + dest = tensordict.empty() + hashing = hash(fen) + dest.set("fen", fen) + dest.set("hashing", hashing) + + done = board.is_checkmate() + turn = torch.tensor(board.turn) + reward = torch.tensor([done]).int() * (turn.int() * 2 - 1) + done = done | board.is_stalemate() | board.is_game_over() + dest.set("reward", reward) + dest.set("turn", turn) + dest.set("done", [done]) + dest.set("terminated", [done]) + return dest + + def _set_seed(self, *args, **kwargs): + ... + + def cardinality(self, tensordict: TensorDictBase | None = None) -> int: + self._set_action_space(tensordict) + return self.action_spec.cardinality() From 9ee1ae7ee6a8663c2d0ef29d26b23959c52e6d26 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 13 Dec 2024 09:03:35 -0800 Subject: [PATCH 27/27] [Feature] CatFrames.make_rb_transform_and_sampler ghstack-source-id: 7ecf952ec9f102a831aefdba533027ff8c4c29cc Pull Request resolved: https://github.com/pytorch/rl/pull/2643 --- .../replay-buffers/catframes-in-buffer.py | 99 +++++++++++++++++++ test/test_transforms.py | 23 +++++ torchrl/data/replay_buffers/samplers.py | 7 ++ torchrl/envs/transforms/transforms.py | 83 +++++++++++++++- 4 files changed, 209 insertions(+), 3 deletions(-) create mode 100644 examples/replay-buffers/catframes-in-buffer.py diff --git a/examples/replay-buffers/catframes-in-buffer.py b/examples/replay-buffers/catframes-in-buffer.py new file mode 100644 index 00000000000..916fc63bc50 --- /dev/null +++ b/examples/replay-buffers/catframes-in-buffer.py @@ -0,0 +1,99 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torchrl.data import LazyTensorStorage, ReplayBuffer +from torchrl.envs import ( + CatFrames, + Compose, + DMControlEnv, + StepCounter, + ToTensorImage, + TransformedEnv, + UnsqueezeTransform, +) + +# Number of frames to stack together +frame_stack = 4 +# Dimension along which the stack should occur +stack_dim = -4 +# Max size of the buffer +max_size = 100_000 +# Batch size of the replay buffer +training_batch_size = 32 + +seed = 123 + + +def main(): + catframes = CatFrames( + N=frame_stack, + dim=stack_dim, + in_keys=["pixels_trsf"], + out_keys=["pixels_trsf"], + ) + env = TransformedEnv( + DMControlEnv( + env_name="cartpole", + task_name="balance", + device="cpu", + from_pixels=True, + pixels_only=True, + ), + Compose( + ToTensorImage( + from_int=True, + dtype=torch.float32, + in_keys=["pixels"], + out_keys=["pixels_trsf"], + shape_tolerant=True, + ), + UnsqueezeTransform( + dim=stack_dim, in_keys=["pixels_trsf"], out_keys=["pixels_trsf"] + ), + catframes, + StepCounter(), + ), + ) + env.set_seed(seed) + + transform, sampler = catframes.make_rb_transform_and_sampler( + batch_size=training_batch_size, + traj_key=("collector", "traj_ids"), + strict_length=True, + ) + + rb_transforms = Compose( + ToTensorImage( + from_int=True, + dtype=torch.float32, + in_keys=["pixels", ("next", "pixels")], + out_keys=["pixels_trsf", ("next", "pixels_trsf")], + shape_tolerant=True, + ), # C W' H' -> C W' H' (unchanged due to shape_tolerant) + UnsqueezeTransform( + dim=stack_dim, + in_keys=["pixels_trsf", ("next", "pixels_trsf")], + out_keys=["pixels_trsf", ("next", "pixels_trsf")], + ), # 1 C W' H' + transform, + ) + + rb = ReplayBuffer( + storage=LazyTensorStorage(max_size=max_size, device="cpu"), + sampler=sampler, + batch_size=training_batch_size, + transform=rb_transforms, + ) + + data = env.rollout(1000, break_when_any_done=False) + rb.extend(data) + + training_batch = rb.sample() + print(training_batch) + + +if __name__ == "__main__": + main() diff --git a/test/test_transforms.py b/test/test_transforms.py index d90c00b6a19..cc3ca40b059 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -933,6 +933,29 @@ def test_transform_rb(self, dim, N, padding, rbclass): assert (tdsample["out_" + key1] == td["out_" + key1]).all() assert (tdsample["next", "out_" + key1] == td["next", "out_" + key1]).all() + def test_transform_rb_maker(self): + env = CountingEnv(max_steps=10) + catframes = CatFrames( + in_keys=["observation"], out_keys=["observation_stack"], dim=-1, N=4 + ) + env.append_transform(catframes) + policy = lambda td: td.update(env.full_action_spec.zeros() + 1) + rollout = env.rollout(150, policy, break_when_any_done=False) + transform, sampler = catframes.make_rb_transform_and_sampler(batch_size=32) + rb = ReplayBuffer( + sampler=sampler, storage=LazyTensorStorage(150), transform=transform + ) + rb.extend(rollout) + sample = rb.sample(32) + assert "observation_stack" not in rb._storage._storage + assert sample.shape == (32,) + assert sample["observation_stack"].shape == (32, 4) + assert sample["next", "observation_stack"].shape == (32, 4) + assert ( + sample["observation_stack"] + == sample["observation_stack"][:, :1] + torch.arange(4) + ).all() + @pytest.mark.parametrize("dim", [-1]) @pytest.mark.parametrize("N", [3, 4]) @pytest.mark.parametrize("padding", ["same", "constant"]) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index b97b585aa3f..bbdf2387683 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -968,6 +968,9 @@ class SliceSampler(Sampler): """ + # We use this whenever we need to sample N times too many transitions to then select only a 1/N fraction of them + _batch_size_multiplier: int | None = 1 + def __init__( self, *, @@ -1295,6 +1298,8 @@ def _adjusted_batch_size(self, batch_size): return seq_length, num_slices def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]: + if self._batch_size_multiplier is not None: + batch_size = batch_size * self._batch_size_multiplier # pick up as many trajs as we need start_idx, stop_idx, lengths = self._get_stop_and_length(storage) # we have to make sure that the number of dims of the storage @@ -1747,6 +1752,8 @@ def _storage_len(self, storage): def sample( self, storage: Storage, batch_size: int ) -> Tuple[Tuple[torch.Tensor, ...], dict]: + if self._batch_size_multiplier is not None: + batch_size = batch_size * self._batch_size_multiplier start_idx, stop_idx, lengths = self._get_stop_and_length(storage) # we have to make sure that the number of dims of the storage # is the same as the stop/start signals since we will diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 0bab5868ded..f3329d085df 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -2825,9 +2825,9 @@ def _reset( class CatFrames(ObservationTransform): """Concatenates successive observation frames into a single tensor. - This can, for instance, account for movement/velocity of the observed - feature. Proposed in "Playing Atari with Deep Reinforcement Learning" ( - https://arxiv.org/pdf/1312.5602.pdf). + This transform is useful for creating a sense of movement or velocity in the observed features. + It can also be used with models that require access to past observations such as transformers and the like. + It was first proposed in "Playing Atari with Deep Reinforcement Learning" (https://arxiv.org/pdf/1312.5602.pdf). When used within a transformed environment, :class:`CatFrames` is a stateful class, and it can be reset to its native state by @@ -2915,6 +2915,14 @@ class CatFrames(ObservationTransform): such as those found in MARL settings, are currently not supported. If this feature is needed, please raise an issue on TorchRL repo. + .. note:: Storing stacks of frames in the replay buffer can significantly increase memory consumption (by N times). + To mitigate this, you can store trajectories directly in the replay buffer and apply :class:`CatFrames` at sampling time. + This approach involves sampling slices of the stored trajectories and then applying the frame stacking transform. + For convenience, :class:`CatFrames` provides a :meth:`~.make_rb_transform_and_sampler` method that creates: + + - A modified version of the transform suitable for use in replay buffers + - A corresponding :class:`SliceSampler` to use with the buffer + """ inplace = False @@ -2964,6 +2972,75 @@ def __init__( self.reset_key = reset_key self.done_key = done_key + def make_rb_transform_and_sampler( + self, batch_size: int, **sampler_kwargs + ) -> Tuple[Transform, "torchrl.data.replay_buffers.SliceSampler"]: # noqa: F821 + """Creates a transform and sampler to be used with a replay buffer when storing frame-stacked data. + + This method helps reduce redundancy in stored data by avoiding the need to + store the entire stack of frames in the buffer. Instead, it creates a + transform that stacks frames on-the-fly during sampling, and a sampler that + ensures the correct sequence length is maintained. + + Args: + batch_size (int): The batch size to use for the sampler. + **sampler_kwargs: Additional keyword arguments to pass to the + :class:`~torchrl.data.replay_buffers.SliceSampler` constructor. + + Returns: + A tuple containing: + - transform (Transform): A transform that stacks frames on-the-fly during sampling. + - sampler (SliceSampler): A sampler that ensures the correct sequence length is maintained. + + Example: + >>> env = TransformedEnv(...) + >>> catframes = CatFrames(N=4, ...) + >>> transform, sampler = catframes.make_rb_transform_and_sampler(batch_size=32) + >>> rb = ReplayBuffer(..., sampler=sampler, transform=transform) + + .. note:: When working with images, it's recommended to use distinct ``in_keys`` and ``out_keys`` in the preceding + :class:`~torchrl.envs.ToTensorImage` transform. This ensures that the tensors stored in the buffer are separate + from their processed counterparts, which we don't want to store. + For non-image data, consider inserting a :class:`~torchrl.envs.RenameTransform` before :class:`CatFrames` to create + a copy of the data that will be stored in the buffer. + + .. note:: When adding the transform to the replay buffer, one should pay attention to also pass the transforms + that precede CatFrames, such as :class:`~torchrl.envs.ToTensorImage` or :class:`~torchrl.envs.UnsqueezeTransform` + in such a way that the :class:`~torchrl.envs.CatFrames` transforms sees data formatted as it was during data + collection. + + .. note:: For a more complete example, refer to torchrl's github repo `examples` folder: + https://github.com/pytorch/rl/tree/main/examples/replay-buffers/catframes-in-buffer.py + + """ + from torchrl.data.replay_buffers import SliceSampler + + in_keys = self.in_keys + in_keys = in_keys + [unravel_key(("next", key)) for key in in_keys] + out_keys = self.out_keys + out_keys = out_keys + [unravel_key(("next", key)) for key in out_keys] + catframes = type(self)( + N=self.N, + in_keys=in_keys, + out_keys=out_keys, + dim=self.dim, + padding=self.padding, + padding_value=self.padding_value, + as_inverse=False, + reset_key=self.reset_key, + done_key=self.done_key, + ) + sampler = SliceSampler(slice_len=self.N, **sampler_kwargs) + sampler._batch_size_multiplier = self.N + transform = Compose( + lambda td: td.reshape(-1, self.N), + catframes, + lambda td: td[:, -1], + # We only store "pixels" to the replay buffer to save memory + ExcludeTransform(*out_keys, inverse=True), + ) + return transform, sampler + @property def done_key(self): done_key = self.__dict__.get("_done_key", None)