-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Better jax.debug.print #25842
Labels
enhancement
New feature or request
Comments
Also pinging @danielsuo who has been thinking about this. |
A few code samples scattering around:
|
copybara-service bot
pushed a commit
that referenced
this issue
Jan 15, 2025
As discussed in #25842, since JAX's current logging mechanisms (e.g. `jax.debug.print`) are built on callbacks, logging a sharded array requires an expensive all-gather operation. It would sometimes be useful to be able to separately print the local data shard on each worker. These parallel changes to XLA and JAX are meant as an experiment to demonstrate the custom GSPMD partitioning logic needed for this behavior. I'm currently using a new FFI handler that doesn't do anything, but this is sufficient to test the partitioning logic. It should be feasible to apply this same logic to a custom call encapsulating a callback. PiperOrigin-RevId: 715862954
copybara-service bot
pushed a commit
to openxla/xla
that referenced
this issue
Jan 15, 2025
As discussed in jax-ml/jax#25842, since JAX's current logging mechanisms (e.g. `jax.debug.print`) are built on callbacks, logging a sharded array requires an expensive all-gather operation. It would sometimes be useful to be able to separately print the local data shard on each worker. These parallel changes to XLA and JAX are meant as an experiment to demonstrate the custom GSPMD partitioning logic needed for this behavior. I'm currently using a new FFI handler that doesn't do anything, but this is sufficient to test the partitioning logic. It should be feasible to apply this same logic to a custom call encapsulating a callback. PiperOrigin-RevId: 715862954
copybara-service bot
pushed a commit
to tensorflow/tensorflow
that referenced
this issue
Jan 15, 2025
As discussed in jax-ml/jax#25842, since JAX's current logging mechanisms (e.g. `jax.debug.print`) are built on callbacks, logging a sharded array requires an expensive all-gather operation. It would sometimes be useful to be able to separately print the local data shard on each worker. These parallel changes to XLA and JAX are meant as an experiment to demonstrate the custom GSPMD partitioning logic needed for this behavior. I'm currently using a new FFI handler that doesn't do anything, but this is sufficient to test the partitioning logic. It should be feasible to apply this same logic to a custom call encapsulating a callback. PiperOrigin-RevId: 715862954
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Problems we want to solve
jax.debug.print
currently works under most environments but it is pretty costly.Always all gather data under GSPMD mode
Under the GSPMD mode, it only supports maximal sharding as inherited from
io_callback
so that when a program has debug.print, it will be way more inefficient as lots of communications are added.Going across Python interpreter/Numpy
Basically these are overheads that lead to slow printing.
Unstable IR and disrupting AOT
The current impl is based on io_callback which puts the opaque pointer of the callback in the stablehlo IR so that each machine ends up with something different. (Related to [Lowering] Stable IR #25123)
How to make it better
The text was updated successfully, but these errors were encountered: