diff --git a/tests/models/embedding/vision_language/test_llava_next.py b/tests/models/embedding/vision_language/test_llava_next.py index bab8d3897579e..329c6ba279f89 100644 --- a/tests/models/embedding/vision_language/test_llava_next.py +++ b/tests/models/embedding/vision_language/test_llava_next.py @@ -2,6 +2,7 @@ import pytest import torch.nn.functional as F +import transformers from transformers import AutoModelForVision2Seq from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner @@ -85,6 +86,9 @@ def _run_test( ) +@pytest.mark.skipif(transformers.__version__.startswith("4.46"), + reason="Model broken with changes in transformers 4.46") +@pytest.mark.core_model @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) def test_models_text( diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 9773ba8cec779..1206424ae1e3f 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -1,5 +1,6 @@ import copy import dataclasses +import time from contextlib import ExitStack from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple from unittest.mock import patch @@ -14,6 +15,7 @@ from .counter import compilation_counter from .inductor_pass import InductorPass +from .monitor import end_monitoring_torch_compile from .pass_manager import PostGradPassManager logger = init_logger(__name__) @@ -22,22 +24,23 @@ def wrap_inductor(graph, example_inputs, additional_inductor_config, - do_logging=False, + compilation_config: CompilationConfig, + graph_index: int = 0, + num_graphs: int = 1, runtime_shape: Optional[int] = None, use_inductor: bool = True): + if graph_index == 0: + # before compiling the first graph, record the start time + global compilation_start_time + compilation_start_time = time.time() + if not use_inductor: return graph compilation_counter.num_inductor_compilations += 1 - if do_logging: - if runtime_shape is None: - logger.info("Compiling a graph for general shape") - else: - logger.info("Compiling a graph for shape %s", runtime_shape) - from torch._inductor import config - current_config = config.shallow_copy_dict() + current_config = config.get_config_copy() from torch._inductor.compile_fx import compile_fx if additional_inductor_config is not None: @@ -52,7 +55,23 @@ def wrap_inductor(graph, # inductor can inplace modify the graph, so we need to copy it # see https://github.com/pytorch/pytorch/issues/138980 graph = copy.deepcopy(graph) - return compile_fx(graph, example_inputs, config_patches=current_config) + compiled_graph = compile_fx(graph, + example_inputs, + config_patches=current_config) + + # after compiling the last graph, record the end time + if graph_index == num_graphs - 1: + now = time.time() + elapsed = now - compilation_start_time + compilation_config.compilation_time += elapsed + if runtime_shape is None: + logger.info("Compiling a graph for general shape takes %.2f s", + elapsed) + else: + logger.info("Compiling a graph for shape %s takes %.2f s", + runtime_shape, elapsed) + + return compiled_graph @dataclasses.dataclass @@ -114,6 +133,8 @@ def split_graph(graph: fx.GraphModule, # we share the global graph pool among all the backends global_graph_pool = None +compilation_start_time = 0.0 + class PiecewiseCompileInterpreter(torch.fx.Interpreter): """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`. @@ -157,12 +178,15 @@ def call_module(self, target: torch.fx.node.Target, sym_shape_indices = [ i for i, x in enumerate(args) if isinstance(x, torch.SymInt) ] + global compilation_start_time compiled_graph_for_general_shape = wrap_inductor( submod, args, self.compilation_configs.inductor_compile_config, + self.compilation_configs, + graph_index=index, + num_graphs=len(self.compile_submod_names), runtime_shape=None, - do_logging=index == 0, use_inductor=self.compilation_configs.use_inductor) self.module.__dict__[target] = PiecewiseBackend( @@ -379,6 +403,8 @@ def __init__(self, graph: fx.GraphModule, # the entries for different shapes that we need to either # compile or capture cudagraph self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {} + self.to_be_compiled_sizes: Set[int] = self.compile_sizes.union( + self.capture_sizes) for shape in self.compile_sizes.union(self.capture_sizes): self.concrete_size_entries[shape] = ConcreteSizeEntry( runtime_shape=shape, @@ -389,6 +415,9 @@ def __init__(self, graph: fx.GraphModule, def __call__(self, *args) -> Any: if not self.first_run_finished: self.first_run_finished = True + # no specific sizes to compile + if self.is_last_graph and not self.to_be_compiled_sizes: + end_monitoring_torch_compile(self.compilation_configs) return self.compiled_graph_for_general_shape(*args) runtime_shape = args[self.sym_shape_indices[0]] @@ -403,15 +432,22 @@ def __call__(self, *args) -> Any: if entry.need_to_compile and not entry.compiled: entry.compiled = True + self.to_be_compiled_sizes.remove(runtime_shape) # args are real arguments entry.runnable = wrap_inductor( self.graph, args, self.compilation_configs.inductor_compile_config, + self.compilation_configs, + graph_index=self.piecewise_compile_index, + num_graphs=self.total_piecewise_compiles, runtime_shape=runtime_shape, - do_logging=self.is_first_graph, use_inductor=self.compilation_configs.use_inductor) + # finished compilations for all required shapes + if self.is_last_graph and not self.to_be_compiled_sizes: + end_monitoring_torch_compile(self.compilation_configs) + if not entry.use_cudagraph: return entry.runnable(*args) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 8700243c9d904..a32dced57e5b3 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -11,6 +11,8 @@ from vllm.sequence import IntermediateTensors from vllm.utils import supports_dynamo +from .monitor import start_monitoring_torch_compile + logger = init_logger(__name__) _T = TypeVar("_T", bound=type[nn.Module]) @@ -155,6 +157,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): TorchCompileWrapperWithCustomDispatcher.__init__( self, compilation_level=vllm_config.compilation_config.level) + if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE: + start_monitoring_torch_compile(vllm_config.compilation_config) + cls.__init__ = __init__ def __call__(self, *args, **kwargs): diff --git a/vllm/compilation/monitor.py b/vllm/compilation/monitor.py new file mode 100644 index 0000000000000..f718e46423212 --- /dev/null +++ b/vllm/compilation/monitor.py @@ -0,0 +1,14 @@ +from vllm.config import CompilationConfig, CompilationLevel +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def start_monitoring_torch_compile(compilation_config: CompilationConfig): + pass + + +def end_monitoring_torch_compile(compilation_config: CompilationConfig): + if compilation_config.level == CompilationLevel.PIECEWISE: + logger.info("graph compilation takes %.2f s in total", + compilation_config.compilation_time) diff --git a/vllm/config.py b/vllm/config.py index 5c904914a71cf..a5e2702035a5c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2281,6 +2281,7 @@ def model_post_init(self, __context: Any) -> None: # keep track of enabled and disabled custom ops enabled_custom_ops: Counter[str] = PrivateAttr disabled_custom_ops: Counter[str] = PrivateAttr + compilation_time: float = PrivateAttr # Per-model forward context # Mainly used to store attention cls @@ -2319,6 +2320,7 @@ def model_post_init(self, __context: Any) -> None: self.enabled_custom_ops = Counter() self.disabled_custom_ops = Counter() self.static_forward_context = {} + self.compilation_time = 0.0 def init_backend(self) -> Union[str, Callable]: if self.level == CompilationLevel.NO_COMPILATION: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1f3c6197ba1a8..26a8c94099a11 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -473,6 +473,7 @@ def _initialize_kv_caches(self) -> None: The workers will determine the number of blocks in both the GPU cache and the swap CPU cache. """ + start = time.time() num_gpu_blocks, num_cpu_blocks = ( self.model_executor.determine_num_available_blocks()) @@ -488,6 +489,9 @@ def _initialize_kv_caches(self) -> None: self.cache_config.num_cpu_blocks = num_cpu_blocks self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks) + elapsed = time.time() - start + logger.info(("init engine (profile, create kv cache, " + "warmup model) took %.2f seconds"), elapsed) @classmethod def _get_executor_cls(cls, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 397a33eed3896..751eb3b40a68d 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -67,6 +67,7 @@ def __init__( def _initialize_kv_caches(self, cache_config: CacheConfig) -> Tuple[int, int]: + start = time.time() num_gpu_blocks, _ = self.model_executor.determine_num_available_blocks( ) @@ -80,6 +81,9 @@ def _initialize_kv_caches(self, num_cpu_blocks = 0 self.model_executor.initialize_cache(num_gpu_blocks) + elapsed = time.time() - start + logger.info(("init engine (profile, create kv cache, " + "warmup model) took %.2f seconds"), elapsed) return num_gpu_blocks, num_cpu_blocks def add_request(self, request: EngineCoreRequest):