diff --git a/epiforecast/data_assimilator.py b/epiforecast/data_assimilator.py index 4253c075..ec86ca7f 100644 --- a/epiforecast/data_assimilator.py +++ b/epiforecast/data_assimilator.py @@ -1,37 +1,38 @@ import numpy as np +import copy from epiforecast.ensemble_adjusted_kalman_filter import EnsembleAdjustedKalmanFilter class DataAssimilator: - def __init__(self, observations, errors, *, + def __init__(self, observations, errors, *, transition_rates_to_update_str = None, transmission_rate_to_update_flag = None): """ A data assimilator, to perform updates of model parameters and states using an - ensemble adjusted Kalman filter (EAKF) method. - + ensemble adjusted Kalman filter (EAKF) method. + Positional Args --------------- observations (list, [], or Observation): A list of Observations, or a single Observation. Generates the indices and covariances of observations - - errors (list, [], or Observation): Observation for the purpose of error checking. Error - observations are used to compute online differences at - the observed (according to Errors) between Kinetic and + + errors (list, [], or Observation): Observation for the purpose of error checking. Error + observations are used to compute online differences at + the observed (according to Errors) between Kinetic and Master Equation models Keyword Args - ------------ + ------------ transition_rates_to_update_str (list): list of strings naming the transition_rates we would like update with data. must coincide with naming found in epiforecast/populations.py. - If not provided, will set [] - + If not provided, will set [] + transmission_rate_to_update_flag (boolean): bool to update transmission rate with data If not provided will set False Methods ------- - + update(ensemble_state, data, contact_network=[], full_ensemble_transition_rates, full_ensemble_transmission_rate): Perform an update of the ensemble states `ensemble_state`, and if provided, ensemble parameters `full_ensemble_transition_rates`, `full_ensemble_transmission_rate` and the network @@ -39,14 +40,14 @@ def __init__(self, observations, errors, *, make_new_observation(state): For every Observation model, update the list of indices at which to observe (given by observations.obs_states). Returns a concatenated list of indices `observed_states` with duplicates removed. - - get_observation_cov(): For every Observation model, obtain the relevant variances when taking a measurement of data. Note we account for multiple node - measurements by using the minimum variance at that node (I.e if same node is queried twice in a time interval we take the most + + get_observation_cov(): For every Observation model, obtain the relevant variances when taking a measurement of data. Note we account for multiple node + measurements by using the minimum variance at that node (I.e if same node is queried twice in a time interval we take the most accurate test). Returns a diagonal covariance matrix for the distinct observed states, with the minimum variances on the diagonal. - - sum_to_one(state): Takes the state `state` and enforces that all statuses at a node sum to one. Does this by distributing the mass (1-(I+H+R+D)) into S and E, where - the mass is divided based on the previous state's relative mass in S and E. i.e Snew= S/(S+E)*(1-(I+H+R+D)), Enew = E/(S+E)*(1-(I+H+R+D)) - + + sum_to_one(state): Takes the state `state` and enforces that all statuses at a node sum to one. Does this by distributing the mass (1-(I+H+R+D)) into S and E, where + the mass is divided based on the previous state's relative mass in S and E. i.e Snew= S/(S+E)*(1-(I+H+R+D)), Enew = E/(S+E)*(1-(I+H+R+D)) + error_to_truth_state(state,data): updates emodel.obs_states and measures (user prescribed) differences between the data and state online. #Current implementation sums the difference in number of predicted states #and actual states in an given interval e.g 0.5 <= I <= 1.0 @@ -60,16 +61,16 @@ def __init__(self, observations, errors, *, if transition_rates_to_update_str is None: transition_rates_to_update_str = [] - + if transmission_rate_to_update_flag is None: - transmission_rate_to_update_flag = False - + transmission_rate_to_update_flag = False + if not isinstance(transition_rates_to_update_str,list):#if it's a string, not array transition_rates_to_update_str = [transition_rates_to_update_str] - + # observation models(s) - self.observations = observations - + self.observations = observations + # the data assimilation models (One for each observation model) self.damethod = EnsembleAdjustedKalmanFilter() @@ -80,19 +81,25 @@ def __init__(self, observations, errors, *, self.transition_rates_to_update_str = transition_rates_to_update_str self.transmission_rate_to_update_flag = transmission_rate_to_update_flag - def find_observation_states(self, ensemble_state): + def find_observation_states(self, + contact_network, + ensemble_state, + data): """ Make all the observations in the list self.observations. This sets observation.obs_states. """ - + print("Observation type : Number of Observed states") + observed_states = [] for observation in self.observations: - observation.find_observation_states(ensemble_state) - - observed_states = np.hstack([observation.obs_states for observation in self.observations]) - - return observed_states + observation.find_observation_states(contact_network, ensemble_state, data) + print(observation.name,":",len(observation.obs_states)) + if observation.obs_states.size > 0: + observed_states.extend(observation.obs_states) + + # observed_states = np.hstack([observation.obs_states for observation in self.observations]) + return np.array(observed_states) def observe(self, contact_network, @@ -101,24 +108,29 @@ def observe(self, scale = 'log', noisy_measurement = False): - + observed_means = [] + observed_variances = [] for observation in self.observations: - observation.observe(contact_network, - state, - data, - scale, - noisy_measurement) - - observed_means = np.hstack([observation.mean for observation in self.observations]) - observed_variances= np.hstack([observation.variance for observation in self.observations]) - - return observed_means, observed_variances - + if (observation.obs_states.size >0): + observation.observe(contact_network, + state, + data, + scale, + noisy_measurement) + + observed_means.extend(observation.mean) + observed_variances.extend(observation.variance) + + #observed_means = np.hstack([observation.mean for observation in self.observations]) + #observed_variances= np.hstack([observation.variance for observation in self.observations]) + + return np.array(observed_means), np.array(observed_variances) + # ensemble_state np.array([ensemble size, num status * num nodes] # data np.array([num status * num nodes]) # contact network networkx.graph (if provided) - # full_ensemble_transition_rates list[ensemble size] of TransitionRates objects from epiforecast.populations + # full_ensemble_transition_rates list[ensemble size] of TransitionRates objects from epiforecast.populations # full_ensemble_transmission_rate np.array([ensemble size]) def update(self, ensemble_state, @@ -133,7 +145,7 @@ def update(self, return ensemble_state, full_ensemble_transition_rates, full_ensemble_transmission_rate else: - + # Extract the transition_rates to update if len(self.transition_rates_to_update_str) > 0: @@ -147,47 +159,68 @@ def update(self, # Have to create here as rates_tmp unknown in advance if member == 0: ensemble_transition_rates = np.empty((0, rates_tmp.size), dtype=float) - + ensemble_transition_rates = np.append(ensemble_transition_rates, [rates_tmp], axis=0) ensemble_transition_rates = np.vstack(ensemble_transition_rates) else: # set to column of empties ensemble_transition_rates = np.empty((ensemble_size, 0), dtype=float) - + if self.transmission_rate_to_update_flag is True: ensemble_transmission_rate = full_ensemble_transmission_rate else: # set to column of empties ensemble_transmission_rate = np.empty((ensemble_size, 0), dtype=float) - + om = self.observations dam = self.damethod - obs_states = self.find_observation_states(ensemble_state) # Generate states to observe + obs_states = self.find_observation_states(user_network, ensemble_state, data) # Generate states to observe if (obs_states.size > 0): - - print("Partial states to be assimilated: ", obs_states.size) + + print("Total states to be assimilated: ", obs_states.size) # Get the truth indices, for the observation(s) truth,var = self.observe(user_network, ensemble_state, - data) - + data, + scale = None) cov = np.diag(var) + # Get the covariances for the observation(s), with the minimum returned if two overlap #cov = self.get_observation_cov() - + # Perform da model update with ensemble_state: states, transition and transmission rates - - (ensemble_state[:, obs_states], + prev_ensemble_state = copy.deepcopy(ensemble_state) + (ensemble_state[:, obs_states], new_ensemble_transition_rates, new_ensemble_transmission_rate) = dam.update(ensemble_state[:, obs_states], ensemble_transition_rates, ensemble_transmission_rate, truth, cov) - + + + # print states > 1 + #tmp = ensemble_state.reshape(ensemble_state.shape[0],5,om[0].N) + #sum_states = np.sum(tmp,axis=1) + # print(sum_states[sum_states > 1 + 1e-2]) + + # print(truth[:5]) + # print(" ") + # print(np.round(prev_ensemble_state[:,obs_states][:3,:5], 2)) + # print(" ") + # print(np.round(ensemble_state[:,obs_states][:3,:5], 2)) + + self.sum_to_one(prev_ensemble_state, ensemble_state) + + # print same states after the sum_to_one() + # tmp = ensemble_state.reshape(ensemble_state.shape[0],5,om[0].N) + # sum_states_after = np.sum(tmp,axis=1) + # print(sum_states_after[sum_states > 1 + 1e-2]) #see what the sum_to_one did to them + + # Update the new transition rates if required if len(self.transition_rates_to_update_str) > 0: @@ -213,7 +246,8 @@ def update(self, # Update the transmission_rate if required if self.transmission_rate_to_update_flag is True: full_ensemble_transmission_rate=new_ensemble_transmission_rate - + + print("EAKF error:", dam.error[-1]) else: print("No assimilation required") @@ -221,14 +255,14 @@ def update(self, # Error to truth if len(self.online_emodel)>0: self.error_to_truth_state(ensemble_state,data) - + # Return ensemble_state, transition rates, and transmission rate return ensemble_state, full_ensemble_transition_rates, full_ensemble_transmission_rate - + #defines a method to take a difference to the data state def error_to_truth_state(self,ensemble_state,data): - + em = self.online_emodel # get corresponding error model # Make sure you have a deterministic ERROR model - or it will not match truth @@ -242,7 +276,7 @@ def error_to_truth_state(self,ensemble_state,data): #print(truth[predicted_infected]) em.make_new_observation(truth[np.newaxis, :]) - + # take error actual_infected= em.obs_states @@ -258,3 +292,44 @@ def error_to_truth_state(self,ensemble_state,data): #as we measure a subset of states, we may need to enforce other states to sum to one + + def sum_to_one(self, prev_ensemble_state, ensemble_state): + N=self.observations[0].N + n_status=self.observations[0].n_status + if n_status == 6: + #First enforce probabilities == 1, by placing excess in susceptible and Exposed + #split based on their current proportionality. + #(Put all in S or E leads quickly to [0,1] bounding issues. + tmp = ensemble_state.reshape(ensemble_state.shape[0],n_status,N) + IHRDmass = np.sum(tmp[:,2:,:],axis=1) #sum over I H R D + Smass = ensemble_state[:,0:N]#mass in S + Emass = ensemble_state[:,N:2*N]#mass in E + fracS = Smass/(Smass+Emass)#get the proportion of mass in frac1 + fracE = 1.0-fracS + ensemble_state[:,0:N] = (1.0 - IHRDmass)*fracS #mult rows by fracS + ensemble_state[:,N:2*N] = (1.0 - IHRDmass)*fracE + + elif n_status==5: + # First obtain the mass contained in category "E" + prev_tmp = prev_ensemble_state.reshape(prev_ensemble_state.shape[0],n_status, N) + Emass = 1.0 - np.sum(prev_tmp,axis=1) # E= 1 - (S + I + H + R + D) + # for each observation we get the observed status e.g 'I' and fix it (as it was updated) + # we then normalize the other states e.g (S,'E',H,R,D) over the difference 1-I + for observation in self.observations: + if len(observation.obs_states > 0): + observed_nodes = np.remainder(observation.obs_states,N) + updated_status = observation.obs_status_idx + free_statuses = [ i for i in range(5) if i!= updated_status] + tmp = ensemble_state.reshape(ensemble_state.shape[0],n_status, N) + + # create arrays of the mass in the observed and the unobserved "free" statuses at the observed nodes. + observed_tmp = tmp[:,:,observed_nodes] + updated_mass = observed_tmp[:, updated_status, :] + free_states = observed_tmp + free_states[:, updated_status, :] = np.zeros([free_states.shape[0], 1, free_states.shape[2]]) #remove this axis for the sum (but maintain the shape) + + free_mass = np.sum(free_states,axis=1) + Emass[:,observed_nodes] + + # normalize the free values e.g for S: set S = (1-I) * S/(S+E+H+R+D) + for i in free_statuses: + ensemble_state[:, i*N+observed_nodes] = (1.0 - updated_mass[:,0,:]) * (free_states[:, i, :] / free_mass) diff --git a/epiforecast/ensemble_adjusted_kalman_filter.py b/epiforecast/ensemble_adjusted_kalman_filter.py index bc416f77..49c91408 100755 --- a/epiforecast/ensemble_adjusted_kalman_filter.py +++ b/epiforecast/ensemble_adjusted_kalman_filter.py @@ -19,7 +19,7 @@ def __init__(self, full_svd = True, \ if full_svd != True and full_svd != False: sys.exit("Incorrect flag detected for full_svd (needs to be True/False)!") - + # Error self.error = np.empty(0) self.full_svd = full_svd @@ -38,18 +38,18 @@ def compute_error(self, x, x_t, cov): self.error = np.append(self.error, error) - + # x: forward evaluation of state, i.e. x(q), with shape (num_ensembles, num_elements) # q: model parameters, with shape (num_ensembles, num_elements) def update(self, ensemble_state, clinical_statistics, transmission_rates, truth, cov, r=1.0): ''' - ensemble_state (np.array): J x M of observed states for each of the J ensembles - + - clinical_statistics (np.array): transition rate model parameters for each of the J ensembles - transmission_rates (np.array): transmission rate of model parameters for each of the J ensembles - + - truth (np.array): M x 1 array of observed states. - cov (np.array): M x M array of covariances that represent observational uncertainty. @@ -67,14 +67,22 @@ def update(self, ensemble_state, clinical_statistics, transmission_rates, truth, assert (cov.ndim == 2), 'EAKF init: covariance must be 2d array' assert (truth.size == cov.shape[0] and truth.size == cov.shape[1]),\ 'EAKF init: truth and cov are not the correct sizes' - + # Observation data statistics at the observed nodes x_t = truth cov = r**2 * cov + # print("----------------------------------------------------") + # print(x_t[:3]) + # print(" ") + # print(np.diag(cov)[:3]) + # print("----------------------------------------------------") + cov = (1./np.maximum(x_t, 1e-9)/np.maximum(1-x_t, 1e-9))**2 * cov x_t = np.log(np.maximum(x_t, 1e-9)/np.maximum(1.-x_t, 1e-9)) + # print(np.diag(cov)) + try: cov_inv = np.linalg.inv(cov) except np.linalg.linalg.LinAlgError: @@ -83,37 +91,54 @@ def update(self, ensemble_state, clinical_statistics, transmission_rates, truth, # States x = np.log(np.maximum(ensemble_state, 1e-9) / np.maximum(1.0 - ensemble_state, 1e-9)) - # Stacked parameters and states # the transition and transmission parameters act similarly in the algorithm p = clinical_statistics q = transmission_rates - - zp = np.hstack([p, q, x]) + + #if only 1 state is given + if (ensemble_state.ndim == 1): + x=x[np.newaxis].T + + if p.size>0 and q.size>0: + zp = np.hstack([p, q, x]) + elif p.size>0 and q.size==0: + zp = np.hstack([p,x]) + elif q.size>0 and p.size==0: + zp = np.hstack([q, x]) + else: + zp = x + params_noise_active=False + x_t = x_t cov = cov - + # Ensemble size J = x.shape[0] - + # Sizes of q and x pqs = q[0].size +p[0].size xs = x[0].size zp_bar = np.mean(zp, 0) Sigma = np.cov(zp.T) - + + #if only one state is given + if Sigma.ndim < 2: + Sigma=np.array([Sigma]) + Sigma=Sigma[np.newaxis] + if self.full_svd == True: - # Add noises to the diagonal of sample covariance - # Current implementation involves a small constant + # Add noises to the diagonal of sample covariance + # Current implementation involves a small constant # This numerical trick can be deactivated if Sigma is not ill-conditioned if self.params_noise_active == True: - Sigma[:pqs,:pqs] = Sigma[:pqs,:pqs] + np.identity(pqs) * self.params_cov_noise + Sigma[:pqs,:pqs] = Sigma[:pqs,:pqs] + np.identity(pqs) * self.params_cov_noise if self.states_noise_active == True: Sigma[pqs:,pqs:] = Sigma[pqs:,pqs:] + np.identity(xs) * self.states_cov_noise - + # Follow Anderson 2001 Month. Weath. Rev. Appendix A. - # Preparing matrices for EAKF + # Preparing matrices for EAKF H = np.hstack([np.zeros((xs, pqs)), np.eye(xs)]) Hpq = np.hstack([np.eye(pqs), np.zeros((pqs, xs))]) F, Dp_vec, _ = la.svd(Sigma) @@ -122,7 +147,7 @@ def update(self, ensemble_state, clinical_statistics, transmission_rates, truth, G = np.diag(np.sqrt(Dp_vec)) G_inv = np.diag(1./np.sqrt(Dp_vec)) U, D_vec, _ = la.svd(np.linalg.multi_dot([G.T, F.T, H.T, cov_inv, H, F, G])) - B = np.diag((1.0 + D_vec) ** (-1.0 / 2.0)) + B = np.diag((1.0 + D_vec) ** (-1.0 / 2.0)) A = np.linalg.multi_dot([F, \ G.T, \ U, \ @@ -136,7 +161,7 @@ def update(self, ensemble_state, clinical_statistics, transmission_rates, truth, from sklearn.decomposition import TruncatedSVD # Follow Anderson 2001 Month. Weath. Rev. Appendix A. - # Preparing matrices for EAKF + # Preparing matrices for EAKF H = np.hstack([np.zeros((xs, pqs)), np.eye(xs)]) Hpq = np.hstack([np.eye(pqs), np.zeros((pqs, xs))]) svd1 = TruncatedSVD(n_components=J-1, random_state=42) @@ -157,10 +182,10 @@ def update(self, ensemble_state, clinical_statistics, transmission_rates, truth, F.T]) Sigma_u = np.linalg.multi_dot([A, Sigma, A.T]) - # Adding noises (model uncertainties) to the truncated dimensions + # Adding noises (model uncertainties) to the truncated dimensions # Need to further think about how to reduce the cost of full SVD here #F_u, Dp_u_vec, _ = la.svd(Sigma_u) - #Dp_u_vec[J-1:] = np.min(Dp_u_vec[:J-1]) + #Dp_u_vec[J-1:] = np.min(Dp_u_vec[:J-1]) #Sigma_u = np.linalg.multi_dot([F_u, np.diag(Dp_u_vec), F_u.T]) # Adding noises approximately to the truncated dimensions (Option #1) @@ -173,16 +198,16 @@ def update(self, ensemble_state, clinical_statistics, transmission_rates, truth, #vec = np.diag(Sigma_u) #vec = np.maximum(vec, svd1.singular_values_[-1]) ##vec = np.maximum(vec, np.sort(vec)[-J]) - #np.fill_diagonal(Sigma_u, vec) + #np.fill_diagonal(Sigma_u, vec) ## Adding noises into data for each ensemble member (Currently deactivated) # noise = np.array([np.random.multivariate_normal(np.zeros(xs), cov) for _ in range(J)]) # x_t = x_t + noise zu_bar = np.dot(Sigma_u, \ (np.dot(Sigma_inv, zp_bar) + np.dot(np.dot(H.T, cov_inv), x_t))) - + # Update parameters and state in `zu` - zu = np.dot(zp - zp_bar, A.T) + zu_bar + zu = np.dot(zp - zp_bar, A.T) + zu_bar # Store updated parameters and states x_logit = np.dot(zu, H.T) @@ -195,11 +220,15 @@ def update(self, ensemble_state, clinical_statistics, transmission_rates, truth, pqout=np.dot(zu,Hpq.T) new_clinical_statistics, new_transmission_rates = pqout[:, :clinical_statistics.shape[1]], pqout[:,clinical_statistics.shape[1]:] - + #self.x = np.append(self.x, [x_p], axis=0) + if (ensemble_state.ndim == 1): + new_ensemble_state=new_ensemble_state.squeeze() + # Compute error self.compute_error(x_logit,x_t,cov) + #print("new_clinical_statistics", new_clinical_statistics) #print("new_transmission_rates", new_transmission_rates) return new_ensemble_state, new_clinical_statistics, new_transmission_rates diff --git a/epiforecast/epiplots.py b/epiforecast/epiplots.py index 5eb21d5b..9a76f80c 100644 --- a/epiforecast/epiplots.py +++ b/epiforecast/epiplots.py @@ -63,3 +63,117 @@ def plot_master_eqns(states, t, axes = None, xlims = None, reduced_system = True plt.tight_layout() return axes + +def plot_ensemble_states(states, t, axes = None, + xlims = None, + reduced_system = True, + leave = False, + figsize = (15, 4), + a_min = None, + a_max = None): + if axes is None: + fig, axes = plt.subplots(1, 2, figsize = figsize) + + ensemble_size = states.shape[0] + if reduced_system: + N_eqns = 5 + statuses = np.arange(N_eqns) + statuses_colors = ['C0', 'C1', 'C2', 'C4', 'C6'] + else: + N_eqns = 6 + statuses = np.arange(N_eqns) + statuses_colors = ['C0', 'C3', 'C1', 'C2', 'C4', 'C6'] + population = states.shape[1]/N_eqns + + states_sum = states.reshape(ensemble_size, N_eqns, -1, len(t)).sum(axis = 2) + states_perc = np.percentile(states_sum, q = [1, 10, 25, 50, 75, 90, 99], axis = 0) + + for status in statuses: + if (reduced_system and status in [0, 3]) or (not reduced_system and status in [0, 2, 4]): + axes[0].fill_between(t, np.clip(states_perc[0,status], a_min, a_max), np.clip(states_perc[-1,status], a_min, a_max), alpha = .2, color = statuses_colors[status], linewidth = 0.) + axes[0].fill_between(t, np.clip(states_perc[1,status], a_min, a_max), np.clip(states_perc[-2,status], a_min, a_max), alpha = .2, color = statuses_colors[status], linewidth = 0.) + axes[0].fill_between(t, np.clip(states_perc[2,status], a_min, a_max), np.clip(states_perc[-3,status], a_min, a_max), alpha = .2, color = statuses_colors[status], linewidth = 0.) + axes[0].plot(t, states_perc[3,status], color = statuses_colors[status]) + + if (reduced_system and status in [1]) or (not reduced_system and status in [2, 3, 5]): + axes[1].fill_between(t, np.clip(states_perc[0,status], a_min, a_max), np.clip(states_perc[-1,status], a_min, a_max), alpha = .2, color = statuses_colors[status], linewidth = 0.) + axes[1].fill_between(t, np.clip(states_perc[1,status], a_min, a_max), np.clip(states_perc[-2,status], a_min, a_max), alpha = .2, color = statuses_colors[status], linewidth = 0.) + axes[1].fill_between(t, np.clip(states_perc[2,status], a_min, a_max), np.clip(states_perc[-3,status], a_min, a_max), alpha = .2, color = statuses_colors[status], linewidth = 0.) + axes[1].plot(t, states_perc[3,status], color = statuses_colors[status]) + + if (reduced_system and status in [2, 4]) or (not reduced_system and status in [3, 5]): + axes[2].fill_between(t, np.clip(states_perc[0,status], a_min, a_max), np.clip(states_perc[-1,status], a_min, a_max), alpha = .2, color = statuses_colors[status], linewidth = 0.) + axes[2].fill_between(t, np.clip(states_perc[1,status], a_min, a_max), np.clip(states_perc[-2,status], a_min, a_max), alpha = .2, color = statuses_colors[status], linewidth = 0.) + axes[2].fill_between(t, np.clip(states_perc[2,status], a_min, a_max), np.clip(states_perc[-3,status], a_min, a_max), alpha = .2, color = statuses_colors[status], linewidth = 0.) + axes[2].plot(t, states_perc[3,status], color = statuses_colors[status]) + + if reduced_system: + residual_state = population - states_sum.sum(axis = 1) + residual_state = np.percentile(residual_state, q = [1, 10, 25, 50, 75, 90, 99], axis = 0) + axes[0].fill_between(t, np.clip(residual_state[0], a_min, a_max), np.clip(residual_state[-1], a_min, a_max), alpha = .2, color = 'C3', linewidth = 0.) + axes[0].fill_between(t, np.clip(residual_state[1], a_min, a_max), np.clip(residual_state[-2], a_min, a_max), alpha = .2, color = 'C3', linewidth = 0.) + axes[0].fill_between(t, np.clip(residual_state[2], a_min, a_max), np.clip(residual_state[-3], a_min, a_max), alpha = .2, color = 'C3', linewidth = 0.) + axes[0].plot(t, np.clip(residual_state[3], a_min, a_max), color = 'C3') + + axes[0].legend(['Susceptible', 'Resistant', 'Exposed'], + bbox_to_anchor=(0., 1.02, 1., .102), loc=3, + ncol=3, mode="expand", borderaxespad=0.); + axes[1].legend(['Infected'], + bbox_to_anchor=(0., 1.02, 1., .102), loc=3, + ncol=2, mode="expand", borderaxespad=0.); + axes[2].legend(['Hospitalized', 'Death'], + bbox_to_anchor=(0., 1.02, 1., .102), loc=3, + ncol=4, mode="expand", borderaxespad=0.); + + for kk, ax in enumerate(axes): + ax.set_xlim(xlims) + + plt.tight_layout() + + return axes + +def plot_kinetic_model_data(kinetic_model, axes): + statuses_name = kinetic_model.return_statuses + statuses_colors = ['C0', 'C3', 'C1', 'C2', 'C4', 'C6'] + colors_dict = dict(zip(statuses_name, statuses_colors)) + + data = kinetic_model.statuses + axes[1].scatter(kinetic_model.current_time, data['I'][-1], c = colors_dict['I'], marker = 'x') + axes[2].scatter(kinetic_model.current_time, data['H'][-1], c = colors_dict['H'], marker = 'x') + axes[2].scatter(kinetic_model.current_time, data['D'][-1], c = colors_dict['D'], marker = 'x') + + # axes[2].set_ylim(-.5, 10) + + return axes + +def plot_ensemble_transmission_latent_fraction(community_transmission_rate_trace, latent_periods_trace, time_horizon): + transmission_perc = np.percentile(community_transmission_rate_trace, q = [1, 25, 50, 75, 99], axis = 0) + latent_periods_perc = np.percentile(latent_periods_trace, q = [1, 25, 50, 75, 99], axis = 0) + + fig, axes = plt.subplots(1, 2, figsize = (12, 4)) + + axes[0].fill_between(time_horizon, transmission_perc[0], transmission_perc[-1], alpha = .2, color = 'C0') + axes[0].fill_between(time_horizon, transmission_perc[1], transmission_perc[-2], alpha = .2, color = 'C0') + axes[0].plot(time_horizon, transmission_perc[2]) + axes[0].set_title(r'Transmission rate: $\beta$'); + + axes[1].fill_between(time_horizon, latent_periods_perc[0], latent_periods_perc[-1], alpha = .2, color = 'C0') + axes[1].fill_between(time_horizon, latent_periods_perc[1], latent_periods_perc[-2], alpha = .2, color = 'C0') + axes[1].plot(time_horizon, latent_periods_perc[2]) + axes[1].set_title(r'Latent period: $\gamma$'); + + return axes + +def plot_scalar_parameters(parameters, time_horizon, names): + percentiles = {} + fig, axes = plt.subplots(1, len(parameters), figsize = (4 * len(parameters), 4)) + + for kk, parameter in enumerate(names): + percentiles[parameter] = np.percentile(parameters[kk], q = [1, 25, 50, 75, 99], axis = 0) + + axes[kk].fill_between(time_horizon, percentiles[parameter][0], percentiles[parameter][-1], alpha = .2, color = 'C0') + axes[kk].fill_between(time_horizon, percentiles[parameter][1], percentiles[parameter][-2], alpha = .2, color = 'C0') + axes[kk].plot(time_horizon, percentiles[parameter][2]) + axes[kk].set_title(names[kk]); + + return axes diff --git a/epiforecast/measurements.py b/epiforecast/measurements.py index 6045c9db..55efa2b2 100644 --- a/epiforecast/measurements.py +++ b/epiforecast/measurements.py @@ -19,7 +19,7 @@ def __init__(self, self.status = status self.n_status = len(self.status_catalog.keys()) - def _set_prevalence(self, ensemble_states): + def _set_prevalence(self, ensemble_states, fixed_prevalence = None): """ Inputs: ------- @@ -27,9 +27,12 @@ def _set_prevalence(self, ensemble_states): status_idx : status id of interest. Following the ordering of the reduced system SIRHD. """ - population = ensemble_states.shape[1]/self.n_status - ensemble_size = ensemble_states.shape[0] - self.prevalence = ensemble_states.reshape(ensemble_size,self.n_status,-1)[:,self.status_catalog[self.status],:].sum(axis = 1)/population + if fixed_prevalence is None: + population = ensemble_states.shape[1]/self.n_status + ensemble_size = ensemble_states.shape[0] + self.prevalence = ensemble_states.reshape(ensemble_size,self.n_status,-1)[:,self.status_catalog[self.status],:].sum(axis = 1)/population + else: + self.prevalence = fixed_prevalence def _set_ppv(self, scale = 'log'): PPV = self.sensitivity * self.prevalence / \ @@ -39,8 +42,8 @@ def _set_ppv(self, scale = 'log'): ((1 - self.sensitivity) * self.prevalence + self.specificity * (1 - self.prevalence)) if scale == 'log': - logit_ppv = np.log(PPV/(1 - PPV)) - logit_for = np.log(FOR/(1 - FOR)) + logit_ppv = np.log(PPV/(1 - PPV + 1e-8)) + logit_for = np.log(FOR/(1 - FOR + 1e-8)) self.logit_ppv_mean = logit_ppv.mean() self.logit_ppv_var = logit_ppv.var() @@ -55,8 +58,8 @@ def _set_ppv(self, scale = 'log'): self.for_mean = FOR.mean() self.for_var = FOR.var() - def update_prevalence(self, ensemble_states, scale = 'log' ): - self._set_prevalence(ensemble_states) + def update_prevalence(self, ensemble_states, scale = 'log', fixed_prevalence=None ): + self._set_prevalence(ensemble_states, fixed_prevalence) self._set_ppv(scale = scale) def get_mean_and_variance(self, positive_test = True, scale = 'log'): @@ -82,8 +85,6 @@ def take_measurements(self, nodes_state_dict, scale = 'log', noisy_measurement = ------- """ - - measurements = {} uncertainty = {} @@ -97,13 +98,8 @@ def take_measurements(self, nodes_state_dict, scale = 'log', noisy_measurement = return measurements, uncertainty - - - #### Adding Observations in here - - #We observe a subset of nodes at a status, only if the state exceeds a given threshold value. #e.g we have a probability of observing I_i if (I_i > 0.8) when the observation takes place. class StateInformedObservation: @@ -117,10 +113,10 @@ def __init__(self, #number of nodes in the graph self.N = N #number of different states a node can be in - + if reduced_system == True: self.status_catalog = dict(zip(['S', 'I', 'H', 'R', 'D'], np.arange(5))) - + else: self.status_catalog = dict(zip(['S', 'E', 'I', 'H', 'R', 'D'], np.arange(6))) self.n_status = len(self.status_catalog.keys()) @@ -135,13 +131,14 @@ def __init__(self, self.obs_max_threshold = np.clip(max_threshold,0.0,1.0) #default init observation - self.obs_states = np.arange(int(self.obs_frac*self.N)*self.obs_status_idx.size) - + self.obs_states=np.empty(0) #updates the observation model when taking observation - def find_observation_states(self, state): + def find_observation_states(self, + contact_network, + state, + data): #Candidates for observations are those with a required state >= threshold - onetoN=np.arange(self.N) - candidate_states= np.hstack([self.N*self.obs_status_idx+i for i in onetoN]) + candidate_states= np.hstack([self.N*self.obs_status_idx+i for i in range(self.N)]) xmean = np.mean(state[:,candidate_states],axis=0) candidate_states_ens=candidate_states[(xmean>=self.obs_min_threshold) & \ @@ -150,15 +147,14 @@ def find_observation_states(self, state): M=candidate_states_ens.size if (int(self.obs_frac*M)>=1)&(self.obs_frac < 1.0) : # If there is at least one state to sample (...)>=1.0 - # If we don't just sample every state - choice=np.random.choice(onetoM, size=int(self.obs_frac*M), replace=False) + # and if we don't sample every state + choice=np.random.choice(np.arange(M), size=int(self.obs_frac*M), replace=False) self.obs_states=candidate_states_ens[choice] elif (self.obs_frac == 1.0): self.obs_states=candidate_states_ens else: #The value is too small self.obs_states=np.array([],dtype=int) print("no observation was above the threshold") - #combine them together class Observation(StateInformedObservation, TestMeasurement): @@ -173,9 +169,9 @@ def __init__(self, reduced_system=True, sensitivity = 0.80, specificity = 0.99): - + self.name=obs_name - + StateInformedObservation.__init__(self, N, obs_frac, @@ -189,12 +185,18 @@ def __init__(self, sensitivity, specificity, reduced_system) - - #State is a numpy array of size [self.N * n_status] - #def find_observation_states(self, state): + + #State is a numpy array of size [self.N * n_status] + def find_observation_states(self, + contact_network, + state, + data): # obtain where one should make an observation based on the # current state, and the contact network - # StateInformedObservation.find_observation_states(self,state) + StateInformedObservation.find_observation_states(self, + contact_network, + state, + data) # data is a dictionary {node number : status} data[i] = contact_network.node(i) # status is 'I' @@ -215,13 +217,204 @@ def observe(self, observed_nodes = np.array(list(contact_network.nodes))[observed_states] observed_data = {node : data[node] for node in observed_nodes} - mean,var = TestMeasurement.take_measurements(self, + mean, var = TestMeasurement.take_measurements(self, observed_data, scale, noisy_measurement) observed_mean = np.array([mean[node] for node in observed_nodes]) - observed_variance = np.array([var[node] for node in observed_nodes]) + observed_variance = np.array([np.maximum(var[node], 1e-3) for node in observed_nodes]) self.mean = observed_mean self.variance = observed_variance + +class DataInformedObservation: + def __init__(self, + N, + bool_type, + obs_status, + reduced_system): + + #number of nodes in the graph + self.N = N + #if you want to find the where the status is, or where it is not. + self.bool_type = bool_type + if reduced_system == True: + self.status_catalog = dict(zip(['S', 'I', 'H', 'R', 'D'], np.arange(5))) + + else: + self.status_catalog = dict(zip(['S', 'E', 'I', 'H', 'R', 'D'], np.arange(6))) + self.n_status = len(self.status_catalog.keys()) + self.obs_status=obs_status + self.obs_status_idx=np.array([self.status_catalog[status] for status in obs_status]) + self.obs_states=np.empty(0) + #updates the observation model when taking observation + def find_observation_states(self, + contact_network, + state, + data): + + # Obtain relevant data entries + user_nodes = np.array(list(contact_network.nodes)) + user_data = {node : data[node] for node in user_nodes} + + candidate_nodes = [] + for status in self.obs_status: + candidate_nodes.extend([node for node in user_data.keys() if (user_data[node] == status) == self.bool_type]) + + # we now have the node numbers for the statuses we want to measure, + # but we require an np index for them + candidate_states_modulo_population = np.array([state for state in range(len(user_nodes)) + if user_nodes[state] in candidate_nodes]) + + #now add the required shift to obtain the correct status 'I' or 'H' etc. + candidate_states = [candidate_states_modulo_population + i*self.N for i in self.obs_status_idx] + + self.obs_states=np.hstack(candidate_states) + + + + +class DataObservation(DataInformedObservation, TestMeasurement): + + def __init__(self, + N, + bool_type, + obs_status, + obs_name, + reduced_system=True, + sensitivity = 0.80, + specificity = 0.99): + + self.name=obs_name + + DataInformedObservation.__init__(self, + N, + bool_type, + obs_status, + reduced_system) + + TestMeasurement.__init__(self, + obs_status, + sensitivity, + specificity, + reduced_system) + + #State is a numpy array of size [self.N * n_status] + def find_observation_states(self, + contact_network, + state, + data): + # obtain where one should make an observation based on the + # current state, and the contact network + DataInformedObservation.find_observation_states(self, + contact_network, + state, + data) + + # data is a dictionary {node number : status} data[i] = contact_network.node(i) + def observe(self, + contact_network, + state, + data, + scale = 'log', + noisy_measurement = False): + + #do not set mean = 1... + observed_mean = (1-0.05/6) * np.ones(self.obs_states.size) + observed_variance = 1e-5 * np.ones(self.obs_states.size) + + if scale == 'log': + observed_variance = (1.0/observed_mean/(1-observed_mean))**2 * observed_variance + observed_mean = np.log(observed_mean/(1 - observed_mean + 1e-8)) + + self.mean = observed_mean + self.variance = observed_variance + +# ============================================================================== +class DataNodeInformedObservation(DataInformedObservation): + """ + This class makes perfect observations for statuses like `H` or `D`. + The information is spread to the other possible states for each observed node as the DA can only update one state. + This means that observing, for example, H = 1 propagates the information to the other states as S = I = R = D = 0. + """ + def __init__(self, + N, + bool_type, + obs_status, + reduced_system): + DataInformedObservation.__init__(self, N, bool_type, obs_status, reduced_system) + + #updates the observation model when taking observation + def find_observation_states(self, + contact_network, + state, + data): + + DataInformedObservation.find_observation_states(self, contact_network, state, data) + self.obs_nodes = self.obs_states % len(contact_network) + self.states_per_node = np.asarray([ node + len(contact_network) * np.arange(5) for node in self.obs_nodes]) + self._obs_states = np.copy(self.obs_states) + self.obs_states = self.states_per_node.flatten() + + +class DataNodeObservation(DataNodeInformedObservation, TestMeasurement): + + def __init__(self, + N, + bool_type, + obs_status, + obs_name, + reduced_system=True, + sensitivity = 0.80, + specificity = 0.99): + + self.name=obs_name + + DataNodeInformedObservation.__init__(self, + N, + bool_type, + obs_status, + reduced_system) + + TestMeasurement.__init__(self, + obs_status, + sensitivity, + specificity, + reduced_system) + + #State is a numpy array of size [self.N * n_status] + def find_observation_states(self, + contact_network, + state, + data): + # obtain where one should make an observation based on the + # current state, and the contact network + DataNodeInformedObservation.find_observation_states(self, + contact_network, + state, + data) + + # data is a dictionary {node number : status} data[i] = contact_network.node(i) + def observe(self, + contact_network, + state, + data, + scale = 'log', + noisy_measurement = False): + + observed_mean = (1-0.05/6) * np.ones(self._obs_states.size) + # observed_variance = 1e-9 * np.ones(self._obs_states.size) + observed_variance = 1e-5 * np.ones(self._obs_states.size) + + if scale == 'log': + observed_variance = (1.0/observed_mean/(1-observed_mean))**2 * observed_variance + observed_mean = np.log(observed_mean/(1 - observed_mean + 1e-8)) + + observed_means = (0.01/6) * np.ones_like(self.states_per_node) + observed_variances = observed_variance[0] * np.ones_like(self.states_per_node) + + observed_means[:, self.obs_status_idx] = observed_mean.reshape(-1,1) + + self.mean = observed_means.flatten() + self.variance = observed_variances.flatten() diff --git a/examples/simple_epidemic_with_da_health_and_death_records.py b/examples/simple_epidemic_with_da_health_and_death_records.py new file mode 100644 index 00000000..483277be --- /dev/null +++ b/examples/simple_epidemic_with_da_health_and_death_records.py @@ -0,0 +1,346 @@ +import os, sys; sys.path.append(os.path.join("..")) + +from timeit import default_timer as timer + +import networkx as nx +import numpy as np +import pandas as pd +import random +import datetime as dt +import matplotlib.dates as mdates +import matplotlib.ticker as ticker +import matplotlib.pyplot as plt + +from numba import set_num_threads + +set_num_threads(1) + +# Utilities for generating random populations +# Utilities for generating random populations +from epiforecast.populations import assign_ages, sample_distribution, TransitionRates +from epiforecast.samplers import GammaSampler, AgeDependentBetaSampler, AgeDependentConstant + +from epiforecast.contact_simulator import DiurnalContactInceptionRate + +from epiforecast.scenarios import load_edges, random_epidemic + +from epiforecast.epiplots import plot_ensemble_states, plot_kinetic_model_data, plot_scalar_parameters + + +from epiforecast.node_identifier_helper import load_node_identifiers +from epiforecast.risk_simulator import MasterEquationModelEnsemble +from epiforecast.epidemic_simulator import EpidemicSimulator +from epiforecast.health_service import HealthService +from epiforecast.measurements import Observation, DataObservation, DataNodeObservation +from epiforecast.data_assimilator import DataAssimilator + +from epiforecast.utilities import seed_numba_random_state + +def random_risk(contact_network, fraction_infected = 0.01, ensemble_size=1): + + population = len(contact_network) + states_ensemble = np.zeros([ensemble_size, 5 * population]) + for mm in range(ensemble_size): + infected = np.random.choice(population, replace = False, size = int(population * fraction_infected)) + E, I, H, R, D = np.zeros([5, population]) + S = np.ones(population,) + I[infected] = 1. + S[infected] = 0. + + states_ensemble[mm, : ] = np.hstack((S, I, H, R, D)) + + return states_ensemble + +def deterministic_risk(contact_network, initial_states, ensemble_size=1): + + population = len(contact_network) + states_ensemble = np.zeros([ensemble_size, 5 * population]) + + init_catalog = {'S': False, 'I': True} + infected = np.array([init_catalog[status] for status in list(initial_states.values())]) + + for mm in range(ensemble_size): + E, I, H, R, D = np.zeros([5, population]) + S = np.ones(population,) + I[infected] = 1. + S[infected] = 0. + + states_ensemble[mm, : ] = np.hstack((S, I, H, R, D)) + + return states_ensemble + +# +# Set random seeds for reproducibility +# + +# Both numpy.random and random are used by the KineticModel. +seed = 212212 + +np.random.seed(seed) +random.seed(seed) + +# set numba seed + +seed_numba_random_state(seed) + +# +# Load an example network +# + +edges = load_edges(os.path.join('..', 'data', 'networks', 'edge_list_SBM_1e3_nobeds.txt')) +node_identifiers = load_node_identifiers(os.path.join('..', 'data', 'networks', 'node_identifier_SBM_1e3_nobeds.txt')) + +contact_network = nx.Graph() +contact_network.add_edges_from(edges) +contact_network = nx.convert_node_labels_to_integers(contact_network) +population = len(contact_network) + +# +# Build the contact simulator +# +start_time = -3 / 24 + +minute = 1 / 60 / 24 +hour = 60 * minute + +# +# Clinical parameters of an age-distributed population +# + +assign_ages(contact_network, distribution=[0.21, 0.4, 0.25, 0.08, 0.06]) + +# We process the clinical data to determine transition rates between each epidemiological state, +transition_rates = TransitionRates(contact_network, + latent_periods = 3.7, + community_infection_periods = 3.2, + hospital_infection_periods = 5.0, + hospitalization_fraction = AgeDependentConstant([0.002, 0.01, 0.04, 0.076, 0.16]), + community_mortality_fraction = AgeDependentConstant([ 1e-4, 1e-3, 0.001, 0.07, 0.015]), + hospital_mortality_fraction = AgeDependentConstant([0.019, 0.073, 0.193, 0.327, 0.512]) +) + +community_transmission_rate = 12.0 +hospital_transmission_reduction = 0.1 + +# +# Simulate the growth and equilibration of an epidemic +# +static_contact_interval = 6 * hour +simulation_length = 30 + +health_service = HealthService(patient_capacity = int(0.05 * len(contact_network)), + health_worker_population = len(node_identifiers['health_workers']), + static_population_network = contact_network) + + + +mean_contact_lifetime=0.5*minute + +epidemic_simulator = EpidemicSimulator( + contact_network = contact_network, + mean_contact_lifetime = mean_contact_lifetime, + contact_inception_rate = DiurnalContactInceptionRate(minimum = 2, maximum = 22), + transition_rates = transition_rates, + community_transmission_rate = community_transmission_rate, + hospital_transmission_reduction = hospital_transmission_reduction, + static_contact_interval = static_contact_interval, + health_service = health_service, + start_time = start_time + ) +ensemble_size = 20 #100 # minimum number for an 'ensemble' + +# # We process the clinical data to determine transition rates between each epidemiological state, +# transition_rates_ensemble = [] +# for i in range(ensemble_size): +# transition_rates_ensemble.append( +# TransitionRates(contact_network, +# latent_periods = np.random.normal(3.7,0.37), +# community_infection_periods = np.random.normal(3.2,0.32), +# hospital_infection_periods = np.random.normal(5.0,0.5), +# hospitalization_fraction = AgeDependentBetaSampler(mean=[0.002, 0.01, 0.04, 0.075, 0.16], b=4), +# community_mortality_fraction = AgeDependentBetaSampler(mean=[ 1e-4, 1e-3, 0.003, 0.01, 0.02], b=4), +# hospital_mortality_fraction = AgeDependentBetaSampler(mean=[0.019, 0.075, 0.195, 0.328, 0.514], b=4) +# ) +# ) +# + +transition_rates_ensemble = [] +for i in range(ensemble_size): + transition_rates_ensemble.append( + TransitionRates(contact_network, + latent_periods = np.random.normal(3.7,0.37), + community_infection_periods = np.random.normal(3.2, 0.32), + hospital_infection_periods = np.random.normal(5.0, 0.50), + hospitalization_fraction = transition_rates.hospitalization_fraction, + community_mortality_fraction = transition_rates.community_mortality_fraction, + hospital_mortality_fraction = transition_rates.hospital_mortality_fraction + ) + ) + + +#set transmission_rates +community_transmission_rate_ensemble = np.random.normal(12.0,1.0, size=(ensemble_size,1)) + +master_eqn_ensemble = MasterEquationModelEnsemble(contact_network = contact_network, + transition_rates = transition_rates_ensemble, + transmission_rate = community_transmission_rate_ensemble, + hospital_transmission_reduction = hospital_transmission_reduction, + ensemble_size = ensemble_size) + +#### + +medical_infection_test = Observation(N = population, + obs_frac = 1.00, + obs_status = 'I', + obs_name = "0.25 < Infected(100%) < 0.75", + min_threshold=0.25, + max_threshold=0.75) + +random_infection_test = Observation(N = population, + obs_frac = 0.01, + obs_status = 'I', + obs_name = "Random Infection Test") + +hospital_records = DataNodeObservation(N = population, + bool_type=True, + obs_status = 'H', + obs_name = "Hospitalized from Data", + specificity = 0.999, + sensitivity = 0.999) + +death_records = DataNodeObservation(N = population, + bool_type=True, + obs_status = 'D', + obs_name = "Deceased from Data", + specificity = 0.999, + sensitivity = 0.999) + + +# give the data assimilator the methods for how to choose observed states +# observations=[medical_infection_test, random_infection_test, hospital_records, death_records] +# observations=[medical_infection_test] +observations=[random_infection_test, hospital_records, death_records] + +# give the data assimilator which transition rates and transmission rate to assimilate +transition_rates_to_update_str=['latent_periods', 'community_infection_periods', 'hospital_infection_periods'] +transmission_rate_to_update_flag=True + +# create the assimilator +assimilator = DataAssimilator(observations = observations, + errors = [], + transition_rates_to_update_str= transition_rates_to_update_str, + transmission_rate_to_update_flag = transmission_rate_to_update_flag) + +time = start_time + +statuses = random_epidemic(contact_network, + fraction_infected=0.01) + +states_ensemble = random_risk(contact_network, + fraction_infected = 0.01, + ensemble_size = ensemble_size) + +# states_ensemble = deterministic_risk(contact_network, +# statuses, +# ensemble_size = ensemble_size) + +epidemic_simulator.set_statuses(statuses) +master_eqn_ensemble.set_states_ensemble(states_ensemble) + +# print(static_contact_interval) +# print(int(simulation_length/static_contact_interval)) + +# fig, axes = plt.subplots(1, 2, figsize = (15, 5)) +fig, axes = plt.subplots(1, 3, figsize = (16, 4)) + +transition_rates_to_update_str=['latent_periods', 'community_infection_periods', 'hospital_infection_periods'] + +community_transmission_rate_trace = np.copy(community_transmission_rate_ensemble) +latent_periods_trace = np.copy(np.array([member.latent_periods for member in transition_rates_ensemble]).reshape(-1,1)) +community_infection_periods_trace = np.copy(np.array([member.community_infection_periods for member in transition_rates_ensemble]).reshape(-1,1)) +hospital_infection_periods_trace = np.copy(np.array([member.hospital_infection_periods for member in transition_rates_ensemble]).reshape(-1,1)) + +for i in range(int(simulation_length/static_contact_interval)): + + epidemic_simulator.run(stop_time = epidemic_simulator.time + static_contact_interval) + #Within the epidemic_simulator: + # health_service discharge and admit patients [changes the contact network] + # contact_simulator run [changes the mean contact duration on the given network] + # set the new contact rates on the network! kinetic_model.set_mean_contact_duration(contact_duration) + # run the kinetic model [kinetic produces the current statuses used as data] + + # as kinetic sets the weights, we do not need to update the contact network. + # run the master equation model [master eqn produces the current states of the risk model] + master_eqn_ensemble.set_mean_contact_duration() #do not need to reset weights as already set in kinetic model + # would love to double check this! ^ + states_ensemble = master_eqn_ensemble.simulate(static_contact_interval, n_steps = 25) + + if i % 1 == 0: + # perform data assimlation [update the master eqn states, the transition rates, and the transmission rate (if supplied)] + (states_ensemble, + transition_rates_ensemble, + community_transmission_rate_ensemble + ) = assimilator.update(ensemble_state = states_ensemble, + data = epidemic_simulator.kinetic_model.current_statuses, + full_ensemble_transition_rates = transition_rates_ensemble, + full_ensemble_transmission_rate = community_transmission_rate_ensemble, + user_network = contact_network) + + #update model parameters (transition and transmission rates) of the master eqn model + master_eqn_ensemble.update_transition_rates(transition_rates_ensemble) + master_eqn_ensemble.update_transmission_rate(community_transmission_rate_ensemble) + + + # for tracking purposes + community_transmission_rate_trace = np.hstack([community_transmission_rate_trace, community_transmission_rate_ensemble]) + latent_periods_trace = np.hstack([latent_periods_trace, np.array([member.latent_periods for member in transition_rates_ensemble]).reshape(-1,1)]) + community_infection_periods_trace = np.hstack([community_infection_periods_trace, np.array([member.community_infection_periods for member in transition_rates_ensemble]).reshape(-1,1)]) + hospital_infection_periods_trace = np.hstack([hospital_infection_periods_trace, np.array([member.hospital_infection_periods for member in transition_rates_ensemble]).reshape(-1,1)]) + + #update states/statuses/times for next iteration + master_eqn_ensemble.set_states_ensemble(states_ensemble) + + axes = plot_ensemble_states(master_eqn_ensemble.states_trace, + master_eqn_ensemble.simulation_time, + axes = axes, + xlims = (-0.1, simulation_length), + a_min = 0.0) + + axes = plot_kinetic_model_data(epidemic_simulator.kinetic_model, + axes = axes) + + plt.savefig('da_ric_tprobs_ninfectest_whospital_wdeath_wrandtest_nodedata_nomodelerror.png', rasterized=True, dpi=150) + +time_horizon = np.linspace(0.0, simulation_length, int(simulation_length/static_contact_interval) + 1) +parameters = [community_transmission_rate_trace, latent_periods_trace, community_infection_periods_trace, hospital_infection_periods_trace ] +parameters_names = ['transmission_rates', 'latent_periods', 'community_infection_periods', 'hospital_infection_periods'] + +axes = plot_scalar_parameters(parameters, time_horizon, parameters_names) +plt.savefig('da_parameters_ric_tprobs_ninfectest_whospital_wdeath_wrandtest_nodedata_nomodelerror.png', rasterized=True, dpi=150) + + +# np.savetxt("../data/simulation_data/simulation_data_NYC_DA_1e3.txt", np.c_[kinetic_model.times, kinetic_model.statuses['S'], kinetic_model.statuses['E'], kinetic_model.statuses['I'], kinetic_model.statuses['H'], kinetic_model.statuses['R'],kinetic_model.statuses['D']], header = 'S E I H R D seed: %d'%seed) + +# # plot all model compartments +# fig, axs = plt.subplots(nrows=2, sharex=True) + +# plt.sca(axs[0]) +# plt.plot(kinetic_model.times, kinetic_model.statuses['S']) +# plt.ylabel("Total susceptible, $S$") + +# plt.sca(axs[1]) +# plt.plot(kinetic_model.times, kinetic_model.statuses['E'], label='Exposed') +# plt.plot(kinetic_model.times, kinetic_model.statuses['I'], label='Infected') +# plt.plot(kinetic_model.times, kinetic_model.statuses['H'], label='Hospitalized') +# plt.plot(kinetic_model.times, kinetic_model.statuses['R'], label='Resistant') +# plt.plot(kinetic_model.times, kinetic_model.statuses['D'], label='Deceased') + +# plt.xlabel("Time (days)") +# plt.ylabel("Total $E, I, H, R, D$") +# plt.legend() + +# image_path = ("../figs/simple_epidemic_with_slow_contact_simulator_" + +# "maxlambda_{:d}.png".format(contact_simulator.mean_contact_rate.maximum_i)) + +# print("Saving a visualization of results at", image_path) +# plt.savefig(image_path, dpi=480) diff --git a/examples/simple_measurements.py b/examples/simple_measurements.py index da631a5f..77dc38b9 100644 --- a/examples/simple_measurements.py +++ b/examples/simple_measurements.py @@ -57,9 +57,9 @@ y0[mm, : ] = np.hstack((S, I, H, R, D)) tF = 35 -res = ensemble_model.simulate(y0, tF, n_steps = 100) - -ode_states = res['states'][:,:,-1] +ensemble_model.set_states_ensemble(y0) +ensemble_model.set_mean_contact_duration() +ode_states = ensemble_model.simulate(tF, n_steps = 100) def random_state(population): """ @@ -91,12 +91,19 @@ def random_state(population): print(np.vstack([np.array(list(statuses.values())), list(mean.values()), list(var.values())]).T[:5]) print('\n3th Test: Hospitalized --------------------------------------') -test = TestMeasurement('H', specificity = 1., sensitivity = 1.) +test = TestMeasurement('H', specificity = .999, sensitivity = 0.999) test.update_prevalence(ode_states, scale = None) mean, var = test.take_measurements(statuses, scale = None) print(np.vstack([np.array(list(statuses.values())), list(mean.values()), list(var.values())]).T[47:47+6]) +print('\n4th Test: Hospitalized --------------------------------------') +test = TestMeasurement('H', specificity = .999, sensitivity = 0.999) +test.update_prevalence(ode_states) +mean, var = test.take_measurements(statuses) + +print(np.vstack([np.array(list(statuses.values())), list(mean.values()), list(var.values())]).T[47:47+6]) + print('\n4th Test: Noisy measurements for positive cases -------------') test = TestMeasurement('I') @@ -123,6 +130,14 @@ def random_state(population): print('Fraction of correct testing: %2.2f'%(np.array(list(mean.values())) == negative_test).mean()) +hospital_records = DataObservation(N = population, + bool_type=True, + obs_status = 'H', + obs_name = "Hospitalized from Data", + specificity = 0.999, + sensitivity = 0.999) + + # print('\n5th Test: Hospitalized --------------------------------------') # test = TestMeasurement(specificity = 1., sensitivity = 1.) # test.update_prevalence(ode_states, scale = 'log', status = 'H')