-
Notifications
You must be signed in to change notification settings - Fork 0
/
viz.py
108 lines (89 loc) · 3.39 KB
/
viz.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import networkx as nx
import os
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Rectangle
# To fix awkward numpy return.
def size(height=0, width=0):
return (height, width)
def edgelist(neighbors):
edges = []
for neighborhood_i, neighborhood in enumerate(neighbors):
for neighbor in neighborhood:
edges.append((neighborhood_i, neighbor))
return edges
def endplot(results, trange, neighbors, D):
# TODO(iamabel): There needs to be a better way to do color.
colors = "bgrcmykw"
done_neighbors = []
# Make a legend
handles = [Rectangle((0,0),1,1, color=colors[n%len(colors)]) for n in range(D)]
labels = ["Degree " + str(n) for n in range(D)]
for neighborhood_i, neighborhood in enumerate(neighbors):
if neighborhood not in done_neighbors:
fig = plt.figure(neighborhood_i)
plt.clf()
for i in neighborhood:
for n in range(D):
# Degree of freedom for i for all times
plt.plot(trange, results[:,n + i*D], colors[n%len(colors)])
done_neighbors.append(neighborhood)
fig.suptitle("Group: "+','.join(map(str, neighborhood)), fontsize=14, fontweight='bold')
plt.legend(handles, labels)
def sigmaovertime(alpha, xn, neighborhoods, trange):
N, D = size(*xn.shape)
sigma_is = []
sigmas = []
# TODO(iamabel): r needs to be theta dot but is theta
for dthetas in r:
for neighbor_i, neighborhood_i in enumerate(neighborhoods):
d_ijs = []
d_0 = alpha * np.linalg.norm(xn[neighbor_i, :])
if not neighborhood_i: # nan avoidance with no neighbor issue
d_ijs.append(0)
else:
for neighbor_j in neighborhood_i:
d_ijs.append(np.linalg.norm(
dthetas[neighbor_i*D:(neighbor_i*D)+D] -
dthetas[neighbor_j*D:(neighbor_j*D)+D])
/d_0)
sigma_is.append(np.average(d_ijs))
print("sigma: " + str(np.average(sigma_is)) + " for alpha: " + str(alpha))
sigmas.append(np.average(sigma_is))
plt.plot(trange, sigmas)
return False
def flagplot(neighborhoods):
G=nx.Graph(edgelist(neighborhoods))
directory = os.fsencode("./assets")
assets = []
for file in os.listdir(directory):
assets.append("./assets/" + os.fsdecode(file))
assets.sort()
for n in G:
# Images from https://en.wikipedia.org/wiki/Gallery_of_sovereign_state_flags
G.node[n]['image']=mpimg.imread(assets[n])
pos=nx.spring_layout(G)
fig=plt.figure(figsize=(7,7))
ax=plt.gca()
ax.set_aspect('equal')
nx.draw_networkx_edges(G,pos,ax=ax)
trans=ax.transData.transform
trans2=fig.transFigure.inverted().transform
flagsize=0.1 # this is the image size
f2=flagsize/2.0
for n in G:
xx,yy=trans(pos[n]) # figure coordinates
xa,ya=trans2((xx,yy)) # axes coordinates
a = plt.axes([xa-f2,ya-f2, flagsize, flagsize])
a.set_aspect('equal')
a.imshow(G.node[n]['image'])
a.set_xticks([])
a.set_yticks([])
ax.axis('off')
plt.show()
def netplot(neighborhoods):
fig = plt.figure(77)
G = nx.Graph(edgelist(neighborhoods))
nx.draw_spring(G)
plt.show()