Skip to content

Commit

Permalink
[torch.compile] add warning for unsupported models (vllm-project#10622)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Nov 25, 2024
1 parent 7c2134b commit 6581378
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 0 deletions.
1 change: 1 addition & 0 deletions vllm/compilation/counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

@dataclasses.dataclass
class CompilationCounter:
num_models_seen: int = 0
num_graphs_seen: int = 0
# including the splitting ops
num_piecewise_graphs_seen: int = 0
Expand Down
2 changes: 2 additions & 0 deletions vllm/compilation/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch

from vllm.compilation.counter import compilation_counter
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import CompilationLevel, VllmConfig
from vllm.logger import init_logger
Expand Down Expand Up @@ -130,6 +131,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
] or not supports_dynamo()
if self.do_not_compile:
return
compilation_counter.num_models_seen += 1
TorchCompileWrapperWithCustomDispatcher.__init__(
self, compilation_level=vllm_config.compilation_config.level)

Expand Down
15 changes: 15 additions & 0 deletions vllm/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def set_current_vllm_config(vllm_config: "VllmConfig"):
"""
global _current_vllm_config
old_vllm_config = _current_vllm_config
from vllm.compilation.counter import compilation_counter
from vllm.config import CompilationLevel
num_models_seen = compilation_counter.num_models_seen
try:
_current_vllm_config = vllm_config
yield
Expand All @@ -88,6 +91,18 @@ def set_current_vllm_config(vllm_config: "VllmConfig"):
vllm_config.compilation_config.enabled_custom_ops)
logger.debug("disabled custom ops: %s",
vllm_config.compilation_config.disabled_custom_ops)
if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \
and compilation_counter.num_models_seen == num_models_seen:
# If the model supports compilation,
# compilation_counter.num_models_seen should be increased
# by at least 1.
# If it is not increased, it means the model does not support
# compilation (does not have @support_torch_compile decorator).
logger.warning(
"`torch.compile` is turned on, but the model %s"
" does not support it. Please open an issue on GitHub"
"if you want it to be supported.",
vllm_config.model_config.model)
_current_vllm_config = old_vllm_config


Expand Down

0 comments on commit 6581378

Please sign in to comment.