Skip to content

Commit

Permalink
Add an SEIRD example with custom .compute_flows() (#2559)
Browse files Browse the repository at this point in the history
* Add more sections to epi tutorial

* Update tutorial

* Tweak ascii art

* Fix link

* Fix typos

* Add an example SEIRD model

* Add link to SimpleSEIRDModel
  • Loading branch information
fritzo authored Jul 15, 2020
1 parent acde6b5 commit b481e9a
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 3 deletions.
107 changes: 107 additions & 0 deletions pyro/contrib/epidemiology/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,113 @@ def transition(self, params, state, t):
obs=self.data[t] if t_is_observed else None)


class SimpleSEIRDModel(CompartmentalModel):
"""
Susceptible-Exposed-Infected-Recovered-Dead model.
To customize this model we recommend forking and editing this class.
This is a stochastic discrete-time discrete-state model with four
compartments: "S" for susceptible, "E" for exposed, "I" for infected, "D"
for deceased individuals, and "R" for recovered individuals (the recovered
individuals are implicit: ``R = population - S - E - I - D``) with
transitions ``S -> E -> I -> R`` and ``I -> D``.
Because the transitions are not simple linear succession, this model
implements a custom :meth:`compute_flows()` method.
:param int population: Total ``population = S + E + I + R + D``.
:param float incubation_time: Mean incubation time (duration in state
``E``). Must be greater than 1.
:param float recovery_time: Mean recovery time (duration in state
``I``). Must be greater than 1.
:param float mortality_rate: Portion of infections resulting in death.
Must be in the open interval ``(0, 1)``.
:param iterable data: Time series of new observed infections. Each time
step is Binomial distributed between 0 and the number of ``S -> E``
transitions. This allows false negative but no false positives.
"""

def __init__(self, population, incubation_time, recovery_time,
mortality_rate, data):
compartments = ("S", "E", "I", "D") # R is implicit.
duration = len(data)
super().__init__(compartments, duration, population)

assert isinstance(incubation_time, float)
assert incubation_time > 1
self.incubation_time = incubation_time

assert isinstance(recovery_time, float)
assert recovery_time > 1
self.recovery_time = recovery_time

assert isinstance(mortality_rate, float)
assert 0 < mortality_rate < 1
self.mortality_rate = mortality_rate

self.data = data

def global_model(self):
tau_e = self.incubation_time
tau_i = self.recovery_time
mu = self.mortality_rate
R0 = pyro.sample("R0", dist.LogNormal(0., 1.))
rho = pyro.sample("rho", dist.Beta(10, 10))
return R0, tau_e, tau_i, mu, rho

def initialize(self, params):
# Start with a single infection.
return {"S": self.population - 1, "E": 0, "I": 1, "D": 0}

def transition(self, params, state, t):
R0, tau_e, tau_i, mu, rho = params

# Sample flows between compartments.
S2E = pyro.sample("S2E_{}".format(t),
infection_dist(individual_rate=R0 / tau_i,
num_susceptible=state["S"],
num_infectious=state["I"],
population=self.population))
E2I = pyro.sample("E2I_{}".format(t),
binomial_dist(state["E"], 1 / tau_e))
# Of the 1/tau_i expected recoveries-or-deaths, a portion mu die and
# the remaining recover. Alternatively we could model this with a
# Multinomial distribution I2_ and extract the two components I2D and
# I2R, however the Multinomial distribution does not currently
# implement overdispersion or moment matching.
I2D = pyro.sample("I2D_{}".format(t),
binomial_dist(state["I"], mu / tau_i))
I2R = pyro.sample("I2R_{}".format(t),
binomial_dist(state["I"] - I2D, 1 / tau_i))

# Update compartments with flows.
state["S"] = state["S"] - S2E
state["E"] = state["E"] + S2E - E2I
state["I"] = state["I"] + E2I - I2R - I2D
state["D"] = state["D"] + I2D

