Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Enhance memory profiling in determine_num_available_blocks with error handling and fallback #9996

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
56 changes: 33 additions & 23 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,33 +179,45 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
KV blocks may be allocated without OOMs.

The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
Then, it calculates the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.

.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
# Check if GPU blocks override is provided; skip profiling if so.
num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override

if num_gpu_blocks_override is not None:
# Calculate num_cpu_blocks based on available swap space.
cache_block_size = self.get_cache_block_size_bytes()
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
cache_block_size) if cache_block_size else 0
num_cpu_blocks = max(num_cpu_blocks, 0)

logger.info(
"Using num_gpu_blocks_override=%d,calculated num_cpu_blocks=%d",
num_gpu_blocks_override, num_cpu_blocks)

return num_gpu_blocks_override, num_cpu_blocks

# Proceed with full profiling if no override is provided.
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info()

# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
# Profile model memory usage with a forward pass.
self.model_runner.profile_run()
torch.cuda.synchronize()

self._assert_memory_footprint_increased_during_profiling()

# Get the peak memory allocation recorded by torch
# Gather peak memory allocation.
peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"]

# Check for any memory left around that may have been allocated on the
# gpu outside of `torch`. NCCL operations, for example, can use a few
# GB during a forward pass
# Adjust for non-torch memory allocations.
torch.cuda.empty_cache()
torch_allocated_bytes = torch.cuda.memory_stats(
)["allocated_bytes.all.current"]
Expand All @@ -219,33 +231,31 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
total_gpu_memory * self.cache_config.gpu_memory_utilization -
peak_memory)

# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
# Calculate the number of blocks for both GPU and CPU.
cache_block_size = self.get_cache_block_size_bytes()
if cache_block_size == 0:
num_gpu_blocks = 0
num_cpu_blocks = 0
else:
num_gpu_blocks = int(available_kv_cache_memory // cache_block_size)
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
cache_block_size)
num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)
num_gpu_blocks = int(available_kv_cache_memory //
cache_block_size) if cache_block_size else 0
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
cache_block_size) if cache_block_size else 0
num_gpu_blocks, num_cpu_blocks = max(num_gpu_blocks,
0), max(num_cpu_blocks, 0)

logger.info(
"Memory profiling results: total_gpu_memory=%.2fGiB"
" initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB"
" memory_usage_post_profile=%.2fGib"
" non_torch_memory=%.2fGiB kv_cache_size=%.2fGiB"
" gpu_memory_utilization=%.2f", total_gpu_memory / (1024**3),
" gpu_memory_utilization=%.2f",
total_gpu_memory / (1024**3),
(total_gpu_memory - free_memory_pre_profile) / (1024**3),
(peak_memory - non_torch_allocations) / (1024**3),
total_allocated_bytes / (1024**3),
non_torch_allocations / (1024**3),
available_kv_cache_memory / (1024**3),
self.cache_config.gpu_memory_utilization)
self.cache_config.gpu_memory_utilization,
)

# Final cleanup
# Final cleanup.
if self.model_runner.lora_manager:
self.model_runner.remove_all_loras()
gc.collect()
Expand Down