Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement continuous-relaxed inference for CompartmentalModel #2522

Merged
merged 5 commits into from
Jun 12, 2020

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented Jun 11, 2020

Addresses #2426
Replaces #2510, #2513

This implements continuous-valued moment-matched inference for CompartmentalModels.
The new algorithm preserves the existing modeling language and requires only replacement of dist.ExtendedBinomial() with binomial_dist().

This algorithm is exact in the large-sample large-population limit, similar to the requirement of popular ODE models. This algorithm differs from popular ODE models in that:

  1. This approximation is accomplished via nonstandard interpretation of discrete distributions. That is we perform approximate inference on the exact discrete-valued model (and forecast forward using exact discrete-valued dynamics), whereas ODEs perform (large-sample asymptotically) exact inference on an approximate model.
  2. This model retains discrete-time dynamics, whereas ODEs are continuous-time (with adaptive time-discretization provided by numerical integrators).
  3. This model is stochastic, whereas most ODE models are deterministic (conditioned on the Rt stochastic process).

Performance

This inference algorithm is much faster than enumeration. Advantages include:

  • there is no longer a num_quant_bins**num_compartments constant overhead factor (which was 4096 for SEIR models and num_quant_bins=4).
  • dist.Normal.log_prob() includes only a single special function, the log() for normalizing constant; by contrast BetaBinomial requres either multiple expensive lgamma() or log() calls.
  • dist.Normal has full support, so infeasible solutions are allowed but heavily penalized. This leads to fewer rejections, larger step size, and faster adaptation during warmup.

Accuracy

In large populations, this algorithm mixes much better than enumeration.

Relaxed

74 seconds

$ python examples/contrib/epidemiology/sir.py -p=1e7 -d=120 -f=30 -hfm=10 --plot --relax
Simulating from a SimpleSIRModel
Observed 227440/454291 infections:
7 5 2 2 6 0 3 7 5 6 14 10 9 14 10 8 10 10 12 9 14 23 24 21 20 20 17 26 21 38 28 31 33 39 31 53 44 54 44 64 64 53 79 57 67 81 100 115 118 112 155 138 183 171 179 183 167 201 230 266 256 296 303 325 385 378 424 431 497 498 540 591 611 665 708 814 822 848 966 1045 1165 1228 1224 1359 1457 1577 1658 1835 1946 2006 2266 2290 2474 2786 2874 3145 3347 3576 3703 4049 4205 4490 4825 5159 5450 5711 6377 6720 7133 7624 8081 8741 9398 9824 10486 11060 11742 12335 13194 14004
INFO 	 Running inference...
Warmup:   0%|                                                                                           | 0/400 [00:00, ?it/s]INFO 	 Heuristic init: R0=1.51, rho=0.519
Sample: 100%|██████████████████████████████████████████| 400/400 [01:14,  5.35it/s, step size=1.10e-02, acc. prob=0.870]


                                   mean       std    median      5.0%     95.0%     n_eff     r_hat
                           R0      1.50      0.00      1.50      1.50      1.50     49.45      1.00
                          rho      0.52      0.00      0.52      0.52      0.52      4.98      1.18
  auxiliary_haar_split_0[0,0]     42.10      0.01     42.10     42.09     42.12      4.65      1.38
  auxiliary_haar_split_0[0,1]     -9.02      0.01     -9.02     -9.03     -9.01      3.20      1.76

image

image

Enumeration

461 seconds

$ python examples/contrib/epidemiology/sir.py -p=1e7 -d=120 -f=30 -hfm=10 --plot
Simulating from a SimpleSIRModel
Observed 227440/454291 infections:
7 5 2 2 6 0 3 7 5 6 14 10 9 14 10 8 10 10 12 9 14 23 24 21 20 20 17 26 21 38 28 31 33 39 31 53 44 54 44 64 64 53 79 57 67 81 100 115 118 112 155 138 183 171 179 183 167 201 230 266 256 296 303 325 385 378 424 431 497 498 540 591 611 665 708 814 822 848 966 1045 1165 1228 1224 1359 1457 1577 1658 1835 1946 2006 2266 2290 2474 2786 2874 3145 3347 3576 3703 4049 4205 4490 4825 5159 5450 5711 6377 6720 7133 7624 8081 8741 9398 9824 10486 11060 11742 12335 13194 14004
INFO 	 Running inference...
Warmup:   0%|                                                                                           | 0/400 [00:00, ?it/s]INFO 	 Heuristic init: R0=1.54, rho=0.453
Sample: 100%|██████████████████████████████████████████| 400/400 [07:41,  1.15s/it, step size=9.92e-03, acc. prob=0.884]


                                   mean       std    median      5.0%     95.0%     n_eff     r_hat
                           R0      1.51      0.00      1.51      1.50      1.51     64.47      1.03
                          rho      0.45      0.00      0.45      0.45      0.45      4.55      1.51
  auxiliary_haar_split_0[0,0]     41.04      0.01     41.04     41.02     41.06      3.08      2.57
  auxiliary_haar_split_0[0,1]     -9.07      0.00     -9.07     -9.07     -9.06     12.56      1.02

image
image

In small populations, this algorithm still appears to be accurate.

Relaxed

84 seconds

% python examples/contrib/epidemiology/sir.py -p=1000 -e=3 -tau=7 -m=0.1 -M=0.5 -d=60 -f=30 -hfm=10 --plot --relax
Simulating from a SimpleSEIRModel
Observed 111/197 infections:
1 0 0 0 1 0 0 0 1 0 0 0 1 1 1 2 1 0 1 1 1 0 1 2 5 1 0 2 1 3 3 1 2 1 1 0 2 5 2 2 3 5 2 5 5 3 1 0 4 5 4 1 2 3 3 3 6 6 2 2
INFO 	 Running inference...
Warmup:   0%|                                                                                           | 0/400 [00:00, ?it/s]INFO 	 Heuristic init: R0=1.53, rho=0.644
Sample: 100%|██████████████████████████████████████████| 400/400 [01:24,  4.71it/s, step size=1.49e-01, acc. prob=0.869]


                                  mean       std    median      5.0%     95.0%     n_eff     r_hat
                          R0      1.51      0.14      1.51      1.28      1.70     56.44      1.00
                         rho      0.64      0.05      0.63      0.57      0.71      4.50      1.66
 auxiliary_haar_split_0[0,0]     13.29      0.43     13.41     12.55     13.88      3.14      1.91
 auxiliary_haar_split_0[0,1]     -2.65      0.10     -2.65     -2.82     -2.51     11.03      1.09

image

Enumeration

1337 seconds

% python examples/contrib/epidemiology/sir.py -p=1000 -e=3 -tau=7 -m=0.1 -M=0.5 -d=60 -f=30 -hfm=10 --plot
Simulating from a SimpleSEIRModel
Observed 111/197 infections:
1 0 0 0 1 0 0 0 1 0 0 0 1 1 1 2 1 0 1 1 1 0 1 2 5 1 0 2 1 3 3 1 2 1 1 0 2 5 2 2 3 5 2 5 5 3 1 0 4 5 4 1 2 3 3 3 6 6 2 2
INFO 	 Running inference...
Warmup:   0%|                                                                                           | 0/400 [00:00, ?it/s]INFO 	 Heuristic init: R0=1.45, rho=0.588
Sample: 100%|██████████████████████████████████████████| 400/400 [22:17,  3.34s/it, step size=1.38e-01, acc. prob=0.851]


                                  mean       std    median      5.0%     95.0%     n_eff     r_hat
                          R0      1.64      0.16      1.64      1.39      1.90      4.93      1.54
                         rho      0.52      0.04      0.52      0.47      0.58     29.34      1.15
 auxiliary_haar_split_0[0,0]     11.96      0.23     11.94     11.57     12.39      5.13      1.44
 auxiliary_haar_split_0[0,1]     -2.55      0.14     -2.56     -2.74     -2.31     18.67      1.01

image

Tested

  • unit tests for moment matching approximations
  • smoke tests of inference
  • added to examples sir.py and regional.py
  • evaluate accuracy in large populations
  • evaluate accuracy in small populations

@fritzo
Copy link
Member Author

fritzo commented Jun 11, 2020

This is so cheap and apparently accurate that I'm tempted to make it the default strategy.

@@ -371,7 +374,7 @@ def heuristic():
logger.info("Running inference...")
max_tree_depth = options.pop("max_tree_depth", 5)
full_mass = options.pop("full_mass", self.full_mass)
model = self._vectorized_model
model = self._relaxed_model if self.relaxed else self._vectorized_model
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is vectorized_model still the most appropriate name?

Copy link
Member Author

@fritzo fritzo Jun 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I've renamed to _quantized_model(). Note _sequential_model() is compatible with both quantized and relaxed inference, so each model is now named by its distinguishing feature.

@fritzo
Copy link
Member Author

fritzo commented Jun 12, 2020

Thanks for reviewing!

@martinjankowiak martinjankowiak merged commit d18fec8 into dev Jun 12, 2020
@fritzo fritzo deleted the sir-relax-3 branch July 14, 2020 01:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants