Skip to content

Commit

Permalink
Merge pull request #102 from kaorahi/plot_tree
Browse files Browse the repository at this point in the history
Implement #94 (plot_tree)
  • Loading branch information
kobanium authored Jan 10, 2024
2 parents edb94c8 + 6bd7766 commit 8982970
Show file tree
Hide file tree
Showing 6 changed files with 307 additions and 2 deletions.
119 changes: 119 additions & 0 deletions graph/plot_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#!/bin/env python3

import sys
import click
import math
import json
import gzip
import graphviz
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

# 親ディレクトリからたどって import するための準備
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))

from mcts.dump import enrich_mcts_dict

@click.command()
@click.argument('input_json_path', type=click.Path(exists=True))
@click.argument('output_image_path', type=click.Path())
@click.option('--around-pv', type=click.BOOL, default=False, \
help="主分岐のまわりのみ表示するフラグ。デフォルトはFalse。")
def plot_tree_main(input_json_path: str, output_image_path: str, around_pv: bool):
# docstring 中の \b は click による rewrapping の抑止(入れないと改行が無視される)
# https://click.palletsprojects.com/en/8.1.x/documentation/#preventing-rewrapping
"""MCTSツリーを可視化。
\b
Args:
input_json_path (str): MCTSの状態を表すJSONファイルのパス。
output_image_path (str): 可視化結果を保存する画像ファイルのパス。
around_pv (bool): 最善応手系列の周辺のみ表示するフラグ。デフォルトはFalse。
\b
Example:
cd tamago
(echo 'tamago-readsgf (;SZ[9]KM[7];B[fe];W[de];B[ec])';
echo 'lz-genmove_analyze 7777777';
echo 'tamago-dump_tree') \\
| python3 main.py --model model/model.bin --strict-visits 100 \\
| grep dump_version | gzip > tree.json.gz
python3 graph/plot_tree.py tree.json.gz tree_graph
display tree_graph.png
"""

opener = gzip.open if input_json_path.endswith('.gz') else open
with opener(input_json_path, 'r') as file:
state = json.load(file)

enrich_mcts_dict(state)
tree = state["tree"]
node = tree["node"]
sorted_indices_list = tree["sorted_indices_list"]

# colormap = plt.cm.get_cmap('coolwarm_r')
colormap = plt.cm.get_cmap('Spectral')
# colormap = plt.cm.get_cmap('RdYlBu')
# colormap = plt.cm.get_cmap('viridis')

dot = graphviz.Digraph(comment='Visualization of MCTS Tree')

for index in sorted_indices_list:
item = node[index]
# ルートノードの場合
if "parent_index" not in item:
dot.node(str(index), label=f"root\n{item['node_visits']} visits")
continue

parent_index = item['parent_index']
parent = node[parent_index]
# around_pv が指定された場合は、PV とその直下の子のみ表示する。
if around_pv and any(order > 0 for order in parent["orders_along_path"]):
continue

# ノードの作成
move = item['gtp_move']
visits = item['visits']
winrate = item['mean_black_winrate']
raw_winrate = item['raw_black_winrate']
node_color = get_color(winrate, colormap)
border_color = get_color(raw_winrate, colormap)
text_color = 'black' if abs(winrate - 0.5) < 0.25 else 'white'
# 黒の着手(次が白番)は□、白の着手は○でノードを描く
shape = 'square' if item["to_move"] == 'white' else 'circle'
wr = int(winrate * 100)
raw_wr = int(raw_winrate * 100)
label = f"{move}\n{wr}%" if visits < 10 else f"{move}\n{wr}% (raw {raw_wr}%)\n{visits} visits"
dot.node(
str(index),
label=label,
color=border_color,
fillcolor=node_color,
fontcolor=text_color,
style='filled',
penwidth='5.0',
height=get_size(visits, shape),
fixedsize='true',
shape=shape,
)

# エッジの作成
penwidth = max(0.5, item['policy'] * 10)
dot.edge(str(parent_index), str(index), penwidth=f"{penwidth}")

dot.render(output_image_path, format='png', view=False, cleanup=True)

def get_color(value, colormap):
emphasis = 1.5 # 色の違いを強調
v = 0.5 + (value - 0.5) * emphasis
return mcolors.to_hex(colormap(v))

def get_size(visits, shape):
size0 = 0.5 + math.log10(visits)
# 正方形と円の面積が同じになるように
size = size0 if shape == 'square' else size0 * 2 / (math.pi ** 0.5)
return str(size)

if __name__ == "__main__":
plot_tree_main()
14 changes: 13 additions & 1 deletion gtp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def __init__(self, board_size: int, superko: bool, model_file_path: str, \
"lz-analyze",
"lz-genmove_analyze",
"cgos-analyze",
"cgos-genmove_analyze"
"cgos-genmove_analyze",
"tamago-dump_tree",
]
self.superko = superko
self.board = GoBoard(board_size=board_size, komi=komi, check_superko=superko)
Expand Down Expand Up @@ -464,6 +465,15 @@ def _genmove_analyze(self, mode: str, arg_list: List[str]) -> NoReturn:
print_out(f"play {self.coordinate.convert_to_gtp_format(pos)}\n")


def _dump_tree(self) -> NoReturn:
"""tamago-dump_treeコマンドを実行する。現在のMCTSツリーの状態をJSON形式で出力する。
"""
json_str = self.mcts.dump_to_json(self.board, self.superko)
respond_success("", ongoing=True)
print(json_str)
print("")


def run(self) -> NoReturn: # pylint: disable=R0912,R0915
"""Go Text Protocolのクライアントの実行処理。
入力されたコマンドに対応する処理を実行し、応答メッセージを表示する。
Expand Down Expand Up @@ -570,6 +580,8 @@ def run(self) -> NoReturn: # pylint: disable=R0912,R0915
print("")
elif input_gtp_command == "cgos-genmove_analyze":
self._genmove_analyze("cgos", command_list[1:])
elif input_gtp_command == "tamago-dump_tree":
self._dump_tree()
elif input_gtp_command == "hash_record":
print_err(self.board.record.get_hash_history())
respond_success("")
Expand Down
103 changes: 103 additions & 0 deletions mcts/dump.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import json
from typing import Any, Dict, NoReturn

from program import PROGRAM_NAME, VERSION, PROTOCOL_VERSION
from board.go_board import GoBoard
from board.coordinate import Coordinate
from board.stone import Stone
from mcts.constant import NOT_EXPANDED

def dump_mcts_to_json(tree_dict: Dict[str, Any], board: GoBoard, superko: bool) -> str:
"""MCTSの状態を表すJSON文字列を返す。
Args:
tree_dict (Dict[str, Any]): 辞書化された「ツリーの状態」。
board (GoBoard): 現在の碁盤。
superko (bool): 超劫判定の有効化。
Returns:
str: MCTSの状態を表すJSON文字列。
"""
state = {
"dump_version": 1,
"tree": tree_dict,
"board_size": board.get_board_size(),
"komi": board.get_komi(),
"superko": superko,
"name": PROGRAM_NAME,
"version": VERSION,
"protocol_version": PROTOCOL_VERSION,
}
return json.dumps(state)

def enrich_mcts_dict(state: Dict[str, Any]) -> NoReturn:
"""MCTSの状態を表す辞書に便利項目をいろいろ追加する。
Args:
state (Dict[str, Any]): MCTSの状態を表す辞書。
"""
coord = Coordinate(board_size=state["board_size"])
tree = state["tree"]
node = tree["node"]

# index, parent_index, index_in_brother の逆引き
for index, item in enumerate(node):
item["index"] = index
for index_in_brother, child_index in enumerate(item["children_index"]):
if child_index == NOT_EXPANDED:
continue
child = node[child_index]
child["parent_index"] = index
child["index_in_brother"] = index_in_brother
# 以下のコードは次の条件を前提としている。
# (将来もし tree.node の仕様が変わったら見落さないようにチェック)
assert index < child_index, "Parent index must be less than child index."
assert child_index < tree["num_nodes"], "Child index must be less than num_nodes."

# 「親は子より前」「兄弟は order の小さい方が前」を保証したリスト
sorted_indices_list = []
tree["sorted_indices_list"] = sorted_indices_list

# expanded_children_index, sorted_indices_list, 兄弟内 order
root_node = node[tree["current_root"]]
nodes_pool = [root_node]
while nodes_pool:
item = nodes_pool.pop(0)
sorted_indices_list.append(item["index"])
expanded_children_index = [i for i in item["children_index"] if i != NOT_EXPANDED]
item["expanded_children_index"] = expanded_children_index
expanded_children = [node[i] for i in expanded_children_index]
expanded_children.sort(key=lambda item: item["node_visits"], reverse=True)
for order, child in enumerate(expanded_children):
child["order"] = order
nodes_pool += expanded_children

# その他いろいろな便利項目を追加
for item in node:
is_root = "parent_index" not in item
if is_root:
item["level"] = 0
item["orders_along_path"] = []
item["to_move"] = tree["to_move"]
continue
parent = node[item["parent_index"]]
item["level"] = parent["level"] + 1
item["orders_along_path"] = [*parent["orders_along_path"], item["order"]]
item["to_move"] = _opposite_color(parent["to_move"])
# ルートノードは以下の項目を持たないことに注意
index_in_brother = item["index_in_brother"]
item["policy"] = parent["children_policy"][index_in_brother]
item["visits"] = parent["children_visits"][index_in_brother]
item["value"] = parent["children_value"][index_in_brother]
item["value_sum"] = parent["children_value_sum"][index_in_brother]
item["gtp_move"] = coord.convert_to_gtp_format(parent["action"][index_in_brother])
item["mean_value"] = item["value_sum"] / item["visits"]
last_move_color = _opposite_color(item["to_move"])
item["raw_black_winrate"] = _black_winrate(item["value"], last_move_color)
item["mean_black_winrate"] = _black_winrate(item["mean_value"], last_move_color)

def _opposite_color(color):
return 'white' if color == 'black' else 'black'

def _black_winrate(value, last_move_color):
return value if last_move_color == "black" else 1.0 - value
36 changes: 35 additions & 1 deletion mcts/node.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""モンテカルロ木探索で使用するノードの実装。
"""
import json
from typing import Callable, Dict, List, NoReturn
from typing import Any, Callable, Dict, List, NoReturn

import numpy as np
import torch
from board.constant import BOARD_SIZE
from board.go_board import GoBoard
from common.print_console import print_err
Expand Down Expand Up @@ -217,6 +218,39 @@ def set_child_index(self, index: int, child_index: int) -> NoReturn:
self.children_index[index] = child_index


def to_dict(self) -> Dict[str, Any]:
"""ノードの状態を辞書化して返す。
Returns:
Dict[str, Any]: ノードの状態を表す辞書。
"""
state = {
"node_visits": self.node_visits,
"virtual_loss": self.virtual_loss,
"node_value_sum": self.node_value_sum,
"raw_value": self.raw_value,
"action": self.action,
"children_index": self.children_index,
"children_value": self.children_value,
"children_visits": self.children_visits,
"children_policy": self.children_policy,
"children_virtual_loss": self.children_virtual_loss,
"children_value_sum": self.children_value_sum,
"noise": self.noise,
"num_children": self.num_children,
}
self._make_serializable(state)
return state

def _make_serializable(self, dic):
for key in dic:
val = dic[key]
if isinstance(val, np.ndarray):
val = val.tolist()
elif isinstance(val, torch.Tensor):
val = val.item()
dic[key] = val

def print_search_result(self, board: GoBoard, pv_dict: Dict[str, List[str]]) -> NoReturn:
"""探索結果を表示する。探索した手の探索回数とValueの平均値を表示する。
Expand Down
36 changes: 36 additions & 0 deletions mcts/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from mcts.sequential_halving import get_candidates_and_visit_pairs
from mcts.node import MCTSNode
from mcts.time_manager import TimeControl, TimeManager
from mcts.dump import dump_mcts_to_json

class MCTSTree: # pylint: disable=R0902
"""モンテカルロ木探索の実装クラス。
Expand All @@ -42,6 +43,7 @@ def __init__(self, network: DualNet, tree_size: int=MCTS_TREE_SIZE, \
self.current_root = 0
self.batch_size = batch_size
self.cgos_mode = cgos_mode
self.to_move = Stone.BLACK


def search_best_move(self, board: GoBoard, color: Stone, time_manager: TimeManager, \
Expand Down Expand Up @@ -137,6 +139,7 @@ def search(self, board: GoBoard, color: Stone, time_manager: TimeManager, \
time_manager (TimeManager): 思考時間管理インスタンス。
analysis_query (Dict[str, Any]) : 解析情報。
"""
self.to_move = color
analysis_clock = time.time()
search_board = copy.deepcopy(board)

Expand Down Expand Up @@ -447,6 +450,39 @@ def get_best_move_sequence(self, pv_list: List[str], index: int) -> List[str]:
return self.get_best_move_sequence(pv_list, next_index)


def dump_to_json(self, board: GoBoard, superko: bool) -> str:
"""MCTSの状態を表すJSON文字列を返す。
Args:
board (GoBoard): 現在の碁盤。
superko (bool): 超劫判定の有効化。
Returns:
str: MCTSの状態を表すJSON文字列。
"""
return dump_mcts_to_json(self.to_dict(), board, superko)


def to_dict(self) -> Dict[str, Any]:
"""ツリーの状態を辞書化して返す。
Returns:
Dict[str, Any]: ツリーの状態を表す辞書。
"""
state = {
"node": [self.node[i].to_dict() for i in range(self.num_nodes)],
"num_nodes": self.num_nodes,
"root": self.root,
#"network": self.network, # ダンプに含めない
#"batch_queue": self.batch_queue, # ダンプに含めない
"current_root": self.current_root,
"batch_size": self.batch_size,
"cgos_mode": self.cgos_mode,
"to_move": 'black' if self.to_move == Stone.BLACK else 'white',
}
return state


def get_tentative_policy(candidates: List[int]) -> Dict[int, float]:
"""ニューラルネットワークの計算が行われるまでに使用するPolicyを取得する。
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
click>=6.7
numpy>=1.19.5
torch>=1.10.0
graphviz>=0.20.1

0 comments on commit 8982970

Please sign in to comment.