Skip to content

Commit

Permalink
[TPU] Use mark_dynamic to reduce compilation time (vllm-project#7340)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Aug 11, 2024
1 parent 4c5d8e8 commit 90bab18
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 16 deletions.
2 changes: 1 addition & 1 deletion Dockerfile.tpu
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ARG NIGHTLY_DATE="20240726"
ARG NIGHTLY_DATE="20240808"
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"

FROM $BASE_IMAGE
Expand Down
4 changes: 2 additions & 2 deletions docs/source/getting_started/tpu-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ First, install the dependencies:
$ pip uninstall torch torch-xla -y
$ # Install PyTorch and PyTorch XLA.
$ export DATE="+20240726"
$ export DATE="+20240808"
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl
Expand All @@ -65,7 +65,7 @@ First, install the dependencies:
$ pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
$ # Install other build dependencies.
$ pip install packaging aiohttp
$ pip install -r requirements-tpu.txt
Next, build vLLM from source. This will only take a few seconds:
Expand Down
60 changes: 47 additions & 13 deletions vllm/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,19 +147,7 @@ def load_model(self) -> None:
)
model = model.eval()
xm.wait_device_ops()

model = ModelWrapper(model)
# NOTE(woosuk): There are two stages of compilation: torch.compile and
# XLA compilation. Setting dynamic=True can reduce the torch.compile
# overhead by reusing the FX graph for different shapes.
# However, the XLA graph will still require static shapes and needs to
# be re-compiled for every different shapes. This overhead is inevitable
# in the first run, but can be skipped afterwards as we cache the XLA
# graphs in the disk (VLLM_XLA_CACHE_PATH).
self.model = torch.compile(model,
backend="openxla",
fullgraph=True,
dynamic=True)
self.model = CompiledModelWrapper(model)

def _dummy_run(
self,
Expand Down Expand Up @@ -697,6 +685,52 @@ def forward(
return next_token_ids


class CompiledModelWrapper:

def __init__(self, model: nn.Module):
model = ModelWrapper(model)
self.model = torch.compile(model,
backend="openxla",
fullgraph=True,
dynamic=False)

def __call__(
self,
token_ids: torch.Tensor,
position_ids: torch.Tensor,
attn_metadata: AttentionMetadata,
input_lens: torch.Tensor,
t: torch.Tensor,
p: torch.Tensor,
num_samples: int,
kv_caches: List[Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]],
) -> torch.Tensor:
# NOTE(woosuk): There are two stages of compilation: torch.compile and
# XLA compilation. Using `mark_dynamic` can reduce the torch.compile
# overhead by reusing the FX graph for different shapes.
# However, the XLA graph will still require static shapes and needs to
# be re-compiled for every different shapes. This overhead is inevitable
# in the first run, but can be skipped afterwards as we cache the XLA
# graphs in the disk (VLLM_XLA_CACHE_PATH).
if attn_metadata.num_prefills > 0:
# Prefll
torch._dynamo.mark_dynamic(token_ids, 1)
torch._dynamo.mark_dynamic(position_ids, 1)
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1)
else:
# Decode
torch._dynamo.mark_dynamic(token_ids, 0)
torch._dynamo.mark_dynamic(position_ids, 0)
torch._dynamo.mark_dynamic(input_lens, 0)
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
torch._dynamo.mark_dynamic(t, 0)
torch._dynamo.mark_dynamic(p, 0)
return self.model(token_ids, position_ids, attn_metadata, input_lens,
t, p, num_samples, kv_caches)


def _get_padded_prefill_len(x: int) -> int:
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
# length to be a multiple of 16. We pad the prompt length to the nearest
Expand Down

0 comments on commit 90bab18

Please sign in to comment.