Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement a regional SIR model with approximate inference #2466

Merged
merged 31 commits into from
May 10, 2020
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
bbbd486
WIP implement unknown start time SIR model
fritzo Apr 30, 2020
43480e7
Merge branch 'dev' into sir-truncated
fritzo May 1, 2020
7a21fc5
Add .predict() method
fritzo May 1, 2020
d4cc9fa
Merge branch 'dev' into sir-truncated
fritzo May 1, 2020
3f9c925
Fix indexing logic
fritzo May 1, 2020
29825c8
Add test for Index()[]
fritzo May 1, 2020
3d73031
Fix docs
fritzo May 1, 2020
72c15a6
WIP sketch regional model
fritzo May 4, 2020
27b20d1
Address review comments
fritzo May 4, 2020
5b4c2de
Merge branch 'sir-truncated' into sir-regional
fritzo May 4, 2020
fad40cd
Fix typo
fritzo May 4, 2020
b3751e6
Merge branch 'sir-truncated' into sir-regional
fritzo May 4, 2020
ccce15d
Merge branch 'dev' into sir-regional
fritzo May 5, 2020
4691214
Merge branch 'dev' into sir-regional
fritzo May 5, 2020
00fcdd7
WIP refactor CompartmentalModel
fritzo May 6, 2020
8aba67e
WIP move enum dimensions to left
fritzo May 6, 2020
d1def5d
Order dimensions as EPTR and with aux as CTR
fritzo May 7, 2020
0cd7a81
Fix bugs
fritzo May 7, 2020
f9a2275
More fixes
fritzo May 7, 2020
7b82a47
Tweak docs
fritzo May 7, 2020
ed0d55f
Merge branch 'dev' into sir-regional
fritzo May 7, 2020
94b8aa2
Add __init__.py files to test dir to pacify pytest
fritzo May 7, 2020
b10fda9
Refactor to use an align_samples() helper
fritzo May 7, 2020
6e3388e
Make one parameter heterogeneous
fritzo May 7, 2020
1488c69
Add simple example script
fritzo May 8, 2020
1a3b62c
Add --coupling command line option
fritzo May 8, 2020
7403488
Add regional.py to test_examples.py
fritzo May 8, 2020
6c79086
Merge branch 'dev' into sir-regional
fritzo May 9, 2020
e1fde53
Expose approximation interface
fritzo May 9, 2020
bc03b28
Improve docs
fritzo May 9, 2020
9354464
Fix typo
fritzo May 10, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyro/contrib/epidemiology/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
from .compartmental import CompartmentalModel
from .distributions import infection_dist
from .seir import OverdispersedSEIRModel, SimpleSEIRModel
from .sir import OverdispersedSIRModel, SimpleSIRModel, SparseSIRModel, UnknownStartSIRModel
from .sir import OverdispersedSIRModel, RegionalSIRModel, SimpleSIRModel, SparseSIRModel, UnknownStartSIRModel

