Skip to content

Commit

Permalink
Update references to JAX's GitHub repo
Browse files Browse the repository at this point in the history
JAX has moved from https://github.com/google/jax to https://github.com/jax-ml/jax

PiperOrigin-RevId: 702886874
  • Loading branch information
jakeharmon8 authored and copybara-github committed Dec 5, 2024
1 parent 6138208 commit 4fd8a50
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions vmoe/scripts/install_gce.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ source env/bin/activate;

if ( which nvidia-smi &> /dev/null ); then
# This assumes CUDA 11 and cuDNN 8.2.
# Check https://github.com/google/jax#pip-installation-gpu-cuda for alternatives.
# Check https://github.com/jax-ml/jax#pip-installation-gpu-cuda for alternatives.
pip install -q 'jax[cuda]' -f https://storage.googleapis.com/jax-releases/jax_releases.html;
else
# Since pjit does not work on CPUs, if nvidia-smi is not found, we assume that
Expand All @@ -31,7 +31,7 @@ fi;
python3 -m pip install -q --upgrade pip;
# Upgrade the following packages from GIT, since we use some features not part
# of any release yet.
pip install -q --upgrade git+https://github.com/google/jax.git;
pip install -q --upgrade git+https://github.com/jax-ml/jax.git;
pip install -q --upgrade git+https://github.com/google/flax.git;
pip install -q --upgrade git+https://github.com/google/CommonLoopUtils.git;
# Install the rest of necessary packages from PyPi.
Expand Down

0 comments on commit 4fd8a50

Please sign in to comment.