diff --git a/examples/contrib/epidemiology/sir.py b/examples/contrib/epidemiology/sir.py index 25417f4c94..67ab4cf4b0 100644 --- a/examples/contrib/epidemiology/sir.py +++ b/examples/contrib/epidemiology/sir.py @@ -79,6 +79,7 @@ def generate_data(args): def infer(args, model): + parallel = args.num_chains > 1 energies = [] def hook_fn(kernel, *unused): @@ -91,16 +92,18 @@ def hook_fn(kernel, *unused): heuristic_ess_threshold=args.ess_threshold, warmup_steps=args.warmup_steps, num_samples=args.num_samples, + num_chains=args.num_chains, + mp_context="spawn" if parallel else None, max_tree_depth=args.max_tree_depth, arrowhead_mass=args.arrowhead_mass, num_quant_bins=args.num_bins, haar=args.haar, haar_full_mass=args.haar_full_mass, jit_compile=args.jit, - hook_fn=hook_fn) + hook_fn=None if parallel else hook_fn) mcmc.summary() - if args.plot: + if args.plot and energies: import matplotlib.pyplot as plt plt.figure(figsize=(6, 3)) plt.plot(energies) @@ -293,6 +296,7 @@ def main(args): parser.add_argument("-np", "--num-particles", default=1024, type=int) parser.add_argument("-ess", "--ess-threshold", default=0.5, type=float) parser.add_argument("-w", "--warmup-steps", type=int) + parser.add_argument("-c", "--num-chains", default=1, type=int) parser.add_argument("-t", "--max-tree-depth", default=5, type=int) parser.add_argument("-a", "--arrowhead-mass", action="store_true") parser.add_argument("-r", "--rng-seed", default=0, type=int) diff --git a/pyro/contrib/epidemiology/compartmental.py b/pyro/contrib/epidemiology/compartmental.py index c2860f3e0a..0b95b14cb4 100644 --- a/pyro/contrib/epidemiology/compartmental.py +++ b/pyro/contrib/epidemiology/compartmental.py @@ -1,6 +1,7 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import functools import logging import operator import re @@ -289,7 +290,7 @@ def generate(self, fixed={}): for name, site in trace.nodes.items() if site["type"] == "sample") - self._concat_series(samples) + self._concat_series(samples, trace) return samples @set_approx_log_prob_tol(0.1) @@ -365,20 +366,8 @@ def fit(self, **options): heuristic_options = {k.replace("heuristic_", ""): options.pop(k) for k in list(options) if k.startswith("heuristic_")} - - def heuristic(): - with poutine.block(): - init_values = self.heuristic(**heuristic_options) - assert isinstance(init_values, dict) - assert "auxiliary" in init_values, \ - ".heuristic() did not define auxiliary value" - if haar: - haar.user_to_aux(init_values) - logger.info("Heuristic init: {}".format(", ".join( - "{}={:0.3g}".format(k, v.item()) - for k, v in init_values.items() - if v.numel() == 1))) - return init_to_value(values=init_values) + init_strategy = init_to_generated( + generate=functools.partial(self._heuristic, haar, **heuristic_options)) # Configure a kernel. logger.info("Running inference...") @@ -387,7 +376,7 @@ def heuristic(): model = haar.reparam(model) kernel = NUTS(model, full_mass=full_mass, - init_strategy=init_to_generated(generate=heuristic), + init_strategy=init_strategy, max_plate_nesting=self.max_plate_nesting, jit_compile=options.pop("jit_compile", False), jit_options=options.pop("jit_options", None), @@ -464,7 +453,7 @@ def predict(self, forecast=0): for name, site in trace.nodes.items() if site["type"] == "sample") - self._concat_series(samples, forecast, vectorized=True) + self._concat_series(samples, trace, forecast) return samples @torch.no_grad() @@ -516,12 +505,27 @@ def heuristic(self, num_particles=1024, ess_threshold=0.5, retries=10): # Internal helpers ######################################## - def _concat_series(self, samples, forecast=0, vectorized=False): + def _heuristic(self, haar, **options): + with poutine.block(): + init_values = self.heuristic(**options) + assert isinstance(init_values, dict) + assert "auxiliary" in init_values, \ + ".heuristic() did not define auxiliary value" + if haar: + haar.user_to_aux(init_values) + logger.info("Heuristic init: {}".format(", ".join( + "{}={:0.3g}".format(k, v.item()) + for k, v in init_values.items() + if v.numel() == 1))) + return init_to_value(values=init_values) + + def _concat_series(self, samples, trace, forecast=0): """ Concatenate sequential time series into tensors, in-place. :param dict samples: A dictionary of samples. """ + time_dim = -2 if self.is_regional else -1 for name in set(self.compartments).union(self.series): pattern = name + "_[0-9]+" series = [] @@ -532,8 +536,9 @@ def _concat_series(self, samples, forecast=0, vectorized=False): continue assert len(series) == self.duration + forecast series = torch.broadcast_tensors(*map(torch.as_tensor, series)) - if vectorized and name != "obs": # TODO Generalize. - samples[name] = torch.cat(series, dim=1) + dim = time_dim - trace.nodes[name + "_0"]["fn"].event_dim + if series[0].dim() >= -dim: + samples[name] = torch.cat(series, dim=dim) else: samples[name] = torch.stack(series) diff --git a/pyro/contrib/epidemiology/distributions.py b/pyro/contrib/epidemiology/distributions.py index f2a3147d1a..7742a8641a 100644 --- a/pyro/contrib/epidemiology/distributions.py +++ b/pyro/contrib/epidemiology/distributions.py @@ -11,6 +11,7 @@ from pyro.distributions.util import is_validation_enabled _RELAX = False +_RELAX_MIN_VARIANCE = 0.1 def _all(x): @@ -98,7 +99,7 @@ def _validate_overdispersion(overdispersion): raise ValueError("Expected overdispersion < 2") -def _relaxed_binomial(total_count, probs, *, min_variance=0.25): +def _relaxed_binomial(total_count, probs): """ Returns a moment-matched :class:`~pyro.distributions.Normal` approximating a :class:`~pyro.distributions.Binomial` but allowing arbitrary real @@ -109,11 +110,11 @@ def _relaxed_binomial(total_count, probs, *, min_variance=0.25): mean = probs * total_count variance = total_count * probs * (1 - probs) - scale = variance.clamp(min=min_variance).sqrt() + scale = variance.clamp(min=_RELAX_MIN_VARIANCE).sqrt() return dist.Normal(mean, scale) -def _relaxed_beta_binomial(concentration1, concentration0, total_count, *, min_variance=0.25): +def _relaxed_beta_binomial(concentration1, concentration0, total_count): """ Returns a moment-matched :class:`~pyro.distributions.Normal` approximating a :class:`~pyro.distributions.BetaBinomial` but allowing arbitrary real @@ -128,7 +129,7 @@ def _relaxed_beta_binomial(concentration1, concentration0, total_count, *, min_v mean = beta_mean * total_count variance = beta_variance * total_count * (c + total_count) - scale = variance.clamp(min=min_variance).sqrt() + scale = variance.clamp(min=_RELAX_MIN_VARIANCE).sqrt() return dist.Normal(mean, scale) diff --git a/pyro/infer/mcmc/api.py b/pyro/infer/mcmc/api.py index c85a48d848..6169073d45 100644 --- a/pyro/infer/mcmc/api.py +++ b/pyro/infer/mcmc/api.py @@ -408,8 +408,12 @@ def model(data): # If transforms is not explicitly provided, infer automatically using # model args, kwargs. if self.transforms is None: + # Try to initialize kernel.transforms using kernel.setup(). + if getattr(self.kernel, "transforms", None) is None: + warmup_steps = 0 + self.kernel.setup(warmup_steps, *args, **kwargs) # Use `kernel.transforms` when available - if hasattr(self.kernel, 'transforms') and self.kernel.transforms is not None: + if getattr(self.kernel, "transforms", None) is not None: self.transforms = self.kernel.transforms # Else, get transforms from model (e.g. in multiprocessing). elif self.kernel.model: diff --git a/tests/contrib/epidemiology/test_distributions.py b/tests/contrib/epidemiology/test_distributions.py index a6ad24dee7..175688e310 100644 --- a/tests/contrib/epidemiology/test_distributions.py +++ b/tests/contrib/epidemiology/test_distributions.py @@ -9,7 +9,7 @@ import pyro.distributions as dist from pyro.contrib.epidemiology import beta_binomial_dist, binomial_dist, infection_dist -from pyro.contrib.epidemiology.distributions import set_relaxed_distributions +from pyro.contrib.epidemiology.distributions import _RELAX_MIN_VARIANCE, set_relaxed_distributions from tests.common import assert_close @@ -199,7 +199,7 @@ def test_relaxed_binomial(): d2 = binomial_dist(total_count, probs) assert isinstance(d2, dist.Normal) assert_close(d2.mean, d1.mean) - assert_close(d2.variance, d1.variance.clamp(min=0.25)) + assert_close(d2.variance, d1.variance.clamp(min=_RELAX_MIN_VARIANCE)) @pytest.mark.parametrize("overdispersion", [0.05, 0.1, 0.2, 0.5, 1.0]) @@ -214,7 +214,7 @@ def test_relaxed_overdispersed_binomial(overdispersion): d2 = binomial_dist(total_count, probs, overdispersion=overdispersion) assert isinstance(d2, dist.Normal) assert_close(d2.mean, d1.mean) - assert_close(d2.variance, d1.variance.clamp(min=0.25)) + assert_close(d2.variance, d1.variance.clamp(min=_RELAX_MIN_VARIANCE)) def test_relaxed_beta_binomial(): @@ -229,7 +229,7 @@ def test_relaxed_beta_binomial(): d2 = beta_binomial_dist(concentration1, concentration0, total_count) assert isinstance(d2, dist.Normal) assert_close(d2.mean, d1.mean) - assert_close(d2.variance, d1.variance.clamp(min=0.25)) + assert_close(d2.variance, d1.variance.clamp(min=_RELAX_MIN_VARIANCE)) @pytest.mark.parametrize("overdispersion", [0.05, 0.1, 0.2, 0.5, 1.0]) @@ -247,4 +247,4 @@ def test_relaxed_overdispersed_beta_binomial(overdispersion): overdispersion=overdispersion) assert isinstance(d2, dist.Normal) assert_close(d2.mean, d1.mean) - assert_close(d2.variance, d1.variance.clamp(min=0.25)) + assert_close(d2.variance, d1.variance.clamp(min=_RELAX_MIN_VARIANCE)) diff --git a/tests/contrib/epidemiology/test_models.py b/tests/contrib/epidemiology/test_models.py index 4f88924dd3..ee47348b8a 100644 --- a/tests/contrib/epidemiology/test_models.py +++ b/tests/contrib/epidemiology/test_models.py @@ -17,6 +17,7 @@ logger = logging.getLogger(__name__) +@pytest.mark.filterwarnings("ignore:num_chains") @pytest.mark.parametrize("duration", [3, 7]) @pytest.mark.parametrize("forecast", [0, 7]) @pytest.mark.parametrize("options", [ @@ -33,6 +34,9 @@ {"jit_compile": True}, {"jit_compile": True, "haar_full_mass": 2}, {"jit_compile": True, "num_quant_bins": 2}, + {"num_chains": 2, "mp_context": "spawn"}, + {"num_chains": 2, "mp_context": "spawn", "num_quant_bins": 2}, + {"num_chains": 2, "mp_context": "spawn", "jit_compile": True}, ], ids=str) def test_simple_sir_smoke(duration, forecast, options): population = 100 @@ -54,6 +58,7 @@ def test_simple_sir_smoke(duration, forecast, options): # Predict and forecast. samples = model.predict(forecast=forecast) + num_samples *= options.get("num_chains", 1) assert samples["S"].shape == (num_samples, duration + forecast) assert samples["I"].shape == (num_samples, duration + forecast) diff --git a/tests/test_examples.py b/tests/test_examples.py index 459266a43e..18dab90083 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -34,6 +34,7 @@ 'contrib/autoname/tree_data.py --num-epochs=1', 'contrib/cevae/synthetic.py --num-epochs=1', 'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2', + 'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -c=2', 'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -e=2', 'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -k=1', 'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -e=2 -k=1',