-
-
Notifications
You must be signed in to change notification settings - Fork 5.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bugfix][Core] Use torch.cuda.memory_stats() to profile peak memory usage #9352
Changes from all commits
f374f3d
6f12a01
3eb8293
fbb5e8f
7f9b77c
994b2a3
d0eee63
4255af2
4b17056
b9e279b
f48b6b9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import torch | ||
|
||
from vllm.engine.arg_utils import EngineArgs | ||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port | ||
from vllm.worker.cache_engine import CacheEngine | ||
from vllm.worker.worker import Worker | ||
|
||
|
||
def test_gpu_memory_profiling(): | ||
# Tests the gpu profiling that happens in order to determine the number of | ||
# KV cache blocks that we can allocate on the GPU. | ||
# This test mocks the maximum available gpu memory so that it can run on | ||
# any gpu setup. | ||
|
||
# Set up engine args to build a worker. | ||
engine_args = EngineArgs(model="facebook/opt-125m", | ||
dtype="half", | ||
load_format="dummy") | ||
engine_config = engine_args.create_engine_config() | ||
engine_config.cache_config.num_gpu_blocks = 1000 | ||
engine_config.cache_config.num_cpu_blocks = 1000 | ||
|
||
# Create the worker. | ||
distributed_init_method = get_distributed_init_method( | ||
get_ip(), get_open_port()) | ||
worker = Worker( | ||
model_config=engine_config.model_config, | ||
parallel_config=engine_config.parallel_config, | ||
scheduler_config=engine_config.scheduler_config, | ||
device_config=engine_config.device_config, | ||
cache_config=engine_config.cache_config, | ||
load_config=engine_config.load_config, | ||
local_rank=0, | ||
rank=0, | ||
distributed_init_method=distributed_init_method, | ||
is_driver_worker=True, | ||
) | ||
|
||
# Load the model so we can profile it | ||
worker.init_device() | ||
worker.load_model() | ||
|
||
# Set 10GiB as the total gpu ram to be device-agnostic | ||
def mock_mem_info(): | ||
current_usage = torch.cuda.memory_stats( | ||
)["allocated_bytes.all.current"] | ||
mock_total_bytes = 10 * 1024**3 | ||
free = mock_total_bytes - current_usage | ||
|
||
return (free, mock_total_bytes) | ||
|
||
from unittest.mock import patch | ||
with patch("torch.cuda.mem_get_info", side_effect=mock_mem_info): | ||
gpu_blocks, _ = worker.determine_num_available_blocks() | ||
|
||
# Peak vram usage by torch should be 0.7077 GiB | ||
# Non-torch allocations should be 0.0079 GiB | ||
# 9.0 GiB should be the utilization target | ||
# 8.2843 GiB should be available for the KV cache | ||
block_size = CacheEngine.get_cache_block_size( | ||
engine_config.cache_config, engine_config.model_config, | ||
engine_config.parallel_config) | ||
|
||
expected_blocks = (8.2843 * 1024**3) // block_size | ||
|
||
# Check within a small tolerance for portability | ||
# Hardware, kernel, or dependency changes could all affect memory | ||
# utilization | ||
assert abs(gpu_blocks - expected_blocks) < 5 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure this is a large enough tolerance TBH but am good with setting it to something and adjusting in the future There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Heh, well I also don't have too much context personally for how much this could swing with any hardware or software differences. At least this currently works on the A100s I tested on, and whatever worker nodes the CI runs have landed on today 😉 I also don't want to go too wide on the tolerance and end up having this test pass if some changes are accidentally made to the profiling code. This test should catch about 8MB of memory allocated outside of torch, and 5 blocks should be about 3MB in this configuration. I can bump it up to 10 so there's 6MB of wiggle room if that sounds alright. I will also happily accept people yelling at me if this test becomes super flaky There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good, I'll keep an eye on it |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -217,42 +217,76 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: | |
# Profile the memory usage of the model and get the maximum number of | ||
# cache blocks that can be allocated with the remaining free memory. | ||
torch.cuda.empty_cache() | ||
torch.cuda.reset_peak_memory_stats() | ||
|
||
free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info() | ||
|
||
# Execute a forward pass with dummy inputs to profile the memory usage | ||
# of the model. | ||
self.model_runner.profile_run() | ||
torch.cuda.synchronize() | ||
|
||
self._assert_memory_footprint_increased_during_profiling() | ||
|
||
# Get the peak memory allocation recorded by torch | ||
peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"] | ||
|
||
# Check for any memory left around that may have been allocated on the | ||
# gpu outside of `torch`. NCCL operations, for example, can use a few | ||
# GB during a forward pass | ||
torch.cuda.empty_cache() | ||
# After emptying the torch cache, any other increase in gpu ram should | ||
# be from non-torch allocations. | ||
non_torch_allocations = free_memory_pre_profile - \ | ||
torch.cuda.mem_get_info()[0] | ||
if non_torch_allocations > 0: | ||
peak_memory += non_torch_allocations | ||
|
||
available_kv_cache_memory = ( | ||
total_gpu_memory * self.cache_config.gpu_memory_utilization - | ||
peak_memory) | ||
|
||
# Calculate the number of blocks that can be allocated with the | ||
# profiled peak memory. | ||
torch.cuda.synchronize() | ||
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() | ||
# NOTE(woosuk): Here we assume that the other processes using the same | ||
# GPU did not change their memory usage during the profiling. | ||
peak_memory = self.init_gpu_memory - free_gpu_memory | ||
assert peak_memory > 0, ( | ||
"Error in memory profiling. " | ||
f"Initial free memory {self.init_gpu_memory}, current free memory" | ||
f" {free_gpu_memory}. This happens when the GPU memory was " | ||
"not properly cleaned up before initializing the vLLM instance.") | ||
|
||
cache_block_size = self.get_cache_block_size_bytes() | ||
if cache_block_size == 0: | ||
num_gpu_blocks = 0 | ||
num_cpu_blocks = 0 | ||
else: | ||
num_gpu_blocks = int( | ||
(total_gpu_memory * self.cache_config.gpu_memory_utilization - | ||
peak_memory) // cache_block_size) | ||
num_gpu_blocks = int(available_kv_cache_memory // cache_block_size) | ||
num_cpu_blocks = int(self.cache_config.swap_space_bytes // | ||
cache_block_size) | ||
num_gpu_blocks = max(num_gpu_blocks, 0) | ||
num_cpu_blocks = max(num_cpu_blocks, 0) | ||
|
||
logger.info( | ||
"Memory profiling results: total_gpu_memory=%.2fGiB" | ||
" initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB" | ||
" non_torch_memory=%.2fGiB kv_cache_size=%.2fGiB" | ||
" gpu_memory_utilization=%.2f", total_gpu_memory / (1024**3), | ||
(total_gpu_memory - free_memory_pre_profile) / (1024**3), | ||
(peak_memory - non_torch_allocations) / (1024**3), | ||
non_torch_allocations / (1024**3), | ||
available_kv_cache_memory / (1024**3), | ||
self.cache_config.gpu_memory_utilization) | ||
|
||
# Final cleanup | ||
if self.model_runner.lora_manager: | ||
self.model_runner.remove_all_loras() | ||
gc.collect() | ||
Comment on lines
274
to
276
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably this Also I'm not sure what the reason for the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah I'm not sure on that either, I'm guessing we're just trying to clean up everything we did during profiling, before the KV cache is allocated? re: moving the gc.collect(), I was trying to leave it later here in case there was something allocated outside torch that hadn't been GC'ed yet that we may need to account for in the peak memory usage. If we run all the cleanup and then check the free memory, then the only reason it would be lower is if there's a memory leak, right? idk- I could go either way. I'm not 100% sold we need the extra check for non-torch allocated memory since it's pretty flaky to try to check for. Think we should just back that out and leave the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. update: @tjohnson31415 will give this a go to see if he can reproduce the NCCL allocations he was seeing that were blowing up vram usage. If this code catches it we'll keep it in, if not I'll back it out There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It needs a small fix, but the The call to In my test with Llama-3.1-8B-Instruct w/ TP=8, moving There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🌶️🌶️🌶️!
I agree the cleanup in general could better live in the profile run executions, but I do want to limit the blast radius here to this file. I'll leave as-is unless anybody feels strongly about refactoring into the individual model runners |
||
torch.cuda.empty_cache() | ||
|
||
return num_gpu_blocks, num_cpu_blocks | ||
|
||
def _assert_memory_footprint_increased_during_profiling(self): | ||
# NOTE(woosuk): Here we assume that the other processes using the same | ||
# GPU did not change their memory usage during the profiling. | ||
free_gpu_memory, _ = torch.cuda.mem_get_info() | ||
assert self.init_gpu_memory - free_gpu_memory > 0, ( | ||
"Error in memory profiling. " | ||
f"Initial free memory {self.init_gpu_memory}, current free memory" | ||
f" {free_gpu_memory}. This happens when the GPU memory was " | ||
"not properly cleaned up before initializing the vLLM instance.") | ||
|
||
def initialize_cache(self, num_gpu_blocks: int, | ||
num_cpu_blocks: int) -> None: | ||
"""Allocate GPU and CPU KV cache with the specified number of blocks. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this test was working before due to the over-estimation of peak memory usage of the model which caused a smaller KV cache to be allocated. Two
LLM
s both setgpu_memory_utilization=0.3
, but once the first LLM uses the full 30% of the gpu, there's no space left to allocate room for the second one.This setting is a bit confusing- how it has been coded is "The total GPU allocation may not exceed x% of the gpu memory when loading this model", but it looks like the test assumed the setting meant "You may not allocate more than x% of the gpu memory for this model, regardless of how much of the gpu memory ends up being allocated." In other words, it assumed this was a per-model limit and not a global limit on gpu memory allocation.
Maybe that should be made more clear in the docs?
(Just a comment for readers- I don't intend to make more docs changes in the scope of this PR)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation. Could you add a short comment, since the reader may find it odd that the second call sets
gpu_memory_utilization
differently from the first?Alternatively, looks like the first llm doesn't need to be live when the second one is created, so we could try to force it to be garbage collected but I don't think it's worth jumping through hoops for this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree :D
I added a small comment here for future readers