Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor TPU requirements file and pin build dependencies #10010

Merged
merged 10 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions Dockerfile.tpu
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,6 @@ RUN apt-get update && apt-get install -y \
git \
ffmpeg libsm6 libxext6 libgl1

# Install the TPU and Pallas dependencies.
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

# Build vLLM.
COPY . .
ARG GIT_REPO_CHECK=0
Expand All @@ -25,7 +19,6 @@ ENV VLLM_TARGET_DEVICE="tpu"
RUN --mount=type=cache,target=/root/.cache/pip \
--mount=type=bind,source=.git,target=.git \
python3 -m pip install \
'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
-r requirements-tpu.txt
RUN python3 setup.py develop

Expand Down
60 changes: 5 additions & 55 deletions docs/source/getting_started/tpu-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -114,32 +114,21 @@ Clone the vLLM repository and go to the vLLM directory:
git clone https://github.com/vllm-project/vllm.git && cd vllm

Uninstall the existing `torch` and `torch_xla` packages:

.. code-block:: bash

pip uninstall torch torch-xla -y
richardsliu marked this conversation as resolved.
Show resolved Hide resolved

Install `torch` and `torch_xla`
Install build dependencies:

.. code-block:: bash

pip install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html

Install JAX and Pallas:

.. code-block:: bash

pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip install jaxlib==0.4.32.dev20240829 jax==0.4.32.dev20240829 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip install -r requirements-tpu.txt
sudo apt-get install libopenblas-base libopenmpi-dev libomp-dev

Install other build dependencies:
Run the setup script:

.. code-block:: bash
VLLM_TARGET_DEVICE="tpu" python setup.py develop

pip install -r requirements-tpu.txt
VLLM_TARGET_DEVICE="tpu" python setup.py develop
sudo apt-get install libopenblas-base libopenmpi-dev libomp-dev

Provision Cloud TPUs with GKE
-----------------------------
Expand Down Expand Up @@ -168,45 +157,6 @@ Run the Docker image with the following command:
$ # Make sure to add `--privileged --net host --shm-size=16G`.
$ docker run --privileged --net host --shm-size=16G -it vllm-tpu


.. _build_from_source_tpu:

Build from source
-----------------

You can also build and install the TPU backend from source.

First, install the dependencies:

.. code-block:: console

$ # (Recommended) Create a new conda environment.
$ conda create -n myenv python=3.10 -y
$ conda activate myenv

$ # Clean up the existing torch and torch-xla packages.
$ pip uninstall torch torch-xla -y

$ # Install PyTorch and PyTorch XLA.
$ export DATE="20241017"
$ export TORCH_VERSION="2.6.0"
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-${TORCH_VERSION}.dev${DATE}-cp310-cp310-linux_x86_64.whl
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-${TORCH_VERSION}.dev${DATE}-cp310-cp310-linux_x86_64.whl

$ # Install JAX and Pallas.
$ pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
$ pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

$ # Install other build dependencies.
$ pip install -r requirements-tpu.txt


Next, build vLLM from source. This will only take a few seconds:

.. code-block:: console

$ VLLM_TARGET_DEVICE="tpu" python setup.py develop

.. note::

Since TPU relies on XLA which requires static shapes, vLLM bucketizes the possible input shapes and compiles an XLA graph for each different shape.
Expand Down
20 changes: 18 additions & 2 deletions requirements-tpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,22 @@
-r requirements-common.txt

# Dependencies for TPU
# Currently, the TPU backend uses a nightly version of PyTorch XLA.
# You can install the dependencies in Dockerfile.tpu.
cmake>=3.26
ninja
packaging
setuptools-scm>=8
wheel
jinja2
ray[default]

# Install torch_xla
--pre
--extra-index-url https://download.pytorch.org/whl/nightly/cpu
--find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
torch==2.6.0.dev20241028+cpu
torchvision==0.20.0.dev20241028+cpu
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl
jaxlib==0.4.32.dev20240829
jax==0.4.32.dev20240829