diff --git a/Dockerfile.tpu b/Dockerfile.tpu index b43442e4c0af1..0a507b6ecdf60 100644 --- a/Dockerfile.tpu +++ b/Dockerfile.tpu @@ -9,12 +9,6 @@ RUN apt-get update && apt-get install -y \ git \ ffmpeg libsm6 libxext6 libgl1 -# Install the TPU and Pallas dependencies. -RUN --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html -RUN --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html - # Build vLLM. COPY . . ARG GIT_REPO_CHECK=0 @@ -25,7 +19,6 @@ ENV VLLM_TARGET_DEVICE="tpu" RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,source=.git,target=.git \ python3 -m pip install \ - 'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \ -r requirements-tpu.txt RUN python3 setup.py develop diff --git a/docs/source/getting_started/tpu-installation.rst b/docs/source/getting_started/tpu-installation.rst index f0c812b941c1f..75ab2b6ba02dc 100644 --- a/docs/source/getting_started/tpu-installation.rst +++ b/docs/source/getting_started/tpu-installation.rst @@ -119,27 +119,19 @@ Uninstall the existing `torch` and `torch_xla` packages: pip uninstall torch torch-xla -y -Install `torch` and `torch_xla` +Install build dependencies: .. code-block:: bash - pip install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu --index-url https://download.pytorch.org/whl/nightly/cpu - pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html + pip install -r requirements-tpu.txt + sudo apt-get install libopenblas-base libopenmpi-dev libomp-dev -Install JAX and Pallas: +Run the setup script: .. code-block:: bash - pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html - pip install jaxlib==0.4.32.dev20240829 jax==0.4.32.dev20240829 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html - -Install other build dependencies: + VLLM_TARGET_DEVICE="tpu" python setup.py develop -.. code-block:: bash - - pip install -r requirements-tpu.txt - VLLM_TARGET_DEVICE="tpu" python setup.py develop - sudo apt-get install libopenblas-base libopenmpi-dev libomp-dev Provision Cloud TPUs with GKE ----------------------------- @@ -168,45 +160,6 @@ Run the Docker image with the following command: $ # Make sure to add `--privileged --net host --shm-size=16G`. $ docker run --privileged --net host --shm-size=16G -it vllm-tpu - -.. _build_from_source_tpu: - -Build from source ------------------ - -You can also build and install the TPU backend from source. - -First, install the dependencies: - -.. code-block:: console - - $ # (Recommended) Create a new conda environment. - $ conda create -n myenv python=3.10 -y - $ conda activate myenv - - $ # Clean up the existing torch and torch-xla packages. - $ pip uninstall torch torch-xla -y - - $ # Install PyTorch and PyTorch XLA. - $ export DATE="20241017" - $ export TORCH_VERSION="2.6.0" - $ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-${TORCH_VERSION}.dev${DATE}-cp310-cp310-linux_x86_64.whl - $ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-${TORCH_VERSION}.dev${DATE}-cp310-cp310-linux_x86_64.whl - - $ # Install JAX and Pallas. - $ pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html - $ pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html - - $ # Install other build dependencies. - $ pip install -r requirements-tpu.txt - - -Next, build vLLM from source. This will only take a few seconds: - -.. code-block:: console - - $ VLLM_TARGET_DEVICE="tpu" python setup.py develop - .. note:: Since TPU relies on XLA which requires static shapes, vLLM bucketizes the possible input shapes and compiles an XLA graph for each different shape. diff --git a/requirements-tpu.txt b/requirements-tpu.txt index 4c606cf0a9105..f9a0770804e55 100644 --- a/requirements-tpu.txt +++ b/requirements-tpu.txt @@ -2,6 +2,22 @@ -r requirements-common.txt # Dependencies for TPU -# Currently, the TPU backend uses a nightly version of PyTorch XLA. -# You can install the dependencies in Dockerfile.tpu. +cmake>=3.26 +ninja +packaging +setuptools-scm>=8 +wheel +jinja2 ray[default] + +# Install torch_xla +--pre +--extra-index-url https://download.pytorch.org/whl/nightly/cpu +--find-links https://storage.googleapis.com/libtpu-releases/index.html +--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html +--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html +torch==2.6.0.dev20241028+cpu +torchvision==0.20.0.dev20241028+cpu +torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl +jaxlib==0.4.32.dev20240829 +jax==0.4.32.dev20240829