Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an SEIRD example with custom .compute_flows() #2559

Merged
merged 10 commits into from
Jul 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions pyro/contrib/epidemiology/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,113 @@ def transition(self, params, state, t):
obs=self.data[t] if t_is_observed else None)


class SimpleSEIRDModel(CompartmentalModel):
"""
Susceptible-Exposed-Infected-Recovered-Dead model.

To customize this model we recommend forking and editing this class.

This is a stochastic discrete-time discrete-state model with four
compartments: "S" for susceptible, "E" for exposed, "I" for infected, "D"
for deceased individuals, and "R" for recovered individuals (the recovered
individuals are implicit: ``R = population - S - E - I - D``) with
transitions ``S -> E -> I -> R`` and ``I -> D``.

Because the transitions are not simple linear succession, this model
implements a custom :meth:`compute_flows()` method.

:param int population: Total ``population = S + E + I + R + D``.
:param float incubation_time: Mean incubation time (duration in state
``E``). Must be greater than 1.
:param float recovery_time: Mean recovery time (duration in state
``I``). Must be greater than 1.
:param float mortality_rate: Portion of infections resulting in death.
Must be in the open interval ``(0, 1)``.
:param iterable data: Time series of new observed infections. Each time
step is Binomial distributed between 0 and the number of ``S -> E``
transitions. This allows false negative but no false positives.
"""

def __init__(self, population, incubation_time, recovery_time,
mortality_rate, data):
compartments = ("S", "E", "I", "D") # R is implicit.
duration = len(data)
super().__init__(compartments, duration, population)

assert isinstance(incubation_time, float)
assert incubation_time > 1
self.incubation_time = incubation_time

assert isinstance(recovery_time, float)
assert recovery_time > 1
self.recovery_time = recovery_time

assert isinstance(mortality_rate, float)
assert 0 < mortality_rate < 1
self.mortality_rate = mortality_rate

self.data = data

def global_model(self):
tau_e = self.incubation_time
tau_i = self.recovery_time
mu = self.mortality_rate
R0 = pyro.sample("R0", dist.LogNormal(0., 1.))
rho = pyro.sample("rho", dist.Beta(10, 10))
return R0, tau_e, tau_i, mu, rho

def initialize(self, params):
# Start with a single infection.
return {"S": self.population - 1, "E": 0, "I": 1, "D": 0}

def transition(self, params, state, t):
R0, tau_e, tau_i, mu, rho = params

# Sample flows between compartments.
S2E = pyro.sample("S2E_{}".format(t),
infection_dist(individual_rate=R0 / tau_i,
num_susceptible=state["S"],
num_infectious=state["I"],
population=self.population))
E2I = pyro.sample("E2I_{}".format(t),
binomial_dist(state["E"], 1 / tau_e))
# Of the 1/tau_i expected recoveries-or-deaths, a portion mu die and
# the remaining recover. Alternatively we could model this with a
# Multinomial distribution I2_ and extract the two components I2D and
# I2R, however the Multinomial distribution does not currently
# implement overdispersion or moment matching.
I2D = pyro.sample("I2D_{}".format(t),
binomial_dist(state["I"], mu / tau_i))
I2R = pyro.sample("I2R_{}".format(t),
binomial_dist(state["I"] - I2D, 1 / tau_i))

# Update compartments with flows.
state["S"] = state["S"] - S2E
state["E"] = state["E"] + S2E - E2I
state["I"] = state["I"] + E2I - I2R - I2D
state["D"] = state["D"] + I2D

# Condition on observations.
t_is_observed = isinstance(t, slice) or t < self.duration
pyro.sample("obs_{}".format(t),
binomial_dist(S2E, rho),
obs=self.data[t] if t_is_observed else None)

def compute_flows(self, prev, curr, t):
S2E = prev["S"] - curr["S"] # S can only go to E.
I2D = curr["D"] - prev["D"] # D can only have come from I.
# We deduce the remaining flows by conservation of mass:
# curr - prev = inflows - outflows
E2I = prev["E"] - curr["E"] + S2E
I2R = prev["I"] - curr["I"] + E2I - I2D
return {
"S2E_{}".format(t): S2E,
"E2I_{}".format(t): E2I,
"I2D_{}".format(t): I2D,
"I2R_{}".format(t): I2R,
}


class OverdispersedSIRModel(CompartmentalModel):
"""
Generalizes :class:`SimpleSIRModel` with overdispersed distributions.
Expand Down
44 changes: 42 additions & 2 deletions tests/contrib/epidemiology/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import pyro.distributions as dist
from pyro.contrib.epidemiology.models import (HeterogeneousRegionalSIRModel, HeterogeneousSIRModel,
OverdispersedSEIRModel, OverdispersedSIRModel, RegionalSIRModel,
SimpleSEIRModel, SimpleSIRModel, SparseSIRModel, SuperspreadingSEIRModel,
SuperspreadingSIRModel, UnknownStartSIRModel)
SimpleSEIRDModel, SimpleSEIRModel, SimpleSIRModel, SparseSIRModel,
SuperspreadingSEIRModel, SuperspreadingSIRModel, UnknownStartSIRModel)
from tests.common import xfail_param

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -111,6 +111,46 @@ def test_simple_seir_smoke(duration, forecast, options, algo):
assert samples["I"].shape == (num_samples, duration + forecast)


@pytest.mark.parametrize("duration", [3, 7])
@pytest.mark.parametrize("forecast", [0, 7])
@pytest.mark.parametrize("algo,options", [
("svi", {}),
("mcmc", {}),
("mcmc", {"haar_full_mass": 2}),
], ids=str)
def test_simple_seird_smoke(duration, forecast, options, algo):
population = 100
incubation_time = 2.0
recovery_time = 7.0
mortality_rate = 0.1

# Generate data.
model = SimpleSEIRDModel(population, incubation_time, recovery_time,
mortality_rate, [None] * duration)
assert model.full_mass == [("R0", "rho")]
for attempt in range(100):
data = model.generate({"R0": 1.5, "rho": 0.5})["obs"]
if data.sum():
break
assert data.sum() > 0, "failed to generate positive data"

# Infer.
model = SimpleSEIRDModel(population, incubation_time, recovery_time,
mortality_rate, data)
num_samples = 5
if algo == "mcmc":
model.fit_mcmc(warmup_steps=1, num_samples=num_samples, max_tree_depth=2, **options)
else:
model.fit_svi(num_steps=2, num_samples=num_samples, **options)

# Predict and forecast.
samples = model.predict(forecast=forecast)
assert samples["S"].shape == (num_samples, duration + forecast)
assert samples["E"].shape == (num_samples, duration + forecast)
assert samples["I"].shape == (num_samples, duration + forecast)
assert samples["D"].shape == (num_samples, duration + forecast)


@pytest.mark.parametrize("duration", [3])
@pytest.mark.parametrize("forecast", [7])
@pytest.mark.parametrize("options", [
Expand Down
4 changes: 3 additions & 1 deletion tutorial/source/epi_intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,9 @@
"```diff\n",
"+ state[\"S\"] = state[\"S\"] - S2I # Correct\n",
"- state[\"S\"] -= S2I # AVOID: may corrupt tensors\n",
"```"
"```\n",
"\n",
"For a slightly more complex example, take a look at the [SimpleSEIRDModel](http://docs.pyro.ai/en/latest/contrib.epidemiology.html#pyro.contrib.epidemiology.models.SimpleSEIRDModel)."
]
},
{
Expand Down