__all__ = [
"CompartmentalModel",
"OverdispersedSEIRModel",
"OverdispersedSIRModel",
"RegionalSIRModel",
"SimpleSEIRModel",
"SimpleSIRModel",
"SparseSIRModel",
Expand Down
158 changes: 112 additions & 46 deletions pyro/contrib/epidemiology/compartmental.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
# SPDX-License-Identifier: Apache-2.0

import logging
import operator
import re
from abc import ABC, abstractmethod
from collections import OrderedDict
from contextlib import ExitStack
from functools import reduce

import torch
from torch.distributions import biject_to, constraints
from torch.nn.functional import pad

import pyro.distributions as dist
import pyro.distributions.hmm
Expand All @@ -19,7 +21,7 @@
from pyro.infer.reparam import DiscreteCosineReparam
from pyro.util import warn_if_nan

from .util import quantize, quantize_enumerate
from .util import align_samples, cat2, clamp, quantize, quantize_enumerate

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -83,8 +85,16 @@ def __init__(self, compartments, duration, population, *,
assert duration >= 1
self.duration = duration

assert isinstance(population, int)
assert population >= 2
if isinstance(population, torch.Tensor):
assert population.dim() == 1
assert (population >= 1).all()
self.is_regional = True
self.max_plate_nesting = 1
else:
assert isinstance(population, int)
assert population >= 2
self.is_regional = False
self.max_plate_nesting = 0
self.population = population

compartments = tuple(compartments)
Expand All @@ -94,10 +104,26 @@ def __init__(self, compartments, duration, population, *,

# Inference state.
self.samples = {}
self._clear_plates()

@property
def region_plate(self):
"""
Either a ``pyro.plate`` or a trivial ``ExitStack`` depending on whether
this model ``.is_regional``.
"""
if self._region_plate is None:
if self.is_regional:
self._region_plate = pyro.plate("region", len(self.population), dim=-1)
else:
self._region_plate = ExitStack() # Trivial context manager.
return self._region_plate

def _clear_plates(self):
self._region_plate = None

# Overridable attributes and methods ########################################

max_plate_nesting = 0
series = ()
full_mass = False

Expand Down Expand Up @@ -126,14 +152,14 @@ def heuristic(self, num_particles=1024):
for t in range(1, self.duration):
smc.step()

# Select the most probably hypothesis.
# Select the most probable hypothesis.
i = int(smc.state._log_weights.max(0).indices)
init = {key: value[i] for key, value in smc.state.items()}

# Fill in sample site values.
init = self.generate(init)
init["auxiliary"] = torch.stack([init[name] for name in self.compartments])
init["auxiliary"].clamp_(min=0.5, max=self.population - 0.5)
aux = torch.stack([init[name] for name in self.compartments], dim=0)
init["auxiliary"] = clamp(aux, min=0.5, max=self.population - 0.5)
return init

def global_model(self):
Expand Down Expand Up @@ -287,6 +313,10 @@ def fit(self, **options):
mcmc = MCMC(kernel, **options)
mcmc.run()
self.samples = mcmc.get_samples()
# Unsqueeze samples to align particle dim for use in poutine.condition.
# TODO refactor to an align_samples or particle_dim kwarg to MCMC.get_samples().
self.samples = align_samples(self.samples, model,
particle_dim=-1 - self.max_plate_nesting)
return mcmc # E.g. so user can run mcmc.summary().

@torch.no_grad()
Expand Down Expand Up @@ -360,7 +390,7 @@ def _concat_series(self, samples, forecast=0):
if series:
assert len(series) == self.duration + forecast
series = torch.broadcast_tensors(*map(torch.as_tensor, series))
samples[name] = torch.stack(series, dim=-1)
samples[name] = torch.stack(series, dim=-2 if self.is_regional else -1)

def _generative_model(self, forecast=0):
"""
Expand All @@ -371,52 +401,75 @@ def _generative_model(self, forecast=0):

# Sample initial values.
state = self.initialize(params)
state = {i: torch.tensor(float(value)) for i, value in state.items()}
state = {k: v if isinstance(v, torch.Tensor) else torch.tensor(float(v))
for k, v in state.items()}

# Sequentially transition.
for t in range(self.duration + forecast):
self.transition_fwd(params, state, t)
for name in self.compartments:
pyro.deterministic("{}_{}".format(name, t), state[name])
with self.region_plate:
for name in self.compartments:
pyro.deterministic("{}_{}".format(name, t), state[name])

self._clear_plates()

def _sequential_model(self):
"""
Sequential model used to sample latents in the interval [0:duration].
"""
C = len(self.compartments)
T = self.duration
R_shape = getattr(self.population, "shape", ()) # Region shape.

# Sample global parameters.
params = self.global_model()

# Sample the continuous reparameterizing variable.
shape = (C, T) + R_shape
auxiliary = pyro.sample("auxiliary",
dist.Uniform(-0.5, self.population + 0.5)
.mask(False)
.expand([len(self.compartments), self.duration])
.to_event(2))
.mask(False).expand(shape).to_event())
if self.is_regional:
# This reshapes from (particle, 1, C, T, R) -> (particle, C, T, R)
# to allow aux below to have shape (particle, R) for region_plate.
auxiliary = auxiliary.squeeze(-4)

# Sequentially transition.
curr = self.initialize(params)
for t in poutine.markov(range(self.duration)):
aux_t = auxiliary[..., t]
prev = curr
curr = {name: quantize("{}_{}".format(name, t), aux,
min=0, max=self.population,
num_quant_bins=self.num_quant_bins)
for name, aux in zip(self.compartments, aux_t.unbind(-1))}
for t, aux_t in poutine.markov(enumerate(auxiliary.unbind(2))):
with self.region_plate:
prev, curr = curr, {}
for name, aux in zip(self.compartments, aux_t.unbind(1)):
curr[name] = quantize("{}_{}".format(name, t), aux,
min=0, max=self.population,
num_quant_bins=self.num_quant_bins)
# In regional models, enable approximate inference by using aux
# as a non-enumerated proxy for enumerated compartment values.
if self.is_regional:
curr[name + "_approx"] = aux
prev.setdefault(name + "_approx", prev[name])
martinjankowiak marked this conversation as resolved.
Show resolved Hide resolved
self.transition_bwd(params, prev, curr, t)

self._clear_plates()

def _vectorized_model(self):
"""
Vectorized model used for inference.
"""
C = len(self.compartments)
T = self.duration
Q = self.num_quant_bins
R_shape = getattr(self.population, "shape", ()) # Region shape.

# Sample global parameters.
params = self.global_model()

# Sample the continuous reparameterizing variable.
shape = (C, T) + R_shape
auxiliary = pyro.sample("auxiliary",
dist.Uniform(-0.5, self.population + 0.5)
.mask(False)
.expand([len(self.compartments), self.duration])
.to_event(2))
.mask(False).expand(shape).to_event())
assert auxiliary.shape == shape, "particle plates are not supported"

# Manually enumerate.
curr, logp = quantize_enumerate(auxiliary, min=0, max=self.population,
Expand All @@ -428,41 +481,54 @@ def _vectorized_model(self):
init = self.initialize(params)
prev = {}
for name in self.compartments:
if not isinstance(init[name], int):
raise NotImplementedError("TODO use torch.cat()")
prev[name] = pad(curr[name][:-1], (0, 0, 1, 0), value=init[name])
value = init[name]
if isinstance(value, torch.Tensor):
value = value[..., None] # Because curr is enumerated on the right.
prev[name] = cat2(value, curr[name][:-1], dim=-3 if self.is_regional else -2)

# Reshape to support broadcasting, similar to EnumMessenger.
C = len(self.compartments)
T = self.duration
Q = self.num_quant_bins # Number of quantization points.

def enum_shape(position):
shape = [T] + [1] * (2 * C)
shape[1 + position] = Q
return torch.Size(shape)
def enum_reshape(tensor, position):
assert tensor.size(-1) == Q
assert tensor.dim() <= self.max_plate_nesting + 2
tensor = tensor.permute(tensor.dim() - 1, *range(tensor.dim() - 1))
shape = [Q] + [1] * (position + self.max_plate_nesting - (tensor.dim() - 2))
shape.extend(tensor.shape[1:])
return tensor.reshape(shape)

for e, name in enumerate(self.compartments):
prev[name] = prev[name].reshape(enum_shape(e))
curr[name] = curr[name].reshape(enum_shape(C + e))
logp[name] = logp[name].reshape(enum_shape(C + e))
t = (Ellipsis,) + (None,) * (2 * C) # Used to unsqueeze data tensors.
curr[name] = enum_reshape(curr[name], e)
logp[name] = enum_reshape(logp[name], e)
prev[name] = enum_reshape(prev[name], e + C)

# In regional models, enable approximate inference by using aux
# as a non-enumerated proxy for enumerated compartment values.
if self.is_regional:
for name, aux in zip(self.compartments, auxiliary):
curr[name + "_approx"] = aux
prev[name + "_approx"] = cat2(init[name], aux[:-1],
dim=-2 if self.is_regional else -1)

# Record transition factors.
with poutine.block(), poutine.trace() as tr:
self.transition_bwd(params, prev, curr, t)
with pyro.plate("time", T, dim=-1 - self.max_plate_nesting):
t = slice(None) # Used to slice data tensors.
self.transition_bwd(params, prev, curr, t)
tr.trace.compute_log_prob()
for name, site in tr.trace.nodes.items():
if site["type"] == "sample":
logp[name] = site["fn"].log_prob(site["value"])
logp[name] = site["log_prob"]

# Manually perform variable elimination.
logp = sum(logp.values())
logp = logp.reshape(T, Q ** C, Q ** C)
logp = pyro.distributions.hmm._sequential_logmatmulexp(logp)
logp = logp.reshape(-1).logsumexp(0)
logp = reduce(operator.add, logp.values())
logp = logp.reshape(Q ** C, Q ** C, T, -1) # prev, curr, T, batch
logp = logp.permute(3, 2, 0, 1).squeeze(0) # batch, T, prev, curr
logp = pyro.distributions.hmm._sequential_logmatmulexp(logp) # batch, prev, curr
logp = logp.reshape(-1, Q ** C * Q ** C).logsumexp(-1).sum()
warn_if_nan(logp)
pyro.factor("transition", logp)

self._clear_plates()


class _SMCModel:
"""
Expand Down
Loading