Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cross-entropy computations trigger checkify NaN error #187

Open
hr0nix opened this issue Aug 15, 2022 · 0 comments
Open

Cross-entropy computations trigger checkify NaN error #187

hr0nix opened this issue Aug 15, 2022 · 0 comments

Comments

@hr0nix
Copy link

hr0nix commented Aug 15, 2022

JAX recently added a transformation that allows to validate user-defined assertions, as well as check for some standard issues, such as division by zero or NaNs arising in computations. Unfortunately, some of this checks are currently violated by the distrax code. For example, this test currently passes:

def test_distrax_ce_nan():
    def func_to_check(logits, labels):
        return distrax.Categorical(probs=labels).cross_entropy(distrax.Categorical(logits=logits))
    checked_func = checkify.checkify(func_to_check, errors=checkify.nan_checks)
    err, val = checked_func(logits=jnp.log(jnp.asarray([0.5, 0.5])), labels=jnp.asarray([1.0, 0.0]))
    with pytest.raises(ValueError):
        err.throw()

The problem lies in the way cross-entropy is computed when some of the probabilities are zero: the current implementation computes zero times negative infinity, which results in NaNs, and then filters the NaNs using jnp.where. Unfortunately, while correct, this implementation triggers a checkify error because of these intermediate NaNs. This becomes problematic when one wants to enable NaN checks at the top-level.

A list of possible solutions that come to mind:

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant