Skip to content

Commit

Permalink
Refactor TPU requirements file and pin build dependencies (vllm-proje…
Browse files Browse the repository at this point in the history
…ct#10010)

Signed-off-by: Richard Liu <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
  • Loading branch information
richardsliu authored and tlrmchlsmth committed Nov 23, 2024
1 parent 84c8b51 commit 64c3e7d
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 61 deletions.
7 changes: 0 additions & 7 deletions Dockerfile.tpu
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,6 @@ RUN apt-get update && apt-get install -y \
git \
ffmpeg libsm6 libxext6 libgl1

# Install the TPU and Pallas dependencies.
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

# Build vLLM.
COPY . .
ARG GIT_REPO_CHECK=0
Expand All @@ -25,7 +19,6 @@ ENV VLLM_TARGET_DEVICE="tpu"
RUN --mount=type=cache,target=/root/.cache/pip \
--mount=type=bind,source=.git,target=.git \
python3 -m pip install \
'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
-r requirements-tpu.txt
RUN python3 setup.py develop

Expand Down
57 changes: 5 additions & 52 deletions docs/source/getting_started/tpu-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,27 +119,19 @@ Uninstall the existing `torch` and `torch_xla` packages:
pip uninstall torch torch-xla -y
Install `torch` and `torch_xla`
Install build dependencies:

.. code-block:: bash
pip install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html
pip install -r requirements-tpu.txt
sudo apt-get install libopenblas-base libopenmpi-dev libomp-dev
Install JAX and Pallas:
Run the setup script:

.. code-block:: bash
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip install jaxlib==0.4.32.dev20240829 jax==0.4.32.dev20240829 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
Install other build dependencies:
VLLM_TARGET_DEVICE="tpu" python setup.py develop
.. code-block:: bash
pip install -r requirements-tpu.txt
VLLM_TARGET_DEVICE="tpu" python setup.py develop
sudo apt-get install libopenblas-base libopenmpi-dev libomp-dev
Provision Cloud TPUs with GKE
-----------------------------
Expand Down Expand Up @@ -168,45 +160,6 @@ Run the Docker image with the following command:
$ # Make sure to add `--privileged --net host --shm-size=16G`.
$ docker run --privileged --net host --shm-size=16G -it vllm-tpu
.. _build_from_source_tpu:

Build from source
-----------------

You can also build and install the TPU backend from source.

First, install the dependencies:

.. code-block:: console
$ # (Recommended) Create a new conda environment.
$ conda create -n myenv python=3.10 -y
$ conda activate myenv
$ # Clean up the existing torch and torch-xla packages.
$ pip uninstall torch torch-xla -y
$ # Install PyTorch and PyTorch XLA.
$ export DATE="20241017"
$ export TORCH_VERSION="2.6.0"
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-${TORCH_VERSION}.dev${DATE}-cp310-cp310-linux_x86_64.whl
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-${TORCH_VERSION}.dev${DATE}-cp310-cp310-linux_x86_64.whl
$ # Install JAX and Pallas.
$ pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
$ pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
$ # Install other build dependencies.
$ pip install -r requirements-tpu.txt
Next, build vLLM from source. This will only take a few seconds:

.. code-block:: console
$ VLLM_TARGET_DEVICE="tpu" python setup.py develop
.. note::

Since TPU relies on XLA which requires static shapes, vLLM bucketizes the possible input shapes and compiles an XLA graph for each different shape.
Expand Down
20 changes: 18 additions & 2 deletions requirements-tpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,22 @@
-r requirements-common.txt

# Dependencies for TPU
# Currently, the TPU backend uses a nightly version of PyTorch XLA.
# You can install the dependencies in Dockerfile.tpu.
cmake>=3.26
ninja
packaging
setuptools-scm>=8
wheel
jinja2
ray[default]

# Install torch_xla
--pre
--extra-index-url https://download.pytorch.org/whl/nightly/cpu
--find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
torch==2.6.0.dev20241028+cpu
torchvision==0.20.0.dev20241028+cpu
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl
jaxlib==0.4.32.dev20240829
jax==0.4.32.dev20240829

0 comments on commit 64c3e7d

Please sign in to comment.