Skip to content

Commit

Permalink
[core] overhaul memory profiling and fix backward compatibility (#10511)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Dec 16, 2024
1 parent efbce85 commit 551603f
Show file tree
Hide file tree
Showing 8 changed files with 236 additions and 60 deletions.
25 changes: 25 additions & 0 deletions tests/entrypoints/llm/test_gpu_utilization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from vllm import LLM, SamplingParams


def test_gpu_memory_utilization():
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# makes sure gpu_memory_utilization is per-instance limit,
# not a global limit
llms = [
LLM(model="facebook/opt-125m",
gpu_memory_utilization=0.3,
enforce_eager=True) for i in range(3)
]
for llm in llms:
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
2 changes: 1 addition & 1 deletion tests/entrypoints/llm/test_lazy_outlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def run_lmfe(sample_regex):
llm = LLM(model="facebook/opt-125m",
enforce_eager=True,
guided_decoding_backend="lm-format-enforcer",
gpu_memory_utilization=0.6)
gpu_memory_utilization=0.3)
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
outputs = llm.generate(
prompts=[
Expand Down
44 changes: 42 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from typing import AsyncIterator, Tuple

import pytest
import torch

from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs,
get_open_port, merge_async_iterators, supports_kw)
get_open_port, memory_profiling, merge_async_iterators,
supports_kw)

from .utils import error_on_warning
from .utils import error_on_warning, fork_new_process_for_each_test


@pytest.mark.asyncio
Expand Down Expand Up @@ -270,3 +272,41 @@ def test_supports_kw(callable,kw_name,requires_kw_only,
requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs
) == is_supported


@fork_new_process_for_each_test
def test_memory_profiling():
# Fake out some model loading + inference memory usage to test profiling
# Memory used by other processes will show up as cuda usage outside of torch
from vllm.distributed.device_communicators.cuda_wrapper import (
CudaRTLibrary)
lib = CudaRTLibrary()
# 512 MiB allocation outside of this instance
handle1 = lib.cudaMalloc(512 * 1024 * 1024)

baseline_memory_in_bytes = \
torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]

# load weights

weights = torch.randn(128, 1024, 1024, device='cuda', dtype=torch.float32)

weights_memory_in_bytes = 128 * 1024 * 1024 * 4 # 512 MiB

with memory_profiling(baseline_memory_in_bytes=baseline_memory_in_bytes,
weights_memory_in_bytes=weights_memory_in_bytes) as result:
# make a memory spike, 1 GiB
spike = torch.randn(256, 1024, 1024, device='cuda', dtype=torch.float32)
del spike

# Add some extra non-torch memory 256 MiB (simulate NCCL)
handle2 = lib.cudaMalloc(256 * 1024 * 1024)

# Check that the memory usage is within 5% of the expected values
non_torch_ratio = result.non_torch_increase_in_bytes / (256 * 1024 * 1024) # noqa
torch_peak_ratio = result.torch_peak_increase_in_bytes / (1024 * 1024 * 1024) # noqa
assert abs(non_torch_ratio - 1) <= 0.05
assert abs(torch_peak_ratio - 1) <= 0.05
del weights
lib.cudaFree(handle1)
lib.cudaFree(handle2)
18 changes: 9 additions & 9 deletions tests/worker/test_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ def test_gpu_memory_profiling():
is_driver_worker=True,
)

# Load the model so we can profile it
worker.init_device()
worker.load_model()

# Set 10GiB as the total gpu ram to be device-agnostic
def mock_mem_info():
current_usage = torch.cuda.memory_stats(
Expand All @@ -46,20 +42,24 @@ def mock_mem_info():

from unittest.mock import patch
with patch("torch.cuda.mem_get_info", side_effect=mock_mem_info):
# Load the model so we can profile it
worker.init_device()
worker.load_model()
gpu_blocks, _ = worker.determine_num_available_blocks()

# Peak vram usage by torch should be 0.7077 GiB
# Peak vram usage by torch should be 0.47 GiB
# Model weights take 0.25 GiB
# No memory should be allocated outside of torch
# 9.0 GiB should be the utilization target
# 8.2923 GiB should be available for the KV cache
# 8.28 GiB should be available for the KV cache
block_size = CacheEngine.get_cache_block_size(
engine_config.cache_config, engine_config.model_config,
engine_config.parallel_config)

expected_blocks = (8.2923 * 1024**3) // block_size
expected_blocks = (8.28 * 1024**3) // block_size

# Check within a small tolerance for portability
# Hardware, kernel, or dependency changes could all affect memory
# utilization.
# A 10 block tolerance here should be about 6MB of wiggle room.
assert abs(gpu_blocks - expected_blocks) < 10
# A 100 block tolerance here should be about 60MB of wiggle room.
assert abs(gpu_blocks - expected_blocks) < 100
11 changes: 6 additions & 5 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,11 +487,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help='The fraction of GPU memory to be used for the model '
'executor, which can range from 0 to 1. For example, a value of '
'0.5 would imply 50%% GPU memory utilization. If unspecified, '
'will use the default value of 0.9. This is a global gpu memory '
'utilization limit, for example if 50%% of the gpu memory is '
'already used before vLLM starts and --gpu-memory-utilization is '
'set to 0.9, then only 40%% of the gpu memory will be allocated '
'to the model executor.')
'will use the default value of 0.9. This is a per-instance '
'limit, and only applies to the current vLLM instance.'
'It does not matter if you have another vLLM instance running '
'on the same GPU. For example, if you have two vLLM instances '
'running on the same GPU, you can set the GPU memory utilization '
'to 0.5 for each instance.')
parser.add_argument(
'--num-gpu-blocks-override',
type=int,
Expand Down
125 changes: 123 additions & 2 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
from collections import UserDict, defaultdict
from collections.abc import Iterable, Mapping
from dataclasses import dataclass, field
from functools import lru_cache, partial, wraps
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
Dict, Generic, Hashable, List, Literal, Optional,
OrderedDict, Set, Tuple, Type, TypeVar, Union, overload)
Dict, Generator, Generic, Hashable, List, Literal,
Optional, OrderedDict, Set, Tuple, Type, TypeVar, Union,
overload)
from uuid import uuid4

import numpy as np
Expand Down Expand Up @@ -1664,3 +1666,122 @@ def kill_process_tree(pid: int):
# Finally kill the parent
with contextlib.suppress(ProcessLookupError):
os.kill(pid, signal.SIGKILL)


@dataclass
class MemorySnapshot:
"""Memory snapshot."""
torch_peak_in_bytes: int = 0
torch_memory_in_bytes: int = 0
timestamp: float = 0.0

def measure(self):
self.torch_peak_in_bytes = torch.cuda.memory_stats(
)["allocated_bytes.all.peak"]
self.torch_memory_in_bytes = torch.cuda.memory_stats(
)["allocated_bytes.all.current"]
self.timestamp = time.time()

def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot":
"""support a - b"""
return MemorySnapshot(
torch_peak_in_bytes=self.torch_peak_in_bytes -
other.torch_peak_in_bytes,
torch_memory_in_bytes=self.torch_memory_in_bytes -
other.torch_memory_in_bytes,
timestamp=self.timestamp - other.timestamp)


@dataclass
class MemoryProfilingResult:
"""Memory profiling result.
""" # noqa
baseline_memory_in_bytes: int = 0
non_kv_cache_memory_in_bytes: int = 0
torch_peak_increase_in_bytes: int = 0
non_torch_increase_in_bytes: int = 0
weights_memory_in_bytes: float = 0
before_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
after_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
profile_time: float = 0.0


@contextlib.contextmanager
def memory_profiling(
baseline_memory_in_bytes: int, weights_memory_in_bytes: int
) -> Generator[MemoryProfilingResult, None, None]:
"""Memory profiling context manager.
baseline_memory_in_bytes: memory used by all the components other than
the current vLLM instance. It contains: memory used by other processes, memory
used by another vLLM instance in the same process, etc. It is usually measured
before the current vLLM instance initialize the device. And we assume it is
constant during the profiling of the current vLLM instance.
weights_memory_in_bytes: memory used by PyTorch when loading the model weights.
Note that, before loading the model weights, we also initialize the device
and distributed environment, which may consume some memory. This part is not
included in the weights_memory_in_bytes because PyTorch does not control it.
The memory in one GPU can be classified into 3 categories:
1. memory used by anything other than the current vLLM instance.
2. memory used by torch in the current vLLM instance.
3. memory used in the current vLLM instance, but not by torch.
A quantitive example:
Before creating the current vLLM instance:
category 1: 1 GiB
category 2: 0 GiB
category 3: 0 GiB
After creating the current vLLM instance and loading the model,
(i.e. before profiling):
category 1: 1 GiB
category 2: 2 GiB (model weights take 2 GiB)
category 3: 0.5 GiB (memory used by NCCL)
During profiling (peak):
category 1: 1 GiB
category 2: 4 GiB (peak activation tensors take 2 GiB)
category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)
After profiling:
category 1: 1 GiB
category 2: 3 GiB (after garbage-collecting activation tensors)
category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)
In this case, non-kv cache takes 5 GiB in total, including:
a. 2 GiB used by the model weights (category 2)
b. 2 GiB reserved for the peak activation tensors (category 2)
c. 1 GiB used by non-torch components (category 3)
The memory used for loading weights (a.) is directly given from the argument `weights_memory_in_bytes`.
The increase of ``torch.cuda.memory_stats()["allocated_bytes.all.peak"]` after profiling gives (b.).
(c.) is tricky. We measure the total memory used in this GPU (`torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]`),
subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_stats()["allocated_bytes.all.current"]`.
""" # noqa
torch.cuda.reset_peak_memory_stats()

result = MemoryProfilingResult()

result.baseline_memory_in_bytes = baseline_memory_in_bytes
# the part of memory used for holding the model weights
result.weights_memory_in_bytes = weights_memory_in_bytes

result.before_profile.measure()

yield result

gc.collect()
torch.cuda.empty_cache()

result.after_profile.measure()

diff = result.after_profile - result.before_profile
result.torch_peak_increase_in_bytes = diff.torch_peak_in_bytes
current_cuda_memory_bytes = torch.cuda.mem_get_info(
)[1] - torch.cuda.mem_get_info()[0]
result.non_torch_increase_in_bytes = current_cuda_memory_bytes - baseline_memory_in_bytes - weights_memory_in_bytes - diff.torch_memory_in_bytes # noqa
result.profile_time = diff.timestamp
result.non_kv_cache_memory_in_bytes = result.non_torch_increase_in_bytes + result.torch_peak_increase_in_bytes + result.weights_memory_in_bytes # noqa
3 changes: 2 additions & 1 deletion vllm/worker/multi_step_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,7 +645,8 @@ def _advance_step(self, model_input: StatefulModelInput,
return model_input

def load_model(self) -> None:
return self._base_model_runner.load_model()
self._base_model_runner.load_model()
self.model_memory_usage = self._base_model_runner.model_memory_usage

def save_sharded_state(
self,
Expand Down
Loading

0 comments on commit 551603f

Please sign in to comment.