diff --git a/pyro/contrib/epidemiology/__init__.py b/pyro/contrib/epidemiology/__init__.py index 00c46370c6..e433ae63bd 100644 --- a/pyro/contrib/epidemiology/__init__.py +++ b/pyro/contrib/epidemiology/__init__.py @@ -4,7 +4,7 @@ from .compartmental import CompartmentalModel from .distributions import infection_dist from .seir import OverdispersedSEIRModel, SimpleSEIRModel -from .sir import OverdispersedSIRModel, SimpleSIRModel, SparseSIRModel +from .sir import OverdispersedSIRModel, SimpleSIRModel, SparseSIRModel, UnknownStartSIRModel __all__ = [ "CompartmentalModel", @@ -13,5 +13,6 @@ "SimpleSEIRModel", "SimpleSIRModel", "SparseSIRModel", + "UnknownStartSIRModel", "infection_dist", ] diff --git a/pyro/contrib/epidemiology/compartmental.py b/pyro/contrib/epidemiology/compartmental.py index baddcbcbbb..74b99c5962 100644 --- a/pyro/contrib/epidemiology/compartmental.py +++ b/pyro/contrib/epidemiology/compartmental.py @@ -132,8 +132,8 @@ def heuristic(self, num_particles=1024): # Fill in sample site values. init = self.generate(init) - init["auxiliary"] = torch.stack( - [init[name] for name in self.compartments]).clamp_(min=0.5) + init["auxiliary"] = torch.stack([init[name] for name in self.compartments]) + init["auxiliary"].clamp_(min=0.5, max=self.population - 0.5) return init def global_model(self): @@ -215,6 +215,7 @@ def generate(self, fixed={}): :returns: A dictionary mapping sample site name to sampled value. :rtype: dict """ + fixed = {k: torch.as_tensor(v) for k, v in fixed.items()} model = self._generative_model model = poutine.condition(model, fixed) trace = poutine.trace(model).get_trace() @@ -296,6 +297,7 @@ def predict(self, forecast=0): This may be run only after :meth:`fit` and draws the same ``num_samples`` as passed to :meth:`fit`. + :param int forecast: The number of time steps to forecast forward. :returns: A dictionary mapping sample site name (or compartment name) to a tensor whose first dimension corresponds to sample batching. :rtype: dict @@ -303,7 +305,6 @@ def predict(self, forecast=0): if not self.samples: raise RuntimeError("Missing samples, try running .fit() first") samples = self.samples - print("DEBUG {}".format(samples.keys())) num_samples = len(next(iter(samples.values()))) particle_plate = pyro.plate("particles", num_samples, dim=-1 - self.max_plate_nesting) @@ -358,7 +359,7 @@ def _concat_series(self, samples, forecast=0): series.append(samples.pop(key)) if series: assert len(series) == self.duration + forecast - series = torch.broadcast_tensors(*series) + series = torch.broadcast_tensors(*map(torch.as_tensor, series)) samples[name] = torch.stack(series, dim=-1) def _generative_model(self, forecast=0): diff --git a/pyro/contrib/epidemiology/sir.py b/pyro/contrib/epidemiology/sir.py index 538ae21e2d..a1da86e842 100644 --- a/pyro/contrib/epidemiology/sir.py +++ b/pyro/contrib/epidemiology/sir.py @@ -1,8 +1,12 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import torch +from torch.nn.functional import pad + import pyro import pyro.distributions as dist +from pyro.ops.indexing import Index from .compartmental import CompartmentalModel from .distributions import infection_dist @@ -318,3 +322,164 @@ def transition_bwd(self, params, prev, curr, t): pyro.sample("obs_{}".format(t), dist.Delta(curr["O"]).mask(self.mask[t]), obs=self.data[t]) + + +class UnknownStartSIRModel(CompartmentalModel): + """ + Susceptible-Infected-Recovered model with unknown date of first infection. + + To customize this model we recommend forking and editing this class. + + This is a stochastic discrete-time discrete-state model with three + compartments: "S" for susceptible, "I" for infected, and "R" for + recovered individuals (the recovered individuals are implicit: ``R = + population - S - I``) with transitions ``S -> I -> R``. + + This model demonstrates: + + 1. How to incorporate spontaneous infections from external sources; + 2. How to incorporate time-varying piecewise ``rho`` by supporting + forecasting in :meth:`transition_fwd` and using the + :class:`~pyro.ops.index.ing.Index` helper in :meth:`transition_bwd`. + 3. How to override the :meth:`predict` method to compute extra + statistics. + + :param int population: Total ``population = S + I + R``. + :param float recovery_time: Mean recovery time (duration in state + ``I``). Must be greater than 1. + :param int pre_obs_window: Number of time steps before beginning ``data`` + where the initial infection may have occurred. Must be positive. + :param iterable data: Time series of new observed infections. Each time + step is Binomial distributed between 0 and the number of ``S -> I`` + transitions. This allows false negative but no false positives. + """ + + def __init__(self, population, recovery_time, pre_obs_window, data): + compartments = ("S", "I") # R is implicit. + duration = pre_obs_window + len(data) + super().__init__(compartments, duration, population) + + assert isinstance(recovery_time, float) + assert recovery_time > 1 + self.recovery_time = recovery_time + + assert isinstance(pre_obs_window, int) and pre_obs_window > 0 + self.pre_obs_window = pre_obs_window + self.post_obs_window = len(data) + + # We set a small time-constant external infecton rate such that on + # average there is a single external infection during the + # pre_obs_window. This allows unknown time of initial infection + # without introducing long-range coupling across time. + self.external_rate = 1 / pre_obs_window + + # Prepend data with zeros. + if isinstance(data, list): + data = [0.] * self.pre_obs_window + data + else: + data = pad(data, (self.pre_obs_window, 0), value=0.) + self.data = data + + series = ("S2I", "I2R", "obs") + full_mass = [("R0", "rho0", "rho1")] + + def global_model(self): + tau = self.recovery_time + R0 = pyro.sample("R0", dist.LogNormal(0., 1.)) + + # Assume two different response rates: rho0 before any observations + # were made (in pre_obs_window), followed by a higher response rate rho1 + # after observations were made (in post_obs_window). + rho0 = pyro.sample("rho0", dist.Uniform(0, 1)) + rho1 = pyro.sample("rho1", dist.Uniform(0, 1)) + # Whereas each of rho0,rho1 are scalars (possibly batched over samples), + # we construct a time series rho with an extra time dim on the right. + rho = torch.cat([ + rho0.unsqueeze(-1).expand(rho0.shape + (self.pre_obs_window,)), + rho1.unsqueeze(-1).expand(rho1.shape + (self.post_obs_window,)), + ], dim=-1) + + # Model external infections as an infectious pseudo-individual added + # to num_infectious when sampling S2I below. + X = self.external_rate * tau / R0 + + return R0, X, tau, rho + + def initialize(self, params): + # Start with no internal infections. + return {"S": self.population, "I": 0} + + def transition_fwd(self, params, state, t): + R0, X, tau, rho = params + + # Sample flows between compartments. + S2I = pyro.sample("S2I_{}".format(t), + infection_dist(individual_rate=R0 / tau, + num_susceptible=state["S"], + num_infectious=state["I"] + X, + population=self.population)) + I2R = pyro.sample("I2R_{}".format(t), + dist.Binomial(state["I"], 1 / tau)) + + # Update compartments with flows. + state["S"] = state["S"] - S2I + state["I"] = state["I"] + S2I - I2R + + # In .transition_fwd() t will always be an integer but may lie outside + # of [0,self.duration) when forecasting. + rho_t = rho[..., t] if t < self.duration else rho[..., -1] + data_t = self.data[t] if t < self.duration else None + + # Condition on observations. + pyro.sample("obs_{}".format(t), + dist.ExtendedBinomial(S2I, rho_t), + obs=data_t) + + def transition_bwd(self, params, prev, curr, t): + R0, X, tau, rho = params + + # Reverse the flow computation. + S2I = prev["S"] - curr["S"] + I2R = prev["I"] - curr["I"] + S2I + + # Condition on flows between compartments. + pyro.sample("S2I_{}".format(t), + infection_dist(individual_rate=R0 / tau, + num_susceptible=prev["S"], + num_infectious=prev["I"] + X, + population=self.population), + obs=S2I) + pyro.sample("I2R_{}".format(t), + dist.ExtendedBinomial(prev["I"], 1 / tau), + obs=I2R) + + # In .transition_bwd() t may be either an integer in [0,self.duration) + # or may be a nonstandard slicing object like (Ellipsis, None, None); + # we use the Index(-)[-] helper to support indexing with nonstandard t. + rho_t = Index(rho)[..., t] + + # Condition on observations. + pyro.sample("obs_{}".format(t), + dist.ExtendedBinomial(S2I, rho_t), + obs=self.data[t]) + + def predict(self, forecast=0): + """ + Augments + :meth:`~pyro.contrib.epidemiology.compartmental.Compartmental.predict` + with samples of ``first_infection`` i.e. the first time index at which + the infection ``I`` becomes nonzero. Note this is measured from the + beginning of ``pre_obs_window``, not the beginning of data. + + :param int forecast: The number of time steps to forecast forward. + :returns: A dictionary mapping sample site name (or compartment name) + to a tensor whose first dimension corresponds to sample batching. + :rtype: dict + """ + samples = super().predict(forecast) + + # Extract the time index of the first infection (samples["I"] > 0) + # for each sample trajectory in the samples["I"] tensor. + samples["first_infection"] = samples["I"].cumsum(-1).eq(0).sum(-1) + + return samples diff --git a/pyro/ops/indexing.py b/pyro/ops/indexing.py index 88ba854964..7fc3c82e29 100644 --- a/pyro/ops/indexing.py +++ b/pyro/ops/indexing.py @@ -8,6 +8,76 @@ def _is_batched(arg): return isinstance(arg, torch.Tensor) and arg.dim() +def _flatten(args, out): + if isinstance(args, tuple): + for arg in args: + _flatten(arg, out) + else: + # Combine consecutive Ellipsis. + if args is Ellipsis and out and out[-1] is Ellipsis: + return + out.append(args) + + +def index(tensor, args): + """ + Indexing with nested tuples. + + See also the convenience wrapper :class:`Index`. + + This is useful for writing indexing code that is compatible with multiple + interpretations, e.g. scalar evaluation, vectorized evaluation, or + reshaping. + + For example suppose ``x`` is a parameter with ``x.dim() == 2`` and we wish + to generalize the expression ``x[..., t]`` where ``t`` can be any of: + + - a scalar ``t=1`` as in ``x[..., 1]``; + - a slice ``t=slice(None)`` equivalent to ``x[..., :]``; or + - a reshaping operation ``t=(Ellipsis, None)`` equivalent to + ``x.unsqueeze(-1)``. + + While naive indexing would work for the first two , the third example would + result in a nested tuple ``(Ellipsis, (Ellipsis, None))``. This helper + flattens that nested tuple and combines consecutive ``Ellipsis``. + + :param torch.Tensor tensor: A tensor to be indexed. + :param tuple args: An index, as args to ``__getitem__``. + :returns: A flattened interpetation of ``tensor[args]``. + :rtype: torch.Tensor + """ + if not isinstance(args, tuple): + return tensor[args] + if not args: + return tensor + + # Flatten. + flat = [] + _flatten(args, flat) + args = tuple(flat) + + return tensor[args] + + +class Index: + """ + Convenience wrapper around :func:`index`. + + The following are equivalent:: + + Index(x)[..., i, j, :] + index(x, (Ellipsis, i, j, slice(None))) + + :param torch.Tensor tensor: A tensor to be indexed. + :return: An object with a special :meth:`__getitem__` method. + """ + def __init__(self, tensor): + self._tensor = tensor + + def __getitem__(self, args): + return index(self._tensor, args) + + def vindex(tensor, args): """ Vectorized advanced indexing with broadcasting semantics. diff --git a/tests/contrib/epidemiology/test_sir.py b/tests/contrib/epidemiology/test_sir.py index cf06cd6bb9..a77c14cfd1 100644 --- a/tests/contrib/epidemiology/test_sir.py +++ b/tests/contrib/epidemiology/test_sir.py @@ -7,7 +7,7 @@ import pytest import torch -from pyro.contrib.epidemiology import OverdispersedSIRModel, SimpleSIRModel, SparseSIRModel +from pyro.contrib.epidemiology import OverdispersedSIRModel, SimpleSIRModel, SparseSIRModel, UnknownStartSIRModel logger = logging.getLogger(__name__) @@ -112,3 +112,48 @@ def test_sparse_smoke(duration, forecast, options): for O in samples["O"]: logger.info("imputed:\n{}".format(O)) assert (O[:duration][mask] == data[mask]).all() + + +@pytest.mark.parametrize("pre_obs_window", [6]) +@pytest.mark.parametrize("duration", [8]) +@pytest.mark.parametrize("forecast", [0, 7]) +@pytest.mark.parametrize("options", [ + {}, + {"dct": 1.}, + {"num_quant_bins": 8}, +], ids=str) +def test_unknown_start_smoke(duration, pre_obs_window, forecast, options): + population = 100 + recovery_time = 7.0 + + # Generate data. + data = [None] * duration + model = UnknownStartSIRModel(population, recovery_time, pre_obs_window, data) + for attempt in range(100): + data = model.generate({"R0": 1.5, "rho0": 0.1, "rho1": 0.5})["obs"] + assert len(data) == pre_obs_window + duration + data = data[pre_obs_window:] + if data.sum(): + break + assert data.sum() > 0, "failed to generate positive data" + logger.info("data:\n{}".format(data)) + + # Infer. + model = UnknownStartSIRModel(population, recovery_time, pre_obs_window, data) + num_samples = 5 + model.fit(warmup_steps=1, num_samples=num_samples, max_tree_depth=2, **options) + + # Predict and forecast. + samples = model.predict(forecast=forecast) + assert samples["S"].shape == (num_samples, pre_obs_window + duration + forecast) + assert samples["I"].shape == (num_samples, pre_obs_window + duration + forecast) + + # Check time of first infection. + t = samples["first_infection"] + logger.info("first_infection:\n{}".format(t)) + assert t.shape == (num_samples,) + assert (0 <= t).all() + assert (t < pre_obs_window + duration).all() + for I, ti in zip(samples["I"], t): + assert (I[:ti] == 0).all() + assert I[ti] > 0 diff --git a/tests/ops/test_indexing.py b/tests/ops/test_indexing.py index 882593daf9..4bf76a791a 100644 --- a/tests/ops/test_indexing.py +++ b/tests/ops/test_indexing.py @@ -8,10 +8,18 @@ import pyro.distributions as dist from pyro.distributions.util import broadcast_shape -from pyro.ops.indexing import Vindex +from pyro.ops.indexing import Index, Vindex from tests.common import assert_equal +class TensorMock: + def __getitem__(self, args): + return args + + +tensor_mock = TensorMock() + + def z(*args): return torch.zeros(*args, dtype=torch.long) @@ -132,3 +140,19 @@ def test_hmm_example(prev_enum_dim, curr_enum_dim): expected = probs_x[x_prev.unsqueeze(-1), x_curr.unsqueeze(-1), torch.arange(hidden_dim)] actual = Vindex(probs_x)[x_prev, x_curr, :] assert_equal(actual, expected) + + +@pytest.mark.parametrize("args,expected", [ + (0, 0), + (1, 1), + (None, None), + (slice(1, 2, 3), slice(1, 2, 3)), + (Ellipsis, Ellipsis), + ((0, 1, None, slice(1, 2, 3), Ellipsis), (0, 1, None, slice(1, 2, 3), Ellipsis)), + (((0, 1), (None, slice(1, 2, 3)), Ellipsis), (0, 1, None, slice(1, 2, 3), Ellipsis)), + ((Ellipsis, None), (Ellipsis, None)), + ((Ellipsis, (Ellipsis, None)), (Ellipsis, None)), + ((Ellipsis, (Ellipsis, None, None)), (Ellipsis, None, None)), +]) +def test_index(args, expected): + assert Index(tensor_mock)[args] == expected