How to install compatible versions of Jax and Jaxlib for a particular CUDA version? #8358
-
I am having trouble getting both
Is there a better way? |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 2 replies
-
JAX dropped support for CUDA 10.X in jaxlib version 0.1.72 (See https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-0172-oct-12-2021), and looking at the changelog the current JAX version at that release was v 0.2.16. So you should be able to use something like
Note that if you already have a version of JAX installed,
I think you can also use the |
Beta Was this translation helpful? Give feedback.
-
CUDA 10.1 is buggy and causes wrong outputs in some cases, which is the key reason we dropped support for it. I really do mean that CUDA ( |
Beta Was this translation helpful? Give feedback.
-
From: https://pypi.org/project/jax/ pip install --upgrade pip CUDA 12 installation
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html CUDA 11 installation
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html |
Beta Was this translation helpful? Give feedback.
JAX dropped support for CUDA 10.X in jaxlib version 0.1.72 (See https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-0172-oct-12-2021), and looking at the changelog the current JAX version at that release was v 0.2.16. So you should be able to use something like
Note that if you already have a version of JAX installed,
pip
uses the local version to parse extras (which feels like a bug to me), so you have to work around it by first installing the appropriate JAX version without extras: