Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better jax.debug.print #25842

Open
yliu120 opened this issue Jan 10, 2025 · 3 comments
Open

Better jax.debug.print #25842

yliu120 opened this issue Jan 10, 2025 · 3 comments
Assignees
Labels
enhancement New feature or request

Comments

@yliu120
Copy link
Contributor

yliu120 commented Jan 10, 2025

Problems we want to solve

jax.debug.print currently works under most environments but it is pretty costly.

  1. Always all gather data under GSPMD mode
    Under the GSPMD mode, it only supports maximal sharding as inherited from io_callback so that when a program has debug.print, it will be way more inefficient as lots of communications are added.

  2. Going across Python interpreter/Numpy
    Basically these are overheads that lead to slow printing.

  3. Unstable IR and disrupting AOT
    The current impl is based on io_callback which puts the opaque pointer of the callback in the stablehlo IR so that each machine ends up with something different. (Related to [Lowering] Stable IR #25123)

How to make it better

  1. 3 is WIP.
  2. We can implement debug.print callback with FFI in XLA so that it will become a precompiled callback in the jaxlib. The custom call will have sharding propagated and will be partitioned as of the inputs. Here is what it will look like if programming in python,
def print(x):
  x = ffi_call.print(x) ==> custom call following the same sharding as x.
  return x
@yliu120 yliu120 added the enhancement New feature or request label Jan 10, 2025
@yliu120
Copy link
Contributor Author

yliu120 commented Jan 10, 2025

This issue reveals an offline conversation with Peter @hawkinsp and @dfm

@dfm
Copy link
Collaborator

dfm commented Jan 10, 2025

Also pinging @danielsuo who has been thinking about this.

@yliu120
Copy link
Contributor Author

yliu120 commented Jan 14, 2025

A few code samples scattering around:

  1. Having custom_partitioning calls working with debug.print ([io_callback] Adds test for io_callback being used inside custom partitioning #23620)
  2. Having debug.print (or io_callback) working inside partial auto ([Shmap/PartialAuto] Temporary solution for debug.print inside a partial-auto shard map. #25705)

copybara-service bot pushed a commit that referenced this issue Jan 15, 2025
As discussed in #25842, since JAX's current logging mechanisms (e.g. `jax.debug.print`) are built on callbacks, logging a sharded array requires an expensive all-gather operation. It would sometimes be useful to be able to separately print the local data shard on each worker.

These parallel changes to XLA and JAX are meant as an experiment to demonstrate the custom GSPMD partitioning logic needed for this behavior. I'm currently using a new FFI handler that doesn't do anything, but this is sufficient to test the partitioning logic. It should be feasible to apply this same logic to a custom call encapsulating a callback.

PiperOrigin-RevId: 715862954
copybara-service bot pushed a commit to openxla/xla that referenced this issue Jan 15, 2025
As discussed in jax-ml/jax#25842, since JAX's current logging mechanisms (e.g. `jax.debug.print`) are built on callbacks, logging a sharded array requires an expensive all-gather operation. It would sometimes be useful to be able to separately print the local data shard on each worker.

These parallel changes to XLA and JAX are meant as an experiment to demonstrate the custom GSPMD partitioning logic needed for this behavior. I'm currently using a new FFI handler that doesn't do anything, but this is sufficient to test the partitioning logic. It should be feasible to apply this same logic to a custom call encapsulating a callback.

PiperOrigin-RevId: 715862954
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Jan 15, 2025
As discussed in jax-ml/jax#25842, since JAX's current logging mechanisms (e.g. `jax.debug.print`) are built on callbacks, logging a sharded array requires an expensive all-gather operation. It would sometimes be useful to be able to separately print the local data shard on each worker.

These parallel changes to XLA and JAX are meant as an experiment to demonstrate the custom GSPMD partitioning logic needed for this behavior. I'm currently using a new FFI handler that doesn't do anything, but this is sufficient to test the partitioning logic. It should be feasible to apply this same logic to a custom call encapsulating a callback.

PiperOrigin-RevId: 715862954
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants