Skip to content

Trouble getting correct inverse-gamma samples from jax.random.gamma #15155

Answered by jakevdp
PaulScemama asked this question in Q&A
Discussion options

You must be logged in to vote

I think you're using different forms of the scale parameter in each expression. Try it this way instead:

key = jax.random.PRNGKey(123)
gamma_samples = jax.random.gamma(key, shape, (100000,)) / scale
print(f"Gamma Sample mean: {gamma_samples.mean()}")
print(f"Gamma Analytical mean: {shape / scale}")
# Gamma Sample mean: 1.0009526014328003
# Gamma Analytical mean: 1.0

## Jax InvGamma
invgamma_samples = 1 / gamma_samples
print(f"InvGamma Sample mean: {invgamma_samples.mean()}")
print(f"InvGamma Analytical mean: {scale / (shape - 1)}")
# InvGamma Sample mean: 1.9969054460525513
# InvGamma Analytical mean: 2.0

Replies: 1 comment 7 replies

Comment options

You must be logged in to vote
7 replies
@jakevdp
Comment options

@jakevdp
Comment options

@PaulScemama
Comment options

@jakevdp
Comment options

@PaulScemama
Comment options

Answer selected by PaulScemama
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants