Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] MCTS policy #2359

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions test/test_storage_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,11 @@ def test_map_rollout(self):
assert contains[: rollout.shape[-1]].all()
assert not contains[rollout.shape[-1] :].any()


class TestMCTSForest:
def test_forest_build(self):
...

def test_forest_extend_and_get(self):
...

Expand Down
60 changes: 55 additions & 5 deletions torchrl/data/map/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,17 @@
from typing import List

import torch
from tensordict import LazyStackedTensorDict, NestedKey, tensorclass, TensorDict
from torchrl.data import ListStorage, TensorDictMap
from torchrl.envs import EnvBase
from tensordict import (
LazyStackedTensorDict,
NestedKey,
tensorclass,
TensorDict,
TensorDictBase,
)
from torchrl.data.map.tdstorage import TensorDictMap
from torchrl.data.replay_buffers.storages import ListStorage
from torchrl.envs.common import EnvBase
from torchrl.modules.mcts.scores import MCTSScores


@tensorclass
Expand Down Expand Up @@ -85,6 +93,7 @@ def __init__(
reward_keys: List[NestedKey] = None,
observation_keys: List[NestedKey] = None,
action_keys: List[NestedKey] = None,
mcts_score: MCTSScores = MCTSScores.PUCT,
):

self.data_map = data_map
Expand All @@ -95,6 +104,7 @@ def __init__(
self.action_keys = action_keys
self.reward_keys = reward_keys
self.observation_keys = observation_keys
self.mcts_score = mcts_score

@property
def done_keys(self):
Expand Down Expand Up @@ -198,7 +208,7 @@ def _make_storage_branches(self, source, dest):
write_fn=self._write_fn_stack,
)

def extend(self, rollout):
def extend(self, rollout: TensorDictBase):
source, dest = rollout, rollout.get("next")
if self.data_map is None:
self._make_storage(source, dest)
Expand All @@ -211,7 +221,10 @@ def extend(self, rollout):
value = source
if self.node_map is None:
self._make_storage_branches(source, dest)
self.node_map[source] = TensorDict.lazy_stack(value.unbind(0))
if source.ndim and source.names[0] == "time":
self.node_map[source] = TensorDict.lazy_stack(value.unbind(0))
else:
self.node_map[source] = value

def get_child(self, root):
return self.data_map[root]
Expand Down Expand Up @@ -445,3 +458,40 @@ def extend(tree, parent):
extend(_tree, labels[-1])
fig = go.Figure(go.Treemap(labels=labels, parents=parents))
fig.show()

def maybe_make_mcts_score(self, criterion: MCTSScores, **kwargs):
"""Makes the MCTS score function if not already done."""
if criterion == self.mcts_score:
return
self.mcts_score_fn = self.mcts_score.value(**kwargs)
self.mcts_score = criterion

@property
def mcts_score_fn(self):
score_fn = getattr(self, "_mcts_score_fn", None)
if score_fn is None:
score_fn = self._mcts_score_fn = self.maybe_make_mcts_score(self.mcts_score)
return score_fn

@mcts_score_fn.setter
def mcts_score_fn(self, value):
self._mcts_score_fn = value

def select_node(self, root, criterion: MCTSScores, as_tensordict: bool = False):
if self.data_map is None or not len(self.data_map):
return root
self.maybe_make_mcts_score(criterion)
# Recursively selects a node by using a given criterion
if root.ndim:
raise RuntimeError
while root in self.node_map:
branches = self.node_map[root]
# Each branch has an action and the resulting state
scored_branches = self.mcts_score_fn(branches)
# Now we take the argmax of the scored_branches
argmax_score = scored_branches.get(self.mcts_score_fn.score_key).argmax()
root = branches[argmax_score]

if as_tensordict:
return TensorDict({"data_content": root})
return MCTSNode(root)
84 changes: 84 additions & 0 deletions torchrl/modules/mcts/policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from enum import Enum

from tensordict.nn import TensorDictModuleBase

from torchrl.data.map.tree import MCTSForest
from torchrl.envs.common import EnvBase

from torchrl.modules.mcts.scores import MCTSScores

class ExpansionStrategies(Enum):
Exhaustive = "Exhaustive"
Sampling = "Sampling"
Embedding = "Embedding"


class MCTSPolicy(TensorDictModuleBase):
rollout_kwargs = {"break_when_any_done": False}
num_sim = 1

def __init__(
self,
simulation_env: EnvBase,
*,
forest: MCTSForest | None = None,
expansion_strategy: ExpansionStrategies = ExpansionStrategies.Exhaustive,
selection_criterion: MCTSScores = MCTSScores.PUCT,
):
super().__init__()
self.env = simulation_env
if forest is None:
forest = MCTSForest()
self.forest = forest
self.expansion_strategy = expansion_strategy
self.selection_criterion = selection_criterion

def forward(self, node):
# 1. Selection
selected_node = self.select_node(node)

# 2. Expansion: generate new child nodes for all possible responses to this move
actions = self.get_possible_actions()

# 3. Simulation
node_with_actions = self.set_actions(
selected_node, actions
) # Expands child to make all possible moves

# we may want to expand the children_with_node to do more than one simulation
if self.num_sim > 1:
node_with_actions = node_with_actions.expand(
self.num_sim, *node_with_actions.shape
)
# Get init state of rollouts (new children)
_, reset_nodes = self.env.step_and_maybe_reset(node_with_actions)

# Get the rollouts
rollouts = self.env.rollout(
max_steps=100, tensordict=reset_nodes, auto_reset=False, **self.rollout_kwargs
)
print(rollouts)
# Update stats of the child_with_move
self.update_stats(node_with_actions, rollouts)

# 4. Backprop

def select_node(self, node):
return self.forest.select_node(node, criterion=self.selection_criterion)

def get_possible_actions(self):
if self.expansion_strategy == ExpansionStrategies.Exhaustive:
# lists the possible moves at the node
return self.env.full_action_spec.enumerate()
elif self.expansion_strategy == ExpansionStrategies.Sampling:
raise NotImplementedError
elif self.expansion_strategy == ExpansionStrategies.Embedding:
raise NotImplementedError
else:
raise NotImplementedError
def set_actions(self, node, actions):
return node.expand(actions.shape[0], *node.shape).update(actions)
Loading