GPU pallas_call
loses compiler params during second call when double jit-wrapped
#25714
Labels
bug
Something isn't working
Description
If you add a
print(triton_params)
after this linejax/jax/_src/pallas/triton/pallas_call_registration.py
Line 63 in 57b2154
This causes some performance problems in production as kernels can't get the right compiler params.
Reproducer
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: