Skip to content

How to install compatible versions of Jax and Jaxlib for a particular CUDA version? #8358

Answered by jakevdp
davidrpugh asked this question in Q&A
Discussion options

You must be logged in to vote

JAX dropped support for CUDA 10.X in jaxlib version 0.1.72 (See https://github.com/google/jax/blob/main/CHANGELOG.md#jaxlib-0172-oct-12-2021), and looking at the changelog the current JAX version at that release was v 0.2.16. So you should be able to use something like

pip install jax[cuda101]==0.2.16 -f https://storage.googleapis.com/jax-releases/jax_releases.html

Note that if you already have a version of JAX installed, pip uses the local version to parse extras (which feels like a bug to me), so you have to work around it by first installing the appropriate JAX version without extras:

pip install jax==0.2.16
pip install jax[cuda101]==0.2.16 -f https://storage.googleapis.com/jax-releas…

Replies: 3 comments 2 replies

Comment options

You must be logged in to vote
1 reply
@davidrpugh
Comment options

Answer selected by davidrpugh
Comment options

You must be logged in to vote
1 reply
@davidrpugh
Comment options

Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
4 participants