diff --git a/examples/contrib/epidemiology/sir.py b/examples/contrib/epidemiology/sir.py index 196a856f0d..1569f1ba38 100644 --- a/examples/contrib/epidemiology/sir.py +++ b/examples/contrib/epidemiology/sir.py @@ -66,7 +66,8 @@ def hook_fn(kernel, *unused): if args.verbose: logging.info("potential = {:0.6g}".format(e)) - mcmc = model.fit(warmup_steps=args.warmup_steps, + mcmc = model.fit(heuristic_num_particles=args.num_particles, + warmup_steps=args.warmup_steps, num_samples=args.num_samples, max_tree_depth=args.max_tree_depth, num_quant_bins=args.num_bins, @@ -184,6 +185,7 @@ def main(args): parser.add_argument("--dct", type=float, help="smoothing for discrete cosine reparameterizer") parser.add_argument("-n", "--num-samples", default=200, type=int) + parser.add_argument("-np", "--num-particles", default=1024, type=int) parser.add_argument("-w", "--warmup-steps", default=100, type=int) parser.add_argument("-t", "--max-tree-depth", default=5, type=int) parser.add_argument("-r", "--rng-seed", default=0, type=int) diff --git a/pyro/contrib/epidemiology/compartmental.py b/pyro/contrib/epidemiology/compartmental.py index 2d3f20d805..baddcbcbbb 100644 --- a/pyro/contrib/epidemiology/compartmental.py +++ b/pyro/contrib/epidemiology/compartmental.py @@ -245,6 +245,8 @@ def fit(self, **options): Defaults to 4. :param float dct: If provided, use a discrete cosine reparameterizer with this value as smoothness. + :param int heuristic_num_particles: Passed to :meth:`heuristic` as + ``num_particles``. Defaults to 1024. :returns: An MCMC object for diagnostics, e.g. ``MCMC.summary()``. :rtype: ~pyro.infer.mcmc.api.MCMC """ @@ -254,7 +256,9 @@ def fit(self, **options): # Heuristically initialze to feasible latents. logger.info("Heuristically initializing...") - init_values = self.heuristic() + heuristic_options = {k.replace("heuristic_", ""): options.pop(k) + for k in list(options) if k.startswith("heuristic_")} + init_values = self.heuristic(**heuristic_options) assert isinstance(init_values, dict) assert "auxiliary" in init_values, \ ".heuristic() did not define auxiliary value" diff --git a/tests/test_examples.py b/tests/test_examples.py index d209274fdc..8eb27b55af 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -33,13 +33,13 @@ 'contrib/autoname/mixture.py --num-epochs=1', 'contrib/autoname/tree_data.py --num-epochs=1', 'contrib/cevae/synthetic.py --num-epochs=1', - 'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2', - 'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -e=2', - 'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -k=1', - 'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -e=2 -k=1', - 'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --dct=1', - 'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -nb=8', - 'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -nb=16 --dct=1', + 'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2', + 'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -e=2', + 'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -k=1', + 'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -e=2 -k=1', + 'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --dct=1', + 'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -nb=8', + 'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -nb=16 --dct=1', 'contrib/forecast/bart.py --num-steps=2 --stride=99999', 'contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --batch-size=1000', 'contrib/gp/sv-dkl.py --binary --epochs=1 --num-inducing=4 --batch-size=1000',