Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add overdispersed models to contrib.epidemiology #2498

Merged
merged 14 commits into from
May 30, 2020
Merged
9 changes: 8 additions & 1 deletion examples/contrib/epidemiology/regional.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,20 @@ def main(args):
parser.add_argument("-w", "--warmup-steps", default=100, type=int)
parser.add_argument("-t", "--max-tree-depth", default=5, type=int)
parser.add_argument("-nb", "--num-bins", default=4, type=int)
parser.add_argument("--double", action="store_true", default=True)
parser.add_argument("--single", action="store_false", dest="double")
parser.add_argument("--rng-seed", default=0, type=int)
parser.add_argument("--cuda", action="store_true")
parser.add_argument("--verbose", action="store_true")
parser.add_argument("--plot", action="store_true")
args = parser.parse_args()

if args.cuda:
if args.double:
if args.cuda:
torch.set_default_tensor_type(torch.cuda.DoubleTensor)
else:
torch.set_default_dtype(torch.float64)
elif args.cuda:
torch.set_default_tensor_type(torch.cuda.FloatTensor)

main(args)
Expand Down
55 changes: 38 additions & 17 deletions examples/contrib/epidemiology/sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from torch.distributions import biject_to, constraints

import pyro
from pyro.contrib.epidemiology import SimpleSEIRModel, SimpleSIRModel, SuperspreadingSEIRModel, SuperspreadingSIRModel
from pyro.contrib.epidemiology import (OverdispersedSEIRModel, OverdispersedSIRModel, SimpleSEIRModel, SimpleSIRModel,
SuperspreadingSEIRModel, SuperspreadingSIRModel)

logging.basicConfig(format='%(message)s', level=logging.INFO)

Expand All @@ -22,17 +23,22 @@ def Model(args, data):
"""Dispatch between different model classes."""
if args.incubation_time > 0:
assert args.incubation_time > 1
if args.concentration == math.inf:
return SimpleSEIRModel(args.population, args.incubation_time,
args.recovery_time, data)
else:
if args.concentration < math.inf:
return SuperspreadingSEIRModel(args.population, args.incubation_time,
args.recovery_time, data)
else:
if args.concentration == math.inf:
return SimpleSIRModel(args.population, args.recovery_time, data)
elif args.overdispersion > 0:
return OverdispersedSEIRModel(args.population, args.incubation_time,
args.recovery_time, data)
else:
return SimpleSEIRModel(args.population, args.incubation_time,
args.recovery_time, data)
else:
if args.concentration < math.inf:
return SuperspreadingSIRModel(args.population, args.recovery_time, data)
elif args.overdispersion > 0:
return OverdispersedSIRModel(args.population, args.recovery_time, data)
else:
return SimpleSIRModel(args.population, args.recovery_time, data)


def generate_data(args):
Expand All @@ -42,20 +48,29 @@ def generate_data(args):
for attempt in range(100):
samples = model.generate({"R0": args.basic_reproduction_number,
"rho": args.response_rate,
"k": args.concentration})
"k": args.concentration,
"od": args.overdispersion})
obs = samples["obs"][:args.duration]
new_I = samples.get("S2I", samples.get("E2I"))

obs_sum = int(obs.sum())
new_I_sum = int(new_I[:args.duration].sum())
if obs_sum >= args.min_observations:
assert 0 <= args.min_obs_portion < args.max_obs_portion <= 1
min_obs = int(math.ceil(args.min_obs_portion * args.population))
max_obs = int(math.floor(args.max_obs_portion * args.population))
if min_obs <= obs_sum <= max_obs:
logging.info("Observed {:d}/{:d} infections:\n{}".format(
obs_sum, new_I_sum, " ".join(str(int(x)) for x in obs)))
return {"new_I": new_I, "obs": obs}

raise ValueError("Failed to generate {} observations. Try increasing "
"--population or decreasing --min-observations"
.format(args.min_observations))
if obs_sum < min_obs:
raise ValueError("Failed to generate >={} observations. "
"Try decreasing --min-obs-portion (currently {})."
.format(min_obs, args.min_obs_portion))
else:
raise ValueError("Failed to generate <={} observations. "
"Try increasing --max-obs-portion (currently {})."
.format(max_obs, args.max_obs_portion))


def infer(args, model):
Expand Down Expand Up @@ -97,6 +112,8 @@ def evaluate(args, model, samples):
"response_rate": "rho"}
if args.concentration < math.inf:
names["concentration"] = "k"
if "od" in samples:
names["overdispersion"] = "od"
for name, key in names.items():
mean = samples[key].mean().item()
std = samples[key].std().item()
Expand Down Expand Up @@ -227,8 +244,9 @@ def main(args):
assert pyro.__version__.startswith('1.3.1')
parser = argparse.ArgumentParser(
description="Compartmental epidemiology modeling using HMC")
parser.add_argument("-p", "--population", default=1000, type=int)
parser.add_argument("-m", "--min-observations", default=3, type=int)
parser.add_argument("-p", "--population", default=1000, type=float)
parser.add_argument("-m", "--min-obs-portion", default=0.01, type=float)
parser.add_argument("-M", "--max-obs-portion", default=0.99, type=float)
parser.add_argument("-d", "--duration", default=20, type=int)
parser.add_argument("-f", "--forecast", default=10, type=int)
parser.add_argument("-R0", "--basic-reproduction-number", default=1.5, type=float)
Expand All @@ -238,6 +256,7 @@ def main(args):
parser.add_argument("-k", "--concentration", default=math.inf, type=float,
help="If finite, use a superspreader model.")
parser.add_argument("-rho", "--response-rate", default=0.5, type=float)
parser.add_argument("-o", "--overdispersion", default=0., type=float)
parser.add_argument("--haar", action="store_true")
parser.add_argument("-hfm", "--haar-full-mass", default=0, type=int)
parser.add_argument("-n", "--num-samples", default=200, type=int)
Expand All @@ -248,17 +267,19 @@ def main(args):
parser.add_argument("-a", "--arrowhead-mass", action="store_true")
parser.add_argument("-r", "--rng-seed", default=0, type=int)
parser.add_argument("-nb", "--num-bins", default=4, type=int)
parser.add_argument("--double", action="store_true")
parser.add_argument("--double", action="store_true", default=True)
parser.add_argument("--single", action="store_false", dest="double")
parser.add_argument("--cuda", action="store_true")
parser.add_argument("--verbose", action="store_true")
parser.add_argument("--plot", action="store_true")
args = parser.parse_args()
args.population = int(args.population) # to allow e.g. --population=1e6

if args.double:
if args.cuda:
torch.set_default_tensor_type(torch.cuda.DoubleTensor)
else:
torch.set_default_tensor_type(torch.DoubleTensor)
torch.set_default_dtype(torch.float64)
elif args.cuda:
torch.set_default_tensor_type(torch.cuda.FloatTensor)

Expand Down
10 changes: 7 additions & 3 deletions pyro/contrib/epidemiology/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,22 @@
# SPDX-License-Identifier: Apache-2.0

from .compartmental import CompartmentalModel
from .distributions import infection_dist
from .models import (RegionalSIRModel, SimpleSEIRModel, SimpleSIRModel, SparseSIRModel, SuperspreadingSEIRModel,
SuperspreadingSIRModel, UnknownStartSIRModel)
from .distributions import beta_binomial_dist, binomial_dist, infection_dist
from .models import (OverdispersedSEIRModel, OverdispersedSIRModel, RegionalSIRModel, SimpleSEIRModel, SimpleSIRModel,
SparseSIRModel, SuperspreadingSEIRModel, SuperspreadingSIRModel, UnknownStartSIRModel)

__all__ = [
"CompartmentalModel",
"OverdispersedSEIRModel",
"OverdispersedSIRModel",
"RegionalSIRModel",
"SimpleSEIRModel",
"SimpleSIRModel",
"SparseSIRModel",
"SuperspreadingSEIRModel",
"SuperspreadingSIRModel",
"UnknownStartSIRModel",
"beta_binomial_dist",
"binomial_dist",
"infection_dist",
]
12 changes: 12 additions & 0 deletions pyro/contrib/epidemiology/compartmental.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import operator
import re
import warnings
from abc import ABC, abstractmethod
from collections import OrderedDict
from contextlib import ExitStack
Expand All @@ -29,6 +30,13 @@
logger = logging.getLogger(__name__)


def _require_double_precision():
if torch.get_default_dtype() != torch.float64:
warnings.warn("CompartmentalModel is unstable for dtypes less than torch.float64; "
"try torch.set_default_dtype(torch.float64)",
RuntimeWarning)


class CompartmentalModel(ABC):
"""
Abstract base class for discrete-time discrete-value stochastic
Expand Down Expand Up @@ -311,6 +319,8 @@ def fit(self, **options):
:returns: An MCMC object for diagnostics, e.g. ``MCMC.summary()``.
:rtype: ~pyro.infer.mcmc.api.MCMC
"""
_require_double_precision()

# Parse options, saving some for use in .predict().
self.num_quant_bins = options.pop("num_quant_bins", 4)
haar = options.pop("haar", False)
Expand Down Expand Up @@ -413,8 +423,10 @@ def predict(self, forecast=0):
to a tensor whose first dimension corresponds to sample batching.
:rtype: dict
"""
_require_double_precision()
if not self.samples:
raise RuntimeError("Missing samples, try running .fit() first")

samples = self.samples
num_samples = len(next(iter(samples.values())))
particle_plate = pyro.plate("particles", num_samples,
Expand Down
Loading