Skip to content

Commit

Permalink
[Misc] Add tqdm progress bar during graph capture (vllm-project#11349)
Browse files Browse the repository at this point in the history
Signed-off-by: mgoin <[email protected]>
  • Loading branch information
mgoin authored Dec 20, 2024
1 parent 7801f56 commit b880ffb
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch
import torch.distributed
import torch.nn as nn
from tqdm import tqdm

import vllm.envs as envs
from vllm.attention import AttentionMetadata, get_attn_backend
Expand All @@ -21,7 +22,8 @@
from vllm.config import CompilationLevel, VllmConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.distributed import get_kv_transfer_group, get_pp_group
from vllm.distributed.parallel_state import graph_capture
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
graph_capture)
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
Expand Down Expand Up @@ -1413,8 +1415,8 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
logger.info("Capturing cudagraphs for decoding. This may lead to "
"unexpected consequences if the model is not static. To "
"run the model in eager mode, set 'enforce_eager=True' or "
"use '--enforce-eager' in the CLI.")
logger.info("If out-of-memory error occurs during cudagraph capture,"
"use '--enforce-eager' in the CLI. "
"If out-of-memory error occurs during cudagraph capture,"
" consider decreasing `gpu_memory_utilization` or "
"switching to eager mode. You can also reduce the "
"`max_num_seqs` as needed to decrease memory usage.")
Expand Down Expand Up @@ -1451,8 +1453,14 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
# memory usage of CUDA graph.
for virtual_engine in range(
self.parallel_config.pipeline_parallel_size):
for batch_size in \
self.vllm_config.compilation_config.capture_sizes:
# Only rank 0 should print progress bar during capture
capture_sizes = (
tqdm(
self.vllm_config.compilation_config.capture_sizes,
desc="Capturing CUDA graph shapes",
) if get_tensor_model_parallel_rank() == 0 else
self.vllm_config.compilation_config.capture_sizes)
for batch_size in capture_sizes:
attn_metadata = (
self.attn_state.graph_capture_get_metadata_for_batch(
batch_size,
Expand Down

0 comments on commit b880ffb

Please sign in to comment.