diff --git a/pyro/contrib/epidemiology/models.py b/pyro/contrib/epidemiology/models.py index 3bc1a1e7bb..393b94f0fd 100644 --- a/pyro/contrib/epidemiology/models.py +++ b/pyro/contrib/epidemiology/models.py @@ -150,6 +150,113 @@ def transition(self, params, state, t): obs=self.data[t] if t_is_observed else None) +class SimpleSEIRDModel(CompartmentalModel): + """ + Susceptible-Exposed-Infected-Recovered-Dead model. + + To customize this model we recommend forking and editing this class. + + This is a stochastic discrete-time discrete-state model with four + compartments: "S" for susceptible, "E" for exposed, "I" for infected, "D" + for deceased individuals, and "R" for recovered individuals (the recovered + individuals are implicit: ``R = population - S - E - I - D``) with + transitions ``S -> E -> I -> R`` and ``I -> D``. + + Because the transitions are not simple linear succession, this model + implements a custom :meth:`compute_flows()` method. + + :param int population: Total ``population = S + E + I + R + D``. + :param float incubation_time: Mean incubation time (duration in state + ``E``). Must be greater than 1. + :param float recovery_time: Mean recovery time (duration in state + ``I``). Must be greater than 1. + :param float mortality_rate: Portion of infections resulting in death. + Must be in the open interval ``(0, 1)``. + :param iterable data: Time series of new observed infections. Each time + step is Binomial distributed between 0 and the number of ``S -> E`` + transitions. This allows false negative but no false positives. + """ + + def __init__(self, population, incubation_time, recovery_time, + mortality_rate, data): + compartments = ("S", "E", "I", "D") # R is implicit. + duration = len(data) + super().__init__(compartments, duration, population) + + assert isinstance(incubation_time, float) + assert incubation_time > 1 + self.incubation_time = incubation_time + + assert isinstance(recovery_time, float) + assert recovery_time > 1 + self.recovery_time = recovery_time + + assert isinstance(mortality_rate, float) + assert 0 < mortality_rate < 1 + self.mortality_rate = mortality_rate + + self.data = data + + def global_model(self): + tau_e = self.incubation_time + tau_i = self.recovery_time + mu = self.mortality_rate + R0 = pyro.sample("R0", dist.LogNormal(0., 1.)) + rho = pyro.sample("rho", dist.Beta(10, 10)) + return R0, tau_e, tau_i, mu, rho + + def initialize(self, params): + # Start with a single infection. + return {"S": self.population - 1, "E": 0, "I": 1, "D": 0} + + def transition(self, params, state, t): + R0, tau_e, tau_i, mu, rho = params + + # Sample flows between compartments. + S2E = pyro.sample("S2E_{}".format(t), + infection_dist(individual_rate=R0 / tau_i, + num_susceptible=state["S"], + num_infectious=state["I"], + population=self.population)) + E2I = pyro.sample("E2I_{}".format(t), + binomial_dist(state["E"], 1 / tau_e)) + # Of the 1/tau_i expected recoveries-or-deaths, a portion mu die and + # the remaining recover. Alternatively we could model this with a + # Multinomial distribution I2_ and extract the two components I2D and + # I2R, however the Multinomial distribution does not currently + # implement overdispersion or moment matching. + I2D = pyro.sample("I2D_{}".format(t), + binomial_dist(state["I"], mu / tau_i)) + I2R = pyro.sample("I2R_{}".format(t), + binomial_dist(state["I"] - I2D, 1 / tau_i)) + + # Update compartments with flows. + state["S"] = state["S"] - S2E + state["E"] = state["E"] + S2E - E2I + state["I"] = state["I"] + E2I - I2R - I2D + state["D"] = state["D"] + I2D + + # Condition on observations. + t_is_observed = isinstance(t, slice) or t < self.duration + pyro.sample("obs_{}".format(t), + binomial_dist(S2E, rho), + obs=self.data[t] if t_is_observed else None) + + def compute_flows(self, prev, curr, t): + S2E = prev["S"] - curr["S"] # S can only go to E. + I2D = curr["D"] - prev["D"] # D can only have come from I. + # We deduce the remaining flows by conservation of mass: + # curr - prev = inflows - outflows + E2I = prev["E"] - curr["E"] + S2E + I2R = prev["I"] - curr["I"] + E2I - I2D + return { + "S2E_{}".format(t): S2E, + "E2I_{}".format(t): E2I, + "I2D_{}".format(t): I2D, + "I2R_{}".format(t): I2R, + } + + class OverdispersedSIRModel(CompartmentalModel): """ Generalizes :class:`SimpleSIRModel` with overdispersed distributions. diff --git a/tests/contrib/epidemiology/test_models.py b/tests/contrib/epidemiology/test_models.py index 5a1e2bfae3..2aea21c6d6 100644 --- a/tests/contrib/epidemiology/test_models.py +++ b/tests/contrib/epidemiology/test_models.py @@ -10,8 +10,8 @@ import pyro.distributions as dist from pyro.contrib.epidemiology.models import (HeterogeneousRegionalSIRModel, HeterogeneousSIRModel, OverdispersedSEIRModel, OverdispersedSIRModel, RegionalSIRModel, - SimpleSEIRModel, SimpleSIRModel, SparseSIRModel, SuperspreadingSEIRModel, - SuperspreadingSIRModel, UnknownStartSIRModel) + SimpleSEIRDModel, SimpleSEIRModel, SimpleSIRModel, SparseSIRModel, + SuperspreadingSEIRModel, SuperspreadingSIRModel, UnknownStartSIRModel) from tests.common import xfail_param logger = logging.getLogger(__name__) @@ -111,6 +111,46 @@ def test_simple_seir_smoke(duration, forecast, options, algo): assert samples["I"].shape == (num_samples, duration + forecast) +@pytest.mark.parametrize("duration", [3, 7]) +@pytest.mark.parametrize("forecast", [0, 7]) +@pytest.mark.parametrize("algo,options", [ + ("svi", {}), + ("mcmc", {}), + ("mcmc", {"haar_full_mass": 2}), +], ids=str) +def test_simple_seird_smoke(duration, forecast, options, algo): + population = 100 + incubation_time = 2.0 + recovery_time = 7.0 + mortality_rate = 0.1 + + # Generate data. + model = SimpleSEIRDModel(population, incubation_time, recovery_time, + mortality_rate, [None] * duration) + assert model.full_mass == [("R0", "rho")] + for attempt in range(100): + data = model.generate({"R0": 1.5, "rho": 0.5})["obs"] + if data.sum(): + break + assert data.sum() > 0, "failed to generate positive data" + + # Infer. + model = SimpleSEIRDModel(population, incubation_time, recovery_time, + mortality_rate, data) + num_samples = 5 + if algo == "mcmc": + model.fit_mcmc(warmup_steps=1, num_samples=num_samples, max_tree_depth=2, **options) + else: + model.fit_svi(num_steps=2, num_samples=num_samples, **options) + + # Predict and forecast. + samples = model.predict(forecast=forecast) + assert samples["S"].shape == (num_samples, duration + forecast) + assert samples["E"].shape == (num_samples, duration + forecast) + assert samples["I"].shape == (num_samples, duration + forecast) + assert samples["D"].shape == (num_samples, duration + forecast) + + @pytest.mark.parametrize("duration", [3]) @pytest.mark.parametrize("forecast", [7]) @pytest.mark.parametrize("options", [ diff --git a/tutorial/source/epi_intro.ipynb b/tutorial/source/epi_intro.ipynb index 699b24246d..fe01e87cf4 100644 --- a/tutorial/source/epi_intro.ipynb +++ b/tutorial/source/epi_intro.ipynb @@ -876,7 +876,9 @@ "```diff\n", "+ state[\"S\"] = state[\"S\"] - S2I # Correct\n", "- state[\"S\"] -= S2I # AVOID: may corrupt tensors\n", - "```" + "```\n", + "\n", + "For a slightly more complex example, take a look at the [SimpleSEIRDModel](http://docs.pyro.ai/en/latest/contrib.epidemiology.html#pyro.contrib.epidemiology.models.SimpleSEIRDModel)." ] }, {