Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement SVI inference for CompartmentalModel #2529

Merged
merged 16 commits into from
Jun 16, 2020
Merged

Implement SVI inference for CompartmentalModel #2529

merged 16 commits into from
Jun 16, 2020

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented Jun 16, 2020

Addresses #2426

This implements SVI inference for the moment-matched relaxed continuous model. Although this approach is doubly approximate, it is very fast and appears to be plausibly accurate. In follow-up PRs I plan to explore ways to make this more exact, e.g. NeuTra or IWAE or importance-sampling.

Summary of changes:

  • lots of arg plumbing
  • a new .fit_svi() method, very vanilla Pyro
  • updated reshaping logic in ._sample_auxiliary(); updated its callers

Tested

  • added smoke tests for shapes
  • added smoke tests for examples/.../sir.py and regional.py
  • ran locally to verify plausible behavior
 python examples/contrib/epidemiology/sir.py --svi -p 10000 -d 60 -f 30 -e 3 -tau 12 -m 0.01 -k 4 -R0 0.4 -M 0.3 --plot --haar -m 0.1 -M 0.5
Simulating from a SuperspreadingSEIRModel
Observed 4023/6908 infections:
0 0 1 0 0 2 0 0 0 1 2 2 2 1 1 3 0 3 3 1 7 7 4 13 12 9 13 10 23 13 33 26 24 29 47 53 55 61 68 77 78 80 113 131 142 151 142 184 193 205 226 215 213 229 189 218 208 195 155 150
INFO 	 Heuristic init: R0=0.393, k=2.5, rho=0.512
INFO 	 Running inference...
INFO 	 step 0 loss = 1.113e+04
INFO 	 step 200 loss = 19.42
INFO 	 step 400 loss = 24.3
INFO 	 step 600 loss = 13.41
INFO 	 step 800 loss = 13.46
INFO 	 step 1000 loss = 9.879
INFO 	 step 1200 loss = 6.904
INFO 	 step 1400 loss = 7.494
INFO 	 step 1600 loss = 6.158
INFO 	 step 1800 loss = 6.399
INFO 	 step 2000 loss = 5.442
INFO 	 step 2200 loss = 5.293
INFO 	 step 2400 loss = 5.163
INFO 	 step 2600 loss = 4.887
INFO 	 step 2800 loss = 4.854
INFO 	 step 3000 loss = 4.488
INFO 	 step 3200 loss = 4.568
INFO 	 step 3400 loss = 4.44
INFO 	 step 3600 loss = 4.394
INFO 	 step 3800 loss = 4.484
INFO 	 step 4000 loss = 4.4
INFO 	 step 4200 loss = 4.379
INFO 	 step 4400 loss = 4.399
INFO 	 step 4600 loss = 4.436
INFO 	 step 4800 loss = 4.303
INFO 	 step 5000 loss = 4.322
INFO 	 SVI took 64.7 seconds, 77.3 step/sec
R0: truth = 0.4, estimate = 0.373 ± 0.00666
rho: truth = 0.5, estimate = 0.525 ± 0.00814
k: truth = 4, estimate = 2.34 ± 1.3

image
image

Copy link
Collaborator

@martinjankowiak martinjankowiak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how long do all the contrib/epi tests take?

@martinjankowiak martinjankowiak merged commit fe9540f into dev Jun 16, 2020
@fritzo
Copy link
Member Author

fritzo commented Jun 16, 2020

Thanks for reviewing!

how long do all the contrib/epi tests take?

2minutes 20sec on my laptop, but they're all in the integration stage (not the unit stage), so they don't add to total test latency.

@fritzo fritzo deleted the sir-svi branch July 14, 2020 01:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants