Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SVI GammaPoisson: Forward-mode differentiation rule for 'digamma' not implemented #621

Closed
jkgiesler opened this issue Jun 8, 2020 · 5 comments
Labels
jax This issue is specific to JAX

Comments

@jkgiesler
Copy link

jkgiesler commented Jun 8, 2020

I'm trying to use SVI with a GammaPoisson distribution on jax version 0.1.57 and numpyro version 0.2.4.

import numpyro

from jax import lax
import jax.numpy as np
import jax.random as random

from numpyro.contrib.autoguide import (AutoContinuousELBO,
                                       AutoLaplaceApproximation)
from numpyro.diagnostics import print_summary
import numpyro.distributions as dist
from numpyro.infer import SVI
import numpyro.optim as optim
from numpyro import handlers

concentration_actual = 1.5
rate_actual = 2

with handlers.seed(rng_seed=0):
    y = numpyro.sample('y',dist.GammaPoisson(concentration_actual,rate_actual),sample_shape=(10000,))

def model(y=y):
    concentration = numpyro.sample('concentration', dist.Normal(3))
    rate = numpyro.sample('rate', dist.HalfNormal(3))

    return numpyro.sample(
        'obs',
        dist.GammaPoisson(concentration, rate),
        obs=y
    )    

model_ala = AutoLaplaceApproximation(model)
svi = SVI(
    model,
    model_ala,
    optim.Adam(0.1),
    AutoContinuousELBO(),
    y=y
)
init_state = svi.init(random.PRNGKey(1))
state,loss = lax.scan(lambda x,i: svi.update(x),init_state,np.zeros(2000))
params = svi.get_params(state)
post = model_ala.sample_posterior(random.PRNGKey(2),params,(1000,))

This code will generate the following exception:

NotImplementedError: Forward-mode differentiation rule for 'digamma' not implemented.

I was able to successfully use MCMC:

from numpyro.infer import MCMC, NUTS
from numpyro.infer import Predictive

kernel = NUTS(model)

model_mcmc = MCMC(
    kernel,
    num_warmup=500,
    num_samples=5000,
    num_chains=4,
    progress_bar=False
)


model_mcmc.run(
    random.PRNGKey(0),
    y=y
)

model_mcmc.print_summary()

r = random.PRNGKey(3)
rng_key,rng_key_ = random.split(r)
samples_1 = model_mcmc.get_samples()
predictive = Predictive(model,samples_1)
predictions = predictive(rng_key_,samples_1)

Does this seem right? If so, I'm happy to raise this issue with the jax team. Are there other distributions that are known to work with MCMC but not SVI in numpyro?

@fehiepsi
Copy link
Member

fehiepsi commented Jun 8, 2020

Yes, I think you can make a FR in JAX. Currently, it lacks jvp rule for digamma so we can't compute hessian for models involving gammaln. You can test it with

import jax
jax.hessian(jax.scipy.special.gammaln)(jax.numpy.ones(3))
jax.grad(jax.scipy.special.digamma)(jax.numpy.ones(3))

I believe the best is to request for scipy.special.polygamma (so we can take derivative for any polygamma including digamma).

@fritzo
Copy link
Member

fritzo commented Jun 9, 2020

In Pyro we also provide a cheap approximate log_beta() function that uses only log() internally and so is differentiable. I've been using that for BetaBinomial, and I believe you could use it for NegativeBinomial as well.

@fehiepsi
Copy link
Member

Thanks, @fritzo ! I just fixed this issue upstream. Using log_beta sounds good to me but we need to find a way to expose tol argument.

@fehiepsi
Copy link
Member

Closed because this issue has been addressed upstream. Thank @jkgiesler for bringing this up!

@jkgiesler
Copy link
Author

Excellent, thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
jax This issue is specific to JAX
Projects
None yet
Development

No branches or pull requests

3 participants