diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index 0d701ba0a..3aa6e69a2 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -1250,7 +1250,7 @@ def sample(self, key, sample_shape=()): @validate_sample def log_prob(self, value): - log_prob = jnp.log(value == self.v) + log_prob = jnp.where(value == self.v, 0, -jnp.inf) log_prob = sum_rightmost(log_prob, len(self.event_shape)) return log_prob + self.log_density