Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add quintic spline quantization strategy to sir_hmc.py #2434

Merged
merged 4 commits into from
Apr 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 109 additions & 51 deletions examples/sir_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,12 @@ def global_model(population):
return rate_s, prob_i, rho


def discrete_model(data, population):
def discrete_model(args, data):
# Sample global parameters.
rate_s, prob_i, rho = global_model(population)
rate_s, prob_i, rho = global_model(args.population)

# Sequentially sample time-local variables.
S = torch.tensor(population - 1.)
S = torch.tensor(args.population - 1.)
I = torch.tensor(1.)
for t, datum in enumerate(data):
S2I = pyro.sample("S2I_{}".format(t),
Expand All @@ -105,7 +105,7 @@ def generate_data(args):
for attempt in range(100):
with poutine.trace() as tr:
with poutine.condition(data=params):
discrete_model(empty_data, args.population)
discrete_model(args, empty_data)

# Concatenate sequential time series into tensors.
obs = torch.stack([site["value"]
Expand Down Expand Up @@ -150,22 +150,22 @@ def generate_data(args):
# The following model is equivalent to the discrete_model:

@config_enumerate
def reparameterized_discrete_model(data, population):
def reparameterized_discrete_model(args, data):
# Sample global parameters.
rate_s, prob_i, rho = global_model(population)
rate_s, prob_i, rho = global_model(args.population)

# Sequentially sample time-local variables.
S_curr = torch.tensor(population - 1.)
S_curr = torch.tensor(args.population - 1.)
I_curr = torch.tensor(1.)
for t, datum in enumerate(data):
# Sample reparameterizing variables.
# When reparameterizing to a factor graph, we ignored density via
# .mask(False). Thus distributions are used only for initialization.
S_prev, I_prev = S_curr, I_curr
S_curr = pyro.sample("S_{}".format(t),
dist.Binomial(population, 0.5).mask(False))
dist.Binomial(args.population, 0.5).mask(False))
I_curr = pyro.sample("I_{}".format(t),
dist.Binomial(population, 0.5).mask(False))
dist.Binomial(args.population, 0.5).mask(False))

# Now we reverse the computation.
S2I = S_prev - S_curr
Expand Down Expand Up @@ -218,7 +218,7 @@ def hook_fn(kernel, *unused):
mcmc = MCMC(kernel, hook_fn=hook_fn,
num_samples=args.num_samples,
warmup_steps=args.warmup_steps)
mcmc.run(data, population=args.population)
mcmc.run(args, data)
mcmc.summary()
if args.plot:
import matplotlib.pyplot as plt
Expand All @@ -242,53 +242,82 @@ def hook_fn(kernel, *unused):
#
# We first define a helper to create enumerated Categorical sites.

def quantize(name, x_real, min, max):
"""Randomly quantize in a way that preserves probability mass."""
def quantize(name, x_real, min, max, spline_order=3):
"""
Randomly quantize in a way that preserves probability mass.
We use a piecewise polynomial spline of order 3 or 5.
"""
assert spline_order in [3, 5]
assert min < max
lb = x_real.detach().floor()

# This cubic spline interpolates over the nearest four integers, ensuring
# piecewise quadratic gradients.
s = x_real - lb
ss = s * s
t = 1 - s
tt = t * t
probs = torch.stack([
t * tt,
4 + ss * (3 * s - 6),
4 + tt * (3 * t - 6),
s * ss,
], dim=-1) * (1/6)
q = pyro.sample("Q_" + name, dist.Categorical(probs)).type_as(x_real) - 1

x = lb + q

if spline_order == 3:
# This cubic spline interpolates over the nearest four integers, ensuring
# piecewise quadratic gradients.
probs = torch.stack([
t * tt,
4 + ss * (3 * s - 6),
4 + tt * (3 * t - 6),
s * ss,
], dim=-1) * (1/6)
elif spline_order == 5:
# This quintic spline interpolates over the nearest eight integers, ensuring
# piecewise quartic gradients.
sss = ss * s
ssss = ss * ss
sssss = sss * ss

ttt = tt * t
tttt = tt * tt
ttttt = ttt * tt

probs = torch.stack([
2 * ttttt,
2 + 10 * t + 20 * tt + 20 * ttt + 10 * tttt - 7 * ttttt,
55 + 115 * t + 70 * tt - 9 * ttt - 25 * tttt + 7 * ttttt,
302 - 100 * ss + 10 * ssss,
302 - 100 * tt + 10 * tttt,
55 + 115 * s + 70 * ss - 9 * sss - 25 * ssss + 7 * sssss,
2 + 10 * s + 20 * ss + 20 * sss + 10 * ssss - 7 * sssss,
2 * sssss
], dim=-1) * (1/840)

q = pyro.sample("Q_" + name, dist.Categorical(probs)).type_as(x_real)

x = lb + q - {3: 1, 5: 3}[spline_order]
x = torch.max(x, 2 * min - 1 - x)
x = torch.min(x, 2 * max + 1 - x)

return pyro.deterministic(name, x)


# Now we can define another equivalent model.

@config_enumerate
def continuous_model(data, population):
def continuous_model(args, data):
# Sample global parameters.
rate_s, prob_i, rho = global_model(population)
rate_s, prob_i, rho = global_model(args.population)

# Sample reparameterizing variables.
S_aux = pyro.sample("S_aux",
dist.Uniform(-0.5, population + 0.5)
dist.Uniform(-0.5, args.population + 0.5)
.mask(False).expand(data.shape).to_event(1))
I_aux = pyro.sample("I_aux",
dist.Uniform(-0.5, population + 0.5)
dist.Uniform(-0.5, args.population + 0.5)
.mask(False).expand(data.shape).to_event(1))

# Sequentially sample time-local variables.
S_curr = torch.tensor(population - 1.)
S_curr = torch.tensor(args.population - 1.)
I_curr = torch.tensor(1.)
for t, datum in poutine.markov(enumerate(data)):
S_prev, I_prev = S_curr, I_curr
S_curr = quantize("S_{}".format(t), S_aux[..., t], min=0, max=population)
I_curr = quantize("I_{}".format(t), I_aux[..., t], min=0, max=population)
S_curr = quantize("S_{}".format(t), S_aux[..., t], min=0, max=args.population, spline_order=args.spline_order)
I_curr = quantize("I_{}".format(t), I_aux[..., t], min=0, max=args.population, spline_order=args.spline_order)

# Now we reverse the computation.
S2I = S_prev - S_curr
Expand Down Expand Up @@ -356,53 +385,81 @@ def infer_hmc_cont(model, args, data):
# with 4 * 4 = 16 states, and then manually perform variable elimination (the
# factors here don't quite conform to DiscreteHMM's interface).

def quantize_enumerate(x_real, min, max):
"""Quantize, then manually enumerate."""
def quantize_enumerate(x_real, min, max, spline_order=3):
"""
Randomly quantize in a way that preserves probability mass.
We use a piecewise polynomial spline of order 3 or 5.
"""
assert spline_order in [3, 5]
assert min < max
lb = x_real.detach().floor()

# This cubic spline interpolates over the nearest four integers, ensuring
# piecewise quadratic gradients.
s = x_real - lb
ss = s * s
t = 1 - s
tt = t * t
probs = torch.stack([
t * tt,
4 + ss * (3 * s - 6),
4 + tt * (3 * t - 6),
s * ss,
], dim=-1) * (1/6)

if spline_order == 3:
# This cubic spline interpolates over the nearest four integers, ensuring
# piecewise quadratic gradients.
probs = torch.stack([
t * tt,
4 + ss * (3 * s - 6),
4 + tt * (3 * t - 6),
s * ss,
], dim=-1) * (1/6)
elif spline_order == 5:
# This quintic spline interpolates over the nearest eight integers, ensuring
# piecewise quartic gradients.
sss = ss * s
ssss = ss * ss
sssss = sss * ss

ttt = tt * t
tttt = tt * tt
ttttt = ttt * tt

probs = torch.stack([
2 * ttttt,
2 + 10 * t + 20 * tt + 20 * ttt + 10 * tttt - 7 * ttttt,
55 + 115 * t + 70 * tt - 9 * ttt - 25 * tttt + 7 * ttttt,
302 - 100 * ss + 10 * ssss,
302 - 100 * tt + 10 * tttt,
55 + 115 * s + 70 * ss - 9 * sss - 25 * ssss + 7 * sssss,
2 + 10 * s + 20 * ss + 20 * sss + 10 * ssss - 7 * sssss,
2 * sssss
], dim=-1) * (1/840)

logits = safe_log(probs)
q = torch.arange(-1., 3.)
q = torch.arange(-1., 3.) if spline_order == 3 else torch.arange(-3., 5.)

x = lb.unsqueeze(-1) + q
x = torch.max(x, 2 * min - 1 - x)
x = torch.min(x, 2 * max + 1 - x)
return x, logits


def vectorized_model(data, population):
def vectorized_model(args, data):
# Sample global parameters.
rate_s, prob_i, rho = global_model(population)
rate_s, prob_i, rho = global_model(args.population)

# Sample reparameterizing variables.
S_aux = pyro.sample("S_aux",
dist.Uniform(-0.5, population + 0.5)
dist.Uniform(-0.5, args.population + 0.5)
.mask(False).expand(data.shape).to_event(1))
I_aux = pyro.sample("I_aux",
dist.Uniform(-0.5, population + 0.5)
dist.Uniform(-0.5, args.population + 0.5)
.mask(False).expand(data.shape).to_event(1))

# Manually enumerate.
S_curr, S_logp = quantize_enumerate(S_aux, min=0, max=population)
I_curr, I_logp = quantize_enumerate(I_aux, min=0, max=population)
S_curr, S_logp = quantize_enumerate(S_aux, min=0, max=args.population, spline_order=args.spline_order)
I_curr, I_logp = quantize_enumerate(I_aux, min=0, max=args.population, spline_order=args.spline_order)
# Truncate final value from the right then pad initial value onto the left.
S_prev = torch.nn.functional.pad(S_curr[:-1], (0, 0, 1, 0), value=population - 1)
S_prev = torch.nn.functional.pad(S_curr[:-1], (0, 0, 1, 0), value=args.population - 1)
I_prev = torch.nn.functional.pad(I_curr[:-1], (0, 0, 1, 0), value=1)
# Reshape to support broadcasting, similar EnumMessenger.
T = len(data)
Q = 4 # Number of quantization points.
Q = 4 if args.spline_order == 3 else 8 # Number of quantization points.
S_prev = S_prev.reshape(T, Q, 1, 1, 1)
I_prev = I_prev.reshape(T, 1, Q, 1, 1)
S_curr = S_curr.reshape(T, 1, 1, Q, 1)
Expand Down Expand Up @@ -494,7 +551,7 @@ def predict(args, data, samples, truth=None):
model = poutine.reparam(model, {"S_aux": rep, "I_aux": rep})
model = infer_discrete(model, first_available_dim=-2)
with poutine.trace() as tr:
model(data, args.population)
model(args, data)
samples = {name: site["value"]
for name, site in tr.trace.nodes.items()
if site["type"] == "sample"}
Expand All @@ -506,7 +563,7 @@ def predict(args, data, samples, truth=None):
model = poutine.condition(discrete_model, samples)
model = particle_plate(model)
with poutine.trace() as tr:
model(extended_data, args.population)
model(args, extended_data)
samples = {name: site["value"]
for name, site in tr.trace.nodes.items()
if site["type"] == "sample"}
Expand Down Expand Up @@ -604,6 +661,7 @@ def main(args):
parser.add_argument("-w", "--warmup-steps", default=100, type=int)
parser.add_argument("-t", "--max-tree-depth", default=5, type=int)
parser.add_argument("-r", "--rng-seed", default=0, type=int)
parser.add_argument("-so", "--spline-order", default=3, type=int)
parser.add_argument("--double", action="store_true")
parser.add_argument("--jit", action="store_true")
parser.add_argument("--cuda", action="store_true")
Expand Down
1 change: 1 addition & 0 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
'sir_hmc.py -t=2 -w=2 -n=4 -d=2 -p=10000 --sequential',
'sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 -f 2',
'sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 -f 2 --dct',
'sir_hmc.py -t=2 -w=2 -n=4 -d=100 -p=10000 -f 2 -so=5',
'smcfilter.py --num-timesteps=3 --num-particles=10',
'sparse_gamma_def.py --num-epochs=2 --eval-particles=2 --eval-frequency=1 --guide custom',
'sparse_gamma_def.py --num-epochs=2 --eval-particles=2 --eval-frequency=1 --guide auto',
Expand Down