-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Pallas] python for loop in kernel function is unbearably slow compared to fori_loop #26812
Comments
Generally for a long loop like this you should be using a JAX loop construct like Sometimes the compiler might be able to make better optimizations with an unrolled loop, in which case you could try unrolling a few iterations using the |
The thing is that compiling a non-pallas for loop implementation (my benchmark script also includes this case) that does the same thing (1000 steps) is still much faster. (no console print about slow compile for this case) Is pallas kernel function compiled by something else and not XLA? If both cases use XLA, why is one much slower than the other? |
Also, an interesting thing is that the naive non-Pallas python-for-loop-based function runs the fastest after compile, faster than compiled Pallas implementation ... |
No, Pallas does not use XLA. In your case you're using the Triton backend for Pallas since you're on GPU. Pallas outputs Triton-IR (https://triton-lang.org/main/dialects/dialects.html#) which is then compiled into PTX via the MLIR compiler framework.
There's already well-optimized kernels on GPUs (e.g. via cudnn) that JAX will call when possible and I'm guessing that's whats happening here. Pallas is useful in cases when the kernels don't exist yet or when the compiler isn't doing what you want. You can probably do some napkin math based on the specs of your hardware (FLOP/s and bandwidth) to see if there's even any improvement you could make with a custom kernel. |
I see! Thanks for the clarification!
I am implementing a custom RNN here so I would be surprised if cudnn comes into play. I am more suspecting that maybe XLA is able to find "optimized unrolling + blockspec" combo and that's why it beats my handwritten Pallas code. Anyway since this is not really a bug but more like a feature, please feel free to close this issue, unless you want to keep it open for improving the Pallas kernel compile time with for loop. Thank you! |
Description
A Python
for
loop of 1000 iterations in a Pallas kernel function is painfully slow to run (~ 1 hour!). Switching to an equivalentjax.lax.fori_loop
implementation is much faster (no noticeable lag). For reference, an equivalent naive JAX implementation without Pallas kernel using Pythonfor
loop takes several minutes to jit compile.To reproduce it, change the following part of the benchmark script
https://github.com/ysngshn/jax-pallas-benchmark/blob/main/run_benchmark.py#L39-L47
to
in the
_indrnn_elementwise_kernel
Pallas kernel function and rerun the script.Console output from JAX for the 1 hour freeze:
System info (python version, jaxlib version, accelerator, etc.)
Python 3.11, jax 0.4.35, jaxlib 0.4.34, numpy: 1.26.4, Linux system, GPU platform
The text was updated successfully, but these errors were encountered: