diff --git a/examples/contrib/epidemiology/regional.py b/examples/contrib/epidemiology/regional.py index b7fdb5532d..1f2b39210a 100644 --- a/examples/contrib/epidemiology/regional.py +++ b/examples/contrib/epidemiology/regional.py @@ -59,6 +59,7 @@ def hook_fn(kernel, *unused): num_quant_bins=args.num_bins, haar=args.haar, haar_full_mass=args.haar_full_mass, + jit_compile=args.jit, hook_fn=hook_fn) mcmc.summary() @@ -149,6 +150,8 @@ def main(args): parser.add_argument("--single", action="store_false", dest="double") parser.add_argument("--rng-seed", default=0, type=int) parser.add_argument("--cuda", action="store_true") + parser.add_argument("--jit", action="store_true", default=True) + parser.add_argument("--nojit", action="store_true", dest="jit") parser.add_argument("--verbose", action="store_true") parser.add_argument("--plot", action="store_true") args = parser.parse_args() diff --git a/examples/contrib/epidemiology/sir.py b/examples/contrib/epidemiology/sir.py index 13c23332c3..25417f4c94 100644 --- a/examples/contrib/epidemiology/sir.py +++ b/examples/contrib/epidemiology/sir.py @@ -96,6 +96,7 @@ def hook_fn(kernel, *unused): num_quant_bins=args.num_bins, haar=args.haar, haar_full_mass=args.haar_full_mass, + jit_compile=args.jit, hook_fn=hook_fn) mcmc.summary() @@ -299,6 +300,8 @@ def main(args): parser.add_argument("--double", action="store_true", default=True) parser.add_argument("--single", action="store_false", dest="double") parser.add_argument("--cuda", action="store_true") + parser.add_argument("--jit", action="store_true", default=True) + parser.add_argument("--nojit", action="store_true", dest="jit") parser.add_argument("--verbose", action="store_true") parser.add_argument("--plot", action="store_true") args = parser.parse_args() diff --git a/pyro/contrib/epidemiology/compartmental.py b/pyro/contrib/epidemiology/compartmental.py index 14accbf888..c2860f3e0a 100644 --- a/pyro/contrib/epidemiology/compartmental.py +++ b/pyro/contrib/epidemiology/compartmental.py @@ -382,7 +382,6 @@ def heuristic(): # Configure a kernel. logger.info("Running inference...") - max_tree_depth = options.pop("max_tree_depth", 5) model = self._relaxed_model if self.relaxed else self._quantized_model if haar: model = haar.reparam(model) @@ -390,7 +389,11 @@ def heuristic(): full_mass=full_mass, init_strategy=init_to_generated(generate=heuristic), max_plate_nesting=self.max_plate_nesting, - max_tree_depth=max_tree_depth) + jit_compile=options.pop("jit_compile", False), + jit_options=options.pop("jit_options", None), + ignore_jit_warnings=options.pop("ignore_jit_warnings", True), + target_accept_prob=options.pop("target_accept_prob", 0.8), + max_tree_depth=options.pop("max_tree_depth", 5)) if options.pop("arrowhead_mass", False): kernel.mass_matrix_adapter = ArrowheadMassMatrix() @@ -700,8 +703,8 @@ def _quantized_model(self): # Manually enumerate. curr, logp = quantize_enumerate(auxiliary, min=0, max=self.population, num_quant_bins=self.num_quant_bins) - curr = OrderedDict(zip(self.compartments, curr)) - logp = OrderedDict(zip(self.compartments, logp)) + curr = OrderedDict(zip(self.compartments, curr.unbind(0))) + logp = OrderedDict(zip(self.compartments, logp.unbind(0))) curr.update(non_compartmental) # Truncate final value from the right then pad initial value onto the left. @@ -778,7 +781,7 @@ def _relaxed_model(self): auxiliary, non_compartmental = self._sample_auxiliary() # Split tensors into current state. - curr = dict(zip(self.compartments, auxiliary)) + curr = dict(zip(self.compartments, auxiliary.unbind(0))) curr.update(non_compartmental) # Truncate final value from the right then pad initial value onto the left. diff --git a/tests/contrib/epidemiology/test_models.py b/tests/contrib/epidemiology/test_models.py index d445299ec7..4f88924dd3 100644 --- a/tests/contrib/epidemiology/test_models.py +++ b/tests/contrib/epidemiology/test_models.py @@ -30,6 +30,9 @@ {"num_quant_bins": 16}, {"num_quant_bins": 2, "haar": True}, {"arrowhead_mass": True}, + {"jit_compile": True}, + {"jit_compile": True, "haar_full_mass": 2}, + {"jit_compile": True, "num_quant_bins": 2}, ], ids=str) def test_simple_sir_smoke(duration, forecast, options): population = 100 @@ -429,6 +432,9 @@ def test_regional_smoke(duration, forecast, options): {"haar": True}, {"haar_full_mass": 2}, {"num_quant_bins": 2}, + {"jit_compile": True}, + {"jit_compile": True, "haar_full_mass": 2}, + {"jit_compile": True, "num_quant_bins": 2}, ], ids=str) def test_hetero_regional_smoke(duration, forecast, options): num_regions = 6 diff --git a/tests/test_examples.py b/tests/test_examples.py index 49a2acb049..459266a43e 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -33,19 +33,19 @@ 'contrib/autoname/mixture.py --num-epochs=1', 'contrib/autoname/tree_data.py --num-epochs=1', 'contrib/cevae/synthetic.py --num-epochs=1', - 'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2', - 'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -e=2', - 'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -k=1', - 'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -e=2 -k=1', - 'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --haar', - 'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -nb=4', - 'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -hfm=3', - 'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -a', - 'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -o=0.2', - 'contrib/epidemiology/regional.py -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2', - 'contrib/epidemiology/regional.py -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --haar', - 'contrib/epidemiology/regional.py -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 -hfm=3', - 'contrib/epidemiology/regional.py -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 -nb=4', + 'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2', + 'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -e=2', + 'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -k=1', + 'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -e=2 -k=1', + 'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --haar', + 'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -nb=4', + 'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -hfm=3', + 'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -a', + 'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -o=0.2', + 'contrib/epidemiology/regional.py --nojit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2', + 'contrib/epidemiology/regional.py --nojit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --haar', + 'contrib/epidemiology/regional.py --nojit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 -hfm=3', + 'contrib/epidemiology/regional.py --nojit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 -nb=4', 'contrib/forecast/bart.py --num-steps=2 --stride=99999', 'contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --batch-size=1000', 'contrib/gp/sv-dkl.py --binary --epochs=1 --num-inducing=4 --batch-size=1000', @@ -108,11 +108,11 @@ 'air/main.py --num-steps=1 --cuda', 'baseball.py --num-samples=200 --warmup-steps=100 --num-chains=2 --cuda', 'contrib/cevae/synthetic.py --num-epochs=1 --cuda', - 'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --cuda', - 'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -nb=16 --cuda', - 'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --haar --cuda', - 'contrib/epidemiology/regional.py -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --cuda', - 'contrib/epidemiology/regional.py -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --haar --cuda', + 'contrib/epidemiology/sir.py --nojit -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --cuda', + 'contrib/epidemiology/sir.py --nojit -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -nb=16 --cuda', + 'contrib/epidemiology/sir.py --nojit -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --haar --cuda', + 'contrib/epidemiology/regional.py --nojit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --cuda', + 'contrib/epidemiology/regional.py --nojit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --haar --cuda', 'contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --cuda', 'lkj.py --n=50 --num-chains=1 --warmup-steps=100 --num-samples=200 --cuda', 'dmm/dmm.py --num-epochs=1 --cuda', @@ -157,6 +157,8 @@ def xfail_jit(*args): 'baseball.py --num-samples=200 --warmup-steps=100 --jit', 'contrib/autoname/mixture.py --num-epochs=1 --jit', 'contrib/cevae/synthetic.py --num-epochs=1 --jit', + 'contrib/epidemiology/sir.py --jit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2', + 'contrib/epidemiology/regional.py --jit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2', xfail_jit('lkj.py --n=50 --num-chains=1 --warmup-steps=100 --num-samples=200 --jit'), xfail_jit('contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --jit'), xfail_jit('dmm/dmm.py --num-epochs=1 --jit'),