Skip to content

Commit

Permalink
Autodetect CUDA version to install the correct PyTorch in CI
Browse files Browse the repository at this point in the history
  • Loading branch information
ssheorey committed Aug 2, 2024
1 parent 7a6cff0 commit 2576d1d
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions util/ci_utils.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ LOW_MEM_USAGE=${LOW_MEM_USAGE:-OFF}
# ML
TENSORFLOW_VER="2.13.0"
TORCH_VER="2.0.1"
TORCH_CPU_GLNX_VER="${TORCH_VER}+cpu"
TORCH_CUDA_GLNX_VER="${TORCH_VER}+cu117" # match CUDA_VERSION in docker/docker_build.sh
TORCH_MACOS_VER="${TORCH_VER}"
TORCH_REPO_URL="https://download.pytorch.org/whl/torch/"
# Python
PIP_VER="23.2.1"
Expand All @@ -53,7 +50,8 @@ install_python_dependencies() {
if [[ "with-cuda" =~ ^($options)$ ]]; then
TF_ARCH_NAME=tensorflow
TF_ARCH_DISABLE_NAME=tensorflow-cpu
TORCH_GLNX="torch==$TORCH_CUDA_GLNX_VER"
CUDA_VER=$(nvcc --version | grep "release " | cut -c33-37 | sed 's|[^0-9]||g') # e.g.: 117, 118, 121, ...
TORCH_GLNX="torch==${TORCH_VER}+cu${CUDA_VER}"
else
# tensorflow-cpu wheels for macOS arm64 are not available
if [[ "$OSTYPE" == "darwin"* ]]; then
Expand All @@ -63,7 +61,7 @@ install_python_dependencies() {
TF_ARCH_NAME=tensorflow-cpu
TF_ARCH_DISABLE_NAME=tensorflow
fi
TORCH_GLNX="torch==$TORCH_CPU_GLNX_VER"
TORCH_GLNX="torch==${TORCH_VER}+cpu"
fi

# TODO: modify other locations to use requirements.txt
Expand All @@ -83,7 +81,7 @@ install_python_dependencies() {
python -m pip install -U "${TORCH_GLNX}" -f "$TORCH_REPO_URL" tensorboard

elif [[ "$OSTYPE" == "darwin"* ]]; then
python -m pip install -U torch=="$TORCH_MACOS_VER" -f "$TORCH_REPO_URL" tensorboard
python -m pip install -U torch=="$TORCH_VER" -f "$TORCH_REPO_URL" tensorboard
else
echo "unknown OS $OSTYPE"
exit 1
Expand Down

0 comments on commit 2576d1d

Please sign in to comment.