Skip to content

Commit

Permalink
add more quantization strategies to contrib.epi (#2440)
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjankowiak authored Apr 24, 2020
1 parent e284f64 commit b52a139
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 39 deletions.
4 changes: 3 additions & 1 deletion examples/contrib/epidemiology/sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def main(args):
obs = dataset["obs"]

# Run inference.
model = SIRModel(args.population, args.recovery_time, obs)
model = SIRModel(args.population, args.recovery_time, obs,
num_quant_bins=args.num_bins)
samples = infer(args, model)

# Evaluate fit.
Expand All @@ -158,6 +159,7 @@ def main(args):
parser.add_argument("-w", "--warmup-steps", default=100, type=int)
parser.add_argument("-t", "--max-tree-depth", default=5, type=int)
parser.add_argument("-r", "--rng-seed", default=0, type=int)
parser.add_argument("-nb", "--num-bins", default=4, type=int)
parser.add_argument("--double", action="store_true")
parser.add_argument("--cuda", action="store_true")
parser.add_argument("--verbose", action="store_true")
Expand Down
17 changes: 13 additions & 4 deletions pyro/contrib/epidemiology/compartmental.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,12 @@ def transition_bwd(self, params, prev, curr, t): ...
:param list compartments: A list of strings of compartment names.
:param int duration:
:param int population:
:param int num_quant_bins: The number of quantization bins to use. Note that
computational cost is exponential in `num_quant_bins`. Defaults to 4.
"""

def __init__(self, compartments, duration, population):
def __init__(self, compartments, duration, population, *,
num_quant_bins=4):
super().__init__()

assert isinstance(duration, int)
Expand All @@ -91,6 +94,10 @@ def __init__(self, compartments, duration, population):
assert len(compartments) == len(set(compartments))
self.compartments = compartments

assert isinstance(num_quant_bins, int)
assert num_quant_bins >= 2
self.num_quant_bins = num_quant_bins

# Inference state.
self.samples = {}

Expand Down Expand Up @@ -364,7 +371,8 @@ def _sequential_model(self):
aux_t = auxiliary[..., t]
prev = curr
curr = {name: quantize("{}_{}".format(name, t), aux,
min=0, max=self.population)
min=0, max=self.population,
num_quant_bins=self.num_quant_bins)
for name, aux in zip(self.compartments, aux_t.unbind(-1))}
self.transition_bwd(params, prev, curr, t)

Expand All @@ -383,7 +391,8 @@ def _vectorized_model(self):
.to_event(2))

# Manually enumerate.
curr, logp = quantize_enumerate(auxiliary, min=0, max=self.population)
curr, logp = quantize_enumerate(auxiliary, min=0, max=self.population,
num_quant_bins=self.num_quant_bins)
curr = OrderedDict(zip(self.compartments, curr))
logp = OrderedDict(zip(self.compartments, logp))

Expand All @@ -398,7 +407,7 @@ def _vectorized_model(self):
# Reshape to support broadcasting, similar to EnumMessenger.
C = len(self.compartments)
T = self.duration
Q = 4 # Number of quantization points.
Q = self.num_quant_bins # Number of quantization points.

def enum_shape(position):
shape = [T] + [1] * (2 * C)
Expand Down
8 changes: 6 additions & 2 deletions pyro/contrib/epidemiology/sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@ class SIRModel(CompartmentalModel):
:param int population:
:param float recovery_time:
:param iterable data: Time series of new observed infections.
:param int data: Time series of new observed infections.
:param int num_quant_bins: The number of quantization bins to use. Note that
computational cost is exponential in `num_quant_bins`. Defaults to 4.
"""

def __init__(self, population, recovery_time, data):
def __init__(self, population, recovery_time, data, *,
num_quant_bins=4):
compartments = ("S", "I") # R is implicit.
duration = len(data)
super().__init__(compartments, duration, population)
super().__init__(compartments, duration, population, num_quant_bins=num_quant_bins)

assert isinstance(recovery_time, float)
assert recovery_time > 0
Expand Down
156 changes: 128 additions & 28 deletions pyro/contrib/epidemiology/util.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,161 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import numpy

import torch

import pyro
import pyro.distributions as dist
from pyro.ops.tensor_utils import safe_log


def quantize(name, x_real, min, max):
# this 8 x 10 tensor encodes the coefficients of 8 10-dimensional polynomials
# that are used to construct the num_quant_bins=16 quantization strategy

W16 = [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1562511562511555e-07],
[1.1562511562511557e-07, 1.04062604062604e-06, 4.16250416250416e-06,
9.712509712509707e-06, 1.456876456876456e-05, 1.4568764568764562e-05,
9.712509712509707e-06, 4.16250416250416e-06, 1.04062604062604e-06, -6.937506937506934e-07],
[5.839068339068337e-05, 0.0002591158841158841, 0.0005036630036630038,
0.0005536130536130536, 0.00036421911421911425, 0.00013111888111888106,
9.712509712509736e-06, -1.2487512487512482e-05, -5.2031302031302014e-06, 1.6187516187516182e-06],
[0.0018637612387612374, 0.004983558108558107, 0.005457042957042955,
0.0029234654234654212, 0.000568181818181818, -0.0001602564102564102,
-8.741258741258739e-05, 4.162504162504162e-06, 9.365634365634364e-06, -1.7536475869809201e-06],
[0.015560115039281694, 0.025703289765789755, 0.015009296259296255,
0.0023682336182336166, -0.000963966588966589, -0.00029380341880341857,
5.6656306656306665e-05, 1.5956265956265953e-05, -6.417193917193917e-06, 7.515632515632516e-07],
[0.057450111616778265, 0.05790875790875791, 0.014424464424464418,
-0.0030303030303030303, -0.0013791763791763793, 0.00011655011655011669,
5.180005180005181e-05, -8.325008325008328e-06, 3.4687534687534703e-07, 0.0],
[0.12553422657589322, 0.072988122988123, -0.0011641136641136712,
-0.006617456617456618, -0.00028651903651903725, 0.00027195027195027195,
3.2375032375032334e-06, -5.550005550005552e-06, 3.4687534687534703e-07, 0.0],
[0.21761806865973532, 1.7482707128494565e-17, -0.028320290820290833,
0.0, 0.0014617327117327117, 0.0,
-3.561253561253564e-05, 0.0, 3.4687534687534714e-07, 0.0]]

W16 = numpy.array(W16)


def compute_bin_probs(s, num_quant_bins=3):
"""
Compute categorical probabilities for a quantization scheme with num_quant_bins many
bins. `s` is a real-valued tensor with values in [0, 1]. Returns probabilities
of shape `s.shape` + `(num_quant_bins,)`
"""
if num_quant_bins not in [4, 8, 12, 16]:
raise ValueError("Supported quantization strategies have 4, 8, 12, or 16 bins")

t = 1 - s
ss = s * s
tt = t * t

if num_quant_bins == 4:
# This cubic spline interpolates over the nearest four integers, ensuring
# piecewise quadratic gradients.
probs = torch.stack([
t * tt,
4 + ss * (3 * s - 6),
4 + tt * (3 * t - 6),
s * ss,
], dim=-1) * (1/6)
elif num_quant_bins == 8:
# This quintic spline interpolates over the nearest eight integers, ensuring
# piecewise quartic gradients.
s3 = ss * s
s4 = ss * ss
s5 = s3 * ss

t3 = tt * t
t4 = tt * tt
t5 = t3 * tt

probs = torch.stack([
2 * t5,
2 + 10 * t + 20 * tt + 20 * t3 + 10 * t4 - 7 * t5,
55 + 115 * t + 70 * tt - 9 * t3 - 25 * t4 + 7 * t5,
302 - 100 * ss + 10 * s4,
302 - 100 * tt + 10 * t4,
55 + 115 * s + 70 * ss - 9 * s3 - 25 * s4 + 7 * s5,
2 + 10 * s + 20 * ss + 20 * s3 + 10 * s4 - 7 * s5,
2 * s5
], dim=-1) * (1/840)
elif num_quant_bins == 12:
# This septic spline interpolates over the nearest 12 integers
s3 = ss * s
s4 = ss * ss
s5 = s3 * ss
s6 = s3 * s3
s7 = s4 * s3

t3 = tt * t
t4 = tt * tt
t5 = t3 * tt
t6 = t3 * t3
t7 = t4 * t3

probs = torch.stack([
693 * t7,
693 + 4851 * t + 14553 * tt + 24255 * t3 + 24255 * t4 + 14553 * t5 + 4851 * t6 - 3267 * t7,
84744 + 282744 * t + 382536 * tt + 249480 * t3 + 55440 * t4 - 24948 * t5 - 18018 * t6 + 5445 * t7,
1017423 + 1823283 * t + 1058211 * tt + 51975 * t3 - 148995 * t4 - 18711 * t5 + 20097 * t6 - 3267 * t7,
3800016 + 3503808 * t + 365904 * tt - 443520 * t3 - 55440 * t4 + 33264 * t5 - 2772 * t6,
8723088 - 1629936 * ss + 110880.0 * s4 - 2772 * s6,
8723088 - 1629936 * tt + 110880.0 * t4 - 2772 * t6,
3800016 + 3503808 * s + 365904 * ss - 443520 * s3 - 55440 * s4 + 33264 * s5 - 2772 * s6,
1017423 + 1823283 * s + 1058211 * ss + 51975 * s3 - 148995 * s4 - 18711 * s5 + 20097 * s6 - 3267 * s7,
84744 + 282744 * s + 382536 * ss + 249480 * s3 + 55440 * s4 - 24948 * s5 - 18018 * s6 + 5445 * s7,
693 + 4851 * s + 14553 * ss + 24255 * s3 + 24255 * s4 + 14553 * s5 + 4851 * s6 - 3267 * s7,
693 * s7,
], dim=-1) * (1/32931360)
elif num_quant_bins == 16:
# This nonic spline interpolates over the nearest 16 integers
w16 = torch.from_numpy(W16).to(s.device).type_as(s)
s_powers = s.unsqueeze(-1).unsqueeze(-1).pow(torch.arange(10.))
t_powers = t.unsqueeze(-1).unsqueeze(-1).pow(torch.arange(10.))
splines_t = (w16 * t_powers).sum(-1)
splines_s = (w16 * s_powers).sum(-1)
index = [0, 1, 2, 3, 4, 5, 6, 15, 7, 14, 13, 12, 11, 10, 9, 8]
probs = torch.cat([splines_t, splines_s], dim=-1)
probs = probs.index_select(-1, torch.tensor(index))

return probs


def quantize(name, x_real, min, max, num_quant_bins=4):
"""Randomly quantize in a way that preserves probability mass."""
assert min < max
lb = x_real.detach().floor()

# This cubic spline interpolates over the nearest four integers, ensuring
# piecewise quadratic gradients.
s = x_real - lb
ss = s * s
t = 1 - s
tt = t * t
probs = torch.stack([
t * tt,
4 + ss * (3 * s - 6),
4 + tt * (3 * t - 6),
s * ss,
], dim=-1) * (1/6)
probs = compute_bin_probs(x_real - lb, num_quant_bins=num_quant_bins)

q = pyro.sample("Q_" + name, dist.Categorical(probs),
infer={"enumerate": "parallel"})
q = q.type_as(x_real) - 1
q = q.type_as(x_real) - (num_quant_bins // 2 - 1)

x = lb + q
x = torch.max(x, 2 * min - 1 - x)
x = torch.min(x, 2 * max + 1 - x)

return pyro.deterministic(name, x)


def quantize_enumerate(x_real, min, max):
def quantize_enumerate(x_real, min, max, num_quant_bins=4):
"""Quantize, then manually enumerate."""
assert min < max
lb = x_real.detach().floor()

# This cubic spline interpolates over the nearest four integers, ensuring
# piecewise quadratic gradients.
s = x_real - lb
ss = s * s
t = 1 - s
tt = t * t
probs = torch.stack([
t * tt,
4 + ss * (3 * s - 6),
4 + tt * (3 * t - 6),
s * ss,
], dim=-1) * (1/6)
probs = compute_bin_probs(x_real - lb, num_quant_bins=num_quant_bins)
logits = safe_log(probs)
q = torch.arange(-1., 3.)

arange_min = 1 - num_quant_bins // 2
arange_max = 1 + num_quant_bins // 2
q = torch.arange(arange_min, arange_max)

x = lb.unsqueeze(-1) + q
x = torch.max(x, 2 * min - 1 - x)
x = torch.min(x, 2 * max + 1 - x)

return x, logits
30 changes: 30 additions & 0 deletions tests/contrib/epidemiology/test_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import pytest

import torch

from pyro.contrib.epidemiology.util import compute_bin_probs


@pytest.mark.parametrize("num_quant_bins", [4, 8, 12, 16])
def test_quantization_scheme(num_quant_bins, num_samples=1000 * 1000):
min, max = 0, 7
probs = torch.zeros(max + 1)

x = torch.linspace(-0.5, max + 0.5, num_samples)
bin_probs = compute_bin_probs(x - x.floor(), num_quant_bins=num_quant_bins)
x_floor = x.floor()

q_min = 1 - num_quant_bins // 2
q_max = 1 + num_quant_bins // 2

for k, q in enumerate(range(q_min, q_max)):
y = (x_floor + q).long()
y = torch.max(y, 2 * min - 1 - y)
y = torch.min(y, 2 * max + 1 - y)
probs.scatter_add_(0, y, bin_probs[:, k] / num_samples)

max_deviation = (probs - 1.0 / (max + 1.0)).abs().max().item()
assert max_deviation < 1.0e-4
10 changes: 6 additions & 4 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@
'contrib/autoname/mixture.py --num-epochs=1',
'contrib/autoname/tree_data.py --num-epochs=1',
'contrib/cevae/synthetic.py --num-epochs=1',
'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=100 -p=10000 -f 2',
'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=100 -p=10000 -f 2 --dct=1',
'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2',
'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -nb=8',
'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -nb=16 --dct=1',
'contrib/forecast/bart.py --num-steps=2 --stride=99999',
'contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --batch-size=1000',
'contrib/gp/sv-dkl.py --binary --epochs=1 --num-inducing=4 --batch-size=1000',
Expand Down Expand Up @@ -97,8 +98,9 @@
'air/main.py --num-steps=1 --cuda',
'baseball.py --num-samples=200 --warmup-steps=100 --num-chains=2 --cuda',
'contrib/cevae/synthetic.py --num-epochs=1 --cuda',
'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=100 -p=10000 -f 2 --cuda',
'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=100 -p=10000 -f 2 --dct=1 --cuda',
'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --cuda',
'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -nb=16 --cuda',
'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --dct=1 --cuda',
'contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --cuda',
'lkj.py --n=50 --num-chains=1 --warmup-steps=100 --num-samples=200 --cuda',
'dmm/dmm.py --num-epochs=1 --cuda',
Expand Down

0 comments on commit b52a139

Please sign in to comment.