Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pallas] python for loop in kernel function is unbearably slow compared to fori_loop #26812

Open
ysngshn opened this issue Feb 27, 2025 · 5 comments
Labels
bug Something isn't working

Comments

@ysngshn
Copy link

ysngshn commented Feb 27, 2025

Description

A Python for loop of 1000 iterations in a Pallas kernel function is painfully slow to run (~ 1 hour!). Switching to an equivalent jax.lax.fori_loop implementation is much faster (no noticeable lag). For reference, an equivalent naive JAX implementation without Pallas kernel using Python for loop takes several minutes to jit compile.

To reproduce it, change the following part of the benchmark script

https://github.com/ysngshn/jax-pallas-benchmark/blob/main/run_benchmark.py#L39-L47

    h0 = h0_ref[...]

    def _step(i, h_prev):
        idx = slice_all[:seq_dim] + (i,) + slice_all[seq_dim+1:]
        h_next = activation(x_ref[idx] + whh * h_prev)
        o_ref[idx] = h_next
        return h_next

    _ = fori_loop(0, seqlen, _step, h0)

to

    h = h0_ref[...]
    
    for i in range(seqlen):
        idx = slice_all[:seq_dim] + (i,) + slice_all[seq_dim+1:]
        h = activation(x_ref[idx] + whh * h)
        o_ref[idx] = h

in the _indrnn_elementwise_kernel Pallas kernel function and rerun the script.

Console output from JAX for the 1 hour freeze:

2025-02-26 22:31:14.268744: E external/xla/xla/service/slow_operation_alarm.cc:65] 
********************************
[Compiling module jit_wrapped] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2025-02-26 23:25:30.248847: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 56m15.98014819s

********************************
[Compiling module jit_wrapped] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************

System info (python version, jaxlib version, accelerator, etc.)

Python 3.11, jax 0.4.35, jaxlib 0.4.34, numpy: 1.26.4, Linux system, GPU platform

@ysngshn ysngshn added the bug Something isn't working label Feb 27, 2025
@justinjfu
Copy link
Collaborator

Generally for a long loop like this you should be using a JAX loop construct like fori to reduce compile times, because otherwise the compilation would scale with the number of loop iterations. Since Pallas emits low-level code the programs get very large (Triton-IR in this case), even when compared to the normal JAX code.

Sometimes the compiler might be able to make better optimizations with an unrolled loop, in which case you could try unrolling a few iterations using the unroll=N argument. But I don't think there's many practical use-cases where you would want to unroll 1000 steps. I'm inclined to say that this is mostly "working as intended" and a fast compile time in this case is a nice-to-have but not necessarily a bug. However, if you experience long compile times even without long unrolled loops that would be a much more urgent issue.

@ysngshn
Copy link
Author

ysngshn commented Feb 27, 2025

The thing is that compiling a non-pallas for loop implementation (my benchmark script also includes this case) that does the same thing (1000 steps) is still much faster. (no console print about slow compile for this case)

Is pallas kernel function compiled by something else and not XLA? If both cases use XLA, why is one much slower than the other?

@ysngshn
Copy link
Author

ysngshn commented Feb 27, 2025

Also, an interesting thing is that the naive non-Pallas python-for-loop-based function runs the fastest after compile, faster than compiled Pallas implementation ...

@justinjfu
Copy link
Collaborator

Is pallas kernel function compiled by something else and not XLA? If both cases use XLA, why is one much slower than the other?

No, Pallas does not use XLA.

In your case you're using the Triton backend for Pallas since you're on GPU. Pallas outputs Triton-IR (https://triton-lang.org/main/dialects/dialects.html#) which is then compiled into PTX via the MLIR compiler framework.

Also, an interesting thing is that the naive non-Pallas python-for-loop-based function runs the fastest after compile, faster than compiled Pallas implementation ...

There's already well-optimized kernels on GPUs (e.g. via cudnn) that JAX will call when possible and I'm guessing that's whats happening here. Pallas is useful in cases when the kernels don't exist yet or when the compiler isn't doing what you want. You can probably do some napkin math based on the specs of your hardware (FLOP/s and bandwidth) to see if there's even any improvement you could make with a custom kernel.

@ysngshn
Copy link
Author

ysngshn commented Feb 27, 2025

In your case you're using the Triton backend for Pallas since you're on GPU. Pallas outputs Triton-IR (https://triton-lang.org/main/dialects/dialects.html#) which is then compiled into PTX via the MLIR compiler framework.

I see! Thanks for the clarification!

There's already well-optimized kernels on GPUs (e.g. via cudnn) that JAX will call when possible and I'm guessing that's whats happening here. ....

I am implementing a custom RNN here so I would be surprised if cudnn comes into play. I am more suspecting that maybe XLA is able to find "optimized unrolling + blockspec" combo and that's why it beats my handwritten Pallas code.

Anyway since this is not really a bug but more like a feature, please feel free to close this issue, unless you want to keep it open for improving the Pallas kernel compile time with for loop. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants