Trouble getting correct inverse-gamma samples from jax.random.gamma
#15155
-
Hi! I'm trying to get samples from an Inverse Gamma from shape = 2
scale = 2
### Jax Gamma
key = jax.random.PRNGKey(123)
gamma_samples = jax.random.gamma(key, shape, (100000,)) * scale
print(f"Gamma Sample mean: {gamma_samples.mean()}")
print(f"Gamma Analytical mean: {shape * scale}")
# Gamma Sample mean: 4.003810405731201
# Gamma Analytical mean: 4
## Jax InvGamma
invgamma_samples = 1 / gamma_samples
print(f"InvGamma Sample mean: {invgamma_samples.mean()}")
print(f"InvGamma Analytical mean: {scale / (shape - 1)}")
# InvGamma Sample mean: 0.4992263913154602
# InvGamma Analytical mean: 2.0 I'm using the analytical means for both gamma and inverse-gamma distributions where the latter requires Thanks in advance! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 7 replies
-
I think you're using different forms of the scale parameter in each expression. Try it this way instead: key = jax.random.PRNGKey(123)
gamma_samples = jax.random.gamma(key, shape, (100000,)) / scale
print(f"Gamma Sample mean: {gamma_samples.mean()}")
print(f"Gamma Analytical mean: {shape / scale}")
# Gamma Sample mean: 1.0009526014328003
# Gamma Analytical mean: 1.0
## Jax InvGamma
invgamma_samples = 1 / gamma_samples
print(f"InvGamma Sample mean: {invgamma_samples.mean()}")
print(f"InvGamma Analytical mean: {scale / (shape - 1)}")
# InvGamma Sample mean: 1.9969054460525513
# InvGamma Analytical mean: 2.0 |
Beta Was this translation helpful? Give feedback.
I think you're using different forms of the scale parameter in each expression. Try it this way instead: