-
Notifications
You must be signed in to change notification settings - Fork 4
Failed to launch ROCm kernel on MI250x #247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Hi @ffrancesco94. Internal ticket has been created to investigate this issue. Thanks! |
I |
I see. So the way to go for now is to compile everything from source? Or will a minor release be published soon? |
To clarify: the above comment was only to add from my experience as a user. I adjusted the wording a bit since it seems overly confident reading it now :-) |
Aah, I see. But it's still very good to get that input, so thank you, at least I know that I can compile things myself for the time being! |
Thanks for letting us know. I will shortly upload the latest wheels. |
That would mean that the ones built here are the same as the ones on PyPI? edit:typo |
I think the ones of PyPI might be behind; building from this repo should yield the latest public available Jax build for ROCm. |
@Ruturaj4 For the version jax0.0.5, rocm6.3.1 I also having similar error
When trying to run
|
@phambinhfin For me it was due to the driver on the host system being older (ROCm 6.0.x or 6.2.x), not sure if it helps in your case. |
@ffrancesco94 can you try installing jax 0.5.0 using pip? We now fixed the pypi issue and our plugins are up. |
@ffrancesco94, this issue will be closed for now due to inactivity. Please feel free to follow up below if the issue persists. Thanks! |
Description
Hi,
I tried to use this build of jax since the upstream one was giving me some trouble. I built a simple reproducer where I try to train a PINN using jax (I can upload it if you need) and I get the following error:
I installed this jax build following these instructions. Interestingly, I had a similar error when I was experimenting a while back and I was missing the
jax-rocm60-pjrt
andjax-rocm60-plugin
packages, but they are installed now. I installed the packages in a Singularity container and ran from there, but it seems that the GPU is being correctly found.System info (python version, jaxlib version, accelerator, etc.)
The accelerator is an AMD MI250X.
The text was updated successfully, but these errors were encountered: