diff --git a/numpyro/distributions/__init__.py b/numpyro/distributions/__init__.py index 50fe50983..a18e55bf4 100644 --- a/numpyro/distributions/__init__.py +++ b/numpyro/distributions/__init__.py @@ -8,6 +8,7 @@ NegativeBinomial2, NegativeBinomialLogits, NegativeBinomialProbs, + BetaNegativeBinomial, ZeroInflatedNegativeBinomial2, ) from numpyro.distributions.continuous import ( @@ -124,6 +125,7 @@ "Beta", "BetaBinomial", "BetaProportion", + "BetaNegativeBinomial", "Binomial", "BinomialLogits", "BinomialProbs", diff --git a/numpyro/distributions/conjugate.py b/numpyro/distributions/conjugate.py index de8904c2e..42d81a0d9 100644 --- a/numpyro/distributions/conjugate.py +++ b/numpyro/distributions/conjugate.py @@ -296,6 +296,86 @@ def __init__(self, mean, concentration, *, validate_args=None): super().__init__(concentration, rate, validate_args=validate_args) +class BetaNegativeBinomial(Distribution): + r""" + Beta negative binomial distribution. + Also known as inverse Markvo-Polya distribution and the generalized Waring distribution. + + The probability mass function is defined as: + .. math:: + f(n | r, \alpha, \beta) = \frac{\Gamma(n + r)}{n! \Gamma(r)} + \frac{\text{B}(\beta + n, \alpha + r)}{\text{B}(\beta, \alpha)} + + where :math:`n \in \mathbb{N}` is the count, :math:`r \in \mathbb{R}^+` is the number of success (`total_count`), + :math:`\alpha \in \mathbb{R}^+` and :math:`\beta \in \mathbb{R}^+` is the concentration parameters + of the beta disribtuion (`concentration1` and `concentration0`)., + :math:`\text{B}` is the beta function, and :math:`\Gamma` is the gamma function. + """ + + arg_constraints = { + "total_count": constraints.positive, + "concentration1": constraints.positive, + "concentration0": constraints.positive, + } + support = constraints.nonnegative_integer + pytree_data_fields = ("total_count", "concentration1", "concentration0", "_beta") + + def __init__( + self, total_count, concentration1, concentration0, *, validate_args=None + ): + self.concentration1, self.concentration0, self.total_count = promote_shapes( + concentration1, concentration0, total_count + ) + batch_shape = lax.broadcast_shapes( + jnp.shape(concentration1), jnp.shape(concentration0), jnp.shape(total_count) + ) + total_count = jnp.broadcast_to(total_count, batch_shape) + concentration1 = jnp.broadcast_to(concentration1, batch_shape) + concentration0 = jnp.broadcast_to(concentration0, batch_shape) + + self._beta = Beta(concentration1, concentration0) + super(BetaNegativeBinomial, self).__init__( + batch_shape=batch_shape, validate_args=validate_args + ) + + def sample(self, key, sample_shape=()): + assert is_prng_key(key) + key_beta, key_negbin = random.split(key) + p = self._beta.sample(key_beta, sample_shape) + q = 1.0 - p + return NegativeBinomialProbs(self.total_count, q).sample(key_negbin) + + @validate_sample + def log_prob(self, value): + return ( + -gammaln(value + 1) + + gammaln(value + self.total_count) + - gammaln(self.total_count) + + betaln( + value + self.concentration1, self.concentration0 + self.total_count + ) + - betaln(self.concentration1, self.concentration0) + ) + + @property + def mean(self): + return jnp.where( + self.concentration1 > 1, + self.total_count * self.concentration0 / (self.concentration1 - 1), + jnp.inf, + ) + + @property + def variance(self): + return jnp.where( + self.concentration1 > 2, + self.total_count + * self.concentration0 + * (self.concentration0 + self.total_count - 1) + / ((self.concentration1 - 1) ** 2 * (self.concentration1 - 2)), + jnp.inf, + ) + def ZeroInflatedNegativeBinomial2( mean, concentration, *, gate=None, gate_logits=None, validate_args=None ): diff --git a/test/test_distributions.py b/test/test_distributions.py index 228634b07..652740812 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -1024,6 +1024,20 @@ def get_sp_dist(jax_dist): np.array([5.0, 3.0]), np.array([10, 12]), ), + T(dist.BetaNegativeBinomial, 5.0, 2.0, 2.0), + T(dist.BetaNegativeBinomial, np.array([1.0, 5.0, 10.0]), 2.0, 2.0), + T( + dist.BetaNegativeBinomial, + np.array([1.0, 5.0, 10.0]), + np.array([2.0, 3.0, 4.0]), + 2.0 + ), + T( + dist.BetaNegativeBinomial, + np.array([1.0, 5.0, 10.0]), + np.array([2.0, 3.0, 4.0]), + np.array([2.0, 1.0, 0.5]) + ), T(dist.BernoulliProbs, 0.2), T(dist.BernoulliProbs, np.array([0.2, 0.7])), T(dist.BernoulliLogits, np.array([-1.0, 3.0])), @@ -2176,6 +2190,7 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape): def dist_gen_fn(): d = jax_dist(*oob_params, validate_args=True) + print(d) return d jax.jit(dist_gen_fn)()