Verifying CUDA and CUDNN installation paths for JAX #15405
Replies: 3 comments 1 reply
-
Correct. If you are using the pip wheel installation of CUDA, JAX only needs you to install the NVIDIA driver (https://github.com/google/jax#pip-installation-gpu-cuda-installed-via-pip-easier). The jax wheel prefers to use the pip-installed copy of CUDA/CUDNN if it is present. If it is not present, JAX will look for a system copy of CUDA/CUDNN. This is done primarily by adding a relative Both CUDA and CUDNN are required. It is permissible to, say, install CUDA on your system yourself and CUDNN using a pip wheel, for example, provided the versions are compatible: it should work but we don't test it ourselves. I don't entirely follow the question about version pinning. Can you clarify? You can pin the version of jax/jaxlib and all the CUDA wheels if you like, provided they are compatible. |
Beta Was this translation helpful? Give feedback.
-
To expand on this a little more: The following snippet doesn't actually confirm anything. It also works with only CUDA. import jax
# gpu-device
jax.numpy.array([1,]).device() However, this snippet only works with CUDA & CUDNN import jax
# gpu-device
m = jax.numpy.array([1,])
# matrix multiplication will fail
m@m |
Beta Was this translation helpful? Give feedback.
-
We are finding a similar issue where in our cluster max available cudnn is 8.9.4.
But we can't figure out how to install and use the nvidia pip wheels with it correctly instead. Any solutions? |
Beta Was this translation helpful? Give feedback.
-
Jax seems to work with only the nvidia driver installed.
However, now how can i know whether or not the system installation of CUDA / CUDNN and the paths are set correctly?
A related question is: If CUDA is required, is CUDNN required?
Another related question is: If CUDA / CUDNN is installed using wheels, how can i pin the version of jax / jaxlib?
Beta Was this translation helpful? Give feedback.
All reactions