Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SIR model with unknown time of initial infection #2460

Merged
merged 9 commits into from
May 5, 2020
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyro/contrib/epidemiology/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .compartmental import CompartmentalModel
from .distributions import infection_dist
from .seir import OverdispersedSEIRModel, SimpleSEIRModel
from .sir import OverdispersedSIRModel, SimpleSIRModel, SparseSIRModel
from .sir import OverdispersedSIRModel, SimpleSIRModel, SparseSIRModel, UnknownStartSIRModel

__all__ = [
"CompartmentalModel",
Expand All @@ -13,5 +13,6 @@
"SimpleSEIRModel",
"SimpleSIRModel",
"SparseSIRModel",
"UnknownStartSIRModel",
"infection_dist",
]
9 changes: 5 additions & 4 deletions pyro/contrib/epidemiology/compartmental.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ def heuristic(self, num_particles=1024):

# Fill in sample site values.
init = self.generate(init)
init["auxiliary"] = torch.stack(
[init[name] for name in self.compartments]).clamp_(min=0.5)
init["auxiliary"] = torch.stack([init[name] for name in self.compartments])
init["auxiliary"].clamp_(min=0.5, max=self.population - 0.5)
return init

def global_model(self):
Expand Down Expand Up @@ -215,6 +215,7 @@ def generate(self, fixed={}):
:returns: A dictionary mapping sample site name to sampled value.
:rtype: dict
"""
fixed = {k: torch.as_tensor(v) for k, v in fixed.items()}
model = self._generative_model
model = poutine.condition(model, fixed)
trace = poutine.trace(model).get_trace()
Expand Down Expand Up @@ -296,14 +297,14 @@ def predict(self, forecast=0):
This may be run only after :meth:`fit` and draws the same
``num_samples`` as passed to :meth:`fit`.

:param int forecast: The number of time steps to forecast forward.
:returns: A dictionary mapping sample site name (or compartment name)
to a tensor whose first dimension corresponds to sample batching.
:rtype: dict
"""
if not self.samples:
raise RuntimeError("Missing samples, try running .fit() first")
samples = self.samples
print("DEBUG {}".format(samples.keys()))
num_samples = len(next(iter(samples.values())))
particle_plate = pyro.plate("particles", num_samples,
dim=-1 - self.max_plate_nesting)
Expand Down Expand Up @@ -358,7 +359,7 @@ def _concat_series(self, samples, forecast=0):
series.append(samples.pop(key))
if series:
assert len(series) == self.duration + forecast
series = torch.broadcast_tensors(*series)
series = torch.broadcast_tensors(*map(torch.as_tensor, series))
samples[name] = torch.stack(series, dim=-1)

def _generative_model(self, forecast=0):
Expand Down
154 changes: 154 additions & 0 deletions pyro/contrib/epidemiology/sir.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import torch
from torch.nn.functional import pad

import pyro
import pyro.distributions as dist
from pyro.ops.indexing import Index

from .compartmental import CompartmentalModel
from .distributions import infection_dist
Expand Down Expand Up @@ -318,3 +322,153 @@ def transition_bwd(self, params, prev, curr, t):
pyro.sample("obs_{}".format(t),
dist.Delta(curr["O"]).mask(self.mask[t]),
obs=self.data[t])


class UnknownStartSIRModel(CompartmentalModel):
"""
Susceptible-Infected-Recovered model with unknown date of first infection.

To customize this model we recommend forking and editing this class.

This is a stochastic discrete-time discrete-state model with three
compartments: "S" for susceptible, "I" for infected, and "R" for
recovered individuals (the recovered individuals are implicit: ``R =
population - S - I``) with transitions ``S -> I -> R``.

This model demonstrates:

1. How to incorporate spontaneous infections from external sources;
2. How to incorporate time-varying piecwise ``rho`` by supporting
fritzo marked this conversation as resolved.
Show resolved Hide resolved
forecasting in :meth:`transition_fwd` and using the
:class:`~pyro.ops.index.ing.Index` helper in :meth:`transition_bwd`.
3. How to override the :meth:`predict` method to compute extra
statistics.

:param int population: Total ``population = S + I + R``.
:param float recovery_time: Mean recovery time (duration in state
``I``). Must be greater than 1.
:param int pre_window: Number of time steps before beginning ``data``
fritzo marked this conversation as resolved.
Show resolved Hide resolved
where the initial infection may have occurred. Must be positive.
:param iterable data: Time series of new observed infections. Each time
step is Binomial distributed between 0 and the number of ``S -> I``
transitions. This allows false negative but no false positives.
"""

