From 57dc25a446a94c4175ad1820473196ff5c49249a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 12 Dec 2024 11:19:09 -0800 Subject: [PATCH] [Refactor] Refactor trees ghstack-source-id: 368ba4c4402b6db0bc8b0688802ce161db9776b7 Pull Request resolved: https://github.com/pytorch/rl/pull/2634 --- test/test_storage_map.py | 104 ++++- torchrl/data/map/hash.py | 3 +- torchrl/data/map/tdstorage.py | 32 +- torchrl/data/map/tree.py | 585 ++++++++++++++++++++++-- torchrl/data/map/utils.py | 6 +- torchrl/data/replay_buffers/storages.py | 4 +- 6 files changed, 678 insertions(+), 56 deletions(-) diff --git a/test/test_storage_map.py b/test/test_storage_map.py index 9ff4431fb50..db2d0bc2c49 100644 --- a/test/test_storage_map.py +++ b/test/test_storage_map.py @@ -301,6 +301,7 @@ def _state0(self) -> TensorDict: def _make_td(state: torch.Tensor, action: torch.Tensor) -> TensorDict: done = torch.zeros_like(action, dtype=torch.bool).unsqueeze(-1) reward = action.clone() + action = action + torch.arange(action.shape[-1]) / action.shape[-1] return TensorDict( { @@ -326,7 +327,7 @@ def _make_forest(self) -> MCTSForest: forest.extend(r4) return forest - def _make_forest_intersect(self) -> MCTSForest: + def _make_forest_rebranching(self) -> MCTSForest: """ ├── 0 │ ├── 16 @@ -449,7 +450,7 @@ def test_forest_check_ids(self): def test_forest_intersect(self): state0 = self._state0() - forest = self._make_forest_intersect() + forest = self._make_forest_rebranching() tree = forest.get_tree(state0) subtree = forest.get_tree(TensorDict(observation=19)) @@ -467,13 +468,110 @@ def test_forest_intersect(self): def test_forest_intersect_vertices(self): state0 = self._state0() - forest = self._make_forest_intersect() + forest = self._make_forest_rebranching() tree = forest.get_tree(state0) assert len(tree.vertices(key_type="path")) > len(tree.vertices(key_type="hash")) assert len(tree.vertices(key_type="id")) == len(tree.vertices(key_type="hash")) with pytest.raises(ValueError, match="key_type must be"): tree.vertices(key_type="another key type") + @pytest.mark.skipif(not _has_gym, reason="requires gym") + def test_simple_tree(self): + from torchrl.envs import GymEnv + + env = GymEnv("Pendulum-v1") + r = env.rollout(10) + state0 = r[0] + forest = MCTSForest() + forest.extend(r) + # forest = self._make_forest_intersect() + tree = forest.get_tree(state0, compact=False) + assert tree.max_length() == 9 + for p in tree.valid_paths(): + assert len(p) == 9 + + @pytest.mark.parametrize( + "tree_type,compact", + [ + ["simple", False], + ["forest", False], + # parent of rebranching trees are still buggy + # ["rebranching", False], + # ["rebranching", True], + ], + ) + def test_forest_parent(self, tree_type, compact): + if tree_type == "simple": + if not _has_gym: + pytest.skip("requires gym") + from torchrl.envs import GymEnv + + env = GymEnv("Pendulum-v1") + r = env.rollout(10) + state0 = r[0] + forest = MCTSForest() + forest.extend(r) + tree = forest.get_tree(state0, compact=compact) + elif tree_type == "forest": + state0 = self._state0() + forest = self._make_forest() + tree = forest.get_tree(state0, compact=compact) + else: + state0 = self._state0() + forest = self._make_forest_rebranching() + tree = forest.get_tree(state0, compact=compact) + # Check access + tree.subtree.parent + tree.subtree.subtree.parent + tree.subtree.subtree.subtree.parent + + # check present of weakref + assert tree.subtree[0]._parent is not None + assert tree.subtree[0].subtree[0]._parent is not None + + # Check content + assert_close(tree.subtree.parent, tree) + for p in tree.valid_paths(): + root = tree + for it in p: + node = root.subtree[it] + assert_close(node.parent, root) + root = node + + def test_forest_action_attr(self): + state0 = self._state0() + forest = self._make_forest() + tree = forest.get_tree(state0) + assert tree.branching_action is None + assert (tree.subtree.branching_action != tree.subtree.prev_action).any() + assert ( + tree.subtree[0].subtree.branching_action + != tree.subtree[0].subtree.prev_action + ).any() + assert tree.prev_action is None + + @pytest.mark.parametrize("intersect", [False, True]) + def test_forest_check_obs_match(self, intersect): + state0 = self._state0() + if intersect: + forest = self._make_forest_rebranching() + else: + forest = self._make_forest() + tree = forest.get_tree(state0) + for path in tree.valid_paths(): + prev_tree = tree + for p in path: + subtree = prev_tree.subtree[p] + assert ( + subtree.node_data["observation"] + == subtree.rollout[..., -1]["next", "observation"] + ).all() + assert ( + subtree.node_observation + == subtree.rollout[..., -1]["next", "observation"] + ).all() + prev_tree = subtree + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() diff --git a/torchrl/data/map/hash.py b/torchrl/data/map/hash.py index 01988dc43be..59526628dbe 100644 --- a/torchrl/data/map/hash.py +++ b/torchrl/data/map/hash.py @@ -75,7 +75,8 @@ def forward(self, features: torch.Tensor) -> torch.Tensor: class SipHash(Module): """A Module to Compute SipHash values for given tensors. - A hash function module based on SipHash implementation in python. + A hash function module based on SipHash implementation in python. Input tensors should have shape ``[batch_size, num_features]`` + and the output shape will be ``[batch_size]``. Args: as_tensor (bool, optional): if ``True``, the bytes will be turned into integers diff --git a/torchrl/data/map/tdstorage.py b/torchrl/data/map/tdstorage.py index a601f1e3261..9413033bac4 100644 --- a/torchrl/data/map/tdstorage.py +++ b/torchrl/data/map/tdstorage.py @@ -138,6 +138,10 @@ def __init__( self.collate_fn = collate_fn self.write_fn = write_fn + @property + def max_size(self): + return self.storage.max_size + @property def out_keys(self) -> List[NestedKey]: out_keys = self.__dict__.get("_out_keys_and_lazy") @@ -177,7 +181,7 @@ def from_tensordict_pair( collate_fn: Callable[[Any], Any] | None = None, write_fn: Callable[[Any, Any], Any] | None = None, consolidated: bool | None = None, - ): + ) -> TensorDictMap: """Creates a new TensorDictStorage from a pair of tensordicts (source and dest) using pre-defined rules of thumb. Args: @@ -238,7 +242,13 @@ def from_tensordict_pair( n_feat = 0 hash_module = [] for in_key in in_keys: - n_feat = source[in_key].shape[-1] + entry = source[in_key] + if entry.ndim == source.ndim: + # this is a good example of why td/tc are useful - carrying metadata + # allows us to know if there's a feature dim or not + n_feat = 0 + else: + n_feat = entry.shape[-1] if n_feat > RandomProjectionHash._N_COMPONENTS_DEFAULT: _hash_module = RandomProjectionHash() else: @@ -308,7 +318,23 @@ def __setitem__(self, item: TensorDictBase, value: TensorDictBase): if not self._has_lazy_out_keys(): # TODO: make this work with pytrees and avoid calling select if keys match value = value.select(*self.out_keys, strict=False) + item, value = self._maybe_add_batch(item, value) + index = self._to_index(item, extend=True) + if index.unique().numel() < index.numel(): + # If multiple values point to the same place in the storage, we cannot process them by batch + # There could be a better way to deal with this, using unique ids. + vals = [] + for it, val in zip(item.split(1), value.split(1)): + self[it] = val + vals.append(val) + # __setitem__ may affect the content of the input data + value.update(TensorDictBase.lazy_stack(vals)) + return if self.write_fn is not None: + # We use this block in the following context: the value written in the storage is already present, + # but it needs to be updated. + # We first check if the value is already there using `contains`. If so, we pass the new value and the + # previous one to write_fn. The values that are not present are passed alone. if len(self): modifiable = self.contains(item) if modifiable.any(): @@ -322,8 +348,6 @@ def __setitem__(self, item: TensorDictBase, value: TensorDictBase): value = self.write_fn(value) else: value = self.write_fn(value) - item, value = self._maybe_add_batch(item, value) - index = self._to_index(item, extend=True) self.storage.set(index, value) def __len__(self): diff --git a/torchrl/data/map/tree.py b/torchrl/data/map/tree.py index 645f7704ddd..513a7b94e58 100644 --- a/torchrl/data/map/tree.py +++ b/torchrl/data/map/tree.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +import weakref from collections import deque from typing import Any, Callable, Dict, List, Literal, Tuple @@ -15,10 +16,13 @@ TensorClass, TensorDict, TensorDictBase, + unravel_key, ) from torchrl.data.map.tdstorage import TensorDictMap from torchrl.data.map.utils import _plot_plotly_box, _plot_plotly_tree from torchrl.data.replay_buffers.storages import ListStorage +from torchrl.data.tensor_specs import Composite + from torchrl.envs.common import EnvBase @@ -69,7 +73,9 @@ class Tree(TensorClass["nocast"]): """ - count: int = None + count: int | torch.Tensor = None + wins: int | torch.Tensor = None + index: torch.Tensor | None = None # The hash is None if the node has more than one action associated hash: int | None = None @@ -78,12 +84,249 @@ class Tree(TensorClass["nocast"]): # rollout following the observation encoded in node, in a TorchRL (TED) format rollout: TensorDict | None = None - # The data specifying the node - node: TensorDict | None = None + # The data specifying the node (typically an observation or a set of observations) + node_data: TensorDict | None = None # Stack of subtrees. A subtree is produced when an action is taken. subtree: "Tree" = None + # weakrefs to the parent(s) of the node + _parent: weakref.ref | List[weakref.ref] | None = None + + # Specs: contains information such as action or observation keys and spaces. + # If present, they should be structured like env specs are: + # Composite(input_spec=Composite(full_state_spec=..., full_action_spec=...), + # output_spec=Composite(full_observation_spec=..., full_reward_spec=..., full_done_spec=...)) + # where every leaf component is optional. + specs: Composite | None = None + + @classmethod + def make_node( + cls, + data: TensorDictBase, + *, + device: torch.device | None = None, + batch_size: torch.Size | None = None, + specs: Composite | None = None, + ) -> Tree: + """Creates a new node given some data.""" + if "next" in data.keys(): + rollout = data + if not rollout.ndim: + rollout = rollout.unsqueeze(0) + subtree = TensorDict.lazy_stack([cls.make_node(data["next"][..., -1])]) + else: + rollout = None + subtree = None + if device is None: + device = data.device + return cls( + count=torch.zeros(()), + wins=torch.zeros(()), + node=data.exclude("action", "next"), + rollout=rollout, + subtree=subtree, + device=device, + batch_size=batch_size, + ) + + # Specs + @property + def full_observation_spec(self): + """The observation spec of the tree. + + This is an alias for `Tree.specs['output_spec', 'full_observation_spec']`.""" + return self.specs["output_spec", "full_observation_spec"] + + @property + def full_reward_spec(self): + """The reward spec of the tree. + + This is an alias for `Tree.specs['output_spec', 'full_reward_spec']`.""" + return self.specs["output_spec", "full_reward_spec"] + + @property + def full_done_spec(self): + """The done spec of the tree. + + This is an alias for `Tree.specs['output_spec', 'full_done_spec']`.""" + return self.specs["output_spec", "full_done_spec"] + + @property + def full_state_spec(self): + """The state spec of the tree. + + This is an alias for `Tree.specs['input_spec', 'full_state_spec']`.""" + return self.specs["input_spec", "full_state_spec"] + + @property + def full_action_spec(self): + """The action spec of the tree. + + This is an alias for `Tree.specs['input_spec', 'full_action_spec']`.""" + return self.specs["input_spec", "full_action_spec"] + + @property + def selected_actions(self) -> torch.Tensor | TensorDictBase | None: + """Returns a tensor containing all the selected actions branching out from this node.""" + if self.subtree is None: + return None + return self.subtree.rollout[..., 0]["action"] + + @property + def prev_action(self) -> torch.Tensor | TensorDictBase | None: + """The action undertaken just before this node's observation was generated. + + Returns: + a tensor, tensordict or None if the node has no parent. + + .. seealso:: This will be equal to :class:`~torchrl.data.Tree.branching_action` whenever the rollout data contains a single step. + + .. seealso:: :class:`All actions associated with a given node (or observation) in the tree <~torchrl.data.Tree.selected_action>`. + + """ + if self.rollout is None: + return None + return self.rollout[..., -1]["action"] + + @property + def branching_action(self) -> torch.Tensor | TensorDictBase | None: + """Returns the action that branched out to this particular node. + + Returns: + a tensor, tensordict or None if the node has no parent. + + .. seealso:: This will be equal to :class:`~torchrl.data.Tree.prev_action` whenever the rollout data contains a single step. + + .. seealso:: :class:`All actions associated with a given node (or observation) in the tree <~torchrl.data.Tree.selected_action>`. + + """ + if self.rollout is None: + return None + return self.rollout[..., 0]["action"] + + @property + def node_observation(self) -> torch.Tensor | TensorDictBase: + """Returns the observation associated with this particular node. + + This is the observation (or bag of observations) that defines the node before a branching occurs. + If the node contains a :attr:`~.rollout` attribute, the node observation is typically identical to the + observation resulting from the last action undertaken, i.e., ``node.rollout[..., -1]["next", "observation"]``. + + If more than one observation key is associated with the tree specs, a :class:`~tensordict.TensorDict` instance + is returned instead. + + For a more consistent representation, see :attr:`~.node_observations`. + + """ + # TODO: implement specs + return self.node_data["observation"] + + @property + def node_observations(self) -> torch.Tensor | TensorDictBase: + """Returns the observations associated with this particular node in a TensorDict format. + + This is the observation (or bag of observations) that defines the node before a branching occurs. + If the node contains a :attr:`~.rollout` attribute, the node observation is typically identical to the + observation resulting from the last action undertaken, i.e., ``node.rollout[..., -1]["next", "observation"]``. + + If more than one observation key is associated with the tree specs, a :class:`~tensordict.TensorDict` instance + is returned instead. + + For a more consistent representation, see :attr:`~.node_observations`. + + """ + # TODO: implement specs + return self.node_data.select("observation") + + @property + def visits(self) -> int | torch.Tensor: + """Returns the number of visits associated with this particular node. + + This is an alias for the :attr:`~.count` attribute. + + """ + return self.count + + @visits.setter + def visits(self, count): + self.count = count + + def __setattr__(self, name: str, value: Any) -> None: + if name == "subtree" and value is not None: + wr = weakref.ref(self._tensordict) + if value._parent is None: + value._parent = wr + elif isinstance(value._parent, list): + value._parent.append(wr) + else: + value._parent = [value._parent, wr] + return super().__setattr__(name, value) + + @property + def parent(self) -> Tree | None: + """The parent of the node. + + If the node has a parent and this object is still present in the python workspace, it will be returned by this + property. + + For re-branching trees, this property may return a stack of trees where every index of the stack corresponds to + a different parent. + + .. note:: the ``parent`` attribute will match in content but not in identity: the tensorclass object is recustructed + using the same tensors (i.e., tensors that point to the same memory locations). + + Returns: + A ``Tree`` containing the parent data or ``None`` if the parent data is out of scope or the node is the root. + """ + parent = self._parent + if parent is not None: + # Check that all parents match + queue = [parent] + + def maybe_flatten_list(maybe_nested_list): + if isinstance(maybe_nested_list, list): + for p in maybe_nested_list: + if isinstance(p, list): + queue.append(p) + else: + yield p() + else: + yield maybe_nested_list() + + parent_result = None + while len(queue): + local_result = None + for r in maybe_flatten_list(queue.pop()): + if local_result is None: + local_result = r + elif r is not None and r is not local_result: + if isinstance(local_result, list): + local_result.append(r) + else: + local_result = [local_result, r] + if local_result is None: + continue + # replicate logic at macro level + if parent_result is None: + parent_result = local_result + else: + if isinstance(local_result, list): + local_result = [ + r for r in local_result if r not in parent_result + ] + else: + local_result = [local_result] + if isinstance(parent_result, list): + parent_result.extend(local_result) + else: + parent_result = [parent_result, *local_result] + if isinstance(parent_result, list): + return TensorDict.lazy_stack( + [self._from_tensordict(r) for r in parent_result] + ) + return self._from_tensordict(parent_result) + @property def num_children(self) -> int: """Number of children of this node. @@ -93,9 +336,19 @@ def num_children(self) -> int: return len(self.subtree) if self.subtree is not None else 0 @property - def is_terminal(self): - """Returns True if the the tree has no children nodes.""" - return self.subtree is None + def is_terminal(self) -> bool | torch.Tensor: + """Returns True if the tree has no children nodes.""" + if self.rollout is not None: + return self.rollout[..., -1]["next", "done"].squeeze(-1) + # If there is no rollout, there is no preceding data - either this is a root or it's a floating node. + # In either case, we assume that the node is not terminal. + return False + + def fully_expanded(self, env: EnvBase) -> bool: + """Returns True if the number of children is equal to the environment cardinality.""" + cardinality = env.cardinality(self.node_data) + num_actions = self.num_children + return cardinality == num_actions def get_vertex_by_id(self, id: int) -> Tree: """Goes through the tree and returns the node corresponding the given id.""" @@ -163,9 +416,6 @@ def vertices( if h in memo and not use_path: continue memo.add(h) - r = tree.rollout - if r is not None: - r = r["next", "observation"] if use_path: result[cur_path] = tree elif use_id: @@ -206,6 +456,14 @@ def num_vertices(self, *, count_repeat: bool = False) -> int: ) def edges(self) -> List[Tuple[int, int]]: + """Retrieves a list of edges in the tree. + + Each edge is represented as a tuple of two node IDs: the parent node ID and the child node ID. + The tree is traversed using Breadth-First Search (BFS) to ensure all edges are visited. + + Returns: + A list of tuples, where each tuple contains a parent node ID and a child node ID. + """ result = [] q = deque() parent = self.node_id @@ -221,22 +479,62 @@ def edges(self) -> List[Tuple[int, int]]: return result def valid_paths(self): + """Generates all valid paths in the tree. + + A valid path is a sequence of child indices that starts at the root node and ends at a leaf node. + Each path is represented as a tuple of integers, where each integer corresponds to the index of a child node. + + Yields: + tuple: A valid path in the tree. + """ + # Initialize a queue with the current tree node and an empty path q = deque() cur_path = () q.append((self, cur_path)) + # Perform BFS traversal of the tree while len(q): + # Dequeue the next tree node and its current path tree, cur_path = q.popleft() + # Get the number of child nodes n = int(tree.num_children) + # If this is a leaf node, yield the current path if not n: yield cur_path + # Iterate over the child nodes for i in range(n): cur_path_tree = cur_path + (i,) q.append((tree.subtree[i], cur_path_tree)) def max_length(self): - return max(*(len(path) for path in self.valid_paths())) + """Returns the maximum length of all valid paths in the tree. + + The length of a path is defined as the number of nodes in the path. + If the tree is empty, returns 0. + + Returns: + int: The maximum length of all valid paths in the tree. + + """ + lengths = tuple(len(path) for path in self.valid_paths()) + if len(lengths) == 0: + return 0 + elif len(lengths) == 1: + return lengths[0] + return max(*lengths) def rollout_from_path(self, path: Tuple[int]) -> TensorDictBase | None: + """Retrieves the rollout data along a given path in the tree. + + The rollout data is concatenated along the last dimension (dim=-1) for each node in the path. + If no rollout data is found along the path, returns ``None``. + + Args: + path: A tuple of integers representing the path in the tree. + + Returns: + The concatenated rollout data along the path, or None if no data is found. + + """ r = self.rollout tree = self rollouts = [] @@ -272,8 +570,19 @@ def plot( backend: str = "plotly", figure: str = "tree", info: List[str] = None, - make_labels: Callable[[Any], Any] | None = None, + make_labels: Callable[[Any, ...], Any] | None = None, ): + """Plots a visualization of the tree using the specified backend and figure type. + + Args: + backend: The plotting backend to use. Currently only supports 'plotly'. + figure: The type of figure to plot. Can be either 'tree' or 'box'. + info: A list of additional information to include in the plot (not currently used). + make_labels: An optional function to generate custom labels for the plot. + + Raises: + NotImplementedError: If an unsupported backend or figure type is specified. + """ if backend == "plotly": if figure == "box": _plot_plotly_box(self) @@ -284,33 +593,48 @@ def plot( else: pass raise NotImplementedError( - f"Unkown plotting backend {backend} with figure {figure}." + f"Unknown plotting backend {backend} with figure {figure}." ) class MCTSForest: """A collection of MCTS trees. + .. warning:: This class is currently under active development. Expect frequent API changes. + The class is aimed at storing rollouts in a storage, and produce trees based on a given root in that dataset. Keyword Args: data_map (TensorDictMap, optional): the storage to use to store the data (observation, reward, states etc). If not provided, it is lazily - initialized using :meth:`~torchrl.data.map.tdstorage.TensorDictMap.from_tensordict_pair`. - node_map (TensorDictMap, optional): TODO - done_keys (list of NestedKey): the done keys of the environment. If not provided, + initialized using :meth:`~torchrl.data.map.tdstorage.TensorDictMap.from_tensordict_pair` + using the list of :attr:`observation_keys` and :attr:`action_keys` as ``in_keys``. + node_map (TensorDictMap, optional): a map from the observation space to the index space. + Internally, the node map is used to gather all possible branches coming out of + a given node. For example, if an observation has two associated actions and outcomes + in the data map, then the :attr:`node_map` will return a data structure containing the + two indices in the :attr:`data_map` that correspond to these two outcomes. + If not provided, it is lazily initialized using + :meth:`~torchrl.data.map.tdstorage.TensorDictMap.from_tensordict_pair` using the list of + :attr:`observation_keys` as ``in_keys`` and the :class:`~torchrl.data.QueryModule` as + ``out_keys``. + max_size (int, optional): the size of the maps. + If not provided, defaults to ``data_map.max_size`` if this can be found, then + ``node_map.max_size``. If none of these are provided, defaults to `1000`. + done_keys (list of NestedKey, optional): the done keys of the environment. If not provided, defaults to ``("done", "terminated", "truncated")``. The :meth:`~.get_keys_from_env` can be used to automatically determine the keys. - action_keys (list of NestedKey): the action keys of the environment. If not provided, + action_keys (list of NestedKey, optional): the action keys of the environment. If not provided, defaults to ``("action",)``. The :meth:`~.get_keys_from_env` can be used to automatically determine the keys. - reward_keys (list of NestedKey): the reward keys of the environment. If not provided, + reward_keys (list of NestedKey, optional): the reward keys of the environment. If not provided, defaults to ``("reward",)``. The :meth:`~.get_keys_from_env` can be used to automatically determine the keys. - observation_keys (list of NestedKey): the observation keys of the environment. If not provided, + observation_keys (list of NestedKey, optional): the observation keys of the environment. If not provided, defaults to ``("observation",)``. The :meth:`~.get_keys_from_env` can be used to automatically determine the keys. + excluded_keys (list of NestedKey, optional): a list of keys to exclude from the data storage. consolidated (bool, optional): if ``True``, the data_map storage will be consolidated on disk. Defaults to ``False``. @@ -405,10 +729,12 @@ def __init__( *, data_map: TensorDictMap | None = None, node_map: TensorDictMap | None = None, + max_size: int | None = None, done_keys: List[NestedKey] | None = None, reward_keys: List[NestedKey] = None, observation_keys: List[NestedKey] = None, action_keys: List[NestedKey] = None, + excluded_keys: List[NestedKey] = None, consolidated: bool | None = None, ): @@ -416,55 +742,125 @@ def __init__( self.node_map = node_map + if max_size is None: + if data_map is not None: + max_size = data_map.max_size + if max_size != getattr(node_map, "max_size", max_size): + raise ValueError( + f"Conflicting max_size: got data_map.max_size={data_map.max_size} and node_map.max_size={node_map.max_size}." + ) + elif node_map is not None: + max_size = node_map.max_size + else: + max_size = None + elif data_map is not None and max_size != getattr( + data_map, "max_size", max_size + ): + raise ValueError( + f"Conflicting max_size: got data_map.max_size={data_map.max_size} and max_size={max_size}." + ) + elif node_map is not None and max_size != getattr( + node_map, "max_size", max_size + ): + raise ValueError( + f"Conflicting max_size: got node_map.max_size={node_map.max_size} and max_size={max_size}." + ) + self.max_size = max_size + self.done_keys = done_keys self.action_keys = action_keys self.reward_keys = reward_keys self.observation_keys = observation_keys + self.excluded_keys = excluded_keys self.consolidated = consolidated @property - def done_keys(self): + def done_keys(self) -> List[NestedKey]: + """Done Keys. + + Returns the keys used to indicate that an episode has ended. + The default done keys are "done", "terminated", and "truncated". These keys can be + used in the environment's output to signal the end of an episode. + + Returns: + A list of strings representing the done keys. + + """ done_keys = getattr(self, "_done_keys", None) if done_keys is None: - self._done_keys = done_keys = ("done", "terminated", "truncated") + self._done_keys = done_keys = ["done", "terminated", "truncated"] return done_keys @done_keys.setter def done_keys(self, value): - self._done_keys = value + self._done_keys = _make_list_of_nestedkeys(value, "done_keys") @property - def reward_keys(self): + def reward_keys(self) -> List[NestedKey]: + """Reward Keys. + + Returns the keys used to retrieve rewards from the environment's output. + The default reward key is "reward". + + Returns: + A list of strings or tuples representing the reward keys. + + """ reward_keys = getattr(self, "_reward_keys", None) if reward_keys is None: - self._reward_keys = reward_keys = ("reward",) + self._reward_keys = reward_keys = ["reward"] return reward_keys @reward_keys.setter def reward_keys(self, value): - self._reward_keys = value + self._reward_keys = _make_list_of_nestedkeys(value, "reward_keys") @property - def action_keys(self): + def action_keys(self) -> List[NestedKey]: + """Action Keys. + + Returns the keys used to retrieve actions from the environment's input. + The default action key is "action". + + Returns: + A list of strings or tuples representing the action keys. + + """ action_keys = getattr(self, "_action_keys", None) if action_keys is None: - self._action_keys = action_keys = ("action",) + self._action_keys = action_keys = ["action"] return action_keys @action_keys.setter def action_keys(self, value): - self._action_keys = value + self._action_keys = _make_list_of_nestedkeys(value, "action_keys") @property - def observation_keys(self): + def observation_keys(self) -> List[NestedKey]: + """Observation Keys. + + Returns the keys used to retrieve observations from the environment's output. + The default observation key is "observation". + + Returns: + A list of strings or tuples representing the observation keys. + """ observation_keys = getattr(self, "_observation_keys", None) if observation_keys is None: - self._observation_keys = observation_keys = ("observation",) + self._observation_keys = observation_keys = ["observation"] return observation_keys @observation_keys.setter def observation_keys(self, value): - self._observation_keys = value + self._observation_keys = _make_list_of_nestedkeys(value, "observation_keys") + + @property + def excluded_keys(self) -> List[NestedKey] | None: + return self._excluded_keys + + @excluded_keys.setter + def excluded_keys(self, value): + self._excluded_keys = _make_list_of_nestedkeys(value, "excluded_keys") def get_keys_from_env(self, env: EnvBase): """Writes missing done, action and reward keys to the Forest given an environment. @@ -482,8 +878,21 @@ def get_keys_from_env(self, env: EnvBase): @classmethod def _write_fn_stack(cls, new, old=None): + # This function updates the old values by adding the new ones + # if and only if the new ones are not there. + # If the old value is not provided, we assume there are none and the + # `new` is just prepared. + # This involves unsqueezing the last dim (since we'll be stacking tensors + # and calling unique). + # The update involves calling cat along the last dim + unique + # which will keep only the new values that were unknown to + # the storage. + # We use this method to track all the indices that are associated with + # an observation. Every time a new index is obtained, it is stacked alongside + # the others. if old is None: - result = new.apply(lambda x: x.unsqueeze(0), filter_empty=False) + # we unsqueeze the values to stack them along dim -1 + result = new.apply(lambda x: x.unsqueeze(-1), filter_empty=False) result.set( "count", torch.ones(result.shape, dtype=torch.int, device=result.device) ) @@ -493,28 +902,44 @@ def cat(name, x, y): if name == "count": return x if y.ndim < x.ndim: - y = y.unsqueeze(0) - result = torch.cat([x, y], 0).unique(dim=0, sorted=False) + y = y.unsqueeze(-1) + result = torch.cat([x, y], -1) + # Breaks on mps + if result.device.type == "mps": + result = result.cpu() + result = result.unique(dim=-1, sorted=False) + result = result.to("mps") + else: + result = result.unique(dim=-1, sorted=False) return result result = old.named_apply(cat, new, default=None) result.set_("count", old.get("count") + 1) return result - def _make_storage(self, source, dest): + def _make_data_map(self, source, dest): try: + kwargs = {} + if self.max_size is not None: + kwargs["max_size"] = self.max_size self.data_map = TensorDictMap.from_tensordict_pair( source, dest, in_keys=[*self.observation_keys, *self.action_keys], consolidated=self.consolidated, + **kwargs, ) + if self.max_size is None: + self.max_size = self.data_map.max_size except KeyError as err: raise KeyError( "A KeyError occurred during data map creation. This could be due to the wrong setting of a key in the MCTSForest constructor. Scroll up for more info." ) from err - def _make_storage_branches(self, source, dest): + def _make_node_map(self, source, dest): + kwargs = {} + if self.max_size is not None: + kwargs["max_size"] = self.max_size self.node_map = TensorDictMap.from_tensordict_pair( source, dest, @@ -528,26 +953,59 @@ def _make_storage_branches(self, source, dest): storage_constructor=ListStorage, collate_fn=TensorDict.lazy_stack, write_fn=self._write_fn_stack, + **kwargs, ) + if self.max_size is None: + self.max_size = self.data_map.max_size - def extend(self, rollout): + def extend(self, rollout, *, return_node: bool = False): source, dest = ( rollout.exclude("next").copy(), rollout.select("next", *self.action_keys).copy(), ) + if self.excluded_keys is not None: + dest = dest.exclude(*self.excluded_keys, inplace=True) + dest.get("next").exclude(*self.excluded_keys, inplace=True) if self.data_map is None: - self._make_storage(source, dest) + self._make_data_map(source, dest) # We need to set the action somewhere to keep track of what action lead to what child # # Set the action in the 'next' # dest[1:] = source[:-1].exclude(*self.done_keys) + # Add ('observation', 'action') -> ('next, observation') self.data_map[source] = dest value = source if self.node_map is None: - self._make_storage_branches(source, dest) + self._make_node_map(source, dest) + # map ('observation',) -> ('indices',) self.node_map[source] = TensorDict.lazy_stack(value.unbind(0)) + if return_node: + return self.get_tree(rollout) + + def add(self, step, *, return_node: bool = False): + source, dest = ( + step.exclude("next").copy(), + step.select("next", *self.action_keys).copy(), + ) + + if self.data_map is None: + self._make_data_map(source, dest) + + # We need to set the action somewhere to keep track of what action lead to what child + # # Set the action in the 'next' + # dest[1:] = source[:-1].exclude(*self.done_keys) + + # Add ('observation', 'action') -> ('next, observation') + self.data_map[source] = dest + value = source + if self.node_map is None: + self._make_node_map(source, dest) + # map ('observation',) -> ('indices',) + self.node_map[source] = value + if return_node: + return self.get_tree(step) def get_child(self, root: TensorDictBase) -> TensorDictBase: return self.data_map[root] @@ -573,6 +1031,8 @@ def _make_local_tree( while index.numel() <= 1: index = index.squeeze() d = self.data_map.storage[index] + + # Rebuild rollout step steps.append(merge_tensordicts(d, root, callback_exist=lambda *x: None)) d = d["next"] if d in self.node_map: @@ -582,6 +1042,15 @@ def _make_local_tree( if not compact: break else: + # If the root is provided and not gathered from the storage, it could be that its + # device doesn't match the data_map storage device. + root = steps[-1]["next"].select(*self.node_map.in_keys) + device = getattr(self.data_map.storage, "device", None) + if root.device != device: + if device is not None: + root = root.to(self.data_map.storage.device) + else: + root.clear_device_() index = None break rollout = None @@ -592,10 +1061,12 @@ def _make_local_tree( return ( Tree( rollout=rollout, - count=node_meta["count"], - node=root, + count=torch.zeros((), dtype=torch.int32), + wins=torch.zeros(()), + node_data=root, index=index, hash=None, + # We do this to avoid raising an exception as rollout and subtree must be provided together subtree=None, ), index, @@ -618,7 +1089,7 @@ def _make_tree_iter( ): q = deque() memo = {} - tree, indices, hash = self._make_local_tree(root, index=index) + tree, indices, hash = self._make_local_tree(root, index=index, compact=compact) tree.node_id = 0 result = tree @@ -626,7 +1097,6 @@ def _make_tree_iter( counter = 1 if indices is not None: q.append((tree, indices, hash, depth)) - del tree, indices while len(q): tree, indices, hash, depth = q.popleft() @@ -638,12 +1108,29 @@ def _make_tree_iter( subtree, subtree_indices, subtree_hash = memo.get(h, (None,) * 3) if subtree is None: subtree, subtree_indices, subtree_hash = self._make_local_tree( - tree.node, index=i, compact=compact + tree.node_data, + index=i, + compact=compact, ) subtree.node_id = counter counter += 1 subtree.hash = h memo[h] = (subtree, subtree_indices, subtree_hash) + else: + # We just need to save the two (or more) rollouts + subtree_bis, _, _ = self._make_local_tree( + tree.node_data, + index=i, + compact=compact, + ) + if subtree.rollout.ndim == subtree_bis.rollout.ndim: + subtree.rollout = TensorDict.stack( + [subtree.rollout, subtree_bis.rollout] + ) + else: + subtree.rollout = TensorDict.stack( + [*subtree.rollout, subtree_bis.rollout] + ) subtrees.append(subtree) if extend and subtree_indices is not None: @@ -668,3 +1155,15 @@ def valid_paths(cls, tree: Tree): def __len__(self): return len(self.data_map) + + +def _make_list_of_nestedkeys(obj: Any, attr: str) -> List[NestedKey]: + if obj is None: + return obj + if isinstance(obj, (str, tuple)): + return [obj] + if not isinstance(obj, list): + raise ValueError( + f"{attr} must be a list of NestedKeys or a NestedKey, got {obj}." + ) + return [unravel_key(key) for key in obj] diff --git a/torchrl/data/map/utils.py b/torchrl/data/map/utils.py index 570214f1cb2..d9588d79905 100644 --- a/torchrl/data/map/utils.py +++ b/torchrl/data/map/utils.py @@ -17,13 +17,13 @@ def _plot_plotly_tree( if make_labels is None: - def make_labels(tree): + def make_labels(tree, path, *args, **kwargs): return str((tree.node_id, tree.hash)) nr_vertices = tree.num_vertices() - vertices = tree.vertices() + vertices = tree.vertices(key_type="path") - v_label = [make_labels(subtree) for subtree in vertices.values()] + v_label = [make_labels(subtree, path) for path, subtree in vertices.items()] G = Graph(nr_vertices, tree.edges()) layout = G.layout_sugiyama(range(nr_vertices)) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index 665cae254f5..ae0d97b7bab 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -246,8 +246,8 @@ def set( set_cursor: bool = True, ): if not isinstance(cursor, INT_CLASSES): - if (isinstance(cursor, torch.Tensor) and cursor.numel() <= 1) or ( - isinstance(cursor, np.ndarray) and cursor.size <= 1 + if (isinstance(cursor, torch.Tensor) and cursor.ndim == 0) or ( + isinstance(cursor, np.ndarray) and cursor.ndim == 0 ): self.set(int(cursor), data, set_cursor=set_cursor) return