diff --git a/examples/sir_hmc.py b/examples/sir_hmc.py index 2fc83c85bd..a9fa427301 100644 --- a/examples/sir_hmc.py +++ b/examples/sir_hmc.py @@ -341,8 +341,8 @@ def heuristic_init(args, data): def infer_hmc_cont(model, args, data): if args.dct: - model = poutine.reparam(model, {"S_aux": DiscreteCosineReparam(), - "I_aux": DiscreteCosineReparam()}) + rep = DiscreteCosineReparam() + model = poutine.reparam(model, {"S_aux": rep, "I_aux": rep}) init_values = heuristic_init(args, data) return _infer_hmc(args, data, model, init_values=init_values) @@ -478,11 +478,14 @@ def predict(args, data, samples, truth=None): particle_plate = pyro.plate("particles", args.num_samples, dim=-1) # First we sample discrete auxiliary variables from the continuous - # variables sampled in vectorized_model. Here infer_discrete runs a - # forward-filter backward-sample algorithm. We'll add these new samples to - # the existing dict of samples. + # variables sampled in vectorized_model. This samples only time steps + # [0:duration]. Here infer_discrete runs a forward-filter backward-sample + # algorithm. We'll add these new samples to the existing dict of samples. model = poutine.condition(continuous_model, samples) model = particle_plate(model) + if args.dct: # Apply the same reparameterizer as during inference. + rep = DiscreteCosineReparam() + model = poutine.reparam(model, {"S_aux": rep, "I_aux": rep}) model = infer_discrete(model, first_available_dim=-2) with poutine.trace() as tr: model(data, args.population) @@ -490,8 +493,9 @@ def predict(args, data, samples, truth=None): for name, site in tr.trace.nodes.items() if site["type"] == "sample"} - # Next we'll run the forward generative process in discrete_model. Again - # we'll update the dict of samples. + # Next we'll run the forward generative process in discrete_model. This + # samples time steps [duration:duration+forecast]. Again we'll update the + # dict of samples. extended_data = list(data) + [None] * args.forecast model = poutine.condition(discrete_model, samples) model = particle_plate(model) @@ -501,7 +505,8 @@ def predict(args, data, samples, truth=None): for name, site in tr.trace.nodes.items() if site["type"] == "sample"} - # Concatenate sequential time series into tensors. + # Finally we'll concatenate the sequentially sampled values into contiguous + # tensors. This operates on the entire time interval [0:duration+forecast]. for key in ("S", "I", "S2I", "I2R"): pattern = key + "_[0-9]+" series = [value diff --git a/tests/test_examples.py b/tests/test_examples.py index 196534e0fa..2cea6778f9 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -74,10 +74,10 @@ 'rsa/schelling.py --num-samples=10', 'rsa/schelling_false.py --num-samples=10', 'rsa/semantic_parsing.py --num-samples=10', - 'sir_hmc.py -w=2 -n=4 -d=2 -m=1 --enum', - 'sir_hmc.py -w=2 -n=4 -d=2 -p=10000 --sequential', - 'sir_hmc.py -w=2 -n=4 -d=100 -p=10000 -f 2', - 'sir_hmc.py -w=2 -n=4 -d=100 -p=10000 --dct', + 'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -m=1 --enum', + 'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -p=10000 --sequential', + 'sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 -f 2', + 'sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 -f 2 --dct', 'smcfilter.py --num-timesteps=3 --num-particles=10', 'sparse_gamma_def.py --num-epochs=2 --eval-particles=2 --eval-frequency=1 --guide custom', 'sparse_gamma_def.py --num-epochs=2 --eval-particles=2 --eval-frequency=1 --guide auto', @@ -119,10 +119,10 @@ 'hmm.py --num-steps=1 --truncate=10 --model=4 --tmc --tmc-num-samples=2 --cuda', 'hmm.py --num-steps=1 --truncate=10 --model=5 --tmc --tmc-num-samples=2 --cuda', 'hmm.py --num-steps=1 --truncate=10 --model=6 --tmc --tmc-num-samples=2 --cuda', - 'sir_hmc.py -w=2 -n=4 -d=2 -m=1 --enum --cuda', - 'sir_hmc.py -w=2 -n=4 -d=2 -p=10000 --sequential --cuda', - 'sir_hmc.py -w=2 -n=4 -d=100 -p=10000 --cuda', - 'sir_hmc.py -w=2 -n=4 -d=100 -p=10000 --dct --cuda', + 'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -m=1 --enum --cuda', + 'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -p=10000 --sequential --cuda', + 'sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 --cuda', + 'sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 --dct --cuda', 'vae/vae.py --num-epochs=1 --cuda', 'vae/ss_vae_M2.py --num-epochs=1 --cuda', 'vae/ss_vae_M2.py --num-epochs=1 --aux-loss --cuda', @@ -160,10 +160,10 @@ def xfail_jit(*args): 'lda.py --num-steps=2 --num-words=100 --num-docs=100 --num-words-per-doc=8 --jit', 'minipyro.py --backend=pyro --jit', 'minipyro.py --jit', - 'sir_hmc.py -w=2 -n=4 -d=2 -m=1 --enum --jit', - 'sir_hmc.py -w=2 -n=4 -d=2 -p=10000 --sequential --jit', - 'sir_hmc.py -w=2 -n=4 -p=10000 --jit', - 'sir_hmc.py -w=2 -n=4 -p=10000 --dct --jit', + 'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -m=1 --enum --jit', + 'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -p=10000 --sequential --jit', + 'sir_hmc.py -t=2 -w=2 -n=4 -p=10000 --jit', + 'sir_hmc.py -t=2 -w=2 -n=4 -p=10000 --dct --jit', xfail_jit('vae/ss_vae_M2.py --num-epochs=1 --aux-loss --jit'), 'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=parallel --jit', 'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=sequential --jit',