From 211fe91aa88730c04df439298d8103a587302493 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 30 Oct 2024 02:41:38 -0700 Subject: [PATCH] [TPU] Correctly profile peak memory usage & Upgrade PyTorch XLA (#9438) --- Dockerfile.tpu | 2 +- docs/source/getting_started/tpu-installation.rst | 4 ++-- vllm/worker/tpu_worker.py | 15 ++++++++------- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/Dockerfile.tpu b/Dockerfile.tpu index bdfab3f61910f..dd8f9ad4714a9 100644 --- a/Dockerfile.tpu +++ b/Dockerfile.tpu @@ -1,4 +1,4 @@ -ARG NIGHTLY_DATE="20240828" +ARG NIGHTLY_DATE="20241017" ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE" FROM $BASE_IMAGE diff --git a/docs/source/getting_started/tpu-installation.rst b/docs/source/getting_started/tpu-installation.rst index 217028839e347..edba209986f6a 100644 --- a/docs/source/getting_started/tpu-installation.rst +++ b/docs/source/getting_started/tpu-installation.rst @@ -56,8 +56,8 @@ First, install the dependencies: $ pip uninstall torch torch-xla -y $ # Install PyTorch and PyTorch XLA. - $ export DATE="20240828" - $ export TORCH_VERSION="2.5.0" + $ export DATE="20241017" + $ export TORCH_VERSION="2.6.0" $ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-${TORCH_VERSION}.dev${DATE}-cp310-cp310-linux_x86_64.whl $ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-${TORCH_VERSION}.dev${DATE}-cp310-cp310-linux_x86_64.whl diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index fe819b9f4b3a8..de6f7ab0072fd 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -133,18 +133,19 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: # Synchronize before measuring the memory usage. xm.wait_device_ops() - dtype_btyes = get_dtype_size(self.cache_dtype) - block_size = self.cache_config.block_size - block_size_bytes = (dtype_btyes * block_size * num_layers * 2 * - head_size * num_kv_heads) - - # Calculate the TPU KV cache size based on profiling. + # Get the maximum amount of memory used by the model weights and + # intermediate activations. m = xm.get_memory_info(self.device) total_memory_size = m["bytes_limit"] + profiled = m["peak_bytes_used"] # Weights + intermediate activations. + + # Calculate the TPU KV cache size based on profiling. usable_memory_size = int(total_memory_size * self.cache_config.gpu_memory_utilization) - profiled = m["bytes_used"] # Weights + intermediate activations. tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0) + dtype_btyes = get_dtype_size(self.cache_dtype) + block_size_bytes = (dtype_btyes * self.cache_config.block_size * + num_layers * 2 * head_size * num_kv_heads) num_tpu_blocks = tpu_kv_cache_bytes // block_size_bytes num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8.