-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SVI GammaPoisson: Forward-mode differentiation rule for 'digamma' not implemented #621
Comments
Yes, I think you can make a FR in JAX. Currently, it lacks jvp rule for import jax
jax.hessian(jax.scipy.special.gammaln)(jax.numpy.ones(3))
jax.grad(jax.scipy.special.digamma)(jax.numpy.ones(3)) I believe the best is to request for scipy.special.polygamma (so we can take derivative for any polygamma including digamma). |
In Pyro we also provide a cheap approximate log_beta() function that uses only |
Thanks, @fritzo ! I just fixed this issue upstream. Using |
Closed because this issue has been addressed upstream. Thank @jkgiesler for bringing this up! |
Excellent, thank you! |
I'm trying to use SVI with a GammaPoisson distribution on jax version 0.1.57 and numpyro version 0.2.4.
This code will generate the following exception:
I was able to successfully use MCMC:
Does this seem right? If so, I'm happy to raise this issue with the jax team. Are there other distributions that are known to work with MCMC but not SVI in numpyro?
The text was updated successfully, but these errors were encountered: