Skip to content

Commit

Permalink
simplify memory profiling code
Browse files Browse the repository at this point in the history
Signed-off-by: Travis Johnson <[email protected]>
  • Loading branch information
tjohnson31415 committed Nov 21, 2024
1 parent d1e7c2a commit e2788e5
Showing 1 changed file with 20 additions and 31 deletions.
51 changes: 20 additions & 31 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,6 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

# Profiling may cause allocations within torch and outside of torch, we
# measure these separately.
free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info()
start_time = time.time()

# Execute a forward pass with dummy inputs to profile the memory usage
Expand All @@ -199,30 +196,26 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:

self._assert_memory_footprint_increased_during_profiling()

model_memory_usage = self.init_gpu_memory - free_memory_pre_profile
torch.cuda.empty_cache()

free_memory, total_gpu_memory = torch.cuda.mem_get_info()
memory_stats = torch.cuda.memory_stats()
torch_memory = memory_stats["allocated_bytes.all.current"]
torch_peak = memory_stats["allocated_bytes.all.peak"]

# The baseline memory usage with no inference requests
# This includes any persistent memory allocated during profiling,
# eg from NCCL or internal buffers for quantization
baseline_memory = self.init_gpu_memory - free_memory

# Get the peak memory recorded by torch during profiling.
# The spike in memory recorded by torch during profiling.
# We assume that no significant temporary allocations occur outside of
# torch during the profiling
peak_torch_memory_usage = torch.cuda.memory_stats(
)["allocated_bytes.all.peak"] - torch_memory_pre_profile
# torch during inference
inference_memory_spike = torch_peak - torch_memory

# Check for persistent memory allocated outside of torch. NCCL
# operations, for example, can allocate a few GB
gc.collect()
torch.cuda.empty_cache()
torch_memory_usage = torch.cuda.memory_stats(
)["allocated_bytes.all.current"] - torch_memory_pre_profile
total_memory_usage = self.init_gpu_memory - torch.cuda.mem_get_info(
)[0]
non_torch_memory_usage = (total_memory_usage - torch_memory_usage -
model_memory_usage)

# Peak memory usage expected during inference
peak_memory = peak_torch_memory_usage + non_torch_memory_usage
available_kv_cache_memory = (
total_gpu_memory * self.cache_config.gpu_memory_utilization -
model_memory_usage - peak_memory)
baseline_memory - inference_memory_spike)

# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
Expand All @@ -239,24 +232,20 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:

logger.info(
"Memory profiling results:"
"duration=%.2f seconds, "
" duration=%.2f seconds,"
" total_gpu_memory=%.2fGiB"
" gpu_memory_utilization=%.2f"
" target_allocation=%.2fGiB"
" model_size=%.2fGiB"
" peak_memory=%.2fGiB"
" torch_memory=%.2fGiB"
" non_torch_memory=%.2fGiB"
" baseline=%.2fGiB"
" max_inference_spike=%.2fGiB"
" kv_cache_size=%.2fGiB",
time.time() - start_time,
total_gpu_memory / (1024**3),
self.cache_config.gpu_memory_utilization,
total_gpu_memory / (1024**3) *
self.cache_config.gpu_memory_utilization,
model_memory_usage / (1024**3),
peak_memory / (1024**3),
peak_torch_memory_usage / (1024**3),
non_torch_memory_usage / (1024**3),
baseline_memory / (1024**3),
inference_memory_spike / (1024**3),
available_kv_cache_memory / (1024**3),
)

Expand Down

0 comments on commit e2788e5

Please sign in to comment.