Skip to content

Commit

Permalink
Support jit in CompartmentalModel (#2526)
Browse files Browse the repository at this point in the history
* Support jit in CompartmentalModel

* Enable jit by default in examples
  • Loading branch information
fritzo authored Jun 14, 2020
1 parent f7a5677 commit ab4c663
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 23 deletions.
3 changes: 3 additions & 0 deletions examples/contrib/epidemiology/regional.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def hook_fn(kernel, *unused):
num_quant_bins=args.num_bins,
haar=args.haar,
haar_full_mass=args.haar_full_mass,
jit_compile=args.jit,
hook_fn=hook_fn)

mcmc.summary()
Expand Down Expand Up @@ -149,6 +150,8 @@ def main(args):
parser.add_argument("--single", action="store_false", dest="double")
parser.add_argument("--rng-seed", default=0, type=int)
parser.add_argument("--cuda", action="store_true")
parser.add_argument("--jit", action="store_true", default=True)
parser.add_argument("--nojit", action="store_true", dest="jit")
parser.add_argument("--verbose", action="store_true")
parser.add_argument("--plot", action="store_true")
args = parser.parse_args()
Expand Down
3 changes: 3 additions & 0 deletions examples/contrib/epidemiology/sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def hook_fn(kernel, *unused):
num_quant_bins=args.num_bins,
haar=args.haar,
haar_full_mass=args.haar_full_mass,
jit_compile=args.jit,
hook_fn=hook_fn)

mcmc.summary()
Expand Down Expand Up @@ -299,6 +300,8 @@ def main(args):
parser.add_argument("--double", action="store_true", default=True)
parser.add_argument("--single", action="store_false", dest="double")
parser.add_argument("--cuda", action="store_true")
parser.add_argument("--jit", action="store_true", default=True)
parser.add_argument("--nojit", action="store_true", dest="jit")
parser.add_argument("--verbose", action="store_true")
parser.add_argument("--plot", action="store_true")
args = parser.parse_args()
Expand Down
13 changes: 8 additions & 5 deletions pyro/contrib/epidemiology/compartmental.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,15 +382,18 @@ def heuristic():

# Configure a kernel.
logger.info("Running inference...")
max_tree_depth = options.pop("max_tree_depth", 5)
model = self._relaxed_model if self.relaxed else self._quantized_model
if haar:
model = haar.reparam(model)
kernel = NUTS(model,
full_mass=full_mass,
init_strategy=init_to_generated(generate=heuristic),
max_plate_nesting=self.max_plate_nesting,
max_tree_depth=max_tree_depth)
jit_compile=options.pop("jit_compile", False),
jit_options=options.pop("jit_options", None),
ignore_jit_warnings=options.pop("ignore_jit_warnings", True),
target_accept_prob=options.pop("target_accept_prob", 0.8),
max_tree_depth=options.pop("max_tree_depth", 5))
if options.pop("arrowhead_mass", False):
kernel.mass_matrix_adapter = ArrowheadMassMatrix()

Expand Down Expand Up @@ -700,8 +703,8 @@ def _quantized_model(self):
# Manually enumerate.
curr, logp = quantize_enumerate(auxiliary, min=0, max=self.population,
num_quant_bins=self.num_quant_bins)
curr = OrderedDict(zip(self.compartments, curr))
logp = OrderedDict(zip(self.compartments, logp))
curr = OrderedDict(zip(self.compartments, curr.unbind(0)))
logp = OrderedDict(zip(self.compartments, logp.unbind(0)))
curr.update(non_compartmental)

# Truncate final value from the right then pad initial value onto the left.
Expand Down Expand Up @@ -778,7 +781,7 @@ def _relaxed_model(self):
auxiliary, non_compartmental = self._sample_auxiliary()

# Split tensors into current state.
curr = dict(zip(self.compartments, auxiliary))
curr = dict(zip(self.compartments, auxiliary.unbind(0)))
curr.update(non_compartmental)

# Truncate final value from the right then pad initial value onto the left.
Expand Down
6 changes: 6 additions & 0 deletions tests/contrib/epidemiology/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
{"num_quant_bins": 16},
{"num_quant_bins": 2, "haar": True},
{"arrowhead_mass": True},
{"jit_compile": True},
{"jit_compile": True, "haar_full_mass": 2},
{"jit_compile": True, "num_quant_bins": 2},
], ids=str)
def test_simple_sir_smoke(duration, forecast, options):
population = 100
Expand Down Expand Up @@ -429,6 +432,9 @@ def test_regional_smoke(duration, forecast, options):
{"haar": True},
{"haar_full_mass": 2},
{"num_quant_bins": 2},
{"jit_compile": True},
{"jit_compile": True, "haar_full_mass": 2},
{"jit_compile": True, "num_quant_bins": 2},
], ids=str)
def test_hetero_regional_smoke(duration, forecast, options):
num_regions = 6
Expand Down
38 changes: 20 additions & 18 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,19 @@
'contrib/autoname/mixture.py --num-epochs=1',
'contrib/autoname/tree_data.py --num-epochs=1',
'contrib/cevae/synthetic.py --num-epochs=1',
'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2',
'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -e=2',
'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -k=1',
'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -e=2 -k=1',
'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --haar',
'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -nb=4',
'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -hfm=3',
'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -a',
'contrib/epidemiology/sir.py -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -o=0.2',
'contrib/epidemiology/regional.py -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2',
'contrib/epidemiology/regional.py -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --haar',
'contrib/epidemiology/regional.py -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 -hfm=3',
'contrib/epidemiology/regional.py -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 -nb=4',
'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2',
'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -e=2',
'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -k=1',
'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -e=2 -k=1',
'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --haar',
'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -nb=4',
'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -hfm=3',
'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -a',
'contrib/epidemiology/sir.py --nojit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -o=0.2',
'contrib/epidemiology/regional.py --nojit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2',
'contrib/epidemiology/regional.py --nojit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --haar',
'contrib/epidemiology/regional.py --nojit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 -hfm=3',
'contrib/epidemiology/regional.py --nojit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 -nb=4',
'contrib/forecast/bart.py --num-steps=2 --stride=99999',
'contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --batch-size=1000',
'contrib/gp/sv-dkl.py --binary --epochs=1 --num-inducing=4 --batch-size=1000',
Expand Down Expand Up @@ -108,11 +108,11 @@
'air/main.py --num-steps=1 --cuda',
'baseball.py --num-samples=200 --warmup-steps=100 --num-chains=2 --cuda',
'contrib/cevae/synthetic.py --num-epochs=1 --cuda',
'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --cuda',
'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -nb=16 --cuda',
'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --haar --cuda',
'contrib/epidemiology/regional.py -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --cuda',
'contrib/epidemiology/regional.py -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --haar --cuda',
'contrib/epidemiology/sir.py --nojit -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --cuda',
'contrib/epidemiology/sir.py --nojit -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -nb=16 --cuda',
'contrib/epidemiology/sir.py --nojit -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --haar --cuda',
'contrib/epidemiology/regional.py --nojit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --cuda',
'contrib/epidemiology/regional.py --nojit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2 --haar --cuda',
'contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --cuda',
'lkj.py --n=50 --num-chains=1 --warmup-steps=100 --num-samples=200 --cuda',
'dmm/dmm.py --num-epochs=1 --cuda',
Expand Down Expand Up @@ -157,6 +157,8 @@ def xfail_jit(*args):
'baseball.py --num-samples=200 --warmup-steps=100 --jit',
'contrib/autoname/mixture.py --num-epochs=1 --jit',
'contrib/cevae/synthetic.py --num-epochs=1 --jit',
'contrib/epidemiology/sir.py --jit -np=128 -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2',
'contrib/epidemiology/regional.py --jit -t=2 -w=2 -n=4 -r=3 -d=20 -p=1000 -f 2',
xfail_jit('lkj.py --n=50 --num-chains=1 --warmup-steps=100 --num-samples=200 --jit'),
xfail_jit('contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --jit'),
xfail_jit('dmm/dmm.py --num-epochs=1 --jit'),
Expand Down

0 comments on commit ab4c663

Please sign in to comment.