-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoptimize_parameter.py
62 lines (53 loc) · 1.82 KB
/
optimize_parameter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Nov 27 18:28:29 2018
@author: ellereyireland1
"""
import system as sys
import run_gravity as g
import numpy as np
import coarse_graining as cg
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import save
#Make the system
s = g.Gravity()
system = sys.System()
# set normal to False to use zipf distribution for city size
system.random_system(1000, normal=False)
#set the lists that will contain the data points
distances = []
cost_values =[]
#set up the original system
s.set_system(system)
s.tuning_function()
s.set_flows()
#Coarse grain the system
coarse_grainer = cg.Coarse_graining(system, 5)
#cell_area = coarse_grainer.get_cell_area()
grained_system = coarse_grainer.generate_new_system()
original_flows = grained_system.flow_matrix
#get the mean value from the distance matrix and set the bounds
mean_dist = np.mean(grained_system.distance_matrix)
bound = np.sqrt(mean_dist)
distances = list(np.linspace(0.01, mean_dist + 5*bound, 1000))
# for each value of distance chosen, rerun the tuning on grained system
for d in tqdm(distances):
s.set_system(grained_system, distance =d)
s.tuning_function()
s.set_flows()
grained_flows = grained_system.flow_matrix
s.cost_function(original_flows, grained_flows)
cost_values.append(s.cost)
#save.save_object(cost_values, "Cost value array")
fig = plt.figure(1, figsize=(15.0, 9.0))
plt.rc('text', usetex=True)
ax = fig.add_subplot(111)
ax.plot(distances, cost_values, 'r-')
ax.set_xlabel("Distance coefficient (length)", fontsize = 15)
ax.set_ylabel("Value of cost function (unitless)", fontsize = 15)
ax.set_title(r'Cost function value against the parameter d', fontsize = 15)
plt.savefig('/Users/ellereyireland1/Documents/University/Third_year/BSc_project/Report/Images/cost_function_d')
plt.show()