diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index c631850ecdedb..45f56cbbd4b16 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -11,8 +11,8 @@ from vllm.compilation.compile_context import set_compile_context from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile -from vllm.compilation.levels import CompilationLevel -from vllm.config import VllmConfig +from vllm.config import CompilationLevel, VllmConfig +from vllm.plugins import set_current_vllm_config from vllm.utils import direct_register_custom_op global_counter = 0 @@ -82,7 +82,9 @@ def test_simple_piecewise_compile(): os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE) - model = SillyModel(vllm_config=VllmConfig(), prefix='') + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + model = SillyModel(vllm_config=vllm_config, prefix='') inputs = torch.randn(100).cuda() diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index c363a587a818e..8032304e95806 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -15,12 +15,10 @@ from torch.library import Library from vllm.compilation.compile_context import set_compile_context -from vllm.compilation.config import CompilationConfig from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile -from vllm.compilation.levels import CompilationLevel -from vllm.config import VllmConfig -from vllm.plugins import set_compilation_config +from vllm.config import CompilationConfig, CompilationLevel, VllmConfig +from vllm.plugins import set_compilation_config, set_current_vllm_config from vllm.utils import direct_register_custom_op # create a library to hold the custom op @@ -272,9 +270,11 @@ def run_model(llama_config, CompilationLevel.NO_COMPILATION) set_compilation_config(None) - model = LlamaModel(config=llama_config, - vllm_config=VllmConfig(), - prefix="").eval().cuda() + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + model = LlamaModel(config=llama_config, + vllm_config=vllm_config, + prefix="").eval().cuda() B = 16 # max batch size input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() @@ -395,9 +395,11 @@ def benchmark(): else: set_compilation_config(None) - model = LlamaModel(config=llama_config, - vllm_config=VllmConfig(), - prefix="").eval().cuda().to(torch.bfloat16) + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + model = LlamaModel(config=llama_config, + vllm_config=vllm_config, + prefix="").eval().cuda().to(torch.bfloat16) B = 256 # max batch size input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 833589ba5dc9f..08747ebc58b75 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -3,7 +3,7 @@ import pytest -from vllm.compilation.levels import CompilationLevel +from vllm.config import CompilationLevel from vllm.utils import cuda_device_count_stateless from ..utils import compare_all_settings diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index f00334934cb46..4dfdfe21a67df 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -1,6 +1,6 @@ import pytest -from vllm.compilation.levels import CompilationLevel +from vllm.config import CompilationLevel from ..utils import fork_new_process_for_each_test from .utils import TEST_MODELS, check_full_graph_support diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index e4d3defafb951..4db79b070fd8d 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -3,10 +3,10 @@ from compressed_tensors.quantization import FP8_DTYPE import vllm.envs as envs -from vllm.compilation.config import CompilationConfig from vllm.compilation.fusion import (FusionPass, find_auto_fn, find_auto_fn_maybe) from vllm.compilation.reshapes import RedundantReshapesPass +from vllm.config import CompilationConfig from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( apply_fp8_linear) diff --git a/tests/compile/test_wrapper.py b/tests/compile/test_wrapper.py index 3668c1fab6b89..74f66baaa5ea1 100644 --- a/tests/compile/test_wrapper.py +++ b/tests/compile/test_wrapper.py @@ -3,6 +3,7 @@ import torch from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher +from vllm.config import CompilationLevel class MyMod(torch.nn.Module): @@ -18,7 +19,8 @@ class MyWrapper(TorchCompileWrapperWithCustomDispatcher): def __init__(self, model): self.model = model compiled_callable = torch.compile(self.forward, backend="eager") - super().__init__(compiled_callable) + super().__init__(compiled_callable, + compilation_level=CompilationLevel.DYNAMO_ONCE) def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): # this is the function to be compiled diff --git a/tests/compile/utils.py b/tests/compile/utils.py index 222c63a342a4b..729f10676888b 100644 --- a/tests/compile/utils.py +++ b/tests/compile/utils.py @@ -4,7 +4,7 @@ from tests.quantization.utils import is_quant_method_supported from vllm import LLM, SamplingParams -from vllm.compilation.levels import CompilationLevel +from vllm.config import CompilationLevel from vllm.platforms import current_platform TEST_MODELS = [ diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index af267f804ffa7..c3219bc50646b 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -3,11 +3,13 @@ import pytest +from vllm.config import CompilationConfig, VllmConfig from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import (GeluAndMul, ReLUSquaredActivation, SiluAndMul) from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.plugins import set_current_vllm_config # Registered subclass for test @@ -51,42 +53,40 @@ class Relu3(ReLUSquaredActivation): ]) def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int], default_on: bool): - os.environ["VLLM_CUSTOM_OPS"] = env os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(torch_level) + vllm_config = VllmConfig(compilation_config=CompilationConfig( + custom_ops=env.split(","))) + with set_current_vllm_config(vllm_config): + assert CustomOp.default_on() == default_on - # Reset default_on (computed once): - CustomOp.default_on.cache_clear() + ops_enabled = [bool(x) for x in ops_enabled] - assert CustomOp.default_on() == default_on + assert RMSNorm(1024).enabled() == ops_enabled[0] + assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0] - ops_enabled = [bool(x) for x in ops_enabled] + assert SiluAndMul().enabled() == ops_enabled[1] + assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1] - assert RMSNorm(1024).enabled() == ops_enabled[0] - assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0] + assert GeluAndMul().enabled() == ops_enabled[2] + assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2] - assert SiluAndMul().enabled() == ops_enabled[1] - assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1] + # If registered, subclasses should follow their own name + assert Relu3().enabled() == ops_enabled[3] + assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3] - assert GeluAndMul().enabled() == ops_enabled[2] - assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2] + # Unregistered subclass + class SiluAndMul2(SiluAndMul): + pass - # If registered, subclasses should follow their own name - assert Relu3().enabled() == ops_enabled[3] - assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3] - - # Unregistered subclass - class SiluAndMul2(SiluAndMul): - pass - - # Subclasses should not require registration - assert SiluAndMul2().enabled() == SiluAndMul().enabled() + # Subclasses should not require registration + assert SiluAndMul2().enabled() == SiluAndMul().enabled() @pytest.mark.parametrize( "env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"]) def test_enabled_ops_invalid(env: str): - os.environ["VLLM_CUSTOM_OPS"] = env - CustomOp.default_on.cache_clear() - - with pytest.raises(AssertionError): - RMSNorm(1024).enabled() + with pytest.raises(Exception): # noqa + vllm_config = VllmConfig(compilation_config=CompilationConfig( + custom_ops=env.split(","))) + with set_current_vllm_config(vllm_config): + RMSNorm(1024).enabled() diff --git a/tests/tpu/test_compilation.py b/tests/tpu/test_compilation.py index 86d9af88e49ea..941abe17a3378 100644 --- a/tests/tpu/test_compilation.py +++ b/tests/tpu/test_compilation.py @@ -5,7 +5,7 @@ import depyf -from vllm.compilation.levels import CompilationLevel +from vllm.config import CompilationLevel # disable custom dispatcher, let Dynamo takes over # all the control diff --git a/tests/tpu/test_custom_dispatcher.py b/tests/tpu/test_custom_dispatcher.py index 923d0f1680802..53b10c06135a1 100644 --- a/tests/tpu/test_custom_dispatcher.py +++ b/tests/tpu/test_custom_dispatcher.py @@ -1,6 +1,6 @@ import os -from vllm.compilation.levels import CompilationLevel +from vllm.config import CompilationLevel from ..utils import compare_two_settings diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 5682faa158069..22c613931f082 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -10,13 +10,12 @@ import torch.fx as fx import vllm.envs as envs +from vllm.config import CompilationConfig, CompilationLevel from vllm.logger import init_logger from vllm.utils import combine_fx_passes, weak_ref_tensors -from .config import CompilationConfig from .counter import compilation_counter from .fusion import FusionPass -from .levels import CompilationLevel from .reshapes import RedundantReshapesPass logger = init_logger(__name__) @@ -392,7 +391,10 @@ class VllmBackend: sym_tensor_indices: List[int] input_buffers: List[torch.Tensor] - def __init__(self, post_grad_passes: Sequence[Callable] = ()): + def __init__( + self, + compilation_configs: CompilationConfig, + ): global global_graph_pool if global_graph_pool is None: global_graph_pool = torch.cuda.graph_pool_handle() @@ -401,11 +403,13 @@ def __init__(self, post_grad_passes: Sequence[Callable] = ()): # streams, it might not be safe to share a global pool. # only investigate this when we use multiple streams self.graph_pool = global_graph_pool - self.post_grad_passes = post_grad_passes + self.post_grad_passes = [] self.sym_tensor_indices = [] self.input_buffers = [] + self.compilation_configs = compilation_configs + # `torch.compile` is JIT compiled, so we don't need to # do anything here @@ -437,10 +441,10 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: assert not self._called, "VllmBackend can only be called once" self.graph = graph - # config is read now, because only here can + # config is updated now, because only here can # we get the sizes to capture for cudagraph # from compilation context - self.compilation_configs = CompilationConfig.select_and_init_config() + self.compilation_configs.init_during_runtime() self.add_passes_to_config() self.split_gm, self.piecewise_graphs = split_graph( @@ -688,4 +692,6 @@ def select_default_backend(level: int) -> Union[str, Callable]: return backend_str assert level == CompilationLevel.PIECEWISE - return VllmBackend() + from vllm.plugins import get_current_vllm_config + compilation_config = get_current_vllm_config().compilation_config + return VllmBackend(compilation_config) diff --git a/vllm/compilation/config.py b/vllm/compilation/config.py deleted file mode 100644 index 3e663505c627d..0000000000000 --- a/vllm/compilation/config.py +++ /dev/null @@ -1,159 +0,0 @@ -import copy -from pathlib import Path -from typing import Any, Dict, List, Optional - -from pydantic import BaseModel, Field, PrivateAttr - -import vllm.envs as envs -from vllm.logger import init_logger - -from .compile_context import get_compile_context - -logger = init_logger(__name__) - - -class CompilationConfig(BaseModel): - """ - Configuration for compilation. - It has two parts: - - CudaGraph capture: - - use_cudagraph: whether to use cudagraph inside compilation. - - False: cudagraph inside compilation is not used. - - True: cudagraph inside compilation is used. It requires - that all input buffers have fixed addresses. - Note that this is orthogonal to the cudagraph capture out - side of compilation. - TODO: move outside cudagraph logic into compilation. - torch.compile will handle cudagraph capture logic in the future. - - cudagraph_capture_sizes: sizes to capture cudagraph. - - None: capture sizes are inferred from compilation context. - - List[int]: capture sizes are specified. - - cudagraph_num_of_warmups: number of warmup runs for cudagraph. - It means the first several runs will be treated as warmup runs. - Only after that, the execution will be recorded, and the recorded - cudagraph will be used for subsequent runs. - - cudagraph_copy_inputs: whether to copy input tensors for - cudagraph. If the caller can guarantee that the same input buffers - are always used, it can set this to False. Otherwise, it should - set this to True, and the compiler will copy the input to an - internally managed buffer. Default is False. - - Inductor compilation: - - use_inductor: whether to use inductor compilation. - - False: inductor compilation is not used. graph runs in eager. - - True: inductor compilation is used. one graph for symbolic shape - is compiled. In addition, compile for different sizes specified - in inductor_compile_sizes, using configurations - in inductor_compile_config. - - inductor_compile_sizes: sizes to compile for inductor. - - inductor_specialize_for_cudagraph_no_more_than: an optional integer - to specialize inductor for cudagraph sizes no more than the - specified size. It is useful when we want to specialize inductor - with a subset of cudagraph sizes. - - inductor_compile_config: additional configurations for inductor. - - None: use default configurations. - - inductor_passes: additional passes for inductor. It is a dictionary - from pass name to pass function qualified name. We use function - name because the config uses json format. If we pass the config - from Python, functions can also be passed directly via Python object - constructor, e.g. `CompilationConfig(inductor_passes={"a": func})` - - Custom inductor passes: - - dump_graph_stages: list of stages for which we want to dump the graph. - Each pass defines its own stages (before, after, maybe in-between). - - dump_graph_dir: directory to dump the graph. Default is . - - enable_fusion: whether to enable the custom fusion pass. - TODO better pass enabling system. - - Why we have different sizes for cudagraph and inductor: - - cudagraph: a cudagraph captured for a specific size can only be used - for the same size. We need to capture all the sizes we want to use. - - inductor: a graph compiled by inductor for a general shape can be used - for different sizes. Inductor can also compile for specific sizes, - where it can have more information to optimize the graph with fully - static shapes. However, we find the general shape compilation is - sufficient for most cases. It might be beneficial to compile for - certain small batchsizes, where inductor is good at optimizing. - """ - use_inductor: bool = True - inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None - inductor_compile_sizes: Optional[List[int]] = Field(default_factory=dict) - inductor_compile_config: Dict = Field(default_factory=dict) - inductor_passes: Dict[str, str] = Field(default_factory=dict) - - use_cudagraph: bool = False - non_cudagraph_ops: List[str] = Field(default_factory=list) - cudagraph_num_of_warmups: int = 0 - cudagraph_capture_sizes: Optional[List[int]] = None - cudagraph_copy_inputs: bool = False - - dump_graph_stages: List[str] = Field(default_factory=list) - dump_graph_dir: Path = Field(default=Path(".")) - enable_fusion: bool = True - - # not configurable, computed after init - compile_sizes: List[int] = PrivateAttr - capture_sizes: List[int] = PrivateAttr - - def model_post_init(self, __context: Any) -> None: - for k, v in self.inductor_passes.items(): - if not isinstance(v, str): - assert callable(v), ( - f"pass {k} should be a function or a qualified name") - self.inductor_compile_config[k] = v - continue - - # resolve function from qualified name - names = v.split(".") - module = ".".join(names[:-1]) - func_name = names[-1] - func = __import__(module).__dict__[func_name] - self.inductor_compile_config[k] = func - - def init_during_runtime(self): - """To complete the initialization of config, - we need to know the compile context, which is only available - during the first run of the model. - """ - context = get_compile_context() - context = copy.deepcopy(context) if context is not None else [] - sizes_to_specialize: List[int] = context - if self.cudagraph_capture_sizes is None: - self.capture_sizes = sizes_to_specialize - else: - self.capture_sizes = self.cudagraph_capture_sizes - logger.info(("cudagraph sizes specified by model runner" - " %s is overridden by config %s"), - sizes_to_specialize, self.cudagraph_capture_sizes) - if self.inductor_specialize_for_cudagraph_no_more_than is not None: - assert self.inductor_compile_sizes is None, ( - "inductor_compile_sizes should be None when " - "inductor_specialize_for_cudagraph_no_more_than is not None") - self.compile_sizes = [ - x for x in self.capture_sizes - if x <= self.inductor_specialize_for_cudagraph_no_more_than - ] - else: - assert self.inductor_compile_sizes is not None, ( - "inductor_compile_sizes should not be None when " - "inductor_specialize_for_cudagraph_no_more_than is None") - self.compile_sizes = self.inductor_compile_sizes - - @staticmethod - def select_and_init_config() -> "CompilationConfig": - """The order of selecting config is: - 1. Use the config specified in environment variable. - 2. Use the config specified in plugins. - 3. Use the default config. - """ - config_path = envs.VLLM_TORCH_COMPILE_CONFIG - if config_path is not None: - with open(config_path) as json_file: - config = CompilationConfig.model_validate_json( - json_file.read()) - else: - from vllm.plugins import get_compilation_config - predefined_config = get_compilation_config() - config = predefined_config if predefined_config is not None else ( - CompilationConfig()) - - config.init_during_runtime() - return config diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index ca1e96a33c014..4b78491bc5a48 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -3,10 +3,8 @@ import torch -import vllm.envs as envs -from vllm.compilation.levels import CompilationLevel from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import VllmConfig +from vllm.config import CompilationLevel, VllmConfig from vllm.logger import init_logger from vllm.sequence import IntermediateTensors from vllm.utils import supports_dynamo @@ -126,12 +124,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner # will handle the compilation, so we don't need to do anything here. - self.do_not_compile = envs.VLLM_TORCH_COMPILE_LEVEL in [ + self.do_not_compile = \ + vllm_config.compilation_config.level in [ CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS ] or not supports_dynamo() if self.do_not_compile: return - TorchCompileWrapperWithCustomDispatcher.__init__(self) + TorchCompileWrapperWithCustomDispatcher.__init__( + self, compilation_level=vllm_config.compilation_config.level) cls.__init__ = __init__ # type: ignore diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index eb43604b1399b..e6a3afef85e1b 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -6,8 +6,8 @@ from torch._inductor.pattern_matcher import (Match, PatternMatcherPass, fwd_only, register_replacement) -from vllm.compilation.config import CompilationConfig from vllm.compilation.inductor_pass import InductorPass +from vllm.config import CompilationConfig from vllm.logger import init_logger logger = init_logger(__name__) diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index b23351fa19759..8082a08b40019 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -2,7 +2,7 @@ import torch -from vllm.compilation.config import CompilationConfig +from vllm.config import CompilationConfig # yapf: disable from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank from vllm.distributed import ( diff --git a/vllm/compilation/levels.py b/vllm/compilation/levels.py deleted file mode 100644 index 19a3a2b526870..0000000000000 --- a/vllm/compilation/levels.py +++ /dev/null @@ -1,8 +0,0 @@ -# constants for the levels of the compilation process - - -class CompilationLevel: - NO_COMPILATION = 0 - DYNAMO_AS_IS = 1 - DYNAMO_ONCE = 2 - PIECEWISE = 3 diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 7366ed4d16b0b..2a1aecc11ce26 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -8,8 +8,7 @@ import torch import vllm.envs as envs - -from .levels import CompilationLevel +from vllm.config import CompilationLevel class TorchCompileWrapperWithCustomDispatcher: @@ -25,7 +24,9 @@ class TorchCompileWrapperWithCustomDispatcher: `torch.compile` over the forward method. """ - def __init__(self, compiled_callable: Optional[Callable] = None): + def __init__(self, + compiled_callable: Optional[Callable] = None, + compilation_level: int = 0): if compiled_callable is None: # default compilation settings @@ -38,7 +39,7 @@ def __init__(self, compiled_callable: Optional[Callable] = None): backend = get_torch_compile_backend() if backend is None: from vllm.compilation.backends import select_default_backend - backend = select_default_backend(envs.VLLM_TORCH_COMPILE_LEVEL) + backend = select_default_backend(compilation_level) compiled_callable = torch.compile( self.forward, @@ -54,7 +55,7 @@ def __init__(self, compiled_callable: Optional[Callable] = None): # subclasses can use this to switch between the custom dispatcher # and the default Dynamo guard mechanism. self.use_custom_dispatcher: bool = \ - envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.DYNAMO_ONCE + compilation_level >= CompilationLevel.DYNAMO_ONCE def __call__(self, *args, **kwargs): """Implement the dispatch logic here, beyond the torch.compile level. diff --git a/vllm/config.py b/vllm/config.py index 64b2f75e092de..7e37edbe594b1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3,10 +3,12 @@ import json import warnings from dataclasses import dataclass, field, replace +from pathlib import Path from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Dict, Final, List, Literal, Mapping, Optional, Set, Tuple, Type, Union) import torch +from pydantic import BaseModel, Field, PrivateAttr from transformers import PretrainedConfig import vllm.envs as envs @@ -2052,6 +2054,185 @@ def __post_init__(self): f"installed. Original error:\n{otel_import_error_traceback}") +class CompilationLevel: + # constants for the levels of the compilation process + NO_COMPILATION = 0 + DYNAMO_AS_IS = 1 + DYNAMO_ONCE = 2 + PIECEWISE = 3 + + +class CompilationConfig(BaseModel): + """ + Configuration for compilation. + It has three parts: + - Top-level Compilation control: + - level: the level of compilation. + - 0: no compilation. + - 1: dynamo as is. + - 2: dynamo once. + - 3: piecewise compilation. + - custom_ops: fine-grained control over which custom ops to enable/disable. + Use 'all' to enable all, 'none' to disable all. + Also specify a list of custom op names to enable (prefixed with a '+'), + or disable (prefixed with a '-'). + Examples: + - 'all,-op1' to enable all except op1 + - 'none,+op1,+op2' to enable only op1 and op2 + By default, all custom ops are enabled when running without Inductor + and disabled when running with Inductor (compile_level >= Inductor). + - CudaGraph capture: + - use_cudagraph: whether to use cudagraph inside compilation. + - False: cudagraph inside compilation is not used. + - True: cudagraph inside compilation is used. It requires + that all input buffers have fixed addresses. + Note that this is orthogonal to the cudagraph capture out + side of compilation. + TODO: move outside cudagraph logic into compilation. + torch.compile will handle cudagraph capture logic in the future. + - cudagraph_capture_sizes: sizes to capture cudagraph. + - None: capture sizes are inferred from compilation context. + - List[int]: capture sizes are specified. + - cudagraph_num_of_warmups: number of warmup runs for cudagraph. + It means the first several runs will be treated as warmup runs. + Only after that, the execution will be recorded, and the recorded + cudagraph will be used for subsequent runs. + - cudagraph_copy_inputs: whether to copy input tensors for + cudagraph. If the caller can guarantee that the same input buffers + are always used, it can set this to False. Otherwise, it should + set this to True, and the compiler will copy the input to an + internally managed buffer. Default is False. + - Inductor compilation: + - use_inductor: whether to use inductor compilation. + - False: inductor compilation is not used. graph runs in eager. + - True: inductor compilation is used. one graph for symbolic shape + is compiled. In addition, compile for different sizes specified + in inductor_compile_sizes, using configurations + in inductor_compile_config. + - inductor_compile_sizes: sizes to compile for inductor. + - inductor_specialize_for_cudagraph_no_more_than: an optional integer + to specialize inductor for cudagraph sizes no more than the + specified size. It is useful when we want to specialize inductor + with a subset of cudagraph sizes. + - inductor_compile_config: additional configurations for inductor. + - None: use default configurations. + - inductor_passes: additional passes for inductor. It is a dictionary + from pass name to pass function qualified name. We use function + name because the config uses json format. If we pass the config + from Python, functions can also be passed directly via Python object + constructor, e.g. `CompilationConfig(inductor_passes={"a": func})` + - custom inductor passes: + - dump_graph_stages: list of stages for which we want to dump the graph. + Each pass defines its own stages (before, after, maybe in-between). + - dump_graph_dir: directory to dump the graph. Default is . + - enable_fusion: whether to enable the custom fusion pass. + TODO better pass enabling system. + + Why we have different sizes for cudagraph and inductor: + - cudagraph: a cudagraph captured for a specific size can only be used + for the same size. We need to capture all the sizes we want to use. + - inductor: a graph compiled by inductor for a general shape can be used + for different sizes. Inductor can also compile for specific sizes, + where it can have more information to optimize the graph with fully + static shapes. However, we find the general shape compilation is + sufficient for most cases. It might be beneficial to compile for + certain small batchsizes, where inductor is good at optimizing. + """ # noqa + level: int = 0 + custom_ops: List[str] = Field(default_factory=list) + + use_inductor: bool = True + inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None + inductor_compile_sizes: Optional[List[int]] = Field(default_factory=dict) + inductor_compile_config: Dict = Field(default_factory=dict) + inductor_passes: Dict[str, str] = Field(default_factory=dict) + + use_cudagraph: bool = False + non_cudagraph_ops: List[str] = Field(default_factory=list) + cudagraph_num_of_warmups: int = 0 + cudagraph_capture_sizes: Optional[List[int]] = None + cudagraph_copy_inputs: bool = False + + dump_graph_stages: List[str] = Field(default_factory=list) + dump_graph_dir: Path = Field(default=Path(".")) + enable_fusion: bool = True + + # not configurable, computed after init + compile_sizes: List[int] = PrivateAttr + capture_sizes: List[int] = PrivateAttr + + def model_post_init(self, __context: Any) -> None: + self.level = envs.VLLM_TORCH_COMPILE_LEVEL + + count_none = self.custom_ops.count("none") + count_all = self.custom_ops.count("all") + assert count_none + count_all <= 1, "Can only specify 'none' or 'all'" + + for k, v in self.inductor_passes.items(): + if not isinstance(v, str): + assert callable(v), ( + f"pass {k} should be a function or a qualified name") + self.inductor_compile_config[k] = v + continue + + # resolve function from qualified name + names = v.split(".") + module = ".".join(names[:-1]) + func_name = names[-1] + func = __import__(module).__dict__[func_name] + self.inductor_compile_config[k] = func + + def init_during_runtime(self): + """To complete the initialization of config, + we need to know the compile context, which is only available + during the first run of the model. + """ + from vllm.compilation.compile_context import get_compile_context + context = get_compile_context() + context = copy.deepcopy(context) if context is not None else [] + sizes_to_specialize: List[int] = context + if self.cudagraph_capture_sizes is None: + self.capture_sizes = sizes_to_specialize + else: + self.capture_sizes = self.cudagraph_capture_sizes + logger.info(("cudagraph sizes specified by model runner" + " %s is overridden by config %s"), + sizes_to_specialize, self.cudagraph_capture_sizes) + if self.inductor_specialize_for_cudagraph_no_more_than is not None: + assert self.inductor_compile_sizes is None, ( + "inductor_compile_sizes should be None when " + "inductor_specialize_for_cudagraph_no_more_than is not None") + self.compile_sizes = [ + x for x in self.capture_sizes + if x <= self.inductor_specialize_for_cudagraph_no_more_than + ] + else: + assert self.inductor_compile_sizes is not None, ( + "inductor_compile_sizes should not be None when " + "inductor_specialize_for_cudagraph_no_more_than is None") + self.compile_sizes = self.inductor_compile_sizes + + @staticmethod + def select_and_init_config() -> "CompilationConfig": + """The order of selecting config is: + 1. Use the config specified in environment variable. + 2. Use the config specified in plugins. + 3. Use the default config. + """ + config_path = envs.VLLM_TORCH_COMPILE_CONFIG + if config_path is not None: + with open(config_path) as json_file: + config = CompilationConfig.model_validate_json( + json_file.read()) + else: + from vllm.plugins import get_compilation_config + predefined_config = get_compilation_config() + config = predefined_config if predefined_config is not None else ( + CompilationConfig()) + + return config + + @dataclass class VllmConfig: """Dataclass which contains all vllm-related configuration. This @@ -2073,6 +2254,8 @@ class VllmConfig: observability_config: Optional[ObservabilityConfig] = None prompt_adapter_config: Optional[PromptAdapterConfig] = None quant_config: Optional[QuantizationConfig] = None + compilation_config: CompilationConfig = field(default=None, + init=True) # type: ignore @staticmethod def _get_quantization_config( @@ -2133,6 +2316,12 @@ def __post_init__(self): self.quant_config = VllmConfig._get_quantization_config( self.model_config, self.load_config) + if self.compilation_config is None: + self.compilation_config = CompilationConfig.select_and_init_config( + ) + + current_platform.check_and_update_config(self) + def __str__(self): return ("model=%r, speculative_config=%r, tokenizer=%r, " "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " diff --git a/vllm/envs.py b/vllm/envs.py index f320e35971f94..716e835a555f1 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -69,7 +69,6 @@ VLLM_SKIP_P2P_CHECK: bool = False VLLM_TORCH_COMPILE_LEVEL: int = 0 VLLM_TORCH_COMPILE_CONFIG: Optional[str] = None - VLLM_CUSTOM_OPS: List[str] = [] VLLM_DISABLED_KERNELS: List[str] = [] VLLM_USE_V1: bool = False VLLM_ENABLE_V1_MULTIPROCESSING: bool = False @@ -217,18 +216,6 @@ def get_default_config_root(): "VLLM_TORCH_COMPILE_CONFIG": lambda: os.environ.get("VLLM_TORCH_COMPILE_CONFIG", None), - # Fine-grained control over which custom ops to enable/disable. - # Use 'all' to enable all, 'none' to disable all. - # Also specify a list of custom op names to enable (prefixed with a '+'), - # or disable (prefixed with a '-'). - # Examples: - # - 'all,-op1' to enable all except op1 - # - 'none,+op1,+op2' to enable only op1 and op2 - # By default, all custom ops are enabled when running without Inductor - # and disabled when running with Inductor (compile_level >= Inductor). - "VLLM_CUSTOM_OPS": - lambda: os.environ.get("VLLM_CUSTOM_OPS", "").replace(" ", "").split(","), - # local rank of the process in the distributed setting, used to determine # the GPU device id "LOCAL_RANK": diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 24d75f4df4e02..6ae7d7cf6964f 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -1,12 +1,10 @@ -from functools import lru_cache from typing import Dict, Type import torch.nn as nn -import vllm.envs as envs -from vllm.compilation.levels import CompilationLevel from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.plugins import get_current_vllm_config from vllm.utils import print_warning_once logger = init_logger(__name__) @@ -87,6 +85,8 @@ def dispatch_forward(self): @classmethod def enabled(cls) -> bool: # if no name, then it was not registered + compilation_config = get_current_vllm_config().compilation_config + custom_ops = compilation_config.custom_ops if not hasattr(cls, "name"): print_warning_once( f"Custom op {cls.__name__} was not registered, " @@ -94,22 +94,25 @@ def enabled(cls) -> bool: f"It will be enabled/disabled based on the global settings.") return CustomOp.default_on() - enabled = f"+{cls.name}" in envs.VLLM_CUSTOM_OPS - disabled = f"-{cls.name}" in envs.VLLM_CUSTOM_OPS + enabled = f"+{cls.name}" in custom_ops + disabled = f"-{cls.name}" in custom_ops assert not (enabled and disabled), f"Cannot enable and disable {cls.name}" return (CustomOp.default_on() or enabled) and not disabled - # On by default if VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE - # Specifying 'all' or 'none' in VLLM_CUSTOM_OPS takes precedence. @staticmethod - @lru_cache def default_on() -> bool: - count_none = envs.VLLM_CUSTOM_OPS.count("none") - count_all = envs.VLLM_CUSTOM_OPS.count("all") - assert count_none + count_all <= 1, "Can only specify 'none' or 'all'" - return envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE and \ + """ + On by default if level < CompilationLevel.PIECEWISE + Specifying 'all' or 'none' in custom_op takes precedence. + """ + from vllm.config import CompilationLevel + compilation_config = get_current_vllm_config().compilation_config + custom_ops = compilation_config.custom_ops + count_none = custom_ops.count("none") + count_all = custom_ops.count("all") + return compilation_config.level < CompilationLevel.PIECEWISE and \ not count_none > 0 or count_all > 0 # Dictionary of all custom ops (classes, indexed by registered name). diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 140b61fe6d56a..0f8b81c3ef40c 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -42,6 +42,7 @@ safetensors_weights_iterator) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform +from vllm.plugins import set_current_vllm_config from vllm.utils import is_pin_memory_available @@ -97,7 +98,8 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module: all_params = [param.name for param in signatures.parameters.values()] if "vllm_config" in all_params and "prefix" in all_params: # new-style model class - return model_class(vllm_config=vllm_config, prefix=prefix) + with set_current_vllm_config(vllm_config): + return model_class(vllm_config=vllm_config, prefix=prefix) msg = ("vLLM model class should accept `vllm_config` and `prefix` as " "input arguments. Possibly you have an old-style model class" " registered from out of tree and it is used for new vLLM version. " @@ -121,7 +123,8 @@ def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module: kwargs["lora_config"] = vllm_config.lora_config if "scheduler_config" in all_params: kwargs["scheduler_config"] = vllm_config.scheduler_config - return model_class(**kwargs) + with set_current_vllm_config(vllm_config): + return model_class(**kwargs) class BaseModelLoader(ABC): diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 81d8bdae2383c..970c0d1be617e 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -1,10 +1,15 @@ import enum import random -from typing import NamedTuple, Optional, Tuple, Union +from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union import numpy as np import torch +if TYPE_CHECKING: + from vllm.config import VllmConfig +else: + VllmConfig = None + class PlatformEnum(enum.Enum): CUDA = enum.auto() @@ -129,6 +134,19 @@ def seed_everything(cls, seed: int) -> None: np.random.seed(seed) torch.manual_seed(seed) + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + """ + Check and update the configuration for the current platform. + + It can raise an exception if the configuration is not compatible with + the current platform, or it can update the configuration to make it + compatible with the current platform. + + The config is passed by reference, so it can be modified in place. + """ + pass + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 8d0ce47df4040..c2e22bfc09f22 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -1,18 +1,16 @@ import os +from typing import TYPE_CHECKING import torch -import vllm.envs as envs -from vllm.compilation.levels import CompilationLevel from vllm.plugins import set_torch_compile_backend from .interface import Platform, PlatformEnum -if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ: - os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.DYNAMO_ONCE) - -assert envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE,\ - "TPU does not support Inductor." +if TYPE_CHECKING: + from vllm.config import VllmConfig +else: + VllmConfig = None set_torch_compile_backend("openxla") @@ -31,3 +29,12 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: @classmethod def inference_mode(cls): return torch.no_grad() + + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + from vllm.config import CompilationLevel + compilation_config = vllm_config.compilation_config + if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ: + compilation_config.level = CompilationLevel.DYNAMO_ONCE + assert compilation_config.level < CompilationLevel.PIECEWISE,\ + "TPU does not support Inductor." diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 7b1bbb14c5302..c20b9ec891d5d 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -1,11 +1,11 @@ import logging +from contextlib import contextmanager from typing import TYPE_CHECKING, Callable, Optional, Union import vllm.envs as envs if TYPE_CHECKING: - from vllm.compilation.config import CompilationConfig - from vllm.config import VllmConfig + from vllm.config import CompilationConfig, VllmConfig else: CompilationConfig = None VllmConfig = None @@ -72,3 +72,29 @@ def set_compilation_config(config: Optional[CompilationConfig]): def get_compilation_config() -> Optional[CompilationConfig]: return _compilation_config + + +_current_vllm_config: Optional[VllmConfig] = None + + +@contextmanager +def set_current_vllm_config(vllm_config: VllmConfig): + """ + Temporarily set the current VLLM config. + Used during model initialization. + We save the current VLLM config in a global variable, + so that all modules can access it, e.g. custom ops + can access the VLLM config to determine how to dispatch. + """ + global _current_vllm_config + old_vllm_config = _current_vllm_config + try: + _current_vllm_config = vllm_config + yield + finally: + _current_vllm_config = old_vllm_config + + +def get_current_vllm_config() -> VllmConfig: + assert _current_vllm_config is not None, "Current VLLM config is not set." + return _current_vllm_config diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index eebd1de96537f..d60f93a44f6dd 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,4 +1,3 @@ -import os import time from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple @@ -8,11 +7,8 @@ import torch.distributed import torch.nn as nn -from vllm import envs from vllm.compilation.compile_context import set_compile_context -from vllm.compilation.config import CompilationConfig -from vllm.compilation.levels import CompilationLevel -from vllm.config import VllmConfig +from vllm.config import CompilationConfig, CompilationLevel, VllmConfig from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger @@ -99,7 +95,7 @@ def __init__( pin_memory=self.pin_memory, ) - self.use_cuda_graph = (envs.VLLM_TORCH_COMPILE_LEVEL + self.use_cuda_graph = (self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager) # TODO(woosuk): Provide an option to tune the max cudagraph batch size. @@ -517,9 +513,9 @@ def load_model(self) -> None: # CUDA graphs do not work properly with the custom CUDA kernels. # FIXME(woosuk): Disable inductor to reduce the compilation time # and avoid any potential issues with the inductor. - os.environ["VLLM_CUSTOM_OPS"] = "none" set_compilation_config( CompilationConfig( + custom_ops=["none"], use_cudagraph=True, non_cudagraph_ops=["vllm.unified_v1_flash_attention"], use_inductor=True, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 22ee3f9f863e4..fd89f95445565 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -19,8 +19,7 @@ from vllm.attention.backends.abstract import AttentionState from vllm.attention.backends.utils import CommonAttentionState from vllm.compilation.compile_context import set_compile_context -from vllm.compilation.levels import CompilationLevel -from vllm.config import VllmConfig +from vllm.config import CompilationLevel, VllmConfig from vllm.core.scheduler import SchedulerOutputs from vllm.distributed import get_pp_group from vllm.distributed.parallel_state import graph_capture @@ -1142,8 +1141,8 @@ def load_model(self) -> None: "provided. Defaulting to scaling factors of 1.0. " "This may lead to less accurate results!") - if envs.VLLM_TORCH_COMPILE_LEVEL == CompilationLevel.DYNAMO_AS_IS \ - and supports_dynamo(): + if self.vllm_config.compilation_config.level ==\ + CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): from vllm.plugins import get_torch_compile_backend backend = get_torch_compile_backend() or "eager" self.model = torch.compile( diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index a721186137328..d7a641857a613 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -140,7 +140,7 @@ def load_model(self) -> None: model = get_model(vllm_config=self.vllm_config) model = model.eval() xm.wait_device_ops() - self.model = ModelWrapper(model) + self.model = ModelWrapper(model, self.vllm_config) def _dummy_run( self, @@ -669,13 +669,15 @@ def execute_model( class ModelWrapper(TorchCompileWrapperWithCustomDispatcher): - def __init__(self, model: nn.Module): + def __init__(self, model: nn.Module, vllm_config: VllmConfig): self.model = model compiled_callable = torch.compile(self.forward, backend="openxla", fullgraph=True, dynamic=False) - super().__init__(compiled_callable) + super().__init__( + compiled_callable, + compilation_level=vllm_config.compilation_config.level) def __call__(self, *args, is_prompt: bool, **kwargs): if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher: