diff --git a/examples/contrib/epidemiology/sir.py b/examples/contrib/epidemiology/sir.py index 440839c92a..75fdb7f0e7 100644 --- a/examples/contrib/epidemiology/sir.py +++ b/examples/contrib/epidemiology/sir.py @@ -131,7 +131,8 @@ def main(args): obs = dataset["obs"] # Run inference. - model = SIRModel(args.population, args.recovery_time, obs) + model = SIRModel(args.population, args.recovery_time, obs, + num_quant_bins=args.num_bins) samples = infer(args, model) # Evaluate fit. @@ -158,6 +159,7 @@ def main(args): parser.add_argument("-w", "--warmup-steps", default=100, type=int) parser.add_argument("-t", "--max-tree-depth", default=5, type=int) parser.add_argument("-r", "--rng-seed", default=0, type=int) + parser.add_argument("-nb", "--num-bins", default=4, type=int) parser.add_argument("--double", action="store_true") parser.add_argument("--cuda", action="store_true") parser.add_argument("--verbose", action="store_true") diff --git a/pyro/contrib/epidemiology/compartmental.py b/pyro/contrib/epidemiology/compartmental.py index 583f7c6ffc..8f59672d55 100644 --- a/pyro/contrib/epidemiology/compartmental.py +++ b/pyro/contrib/epidemiology/compartmental.py @@ -73,9 +73,12 @@ def transition_bwd(self, params, prev, curr, t): ... :param list compartments: A list of strings of compartment names. :param int duration: :param int population: + :param int num_quant_bins: The number of quantization bins to use. Note that + computational cost is exponential in `num_quant_bins`. Defaults to 4. """ - def __init__(self, compartments, duration, population): + def __init__(self, compartments, duration, population, *, + num_quant_bins=4): super().__init__() assert isinstance(duration, int) @@ -91,6 +94,10 @@ def __init__(self, compartments, duration, population): assert len(compartments) == len(set(compartments)) self.compartments = compartments + assert isinstance(num_quant_bins, int) + assert num_quant_bins >= 2 + self.num_quant_bins = num_quant_bins + # Inference state. self.samples = {} @@ -364,7 +371,8 @@ def _sequential_model(self): aux_t = auxiliary[..., t] prev = curr curr = {name: quantize("{}_{}".format(name, t), aux, - min=0, max=self.population) + min=0, max=self.population, + num_quant_bins=self.num_quant_bins) for name, aux in zip(self.compartments, aux_t.unbind(-1))} self.transition_bwd(params, prev, curr, t) @@ -383,7 +391,8 @@ def _vectorized_model(self): .to_event(2)) # Manually enumerate. - curr, logp = quantize_enumerate(auxiliary, min=0, max=self.population) + curr, logp = quantize_enumerate(auxiliary, min=0, max=self.population, + num_quant_bins=self.num_quant_bins) curr = OrderedDict(zip(self.compartments, curr)) logp = OrderedDict(zip(self.compartments, logp)) @@ -398,7 +407,7 @@ def _vectorized_model(self): # Reshape to support broadcasting, similar to EnumMessenger. C = len(self.compartments) T = self.duration - Q = 4 # Number of quantization points. + Q = self.num_quant_bins # Number of quantization points. def enum_shape(position): shape = [T] + [1] * (2 * C) diff --git a/pyro/contrib/epidemiology/sir.py b/pyro/contrib/epidemiology/sir.py index 62fa24fec2..c72cd9f702 100644 --- a/pyro/contrib/epidemiology/sir.py +++ b/pyro/contrib/epidemiology/sir.py @@ -17,12 +17,16 @@ class SIRModel(CompartmentalModel): :param int population: :param float recovery_time: :param iterable data: Time series of new observed infections. + :param int data: Time series of new observed infections. + :param int num_quant_bins: The number of quantization bins to use. Note that + computational cost is exponential in `num_quant_bins`. Defaults to 4. """ - def __init__(self, population, recovery_time, data): + def __init__(self, population, recovery_time, data, *, + num_quant_bins=4): compartments = ("S", "I") # R is implicit. duration = len(data) - super().__init__(compartments, duration, population) + super().__init__(compartments, duration, population, num_quant_bins=num_quant_bins) assert isinstance(recovery_time, float) assert recovery_time > 0 diff --git a/pyro/contrib/epidemiology/util.py b/pyro/contrib/epidemiology/util.py index f09c6bd954..16cd8acc1f 100644 --- a/pyro/contrib/epidemiology/util.py +++ b/pyro/contrib/epidemiology/util.py @@ -1,6 +1,8 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import numpy + import torch import pyro @@ -8,54 +10,152 @@ from pyro.ops.tensor_utils import safe_log -def quantize(name, x_real, min, max): +# this 8 x 10 tensor encodes the coefficients of 8 10-dimensional polynomials +# that are used to construct the num_quant_bins=16 quantization strategy + +W16 = [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.1562511562511555e-07], + [1.1562511562511557e-07, 1.04062604062604e-06, 4.16250416250416e-06, + 9.712509712509707e-06, 1.456876456876456e-05, 1.4568764568764562e-05, + 9.712509712509707e-06, 4.16250416250416e-06, 1.04062604062604e-06, -6.937506937506934e-07], + [5.839068339068337e-05, 0.0002591158841158841, 0.0005036630036630038, + 0.0005536130536130536, 0.00036421911421911425, 0.00013111888111888106, + 9.712509712509736e-06, -1.2487512487512482e-05, -5.2031302031302014e-06, 1.6187516187516182e-06], + [0.0018637612387612374, 0.004983558108558107, 0.005457042957042955, + 0.0029234654234654212, 0.000568181818181818, -0.0001602564102564102, + -8.741258741258739e-05, 4.162504162504162e-06, 9.365634365634364e-06, -1.7536475869809201e-06], + [0.015560115039281694, 0.025703289765789755, 0.015009296259296255, + 0.0023682336182336166, -0.000963966588966589, -0.00029380341880341857, + 5.6656306656306665e-05, 1.5956265956265953e-05, -6.417193917193917e-06, 7.515632515632516e-07], + [0.057450111616778265, 0.05790875790875791, 0.014424464424464418, + -0.0030303030303030303, -0.0013791763791763793, 0.00011655011655011669, + 5.180005180005181e-05, -8.325008325008328e-06, 3.4687534687534703e-07, 0.0], + [0.12553422657589322, 0.072988122988123, -0.0011641136641136712, + -0.006617456617456618, -0.00028651903651903725, 0.00027195027195027195, + 3.2375032375032334e-06, -5.550005550005552e-06, 3.4687534687534703e-07, 0.0], + [0.21761806865973532, 1.7482707128494565e-17, -0.028320290820290833, + 0.0, 0.0014617327117327117, 0.0, + -3.561253561253564e-05, 0.0, 3.4687534687534714e-07, 0.0]] + +W16 = numpy.array(W16) + + +def compute_bin_probs(s, num_quant_bins=3): + """ + Compute categorical probabilities for a quantization scheme with num_quant_bins many + bins. `s` is a real-valued tensor with values in [0, 1]. Returns probabilities + of shape `s.shape` + `(num_quant_bins,)` + """ + if num_quant_bins not in [4, 8, 12, 16]: + raise ValueError("Supported quantization strategies have 4, 8, 12, or 16 bins") + + t = 1 - s + ss = s * s + tt = t * t + + if num_quant_bins == 4: + # This cubic spline interpolates over the nearest four integers, ensuring + # piecewise quadratic gradients. + probs = torch.stack([ + t * tt, + 4 + ss * (3 * s - 6), + 4 + tt * (3 * t - 6), + s * ss, + ], dim=-1) * (1/6) + elif num_quant_bins == 8: + # This quintic spline interpolates over the nearest eight integers, ensuring + # piecewise quartic gradients. + s3 = ss * s + s4 = ss * ss + s5 = s3 * ss + + t3 = tt * t + t4 = tt * tt + t5 = t3 * tt + + probs = torch.stack([ + 2 * t5, + 2 + 10 * t + 20 * tt + 20 * t3 + 10 * t4 - 7 * t5, + 55 + 115 * t + 70 * tt - 9 * t3 - 25 * t4 + 7 * t5, + 302 - 100 * ss + 10 * s4, + 302 - 100 * tt + 10 * t4, + 55 + 115 * s + 70 * ss - 9 * s3 - 25 * s4 + 7 * s5, + 2 + 10 * s + 20 * ss + 20 * s3 + 10 * s4 - 7 * s5, + 2 * s5 + ], dim=-1) * (1/840) + elif num_quant_bins == 12: + # This septic spline interpolates over the nearest 12 integers + s3 = ss * s + s4 = ss * ss + s5 = s3 * ss + s6 = s3 * s3 + s7 = s4 * s3 + + t3 = tt * t + t4 = tt * tt + t5 = t3 * tt + t6 = t3 * t3 + t7 = t4 * t3 + + probs = torch.stack([ + 693 * t7, + 693 + 4851 * t + 14553 * tt + 24255 * t3 + 24255 * t4 + 14553 * t5 + 4851 * t6 - 3267 * t7, + 84744 + 282744 * t + 382536 * tt + 249480 * t3 + 55440 * t4 - 24948 * t5 - 18018 * t6 + 5445 * t7, + 1017423 + 1823283 * t + 1058211 * tt + 51975 * t3 - 148995 * t4 - 18711 * t5 + 20097 * t6 - 3267 * t7, + 3800016 + 3503808 * t + 365904 * tt - 443520 * t3 - 55440 * t4 + 33264 * t5 - 2772 * t6, + 8723088 - 1629936 * ss + 110880.0 * s4 - 2772 * s6, + 8723088 - 1629936 * tt + 110880.0 * t4 - 2772 * t6, + 3800016 + 3503808 * s + 365904 * ss - 443520 * s3 - 55440 * s4 + 33264 * s5 - 2772 * s6, + 1017423 + 1823283 * s + 1058211 * ss + 51975 * s3 - 148995 * s4 - 18711 * s5 + 20097 * s6 - 3267 * s7, + 84744 + 282744 * s + 382536 * ss + 249480 * s3 + 55440 * s4 - 24948 * s5 - 18018 * s6 + 5445 * s7, + 693 + 4851 * s + 14553 * ss + 24255 * s3 + 24255 * s4 + 14553 * s5 + 4851 * s6 - 3267 * s7, + 693 * s7, + ], dim=-1) * (1/32931360) + elif num_quant_bins == 16: + # This nonic spline interpolates over the nearest 16 integers + w16 = torch.from_numpy(W16).to(s.device).type_as(s) + s_powers = s.unsqueeze(-1).unsqueeze(-1).pow(torch.arange(10.)) + t_powers = t.unsqueeze(-1).unsqueeze(-1).pow(torch.arange(10.)) + splines_t = (w16 * t_powers).sum(-1) + splines_s = (w16 * s_powers).sum(-1) + index = [0, 1, 2, 3, 4, 5, 6, 15, 7, 14, 13, 12, 11, 10, 9, 8] + probs = torch.cat([splines_t, splines_s], dim=-1) + probs = probs.index_select(-1, torch.tensor(index)) + + return probs + + +def quantize(name, x_real, min, max, num_quant_bins=4): """Randomly quantize in a way that preserves probability mass.""" assert min < max lb = x_real.detach().floor() - # This cubic spline interpolates over the nearest four integers, ensuring - # piecewise quadratic gradients. - s = x_real - lb - ss = s * s - t = 1 - s - tt = t * t - probs = torch.stack([ - t * tt, - 4 + ss * (3 * s - 6), - 4 + tt * (3 * t - 6), - s * ss, - ], dim=-1) * (1/6) + probs = compute_bin_probs(x_real - lb, num_quant_bins=num_quant_bins) + q = pyro.sample("Q_" + name, dist.Categorical(probs), infer={"enumerate": "parallel"}) - q = q.type_as(x_real) - 1 + q = q.type_as(x_real) - (num_quant_bins // 2 - 1) x = lb + q x = torch.max(x, 2 * min - 1 - x) x = torch.min(x, 2 * max + 1 - x) + return pyro.deterministic(name, x) -def quantize_enumerate(x_real, min, max): +def quantize_enumerate(x_real, min, max, num_quant_bins=4): """Quantize, then manually enumerate.""" assert min < max lb = x_real.detach().floor() - # This cubic spline interpolates over the nearest four integers, ensuring - # piecewise quadratic gradients. - s = x_real - lb - ss = s * s - t = 1 - s - tt = t * t - probs = torch.stack([ - t * tt, - 4 + ss * (3 * s - 6), - 4 + tt * (3 * t - 6), - s * ss, - ], dim=-1) * (1/6) + probs = compute_bin_probs(x_real - lb, num_quant_bins=num_quant_bins) logits = safe_log(probs) - q = torch.arange(-1., 3.) + + arange_min = 1 - num_quant_bins // 2 + arange_max = 1 + num_quant_bins // 2 + q = torch.arange(arange_min, arange_max) x = lb.unsqueeze(-1) + q x = torch.max(x, 2 * min - 1 - x) x = torch.min(x, 2 * max + 1 - x) + return x, logits diff --git a/tests/contrib/epidemiology/test_quant.py b/tests/contrib/epidemiology/test_quant.py new file mode 100644 index 0000000000..dbaca64d7a --- /dev/null +++ b/tests/contrib/epidemiology/test_quant.py @@ -0,0 +1,30 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +import torch + +from pyro.contrib.epidemiology.util import compute_bin_probs + + +@pytest.mark.parametrize("num_quant_bins", [4, 8, 12, 16]) +def test_quantization_scheme(num_quant_bins, num_samples=1000 * 1000): + min, max = 0, 7 + probs = torch.zeros(max + 1) + + x = torch.linspace(-0.5, max + 0.5, num_samples) + bin_probs = compute_bin_probs(x - x.floor(), num_quant_bins=num_quant_bins) + x_floor = x.floor() + + q_min = 1 - num_quant_bins // 2 + q_max = 1 + num_quant_bins // 2 + + for k, q in enumerate(range(q_min, q_max)): + y = (x_floor + q).long() + y = torch.max(y, 2 * min - 1 - y) + y = torch.min(y, 2 * max + 1 - y) + probs.scatter_add_(0, y, bin_probs[:, k] / num_samples) + + max_deviation = (probs - 1.0 / (max + 1.0)).abs().max().item() + assert max_deviation < 1.0e-4 diff --git a/tests/test_examples.py b/tests/test_examples.py index 5e933cfc1a..894190f1d5 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -33,8 +33,9 @@ 'contrib/autoname/mixture.py --num-epochs=1', 'contrib/autoname/tree_data.py --num-epochs=1', 'contrib/cevae/synthetic.py --num-epochs=1', - 'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=100 -p=10000 -f 2', - 'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=100 -p=10000 -f 2 --dct=1', + 'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2', + 'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -nb=8', + 'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -nb=16 --dct=1', 'contrib/forecast/bart.py --num-steps=2 --stride=99999', 'contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --batch-size=1000', 'contrib/gp/sv-dkl.py --binary --epochs=1 --num-inducing=4 --batch-size=1000', @@ -97,8 +98,9 @@ 'air/main.py --num-steps=1 --cuda', 'baseball.py --num-samples=200 --warmup-steps=100 --num-chains=2 --cuda', 'contrib/cevae/synthetic.py --num-epochs=1 --cuda', - 'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=100 -p=10000 -f 2 --cuda', - 'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=100 -p=10000 -f 2 --dct=1 --cuda', + 'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --cuda', + 'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 -nb=16 --cuda', + 'contrib/epidemiology/sir.py -t=2 -w=2 -n=4 -d=20 -p=1000 -f 2 --dct=1 --cuda', 'contrib/gp/sv-dkl.py --epochs=1 --num-inducing=4 --cuda', 'lkj.py --n=50 --num-chains=1 --warmup-steps=100 --num-samples=200 --cuda', 'dmm/dmm.py --num-epochs=1 --cuda',