Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HMC #4

Open
1 task
alecandido opened this issue Apr 27, 2022 · 3 comments
Open
1 task

HMC #4

alecandido opened this issue Apr 27, 2022 · 3 comments
Labels
roadmap Track progresses on the project

Comments

@alecandido
Copy link
Member

Here the roadmap for HMC implementation:

  • 0: basic GP regression with model and NUTS from pymc3
@alecandido
Copy link
Member Author

This is just for myself @alecandido, same considerations of #2 (comment)

@Gattocrucco
Copy link
Collaborator

Gattocrucco commented Apr 27, 2022

The problem with pymc3 is that we want to take second derivatives, and then pymc3 takes the gradient on top of that. pymc4 supports JAX as backend so maybe (not sure) you can do it easily, in pymc3 you would need to code all derivatives manually, and redo it if you change the model (see slide 24 in my seminar on lsqfitgp).

Currently in the tests with lsqfitgp I'm doing a Laplace approximation for the nonlinearities and the hyperparameters, after appropriately transforming the hyperparameters (log for positive, etc.). This often works well, considering in particular that we are not interested in the hyperparameters per se, we only use them as a way of specifying a flexible prior distribution of the PDFs, we care about the predictive error and not about getting right the tails of the posteriors of the hyperparameters.

Moreover I've seen the error on the current fitted PDFs and it's small, so overall I think the fit would work without MCMC.

If we end up really needing it, we could first test the fit with lsqfitgp where it's easy to change the model and then hardcode everything in pymc3 when we are sure of what kernels we want to use. Other alternatives are using JAX-based beta software (numpyro, pymc4, tinygp) and do some stuff on our own but not computing all derivatives, but considering that lsqfitgp is written with autograd I could as well port lsqfitgp to JAX and then plug its marginal likelihood into any NUTS implementation.

@alecandido alecandido added the roadmap Track progresses on the project label Apr 27, 2022
@alecandido
Copy link
Member Author

Ok, the idea is to have a NUTS-based implementation, about the details of the library providing it I don't have any preference.

I guess the proof of concept is worth the effort, even if we end up choosing pure lsqfitgp implementation at the end (meaning kriging, and not HMC).
I had a look at numpyro, as you know, but it looked me bloated. Likewise, I'm aware that even pymc3 can be bloated as well, as in your slides.

I'm taking the burden of providing one or the other implementation, but if you can help me, I'd be glad to accept your insights (and even practical help).

Furthermore, I had a look at the implementation of NUTS in pymc3, and it's rather concise:

thus even an implementation from scratch should not be hard (even if I'd study it and the abstract algorithm before, and cook up my own implementation in my own optimized/favorite way).
If we can plug it in lsqfitgp, so much the better.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
roadmap Track progresses on the project
Projects
None yet
Development

No branches or pull requests

2 participants