-
-
Notifications
You must be signed in to change notification settings - Fork 988
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Start a contrib.epidemiology module #2437
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great!
|
||
# Sample initial values. | ||
state = self.initialize(params) | ||
state = {i: torch.tensor(float(value)) for i, value in state.items()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is this for exactly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.initialize
currently returns deterministic initial states. Because those states are used in torch.nn.functional.pad
they must be Python scalars. However in this function those initial states need to start out as tensors.
The state
dict in this function acts similarly to the state
dict in SMCFilter: it is a framework-managed storage location where users can read and write values. Note the self.transition_fwd(params, state, t)
call below updates this state dict in-place.
logp = pyro.distributions.hmm._sequential_logmatmulexp(logp) | ||
logp = logp.reshape(-1).logsumexp(0) | ||
warn_if_nan(logp) | ||
pyro.factor("transition", logp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happened to the previous -log(4)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have neglected that from this model because (1) it does not affect inference (it does not change gradients), and (2) we will probably be moving to stochastic initial state soon anyway.
Addresses #2426
This refactors examples/sir_hmc.py into a more reusable framework that separates (1) clear user-facing model specification from (2) intricate effect handler code for inference and prediction. I have preserved the existing examples/sir_hmc.py as a minipyro-like concrete explanation of the new module; it thus serves as architecture documentation. This also makes small changes to sir_hmc.py to keep the two versions aligned (the old concrete and new abstract versions).
Whereas
sir_hmc.py
contains four models, I was able to reduce user-facing modeling code down to a few methods and only a single duplication: we need a forward.transition_fwd()
method and also a.transition_bwd()
method; these contain reverse versions of dynamic equations and cannot easily be automatically generated.After this PR I plan to simplify sir_hmc.py by moving all the fancy DCT and spline stuff into the new module.
Tested
sir_hmc.py