-
Notifications
You must be signed in to change notification settings - Fork 6
/
visualize.py
107 lines (94 loc) · 3.59 KB
/
visualize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import json
import os
import numpy as np
from graphviz import Digraph
from argparse import ArgumentParser
# windows settings
gv_path = r';C:\Program Files (x86)\Graphviz2.38\bin'
print(gv_path)
os.environ['PATH'] += gv_path
OPS = [
'irs23',
'irs25',
'irs43',
'irs45',
'identity'
]
class Visualizer(object):
def __init__(self, file, save_type='pdf', save_path='.', OPS=OPS):
'''
file: the json file to load cell structure
OPS: predefined operations
'''
self.file = file
self.save_type = save_type
self.save_path = save_path
self.OPS = OPS
with open(file, 'r') as f:
self.cell = json.load(f)
def plot_cells(self):
normal_cell, reduce_cell = self.split_cell_json(self.cell)
self.plot_cell(normal_cell, 'normal')
self.plot_cell(reduce_cell, 'reduce')
def split_cell_json(self, cell):
f = lambda x: int(x)
normal_cell = {key: list(map(f, cell[key])) for key in cell if key.startswith('normal')}
reduce_cell = {key: list(map(f, cell[key])) for key in cell if key.startswith('reduce')}
return normal_cell, reduce_cell
def plot_cell(self, cell, cell_type='normal'):
dot = Digraph(
format=self.save_type,
edge_attr=dict(fontsize='20'),
node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5', width='0.5', penwidth='2'),
engine='dot')
dot.body.extend(['rankdir=LR'])
dot.node("c_{k-2}", fillcolor='darkseagreen2')
dot.node("c_{k-1}", fillcolor='darkseagreen2')
steps = len(cell) // 4
for i in range(steps):
dot.node(str(i), fillcolor='lightblue')
for i, key in enumerate(cell):
if i%2 == 1:
# op
op_index = np.argmax(cell[key])
op = self.OPS[op_index]
dot.edge(edge_in, edge_out, label=op, fillcolor="gray")
else:
input_index = np.argmax(cell[key])
if input_index == 0:
edge_in = "c_{k-2}"
elif input_index == 1:
edge_in = "c_{k-1}"
else:
edge_in = str(input_index - 2)
edge_out = str(i//4)
dot.node("c_{k}", fillcolor='palegoldenrod')
for i in range(steps):
dot.edge(str(i), "c_{k}", fillcolor="gray")
filename = f'{cell_type}_' + os.path.basename(self.file).replace('.json', '')
filename = os.path.join(self.save_path, filename)
print(filename)
dot.render(filename, view=False, cleanup=True)
# return dot
if __name__ == "__main__":
parser = ArgumentParser("Visualization")
parser.add_argument('--file_path', default='outputs', type=str)
parser.add_argument("--save_path", default='outputs/vis', type=str)
parser.add_argument("--save_type", default='png', type=str)
args = parser.parse_args()
file_path = args.file_path
save_type = args.save_type
save_path = args.save_path
if os.path.exists(file_path):
if os.path.isfile(file_path):
file = file_path
vis = Visualizer(file, save_type, save_path)
vis.plot_cells()
elif os.path.isdir(file_path):
for file in os.listdir(file_path):
if file.endswith(".json"):
file = os.path.join(file_path, file)
vis = Visualizer(file, save_type, save_path)
vis.plot_cells()
else:
print('Wrong file path')