Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify compartmental model interface #2445

Merged
merged 2 commits into from
Apr 26, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/contrib/epidemiology/sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
import torch

import pyro
from pyro.contrib.epidemiology import SIRModel
from pyro.contrib.epidemiology import SimpleSIRModel

logging.basicConfig(format='%(message)s', level=logging.INFO)


def generate_data(args):
extended_data = [None] * (args.duration + args.forecast)
model = SIRModel(args.population, args.recovery_time, extended_data)
model = SimpleSIRModel(args.population, args.recovery_time, extended_data)
for attempt in range(100):
samples = model.generate({"R0": args.basic_reproduction_number,
"rho": args.response_rate})
Expand Down Expand Up @@ -49,6 +49,7 @@ def hook_fn(kernel, *unused):
mcmc = model.fit(warmup_steps=args.warmup_steps,
num_samples=args.num_samples,
max_tree_depth=args.max_tree_depth,
num_quant_bins=args.num_bins,
dct=args.dct,
hook_fn=hook_fn)

Expand Down Expand Up @@ -131,8 +132,7 @@ def main(args):
obs = dataset["obs"]

# Run inference.
model = SIRModel(args.population, args.recovery_time, obs,
num_quant_bins=args.num_bins)
model = SimpleSIRModel(args.population, args.recovery_time, obs)
samples = infer(args, model)

# Evaluate fit.
Expand Down
4 changes: 2 additions & 2 deletions pyro/contrib/epidemiology/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
# SPDX-License-Identifier: Apache-2.0

from .compartmental import CompartmentalModel
from .sir import SIRModel
from .sir import SimpleSIRModel

__all__ = [
"CompartmentalModel",
"SIRModel",
"SimpleSIRModel",
]
13 changes: 6 additions & 7 deletions pyro/contrib/epidemiology/compartmental.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,6 @@ def transition_bwd(self, params, prev, curr, t): ...
:param list compartments: A list of strings of compartment names.
:param int duration:
:param int population:
:param int num_quant_bins: The number of quantization bins to use. Note that
computational cost is exponential in `num_quant_bins`. Defaults to 4.
"""

def __init__(self, compartments, duration, population, *,
Expand All @@ -94,10 +92,6 @@ def __init__(self, compartments, duration, population, *,
assert len(compartments) == len(set(compartments))
self.compartments = compartments

assert isinstance(num_quant_bins, int)
assert num_quant_bins >= 2
self.num_quant_bins = num_quant_bins

# Inference state.
self.samples = {}

Expand Down Expand Up @@ -223,13 +217,18 @@ def fit(self, **options):
:class:`~pyro.infer.mcmc.nuts.NUTS` kernel.
:param full_mass: (Default ``False``). Specification of mass matrix
of the :class:`~pyro.infer.mcmc.nuts.NUTS` kernel.
:param int num_quant_bins: The number of quantization bins to use. Note
that computational cost is exponential in `num_quant_bins`.
Defaults to 4.
:param float dct: If provided, use a discrete cosine reparameterizer
with this value as smoothness.
:returns: An MCMC object for diagnostics, e.g. ``MCMC.summary()``.
:rtype: ~pyro.infer.mcmc.api.MCMC
"""
logger.info("Running inference...")
self._dct = options.pop("dct", None) # Save for .predict().
# Save these options for .predict().
self.num_quant_bins = options.pop("num_quant_bins", 4)
self._dct = options.pop("dct", None)

# Heuristically initialze to feasible latents.
init_values = self.heuristic()
Expand Down
25 changes: 17 additions & 8 deletions pyro/contrib/epidemiology/sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,35 @@
from .compartmental import CompartmentalModel


class SIRModel(CompartmentalModel):
class SimpleSIRModel(CompartmentalModel):
"""
Susceptible-Infected-Recovered model.

:param int population:
:param float recovery_time:
To customize this model we recommend forking and editing this class.

This is a stochastic discrete-time discrete-state model with three
compartments: "S" for susceptible, "I" for infected, and "R" for
recovered individuals (the recovered individuals are implicit: ``R =
population - S - I``) with transitions ``S -> I -> R``.

:param int population: Total ``population = S + I + R``.
:param float recovery_time: Mean recovery time (duration in state
``I``). Must be greater than 1.
:param iterable data: Time series of new observed infections.
:param int data: Time series of new observed infections.
:param int num_quant_bins: The number of quantization bins to use. Note that
computational cost is exponential in `num_quant_bins`. Defaults to 4.
:param int num_quant_bins: The number of quantization bins to use.
Note that computational cost is exponential in `num_quant_bins`.
Defaults to 4.
"""

def __init__(self, population, recovery_time, data, *,
num_quant_bins=4):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't see where num_quant_bins is piped here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, thanks for catching, I've removed.

compartments = ("S", "I") # R is implicit.
duration = len(data)
super().__init__(compartments, duration, population, num_quant_bins=num_quant_bins)
super().__init__(compartments, duration, population)

assert isinstance(recovery_time, float)
assert recovery_time > 0
assert recovery_time > 1
self.recovery_time = recovery_time

self.data = data
Expand Down Expand Up @@ -61,7 +70,7 @@ def global_model(self):

# Convert interpretable parameters to distribution parameters.
rate_s = -R0 / (tau * self.population)
prob_i = 1 / (1 + tau)
prob_i = 1 / tau

return rate_s, prob_i, rho

Expand Down
17 changes: 12 additions & 5 deletions tests/contrib/epidemiology/test_sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,34 @@

import pytest

from pyro.contrib.epidemiology import SIRModel
from pyro.contrib.epidemiology import SimpleSIRModel


@pytest.mark.parametrize("duration", [3, 7])
@pytest.mark.parametrize("forecast", [0, 7])
def test_smoke(duration, forecast):
@pytest.mark.parametrize("options", [
{},
{"dct": 1.},
{"num_quant_bins": 8},
{"num_quant_bins": 12},
{"num_quant_bins": 16},
], ids=str)
def test_smoke(duration, forecast, options):
population = 100
recovery_time = 7.0

# Generate data.
model = SIRModel(population, recovery_time, [None] * duration)
model = SimpleSIRModel(population, recovery_time, [None] * duration)
for attempt in range(100):
data = model.generate({"R0": 1.5, "rho": 0.5})["obs"]
if data.sum():
break
assert data.sum() > 0, "failed to generate positive data"

# Infer.
model = SIRModel(population, recovery_time, data)
model = SimpleSIRModel(population, recovery_time, data)
num_samples = 5
model.fit(warmup_steps=2, num_samples=num_samples, max_tree_depth=2)
model.fit(warmup_steps=1, num_samples=num_samples, max_tree_depth=2, **options)

# Predict and forecast.
samples = model.predict(forecast=forecast)
Expand Down