-
-
Notifications
You must be signed in to change notification settings - Fork 985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add SIR model with unknown time of initial infection #2460
Changes from all commits
bbbd486
43480e7
7a21fc5
d4cc9fa
3f9c925
29825c8
3d73031
27b20d1
fad40cd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,12 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import torch | ||
from torch.nn.functional import pad | ||
|
||
import pyro | ||
import pyro.distributions as dist | ||
from pyro.ops.indexing import Index | ||
|
||
from .compartmental import CompartmentalModel | ||
from .distributions import infection_dist | ||
|
@@ -318,3 +322,164 @@ def transition_bwd(self, params, prev, curr, t): | |
pyro.sample("obs_{}".format(t), | ||
dist.Delta(curr["O"]).mask(self.mask[t]), | ||
obs=self.data[t]) | ||
|
||
|
||
class UnknownStartSIRModel(CompartmentalModel): | ||
""" | ||
Susceptible-Infected-Recovered model with unknown date of first infection. | ||
|
||
To customize this model we recommend forking and editing this class. | ||
|
||
This is a stochastic discrete-time discrete-state model with three | ||
compartments: "S" for susceptible, "I" for infected, and "R" for | ||
recovered individuals (the recovered individuals are implicit: ``R = | ||
population - S - I``) with transitions ``S -> I -> R``. | ||
|
||
This model demonstrates: | ||
|
||
1. How to incorporate spontaneous infections from external sources; | ||
2. How to incorporate time-varying piecewise ``rho`` by supporting | ||
forecasting in :meth:`transition_fwd` and using the | ||
:class:`~pyro.ops.index.ing.Index` helper in :meth:`transition_bwd`. | ||
3. How to override the :meth:`predict` method to compute extra | ||
statistics. | ||
|
||
:param int population: Total ``population = S + I + R``. | ||
:param float recovery_time: Mean recovery time (duration in state | ||
``I``). Must be greater than 1. | ||
:param int pre_obs_window: Number of time steps before beginning ``data`` | ||
where the initial infection may have occurred. Must be positive. | ||
:param iterable data: Time series of new observed infections. Each time | ||
step is Binomial distributed between 0 and the number of ``S -> I`` | ||
transitions. This allows false negative but no false positives. | ||
""" | ||
|
||
def __init__(self, population, recovery_time, pre_obs_window, data): | ||
compartments = ("S", "I") # R is implicit. | ||
duration = pre_obs_window + len(data) | ||
super().__init__(compartments, duration, population) | ||
|
||
assert isinstance(recovery_time, float) | ||
assert recovery_time > 1 | ||
self.recovery_time = recovery_time | ||
|
||
assert isinstance(pre_obs_window, int) and pre_obs_window > 0 | ||
self.pre_obs_window = pre_obs_window | ||
self.post_obs_window = len(data) | ||
|
||
# We set a small time-constant external infecton rate such that on | ||
# average there is a single external infection during the | ||
# pre_obs_window. This allows unknown time of initial infection | ||
# without introducing long-range coupling across time. | ||
self.external_rate = 1 / pre_obs_window | ||
|
||
# Prepend data with zeros. | ||
if isinstance(data, list): | ||
data = [0.] * self.pre_obs_window + data | ||
else: | ||
data = pad(data, (self.pre_obs_window, 0), value=0.) | ||
self.data = data | ||
|
||
series = ("S2I", "I2R", "obs") | ||
full_mass = [("R0", "rho0", "rho1")] | ||
|
||
def global_model(self): | ||
tau = self.recovery_time | ||
R0 = pyro.sample("R0", dist.LogNormal(0., 1.)) | ||
|
||
# Assume two different response rates: rho0 before any observations | ||
# were made (in pre_obs_window), followed by a higher response rate rho1 | ||
# after observations were made (in post_obs_window). | ||
rho0 = pyro.sample("rho0", dist.Uniform(0, 1)) | ||
rho1 = pyro.sample("rho1", dist.Uniform(0, 1)) | ||
# Whereas each of rho0,rho1 are scalars (possibly batched over samples), | ||
# we construct a time series rho with an extra time dim on the right. | ||
rho = torch.cat([ | ||
rho0.unsqueeze(-1).expand(rho0.shape + (self.pre_obs_window,)), | ||
rho1.unsqueeze(-1).expand(rho1.shape + (self.post_obs_window,)), | ||
], dim=-1) | ||
|
||
# Model external infections as an infectious pseudo-individual added | ||
# to num_infectious when sampling S2I below. | ||
X = self.external_rate * tau / R0 | ||
|
||
return R0, X, tau, rho | ||
|
||
def initialize(self, params): | ||
# Start with no internal infections. | ||
return {"S": self.population, "I": 0} | ||
|
||
def transition_fwd(self, params, state, t): | ||
R0, X, tau, rho = params | ||
|
||
# Sample flows between compartments. | ||
S2I = pyro.sample("S2I_{}".format(t), | ||
infection_dist(individual_rate=R0 / tau, | ||
num_susceptible=state["S"], | ||
num_infectious=state["I"] + X, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i guess the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct. That modeling decision seems more parsimonious, simpler to code, and quantitatively close to the alternative (one initial infection at an unknown time) because during an outbreak, new infections should be dominated by local internal infections. |
||
population=self.population)) | ||
I2R = pyro.sample("I2R_{}".format(t), | ||
dist.Binomial(state["I"], 1 / tau)) | ||
|
||
# Update compartments with flows. | ||
state["S"] = state["S"] - S2I | ||
state["I"] = state["I"] + S2I - I2R | ||
|
||
# In .transition_fwd() t will always be an integer but may lie outside | ||
# of [0,self.duration) when forecasting. | ||
rho_t = rho[..., t] if t < self.duration else rho[..., -1] | ||
data_t = self.data[t] if t < self.duration else None | ||
|
||
# Condition on observations. | ||
pyro.sample("obs_{}".format(t), | ||
dist.ExtendedBinomial(S2I, rho_t), | ||
obs=data_t) | ||
|
||
def transition_bwd(self, params, prev, curr, t): | ||
R0, X, tau, rho = params | ||
|
||
# Reverse the flow computation. | ||
S2I = prev["S"] - curr["S"] | ||
I2R = prev["I"] - curr["I"] + S2I | ||
|
||
# Condition on flows between compartments. | ||
pyro.sample("S2I_{}".format(t), | ||
infection_dist(individual_rate=R0 / tau, | ||
num_susceptible=prev["S"], | ||
num_infectious=prev["I"] + X, | ||
population=self.population), | ||
obs=S2I) | ||
pyro.sample("I2R_{}".format(t), | ||
dist.ExtendedBinomial(prev["I"], 1 / tau), | ||
obs=I2R) | ||
|
||
# In .transition_bwd() t may be either an integer in [0,self.duration) | ||
# or may be a nonstandard slicing object like (Ellipsis, None, None); | ||
# we use the Index(-)[-] helper to support indexing with nonstandard t. | ||
rho_t = Index(rho)[..., t] | ||
|
||
# Condition on observations. | ||
pyro.sample("obs_{}".format(t), | ||
dist.ExtendedBinomial(S2I, rho_t), | ||
obs=self.data[t]) | ||
|
||
def predict(self, forecast=0): | ||
""" | ||
Augments | ||
:meth:`~pyro.contrib.epidemiology.compartmental.Compartmental.predict` | ||
with samples of ``first_infection`` i.e. the first time index at which | ||
the infection ``I`` becomes nonzero. Note this is measured from the | ||
beginning of ``pre_obs_window``, not the beginning of data. | ||
|
||
:param int forecast: The number of time steps to forecast forward. | ||
:returns: A dictionary mapping sample site name (or compartment name) | ||
to a tensor whose first dimension corresponds to sample batching. | ||
:rtype: dict | ||
""" | ||
samples = super().predict(forecast) | ||
|
||
# Extract the time index of the first infection (samples["I"] > 0) | ||
# for each sample trajectory in the samples["I"] tensor. | ||
samples["first_infection"] = samples["I"].cumsum(-1).eq(0).sum(-1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you please comment on this quantity? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
return samples |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comment on shape of rho?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done