From 6853db53defc3cdb4307df2f6a6bc7dc39a24c69 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Fri, 24 Apr 2020 05:55:00 -0700 Subject: [PATCH] Remove DCT and higher order spline --- examples/sir_hmc.py | 128 +++++++++-------------------------------- tests/test_examples.py | 4 -- 2 files changed, 27 insertions(+), 105 deletions(-) diff --git a/examples/sir_hmc.py b/examples/sir_hmc.py index b96dcdc308..15a02d6fbe 100644 --- a/examples/sir_hmc.py +++ b/examples/sir_hmc.py @@ -19,17 +19,13 @@ from collections import OrderedDict import torch -from torch.distributions import biject_to, constraints -from torch.distributions.transforms import ComposeTransform import pyro import pyro.distributions as dist import pyro.distributions.hmm import pyro.poutine as poutine -from pyro.distributions.transforms.discrete_cosine import DiscreteCosineTransform from pyro.infer import MCMC, NUTS, config_enumerate, infer_discrete from pyro.infer.autoguide import init_to_value -from pyro.infer.reparam import DiscreteCosineReparam from pyro.ops.tensor_utils import convolve, safe_log from pyro.util import warn_if_nan @@ -243,54 +239,29 @@ def hook_fn(kernel, *unused): # # We first define a helper to create enumerated Categorical sites. -def quantize(name, x_real, min, max, spline_order=3): +def quantize(name, x_real, min, max): """ Randomly quantize in a way that preserves probability mass. - We use a piecewise polynomial spline of order 3 or 5. + We use a piecewise polynomial spline of order 3. """ - assert spline_order in [3, 5] assert min < max lb = x_real.detach().floor() + # This cubic spline interpolates over the nearest four integers, ensuring + # piecewise quadratic gradients. s = x_real - lb ss = s * s t = 1 - s tt = t * t - - if spline_order == 3: - # This cubic spline interpolates over the nearest four integers, ensuring - # piecewise quadratic gradients. - probs = torch.stack([ - t * tt, - 4 + ss * (3 * s - 6), - 4 + tt * (3 * t - 6), - s * ss, - ], dim=-1) * (1/6) - elif spline_order == 5: - # This quintic spline interpolates over the nearest eight integers, ensuring - # piecewise quartic gradients. - sss = ss * s - ssss = ss * ss - sssss = sss * ss - - ttt = tt * t - tttt = tt * tt - ttttt = ttt * tt - - probs = torch.stack([ - 2 * ttttt, - 2 + 10 * t + 20 * tt + 20 * ttt + 10 * tttt - 7 * ttttt, - 55 + 115 * t + 70 * tt - 9 * ttt - 25 * tttt + 7 * ttttt, - 302 - 100 * ss + 10 * ssss, - 302 - 100 * tt + 10 * tttt, - 55 + 115 * s + 70 * ss - 9 * sss - 25 * ssss + 7 * sssss, - 2 + 10 * s + 20 * ss + 20 * sss + 10 * ssss - 7 * sssss, - 2 * sssss - ], dim=-1) * (1/840) - + probs = torch.stack([ + t * tt, + 4 + ss * (3 * s - 6), + 4 + tt * (3 * t - 6), + s * ss, + ], dim=-1) * (1/6) q = pyro.sample("Q_" + name, dist.Categorical(probs)).type_as(x_real) - x = lb + q - {3: 1, 5: 3}[spline_order] + x = lb + q - 1 x = torch.max(x, 2 * min - 1 - x) x = torch.min(x, 2 * max + 1 - x) @@ -317,8 +288,8 @@ def continuous_model(args, data): I_curr = torch.tensor(1.) for t, datum in poutine.markov(enumerate(data)): S_prev, I_prev = S_curr, I_curr - S_curr = quantize("S_{}".format(t), S_aux[..., t], min=0, max=args.population, spline_order=args.spline_order) - I_curr = quantize("I_{}".format(t), I_aux[..., t], min=0, max=args.population, spline_order=args.spline_order) + S_curr = quantize("S_{}".format(t), S_aux[..., t], min=0, max=args.population) + I_curr = quantize("I_{}".format(t), I_aux[..., t], min=0, max=args.population) # Now we reverse the computation. S2I = S_prev - S_curr @@ -353,29 +324,15 @@ def heuristic_init(args, data): recovery = torch.arange(30.).div(args.recovery_time).neg().exp() I_aux = convolve(S2I, recovery)[:len(data)].clamp(min=0.5) - # Also initialize DCT transformed coordinates. - t = ComposeTransform([biject_to(constraints.interval(-0.5, args.population + 0.5)).inv, - DiscreteCosineTransform(dim=-1)]) - return { "R0": torch.tensor(2.0), "rho": torch.tensor(0.5), "S_aux": S_aux, "I_aux": I_aux, - "S_aux_dct": t(S_aux), - "I_aux_dct": t(I_aux), } -# One trick to improve inference geometry is to reparameterize the S_aux,I_aux -# variables via DiscreteCosineReparam. This allows HMC's diagonal mass matrix -# adaptation to learn different step sizes for high- and low-frequency -# directions. We can apply that outside of the model, during inference. - def infer_hmc_cont(model, args, data): - if args.dct: - rep = DiscreteCosineReparam() - model = poutine.reparam(model, {"S_aux": rep, "I_aux": rep}) init_values = heuristic_init(args, data) return _infer_hmc(args, data, model, init_values=init_values) @@ -386,53 +343,28 @@ def infer_hmc_cont(model, args, data): # with 4 * 4 = 16 states, and then manually perform variable elimination (the # factors here don't quite conform to DiscreteHMM's interface). -def quantize_enumerate(x_real, min, max, spline_order=3): +def quantize_enumerate(x_real, min, max): """ Randomly quantize in a way that preserves probability mass. - We use a piecewise polynomial spline of order 3 or 5. + We use a piecewise polynomial spline of order 3. """ - assert spline_order in [3, 5] assert min < max lb = x_real.detach().floor() + # This cubic spline interpolates over the nearest four integers, ensuring + # piecewise quadratic gradients. s = x_real - lb ss = s * s t = 1 - s tt = t * t - - if spline_order == 3: - # This cubic spline interpolates over the nearest four integers, ensuring - # piecewise quadratic gradients. - probs = torch.stack([ - t * tt, - 4 + ss * (3 * s - 6), - 4 + tt * (3 * t - 6), - s * ss, - ], dim=-1) * (1/6) - elif spline_order == 5: - # This quintic spline interpolates over the nearest eight integers, ensuring - # piecewise quartic gradients. - sss = ss * s - ssss = ss * ss - sssss = sss * ss - - ttt = tt * t - tttt = tt * tt - ttttt = ttt * tt - - probs = torch.stack([ - 2 * ttttt, - 2 + 10 * t + 20 * tt + 20 * ttt + 10 * tttt - 7 * ttttt, - 55 + 115 * t + 70 * tt - 9 * ttt - 25 * tttt + 7 * ttttt, - 302 - 100 * ss + 10 * ssss, - 302 - 100 * tt + 10 * tttt, - 55 + 115 * s + 70 * ss - 9 * sss - 25 * ssss + 7 * sssss, - 2 + 10 * s + 20 * ss + 20 * sss + 10 * ssss - 7 * sssss, - 2 * sssss - ], dim=-1) * (1/840) - + probs = torch.stack([ + t * tt, + 4 + ss * (3 * s - 6), + 4 + tt * (3 * t - 6), + s * ss, + ], dim=-1) * (1/6) logits = safe_log(probs) - q = torch.arange(-1., 3.) if spline_order == 3 else torch.arange(-3., 5.) + q = torch.arange(-1., 3.) x = lb.unsqueeze(-1) + q x = torch.max(x, 2 * min - 1 - x) @@ -453,14 +385,14 @@ def vectorized_model(args, data): .mask(False).expand(data.shape).to_event(1)) # Manually enumerate. - S_curr, S_logp = quantize_enumerate(S_aux, min=0, max=args.population, spline_order=args.spline_order) - I_curr, I_logp = quantize_enumerate(I_aux, min=0, max=args.population, spline_order=args.spline_order) + S_curr, S_logp = quantize_enumerate(S_aux, min=0, max=args.population) + I_curr, I_logp = quantize_enumerate(I_aux, min=0, max=args.population) # Truncate final value from the right then pad initial value onto the left. S_prev = torch.nn.functional.pad(S_curr[:-1], (0, 0, 1, 0), value=args.population - 1) I_prev = torch.nn.functional.pad(I_curr[:-1], (0, 0, 1, 0), value=1) # Reshape to support broadcasting, similar to EnumMessenger. T = len(data) - Q = 4 if args.spline_order == 3 else 8 # Number of quantization points. + Q = 4 S_prev = S_prev.reshape(T, Q, 1, 1, 1) I_prev = I_prev.reshape(T, 1, Q, 1, 1) S_curr = S_curr.reshape(T, 1, 1, Q, 1) @@ -547,9 +479,6 @@ def predict(args, data, samples, truth=None): # algorithm. We'll add these new samples to the existing dict of samples. model = poutine.condition(continuous_model, samples) model = particle_plate(model) - if args.dct: # Apply the same reparameterizer as during inference. - rep = DiscreteCosineReparam() - model = poutine.reparam(model, {"S_aux": rep, "I_aux": rep}) model = infer_discrete(model, first_available_dim=-2) with poutine.trace() as tr: model(args, data) @@ -656,13 +585,10 @@ def main(args): help="use the full enumeration model") parser.add_argument("-s", "--sequential", action="store_true", help="use the sequential continuous model") - parser.add_argument("--dct", action="store_true", - help="use discrete cosine reparameterizer") parser.add_argument("-n", "--num-samples", default=200, type=int) parser.add_argument("-w", "--warmup-steps", default=100, type=int) parser.add_argument("-t", "--max-tree-depth", default=5, type=int) parser.add_argument("-r", "--rng-seed", default=0, type=int) - parser.add_argument("-so", "--spline-order", default=3, type=int) parser.add_argument("--double", action="store_true") parser.add_argument("--jit", action="store_true") parser.add_argument("--cuda", action="store_true") diff --git a/tests/test_examples.py b/tests/test_examples.py index 2480c48abc..5e933cfc1a 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -79,8 +79,6 @@ 'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -m=1 --enum', 'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -p=10000 --sequential', 'sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 -f 2', - 'sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 -f 2 --dct', - 'sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 -f 2 -so=5', 'smcfilter.py --num-timesteps=3 --num-particles=10', 'sparse_gamma_def.py --num-epochs=2 --eval-particles=2 --eval-frequency=1 --guide custom', 'sparse_gamma_def.py --num-epochs=2 --eval-particles=2 --eval-frequency=1 --guide auto', @@ -127,7 +125,6 @@ 'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -m=1 --enum --cuda', 'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -p=10000 --sequential --cuda', 'sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 --cuda', - 'sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 --dct --cuda', 'vae/vae.py --num-epochs=1 --cuda', 'vae/ss_vae_M2.py --num-epochs=1 --cuda', 'vae/ss_vae_M2.py --num-epochs=1 --aux-loss --cuda', @@ -168,7 +165,6 @@ def xfail_jit(*args): 'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -m=1 --enum --jit', 'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -p=10000 --sequential --jit', 'sir_hmc.py -t=2 -w=2 -n=4 -p=10000 --jit', - 'sir_hmc.py -t=2 -w=2 -n=4 -p=10000 --dct --jit', xfail_jit('vae/ss_vae_M2.py --num-epochs=1 --aux-loss --jit'), 'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=parallel --jit', 'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=sequential --jit',