-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathbrain_optimality_script.py
223 lines (194 loc) · 8.32 KB
/
brain_optimality_script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import pdb
"""
Created on Wed Sep 3 16:27:45 2014
@author: rkp
"""
import numpy as np
import matplotlib.pyplot as plt; plt.close('all')
import network_gen
import area_compute
import brain_optimality as bropt
CALC_RANDS = False
SHOW_WVD = False
LOOP_OVER_NSWAPS = False
TEST_SPECIFIC_AREA_SETS = True
area_sets = ['CTX','CTXpl','Isocortex','OLF','HPF','TH',['CTX','TH']]
PLOT_SPECIFIC_AREA_SETS = True
TEST_RANDOM_AREA_SETS = False
ALL_PATHS_COST = False
# Parameters
sym = False
n_swaps = 1
n_permutations = 5000
cost_type = 'dist'
print_every = 500
y_lim = [84000,95000]
W,row_labels,col_labels = network_gen.quick_net()
centroids = area_compute.get_centroids(row_labels)
D = bropt.dist_mat(centroids)
if SHOW_WVD:
# Show correlation between distance and log weight
W_vec = W.flatten()
D_vec = D.flatten()
W_vec_nz = W_vec[W_vec>0]
D_vec_nz = D_vec[W_vec>0]
fig, ax = plt.subplots(1,1,facecolor='white')
ax.scatter(D_vec_nz,np.log(W_vec_nz))
ax.set_xlabel('Distance (dx = 100 um)')
ax.set_ylabel('Log[Weight]')
for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
ax.get_xticklabels() + ax.get_yticklabels()):
item.set_fontsize(20)
plt.draw()
if CALC_RANDS:
# Create 3 completely random permutations
num_rand_perm = 3
rand_c = np.zeros((num_rand_perm,),dtype=float)
for rp_idx in range(num_rand_perm):
# Make swapped distance matrix
D_swapped,_,_ = bropt.swap_nodes(D,row_labels,centroids,
n_swaps=D.shape[0])
rand_c[rp_idx] = bropt.cost(D_swapped,W,cost_type=cost_type)
print 'Costs of randomly shuffled networks'
print rand_c
if LOOP_OVER_NSWAPS:
# Loop over number of swaps
n_swaps_vec = np.arange(1,11)
plot_idxs = np.array([0,1,4,9])
p_values = np.ones((len(n_swaps_vec),))
D0 = bropt.dist_mat(centroids)
c0 = bropt.cost(D0,W,cost_type=cost_type)
for ns_idx,n_swaps in enumerate(n_swaps_vec):
print 'n_swaps = %.1f'%n_swaps
# Iterate over random permutations of pairs of nodes (not symmetrically)
c = np.zeros((n_permutations,))
node_pairs = [None for p_idx in range(n_permutations)]
centroid_pairs = [None for p_idx in range(n_permutations)]
for p_idx in range(n_permutations):
if not (p_idx+1)%print_every:
print 'Permutation #%d'%(p_idx+1)
D_swapped, node_pair, centroid_pair = \
bropt.swap_nodes(D0,row_labels,centroids,n_swaps=n_swaps,sym=sym)
node_pairs[p_idx] = node_pair
centroid_pairs[p_idx] = centroid_pair
c[p_idx] = bropt.cost(D_swapped,W,cost_type=cost_type)
p_value = ((c<=c0).sum()/float(n_permutations))
if ns_idx in plot_idxs:
fig,ax = plt.subplots(1,1,facecolor='w')
ax.scatter(np.arange(n_permutations),c,c='r')
ax.plot(np.arange(n_permutations),c0*np.ones((n_permutations,)),c='b',lw=5)
ax.set_xlim(0,n_permutations)
ax.set_ylim(y_lim[0],y_lim[1])
ax.set_xlabel('Permutation #')
ax.set_ylabel('Cost')
if sym:
ax.set_title('%d symmetric swaps, P = %.3f'%(n_swaps,p_value))
else:
ax.set_title('%d swaps, P = %.3f'%(n_swaps,p_value))
for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
ax.get_xticklabels() + ax.get_yticklabels()):
item.set_fontsize(20)
plt.draw()
p_values[ns_idx] = p_value
fig,ax = plt.subplots(1,1,facecolor='w')
ax.plot(n_swaps_vec,p_values,lw=2)
ax.set_xlabel('n_swaps')
ax.set_ylabel('P')
for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
ax.get_xticklabels() + ax.get_yticklabels()):
item.set_fontsize(20)
plt.draw()
if ALL_PATHS_COST:
# Run for all_paths cost function
# Parameters
sym = False
n_swaps = 1
n_permutations = 5000
cost_type = 'all_paths'
W,row_labels,col_labels = network_gen.quick_net()
centroids = area_compute.get_centroids(row_labels)
D0 = bropt.dist_mat(centroids)
c0 = bropt.cost(D0,W,cost_type=cost_type)
# Iterate over random permutations of pairs of nodes (not symmetrically)
c = np.zeros((n_permutations,))
node_pairs = [None for p_idx in range(n_permutations)]
centroid_pairs = [None for p_idx in range(n_permutations)]
for p_idx in range(n_permutations):
print 'Permutation #%d'%(p_idx+1)
D_swapped, node_pair, centroid_pair = \
bropt.swap_nodes(D0,row_labels,centroids,n_swaps=n_swaps,sym=sym)
node_pairs[p_idx] = node_pair
centroid_pairs[p_idx] = centroid_pair
c[p_idx] = bropt.cost(D_swapped,W,cost_type=cost_type)
fig,ax = plt.subplots(1,1,facecolor='w')
ax.scatter(np.arange(n_permutations),c,c='r')
ax.plot(np.arange(n_permutations),c0*np.ones((n_permutations,)),c='b',lw=4)
ax.set_xlabel('Permutation #')
ax.set_ylabel('Cost')
p_value = ((c<c0).sum()/float(n_permutations))
if sym:
ax.set_title('%d symmetric swaps, P = %.3f'%(n_swaps,p_value))
else:
ax.set_title('%d swaps, P = %.3f'%(n_swaps,p_value))
plt.draw()
if TEST_SPECIFIC_AREA_SETS:
# Run same analyses with specific area sets
p_per_set = [None for area_set in area_sets]
for area_set_idx,area_set in enumerate(area_sets):
print 'Getting areas from %s'%area_set
if area_set == 'Brain':
area_set = 'root'
# Get list of all areas in area set
area_subset = area_compute.get_area_subset(row_labels,area_set)
# Create weight and distance matrices for area subsets
area_mask = np.array([(area in area_subset) for area in row_labels])
centroids_subset = centroids[area_mask,:]
Wss = W[area_mask,:][:,area_mask]
Dss = D[area_mask,:][:,area_mask]
n_swaps_vec = np.arange(1,11)
plot_idxs = np.array([])
p_values = np.ones((len(n_swaps_vec),))
c0 = bropt.cost(Dss,Wss,cost_type=cost_type)
for ns_idx,n_swaps in enumerate(n_swaps_vec):
print 'n_swaps = %.1f'%n_swaps
# Iterate over random permutations of pairs of nodes (not symmetrically)
c = np.zeros((n_permutations,))
node_pairs = [None for p_idx in range(n_permutations)]
centroid_pairs = [None for p_idx in range(n_permutations)]
for p_idx in range(n_permutations):
if not (p_idx+1)%print_every:
print 'Permutation #%d'%(p_idx+1)
D_swapped, node_pair, centroid_pair = \
bropt.swap_nodes(Dss,area_subset,centroids_subset,n_swaps=n_swaps,sym=sym)
node_pairs[p_idx] = node_pair
centroid_pairs[p_idx] = centroid_pair
c[p_idx] = bropt.cost(D_swapped,Wss,cost_type=cost_type)
p_value = ((c<=c0).sum()/float(n_permutations))
if ns_idx in plot_idxs:
fig,ax = plt.subplots(1,1,facecolor='w')
ax.scatter(np.arange(n_permutations),c,c='r')
ax.plot(np.arange(n_permutations),c0*np.ones((n_permutations,)),c='b',lw=5)
ax.set_xlim(0,n_permutations)
ax.set_xlabel('Permutation #')
ax.set_ylabel('Cost')
if sym:
ax.set_title('%s: %d symmetric swaps, P = %.3f'%(area_set,n_swaps,p_value))
else:
ax.set_title('%s: %d swaps, P = %.3f'%(area_set,n_swaps,p_value))
for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
ax.get_xticklabels() + ax.get_yticklabels()):
item.set_fontsize(16)
plt.draw()
p_values[ns_idx] = p_value
p_per_set[area_set_idx] = p_values
if PLOT_SPECIFIC_AREA_SETS:
fig,ax = plt.subplots(1,1,facecolor='white')
for a_idx,p_values in enumerate(p_per_set):
ax.plot(n_swaps_vec, p_values, lw=3, label=area_sets[a_idx])
ax.set_xlabel('n_swaps')
ax.set_ylabel('P')
ax.legend(prop={'size':20})
for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
ax.get_xticklabels() + ax.get_yticklabels()):
item.set_fontsize(20)
plt.draw()