You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
JAX recently added a transformation that allows to validate user-defined assertions, as well as check for some standard issues, such as division by zero or NaNs arising in computations. Unfortunately, some of this checks are currently violated by the distrax code. For example, this test currently passes:
The problem lies in the way cross-entropy is computed when some of the probabilities are zero: the current implementation computes zero times negative infinity, which results in NaNs, and then filters the NaNs using jnp.where. Unfortunately, while correct, this implementation triggers a checkify error because of these intermediate NaNs. This becomes problematic when one wants to enable NaN checks at the top-level.
A list of possible solutions that come to mind:
Changing cross-entropy computation to avoid computing NaNs, for instance by replacing negative infinities in logits with finite values (the result will not be affected anyway because of jnp.where).
JAX recently added a transformation that allows to validate user-defined assertions, as well as check for some standard issues, such as division by zero or NaNs arising in computations. Unfortunately, some of this checks are currently violated by the distrax code. For example, this test currently passes:
The problem lies in the way cross-entropy is computed when some of the probabilities are zero: the current implementation computes zero times negative infinity, which results in NaNs, and then filters the NaNs using jnp.where. Unfortunately, while correct, this implementation triggers a checkify error because of these intermediate NaNs. This becomes problematic when one wants to enable NaN checks at the top-level.
A list of possible solutions that come to mind:
Thanks!
The text was updated successfully, but these errors were encountered: