-
-
Notifications
You must be signed in to change notification settings - Fork 984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement continuous-relaxed inference for CompartmentalModel #2522
Conversation
This is so cheap and apparently accurate that I'm tempted to make it the default strategy. |
@@ -371,7 +374,7 @@ def heuristic(): | |||
logger.info("Running inference...") | |||
max_tree_depth = options.pop("max_tree_depth", 5) | |||
full_mass = options.pop("full_mass", self.full_mass) | |||
model = self._vectorized_model | |||
model = self._relaxed_model if self.relaxed else self._vectorized_model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is vectorized_model
still the most appropriate name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I've renamed to _quantized_model()
. Note _sequential_model()
is compatible with both quantized and relaxed inference, so each model is now named by its distinguishing feature.
Thanks for reviewing! |
Addresses #2426
Replaces #2510, #2513
This implements continuous-valued moment-matched inference for
CompartmentalModels
.The new algorithm preserves the existing modeling language and requires only replacement of
dist.ExtendedBinomial()
withbinomial_dist()
.This algorithm is exact in the large-sample large-population limit, similar to the requirement of popular ODE models. This algorithm differs from popular ODE models in that:
Rt
stochastic process).Performance
This inference algorithm is much faster than enumeration. Advantages include:
num_quant_bins**num_compartments
constant overhead factor (which was 4096 for SEIR models andnum_quant_bins=4
).dist.Normal.log_prob()
includes only a single special function, thelog()
for normalizing constant; by contrastBetaBinomial
requres either multiple expensivelgamma()
orlog()
calls.dist.Normal
has full support, so infeasible solutions are allowed but heavily penalized. This leads to fewer rejections, larger step size, and faster adaptation during warmup.Accuracy
In large populations, this algorithm mixes much better than enumeration.
Relaxed
74 seconds
Enumeration
461 seconds
In small populations, this algorithm still appears to be accurate.
Relaxed
84 seconds
Enumeration
1337 seconds
Tested
sir.py
andregional.py