Skip to content

Commit

Permalink
[misc] add torch.compile compatibility check (vllm-project#10618)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Nov 25, 2024
1 parent 6581378 commit 25d806e
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 1 deletion.
2 changes: 1 addition & 1 deletion tests/v1/engine/test_engine_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

engine_args = EngineArgs(model=MODEL_NAME)
engine_args = EngineArgs(model=MODEL_NAME, compilation_config=3)
vllm_config = engine_args.create_engine_config()
executor_class = AsyncLLM._get_executor_cls(vllm_config)
client = EngineCoreClient.make_client(
Expand Down
14 changes: 14 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2394,6 +2394,20 @@ def __post_init__(self):
self.compilation_config.pass_config.enable_reshape = False
self.compilation_config.level = CompilationLevel.PIECEWISE

if self.cache_config is not None and \
self.cache_config.cpu_offload_gb > 0 and \
self.compilation_config.level != CompilationLevel.NO_COMPILATION:
logger.warning(
"CPU offload is not supported with `torch.compile` yet."
" Disabling `torch.compile`.")
self.compilation_config.level = CompilationLevel.NO_COMPILATION

if self.lora_config is not None and self.compilation_config.level !=\
CompilationLevel.NO_COMPILATION:
logger.warning("LoRA is not supported with `torch.compile` yet. "
"Disabling `torch.compile`.")
self.compilation_config.level = CompilationLevel.NO_COMPILATION

current_platform.check_and_update_config(self)

def __str__(self):
Expand Down
7 changes: 7 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,13 @@ def __post_init__(self):
if not self.tokenizer:
self.tokenizer = self.model

# support `EngineArgs(compilation_config={...})`
# without having to manually construct a
# CompilationConfig object
if isinstance(self.compilation_config, (int, dict)):
self.compilation_config = CompilationConfig.from_cli(
json.dumps(self.compilation_config))

# Setup plugins
from vllm.plugins import load_general_plugins
load_general_plugins()
Expand Down

0 comments on commit 25d806e

Please sign in to comment.