Skip to content

Add deterministic advi #564

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

martiningram
Copy link

Hi everyone,

I'm one of the authors of the paper on deterministic ADVI. There is an open feature request for this in PyMC here so I thought I'd kick things off with this PR.

In simple terms, DADVI is like ADVI but rather than using a new draw to estimate its objective at each step, it uses a fixed set of draws during the optimisation. That means that (1) it can use regular off-the-shelf optimisers rather than stochastic optimisation, making convergence more reliable, and (2) it's possible to use techniques to improve the variance estimates. This is in the paper, as well as tools to assess how big the error is from using fixed draws.

This PR covers only the first part -- optimising ADVI with fixed draws. This is because I thought I'd start simple and because I'm hoping that it already addresses a real problem with ADVI, which is the difficulty in assessing convergence.

In addition to adding the code, there is an example notebook in notebooks/deterministic_advi_example.ipynb. It fits DADVI to the PyMC basic linear regression example. I can add more examples, but I thought I'd start simple.

I mostly lifted the code from my research repository, so there are probably some style differences. Let me know what would be important to change.

Note that JAX is needed, but there shouldn't be any other dependencies.

Very keen to hear what you all think! :)

All the best,
Martin

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@jessegrabowski jessegrabowski added the enhancements New feature or request label Aug 14, 2025
@jessegrabowski
Copy link
Member

This is super cool -- I'm very excited to look more closely over the next few days.

Since you're ultimately building a loss function and sending it to scipy.optimize, do you think we could re-use any of the machinery that exists for doing that in the laplace_approx module, for example this or this?

@ricardoV94
Copy link
Member

What exactly needs jax?

@martiningram
Copy link
Author

martiningram commented Aug 14, 2025

Thank you both very much for having a look so quickly!

@jessegrabowski Good point, yes maybe! I'll take a look.

@ricardoV94 Currently, JAX is used to compute the hvp and the jacobian of the objective. That involves computing a gradient for each of the fixed draws and then taking an average. What's quite nice in JAX is that this can be done with vmap easily: https://github.com/pymc-devs/pymc-extras/pull/564/files#diff-1d6e8b962a8c3ca803c55bea43c19863223ed50ae3814acc55424834ade1215cR44

That said, JAX isn't strictly necessary. Anything that can provide the DADVIFuns is fine: https://github.com/pymc-devs/pymc-extras/pull/564/files#diff-48ee4e85c0ff57f5b8af20dfd608bd0e37c3a2c76169a7bbe499e77ff3802d9dR13 . In fact, I have code in the original research repo that turns the regular hvp and gradient function into the DADVIFuns. But I think it'll be slower because of the for loops e.g. here .

Are you concerned about the JAX dependency? If so, maybe I could have a go at doing a JAX-free version using the code just mentioned and then only support JAX optionally. I do think it might be nice to have since it's probably more efficient and would hopefully also run fast on GPUs. But interested in your thoughts.

Also, I see one of the pre-commit checks seem to be failing. I can do the work to make the pre-commit hooks happy, sorry I haven't done that yet.

@zaxtax
Copy link
Contributor

zaxtax commented Aug 14, 2025 via email

@ricardoV94
Copy link
Member

ricardoV94 commented Aug 14, 2025

@ricardoV94 Currently, JAX is used to compute the hvp and the jacobian of the objective. That involves computing a gradient for each of the fixed draws and then taking an average. What's quite nice in JAX is that this can be done with vmap easily: h

PyTensor has the equivalent hessian_product_vector and jacobian, and vectorize_graph that does the same as vmap (or more if you have multiple batch dimensions)

The reason I ask is if you don't have anything jax specific you can still end up using jax, but also C or numba which may be better for certain users.

@martiningram
Copy link
Author

@ricardoV94 Oh cool, thanks, I didn't realise! I'll take a look if I can use those. I agree it would be nice to support as many users as possible.

@ricardoV94
Copy link
Member

Happy to assist you. If you're vectorizing the Jacobian you probably want to build jacobian(vectorize=True) which can further be vectorized more nicely.

