Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove DCT and higher order spline from sir_hmc.py #2438

Merged
merged 1 commit into from
Apr 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 27 additions & 101 deletions examples/sir_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,13 @@
from collections import OrderedDict

import torch
from torch.distributions import biject_to, constraints
from torch.distributions.transforms import ComposeTransform

import pyro
import pyro.distributions as dist
import pyro.distributions.hmm
import pyro.poutine as poutine
from pyro.distributions.transforms.discrete_cosine import DiscreteCosineTransform
from pyro.infer import MCMC, NUTS, config_enumerate, infer_discrete
from pyro.infer.autoguide import init_to_value
from pyro.infer.reparam import DiscreteCosineReparam
from pyro.ops.tensor_utils import convolve, safe_log
from pyro.util import warn_if_nan

Expand Down Expand Up @@ -243,54 +239,29 @@ def hook_fn(kernel, *unused):
#
# We first define a helper to create enumerated Categorical sites.

def quantize(name, x_real, min, max, spline_order=3):
def quantize(name, x_real, min, max):
"""
Randomly quantize in a way that preserves probability mass.
We use a piecewise polynomial spline of order 3 or 5.
We use a piecewise polynomial spline of order 3.
"""
assert spline_order in [3, 5]
assert min < max
lb = x_real.detach().floor()

# This cubic spline interpolates over the nearest four integers, ensuring
# piecewise quadratic gradients.
s = x_real - lb
ss = s * s
t = 1 - s
tt = t * t

if spline_order == 3:
# This cubic spline interpolates over the nearest four integers, ensuring
# piecewise quadratic gradients.
probs = torch.stack([
t * tt,
4 + ss * (3 * s - 6),
4 + tt * (3 * t - 6),
s * ss,
], dim=-1) * (1/6)
elif spline_order == 5:
# This quintic spline interpolates over the nearest eight integers, ensuring
# piecewise quartic gradients.
sss = ss * s
ssss = ss * ss
sssss = sss * ss

ttt = tt * t
tttt = tt * tt
ttttt = ttt * tt

probs = torch.stack([
2 * ttttt,
2 + 10 * t + 20 * tt + 20 * ttt + 10 * tttt - 7 * ttttt,
55 + 115 * t + 70 * tt - 9 * ttt - 25 * tttt + 7 * ttttt,
302 - 100 * ss + 10 * ssss,
302 - 100 * tt + 10 * tttt,
55 + 115 * s + 70 * ss - 9 * sss - 25 * ssss + 7 * sssss,
2 + 10 * s + 20 * ss + 20 * sss + 10 * ssss - 7 * sssss,
2 * sssss
], dim=-1) * (1/840)

probs = torch.stack([
t * tt,
4 + ss * (3 * s - 6),
4 + tt * (3 * t - 6),
s * ss,
], dim=-1) * (1/6)
q = pyro.sample("Q_" + name, dist.Categorical(probs)).type_as(x_real)

x = lb + q - {3: 1, 5: 3}[spline_order]
x = lb + q - 1
x = torch.max(x, 2 * min - 1 - x)
x = torch.min(x, 2 * max + 1 - x)

Expand All @@ -317,8 +288,8 @@ def continuous_model(args, data):
I_curr = torch.tensor(1.)
for t, datum in poutine.markov(enumerate(data)):
S_prev, I_prev = S_curr, I_curr
S_curr = quantize("S_{}".format(t), S_aux[..., t], min=0, max=args.population, spline_order=args.spline_order)
I_curr = quantize("I_{}".format(t), I_aux[..., t], min=0, max=args.population, spline_order=args.spline_order)
S_curr = quantize("S_{}".format(t), S_aux[..., t], min=0, max=args.population)
I_curr = quantize("I_{}".format(t), I_aux[..., t], min=0, max=args.population)

# Now we reverse the computation.
S2I = S_prev - S_curr
Expand Down Expand Up @@ -353,29 +324,15 @@ def heuristic_init(args, data):
recovery = torch.arange(30.).div(args.recovery_time).neg().exp()
I_aux = convolve(S2I, recovery)[:len(data)].clamp(min=0.5)

# Also initialize DCT transformed coordinates.
t = ComposeTransform([biject_to(constraints.interval(-0.5, args.population + 0.5)).inv,
DiscreteCosineTransform(dim=-1)])

return {
"R0": torch.tensor(2.0),
"rho": torch.tensor(0.5),
"S_aux": S_aux,
"I_aux": I_aux,
"S_aux_dct": t(S_aux),
"I_aux_dct": t(I_aux),
}


# One trick to improve inference geometry is to reparameterize the S_aux,I_aux
# variables via DiscreteCosineReparam. This allows HMC's diagonal mass matrix
# adaptation to learn different step sizes for high- and low-frequency
# directions. We can apply that outside of the model, during inference.

def infer_hmc_cont(model, args, data):
if args.dct:
rep = DiscreteCosineReparam()
model = poutine.reparam(model, {"S_aux": rep, "I_aux": rep})
init_values = heuristic_init(args, data)
return _infer_hmc(args, data, model, init_values=init_values)

Expand All @@ -386,53 +343,28 @@ def infer_hmc_cont(model, args, data):
# with 4 * 4 = 16 states, and then manually perform variable elimination (the
# factors here don't quite conform to DiscreteHMM's interface).

def quantize_enumerate(x_real, min, max, spline_order=3):
def quantize_enumerate(x_real, min, max):
"""
Randomly quantize in a way that preserves probability mass.
We use a piecewise polynomial spline of order 3 or 5.
We use a piecewise polynomial spline of order 3.
"""
assert spline_order in [3, 5]
assert min < max
lb = x_real.detach().floor()

# This cubic spline interpolates over the nearest four integers, ensuring
# piecewise quadratic gradients.
s = x_real - lb
ss = s * s
t = 1 - s
tt = t * t

if spline_order == 3:
# This cubic spline interpolates over the nearest four integers, ensuring
# piecewise quadratic gradients.
probs = torch.stack([
t * tt,
4 + ss * (3 * s - 6),
4 + tt * (3 * t - 6),
s * ss,
], dim=-1) * (1/6)
elif spline_order == 5:
# This quintic spline interpolates over the nearest eight integers, ensuring
# piecewise quartic gradients.
sss = ss * s
ssss = ss * ss
sssss = sss * ss

ttt = tt * t
tttt = tt * tt
ttttt = ttt * tt

probs = torch.stack([
2 * ttttt,
2 + 10 * t + 20 * tt + 20 * ttt + 10 * tttt - 7 * ttttt,
55 + 115 * t + 70 * tt - 9 * ttt - 25 * tttt + 7 * ttttt,
302 - 100 * ss + 10 * ssss,
302 - 100 * tt + 10 * tttt,
55 + 115 * s + 70 * ss - 9 * sss - 25 * ssss + 7 * sssss,
2 + 10 * s + 20 * ss + 20 * sss + 10 * ssss - 7 * sssss,
2 * sssss
], dim=-1) * (1/840)

probs = torch.stack([
t * tt,
4 + ss * (3 * s - 6),
4 + tt * (3 * t - 6),
s * ss,
], dim=-1) * (1/6)
logits = safe_log(probs)
q = torch.arange(-1., 3.) if spline_order == 3 else torch.arange(-3., 5.)
q = torch.arange(-1., 3.)

x = lb.unsqueeze(-1) + q
x = torch.max(x, 2 * min - 1 - x)
Expand All @@ -453,14 +385,14 @@ def vectorized_model(args, data):
.mask(False).expand(data.shape).to_event(1))

# Manually enumerate.
S_curr, S_logp = quantize_enumerate(S_aux, min=0, max=args.population, spline_order=args.spline_order)
I_curr, I_logp = quantize_enumerate(I_aux, min=0, max=args.population, spline_order=args.spline_order)
S_curr, S_logp = quantize_enumerate(S_aux, min=0, max=args.population)
I_curr, I_logp = quantize_enumerate(I_aux, min=0, max=args.population)
# Truncate final value from the right then pad initial value onto the left.
S_prev = torch.nn.functional.pad(S_curr[:-1], (0, 0, 1, 0), value=args.population - 1)
I_prev = torch.nn.functional.pad(I_curr[:-1], (0, 0, 1, 0), value=1)
# Reshape to support broadcasting, similar to EnumMessenger.
T = len(data)
Q = 4 if args.spline_order == 3 else 8 # Number of quantization points.
Q = 4
S_prev = S_prev.reshape(T, Q, 1, 1, 1)
I_prev = I_prev.reshape(T, 1, Q, 1, 1)
S_curr = S_curr.reshape(T, 1, 1, Q, 1)
Expand Down Expand Up @@ -547,9 +479,6 @@ def predict(args, data, samples, truth=None):
# algorithm. We'll add these new samples to the existing dict of samples.
model = poutine.condition(continuous_model, samples)
model = particle_plate(model)
if args.dct: # Apply the same reparameterizer as during inference.
rep = DiscreteCosineReparam()
model = poutine.reparam(model, {"S_aux": rep, "I_aux": rep})
model = infer_discrete(model, first_available_dim=-2)
with poutine.trace() as tr:
model(args, data)
Expand Down Expand Up @@ -656,13 +585,10 @@ def main(args):
help="use the full enumeration model")
parser.add_argument("-s", "--sequential", action="store_true",
help="use the sequential continuous model")
parser.add_argument("--dct", action="store_true",
help="use discrete cosine reparameterizer")
parser.add_argument("-n", "--num-samples", default=200, type=int)
parser.add_argument("-w", "--warmup-steps", default=100, type=int)
parser.add_argument("-t", "--max-tree-depth", default=5, type=int)
parser.add_argument("-r", "--rng-seed", default=0, type=int)
parser.add_argument("-so", "--spline-order", default=3, type=int)
parser.add_argument("--double", action="store_true")
parser.add_argument("--jit", action="store_true")
parser.add_argument("--cuda", action="store_true")
Expand Down
4 changes: 0 additions & 4 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,6 @@
'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -m=1 --enum',
'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -p=10000 --sequential',
'sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 -f 2',
'sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 -f 2 --dct',
'sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 -f 2 -so=5',
'smcfilter.py --num-timesteps=3 --num-particles=10',
'sparse_gamma_def.py --num-epochs=2 --eval-particles=2 --eval-frequency=1 --guide custom',
'sparse_gamma_def.py --num-epochs=2 --eval-particles=2 --eval-frequency=1 --guide auto',
Expand Down Expand Up @@ -127,7 +125,6 @@
'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -m=1 --enum --cuda',
'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -p=10000 --sequential --cuda',
'sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 --cuda',
'sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 --dct --cuda',
'vae/vae.py --num-epochs=1 --cuda',
'vae/ss_vae_M2.py --num-epochs=1 --cuda',
'vae/ss_vae_M2.py --num-epochs=1 --aux-loss --cuda',
Expand Down Expand Up @@ -168,7 +165,6 @@ def xfail_jit(*args):
'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -m=1 --enum --jit',
'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -p=10000 --sequential --jit',
'sir_hmc.py -t=2 -w=2 -n=4 -p=10000 --jit',
'sir_hmc.py -t=2 -w=2 -n=4 -p=10000 --dct --jit',
xfail_jit('vae/ss_vae_M2.py --num-epochs=1 --aux-loss --jit'),
'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=parallel --jit',
'vae/ss_vae_M2.py --num-epochs=1 --enum-discrete=sequential --jit',
Expand Down