Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perfect hospital (seekers) and death observations #85

Merged
merged 26 commits into from
Jun 9, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
9f8bd0a
new plots for ensemble states
agarbuno Jun 5, 2020
dfedc10
added some plots and incorporated testing capabilities
agarbuno Jun 5, 2020
9a70465
added DataObservation class and found bug in data assimilation loop
odunbar Jun 7, 2020
9fc6200
small changes and fixed captions
agarbuno Jun 7, 2020
778c16a
simple sum_to_one() function added
odunbar Jun 8, 2020
f59155a
changed parameters in file
agarbuno Jun 8, 2020
e65917c
Merge branch 'ag/multiple-tests' of https://github.com/dburov190/risk…
agarbuno Jun 8, 2020
2401bd7
added a hack to fix the prevalance and variance when taking perfect o…
odunbar Jun 8, 2020
b7597e8
small edits
agarbuno Jun 8, 2020
3467a5c
Merge branch 'ag/multiple-tests' of https://github.com/dburov190/risk…
agarbuno Jun 8, 2020
c9b8f76
fixed hack on perftectly observed observations in probability space
agarbuno Jun 8, 2020
c35d183
fixed bug: truth state as probs for da
agarbuno Jun 9, 2020
b6b2d45
just before a node based observation class
agarbuno Jun 9, 2020
706f987
created new observation type based on perfect record
agarbuno Jun 9, 2020
d8dabdb
fixed bug in observed variances for perfect observations
agarbuno Jun 9, 2020
5d291ac
test with node informed states: H and D
agarbuno Jun 9, 2020
47c8059
added observation dependent sum_to_one() and revert to DataObservation
odunbar Jun 9, 2020
c0fec96
commented print statements for sum_to_one testing
odunbar Jun 9, 2020
41963ca
small changes to measurements edge values
agarbuno Jun 9, 2020
43a6203
removed chaff from DataObservation, set observed means in perfect set…
odunbar Jun 9, 2020
c5134a4
resolved conflicts
agarbuno Jun 9, 2020
e0809ff
fixed perfect testing for record based observations: do not use mean …
agarbuno Jun 9, 2020
68ccf1f
turn off assimilated state printing
agarbuno Jun 9, 2020
3b5c3e2
no model error in the epidemic da example
agarbuno Jun 9, 2020
f6baab6
towards parameters plots
agarbuno Jun 9, 2020
f65672a
perfect observations with observed variances
agarbuno Jun 9, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions epiforecast/epiplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,101 @@ def plot_master_eqns(states, t, axes = None, xlims = None, reduced_system = True
plt.tight_layout()

return axes

def plot_ensemble_states(states, t, axes = None,
xlims = None,
reduced_system = True,
leave = False,
figsize = (15, 4),
a_min = None,
a_max = None):
if axes is None:
fig, axes = plt.subplots(1, 2, figsize = figsize)

ensemble_size = states.shape[0]
if reduced_system:
N_eqns = 5
statuses = np.arange(N_eqns)
statuses_colors = ['C0', 'C1', 'C2', 'C4', 'C6']
else:
N_eqns = 6
statuses = np.arange(N_eqns)
statuses_colors = ['C0', 'C3', 'C1', 'C2', 'C4', 'C6']
population = states.shape[1]/N_eqns

states_sum = states.reshape(ensemble_size, N_eqns, -1, len(t)).sum(axis = 2)
states_perc = np.percentile(states_sum, q = [1, 10, 25, 50, 75, 90, 99], axis = 0)

for status in statuses:
if (reduced_system and status in [0, 3]) or (not reduced_system and status in [0, 2, 4]):
axes[0].fill_between(t, np.clip(states_perc[0,status], a_min, a_max), np.clip(states_perc[-1,status], a_min, a_max), alpha = .2, color = statuses_colors[status], linewidth = 0.)
axes[0].fill_between(t, np.clip(states_perc[1,status], a_min, a_max), np.clip(states_perc[-2,status], a_min, a_max), alpha = .2, color = statuses_colors[status], linewidth = 0.)
axes[0].fill_between(t, np.clip(states_perc[2,status], a_min, a_max), np.clip(states_perc[-3,status], a_min, a_max), alpha = .2, color = statuses_colors[status], linewidth = 0.)
axes[0].plot(t, states_perc[3,status], color = statuses_colors[status])

if (reduced_system and status in [1]) or (not reduced_system and status in [2, 3, 5]):
axes[1].fill_between(t, np.clip(states_perc[0,status], a_min, a_max), np.clip(states_perc[-1,status], a_min, a_max), alpha = .2, color = statuses_colors[status], linewidth = 0.)
axes[1].fill_between(t, np.clip(states_perc[1,status], a_min, a_max), np.clip(states_perc[-2,status], a_min, a_max), alpha = .2, color = statuses_colors[status], linewidth = 0.)
axes[1].fill_between(t, np.clip(states_perc[2,status], a_min, a_max), np.clip(states_perc[-3,status], a_min, a_max), alpha = .2, color = statuses_colors[status], linewidth = 0.)
axes[1].plot(t, states_perc[3,status], color = statuses_colors[status])

if (reduced_system and status in [2, 4]) or (not reduced_system and status in [3, 5]):
axes[2].fill_between(t, np.clip(states_perc[0,status], a_min, a_max), np.clip(states_perc[-1,status], a_min, a_max), alpha = .2, color = statuses_colors[status], linewidth = 0.)
axes[2].fill_between(t, np.clip(states_perc[1,status], a_min, a_max), np.clip(states_perc[-2,status], a_min, a_max), alpha = .2, color = statuses_colors[status], linewidth = 0.)
axes[2].fill_between(t, np.clip(states_perc[2,status], a_min, a_max), np.clip(states_perc[-3,status], a_min, a_max), alpha = .2, color = statuses_colors[status], linewidth = 0.)
axes[2].plot(t, states_perc[3,status], color = statuses_colors[status])

if reduced_system:
residual_state = population - states_sum.sum(axis = 1)
residual_state = np.percentile(residual_state, q = [1, 10, 25, 50, 75, 90, 99], axis = 0)
axes[0].fill_between(t, np.clip(residual_state[0], a_min, a_max), np.clip(residual_state[-1], a_min, a_max), alpha = .2, color = 'C3', linewidth = 0.)
axes[0].fill_between(t, np.clip(residual_state[1], a_min, a_max), np.clip(residual_state[-2], a_min, a_max), alpha = .2, color = 'C3', linewidth = 0.)
axes[0].fill_between(t, np.clip(residual_state[2], a_min, a_max), np.clip(residual_state[-3], a_min, a_max), alpha = .2, color = 'C3', linewidth = 0.)
axes[0].plot(t, np.clip(residual_state[3], a_min, a_max), color = 'C3')

axes[0].legend(['Susceptible', 'Resistant', 'Exposed'],
bbox_to_anchor=(0., 1.02, 1., .102), loc=3,
ncol=3, mode="expand", borderaxespad=0.);
axes[1].legend(['Infected'],
bbox_to_anchor=(0., 1.02, 1., .102), loc=3,
ncol=2, mode="expand", borderaxespad=0.);
axes[2].legend(['Hospitalized', 'Death'],
bbox_to_anchor=(0., 1.02, 1., .102), loc=3,
ncol=4, mode="expand", borderaxespad=0.);

for kk, ax in enumerate(axes):
ax.set_xlim(xlims)

plt.tight_layout()

return axes

def plot_kinetic_model_data(kinetic_model, axes):
statuses_name = kinetic_model.return_statuses
statuses_colors = ['C0', 'C3', 'C1', 'C2', 'C4', 'C6']
colors_dict = dict(zip(statuses_name, statuses_colors))

data = kinetic_model.statuses
axes[1].scatter(kinetic_model.current_time, data['I'][-1], c = colors_dict['I'], marker = 'x')
axes[2].scatter(kinetic_model.current_time, data['H'][-1], c = colors_dict['H'], marker = 'x')
axes[2].scatter(kinetic_model.current_time, data['D'][-1], c = colors_dict['D'], marker = 'x')

return axes

def plot_ensemble_transmission_latent_fraction(community_transmission_rate_trace, latent_periods_trace, time_horizon):
transmission_perc = np.percentile(community_transmission_rate_trace, q = [1, 25, 50, 75, 99], axis = 0)
latent_periods_perc = np.percentile(latent_periods_trace, q = [1, 25, 50, 75, 99], axis = 0)

fig, axes = plt.subplots(1, 2, figsize = (12, 4))

axes[0].fill_between(time_horizon, transmission_perc[0], transmission_perc[-1], alpha = .2, color = 'C0')
axes[0].fill_between(time_horizon, transmission_perc[1], transmission_perc[-2], alpha = .2, color = 'C0')
axes[0].plot(time_horizon, transmission_perc[2])
axes[0].set_title(r'Transmission rate: $\beta$');

axes[1].fill_between(time_horizon, latent_periods_perc[0], latent_periods_perc[-1], alpha = .2, color = 'C0')
axes[1].fill_between(time_horizon, latent_periods_perc[1], latent_periods_perc[-2], alpha = .2, color = 'C0')
axes[1].plot(time_horizon, latent_periods_perc[2])
axes[1].set_title(r'Hospitalization fraction: $h$');

return axes
18 changes: 9 additions & 9 deletions epiforecast/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def _set_ppv(self, scale = 'log'):
((1 - self.sensitivity) * self.prevalence + self.specificity * (1 - self.prevalence))

if scale == 'log':
logit_ppv = np.log(PPV/(1 - PPV))
logit_for = np.log(FOR/(1 - FOR))
logit_ppv = np.log(PPV/(1 - PPV + 1e-8))
logit_for = np.log(FOR/(1 - FOR + 1e-8))

self.logit_ppv_mean = logit_ppv.mean()
self.logit_ppv_var = logit_ppv.var()
Expand Down Expand Up @@ -117,10 +117,10 @@ def __init__(self,
#number of nodes in the graph
self.N = N
#number of different states a node can be in

if reduced_system == True:
self.status_catalog = dict(zip(['S', 'I', 'H', 'R', 'D'], np.arange(5)))

else:
self.status_catalog = dict(zip(['S', 'E', 'I', 'H', 'R', 'D'], np.arange(6)))
self.n_status = len(self.status_catalog.keys())
Expand Down Expand Up @@ -158,7 +158,7 @@ def find_observation_states(self, state):
else: #The value is too small
self.obs_states=np.array([],dtype=int)
print("no observation was above the threshold")


#combine them together
class Observation(StateInformedObservation, TestMeasurement):
Expand All @@ -173,9 +173,9 @@ def __init__(self,
reduced_system=True,
sensitivity = 0.80,
specificity = 0.99):

self.name=obs_name

StateInformedObservation.__init__(self,
N,
obs_frac,
Expand All @@ -189,8 +189,8 @@ def __init__(self,
sensitivity,
specificity,
reduced_system)
#State is a numpy array of size [self.N * n_status]

#State is a numpy array of size [self.N * n_status]
#def find_observation_states(self, state):
# obtain where one should make an observation based on the
# current state, and the contact network
Expand Down
Loading