Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a time-heterogeneous SIR model #2517

Merged
merged 22 commits into from
Jun 5, 2020
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions examples/contrib/epidemiology/sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,19 @@
from torch.distributions import biject_to, constraints

import pyro
from pyro.contrib.epidemiology import (OverdispersedSEIRModel, OverdispersedSIRModel, SimpleSEIRModel, SimpleSIRModel,
SuperspreadingSEIRModel, SuperspreadingSIRModel)
from pyro.contrib.epidemiology import (HeterogeneousSIRModel, OverdispersedSEIRModel, OverdispersedSIRModel,
SimpleSEIRModel, SimpleSIRModel, SuperspreadingSEIRModel, SuperspreadingSIRModel)

logging.basicConfig(format='%(message)s', level=logging.INFO)


def Model(args, data):
"""Dispatch between different model classes."""
if args.incubation_time > 0:
if args.heterogeneous:
assert args.incubation_time == 0
assert args.overdispersion == 0
return HeterogeneousSIRModel(args.population, args.recovery_time, data)
elif args.incubation_time > 0:
assert args.incubation_time > 1
if args.concentration < math.inf:
return SuperspreadingSEIRModel(args.population, args.incubation_time,
Expand Down Expand Up @@ -108,8 +112,9 @@ def hook_fn(kernel, *unused):

def evaluate(args, model, samples):
# Print estimated values.
names = {"basic_reproduction_number": "R0",
"response_rate": "rho"}
names = {"basic_reproduction_number": "R0"}
if not args.heterogeneous:
names["response_rate"] = "rho"
if args.concentration < math.inf:
names["concentration"] = "k"
if "od" in samples:
Expand All @@ -127,6 +132,8 @@ def evaluate(args, model, samples):

# Plot individual histograms.
fig, axes = plt.subplots(len(names), 1, figsize=(5, 2.5 * len(names)))
if len(names) == 1:
axes = [axes]
axes[0].set_title("Posterior parameter estimates")
for ax, (name, key) in zip(axes, names.items()):
truth = getattr(args, name)
Expand All @@ -139,7 +146,7 @@ def evaluate(args, model, samples):

# Plot pairwise joint distributions for selected variables.
covariates = [(name, samples[name]) for name in names.values()]
for i, aux in enumerate(samples["auxiliary"].unbind(-2)):
for i, aux in enumerate(samples["auxiliary"].squeeze(1).unbind(-2)):
covariates.append(("aux[{},0]".format(i), aux[:, 0]))
covariates.append(("aux[{},-1]".format(i), aux[:, -1]))
N = len(covariates)
Expand All @@ -162,9 +169,10 @@ def unconstrain(constraint, value):
value = biject_to(constraint).inv(value)
return value.reshape(args.num_samples, -1)

covariates = [
("R1", unconstrain(constraints.positive, samples["R0"])),
("rho", unconstrain(constraints.unit_interval, samples["rho"]))]
covariates = [("R1", unconstrain(constraints.positive, samples["R0"]))]
if not args.heterogeneous:
covariates.append(
("rho", unconstrain(constraints.unit_interval, samples["rho"])))
if "k" in samples:
covariates.append(
("k", unconstrain(constraints.positive, samples["k"])))
Expand Down Expand Up @@ -219,6 +227,25 @@ def predict(args, model, truth):
plt.legend(loc="upper left")
plt.tight_layout()

# Plot Re time series.
if args.heterogeneous:
plt.figure()
Re = samples["Re"]
median = Re.median(dim=0).values
p05 = Re.kthvalue(int(round(0.5 + 0.05 * args.num_samples)), dim=0).values
p95 = Re.kthvalue(int(round(0.5 + 0.95 * args.num_samples)), dim=0).values
plt.fill_between(time, p05, p95, color="red", alpha=0.3, label="90% CI")
plt.plot(time, median, "r-", label="median")
plt.plot(time[:args.duration], obs, "k.", label="observed")
plt.axvline(args.duration - 0.5, color="gray", lw=1)
plt.xlim(0, len(time) - 1)
plt.ylim(0, None)
plt.xlabel("day after first infection")
plt.ylabel("Re")
plt.title("Effective reproductive number over time")
plt.legend(loc="upper left")
plt.tight_layout()


def main(args):
pyro.enable_validation(__debug__)
Expand Down Expand Up @@ -257,6 +284,7 @@ def main(args):
help="If finite, use a superspreader model.")
parser.add_argument("-rho", "--response-rate", default=0.5, type=float)
parser.add_argument("-o", "--overdispersion", default=0., type=float)
parser.add_argument("-hg", "--heterogeneous", action="store_true")
parser.add_argument("--haar", action="store_true")
parser.add_argument("-hfm", "--haar-full-mass", default=0, type=int)
parser.add_argument("-n", "--num-samples", default=200, type=int)
Expand Down
6 changes: 4 additions & 2 deletions pyro/contrib/epidemiology/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

