@@ -975,7 +975,72 @@ The following classes are deprecated and just point to the classes above:
975
975
Trees and Forests
976
976
-----------------
977
977
978
- TorchRL offers a set of classes and functions that can be used to represent trees and forests efficiently.
978
+ TorchRL offers a set of classes and functions that can be used to represent trees and forests efficiently,
979
+ which is particularly useful for Monte Carlo Tree Search (MCTS) algorithms.
980
+
981
+ TensorDictMap
982
+ ~~~~~~~~~~~~~
983
+
984
+ At its core, the MCTS API relies on the :class: `~torchrl.data.TensorDictMap ` which acts like a storage where indices can
985
+ be any numerical object. In traditional storages (e.g., :class: `~torchrl.data.TensorStorage `), only integer indices
986
+ are allowed:
987
+
988
+ >>> storage = TensorStorage(... )
989
+ >>> data = storage[3 ]
990
+
991
+ :class: `~torchrl.data.TensorDictMap ` allows us to make more advanced queries in the storage. The typical example is
992
+ when we have a storage containing a set of MDPs and we want to rebuild a trajectory given its initial observation, action
993
+ pair. In tensor terms, this could be written with the following pseudocode:
994
+
995
+ >>> next_state = storage[observation, action]
996
+
997
+ (if there is more than one next state associated with this pair one could return a stack of ``next_states `` instead).
998
+ This API would make sense but it would be restrictive: allowing observations or actions that are composed of
999
+ multiple tensors may be hard to implement. Instead, we provide a tensordict containing these values and let the storage
1000
+ know what ``in_keys `` to look at to query the next state:
1001
+
1002
+ >>> td = TensorDict(observation = observation, action = action)
1003
+ >>> next_td = storage[td]
1004
+
1005
+ Of course, this class also allows us to extend the storage with new data:
1006
+
1007
+ >>> storage[td] = next_state
1008
+
1009
+ This comes in handy because it allows us to represent complex rollout structures where different actions are undertaken
1010
+ at a given node (ie, for a given observation). All `(observation, action) ` pairs that have been observed may lead us to
1011
+ a (set of) rollout that we can use further.
1012
+
1013
+ MCTSForest
1014
+ ~~~~~~~~~~
1015
+
1016
+ Building a tree from an initial observation then becomes just a matter of organizing data efficiently.
1017
+ The :class: `~torchrl.data.MCTSForest ` has at its core two storages: a first storage links observations to hashes and
1018
+ indices of actions encountered in the past in the dataset:
1019
+
1020
+ >>> data = TensorDict(observation = observation)
1021
+ >>> metadata = forest.node_map[data]
1022
+ >>> index = metadata[" _index" ]
1023
+
1024
+ where ``forest `` is a :class: `~torchrl.data.MCTSForest ` instance.
1025
+ Then, a second storage keeps track of the actions and results associated with the observation:
1026
+
1027
+ >>> next_data = forest.data_map[index]
1028
+
1029
+ The ``next_data `` entry can have any shape, but it will usually match the shape of ``index `` (since at each index
1030
+ corresponds one action). Once ``next_data `` is obtrained, it can be put together with ``data `` to form a set of nodes,
1031
+ and the tree can be expanded for each of these. The following figure shows how this is done.
1032
+
1033
+ .. figure :: /_static/img/collector-copy.png
1034
+
1035
+ Building a :class: `~torchrl.data.Tree` from a :class:`~torchrl.data.MCTSForest` object.
1036
+ The flowchart represents a tree being built from an initial observation `o `. The :class: `~torchrl.data.MCTSForest.get_tree`
1037
+ method passed the input data structure (the root node) to the ``node_map `` :class: `~torchrl.data.TensorDictMap` instance
1038
+ that returns a set of hashes and indices. These indices are then used to query the corresponding tuples of
1039
+ actions, next observations, rewards etc. that are associated with the root node.
1040
+ A vertex is created from each of them (possibly with a longer rollout when a compact representation is asked).
1041
+ The stack of vertices is then used to build up the tree further, and these vertices are stacked together and make
1042
+ up the branches of the tree at the root. This process is repeated for a given depth or until the tree cannot be
1043
+ expanded anymore.
979
1044
980
1045
.. currentmodule :: torchrl.data
981
1046
0 commit comments