Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix DiscreteUniform.enumerate_support with non-trivial batch shape #1859

Merged
merged 2 commits into from
Sep 16, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix DiscreteUniform enumerate_support
fehiepsi committed Sep 5, 2024
commit 016ba850e7987aa7619cb95d8fc06768ca93a13d
6 changes: 3 additions & 3 deletions numpyro/distributions/discrete.py
Original file line number Diff line number Diff line change
@@ -469,9 +469,9 @@ def enumerate_support(self, expand=True):
raise NotImplementedError(
"Inhomogeneous `high` not supported by `enumerate_support`."
)
values = (self.low + jnp.arange(np.amax(self.high - self.low) + 1)).reshape(
(-1,) + (1,) * len(self.batch_shape)
)
low = jnp.reshape(self.low, -1)[0]
high = jnp.reshape(self.high, -1)[0]
values = jnp.arange(low, high + 1).reshape((-1,) + (1,) * len(self.batch_shape))
if expand:
values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
return values
19 changes: 10 additions & 9 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
@@ -2742,13 +2742,14 @@ def test_generated_sample_distribution(
@pytest.mark.parametrize(
"jax_dist, params, support",
[
(dist.BernoulliLogits, (5.0,), jnp.arange(2)),
(dist.BernoulliProbs, (0.5,), jnp.arange(2)),
(dist.BinomialLogits, (4.5, 10), jnp.arange(11)),
(dist.BinomialProbs, (0.5, 11), jnp.arange(12)),
(dist.BetaBinomial, (2.0, 0.5, 12), jnp.arange(13)),
(dist.CategoricalLogits, (np.array([3.0, 4.0, 5.0]),), jnp.arange(3)),
(dist.CategoricalProbs, (np.array([0.1, 0.5, 0.4]),), jnp.arange(3)),
(dist.BernoulliLogits, (5.0,), np.arange(2)),
(dist.BernoulliProbs, (0.5,), np.arange(2)),
(dist.BinomialLogits, (4.5, 10), np.arange(11)),
(dist.BinomialProbs, (0.5, 11), np.arange(12)),
(dist.BetaBinomial, (2.0, 0.5, 12), np.arange(13)),
(dist.CategoricalLogits, (np.array([3.0, 4.0, 5.0]),), np.arange(3)),
(dist.CategoricalProbs, (np.array([0.1, 0.5, 0.4]),), np.arange(3)),
(dist.DiscreteUniform, (2, 4), np.arange(2, 5)),
],
)
@pytest.mark.parametrize("batch_shape", [(5,), ()])
@@ -3333,8 +3334,8 @@ def test_normal_log_cdf():
"value",
[
-15.0,
jnp.array([[-15.0], [-10.0], [-5.0]]),
jnp.array([[[-15.0], [-10.0], [-5.0]], [[-14.0], [-9.0], [-4.0]]]),
np.array([[-15.0], [-10.0], [-5.0]]),
np.array([[[-15.0], [-10.0], [-5.0]], [[-14.0], [-9.0], [-4.0]]]),
],
)
def test_truncated_normal_log_prob_in_tail(value):