Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement a regional SIR model with approximate inference #2466

Merged
merged 31 commits into from
May 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
bbbd486
WIP implement unknown start time SIR model
fritzo Apr 30, 2020
43480e7
Merge branch 'dev' into sir-truncated
fritzo May 1, 2020
7a21fc5
Add .predict() method
fritzo May 1, 2020
d4cc9fa
Merge branch 'dev' into sir-truncated
fritzo May 1, 2020
3f9c925
Fix indexing logic
fritzo May 1, 2020
29825c8
Add test for Index()[]
fritzo May 1, 2020
3d73031
Fix docs
fritzo May 1, 2020
72c15a6
WIP sketch regional model
fritzo May 4, 2020
27b20d1
Address review comments
fritzo May 4, 2020
5b4c2de
Merge branch 'sir-truncated' into sir-regional
fritzo May 4, 2020
fad40cd
Fix typo
fritzo May 4, 2020
b3751e6
Merge branch 'sir-truncated' into sir-regional
fritzo May 4, 2020
ccce15d
Merge branch 'dev' into sir-regional
fritzo May 5, 2020
4691214
Merge branch 'dev' into sir-regional
fritzo May 5, 2020
00fcdd7
WIP refactor CompartmentalModel
fritzo May 6, 2020
8aba67e
WIP move enum dimensions to left
fritzo May 6, 2020
d1def5d
Order dimensions as EPTR and with aux as CTR
fritzo May 7, 2020
0cd7a81
Fix bugs
fritzo May 7, 2020
f9a2275
More fixes
fritzo May 7, 2020
7b82a47
Tweak docs
fritzo May 7, 2020
ed0d55f
Merge branch 'dev' into sir-regional
fritzo May 7, 2020
94b8aa2
Add __init__.py files to test dir to pacify pytest
fritzo May 7, 2020
b10fda9
Refactor to use an align_samples() helper
fritzo May 7, 2020
6e3388e
Make one parameter heterogeneous
fritzo May 7, 2020
1488c69
Add simple example script
fritzo May 8, 2020
1a3b62c
Add --coupling command line option
fritzo May 8, 2020
7403488
Add regional.py to test_examples.py
fritzo May 8, 2020
6c79086
Merge branch 'dev' into sir-regional
fritzo May 9, 2020
e1fde53
Expose approximation interface
fritzo May 9, 2020
bc03b28
Improve docs
fritzo May 9, 2020
9354464
Fix typo
fritzo May 10, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions examples/contrib/epidemiology/regional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import argparse
import logging

import torch

import pyro
from pyro.contrib.epidemiology import RegionalSIRModel

logging.basicConfig(format='%(message)s', level=logging.INFO)


def Model(args, data):
assert 0 <= args.coupling <= 1, args.coupling
population = torch.full((args.num_regions,), float(args.population))
coupling = torch.eye(args.num_regions).clamp(min=args.coupling)
return RegionalSIRModel(population, coupling, args.recovery_time, data)


def generate_data(args):
extended_data = [None] * (args.duration + args.forecast)
model = Model(args, extended_data)
logging.info("Simulating from a {}".format(type(model).__name__))
for attempt in range(100):
samples = model.generate({"R0": args.basic_reproduction_number,
"rho_c1": 10 * args.response_rate,
"rho_c0": 10 * (1 - args.response_rate)})
obs = samples["obs"][:args.duration]
S2I = samples["S2I"]

obs_sum = int(obs.sum())
S2I_sum = int(S2I[:args.duration].sum())
if obs_sum >= args.min_observations:
logging.info("Observed {:d}/{:d} infections:\n{}".format(
obs_sum, S2I_sum, " ".join(str(int(x)) for x in obs[:, 0])))
return {"S2I": S2I, "obs": obs}

raise ValueError("Failed to generate {} observations. Try increasing "
"--population or decreasing --min-observations"
.format(args.min_observations))


def infer(args, model):
energies = []

def hook_fn(kernel, *unused):
e = float(kernel._potential_energy_last)
energies.append(e)
if args.verbose:
logging.info("potential = {:0.6g}".format(e))

mcmc = model.fit(heuristic_num_particles=args.num_particles,
warmup_steps=args.warmup_steps,
num_samples=args.num_samples,
max_tree_depth=args.max_tree_depth,
num_quant_bins=args.num_bins,
hook_fn=hook_fn)

mcmc.summary()
if args.plot:
import matplotlib.pyplot as plt
plt.figure(figsize=(6, 3))
plt.plot(energies)
plt.xlabel("MCMC step")
plt.ylabel("potential energy")
plt.title("MCMC energy trace")
plt.tight_layout()

return model.samples


def predict(args, model, truth):
samples = model.predict(forecast=args.forecast)
S2I = samples["S2I"]
median = S2I.median(dim=0).values
lines = ["Median prediction of new infections (starting on day 0):"]
for r in range(args.num_regions):
lines.append("Region {}: {}".format(r, " ".join(map(str, map(int, median[:, r])))))
logging.info("\n".join(lines))

# Optionally plot the latent and forecasted series of new infections.
if args.plot:
import matplotlib.pyplot as plt
fig, axes = plt.subplots(args.num_regions, sharex=True,
figsize=(6, 1 + args.num_regions))
time = torch.arange(args.duration + args.forecast)
p05 = S2I.kthvalue(int(round(0.5 + 0.05 * args.num_samples)), dim=0).values
p95 = S2I.kthvalue(int(round(0.5 + 0.95 * args.num_samples)), dim=0).values
for r, ax in enumerate(axes):
ax.fill_between(time, p05[:, r], p95[:, r], color="red", alpha=0.3, label="90% CI")
ax.plot(time, median[:, r], "r-", label="median")
ax.plot(time[:args.duration], model.data[:, r], "k.", label="observed")
ax.plot(time, truth[:, r], "k--", label="truth")
ax.axvline(args.duration - 0.5, color="gray", lw=1)
ax.set_xlim(0, len(time) - 1)
ax.set_ylim(0, None)
axes[0].set_title("New infections among {} regions each of size {}"
.format(args.num_regions, args.population))
axes[args.num_regions // 2].set_ylabel("inf./day")
axes[-1].set_xlabel("day after first infection")
axes[-1].legend(loc="upper left")
plt.tight_layout()
plt.subplots_adjust(hspace=0)


def main(args):
pyro.enable_validation(__debug__)
pyro.set_rng_seed(args.rng_seed)

# Generate data.
dataset = generate_data(args)
obs = dataset["obs"]

# Run inference.
model = Model(args, obs)
infer(args, model)

# Predict latent time series.
predict(args, model, truth=dataset["S2I"])


if __name__ == "__main__":
assert pyro.__version__.startswith('1.3.1')
parser = argparse.ArgumentParser(
description="Regional compartmental epidemiology modeling using HMC")
parser.add_argument("-p", "--population", default=1000, type=int)
parser.add_argument("-r", "--num-regions", default=2, type=int)
parser.add_argument("-c", "--coupling", default=0.1, type=float)
parser.add_argument("-m", "--min-observations", default=3, type=int)
parser.add_argument("-d", "--duration", default=20, type=int)
parser.add_argument("-f", "--forecast", default=10, type=int)
parser.add_argument("-R0", "--basic-reproduction-number", default=1.5, type=float)
parser.add_argument("-tau", "--recovery-time", default=7.0, type=float)
parser.add_argument("-rho", "--response-rate", default=0.5, type=float)
parser.add_argument("-n", "--num-samples", default=200, type=int)
parser.add_argument("-np", "--num-particles", default=1024, type=int)
parser.add_argument("-w", "--warmup-steps", default=100, type=int)
parser.add_argument("-t", "--max-tree-depth", default=5, type=int)
parser.add_argument("-nb", "--num-bins", default=4, type=int)
parser.add_argument("--rng-seed", default=0, type=int)
parser.add_argument("--cuda", action="store_true")
parser.add_argument("--verbose", action="store_true")
parser.add_argument("--plot", action="store_true")
args = parser.parse_args()

if args.cuda:
torch.set_default_tensor_type(torch.cuda.FloatTensor)

main(args)

if args.plot:
import matplotlib.pyplot as plt
plt.show()
3 changes: 2 additions & 1 deletion pyro/contrib/epidemiology/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
from .compartmental import CompartmentalModel
from .distributions import infection_dist
from .seir import OverdispersedSEIRModel, SimpleSEIRModel
from .sir import OverdispersedSIRModel, SimpleSIRModel, SparseSIRModel, UnknownStartSIRModel
from .sir import OverdispersedSIRModel, RegionalSIRModel, SimpleSIRModel, SparseSIRModel, UnknownStartSIRModel

__all__ = [
"CompartmentalModel",
"OverdispersedSEIRModel",
"OverdispersedSIRModel",
"RegionalSIRModel",
"SimpleSEIRModel",
"SimpleSIRModel",
"SparseSIRModel",
Expand Down
Loading