def __init__(self, population, recovery_time, pre_window, data):
compartments = ("S", "I") # R is implicit.
duration = pre_window + len(data)
super().__init__(compartments, duration, population)

assert isinstance(recovery_time, float)
assert recovery_time > 1
self.recovery_time = recovery_time

assert isinstance(pre_window, int) and pre_window > 0
self.pre_window = pre_window
self.post_window = len(data)

# Expect a single external infection during the pre_window.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please explain this assumption?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

self.external_rate = 1 / pre_window

# Prepend data with zeros.
if isinstance(data, list):
data = [0.] * self.pre_window + data
else:
data = pad(data, (self.pre_window, 0), value=0.)
self.data = data

series = ("S2I", "I2R", "obs")
full_mass = [("R0", "rho0", "rho1")]

def global_model(self):
tau = self.recovery_time
R0 = pyro.sample("R0", dist.LogNormal(0., 1.))

# Assume two different response rates: rho0 before any observations
# were made (in pre_window), followed by a higher response rate rho1
# after observations were made (in post_window).
rho0 = pyro.sample("rho0", dist.Uniform(0, 1))
rho1 = pyro.sample("rho1", dist.Uniform(0, 1))
rho = torch.cat([
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment on shape of rho?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

rho0.unsqueeze(-1).expand(rho0.shape + (self.pre_window,)),
rho1.unsqueeze(-1).expand(rho1.shape + (self.post_window,)),
], dim=-1)

# Model external infections as an infectious pseudo-individual added
# to num_infectious when sampling S2I below.
X = self.external_rate * tau / R0

return R0, X, tau, rho

def initialize(self, params):
# Start with no internal infections.
return {"S": self.population, "I": 0}

def transition_fwd(self, params, state, t):
R0, X, tau, rho = params

# Sample flows between compartments.
S2I = pyro.sample("S2I_{}".format(t),
infection_dist(individual_rate=R0 / tau,
num_susceptible=state["S"],
num_infectious=state["I"] + X,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i guess the X contribution is active at all times?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. That modeling decision seems more parsimonious, simpler to code, and quantitatively close to the alternative (one initial infection at an unknown time) because during an outbreak, new infections should be dominated by local internal infections.

population=self.population))
I2R = pyro.sample("I2R_{}".format(t),
dist.Binomial(state["I"], 1 / tau))

# Update compartments with flows.
state["S"] = state["S"] - S2I
state["I"] = state["I"] + S2I - I2R

# In .transition_fwd() t will always be an integer but may lie outside
# of [0,self.duration) when forecasting.
rho_t = rho[..., t] if t < self.duration else rho[..., -1]
data_t = self.data[t] if t < self.duration else None

# Condition on observations.
pyro.sample("obs_{}".format(t),
dist.ExtendedBinomial(S2I, rho_t),
obs=data_t)

def transition_bwd(self, params, prev, curr, t):
R0, X, tau, rho = params

# Reverse the flow computation.
S2I = prev["S"] - curr["S"]
I2R = prev["I"] - curr["I"] + S2I

# Condition on flows between compartments.
pyro.sample("S2I_{}".format(t),
infection_dist(individual_rate=R0 / tau,
num_susceptible=prev["S"],
num_infectious=prev["I"] + X,
population=self.population),
obs=S2I)
pyro.sample("I2R_{}".format(t),
dist.ExtendedBinomial(prev["I"], 1 / tau),
obs=I2R)

# In .transition_bwd() t may be either an integer in [0,self.duration)
# or may be a nonstandard slicing object like (Ellipsis, None, None);
# we use the Index(-)[-] helper to support indexing with nonstandard t.
rho_t = Index(rho)[..., t]

# Condition on observations.
pyro.sample("obs_{}".format(t),
dist.ExtendedBinomial(S2I, rho_t),
obs=self.data[t])

def predict(self, forecast=0):
"""
Augments
:meth:`~pyro.contrib.epidemiology.compartmental.Compartmental.predict`
with samples of ``first_infection``.

:param int forecast: The number of time steps to forecast forward.
:returns: A dictionary mapping sample site name (or compartment name)
to a tensor whose first dimension corresponds to sample batching.
:rtype: dict
"""
samples = super().predict(forecast)
samples["first_infection"] = samples["I"].cumsum(-1).eq(0).sum(-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please comment on this quantity?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return samples
70 changes: 70 additions & 0 deletions pyro/ops/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,76 @@ def _is_batched(arg):
return isinstance(arg, torch.Tensor) and arg.dim()


def _flatten(args, out):
if isinstance(args, tuple):
for arg in args:
_flatten(arg, out)
else:
# Combine consecutive Ellipsis.
if args is Ellipsis and out and out[-1] is Ellipsis:
return
out.append(args)


def index(tensor, args):
"""
Indexing with nested tuples.

See also the convenience wrapper :class:`Index`.

This is useful for writing indexing code that is compatible with multiple
interpretations, e.g. scalar evaluation, vectorized evaluation, or
reshaping.

For example suppose ``x`` is a parameter with ``x.dim() == 2`` and we wish
to generalize the expression ``x[..., t]`` where ``t`` can be any of:

- a scalar ``t=1`` as in ``x[..., 1]``;
- a slice ``t=slice(None)`` equivalent to ``x[..., :]``; or
- a reshaping operation ``t=(Ellipsis, None)`` equivalent to
``x.unsqueeze(-1)``.
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved

While naive indexing would work for the first two , the third example would
result in a nested tuple ``(Ellipsis, (Ellipsis, None))``. This helper
flattens that nested tuple and combines consecutive ``Ellipsis``.

:param torch.Tensor tensor: A tensor to be indexed.
:param tuple args: An index, as args to ``__getitem__``.
:returns: A flattened interpetation of ``tensor[args]``.
:rtype: torch.Tensor
"""
if not isinstance(args, tuple):
return tensor[args]
if not args:
return tensor

# Flatten.
flat = []
_flatten(args, flat)
args = tuple(flat)

return tensor[args]


class Index:
"""
Convenience wrapper around :func:`index`.

The following are equivalent::

Index(x)[..., i, j, :]
index(x, (Ellipsis, i, j, slice(None)))

:param torch.Tensor tensor: A tensor to be indexed.
:return: An object with a special :meth:`__getitem__` method.
"""
def __init__(self, tensor):
self._tensor = tensor

def __getitem__(self, args):
return index(self._tensor, args)


def vindex(tensor, args):
"""
Vectorized advanced indexing with broadcasting semantics.
Expand Down
47 changes: 46 additions & 1 deletion tests/contrib/epidemiology/test_sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
import torch

from pyro.contrib.epidemiology import OverdispersedSIRModel, SimpleSIRModel, SparseSIRModel
from pyro.contrib.epidemiology import OverdispersedSIRModel, SimpleSIRModel, SparseSIRModel, UnknownStartSIRModel

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -112,3 +112,48 @@ def test_sparse_smoke(duration, forecast, options):
for O in samples["O"]:
logger.info("imputed:\n{}".format(O))
assert (O[:duration][mask] == data[mask]).all()


@pytest.mark.parametrize("pre_window", [6])
@pytest.mark.parametrize("duration", [8])
@pytest.mark.parametrize("forecast", [0, 7])
@pytest.mark.parametrize("options", [
{},
{"dct": 1.},
{"num_quant_bins": 8},
], ids=str)
def test_unknown_start_smoke(duration, pre_window, forecast, options):
population = 100
recovery_time = 7.0

# Generate data.
data = [None] * duration
model = UnknownStartSIRModel(population, recovery_time, pre_window, data)
for attempt in range(100):
data = model.generate({"R0": 1.5, "rho0": 0.1, "rho1": 0.5})["obs"]
assert len(data) == pre_window + duration
data = data[pre_window:]
if data.sum():
break
assert data.sum() > 0, "failed to generate positive data"
logger.info("data:\n{}".format(data))

# Infer.
model = UnknownStartSIRModel(population, recovery_time, pre_window, data)
num_samples = 5
model.fit(warmup_steps=1, num_samples=num_samples, max_tree_depth=2, **options)

# Predict and forecast.
samples = model.predict(forecast=forecast)
assert samples["S"].shape == (num_samples, pre_window + duration + forecast)
assert samples["I"].shape == (num_samples, pre_window + duration + forecast)

# Check time of first infection.
t = samples["first_infection"]
logger.info("first_infection:\n{}".format(t))
assert t.shape == (num_samples,)
assert (0 <= t).all()
assert (t < pre_window + duration).all()
for I, ti in zip(samples["I"], t):
assert (I[:ti] == 0).all()
assert I[ti] > 0
Loading