-
Notifications
You must be signed in to change notification settings - Fork 11
/
visualize_feature_with_ci_calib.py
72 lines (61 loc) · 1.93 KB
/
visualize_feature_with_ci_calib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# Santiago Nunez-Corrales and Eric Jakobsson
# Illinois Informatics and Molecular and Cell Biology
# University of Illinois at Urbana-Champaign
# {nunezco,jake}@illinois.edu
# A simple tunable model for COVID-19 response
import matplotlib.pyplot as plt
import scipy.stats as sps
import seaborn as sns
import pandas as pd
import numpy as np
from covidmodel import CovidModel
import sys
feature = sys.argv[1]
ymax = float(sys.argv[2])
in_file = sys.argv[3]
out_file = sys.argv[4]
plt.figure(figsize = (11.7, 8.27))
plt.ticklabel_format(style='plain', axis='y')
df0 = pd.read_csv(in_file)
df0["Step"] = df0["Step"]/96
df = pd.DataFrame()
df["Step"] = df0["Step"]
df[feature] = df0[feature]
xmin = 0
xmax = df["Step"].max()
ymin = df[feature].min()
avg = []
low_ci_95 = []
high_ci_95 = []
low_ci_99 = []
high_ci_99 = []
print(f"Computing confidence intervals...")
for step in df["Step"].unique():
values = df[feature][df["Step"] == step]
f_mean = values.mean()
lci95, hci95 = sps.t.interval(0.95, len(values), loc=f_mean, scale=sps.sem(values))
lci99, hci99 = sps.t.interval(0.99, len(values), loc=f_mean, scale=sps.sem(values))
avg.append(f_mean)
low_ci_95.append(lci95)
high_ci_95.append(hci95)
low_ci_99.append(lci99)
high_ci_99.append(hci99)
df_stats = pd.DataFrame()
df_stats["Step"] = df["Step"].unique()
df_stats["mean"] = avg
df_stats["lci95"] = low_ci_95
df_stats["hci95"] = high_ci_95
df_stats["lci99"] = low_ci_99
df_stats["hci99"] = high_ci_99
fig, ax = plt.subplots()
ax.plot(df_stats["Step"], df_stats["mean"], color="darkred", label="Active")
ax.fill_between(df_stats["Step"], df_stats["lci95"], df_stats["hci95"], color='orangered', alpha=.1)
ax.vlines(116, 0, ymax, colors='darkblue')
ax.vlines(130, 0, ymax, colors='mediumblue')
ax.vlines(136, 0, ymax, colors='royalblue')
ax.set_xlim([xmin, xmax])
ax.set_ylim([ymin, ymax])
ax.set_xlabel("Days")
ax.set_ylabel("Fraction")
ax.legend()
plt.savefig(out_file, dpi=300)