-
-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add find_MAP
with close JAX integration and fix bug with Laplace fit
#385
Conversation
Hey, nice one, yeah I agree, we should only have one
The current Current behaviour of
The behaviour when you only pass a subset of variables isn't really desirable in my opinion (see #345 (comment)), so we put a warning. So as you say:
Agree, that's the best plan for Judging by your docs and a quick glance at your code, I think you're basically doing the same thing. The current implementation is few lines of code and a few docs, so I reckon
Then it should be safe to delete the existing code and we can go back to one
I would love a generic optimiser in p u r e pytensor, but I can see looking at your code that there a lot of fancy extras that would take a large effort to write in pytensor. Still, if we want to go back to one of our efforts with a fixed point operator (pymc-devs/pytensor#978 and pymc-devs/pytensor#944), we could probably write Happy to look at your code and review properly later in the week if you'd like me to. Let me know. Otherwise, I'll leave to the core devs. |
That would be appreciated |
Agree with what @theorashid said. This |
No objections about your custom library wrapper |
tagging @theorashid -- I couldn't pick you as a reviewer? I did a major refactor of this. I broke the marriage to jax and generalized the find_MAP function. Files have been renamed to reflect this. I also merged the two laplace approaches. The biggest change is that I removed the ability to choose |
yea sorry I'm just a normal, but I'll give it a review. Will do it at some point in the next 2 weeks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor suggestions, PR looks amazing!
@jessegrabowski can we close #376 with this PR? Do you have a test that covers something like it? |
- Rename function `laplace` -> `sample_laplace_posterior`
find_MAP
functionfind_MAP
with close JAX integration and fix bug with Laplace fit
sweet, all done? |
For now, though I'd still appreciate it if you could have a look and open issues on any bugs/shortcomings you find |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I managed to follow the code through and it looks good to me. Happy you got rid of the option to fit on a subset of variables, which didn't make sense to me anyway. If it passes the original test then it should be good. You can do something about the other comments if you want, but maybe not because we are e x p e r i m e n t a l
H_inv = get_nearest_psd(H_inv) | ||
if on_bad_cov == "warn": | ||
_log.warning( | ||
"Inverse Hessian is not positive semi-definite at the provided point, using the closest PSD " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my understanding, what sort of scenarios/models would get a not PSD hessian. And is using closest PSD a good ideas?
Check for `jax` installation before any computation if `gradient_backend = 'jax'`
Check for `jax` installation before any computation if `gradient_backend = 'jax'`
Closes #376
This PR adds code to run
find_MAP
using JAX. I'm using JAX for gradients, because I found the compile times were faster. Open to suggestions/rebuke.It also adds a
fit_laplace
function, which is bad because we already have afit_laplace
function. This one has slightly different objective though -- it isn't meant to be used as a step sampler on a subset of model variables. Instead, it is meant to be used on the MAP result to give an approximation to the full posterior. My function also lets you do the Laplace approximation in the transformed space, then do sample-wise reverse transformation. I think this is legit, and lets you obtain approximate posteriors that respect the domain of the prior. Tagging @theorashid so we can resolve the differences.Last point is that I added a dependency on
better_optimize
. This is a package I wrote that basically rips out the wrapper code used in PyMCfind_MAP
and applies it to arbitrary optimization problems. It is more feature complete than the PyMC wrapper -- it supports all optimizer modes forscipy.optimize.minimize
andscipy.optimize.root
, and also helps get keywords to the right place in those functions (who can ever remember if an argument goes inmethod_kwargs
or in the funciton itself?). I plan to add support forbasinhopping
as well, which will be nice for really hairy minimizations.I could see an objection to adding another dependency, but 1) it's a lightweight wrapper around functionality that doesn't really belong in PyMC anyway, and 2) it's a big value-add compared to working directly with the
scipy.optimize
functions, which have gnarly, inconsistent signatures.