from .compartmental import CompartmentalModel
from .distributions import beta_binomial_dist, binomial_dist, infection_dist
from .models import (OverdispersedSEIRModel, OverdispersedSIRModel, RegionalSIRModel, SimpleSEIRModel, SimpleSIRModel,
SparseSIRModel, SuperspreadingSEIRModel, SuperspreadingSIRModel, UnknownStartSIRModel)
from .models import (HeterogeneousSIRModel, OverdispersedSEIRModel, OverdispersedSIRModel, RegionalSIRModel,
SimpleSEIRModel, SimpleSIRModel, SparseSIRModel, SuperspreadingSEIRModel, SuperspreadingSIRModel,
UnknownStartSIRModel)

__all__ = [
"CompartmentalModel",
"HeterogeneousSIRModel",
"OverdispersedSEIRModel",
"OverdispersedSIRModel",
"RegionalSIRModel",
Expand Down
142 changes: 114 additions & 28 deletions pyro/contrib/epidemiology/compartmental.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import torch
from torch.distributions import biject_to, constraints
from torch.distributions.utils import lazy_property

import pyro.distributions as dist
import pyro.distributions.hmm
Expand All @@ -23,7 +24,7 @@
from pyro.infer.reparam import HaarReparam, SplitReparam
from pyro.infer.smcfilter import SMCFailed
from pyro.infer.util import is_validation_enabled
from pyro.util import warn_if_nan
from pyro.util import optional, warn_if_nan

from .distributions import set_approx_log_prob_tol, set_approx_sample_thresh
from .util import align_samples, cat2, clamp, quantize, quantize_enumerate
Expand Down Expand Up @@ -109,12 +110,12 @@ def __init__(self, compartments, duration, population, *,
assert population.dim() == 1
assert (population >= 1).all()
self.is_regional = True
self.max_plate_nesting = 1
self.max_plate_nesting = 2 # [time, region]
else:
assert isinstance(population, int)
assert population >= 2
self.is_regional = False
self.max_plate_nesting = 0
self.max_plate_nesting = 1 # [time]
self.population = population

compartments = tuple(compartments)
Expand All @@ -130,6 +131,16 @@ def __init__(self, compartments, duration, population, *,
self.samples = {}
self._clear_plates()

@property
def time_plate(self):
"""
A ``pyro.plate`` for the time dimension.
"""
if self._time_plate is None:
self._time_plate = pyro.plate("time", self.duration,
dim=-2 if self.is_regional else -1)
return self._time_plate

@property
def region_plate(self):
"""
Expand All @@ -144,6 +155,7 @@ def region_plate(self):
return self._region_plate

def _clear_plates(self):
self._time_plate = None
self._region_plate = None

# Overridable attributes and methods ########################################
Expand Down Expand Up @@ -430,7 +442,7 @@ def predict(self, forecast=0):
for name, site in trace.nodes.items()
if site["type"] == "sample")

self._concat_series(samples, forecast)
self._concat_series(samples, forecast, vectorized=True)
return samples

@torch.no_grad()
Expand Down Expand Up @@ -472,7 +484,7 @@ def heuristic(self, num_particles=1024, ess_threshold=0.5, retries=10):

# Select the most probable hypothesis.
i = int(smc.state._log_weights.max(0).indices)
init = {key: value[i] for key, value in smc.state.items()}
init = {key: value[i, 0] for key, value in smc.state.items()}

# Fill in sample site values.
init = self.generate(init)
Expand All @@ -482,7 +494,7 @@ def heuristic(self, num_particles=1024, ess_threshold=0.5, retries=10):

# Internal helpers ########################################

def _concat_series(self, samples, forecast=0):
def _concat_series(self, samples, forecast=0, vectorized=False):
"""
Concatenate sequential time series into tensors, in-place.

Expand All @@ -494,19 +506,58 @@ def _concat_series(self, samples, forecast=0):
for key in list(samples):
if re.match(pattern, key):
series.append(samples.pop(key))
if series:
assert len(series) == self.duration + forecast
series = torch.broadcast_tensors(*map(torch.as_tensor, series))
samples[name] = torch.stack(series, dim=-2 if self.is_regional else -1)
if not series:
continue
assert len(series) == self.duration + forecast
series = torch.broadcast_tensors(*map(torch.as_tensor, series))
if vectorized and name != "obs": # TODO Generalize.
samples[name] = torch.cat(series, dim=1)
else:
samples[name] = torch.stack(series)

@lazy_property
@torch.no_grad()
def _non_compartmental(self):
"""
A dict mapping name -> (distribution, is_regional) for all
non-compartmental sites in :meth:`transition`. For simple models this
is often empty; for time-heterogeneous models this may contain
time-local latent variables.
"""
# Trace a simple invocation of .transition().
with torch.no_grad(), poutine.block():
params = self.global_model()
prev = self.initialize(params)
for name in self.approximate:
prev[name + "_approx"] = prev[name]
curr = prev.copy()
with poutine.trace() as tr:
self.transition(params, curr, 0)
flows = self.compute_flows(prev, curr, 0)

# Extract latent variables that are not compartmental flows.
result = OrderedDict()
for name, site in tr.trace.iter_stochastic_nodes():
if name in flows:
continue
assert name.endswith("_0"), name
name = name[:-2]
assert name in self.series, name
# TODO This supports only the region_plate. For full plate support,
# this could be replaced by a self.plate() method as in EasyGuide.
is_regional = any(f.name == "region" for f in site["cond_indep_stack"])
result[name] = site["fn"], is_regional
return result

def _transition_bwd(self, params, prev, curr, t):
"""
Helper to collect probabilty factors from .transition() conditioned on
previous and current enumerated states.
"""
# Run .transition() conditioned on computed flows.
flows = self.compute_flows(prev, curr, t)
with poutine.condition(data=flows):
cond_data = {"{}_{}".format(k, t): v for k, v in curr.items()}
cond_data.update(self.compute_flows(prev, curr, t))
with poutine.condition(data=cond_data):
state = prev.copy()
self.transition(params, state, t) # Mutates state.

Expand Down Expand Up @@ -537,7 +588,7 @@ def _generative_model(self, forecast=0):
self.transition(params, state, t)
with self.region_plate:
for name in self.compartments:
pyro.deterministic("{}_{}".format(name, t), state[name])
pyro.deterministic("{}_{}".format(name, t), state[name], event_dim=0)

self._clear_plates()

Expand All @@ -552,30 +603,47 @@ def _sequential_model(self):
# Sample global parameters.
params = self.global_model()

# Sample the continuous reparameterizing variable.
# Sample the compartmental continuous reparameterizing variable.
shape = (C, T) + R_shape
auxiliary = pyro.sample("auxiliary",
dist.Uniform(-0.5, self.population + 0.5)
.mask(False).expand(shape).to_event())
num_samples = auxiliary.size(0)
if self.is_regional:
# This reshapes from (particle, 1, C, T, R) -> (particle, C, T, R)
# to allow aux below to have shape (particle, R) for region_plate.
auxiliary = auxiliary.squeeze(-4)
auxiliary = auxiliary.squeeze(1)
assert auxiliary.shape == (num_samples, 1, C, T) + R_shape
aux = [aux.unbind(2) for aux in auxiliary.unbind(2)]

# Sample any non-compartmental time series in batch.
# TODO Consider using pyro.contrib.forecast.util.reshape_batch to
# support DiscreteCosineReparam and HaarReparam along the time dim.
non_compartmental = OrderedDict()
for name, (fn, is_regional) in self._non_compartmental.items():
fn = dist.ImproperUniform(fn.support, fn.batch_shape, fn.event_shape)
with self.time_plate, optional(self.region_plate, is_regional):
non_compartmental[name] = pyro.sample(name, fn)

# Sequentially transition.
curr = self.initialize(params)
for t, aux_t in poutine.markov(enumerate(auxiliary.unbind(2))):
for t in poutine.markov(range(T)):
with self.region_plate:
prev, curr = curr, {}
for name, aux in zip(self.compartments, aux_t.unbind(1)):
curr[name] = quantize("{}_{}".format(name, t), aux,

# Extract any non-compartmental variables.
for name, value in non_compartmental.items():
curr[name] = value[:, t:t+1]

# Extract and enumerate all compartmental variables.
for c, name in enumerate(self.compartments):
curr[name] = quantize("{}_{}".format(name, t), aux[c][t],
min=0, max=self.population,
num_quant_bins=self.num_quant_bins)
# Enable approximate inference by using aux as a
# non-enumerated proxy for enumerated compartment values.
if name in self.approximate:
curr[name + "_approx"] = aux
curr[name + "_approx"] = aux[c][t]
prev.setdefault(name + "_approx", prev[name])

self._transition_bwd(params, prev, curr, t)

self._clear_plates()
Expand All @@ -592,7 +660,7 @@ def _vectorized_model(self):
# Sample global parameters.
params = self.global_model()

# Sample the continuous reparameterizing variable.
# Sample the compartmental continuous reparameterizing variable.
shape = (C, T) + R_shape
auxiliary = pyro.sample("auxiliary",
dist.Uniform(-0.5, self.population + 0.5)
Expand All @@ -605,14 +673,25 @@ def _vectorized_model(self):
curr = OrderedDict(zip(self.compartments, curr))
logp = OrderedDict(zip(self.compartments, logp))

# Sample any non-compartmental time series in batch.
# TODO Consider using pyro.contrib.forecast.util.reshape_batch to
# support DiscreteCosineReparam and HaarReparam along the time dim.
for name, (fn, is_regional) in self._non_compartmental.items():
fn = dist.ImproperUniform(fn.support, fn.batch_shape, fn.event_shape)
with self.time_plate, optional(self.region_plate, is_regional):
curr[name] = pyro.sample(name, fn)

# Truncate final value from the right then pad initial value onto the left.
init = self.initialize(params)
prev = {}
for name in self.compartments:
value = init[name]
if isinstance(value, torch.Tensor):
value = value[..., None] # Because curr is enumerated on the right.
prev[name] = cat2(value, curr[name][:-1], dim=-3 if self.is_regional else -2)
for name, value in init.items():
if name in self.compartments:
if isinstance(value, torch.Tensor):
value = value[..., None] # Because curr is enumerated on the right.
prev[name] = cat2(value, curr[name][:-1],
dim=-3 if self.is_regional else -2)
else: # non-compartmental
prev[name] = cat2(init[name], curr[name][:-1], dim=-curr[name].dim())

# Reshape to support broadcasting, similar to EnumMessenger.
def enum_reshape(tensor, position):
Expand All @@ -638,12 +717,19 @@ def enum_reshape(tensor, position):

# Record transition factors.
with poutine.block(), poutine.trace() as tr:
with pyro.plate("time", T, dim=-1 - self.max_plate_nesting):
with self.time_plate:
t = slice(0, T, 1) # Used to slice data tensors.
self._transition_bwd(params, prev, curr, t)
tr.trace.compute_log_prob()
for name, site in tr.trace.nodes.items():
if site["type"] == "sample":
log_prob = site["log_prob"]
if log_prob.dim() <= self.max_plate_nesting: # Not enumerated.
pyro.factor("transition_" + name, site["log_prob_sum"])
continue
if self.is_regional and log_prob.shape[-1:] != R_shape:
# Poor man's tensor variable elimination.
log_prob = log_prob.expand(log_prob.shape[:-1] + R_shape) / R_shape[0]
logp[name] = site["log_prob"]

# Manually perform variable elimination.
Expand Down
Loading