You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
JAX compilation freezes in versions 0.5.0 and 0.5.1 in seemingly unpredictable settings. Unfortunately, I was unable to produce a minimal breaking example but the following instructions should suffice to reproduce the issue (evaluated on 3 different machines with different CUDA versions and drivers).
Setup:
Description
JAX compilation freezes in versions 0.5.0 and 0.5.1 in seemingly unpredictable settings. Unfortunately, I was unable to produce a minimal breaking example but the following instructions should suffice to reproduce the issue (evaluated on 3 different machines with different CUDA versions and drivers).
Setup:
The following code hangs indefinitely
Though, using an older version works
Weirdly enough choosing a different system also works (in both jax versions). These only vary in some tensor shapes.
I attached zips with the XLA dumps for all three experiments.
System info (python version, jaxlib version, accelerator, etc.)
xla_dump1.zip
xla_dump2.zip
xla_dump3.zip
The text was updated successfully, but these errors were encountered: