Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Epidemic Data Storage with example #110

Merged
merged 10 commits into from
Jun 16, 2020
7 changes: 3 additions & 4 deletions epiforecast/data_assimilator.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,6 @@ def sum_to_one(self, prev_ensemble_state, ensemble_state):
free_mass = np.sum(free_states,axis=1) + Emass[:,observed_nodes]

# normalize the free values e.g for S: set S = (1-I) * S/(S+E+H+R+D)
for i in free_statuses:

ensemble_state[:, i*N+observed_nodes] = (1.0 - updated_mass[:,0,:]) * (free_states[:, i, :] / free_mass)

if free_mass > 0:
for i in free_statuses:
ensemble_state[:, i*N+observed_nodes] = (1.0 - updated_mass[:,0,:]) * (free_states[:, i, :] / free_mass)
86 changes: 86 additions & 0 deletions epiforecast/network_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import copy
from collections import namedtuple

class StaticNetwork:
dburov190 marked this conversation as resolved.
Show resolved Hide resolved
"""
A container to hold a static contact network, it contains
The network, when it used in the simulation, and the data (statuses)
at these times.
"""
def __init__(self,
contact_network,
start_time,
end_time):

self.contact_network = copy.deepcopy(contact_network)
self.start_time = copy.deepcopy(start_time)
self.end_time = copy.deepcopy(end_time)

def set_end_statuses(self, end_statuses):
self.end_statuses = copy.deepcopy(end_statuses)

def set_start_statuses(self, start_statuses):
self.start_statuses = copy.deepcopy(start_statuses)

class StaticNetworkSeries:
"""
A container to hold a series of StaticNetwork objects. It stores the networks as a
dictionary with keys given by a named tuple StartEndTime which will set/get networks based
on the provided start_time, end_time or both
"""
def __init__(self, static_contact_interval):
"""
Args
----
static_contact_interval (float): the fixed duration at which the network is static. (so we can
deduce end time from start time, start_time from end time).
"""
self.static_network_series={}
self.static_contact_interval=static_contact_interval
self.StartEndTime = namedtuple("StartEndTime",["start","end"])
dburov190 marked this conversation as resolved.
Show resolved Hide resolved

def save_network_by_start_time(self,
contact_network,
start_time):

end_time = start_time+self.static_contact_interval
start_end_time = self.StartEndTime(start=start_time, end=end_time)
new_network = StaticNetwork(contact_network,
start_time,
end_time)

self.static_network_series[start_end_time] = new_network
dburov190 marked this conversation as resolved.
Show resolved Hide resolved

def save_network_by_end_time(self,
contact_network,
end_time):

start_end_time = self.StartEndTime(start=end_time-self.static_contact_interval, end=end_time)
new_network = StaticNetwork(contact_network,
start_time,
end_time)

self.static_network_series[start_end_time] = new_network

def save_end_statuses_to_network(self,
end_time,
end_statuses):
start_end_time = next(filter(lambda keys: abs(keys.end - end_time) < 1e-8, self.static_network_series.keys()))
dburov190 marked this conversation as resolved.
Show resolved Hide resolved
self.static_network_series[start_end_time].set_end_statuses(end_statuses)
dburov190 marked this conversation as resolved.
Show resolved Hide resolved

def save_start_statuses_to_network(self,
start_time,
start_statuses):
start_end_time =next(filter(lambda keys: abs(keys.start - start_time) < 1e-8, self.static_network_series.keys()))
self.static_network_series[start_end_time].set_start_statuses(start_statuses)

def get_network_from_start_time(self,
start_time):
start_end_time = next(filter(lambda keys: abs(keys.start - start_time) < 1e-8, self.static_network_series.keys()))
return self.static_network_series[start_end_time]

def get_network_from_end_time(self,
end_time):
start_end_time =next(filter(lambda keys: abs(keys.end - end_time) < 1e-8, self.static_network_series.keys()))
return self.static_network_series[start_end_time]

31 changes: 21 additions & 10 deletions epiforecast/risk_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ def set_mean_contact_duration(self, new_mean_contact_duration=None):

self.L = nx.to_scipy_sparse_matrix(self.contact_network, weight = 'exposed_by_infected')

def set_contact_network(self, new_contact_network):
self.contact_network = new_contact_network
dburov190 marked this conversation as resolved.
Show resolved Hide resolved
# Automatically reset the edge weights
self.set_mean_contact_duration()
dburov190 marked this conversation as resolved.
Show resolved Hide resolved

def update_transmission_rate(self, new_transmission_rate):
"""
new_transmission_rate : `np.array` of length `ensemble_size`
Expand Down Expand Up @@ -282,8 +287,10 @@ def simulate(self, time_window, n_steps = 50, closure = 'independent', **kwargs)

return self.y0

def simulate_backwards(self, y0, stop_time, n_steps = 100, start_time = 0.0, closure = 'independent', **kwargs):
def simulate_backwards(self, time_window, n_steps = 100, closure = 'independent', **kwargs):
"""
Note: start time > stop time...

Args:
-------
y0 : `np.array` of initial states for simulation of size (M, 5 times N)
dburov190 marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -292,23 +299,27 @@ def simulate_backwards(self, y0, stop_time, n_steps = 100, start_time = 0.0, clo
start_time : initial time of simulation
closure : by default consider that closure = 'independent'
"""
self.tf = 0.
self.y0 = np.copy(y0)
t = np.linspace(stop_time, start_time, n_steps + 1)
self.stop_time = self.start_time - time_window
dburov190 marked this conversation as resolved.
Show resolved Hide resolved
t = np.linspace(self.start_time, self.stop_time, n_steps + 1)
self.dt = np.diff(t).min()
yt = np.empty((len(y0.flatten()), len(t)))
yt[:,0] = np.copy(y0.flatten())

yt = np.empty((len(self.y0.flatten()), len(t)))
dburov190 marked this conversation as resolved.
Show resolved Hide resolved
yt[:,0] = np.copy(self.y0.flatten())
dburov190 marked this conversation as resolved.
Show resolved Hide resolved

for jj, time in tqdm(enumerate(t[:-1]), desc = 'Simulate backward', total = n_steps):
for jj, time in tqdm(enumerate(t[:-1]),
desc = '[ Master equations ] Time window [%2.3f, %2.3f]'%(self.stop_time, self.start_time),
total = n_steps):
self.eval_closure(self.y0, closure = closure)
for mm, member in enumerate(self.ensemble):
if self.ix_reduced:
self.y0[mm] += self.dt * self.do_step(t, self.y0[mm], member, closure = closure)
else:
self.y0[mm] += self.dt * self.do_step_full(t, self.y0[mm], member, closure = closure)
self.y0[mm] = np.clip(self.y0[mm], 0., 1.)
self.tf += self.dt
yt[:,jj + 1] = np.copy(self.y0.flatten())

return {'times' : t,
'states': yt.reshape(self.M, -1, len(t))}
self.simulation_time = t
self.states_trace = yt.reshape(self.M, -1, len(t))
dburov190 marked this conversation as resolved.
Show resolved Hide resolved
self.start_time -= time_window

return self.y0
dburov190 marked this conversation as resolved.
Show resolved Hide resolved
74 changes: 74 additions & 0 deletions examples/saving_contact_networks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import os, sys; sys.path.append(os.path.join(".."))

from timeit import default_timer as timer

import networkx as nx
import numpy as np
import pandas as pd
import random
import datetime as dt
import matplotlib.dates as mdates
import matplotlib.ticker as ticker
import matplotlib.pyplot as plt

from numba import set_num_threads

set_num_threads(1)

from epiforecast.scenarios import random_epidemic
from epiforecast.network_storage import StaticNetworkSeries
#
# Set random seeds for reproducibility
#
seed = 212212
np.random.seed(seed)

contact_network = nx.barabasi_albert_graph(30000, 10)
population = len(contact_network)

time = 0.0
static_contact_interval = 0.25
simulation_length = 1

contact_network = nx.barabasi_albert_graph(1000, 10)

network_storage = StaticNetworkSeries(static_contact_interval)


statuses = random_epidemic(contact_network,
fraction_infected=0.01)

print("saving all the networks")
current_infected = [node for node in contact_network.nodes if statuses[node] == 'I']
print("infected at time", time, current_infected)

for i in range(int(simulation_length/static_contact_interval)):
#save network and start time statuses
network_storage.save_network_by_start_time(contact_network=contact_network, start_time=time)
network_storage.save_start_statuses_to_network(start_time=time, start_statuses=statuses)

#pretend we 'simuate' forward
contact_network = nx.barabasi_albert_graph(1000, 10)
statuses = random_epidemic(contact_network, fraction_infected=0.01)
time = time+static_contact_interval
current_infected = [node for node in contact_network.nodes if statuses[node] == 'I']
print("infected at time", time, current_infected)

#save end time statuses
network_storage.save_end_statuses_to_network(end_time=time, end_statuses = statuses)

print(" ")
print("loading the networks backwards by end time")
for i in range(int(simulation_length/static_contact_interval)):
net=network_storage.get_network_from_end_time(end_time=time)
current_infected = [node for node in net.contact_network.nodes if net.end_statuses[node] == 'I']
print("infected at time", net.end_time, current_infected)
time = time - static_contact_interval

print(" ")
print("loading the networks forwards by start time")
for i in range(int(simulation_length/static_contact_interval)):
net=network_storage.get_network_from_start_time(start_time=time)
current_infected = [node for node in net.contact_network.nodes if net.start_statuses[node] == 'I']
print("infected at time", net.start_time, current_infected)
time = time + static_contact_interval
Loading