Skip to content

Commit

Permalink
Remove DCT and higher order spline (#2438)
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo authored Apr 24, 2020
1 parent c11b83c commit 560aafb
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 105 deletions.
128 changes: 27 additions & 101 deletions examples/sir_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,13 @@
from collections import OrderedDict

import torch
from torch.distributions import biject_to, constraints
from torch.distributions.transforms import ComposeTransform

import pyro
import pyro.distributions as dist
import pyro.distributions.hmm
import pyro.poutine as poutine
from pyro.distributions.transforms.discrete_cosine import DiscreteCosineTransform
from pyro.infer import MCMC, NUTS, config_enumerate, infer_discrete
from pyro.infer.autoguide import init_to_value
from pyro.infer.reparam import DiscreteCosineReparam
from pyro.ops.tensor_utils import convolve, safe_log
from pyro.util import warn_if_nan

Expand Down Expand Up @@ -243,54 +239,29 @@ def hook_fn(kernel, *unused):
#
# We first define a helper to create enumerated Categorical sites.

def quantize(name, x_real, min, max, spline_order=3):
def quantize(name, x_real, min, max):
"""
Randomly quantize in a way that preserves probability mass.
We use a piecewise polynomial spline of order 3 or 5.
We use a piecewise polynomial spline of order 3.
"""
assert spline_order in [3, 5]
assert min < max
lb = x_real.detach().floor()

# This cubic spline interpolates over the nearest four integers, ensuring
# piecewise quadratic gradients.
s = x_real - lb
ss = s * s
t = 1 - s
tt = t * t

if spline_order == 3:
# This cubic spline interpolates over the nearest four integers, ensuring
# piecewise quadratic gradients.
probs = torch.stack([
t * tt,
4 + ss * (3 * s - 6),
4 + tt * (3 * t - 6),
s * ss,
], dim=-1) * (1/6)
elif spline_order == 5:
# This quintic spline interpolates over the nearest eight integers, ensuring
# piecewise quartic gradients.
sss = ss * s
ssss = ss * ss
sssss = sss * ss

ttt = tt * t
tttt = tt * tt
ttttt = ttt * tt

probs = torch.stack([
2 * ttttt,
2 + 10 * t + 20 * tt + 20 * ttt + 10 * tttt - 7 * ttttt,
55 + 115 * t + 70 * tt - 9 * ttt - 25 * tttt + 7 * ttttt,
302 - 100 * ss + 10 * ssss,
302 - 100 * tt + 10 * tttt,
55 + 115 * s + 70 * ss - 9 * sss - 25 * ssss + 7 * sssss,
2 + 10 * s + 20 * ss + 20 * sss + 10 * ssss - 7 * sssss,
2 * sssss
], dim=-1) * (1/840)

probs = torch.stack([
t * tt,
4 + ss * (3 * s - 6),
4 + tt * (3 * t - 6),
s * ss,
], dim=-1) * (1/6)
q = pyro.sample("Q_" + name, dist.Categorical(probs)).type_as(x_real)

x = lb + q - {3: 1, 5: 3}[spline_order]
x = lb + q - 1
x = torch.max(x, 2 * min - 1 - x)
x = torch.min(x, 2 * max + 1 - x)

Expand All @@ -317,8 +288,8 @@ def continuous_model(args, data):
I_curr = torch.tensor(1.)
for t, datum in poutine.markov(enumerate(data)):
S_prev, I_prev = S_curr, I_curr
S_curr = quantize("S_{}".format(t), S_aux[..., t], min=0, max=args.population, spline_order=args.spline_order)
I_curr = quantize("I_{}".format(t), I_aux[..., t], min=0, max=args.population, spline_order=args.spline_order)
S_curr = quantize("S_{}".format(t), S_aux[..., t], min=0, max=args.population)
I_curr = quantize("I_{}".format(t), I_aux[..., t], min=0, max=args.population)

# Now we reverse the computation.
S2I = S_prev - S_curr
Expand Down Expand Up @@ -353,29 +324,15 @@ def heuristic_init(args, data):
recovery = torch.arange(30.).div(args.recovery_time).neg().exp()
I_aux = convolve(S2I, recovery)[:len(data)].clamp(min=0.5)

# Also initialize DCT transformed coordinates.
t = ComposeTransform([biject_to(constraints.interval(-0.5, args.population + 0.5)).inv,
DiscreteCosineTransform(dim=-1)])

return {
"R0": torch.tensor(2.0),
"rho": torch.tensor(0.5),
"S_aux": S_aux,
"I_aux": I_aux,
"S_aux_dct": t(S_aux),
"I_aux_dct": t(I_aux),
}


# One trick to improve inference geometry is to reparameterize the S_aux,I_aux
# variables via DiscreteCosineReparam. This allows HMC's diagonal mass matrix
# adaptation to learn different step sizes for high- and low-frequency
# directions. We can apply that outside of the model, during inference.

def infer_hmc_cont(model, args, data):
if args.dct:
rep = DiscreteCosineReparam()
model = poutine.reparam(model, {"S_aux": rep, "I_aux": rep})
init_values = heuristic_init(args, data)
return _infer_hmc(args, data, model, init_values=init_values)

Expand All @@ -386,53 +343,28 @@ def infer_hmc_cont(model, args, data):
# with 4 * 4 = 16 states, and then manually perform variable elimination (the
# factors here don't quite conform to DiscreteHMM's interface).

def quantize_enumerate(x_real, min, max, spline_order=3):
def quantize_enumerate(x_real, min, max):
"""
Randomly quantize in a way that preserves probability mass.
We use a piecewise polynomial spline of order 3 or 5.
We use a piecewise polynomial spline of order 3.
"""
assert spline_order in [3, 5]
assert min < max
lb = x_real.detach().floor()

# This cubic spline interpolates over the nearest four integers, ensuring
# piecewise quadratic gradients.
s = x_real - lb
ss = s * s
t = 1 - s
tt = t * t

if spline_order == 3:
# This cubic spline interpolates over the nearest four integers, ensuring
# piecewise quadratic gradients.
probs = torch.stack([
t * tt,
4 + ss * (3 * s - 6),
4 + tt * (3 * t - 6),
s * ss,
], dim=-1) * (1/6)
elif spline_order == 5:
# This quintic spline interpolates over the nearest eight integers, ensuring
# piecewise quartic gradients.
sss = ss * s
ssss = ss * ss
sssss = sss * ss

ttt = tt * t
tttt = tt * tt
ttttt = ttt * tt

probs = torch.stack([
2 * ttttt,
2 + 10 * t + 20 * tt + 20 * ttt + 10 * tttt - 7 * ttttt,
55 + 115 * t + 70 * tt - 9 * ttt - 25 * tttt + 7 * ttttt,
302 - 100 * ss + 10 * ssss,
302 - 100 * tt + 10 * tttt,
55 + 115 * s + 70 * ss - 9 * sss - 25 * ssss + 7 * sssss,
2 + 10 * s + 20 * ss + 20 * sss + 10 * ssss - 7 * sssss,
2 * sssss
], dim=-1) * (1/840)

probs = torch.stack([
t * tt,
4 + ss * (3 * s - 6),
4 + tt * (3 * t - 6),
s * ss,
], dim=-1) * (1/6)
logits = safe_log(probs)
q = torch.arange(-1., 3.) if spline_order == 3 else torch.arange(-3., 5.)
q = torch.arange(-1., 3.)

x = lb.unsqueeze(-1) + q
x = torch.max(x, 2 * min - 1 - x)
Expand All @@ -453,14 +385,14 @@ def vectorized_model(args, data):
.mask(False).expand(data.shape).to_event(1))

# Manually enumerate.
S_curr, S_logp = quantize_enumerate(S_aux, min=0, max=args.population, spline_order=args.spline_order)
I_curr, I_logp = quantize_enumerate(I_aux, min=0, max=args.population, spline_order=args.spline_order)
S_curr, S_logp = quantize_enumerate(S_aux, min=0, max=args.population)
I_curr, I_logp = quantize_enumerate(I_aux, min=0, max=args.population)
# Truncate final value from the right then pad initial value onto the left.
S_prev = torch.nn.functional.pad(S_curr[:-1], (0, 0, 1, 0), value=args.population - 1)
I_prev = torch.nn.functional.pad(I_curr[:-1], (0, 0, 1, 0), value=1)
# Reshape to support broadcasting, similar to EnumMessenger.
T = len(data)
Q = 4 if args.spline_order == 3 else 8 # Number of quantization points.
Q = 4
S_prev = S_prev.reshape(T, Q, 1, 1, 1)
I_prev = I_prev.reshape(T, 1, Q, 1, 1)
S_curr = S_curr.reshape(T, 1, 1, Q, 1)
Expand Down Expand Up @@ -547,9 +479,6 @@ def predict(args, data, samples, truth=None):
# algorithm. We'll add these new samples to the existing dict of samples.
model = poutine.condition(continuous_model, samples)
model = particle_plate(model)
if args.dct: # Apply the same reparameterizer as during inference.
rep = DiscreteCosineReparam()
model = poutine.reparam(model, {"S_aux": rep, "I_aux": rep})
model = infer_discrete(model, first_available_dim=-2)
with poutine.trace() as tr:
model(args, data)
Expand Down Expand Up @@ -656,13 +585,10 @@ def main(args):
help="use the full enumeration model")
parser.add_argument("-s", "--sequential", action="store_true",
help="use the sequential continuous model")
parser.add_argument("--dct", action="store_true",
help="use discrete cosine reparameterizer")
parser.add_argument("-n", "--num-samples", default=200, type=int)
parser.add_argument("-w", "--warmup-steps", default=100, type=int)
parser.add_argument("-t", "--max-tree-depth", default=5, type=int)
parser.add_argument("-r", "--rng-seed", default=0, type=int)
parser.add_argument("-so", "--spline-order", default=3, type=int)
parser.add_argument("--double", action="store_true")
parser.add_argument("--jit", action="store_true")
parser.add_argument("--cuda", action="store_true")
Expand Down
4 changes: 0 additions & 4 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,6 @@
'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -m=1 --enum',
'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -p=10000 --sequential',
'sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 -f 2',
'sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 -f 2 --dct',
'sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 -f 2 -so=5',
'smcfilter.py --num-timesteps=3 --num-particles=10',
'sparse_gamma_def.py --num-epochs=2 --eval-particles=2 --eval-frequency=1 --guide custom',
'sparse_gamma_def.py --num-epochs=2 --eval-particles=2 --eval-frequency=1 --guide auto',
Expand Down Expand Up @@ -127,7 +125,6 @@
'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -m=1 --enum --cuda',
'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -p=10000 --sequential --cuda',
'sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 --cuda',
'sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 --dct --cuda',
'vae/vae.py --num-epochs=1 --cuda',
'vae/ss_vae_M2.py --num-epochs=1 --cuda',
'vae/ss_vae_M2.py --num-epochs=1 --aux-loss --cuda',
Expand Down Expand Up @@ -168,7 +165,6 @@ def xfail_jit(*args):
'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -m=1 --enum --jit',
'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -p=10000 --sequential --jit',
'sir_hmc.py -t=2 -w=2 -n=4 -p=10000 --jit',
'sir_hmc.py -t=2 -w=2 -n=4 -p=10000 --dct --jit',
xfail_jit('vae/ss_vae_M2.py --num-epochs=1 --aux-loss --jit'),
'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=parallel --jit',
'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=sequential --jit',
Expand Down

0 comments on commit 560aafb

Please sign in to comment.