Skip to content

Commit 737bb64

Browse files
author
Vincent Moens
committed
[Feature] MCTS Scoring functions
ghstack-source-id: 7ee601a Pull Request resolved: #2358
1 parent db9e2af commit 737bb64

File tree

2 files changed

+105
-0
lines changed

2 files changed

+105
-0
lines changed

torchrl/modules/mcts/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from .scores import PUCTScore, UCBScore

torchrl/modules/mcts/scores.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
7+
import functools
8+
import math
9+
from abc import abstractmethod
10+
from enum import Enum
11+
12+
from tensordict import NestedKey, TensorDictBase
13+
from tensordict.nn import TensorDictModuleBase
14+
from torch import nn
15+
16+
17+
class MCTSScore(TensorDictModuleBase):
18+
@abstractmethod
19+
def forward(self, node):
20+
pass
21+
22+
23+
class PUCTScore(MCTSScore):
24+
c: float
25+
26+
def __init__(
27+
self,
28+
*,
29+
c: float,
30+
win_count_key: NestedKey = "win_count",
31+
visits_key: NestedKey = "visits",
32+
total_visits_key: NestedKey = "total_visits",
33+
prior_prob_key: NestedKey = "prior_prob",
34+
score_key: NestedKey = "score",
35+
):
36+
super().__init__()
37+
self.c = c
38+
self.win_count_key = win_count_key
39+
self.visits_key = visits_key
40+
self.total_visits_key = total_visits_key
41+
self.prior_prob_key = prior_prob_key
42+
self.score_key = score_key
43+
self.in_keys = [
44+
self.win_count_key,
45+
self.prior_prob_key,
46+
self.total_visits_key,
47+
self.visits_key,
48+
]
49+
self.out_keys = [self.score_key]
50+
51+
def forward(self, node: TensorDictBase) -> TensorDictBase:
52+
win_count = node.get(self.win_count_key)
53+
visits = node.get(self.visits_key)
54+
n_total = node.get(self.total_visits_key)
55+
prior_prob = node.get(self.prior_prob_key)
56+
node.set(
57+
self.score_key,
58+
(win_count / visits) + self.c * prior_prob * n_total.sqrt() / (1 + visits),
59+
)
60+
return node
61+
62+
63+
class UCBScore(MCTSScore):
64+
c: float
65+
66+
def __init__(
67+
self,
68+
*,
69+
c: float,
70+
win_count_key: NestedKey = "win_count",
71+
visits_key: NestedKey = "visits",
72+
total_visits_key: NestedKey = "total_visits",
73+
score_key: NestedKey = "score",
74+
):
75+
super().__init__()
76+
self.c = c
77+
self.win_count_key = win_count_key
78+
self.visits_key = visits_key
79+
self.total_visits_key = total_visits_key
80+
self.score_key = score_key
81+
self.in_keys = [self.win_count_key, self.total_visits_key, self.visits_key]
82+
self.out_keys = [self.score_key]
83+
84+
def forward(self, node: TensorDictBase) -> TensorDictBase:
85+
win_count = node.get(self.win_count_key)
86+
visits = node.get(self.visits_key)
87+
n_total = node.get(self.total_visits_key)
88+
node.set(
89+
self.score_key,
90+
(win_count / visits) + self.c * n_total.sqrt() / (1 + visits),
91+
)
92+
return node
93+
94+
95+
class MCTSScores(Enum):
96+
PUCT = functools.partial(PUCTScore, c=5) # AlphaGo default value
97+
UCB = functools.partial(UCBScore, c=math.sqrt(2)) # default from Auer et al. 2002
98+
UCB1_TUNED = "UCB1-Tuned"
99+
EXP3 = "EXP3"
100+
PUCT_VARIANT = "PUCT-Variant"

0 commit comments

Comments
 (0)