Skip to content

Commit 570afff

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
1 parent 34cfaae commit 570afff

File tree

5 files changed

+75
-8
lines changed

5 files changed

+75
-8
lines changed
229 KB
Loading

docs/source/reference/data.rst

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -975,7 +975,72 @@ The following classes are deprecated and just point to the classes above:
975975
Trees and Forests
976976
-----------------
977977

978-
TorchRL offers a set of classes and functions that can be used to represent trees and forests efficiently.
978+
TorchRL offers a set of classes and functions that can be used to represent trees and forests efficiently,
979+
which is particularly useful for Monte Carlo Tree Search (MCTS) algorithms.
980+
981+
TensorDictMap
982+
~~~~~~~~~~~~~
983+
984+
At its core, the MCTS API relies on the :class:`~torchrl.data.TensorDictMap` which acts like a storage where indices can
985+
be any numerical object. In traditional storages (e.g., :class:`~torchrl.data.TensorStorage`), only integer indices
986+
are allowed:
987+
988+
>>> storage = TensorStorage(...)
989+
>>> data = storage[3]
990+
991+
:class:`~torchrl.data.TensorDictMap` allows us to make more advanced queries in the storage. The typical example is
992+
when we have a storage containing a set of MDPs and we want to rebuild a trajectory given its initial observation, action
993+
pair. In tensor terms, this could be written with the following pseudocode:
994+
995+
>>> next_state = storage[observation, action]
996+
997+
(if there is more than one next state associated with this pair one could return a stack of ``next_states`` instead).
998+
This API would make sense but it would be restrictive: allowing observations or actions that are composed of
999+
multiple tensors may be hard to implement. Instead, we provide a tensordict containing these values and let the storage
1000+
know what ``in_keys`` to look at to query the next state:
1001+
1002+
>>> td = TensorDict(observation=observation, action=action)
1003+
>>> next_td = storage[td]
1004+
1005+
Of course, this class also allows us to extend the storage with new data:
1006+
1007+
>>> storage[td] = next_state
1008+
1009+
This comes in handy because it allows us to represent complex rollout structures where different actions are undertaken
1010+
at a given node (ie, for a given observation). All `(observation, action)` pairs that have been observed may lead us to
1011+
a (set of) rollout that we can use further.
1012+
1013+
MCTSForest
1014+
~~~~~~~~~~
1015+
1016+
Building a tree from an initial observation then becomes just a matter of organizing data efficiently.
1017+
The :class:`~torchrl.data.MCTSForest` has at its core two storages: a first storage links observations to hashes and
1018+
indices of actions encountered in the past in the dataset:
1019+
1020+
>>> data = TensorDict(observation=observation)
1021+
>>> metadata = forest.node_map[data]
1022+
>>> index = metadata["_index"]
1023+
1024+
where ``forest`` is a :class:`~torchrl.data.MCTSForest` instance.
1025+
Then, a second storage keeps track of the actions and results associated with the observation:
1026+
1027+
>>> next_data = forest.data_map[index]
1028+
1029+
The ``next_data`` entry can have any shape, but it will usually match the shape of ``index`` (since at each index
1030+
corresponds one action). Once ``next_data`` is obtrained, it can be put together with ``data`` to form a set of nodes,
1031+
and the tree can be expanded for each of these. The following figure shows how this is done.
1032+
1033+
.. figure:: /_static/img/collector-copy.png
1034+
1035+
Building a :class:`~torchrl.data.Tree` from a :class:`~torchrl.data.MCTSForest` object.
1036+
The flowchart represents a tree being built from an initial observation `o`. The :class:`~torchrl.data.MCTSForest.get_tree`
1037+
method passed the input data structure (the root node) to the ``node_map`` :class:`~torchrl.data.TensorDictMap` instance
1038+
that returns a set of hashes and indices. These indices are then used to query the corresponding tuples of
1039+
actions, next observations, rewards etc. that are associated with the root node.
1040+
A vertex is created from each of them (possibly with a longer rollout when a compact representation is asked).
1041+
The stack of vertices is then used to build up the tree further, and these vertices are stacked together and make
1042+
up the branches of the tree at the root. This process is repeated for a given depth or until the tree cannot be
1043+
expanded anymore.
9791044

9801045
.. currentmodule:: torchrl.data
9811046

test/test_storage_map.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -372,15 +372,16 @@ def make_labels(tree):
372372
tree.rollout["next", "observation"],
373373
]
374374
)
375+
a = tree.rollout["action"].tolist()
375376
s = s.tolist()
376-
return f"{tree.node_id}: {s}"
377-
return f"{tree.node_id}"
377+
return f"node {tree.node_id}: states {s}, actions {a}"
378+
return f"node {tree.node_id}"
378379

379380
def test_forest_build(self):
380381
r0, *_ = self.dummy_rollouts()
381382
forest = self._make_forest()
382383
tree = forest.get_tree(r0[0])
383-
# tree.plot(make_labels=self.make_labels)
384+
tree.plot(make_labels=self.make_labels)
384385

385386
def test_forest_vertices(self):
386387
r0, *_ = self.dummy_rollouts()

test/test_transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7098,7 +7098,7 @@ def test_tensordictprimer_batching(self, batched_class, break_when_any_done):
70987098
torch.manual_seed(0)
70997099
env.set_seed(0)
71007100
r1 = env.rollout(100, break_when_any_done=break_when_any_done)
7101-
tensordict.tensordict.assert_allclose_td(r0, r1)
7101+
tensordict.assert_close(r0, r1)
71027102

71037103
def test_callable_default_value(self):
71047104
def create_tensor():

torchrl/data/map/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,24 +53,25 @@ def make_labels(tree):
5353
x=Xe,
5454
y=Ye,
5555
mode="lines",
56-
line={"color": "rgb(210,210,210)", "width": 1},
56+
line={"color": "rgb(210,210,210)", "width": 5},
5757
hoverinfo="none",
5858
)
5959
)
6060
fig.add_trace(
6161
go.Scatter(
6262
x=Xn,
6363
y=Yn,
64-
mode="markers",
64+
mode="markers+text",
6565
name="bla",
6666
marker={
6767
"symbol": "circle-dot",
68-
"size": 18,
68+
"size": 40,
6969
"color": "#6175c1", # '#DB4551',
7070
"line": {"color": "rgb(50,50,50)", "width": 1},
7171
},
7272
text=labels,
7373
hoverinfo="text",
74+
textposition="middle right",
7475
opacity=0.8,
7576
)
7677
)

0 commit comments

Comments
 (0)