-
Notifications
You must be signed in to change notification settings - Fork 5
/
visualize_graphs.py
81 lines (58 loc) · 2.46 KB
/
visualize_graphs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#----------------------------------------------------------------------
# Script to visualize halos as graphs
# Author: Pablo Villanueva Domingo
# Last update: 10/11/21
#----------------------------------------------------------------------
import time, datetime
from Source.networks import *
from Source.plotting import *
from Source.load_data import *
# Visualization routine for plotting graphs
def visualize_graph(data, ind, projection="3d", edge_index=None):
fig = plt.figure(figsize=(4, 4))
if projection=="3d":
ax = fig.add_subplot(projection ="3d")
pos = data.x[:,:3]
elif projection=="2d":
ax = fig.add_subplot()
pos = data.x[:,:2]
# Draw lines for each edge
if edge_index is not None:
for (src, dst) in edge_index.t().tolist():
src = pos[src].tolist()
dst = pos[dst].tolist()
if projection=="3d":
ax.plot([src[0], dst[0]], [src[1], dst[1]], zs=[src[2], dst[2]], linewidth=0.1, color='black')
elif projection=="2d":
ax.plot([src[0], dst[0]], [src[1], dst[1]], linewidth=0.1, color='black')
# Plot nodes
if projection=="3d":
ax.scatter(pos[:, 0], pos[:, 1], pos[:, 2], s=50, zorder=1000)
elif projection=="2d":
ax.scatter(pos[:, 0], pos[:, 1], s=50, zorder=1000)
#plt.axis('off')
fig.savefig("Plots/visualize_graph_"+str(ind), bbox_inches='tight', dpi=300)
# Main routine to display graphs from several simulations
def display_graphs(simsuite, simset, n_sims, k_nn):
# Max index of graphs to be displayed
nmax = 20
# Load data and create dataset
dataset, node_features = create_dataset(simsuite, simset, n_sims)
for i, data in enumerate(dataset[:nmax]):
if (i%2)==0: # take half of them
# Get edges from nearest neighbors within a radius k_nn
edge_index = radius_graph(data.pos, r=k_nn)
#visualize_graph(data, i, "2d", edge_index)
visualize_graph(data, i, "3d", edge_index)
#--- MAIN ---#
time_ini = time.time()
# Number of nearest neighbors in kNN / radius of NNs
k_nn = 0.07
# Simulation suite, choose between "IllustrisTNG" and "SIMBA"
simsuite = "IllustrisTNG"
# Simulation set, choose between "CV" and "LH"
simset = "CV"
# Number of simulations considered, maximum 27 for CV and 1000 for LH
n_sims = 1
display_graphs(simsuite, simset, n_sims, k_nn)
print("Finished. Time elapsed:",datetime.timedelta(seconds=time.time()-time_ini))