diff --git a/examples/contrib/epidemiology/sir.py b/examples/contrib/epidemiology/sir.py index 75fdb7f0e7..d1d8788f09 100644 --- a/examples/contrib/epidemiology/sir.py +++ b/examples/contrib/epidemiology/sir.py @@ -11,14 +11,14 @@ import torch import pyro -from pyro.contrib.epidemiology import SIRModel +from pyro.contrib.epidemiology import SimpleSIRModel logging.basicConfig(format='%(message)s', level=logging.INFO) def generate_data(args): extended_data = [None] * (args.duration + args.forecast) - model = SIRModel(args.population, args.recovery_time, extended_data) + model = SimpleSIRModel(args.population, args.recovery_time, extended_data) for attempt in range(100): samples = model.generate({"R0": args.basic_reproduction_number, "rho": args.response_rate}) @@ -49,6 +49,7 @@ def hook_fn(kernel, *unused): mcmc = model.fit(warmup_steps=args.warmup_steps, num_samples=args.num_samples, max_tree_depth=args.max_tree_depth, + num_quant_bins=args.num_bins, dct=args.dct, hook_fn=hook_fn) @@ -131,8 +132,7 @@ def main(args): obs = dataset["obs"] # Run inference. - model = SIRModel(args.population, args.recovery_time, obs, - num_quant_bins=args.num_bins) + model = SimpleSIRModel(args.population, args.recovery_time, obs) samples = infer(args, model) # Evaluate fit. diff --git a/pyro/contrib/epidemiology/__init__.py b/pyro/contrib/epidemiology/__init__.py index 28812fce77..a6f3569c83 100644 --- a/pyro/contrib/epidemiology/__init__.py +++ b/pyro/contrib/epidemiology/__init__.py @@ -2,9 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 from .compartmental import CompartmentalModel -from .sir import SIRModel +from .sir import SimpleSIRModel __all__ = [ "CompartmentalModel", - "SIRModel", + "SimpleSIRModel", ] diff --git a/pyro/contrib/epidemiology/compartmental.py b/pyro/contrib/epidemiology/compartmental.py index 8f59672d55..4d4a0feff2 100644 --- a/pyro/contrib/epidemiology/compartmental.py +++ b/pyro/contrib/epidemiology/compartmental.py @@ -73,8 +73,6 @@ def transition_bwd(self, params, prev, curr, t): ... :param list compartments: A list of strings of compartment names. :param int duration: :param int population: - :param int num_quant_bins: The number of quantization bins to use. Note that - computational cost is exponential in `num_quant_bins`. Defaults to 4. """ def __init__(self, compartments, duration, population, *, @@ -94,10 +92,6 @@ def __init__(self, compartments, duration, population, *, assert len(compartments) == len(set(compartments)) self.compartments = compartments - assert isinstance(num_quant_bins, int) - assert num_quant_bins >= 2 - self.num_quant_bins = num_quant_bins - # Inference state. self.samples = {} @@ -223,13 +217,18 @@ def fit(self, **options): :class:`~pyro.infer.mcmc.nuts.NUTS` kernel. :param full_mass: (Default ``False``). Specification of mass matrix of the :class:`~pyro.infer.mcmc.nuts.NUTS` kernel. + :param int num_quant_bins: The number of quantization bins to use. Note + that computational cost is exponential in `num_quant_bins`. + Defaults to 4. :param float dct: If provided, use a discrete cosine reparameterizer with this value as smoothness. :returns: An MCMC object for diagnostics, e.g. ``MCMC.summary()``. :rtype: ~pyro.infer.mcmc.api.MCMC """ logger.info("Running inference...") - self._dct = options.pop("dct", None) # Save for .predict(). + # Save these options for .predict(). + self.num_quant_bins = options.pop("num_quant_bins", 4) + self._dct = options.pop("dct", None) # Heuristically initialze to feasible latents. init_values = self.heuristic() diff --git a/pyro/contrib/epidemiology/sir.py b/pyro/contrib/epidemiology/sir.py index c72cd9f702..310450f363 100644 --- a/pyro/contrib/epidemiology/sir.py +++ b/pyro/contrib/epidemiology/sir.py @@ -10,26 +10,31 @@ from .compartmental import CompartmentalModel -class SIRModel(CompartmentalModel): +class SimpleSIRModel(CompartmentalModel): """ Susceptible-Infected-Recovered model. - :param int population: - :param float recovery_time: + To customize this model we recommend forking and editing this class. + + This is a stochastic discrete-time discrete-state model with three + compartments: "S" for susceptible, "I" for infected, and "R" for + recovered individuals (the recovered individuals are implicit: ``R = + population - S - I``) with transitions ``S -> I -> R``. + + :param int population: Total ``population = S + I + R``. + :param float recovery_time: Mean recovery time (duration in state + ``I``). Must be greater than 1. :param iterable data: Time series of new observed infections. :param int data: Time series of new observed infections. - :param int num_quant_bins: The number of quantization bins to use. Note that - computational cost is exponential in `num_quant_bins`. Defaults to 4. """ - def __init__(self, population, recovery_time, data, *, - num_quant_bins=4): + def __init__(self, population, recovery_time, data): compartments = ("S", "I") # R is implicit. duration = len(data) - super().__init__(compartments, duration, population, num_quant_bins=num_quant_bins) + super().__init__(compartments, duration, population) assert isinstance(recovery_time, float) - assert recovery_time > 0 + assert recovery_time > 1 self.recovery_time = recovery_time self.data = data @@ -61,7 +66,7 @@ def global_model(self): # Convert interpretable parameters to distribution parameters. rate_s = -R0 / (tau * self.population) - prob_i = 1 / (1 + tau) + prob_i = 1 / tau return rate_s, prob_i, rho diff --git a/tests/contrib/epidemiology/test_sir.py b/tests/contrib/epidemiology/test_sir.py index 04d0a83fea..78f0e2dfcb 100644 --- a/tests/contrib/epidemiology/test_sir.py +++ b/tests/contrib/epidemiology/test_sir.py @@ -3,17 +3,24 @@ import pytest -from pyro.contrib.epidemiology import SIRModel +from pyro.contrib.epidemiology import SimpleSIRModel @pytest.mark.parametrize("duration", [3, 7]) @pytest.mark.parametrize("forecast", [0, 7]) -def test_smoke(duration, forecast): +@pytest.mark.parametrize("options", [ + {}, + {"dct": 1.}, + {"num_quant_bins": 8}, + {"num_quant_bins": 12}, + {"num_quant_bins": 16}, +], ids=str) +def test_smoke(duration, forecast, options): population = 100 recovery_time = 7.0 # Generate data. - model = SIRModel(population, recovery_time, [None] * duration) + model = SimpleSIRModel(population, recovery_time, [None] * duration) for attempt in range(100): data = model.generate({"R0": 1.5, "rho": 0.5})["obs"] if data.sum(): @@ -21,9 +28,9 @@ def test_smoke(duration, forecast): assert data.sum() > 0, "failed to generate positive data" # Infer. - model = SIRModel(population, recovery_time, data) + model = SimpleSIRModel(population, recovery_time, data) num_samples = 5 - model.fit(warmup_steps=2, num_samples=num_samples, max_tree_depth=2) + model.fit(warmup_steps=1, num_samples=num_samples, max_tree_depth=2, **options) # Predict and forecast. samples = model.predict(forecast=forecast)