# Condition on observations.
t_is_observed = isinstance(t, slice) or t < self.duration
pyro.sample("obs_{}".format(t),
binomial_dist(S2E, rho),
obs=self.data[t] if t_is_observed else None)

def compute_flows(self, prev, curr, t):
S2E = prev["S"] - curr["S"] # S can only go to E.
I2D = curr["D"] - prev["D"] # D can only have come from I.
# We deduce the remaining flows by conservation of mass:
# curr - prev = inflows - outflows
E2I = prev["E"] - curr["E"] + S2E
I2R = prev["I"] - curr["I"] + E2I - I2D
return {
"S2E_{}".format(t): S2E,
"E2I_{}".format(t): E2I,
"I2D_{}".format(t): I2D,
"I2R_{}".format(t): I2R,
}


class OverdispersedSIRModel(CompartmentalModel):
"""
Generalizes :class:`SimpleSIRModel` with overdispersed distributions.
Expand Down
44 changes: 42 additions & 2 deletions tests/contrib/epidemiology/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import pyro.distributions as dist
from pyro.contrib.epidemiology.models import (HeterogeneousRegionalSIRModel, HeterogeneousSIRModel,
OverdispersedSEIRModel, OverdispersedSIRModel, RegionalSIRModel,
SimpleSEIRModel, SimpleSIRModel, SparseSIRModel, SuperspreadingSEIRModel,
SuperspreadingSIRModel, UnknownStartSIRModel)
SimpleSEIRDModel, SimpleSEIRModel, SimpleSIRModel, SparseSIRModel,
SuperspreadingSEIRModel, SuperspreadingSIRModel, UnknownStartSIRModel)
from tests.common import xfail_param

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -111,6 +111,46 @@ def test_simple_seir_smoke(duration, forecast, options, algo):
assert samples["I"].shape == (num_samples, duration + forecast)


@pytest.mark.parametrize("duration", [3, 7])
@pytest.mark.parametrize("forecast", [0, 7])
@pytest.mark.parametrize("algo,options", [
("svi", {}),
("mcmc", {}),
("mcmc", {"haar_full_mass": 2}),
], ids=str)
def test_simple_seird_smoke(duration, forecast, options, algo):
population = 100
incubation_time = 2.0
recovery_time = 7.0
mortality_rate = 0.1

# Generate data.
model = SimpleSEIRDModel(population, incubation_time, recovery_time,
mortality_rate, [None] * duration)
assert model.full_mass == [("R0", "rho")]
for attempt in range(100):
data = model.generate({"R0": 1.5, "rho": 0.5})["obs"]
if data.sum():
break
assert data.sum() > 0, "failed to generate positive data"

# Infer.
model = SimpleSEIRDModel(population, incubation_time, recovery_time,
mortality_rate, data)
num_samples = 5
if algo == "mcmc":
model.fit_mcmc(warmup_steps=1, num_samples=num_samples, max_tree_depth=2, **options)
else:
model.fit_svi(num_steps=2, num_samples=num_samples, **options)

# Predict and forecast.
samples = model.predict(forecast=forecast)
assert samples["S"].shape == (num_samples, duration + forecast)
assert samples["E"].shape == (num_samples, duration + forecast)
assert samples["I"].shape == (num_samples, duration + forecast)
assert samples["D"].shape == (num_samples, duration + forecast)


@pytest.mark.parametrize("duration", [3])
@pytest.mark.parametrize("forecast", [7])
@pytest.mark.parametrize("options", [
Expand Down
4 changes: 3 additions & 1 deletion tutorial/source/epi_intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,9 @@
"```diff\n",
"+ state[\"S\"] = state[\"S\"] - S2I # Correct\n",
"- state[\"S\"] -= S2I # AVOID: may corrupt tensors\n",
"```"
"```\n",
"\n",
"For a slightly more complex example, take a look at the [SimpleSEIRDModel](http://docs.pyro.ai/en/latest/contrib.epidemiology.html#pyro.contrib.epidemiology.models.SimpleSEIRDModel)."
]
},
{
Expand Down

0 comments on commit b481e9a

Please sign in to comment.