From 551603feffd9b4ba98ccdd34e02e403e04db88c1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 16 Dec 2024 13:32:25 -0800 Subject: [PATCH] [core] overhaul memory profiling and fix backward compatibility (#10511) Signed-off-by: youkaichao --- tests/entrypoints/llm/test_gpu_utilization.py | 25 ++++ tests/entrypoints/llm/test_lazy_outlines.py | 2 +- tests/test_utils.py | 44 +++++- tests/worker/test_profile.py | 18 +-- vllm/engine/arg_utils.py | 11 +- vllm/utils.py | 125 +++++++++++++++++- vllm/worker/multi_step_model_runner.py | 3 +- vllm/worker/worker.py | 68 ++++------ 8 files changed, 236 insertions(+), 60 deletions(-) create mode 100644 tests/entrypoints/llm/test_gpu_utilization.py diff --git a/tests/entrypoints/llm/test_gpu_utilization.py b/tests/entrypoints/llm/test_gpu_utilization.py new file mode 100644 index 0000000000000..c2dab300ecefb --- /dev/null +++ b/tests/entrypoints/llm/test_gpu_utilization.py @@ -0,0 +1,25 @@ +from vllm import LLM, SamplingParams + + +def test_gpu_memory_utilization(): + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + + # makes sure gpu_memory_utilization is per-instance limit, + # not a global limit + llms = [ + LLM(model="facebook/opt-125m", + gpu_memory_utilization=0.3, + enforce_eager=True) for i in range(3) + ] + for llm in llms: + outputs = llm.generate(prompts, sampling_params) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/tests/entrypoints/llm/test_lazy_outlines.py b/tests/entrypoints/llm/test_lazy_outlines.py index 2c53676c5f5dd..bf609b38a94f5 100644 --- a/tests/entrypoints/llm/test_lazy_outlines.py +++ b/tests/entrypoints/llm/test_lazy_outlines.py @@ -36,7 +36,7 @@ def run_lmfe(sample_regex): llm = LLM(model="facebook/opt-125m", enforce_eager=True, guided_decoding_backend="lm-format-enforcer", - gpu_memory_utilization=0.6) + gpu_memory_utilization=0.3) sampling_params = SamplingParams(temperature=0.8, top_p=0.95) outputs = llm.generate( prompts=[ diff --git a/tests/test_utils.py b/tests/test_utils.py index a731b11eae81c..0bc9e5bc32a46 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,11 +5,13 @@ from typing import AsyncIterator, Tuple import pytest +import torch from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs, - get_open_port, merge_async_iterators, supports_kw) + get_open_port, memory_profiling, merge_async_iterators, + supports_kw) -from .utils import error_on_warning +from .utils import error_on_warning, fork_new_process_for_each_test @pytest.mark.asyncio @@ -270,3 +272,41 @@ def test_supports_kw(callable,kw_name,requires_kw_only, requires_kw_only=requires_kw_only, allow_var_kwargs=allow_var_kwargs ) == is_supported + + +@fork_new_process_for_each_test +def test_memory_profiling(): + # Fake out some model loading + inference memory usage to test profiling + # Memory used by other processes will show up as cuda usage outside of torch + from vllm.distributed.device_communicators.cuda_wrapper import ( + CudaRTLibrary) + lib = CudaRTLibrary() + # 512 MiB allocation outside of this instance + handle1 = lib.cudaMalloc(512 * 1024 * 1024) + + baseline_memory_in_bytes = \ + torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0] + + # load weights + + weights = torch.randn(128, 1024, 1024, device='cuda', dtype=torch.float32) + + weights_memory_in_bytes = 128 * 1024 * 1024 * 4 # 512 MiB + + with memory_profiling(baseline_memory_in_bytes=baseline_memory_in_bytes, + weights_memory_in_bytes=weights_memory_in_bytes) as result: + # make a memory spike, 1 GiB + spike = torch.randn(256, 1024, 1024, device='cuda', dtype=torch.float32) + del spike + + # Add some extra non-torch memory 256 MiB (simulate NCCL) + handle2 = lib.cudaMalloc(256 * 1024 * 1024) + + # Check that the memory usage is within 5% of the expected values + non_torch_ratio = result.non_torch_increase_in_bytes / (256 * 1024 * 1024) # noqa + torch_peak_ratio = result.torch_peak_increase_in_bytes / (1024 * 1024 * 1024) # noqa + assert abs(non_torch_ratio - 1) <= 0.05 + assert abs(torch_peak_ratio - 1) <= 0.05 + del weights + lib.cudaFree(handle1) + lib.cudaFree(handle2) diff --git a/tests/worker/test_profile.py b/tests/worker/test_profile.py index 194ea2aa506f4..79233c75714de 100644 --- a/tests/worker/test_profile.py +++ b/tests/worker/test_profile.py @@ -31,10 +31,6 @@ def test_gpu_memory_profiling(): is_driver_worker=True, ) - # Load the model so we can profile it - worker.init_device() - worker.load_model() - # Set 10GiB as the total gpu ram to be device-agnostic def mock_mem_info(): current_usage = torch.cuda.memory_stats( @@ -46,20 +42,24 @@ def mock_mem_info(): from unittest.mock import patch with patch("torch.cuda.mem_get_info", side_effect=mock_mem_info): + # Load the model so we can profile it + worker.init_device() + worker.load_model() gpu_blocks, _ = worker.determine_num_available_blocks() - # Peak vram usage by torch should be 0.7077 GiB + # Peak vram usage by torch should be 0.47 GiB + # Model weights take 0.25 GiB # No memory should be allocated outside of torch # 9.0 GiB should be the utilization target - # 8.2923 GiB should be available for the KV cache + # 8.28 GiB should be available for the KV cache block_size = CacheEngine.get_cache_block_size( engine_config.cache_config, engine_config.model_config, engine_config.parallel_config) - expected_blocks = (8.2923 * 1024**3) // block_size + expected_blocks = (8.28 * 1024**3) // block_size # Check within a small tolerance for portability # Hardware, kernel, or dependency changes could all affect memory # utilization. - # A 10 block tolerance here should be about 6MB of wiggle room. - assert abs(gpu_blocks - expected_blocks) < 10 + # A 100 block tolerance here should be about 60MB of wiggle room. + assert abs(gpu_blocks - expected_blocks) < 100 diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 0aa367a173b6c..06b8542779dc0 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -487,11 +487,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help='The fraction of GPU memory to be used for the model ' 'executor, which can range from 0 to 1. For example, a value of ' '0.5 would imply 50%% GPU memory utilization. If unspecified, ' - 'will use the default value of 0.9. This is a global gpu memory ' - 'utilization limit, for example if 50%% of the gpu memory is ' - 'already used before vLLM starts and --gpu-memory-utilization is ' - 'set to 0.9, then only 40%% of the gpu memory will be allocated ' - 'to the model executor.') + 'will use the default value of 0.9. This is a per-instance ' + 'limit, and only applies to the current vLLM instance.' + 'It does not matter if you have another vLLM instance running ' + 'on the same GPU. For example, if you have two vLLM instances ' + 'running on the same GPU, you can set the GPU memory utilization ' + 'to 0.5 for each instance.') parser.add_argument( '--num-gpu-blocks-override', type=int, diff --git a/vllm/utils.py b/vllm/utils.py index 45e682ac15782..73d2ae25f15ca 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -23,10 +23,12 @@ from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task from collections import UserDict, defaultdict from collections.abc import Iterable, Mapping +from dataclasses import dataclass, field from functools import lru_cache, partial, wraps from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, - Dict, Generic, Hashable, List, Literal, Optional, - OrderedDict, Set, Tuple, Type, TypeVar, Union, overload) + Dict, Generator, Generic, Hashable, List, Literal, + Optional, OrderedDict, Set, Tuple, Type, TypeVar, Union, + overload) from uuid import uuid4 import numpy as np @@ -1664,3 +1666,122 @@ def kill_process_tree(pid: int): # Finally kill the parent with contextlib.suppress(ProcessLookupError): os.kill(pid, signal.SIGKILL) + + +@dataclass +class MemorySnapshot: + """Memory snapshot.""" + torch_peak_in_bytes: int = 0 + torch_memory_in_bytes: int = 0 + timestamp: float = 0.0 + + def measure(self): + self.torch_peak_in_bytes = torch.cuda.memory_stats( + )["allocated_bytes.all.peak"] + self.torch_memory_in_bytes = torch.cuda.memory_stats( + )["allocated_bytes.all.current"] + self.timestamp = time.time() + + def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot": + """support a - b""" + return MemorySnapshot( + torch_peak_in_bytes=self.torch_peak_in_bytes - + other.torch_peak_in_bytes, + torch_memory_in_bytes=self.torch_memory_in_bytes - + other.torch_memory_in_bytes, + timestamp=self.timestamp - other.timestamp) + + +@dataclass +class MemoryProfilingResult: + """Memory profiling result. + """ # noqa + baseline_memory_in_bytes: int = 0 + non_kv_cache_memory_in_bytes: int = 0 + torch_peak_increase_in_bytes: int = 0 + non_torch_increase_in_bytes: int = 0 + weights_memory_in_bytes: float = 0 + before_profile: MemorySnapshot = field(default_factory=MemorySnapshot) + after_profile: MemorySnapshot = field(default_factory=MemorySnapshot) + profile_time: float = 0.0 + + +@contextlib.contextmanager +def memory_profiling( + baseline_memory_in_bytes: int, weights_memory_in_bytes: int +) -> Generator[MemoryProfilingResult, None, None]: + """Memory profiling context manager. + baseline_memory_in_bytes: memory used by all the components other than + the current vLLM instance. It contains: memory used by other processes, memory + used by another vLLM instance in the same process, etc. It is usually measured + before the current vLLM instance initialize the device. And we assume it is + constant during the profiling of the current vLLM instance. + weights_memory_in_bytes: memory used by PyTorch when loading the model weights. + Note that, before loading the model weights, we also initialize the device + and distributed environment, which may consume some memory. This part is not + included in the weights_memory_in_bytes because PyTorch does not control it. + + The memory in one GPU can be classified into 3 categories: + 1. memory used by anything other than the current vLLM instance. + 2. memory used by torch in the current vLLM instance. + 3. memory used in the current vLLM instance, but not by torch. + + A quantitive example: + + Before creating the current vLLM instance: + category 1: 1 GiB + category 2: 0 GiB + category 3: 0 GiB + + After creating the current vLLM instance and loading the model, + (i.e. before profiling): + category 1: 1 GiB + category 2: 2 GiB (model weights take 2 GiB) + category 3: 0.5 GiB (memory used by NCCL) + + During profiling (peak): + category 1: 1 GiB + category 2: 4 GiB (peak activation tensors take 2 GiB) + category 3: 1 GiB (memory used by NCCL + buffers for some attention backends) + + After profiling: + category 1: 1 GiB + category 2: 3 GiB (after garbage-collecting activation tensors) + category 3: 1 GiB (memory used by NCCL + buffers for some attention backends) + + In this case, non-kv cache takes 5 GiB in total, including: + a. 2 GiB used by the model weights (category 2) + b. 2 GiB reserved for the peak activation tensors (category 2) + c. 1 GiB used by non-torch components (category 3) + + The memory used for loading weights (a.) is directly given from the argument `weights_memory_in_bytes`. + + The increase of ``torch.cuda.memory_stats()["allocated_bytes.all.peak"]` after profiling gives (b.). + + (c.) is tricky. We measure the total memory used in this GPU (`torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]`), + subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_stats()["allocated_bytes.all.current"]`. + """ # noqa + torch.cuda.reset_peak_memory_stats() + + result = MemoryProfilingResult() + + result.baseline_memory_in_bytes = baseline_memory_in_bytes + # the part of memory used for holding the model weights + result.weights_memory_in_bytes = weights_memory_in_bytes + + result.before_profile.measure() + + yield result + + gc.collect() + torch.cuda.empty_cache() + + result.after_profile.measure() + + diff = result.after_profile - result.before_profile + result.torch_peak_increase_in_bytes = diff.torch_peak_in_bytes + current_cuda_memory_bytes = torch.cuda.mem_get_info( + )[1] - torch.cuda.mem_get_info()[0] + result.non_torch_increase_in_bytes = current_cuda_memory_bytes - baseline_memory_in_bytes - weights_memory_in_bytes - diff.torch_memory_in_bytes # noqa + result.profile_time = diff.timestamp + result.non_kv_cache_memory_in_bytes = result.non_torch_increase_in_bytes + result.torch_peak_increase_in_bytes + result.weights_memory_in_bytes # noqa diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index e08a61e31fe42..18b03bf1bfb56 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -645,7 +645,8 @@ def _advance_step(self, model_input: StatefulModelInput, return model_input def load_model(self) -> None: - return self._base_model_runner.load_model() + self._base_model_runner.load_model() + self.model_memory_usage = self._base_model_runner.model_memory_usage def save_sharded_state( self, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index a368bb9ee9a5b..f51b51d433d3d 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -1,7 +1,6 @@ """A GPU worker class.""" import gc import os -import time from typing import Dict, List, Optional, Set, Tuple, Type, Union import torch @@ -22,6 +21,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, SequenceGroupMetadata, SequenceGroupMetadataDelta) +from vllm.utils import GiB_bytes, memory_profiling from vllm.worker.cache_engine import CacheEngine from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner @@ -192,33 +192,22 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: torch.cuda.reset_peak_memory_stats() free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info() - start_time = time.time() # Execute a forward pass with dummy inputs to profile the memory usage # of the model. - self.model_runner.profile_run() - torch.cuda.synchronize() + with memory_profiling(baseline_memory_in_bytes=total_gpu_memory - + self.init_gpu_memory, + weights_memory_in_bytes=self.model_runner. + model_memory_usage) as result: + self.model_runner.profile_run() + torch.cuda.synchronize() self._assert_memory_footprint_increased_during_profiling() - # Get the peak memory allocation recorded by torch - peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"] - - # Check for any memory left around that may have been allocated on the - # gpu outside of `torch`. NCCL operations, for example, can use a few - # GB during a forward pass - torch.cuda.empty_cache() - torch_allocated_bytes = torch.cuda.memory_stats( - )["allocated_bytes.all.current"] - total_allocated_bytes = torch.cuda.mem_get_info( - )[1] - torch.cuda.mem_get_info()[0] - non_torch_allocations = total_allocated_bytes - torch_allocated_bytes - if non_torch_allocations > 0: - peak_memory += non_torch_allocations - - available_kv_cache_memory = ( - total_gpu_memory * self.cache_config.gpu_memory_utilization - - peak_memory) + memory_for_current_instance = total_gpu_memory * \ + self.cache_config.gpu_memory_utilization + available_kv_cache_memory = (memory_for_current_instance - + result.non_kv_cache_memory_in_bytes) # Calculate the number of blocks that can be allocated with the # profiled peak memory. @@ -233,24 +222,23 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: num_gpu_blocks = max(num_gpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) - end_time = time.time() - logger.info( - "Memory profiling results: " - "duration=%.2f seconds, " - "total_gpu_memory=%.2fGiB, " - "initial_memory_usage=%.2fGiB, " - "peak_torch_memory=%.2fGiB, " - "memory_usage_post_profile=%.2fGiB, " - "non_torch_memory=%.2fGiB, " - "kv_cache_size=%.2fGiB, " - "gpu_memory_utilization=%.2f.", end_time - start_time, - total_gpu_memory / (1024**3), - (total_gpu_memory - free_memory_pre_profile) / (1024**3), - (peak_memory - non_torch_allocations) / (1024**3), - total_allocated_bytes / (1024**3), - non_torch_allocations / (1024**3), - available_kv_cache_memory / (1024**3), - self.cache_config.gpu_memory_utilization) + msg = (f"Memory profiling takes {result.profile_time:.2f} seconds\n" + "the current vLLM instance can use " + "total_gpu_memory " + f"({(total_gpu_memory / GiB_bytes):.2f}GiB)" + " x gpu_memory_utilization " + f"({self.cache_config.gpu_memory_utilization:.2f})" + f" = {(memory_for_current_instance / GiB_bytes):.2f}GiB\n" + "model weights take " + f"{(result.weights_memory_in_bytes / GiB_bytes):.2f}GiB;" + " non_torch_memory takes " + f"{(result.non_torch_increase_in_bytes / GiB_bytes):.2f}GiB;" + " PyTorch activation peak memory takes " + f"{(result.torch_peak_increase_in_bytes / GiB_bytes):.2f}GiB;" + " the rest of the memory reserved for KV Cache is " + f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB.") + + logger.info(msg) # Final cleanup if self.model_runner.lora_manager: