-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #102 from kaorahi/plot_tree
Implement #94 (plot_tree)
- Loading branch information
Showing
6 changed files
with
307 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
#!/bin/env python3 | ||
|
||
import sys | ||
import click | ||
import math | ||
import json | ||
import gzip | ||
import graphviz | ||
import matplotlib.pyplot as plt | ||
import matplotlib.colors as mcolors | ||
|
||
# 親ディレクトリからたどって import するための準備 | ||
from pathlib import Path | ||
sys.path.append(str(Path(__file__).parent.parent)) | ||
|
||
from mcts.dump import enrich_mcts_dict | ||
|
||
@click.command() | ||
@click.argument('input_json_path', type=click.Path(exists=True)) | ||
@click.argument('output_image_path', type=click.Path()) | ||
@click.option('--around-pv', type=click.BOOL, default=False, \ | ||
help="主分岐のまわりのみ表示するフラグ。デフォルトはFalse。") | ||
def plot_tree_main(input_json_path: str, output_image_path: str, around_pv: bool): | ||
# docstring 中の \b は click による rewrapping の抑止(入れないと改行が無視される) | ||
# https://click.palletsprojects.com/en/8.1.x/documentation/#preventing-rewrapping | ||
"""MCTSツリーを可視化。 | ||
\b | ||
Args: | ||
input_json_path (str): MCTSの状態を表すJSONファイルのパス。 | ||
output_image_path (str): 可視化結果を保存する画像ファイルのパス。 | ||
around_pv (bool): 最善応手系列の周辺のみ表示するフラグ。デフォルトはFalse。 | ||
\b | ||
Example: | ||
cd tamago | ||
(echo 'tamago-readsgf (;SZ[9]KM[7];B[fe];W[de];B[ec])'; | ||
echo 'lz-genmove_analyze 7777777'; | ||
echo 'tamago-dump_tree') \\ | ||
| python3 main.py --model model/model.bin --strict-visits 100 \\ | ||
| grep dump_version | gzip > tree.json.gz | ||
python3 graph/plot_tree.py tree.json.gz tree_graph | ||
display tree_graph.png | ||
""" | ||
|
||
opener = gzip.open if input_json_path.endswith('.gz') else open | ||
with opener(input_json_path, 'r') as file: | ||
state = json.load(file) | ||
|
||
enrich_mcts_dict(state) | ||
tree = state["tree"] | ||
node = tree["node"] | ||
sorted_indices_list = tree["sorted_indices_list"] | ||
|
||
# colormap = plt.cm.get_cmap('coolwarm_r') | ||
colormap = plt.cm.get_cmap('Spectral') | ||
# colormap = plt.cm.get_cmap('RdYlBu') | ||
# colormap = plt.cm.get_cmap('viridis') | ||
|
||
dot = graphviz.Digraph(comment='Visualization of MCTS Tree') | ||
|
||
for index in sorted_indices_list: | ||
item = node[index] | ||
# ルートノードの場合 | ||
if "parent_index" not in item: | ||
dot.node(str(index), label=f"root\n{item['node_visits']} visits") | ||
continue | ||
|
||
parent_index = item['parent_index'] | ||
parent = node[parent_index] | ||
# around_pv が指定された場合は、PV とその直下の子のみ表示する。 | ||
if around_pv and any(order > 0 for order in parent["orders_along_path"]): | ||
continue | ||
|
||
# ノードの作成 | ||
move = item['gtp_move'] | ||
visits = item['visits'] | ||
winrate = item['mean_black_winrate'] | ||
raw_winrate = item['raw_black_winrate'] | ||
node_color = get_color(winrate, colormap) | ||
border_color = get_color(raw_winrate, colormap) | ||
text_color = 'black' if abs(winrate - 0.5) < 0.25 else 'white' | ||
# 黒の着手(次が白番)は□、白の着手は○でノードを描く | ||
shape = 'square' if item["to_move"] == 'white' else 'circle' | ||
wr = int(winrate * 100) | ||
raw_wr = int(raw_winrate * 100) | ||
label = f"{move}\n{wr}%" if visits < 10 else f"{move}\n{wr}% (raw {raw_wr}%)\n{visits} visits" | ||
dot.node( | ||
str(index), | ||
label=label, | ||
color=border_color, | ||
fillcolor=node_color, | ||
fontcolor=text_color, | ||
style='filled', | ||
penwidth='5.0', | ||
height=get_size(visits, shape), | ||
fixedsize='true', | ||
shape=shape, | ||
) | ||
|
||
# エッジの作成 | ||
penwidth = max(0.5, item['policy'] * 10) | ||
dot.edge(str(parent_index), str(index), penwidth=f"{penwidth}") | ||
|
||
dot.render(output_image_path, format='png', view=False, cleanup=True) | ||
|
||
def get_color(value, colormap): | ||
emphasis = 1.5 # 色の違いを強調 | ||
v = 0.5 + (value - 0.5) * emphasis | ||
return mcolors.to_hex(colormap(v)) | ||
|
||
def get_size(visits, shape): | ||
size0 = 0.5 + math.log10(visits) | ||
# 正方形と円の面積が同じになるように | ||
size = size0 if shape == 'square' else size0 * 2 / (math.pi ** 0.5) | ||
return str(size) | ||
|
||
if __name__ == "__main__": | ||
plot_tree_main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import json | ||
from typing import Any, Dict, NoReturn | ||
|
||
from program import PROGRAM_NAME, VERSION, PROTOCOL_VERSION | ||
from board.go_board import GoBoard | ||
from board.coordinate import Coordinate | ||
from board.stone import Stone | ||
from mcts.constant import NOT_EXPANDED | ||
|
||
def dump_mcts_to_json(tree_dict: Dict[str, Any], board: GoBoard, superko: bool) -> str: | ||
"""MCTSの状態を表すJSON文字列を返す。 | ||
Args: | ||
tree_dict (Dict[str, Any]): 辞書化された「ツリーの状態」。 | ||
board (GoBoard): 現在の碁盤。 | ||
superko (bool): 超劫判定の有効化。 | ||
Returns: | ||
str: MCTSの状態を表すJSON文字列。 | ||
""" | ||
state = { | ||
"dump_version": 1, | ||
"tree": tree_dict, | ||
"board_size": board.get_board_size(), | ||
"komi": board.get_komi(), | ||
"superko": superko, | ||
"name": PROGRAM_NAME, | ||
"version": VERSION, | ||
"protocol_version": PROTOCOL_VERSION, | ||
} | ||
return json.dumps(state) | ||
|
||
def enrich_mcts_dict(state: Dict[str, Any]) -> NoReturn: | ||
"""MCTSの状態を表す辞書に便利項目をいろいろ追加する。 | ||
Args: | ||
state (Dict[str, Any]): MCTSの状態を表す辞書。 | ||
""" | ||
coord = Coordinate(board_size=state["board_size"]) | ||
tree = state["tree"] | ||
node = tree["node"] | ||
|
||
# index, parent_index, index_in_brother の逆引き | ||
for index, item in enumerate(node): | ||
item["index"] = index | ||
for index_in_brother, child_index in enumerate(item["children_index"]): | ||
if child_index == NOT_EXPANDED: | ||
continue | ||
child = node[child_index] | ||
child["parent_index"] = index | ||
child["index_in_brother"] = index_in_brother | ||
# 以下のコードは次の条件を前提としている。 | ||
# (将来もし tree.node の仕様が変わったら見落さないようにチェック) | ||
assert index < child_index, "Parent index must be less than child index." | ||
assert child_index < tree["num_nodes"], "Child index must be less than num_nodes." | ||
|
||
# 「親は子より前」「兄弟は order の小さい方が前」を保証したリスト | ||
sorted_indices_list = [] | ||
tree["sorted_indices_list"] = sorted_indices_list | ||
|
||
# expanded_children_index, sorted_indices_list, 兄弟内 order | ||
root_node = node[tree["current_root"]] | ||
nodes_pool = [root_node] | ||
while nodes_pool: | ||
item = nodes_pool.pop(0) | ||
sorted_indices_list.append(item["index"]) | ||
expanded_children_index = [i for i in item["children_index"] if i != NOT_EXPANDED] | ||
item["expanded_children_index"] = expanded_children_index | ||
expanded_children = [node[i] for i in expanded_children_index] | ||
expanded_children.sort(key=lambda item: item["node_visits"], reverse=True) | ||
for order, child in enumerate(expanded_children): | ||
child["order"] = order | ||
nodes_pool += expanded_children | ||
|
||
# その他いろいろな便利項目を追加 | ||
for item in node: | ||
is_root = "parent_index" not in item | ||
if is_root: | ||
item["level"] = 0 | ||
item["orders_along_path"] = [] | ||
item["to_move"] = tree["to_move"] | ||
continue | ||
parent = node[item["parent_index"]] | ||
item["level"] = parent["level"] + 1 | ||
item["orders_along_path"] = [*parent["orders_along_path"], item["order"]] | ||
item["to_move"] = _opposite_color(parent["to_move"]) | ||
# ルートノードは以下の項目を持たないことに注意 | ||
index_in_brother = item["index_in_brother"] | ||
item["policy"] = parent["children_policy"][index_in_brother] | ||
item["visits"] = parent["children_visits"][index_in_brother] | ||
item["value"] = parent["children_value"][index_in_brother] | ||
item["value_sum"] = parent["children_value_sum"][index_in_brother] | ||
item["gtp_move"] = coord.convert_to_gtp_format(parent["action"][index_in_brother]) | ||
item["mean_value"] = item["value_sum"] / item["visits"] | ||
last_move_color = _opposite_color(item["to_move"]) | ||
item["raw_black_winrate"] = _black_winrate(item["value"], last_move_color) | ||
item["mean_black_winrate"] = _black_winrate(item["mean_value"], last_move_color) | ||
|
||
def _opposite_color(color): | ||
return 'white' if color == 'black' else 'black' | ||
|
||
def _black_winrate(value, last_move_color): | ||
return value if last_move_color == "black" else 1.0 - value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
click>=6.7 | ||
numpy>=1.19.5 | ||
torch>=1.10.0 | ||
graphviz>=0.20.1 |