Everything is described here although a bit scattered: https://pytensor.readthedocs.io/en/latest/tutorial/gradients.html

@martiningram
Copy link
Author

martiningram commented Aug 15, 2025

Hey @ricardoV94 (and potentially others!), I think I could use your advice with the vectorisation. I think I've read enough to do it without using the functions here but I'd really like to try to get this vectorised for speed.

To explain a bit: the code expects the definition of DADVIFuns (https://github.com/pymc-devs/pymc-extras/pull/564/files#diff-48ee4e85c0ff57f5b8af20dfd608bd0e37c3a2c76169a7bbe499e77ff3802d9dR13). The first of these expects two inputs:

  1. The variational parameter vector eta, which is all the means concatenated with the log_sds of the variational parameters. This will have length 2D, where D is the number of parameters in the model (first D is means, second D is log_sds).
  2. A matrix of draws of shape [M, D], with D as before and M the number of draws.

The function should then return the estimate of the kl divergence using these draws, as well as its gradient with respect to the variational parameters. The KL divergence is the sum of the entropy of the approximation (a simple function of the variational parameters only) and the average of the log posterior densities from the draws. That's the part that I'd like to vectorise.

Now in JAX, the way I do this is to...:

  1. Compute the log posterior density for a single draw (i.e. one row in the matrix of draws): https://github.com/pymc-devs/pymc-extras/pull/564/files#diff-1d6e8b962a8c3ca803c55bea43c19863223ed50ae3814acc55424834ade1215cR37
  2. vmap this computation across all the draws and then take the mean here: https://github.com/pymc-devs/pymc-extras/pull/564/files#diff-1d6e8b962a8c3ca803c55bea43c19863223ed50ae3814acc55424834ade1215cR42

Thanks to vectorize_graph, I am hoping I can do something like this with pytensor. My strategy idea was to...:

  1. Define the variational parameter vector eta in pytensor, and transform a single draw with this vector
  2. clone_replace to use the transformed draw as an input to the graph, rather than the current input
  3. Use vectorize_graph to vectorise with respect to the draws
  4. Compute the mean of the densities and get gradients of this mean with respect to the variational parameter vector

This makes sense in my head but the problem I see is that the pymc model's logp seems to expect a dictionary, rather than a flat vector. So as part of the step to get from the new inputs to the density, I need to turn the flat vector into the dictionary. In pymc there is DictToArrayBijection which does this, but I don't think I can use it as part of the pytensor graph.

So in essence, I think I need code to do DictToArrayBijection in pure pytensor. Is there something like that? Or is there another way I am missing? I guess it would be great if I could just have a logp function that takes a flat vector as an input already -- is there a way I can get to that?

Thanks a lot for your help :)

@jessegrabowski
Copy link
Member

If you get the logp of a pymc model using model.logp (rather than compile_logp or one of the jax helpers), it will just return the symbolic logp graph, which you can then do all your vectorization/replacements on.

The path followed by the laplace code is to freeze the model and extract the negative logp , then create a flat vector input replacing the individual value inputs, then compile the loss_and_grads/hess/hessp functions, (optionally in jax)

My hope is that you can get the correct loss function for DADVI, then you should be able to directly pass it into scipy_optimize_funcs_from_loss and just completely re-use all that.

The 4 steps you outline seem correct to me. pymc.pytensorf.join_nonshared_inputs is the function I think you're looking for to do the pack/unpack operation on different the parameters; I linked to its usage above.

@martiningram
Copy link
Author

Thanks a lot @jessegrabowski . I'll give it a go!

@martiningram
Copy link
Author

Hey all, I think I made good progress with the pytensor version. A first version is here: https://github.com/pymc-devs/pymc-extras/pull/564/files#diff-1b6e7da940ec73fce49f5e13ae1db5369ec011cb0b55974ec04d81e519e923f6R55

I think the only major thing missing is to transform the draws back into the constrained space from the unconstrained space. Is there a code snippet anyone could point me to? Thanks for your help and for all the helpful advice you've already given!

@zaxtax
Copy link
Contributor

zaxtax commented Aug 16, 2025 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancements New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants