diff --git a/examples/contrib/epidemiology/regional.py b/examples/contrib/epidemiology/regional.py new file mode 100644 index 0000000000..90aed28d74 --- /dev/null +++ b/examples/contrib/epidemiology/regional.py @@ -0,0 +1,155 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import logging + +import torch + +import pyro +from pyro.contrib.epidemiology import RegionalSIRModel + +logging.basicConfig(format='%(message)s', level=logging.INFO) + + +def Model(args, data): + assert 0 <= args.coupling <= 1, args.coupling + population = torch.full((args.num_regions,), float(args.population)) + coupling = torch.eye(args.num_regions).clamp(min=args.coupling) + return RegionalSIRModel(population, coupling, args.recovery_time, data) + + +def generate_data(args): + extended_data = [None] * (args.duration + args.forecast) + model = Model(args, extended_data) + logging.info("Simulating from a {}".format(type(model).__name__)) + for attempt in range(100): + samples = model.generate({"R0": args.basic_reproduction_number, + "rho_c1": 10 * args.response_rate, + "rho_c0": 10 * (1 - args.response_rate)}) + obs = samples["obs"][:args.duration] + S2I = samples["S2I"] + + obs_sum = int(obs.sum()) + S2I_sum = int(S2I[:args.duration].sum()) + if obs_sum >= args.min_observations: + logging.info("Observed {:d}/{:d} infections:\n{}".format( + obs_sum, S2I_sum, " ".join(str(int(x)) for x in obs[:, 0]))) + return {"S2I": S2I, "obs": obs} + + raise ValueError("Failed to generate {} observations. Try increasing " + "--population or decreasing --min-observations" + .format(args.min_observations)) + + +def infer(args, model): + energies = [] + + def hook_fn(kernel, *unused): + e = float(kernel._potential_energy_last) + energies.append(e) + if args.verbose: + logging.info("potential = {:0.6g}".format(e)) + + mcmc = model.fit(heuristic_num_particles=args.num_particles, + warmup_steps=args.warmup_steps, + num_samples=args.num_samples, + max_tree_depth=args.max_tree_depth, + num_quant_bins=args.num_bins, + hook_fn=hook_fn) + + mcmc.summary() + if args.plot: + import matplotlib.pyplot as plt + plt.figure(figsize=(6, 3)) + plt.plot(energies) + plt.xlabel("MCMC step") + plt.ylabel("potential energy") + plt.title("MCMC energy trace") + plt.tight_layout() + + return model.samples + + +def predict(args, model, truth): + samples = model.predict(forecast=args.forecast) + S2I = samples["S2I"] + median = S2I.median(dim=0).values + lines = ["Median prediction of new infections (starting on day 0):"] + for r in range(args.num_regions): + lines.append("Region {}: {}".format(r, " ".join(map(str, map(int, median[:, r]))))) + logging.info("\n".join(lines)) + + # Optionally plot the latent and forecasted series of new infections. + if args.plot: + import matplotlib.pyplot as plt + fig, axes = plt.subplots(args.num_regions, sharex=True, + figsize=(6, 1 + args.num_regions)) + time = torch.arange(args.duration + args.forecast) + p05 = S2I.kthvalue(int(round(0.5 + 0.05 * args.num_samples)), dim=0).values + p95 = S2I.kthvalue(int(round(0.5 + 0.95 * args.num_samples)), dim=0).values + for r, ax in enumerate(axes): + ax.fill_between(time, p05[:, r], p95[:, r], color="red", alpha=0.3, label="90% CI") + ax.plot(time, median[:, r], "r-", label="median") + ax.plot(time[:args.duration], model.data[:, r], "k.", label="observed") + ax.plot(time, truth[:, r], "k--", label="truth") + ax.axvline(args.duration - 0.5, color="gray", lw=1) + ax.set_xlim(0, len(time) - 1) + ax.set_ylim(0, None) + axes[0].set_title("New infections among {} regions each of size {}" + .format(args.num_regions, args.population)) + axes[args.num_regions // 2].set_ylabel("inf./day") + axes[-1].set_xlabel("day after first infection") + axes[-1].legend(loc="upper left") + plt.tight_layout() + plt.subplots_adjust(hspace=0) + + +def main(args): + pyro.enable_validation(__debug__) + pyro.set_rng_seed(args.rng_seed) + + # Generate data. + dataset = generate_data(args) + obs = dataset["obs"] + + # Run inference. + model = Model(args, obs) + infer(args, model) + + # Predict latent time series. + predict(args, model, truth=dataset["S2I"]) + + +if __name__ == "__main__": + assert pyro.__version__.startswith('1.3.1') + parser = argparse.ArgumentParser( + description="Regional compartmental epidemiology modeling using HMC") + parser.add_argument("-p", "--population", default=1000, type=int) + parser.add_argument("-r", "--num-regions", default=2, type=int) + parser.add_argument("-c", "--coupling", default=0.1, type=float) + parser.add_argument("-m", "--min-observations", default=3, type=int) + parser.add_argument("-d", "--duration", default=20, type=int) + parser.add_argument("-f", "--forecast", default=10, type=int) + parser.add_argument("-R0", "--basic-reproduction-number", default=1.5, type=float) + parser.add_argument("-tau", "--recovery-time", default=7.0, type=float) + parser.add_argument("-rho", "--response-rate", default=0.5, type=float) + parser.add_argument("-n", "--num-samples", default=200, type=int) + parser.add_argument("-np", "--num-particles", default=1024, type=int) + parser.add_argument("-w", "--warmup-steps", default=100, type=int) + parser.add_argument("-t", "--max-tree-depth", default=5, type=int) + parser.add_argument("-nb", "--num-bins", default=4, type=int) + parser.add_argument("--rng-seed", default=0, type=int) + parser.add_argument("--cuda", action="store_true") + parser.add_argument("--verbose", action="store_true") + parser.add_argument("--plot", action="store_true") + args = parser.parse_args() + + if args.cuda: + torch.set_default_tensor_type(torch.cuda.FloatTensor) + + main(args) + + if args.plot: + import matplotlib.pyplot as plt + plt.show() diff --git a/pyro/contrib/epidemiology/__init__.py b/pyro/contrib/epidemiology/__init__.py index e433ae63bd..77b52bb2e1 100644 --- a/pyro/contrib/epidemiology/__init__.py +++ b/pyro/contrib/epidemiology/__init__.py @@ -4,12 +4,13 @@ from .compartmental import CompartmentalModel from .distributions import infection_dist from .seir import OverdispersedSEIRModel, SimpleSEIRModel -from .sir import OverdispersedSIRModel, SimpleSIRModel, SparseSIRModel, UnknownStartSIRModel +from .sir import OverdispersedSIRModel, RegionalSIRModel, SimpleSIRModel, SparseSIRModel, UnknownStartSIRModel __all__ = [ "CompartmentalModel", "OverdispersedSEIRModel", "OverdispersedSIRModel", + "RegionalSIRModel", "SimpleSEIRModel", "SimpleSIRModel", "SparseSIRModel", diff --git a/pyro/contrib/epidemiology/compartmental.py b/pyro/contrib/epidemiology/compartmental.py index 74b99c5962..5feaaa669e 100644 --- a/pyro/contrib/epidemiology/compartmental.py +++ b/pyro/contrib/epidemiology/compartmental.py @@ -2,13 +2,15 @@ # SPDX-License-Identifier: Apache-2.0 import logging +import operator import re from abc import ABC, abstractmethod from collections import OrderedDict +from contextlib import ExitStack +from functools import reduce import torch from torch.distributions import biject_to, constraints -from torch.nn.functional import pad import pyro.distributions as dist import pyro.distributions.hmm @@ -19,7 +21,7 @@ from pyro.infer.reparam import DiscreteCosineReparam from pyro.util import warn_if_nan -from .util import quantize, quantize_enumerate +from .util import align_samples, cat2, clamp, quantize, quantize_enumerate logger = logging.getLogger(__name__) @@ -71,20 +73,38 @@ def transition_bwd(self, params, prev, curr, t): ... in returned sample dictionaries. :ivar dict samples: Dictionary of posterior samples. :param list compartments: A list of strings of compartment names. - :param int duration: - :param int population: + :param int duration: The number of discrete time steps in this model. + :param population: Either the total population of a single-region model or + a tensor of each region's population in a regional model. + :type population: int or torch.Tensor + :param tuple approximate: Names of compartments for which pointwise + approximations should be provided in :meth:`transition_bwd`, e.g. if you + specify ``approximate=("I")`` then the ``prev["I_approx"]`` will be a + continuous-valued non-enumerated point estimate of ``prev["I"]``. + Approximations are useful to reduce computational cost. Approximations + are continuous-valued with support ``(-0.5, population + 0.5)``. + :param int num_quant_bins: Number of quantization bins in the auxiliary + variable spline. Defaults to 4. """ def __init__(self, compartments, duration, population, *, - num_quant_bins=4): + num_quant_bins=4, approximate=()): super().__init__() assert isinstance(duration, int) assert duration >= 1 self.duration = duration - assert isinstance(population, int) - assert population >= 2 + if isinstance(population, torch.Tensor): + assert population.dim() == 1 + assert (population >= 1).all() + self.is_regional = True + self.max_plate_nesting = 1 + else: + assert isinstance(population, int) + assert population >= 2 + self.is_regional = False + self.max_plate_nesting = 0 self.population = population compartments = tuple(compartments) @@ -92,12 +112,32 @@ def __init__(self, compartments, duration, population, *, assert len(compartments) == len(set(compartments)) self.compartments = compartments + assert isinstance(approximate, tuple) + assert all(name in compartments for name in approximate) + self.approximate = approximate + # Inference state. self.samples = {} + self._clear_plates() + + @property + def region_plate(self): + """ + Either a ``pyro.plate`` or a trivial ``ExitStack`` depending on whether + this model ``.is_regional``. + """ + if self._region_plate is None: + if self.is_regional: + self._region_plate = pyro.plate("region", len(self.population), dim=-1) + else: + self._region_plate = ExitStack() # Trivial context manager. + return self._region_plate + + def _clear_plates(self): + self._region_plate = None # Overridable attributes and methods ######################################## - max_plate_nesting = 0 series = () full_mass = False @@ -126,14 +166,14 @@ def heuristic(self, num_particles=1024): for t in range(1, self.duration): smc.step() - # Select the most probably hypothesis. + # Select the most probable hypothesis. i = int(smc.state._log_weights.max(0).indices) init = {key: value[i] for key, value in smc.state.items()} # Fill in sample site values. init = self.generate(init) - init["auxiliary"] = torch.stack([init[name] for name in self.compartments]) - init["auxiliary"].clamp_(min=0.5, max=self.population - 0.5) + aux = torch.stack([init[name] for name in self.compartments], dim=0) + init["auxiliary"] = clamp(aux, min=0.5, max=self.population - 0.5) return init def global_model(self): @@ -254,6 +294,8 @@ def fit(self, **options): # Save these options for .predict(). self.num_quant_bins = options.pop("num_quant_bins", 4) self._dct = options.pop("dct", None) + if self._dct is not None and self.is_regional: + raise NotImplementedError("regional models do not support DiscreteCosineReparam") # Heuristically initialze to feasible latents. logger.info("Heuristically initializing...") @@ -287,6 +329,10 @@ def fit(self, **options): mcmc = MCMC(kernel, **options) mcmc.run() self.samples = mcmc.get_samples() + # Unsqueeze samples to align particle dim for use in poutine.condition. + # TODO refactor to an align_samples or particle_dim kwarg to MCMC.get_samples(). + self.samples = align_samples(self.samples, model, + particle_dim=-1 - self.max_plate_nesting) return mcmc # E.g. so user can run mcmc.summary(). @torch.no_grad() @@ -360,7 +406,7 @@ def _concat_series(self, samples, forecast=0): if series: assert len(series) == self.duration + forecast series = torch.broadcast_tensors(*map(torch.as_tensor, series)) - samples[name] = torch.stack(series, dim=-1) + samples[name] = torch.stack(series, dim=-2 if self.is_regional else -1) def _generative_model(self, forecast=0): """ @@ -371,52 +417,75 @@ def _generative_model(self, forecast=0): # Sample initial values. state = self.initialize(params) - state = {i: torch.tensor(float(value)) for i, value in state.items()} + state = {k: v if isinstance(v, torch.Tensor) else torch.tensor(float(v)) + for k, v in state.items()} # Sequentially transition. for t in range(self.duration + forecast): self.transition_fwd(params, state, t) - for name in self.compartments: - pyro.deterministic("{}_{}".format(name, t), state[name]) + with self.region_plate: + for name in self.compartments: + pyro.deterministic("{}_{}".format(name, t), state[name]) + + self._clear_plates() def _sequential_model(self): """ Sequential model used to sample latents in the interval [0:duration]. """ + C = len(self.compartments) + T = self.duration + R_shape = getattr(self.population, "shape", ()) # Region shape. + # Sample global parameters. params = self.global_model() # Sample the continuous reparameterizing variable. + shape = (C, T) + R_shape auxiliary = pyro.sample("auxiliary", dist.Uniform(-0.5, self.population + 0.5) - .mask(False) - .expand([len(self.compartments), self.duration]) - .to_event(2)) + .mask(False).expand(shape).to_event()) + if self.is_regional: + # This reshapes from (particle, 1, C, T, R) -> (particle, C, T, R) + # to allow aux below to have shape (particle, R) for region_plate. + auxiliary = auxiliary.squeeze(-4) # Sequentially transition. curr = self.initialize(params) - for t in poutine.markov(range(self.duration)): - aux_t = auxiliary[..., t] - prev = curr - curr = {name: quantize("{}_{}".format(name, t), aux, - min=0, max=self.population, - num_quant_bins=self.num_quant_bins) - for name, aux in zip(self.compartments, aux_t.unbind(-1))} + for t, aux_t in poutine.markov(enumerate(auxiliary.unbind(2))): + with self.region_plate: + prev, curr = curr, {} + for name, aux in zip(self.compartments, aux_t.unbind(1)): + curr[name] = quantize("{}_{}".format(name, t), aux, + min=0, max=self.population, + num_quant_bins=self.num_quant_bins) + # Enable approximate inference by using aux as a + # non-enumerated proxy for enumerated compartment values. + if name in self.approximate: + curr[name + "_approx"] = aux + prev.setdefault(name + "_approx", prev[name]) self.transition_bwd(params, prev, curr, t) + self._clear_plates() + def _vectorized_model(self): """ Vectorized model used for inference. """ + C = len(self.compartments) + T = self.duration + Q = self.num_quant_bins + R_shape = getattr(self.population, "shape", ()) # Region shape. + # Sample global parameters. params = self.global_model() # Sample the continuous reparameterizing variable. + shape = (C, T) + R_shape auxiliary = pyro.sample("auxiliary", dist.Uniform(-0.5, self.population + 0.5) - .mask(False) - .expand([len(self.compartments), self.duration]) - .to_event(2)) + .mask(False).expand(shape).to_event()) + assert auxiliary.shape == shape, "particle plates are not supported" # Manually enumerate. curr, logp = quantize_enumerate(auxiliary, min=0, max=self.population, @@ -428,41 +497,54 @@ def _vectorized_model(self): init = self.initialize(params) prev = {} for name in self.compartments: - if not isinstance(init[name], int): - raise NotImplementedError("TODO use torch.cat()") - prev[name] = pad(curr[name][:-1], (0, 0, 1, 0), value=init[name]) + value = init[name] + if isinstance(value, torch.Tensor): + value = value[..., None] # Because curr is enumerated on the right. + prev[name] = cat2(value, curr[name][:-1], dim=-3 if self.is_regional else -2) # Reshape to support broadcasting, similar to EnumMessenger. - C = len(self.compartments) - T = self.duration - Q = self.num_quant_bins # Number of quantization points. - - def enum_shape(position): - shape = [T] + [1] * (2 * C) - shape[1 + position] = Q - return torch.Size(shape) + def enum_reshape(tensor, position): + assert tensor.size(-1) == Q + assert tensor.dim() <= self.max_plate_nesting + 2 + tensor = tensor.permute(tensor.dim() - 1, *range(tensor.dim() - 1)) + shape = [Q] + [1] * (position + self.max_plate_nesting - (tensor.dim() - 2)) + shape.extend(tensor.shape[1:]) + return tensor.reshape(shape) for e, name in enumerate(self.compartments): - prev[name] = prev[name].reshape(enum_shape(e)) - curr[name] = curr[name].reshape(enum_shape(C + e)) - logp[name] = logp[name].reshape(enum_shape(C + e)) - t = (Ellipsis,) + (None,) * (2 * C) # Used to unsqueeze data tensors. + curr[name] = enum_reshape(curr[name], e) + logp[name] = enum_reshape(logp[name], e) + prev[name] = enum_reshape(prev[name], e + C) + + # Enable approximate inference by using aux as a non-enumerated proxy + # for enumerated compartment values. + for name in self.approximate: + aux = auxiliary[self.compartments.index(name)] + curr[name + "_approx"] = aux + prev[name + "_approx"] = cat2(init[name], aux[:-1], + dim=-2 if self.is_regional else -1) # Record transition factors. with poutine.block(), poutine.trace() as tr: - self.transition_bwd(params, prev, curr, t) + with pyro.plate("time", T, dim=-1 - self.max_plate_nesting): + t = slice(None) # Used to slice data tensors. + self.transition_bwd(params, prev, curr, t) + tr.trace.compute_log_prob() for name, site in tr.trace.nodes.items(): if site["type"] == "sample": - logp[name] = site["fn"].log_prob(site["value"]) + logp[name] = site["log_prob"] # Manually perform variable elimination. - logp = sum(logp.values()) - logp = logp.reshape(T, Q ** C, Q ** C) - logp = pyro.distributions.hmm._sequential_logmatmulexp(logp) - logp = logp.reshape(-1).logsumexp(0) + logp = reduce(operator.add, logp.values()) + logp = logp.reshape(Q ** C, Q ** C, T, -1) # prev, curr, T, batch + logp = logp.permute(3, 2, 0, 1).squeeze(0) # batch, T, prev, curr + logp = pyro.distributions.hmm._sequential_logmatmulexp(logp) # batch, prev, curr + logp = logp.reshape(-1, Q ** C * Q ** C).logsumexp(-1).sum() warn_if_nan(logp) pyro.factor("transition", logp) + self._clear_plates() + class _SMCModel: """ diff --git a/pyro/contrib/epidemiology/sir.py b/pyro/contrib/epidemiology/sir.py index 76de1bfec2..8388045f21 100644 --- a/pyro/contrib/epidemiology/sir.py +++ b/pyro/contrib/epidemiology/sir.py @@ -6,7 +6,6 @@ import pyro import pyro.distributions as dist -from pyro.ops.indexing import Index from .compartmental import CompartmentalModel from .distributions import infection_dist @@ -339,8 +338,7 @@ class UnknownStartSIRModel(CompartmentalModel): 1. How to incorporate spontaneous infections from external sources; 2. How to incorporate time-varying piecewise ``rho`` by supporting - forecasting in :meth:`transition_fwd` and using the - :class:`~pyro.ops.index.ing.Index` helper in :meth:`transition_bwd`. + forecasting in :meth:`transition_fwd`. 3. How to override the :meth:`predict` method to compute extra statistics. @@ -453,14 +451,9 @@ def transition_bwd(self, params, prev, curr, t): dist.ExtendedBinomial(prev["I"], 1 / tau), obs=I2R) - # In .transition_bwd() t may be either an integer in [0,self.duration) - # or may be a nonstandard slicing object like (Ellipsis, None, None); - # we use the Index(-)[-] helper to support indexing with nonstandard t. - rho_t = Index(rho)[..., t] - # Condition on observations. pyro.sample("obs_{}".format(t), - dist.ExtendedBinomial(S2I, rho_t), + dist.ExtendedBinomial(S2I, rho[..., t]), obs=self.data[t]) def predict(self, forecast=0): @@ -483,3 +476,154 @@ def predict(self, forecast=0): samples["first_infection"] = samples["I"].cumsum(-1).eq(0).sum(-1) return samples + + +class RegionalSIRModel(CompartmentalModel): + r""" + Susceptible-Infected-Recovered model with coupling across regions. + + To customize this model we recommend forking and editing this class. + + This is a stochastic discrete-time discrete-state model with three + compartments in each region: "S" for susceptible, "I" for infected, and "R" + for recovered individuals (the recovered individuals are implicit: ``R = + population - S - I``) with transitions ``S -> I -> R``. + + Regions are coupled by a ``coupling`` matrix with entries in ``[0,1]``. + The all ones matrix is equivalent to a single region. The identity matrix + is equivalent to a set of independent regions. This need not be symmetric, + but symmetric matrices are probably more physically plausible. The expected + number of new infections each time step ``S2I`` is Binomial distributed + with mean:: + + E[S2I] = S (1 - (1 - R0 / (population @ coupling)) ** (I @ coupling)) + ≈ R0 S (I @ coupling) / (population @ coupling) # for small I + + Thus in a nearly entirely susceptible population, a single infected + individual infects approximately ``R0`` new individuals on average, + independent of ``coupling``. + + This model demonstrates: + + 1. How to create a regional model with a ``population`` vector. + 2. How to model both homogeneous parameters (here ``R0``) and + heterogeneous parameters with hierarchical structure (here ``rho``) + using ``self.region_plate``. + 3. How to approximately couple regions in :meth:`transition_bwd` using + ``prev["I_approx"]``. + + :param torch.Tensor population: Tensor of per-region populations, defining + ``population = S + I + R``. + :param torch.Tensor coupling: Pairwise coupling matrix. Entries should be + in ``[0,1]``. + :param float recovery_time: Mean recovery time (duration in state ``I``). + Must be greater than 1. + :param iterable data: Time x Region sized tensor of new observed + infections. Each time step is vector of Binomials distributed between + 0 and the number of ``S -> I`` transitions. This allows false negative + but no false positives. + """ + + def __init__(self, population, coupling, recovery_time, data): + duration = len(data) + num_regions, = population.shape + assert coupling.shape == (num_regions, num_regions) + assert (0 <= coupling).all() + assert (coupling <= 1).all() + assert isinstance(recovery_time, float) + assert recovery_time > 1 + if isinstance(data, torch.Tensor): + # Data tensors should be oriented as (time, region). + assert data.shape == (duration, num_regions) + compartments = ("S", "I") # R is implicit. + + # We create a regional model by passing a vector of populations. + super().__init__(compartments, duration, population, approximate=("I",)) + + self.coupling = coupling + self.recovery_time = recovery_time + self.data = data + + series = ("S2I", "I2R", "obs") + full_mass = [("R0", "rho")] + + def global_model(self): + # Assume recovery time is a known constant. + tau = self.recovery_time + + # Assume reproductive number is unknown but homogeneous. + R0 = pyro.sample("R0", dist.LogNormal(0., 1.)) + + # Assume response rate is heterogeneous and model it with a + # hierarchical Gamma-Beta prior. + rho_c1 = pyro.sample("rho_c1", dist.Gamma(2, 1)) + rho_c0 = pyro.sample("rho_c0", dist.Gamma(2, 1)) + with self.region_plate: + rho = pyro.sample("rho", dist.Beta(rho_c1, rho_c0)) + + return R0, tau, rho + + def initialize(self, params): + # Start with a single infection in region 0. + I = torch.zeros_like(self.population) + I[0] += 1 + S = self.population - I + return {"S": S, "I": I} + + def transition_fwd(self, params, state, t): + R0, tau, rho = params + + # Account for infections from all regions. + I_coupled = state["I"] @ self.coupling + pop_coupled = self.population @ self.coupling + + with self.region_plate: + # Sample flows between compartments. + S2I = pyro.sample("S2I_{}".format(t), + infection_dist(individual_rate=R0 / tau, + num_susceptible=state["S"], + num_infectious=I_coupled, + population=pop_coupled)) + I2R = pyro.sample("I2R_{}".format(t), + dist.Binomial(state["I"], 1 / tau)) + + # Update compartments with flows. + state["S"] = state["S"] - S2I + state["I"] = state["I"] + S2I - I2R + + # Condition on observations. + pyro.sample("obs_{}".format(t), + dist.ExtendedBinomial(S2I, rho), + obs=self.data[t] if t < self.duration else None) + + def transition_bwd(self, params, prev, curr, t): + R0, tau, rho = params + + # Account for infections from all regions. This uses approximate (point + # estimate) counts I_approx for infection from other regions, but uses + # the exact (enumerated) count I for infections from one's own region. + I_coupled = prev["I_approx"] @ self.coupling + I_coupled = I_coupled + (prev["I"] - prev["I_approx"]) * self.coupling.diag() + I_coupled = I_coupled.clamp(min=0) # In case I_approx is negative. + pop_coupled = self.population @ self.coupling + + # Reverse the flow computation. + S2I = prev["S"] - curr["S"] + I2R = prev["I"] - curr["I"] + S2I + + with self.region_plate: + # Condition on flows between compartments. + pyro.sample("S2I_{}".format(t), + infection_dist(individual_rate=R0 / tau, + num_susceptible=prev["S"], + num_infectious=I_coupled, + population=pop_coupled), + obs=S2I) + pyro.sample("I2R_{}".format(t), + dist.ExtendedBinomial(prev["I"], 1 / tau), + obs=I2R) + + # Condition on observations. + pyro.sample("obs_{}".format(t), + dist.ExtendedBinomial(S2I, rho), + obs=self.data[t]) diff --git a/pyro/contrib/epidemiology/util.py b/pyro/contrib/epidemiology/util.py index 16cd8acc1f..004613e58f 100644 --- a/pyro/contrib/epidemiology/util.py +++ b/pyro/contrib/epidemiology/util.py @@ -2,14 +2,82 @@ # SPDX-License-Identifier: Apache-2.0 import numpy - import torch import pyro import pyro.distributions as dist +import pyro.poutine as poutine +from pyro.distributions.util import broadcast_shape from pyro.ops.tensor_utils import safe_log +def clamp(tensor, *, min=None, max=None): + """ + Like :func:`torch.clamp` but dispatches to :func:`torch.min` and/or + :func:`torch.max` if ``min`` and/or ``max`` is a :class:`~torch.Tensor`. + """ + if isinstance(min, torch.Tensor): + tensor = torch.max(tensor, min) + min = None + if isinstance(max, torch.Tensor): + tensor = torch.min(tensor, max) + max = None + if min is None and max is None: + return tensor + return tensor.clamp(min=min, max=max) + + +def cat2(lhs, rhs, *, dim=-1): + """ + Like ``torch.cat([lhs, rhs], dim=dim)`` but dispatches to + :func:`torch.nn.functional.pad` in case one of ``lhs`` or ``rhs`` is a + scalar. + """ + assert dim < 0 + if not isinstance(lhs, torch.Tensor): + pad = (0, 0) * (-1 - dim) + (1, 0) + return torch.nn.functional.pad(rhs, pad, value=lhs) + if not isinstance(rhs, torch.Tensor): + pad = (0, 0) * (-1 - dim) + (0, 1) + return torch.nn.functional.pad(lhs, pad, value=rhs) + + diff = lhs.dim() - rhs.dim() + if diff > 0: + rhs = rhs.expand((1,) * diff + rhs.shape) + elif diff < 0: + diff = -diff + lhs = lhs.expand((1,) * diff + lhs.shape) + shape = list(broadcast_shape(lhs.shape, rhs.shape)) + shape[dim] = -1 + return torch.cat([lhs.expand(shape), rhs.expand(shape)], dim=dim) + + +@torch.no_grad() +def align_samples(samples, model, particle_dim): + """ + Unsqueeze stacked samples such that their particle dim all aligns. + This traces ``model`` to determine the ``event_dim`` of each site. + """ + assert particle_dim < 0 + + sample = {name: value[0] for name, value in samples.items()} + with poutine.block(), poutine.trace() as tr, poutine.condition(data=sample): + model() + + samples = samples.copy() + for name, value in samples.items(): + event_dim = tr.trace.nodes[name]["fn"].event_dim + pad = event_dim - particle_dim - value.dim() + if pad < 0: + raise ValueError("Cannot align samples, try moving particle_dim left") + if pad > 0: + shape = value.shape[:1] + (1,) * pad + value.shape[1:] + print("DEBUG reshaping {} : {} -> {}".format(name, value.shape, shape)) + samples[name] = value.reshape(shape) + + return samples + + # this 8 x 10 tensor encodes the coefficients of 8 10-dimensional polynomials # that are used to construct the num_quant_bins=16 quantization strategy @@ -124,9 +192,17 @@ def compute_bin_probs(s, num_quant_bins=3): return probs +def _all(x): + return x.all() if isinstance(x, torch.Tensor) else x + + +def _unsqueeze(x): + return x.unsqueeze(-1) if isinstance(x, torch.Tensor) else x + + def quantize(name, x_real, min, max, num_quant_bins=4): """Randomly quantize in a way that preserves probability mass.""" - assert min < max + assert _all(min < max) lb = x_real.detach().floor() probs = compute_bin_probs(x_real - lb, num_quant_bins=num_quant_bins) @@ -144,7 +220,7 @@ def quantize(name, x_real, min, max, num_quant_bins=4): def quantize_enumerate(x_real, min, max, num_quant_bins=4): """Quantize, then manually enumerate.""" - assert min < max + assert _all(min < max) lb = x_real.detach().floor() probs = compute_bin_probs(x_real - lb, num_quant_bins=num_quant_bins) @@ -155,7 +231,7 @@ def quantize_enumerate(x_real, min, max, num_quant_bins=4): q = torch.arange(arange_min, arange_max) x = lb.unsqueeze(-1) + q - x = torch.max(x, 2 * min - 1 - x) - x = torch.min(x, 2 * max + 1 - x) + x = torch.max(x, 2 * _unsqueeze(min) - 1 - x) + x = torch.min(x, 2 * _unsqueeze(max) + 1 - x) return x, logits diff --git a/pyro/contrib/randomvariable/__init__.py b/pyro/contrib/randomvariable/__init__.py index e9d97ab2e0..8abfb1c723 100644 --- a/pyro/contrib/randomvariable/__init__.py +++ b/pyro/contrib/randomvariable/__init__.py @@ -1,3 +1,6 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + from pyro.contrib.randomvariable.random_variable import RandomVariable __all__ = [ diff --git a/tests/contrib/epidemiology/__init__.py b/tests/contrib/epidemiology/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/contrib/epidemiology/test_sir.py b/tests/contrib/epidemiology/test_sir.py index a77c14cfd1..f7bae63bcc 100644 --- a/tests/contrib/epidemiology/test_sir.py +++ b/tests/contrib/epidemiology/test_sir.py @@ -7,7 +7,8 @@ import pytest import torch -from pyro.contrib.epidemiology import OverdispersedSIRModel, SimpleSIRModel, SparseSIRModel, UnknownStartSIRModel +from pyro.contrib.epidemiology import (OverdispersedSIRModel, RegionalSIRModel, SimpleSIRModel, SparseSIRModel, + UnknownStartSIRModel) logger = logging.getLogger(__name__) @@ -157,3 +158,36 @@ def test_unknown_start_smoke(duration, pre_obs_window, forecast, options): for I, ti in zip(samples["I"], t): assert (I[:ti] == 0).all() assert I[ti] > 0 + + +@pytest.mark.parametrize("duration", [3, 7]) +@pytest.mark.parametrize("forecast", [0, 7]) +@pytest.mark.parametrize("options", [ + {}, + {"num_quant_bins": 8}, +], ids=str) +def test_regional_smoke(duration, forecast, options): + num_regions = 6 + coupling = torch.eye(num_regions).clamp(min=0.1) + population = torch.tensor([2., 3., 4., 10., 100., 1000.]) + recovery_time = 7.0 + + # Generate data. + model = RegionalSIRModel(population, coupling, recovery_time, + data=[None] * duration) + for attempt in range(100): + data = model.generate({"R0": 1.5, "rho": 0.5})["obs"] + assert data.shape == (duration, num_regions) + if data.sum(): + break + assert data.sum() > 0, "failed to generate positive data" + + # Infer. + model = RegionalSIRModel(population, coupling, recovery_time, data) + num_samples = 5 + model.fit(warmup_steps=1, num_samples=num_samples, max_tree_depth=2, **options) + + # Predict and forecast. + samples = model.predict(forecast=forecast) + assert samples["S"].shape == (num_samples, duration + forecast, num_regions) + assert samples["I"].shape == (num_samples, duration + forecast, num_regions) diff --git a/tests/contrib/epidemiology/test_util.py b/tests/contrib/epidemiology/test_util.py new file mode 100644 index 0000000000..bc602620c6 --- /dev/null +++ b/tests/contrib/epidemiology/test_util.py @@ -0,0 +1,44 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +from pyro.contrib.epidemiology.util import cat2, clamp +from tests.common import assert_equal + + +@pytest.mark.parametrize("min", [None, 0., (), (2,)], ids=str) +@pytest.mark.parametrize("max", [None, 1., (), (2,)], ids=str) +@pytest.mark.parametrize("shape", [(2,), (3, 2)], ids=str) +def test_clamp(shape, min, max): + tensor = torch.randn(shape) + if isinstance(min, tuple): + min = torch.zeros(min) + if isinstance(max, tuple): + max = torch.ones(max) + + actual = clamp(tensor, min=min, max=max) + + expected = tensor + if min is not None: + min = torch.as_tensor(min).expand_as(tensor) + expected = torch.max(min, expected) + if max is not None: + max = torch.as_tensor(max).expand_as(tensor) + expected = torch.min(max, expected) + + assert_equal(actual, expected) + + +@pytest.mark.parametrize("shape", [(), (2,), (2, 3), (2, 3, 4)], ids=str) +def test_cat2_scalar(shape): + tensor = torch.randn(shape) + for dim in range(-len(shape), 0): + expected_shape = list(shape) + expected_shape[dim] += 1 + expected_shape = torch.Size(expected_shape) + assert cat2(tensor, 0, dim=dim).shape == expected_shape + assert cat2(0, tensor, dim=dim).shape == expected_shape + assert (cat2(tensor, 0, dim=dim).unbind(dim)[-1] == 0).all() + assert (cat2(0, tensor, dim=dim).unbind(dim)[0] == 0).all() diff --git a/tests/contrib/forecast/__init__.py b/tests/contrib/forecast/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_examples.py b/tests/test_examples.py index 8eb27b55af..de93d5c77e 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -40,6 +40,7 @@ 'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --dct=1', 'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -nb=8', 'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -nb=16 --dct=1', + 'contrib/epidemiology/regional.py -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2', 'contrib/forecast/bart.py --num-steps=2 --stride=99999', 'contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --batch-size=1000', 'contrib/gp/sv-dkl.py --binary --epochs=1 --num-inducing=4 --batch-size=1000', @@ -105,6 +106,7 @@ 'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --cuda', 'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -nb=16 --cuda', 'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --dct=1 --cuda', + 'contrib/epidemiology/regional.py -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --cuda', 'contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --cuda', 'lkj.py --n=50 --num-chains=1 --warmup-steps=100 --num-samples=200 --cuda', 'dmm/dmm.py --num-epochs=1 --cuda',