From 4e8d0fc327f2c0fde6ade6e0efc22d7d301126a1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 11 Nov 2024 18:01:06 -0800 Subject: [PATCH] [1/N] torch.compile user interface design (#10237) Signed-off-by: youkaichao Signed-off-by: OmerD --- tests/compile/piecewise/test_simple.py | 14 +++++++---- tests/compile/piecewise/test_toy_llama.py | 21 ++++++++++------ vllm/compilation/decorators.py | 27 ++++++++++---------- vllm/config.py | 30 ++++++++++++++--------- 4 files changed, 55 insertions(+), 37 deletions(-) diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index fcfe80d8e4041..c631850ecdedb 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -12,10 +12,9 @@ from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile from vllm.compilation.levels import CompilationLevel +from vllm.config import VllmConfig from vllm.utils import direct_register_custom_op -os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE) - global_counter = 0 # create a library to hold the custom op @@ -48,7 +47,11 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @support_torch_compile class SillyModel(nn.Module): - def __init__(self) -> None: + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = '', + **kwargs) -> None: super().__init__() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -74,11 +77,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def test_simple_piecewise_compile(): - model = SillyModel() - directory = os.path.dirname(__file__) config = os.path.join(directory, "piecewise_compilation_config.json") os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config + os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE) + + model = SillyModel(vllm_config=VllmConfig(), prefix='') inputs = torch.randn(100).cuda() diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index 73fa9e9906936..c363a587a818e 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -19,6 +19,7 @@ from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile from vllm.compilation.levels import CompilationLevel +from vllm.config import VllmConfig from vllm.plugins import set_compilation_config from vllm.utils import direct_register_custom_op @@ -195,9 +196,15 @@ def forward( return hidden_states, residual +@support_torch_compile class LlamaModel(nn.Module): - def __init__(self, config: LlamaConfig) -> None: + def __init__(self, + *, + vllm_config: VllmConfig, + config: LlamaConfig, + prefix: str = '', + **kwargs) -> None: super().__init__() self.embedding_tokens = nn.Embedding( num_embeddings=config.vocab_size, @@ -265,10 +272,9 @@ def run_model(llama_config, CompilationLevel.NO_COMPILATION) set_compilation_config(None) - cls = LlamaModel - if use_compile: - cls = support_torch_compile(LlamaModel) - model = cls(llama_config).eval().cuda() + model = LlamaModel(config=llama_config, + vllm_config=VllmConfig(), + prefix="").eval().cuda() B = 16 # max batch size input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() @@ -357,7 +363,6 @@ def test_toy_llama(): def benchmark(): os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE) from triton.testing import do_bench - cls = support_torch_compile(LlamaModel) # similar to llama 3.1-8B llama_config = LlamaConfig(hidden_size=4096, @@ -390,7 +395,9 @@ def benchmark(): else: set_compilation_config(None) - model = cls(llama_config).eval().cuda().to(torch.bfloat16) + model = LlamaModel(config=llama_config, + vllm_config=VllmConfig(), + prefix="").eval().cuda().to(torch.bfloat16) B = 256 # max batch size input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 3053e57e0b63b..ca1e96a33c014 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -6,6 +6,7 @@ import vllm.envs as envs from vllm.compilation.levels import CompilationLevel from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.sequence import IntermediateTensors from vllm.utils import supports_dynamo @@ -110,26 +111,26 @@ def _support_torch_compile(cls: type, """ A decorator to add support for compiling the forward method of a class. """ - - # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner - # will handle the compilation, so we don't need to do anything here. - if envs.VLLM_TORCH_COMPILE_LEVEL in [ - CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS - ] or not supports_dynamo(): + if TorchCompileWrapperWithCustomDispatcher in cls.__bases__: + # support decorating multiple times return cls # take care of method resolution order # make sure super().__init__ is called on the base class # other than TorchCompileWrapperWithCustomDispatcher - if TorchCompileWrapperWithCustomDispatcher not in cls.__bases__: - # support decorating multiple times - cls.__bases__ = cls.__bases__ + ( - TorchCompileWrapperWithCustomDispatcher, ) + cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, ) old_init = cls.__init__ # type: ignore - def __init__(self, *args, **kwargs): - old_init(self, *args, **kwargs) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): + old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) + # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner + # will handle the compilation, so we don't need to do anything here. + self.do_not_compile = envs.VLLM_TORCH_COMPILE_LEVEL in [ + CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS + ] or not supports_dynamo() + if self.do_not_compile: + return TorchCompileWrapperWithCustomDispatcher.__init__(self) cls.__init__ = __init__ # type: ignore @@ -138,7 +139,7 @@ def __call__(self, *args, **kwargs): # torch.compiler.is_compiling() means we are inside the compilation # e.g. TPU has the compilation logic in model runner, so we don't # need to compile the model inside. - if torch.compiler.is_compiling(): + if self.do_not_compile or torch.compiler.is_compiling(): return self.forward(*args, **kwargs) # the first compilation needs to have dynamic shapes marked diff --git a/vllm/config.py b/vllm/config.py index 9c0d2d4764332..fc06b4fe853a6 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2054,12 +2054,15 @@ class VllmConfig: simplifies passing around the distinct configurations in the codebase. """ - model_config: ModelConfig - cache_config: CacheConfig - parallel_config: ParallelConfig - scheduler_config: SchedulerConfig - device_config: DeviceConfig - load_config: LoadConfig + model_config: ModelConfig = field(default=None, init=True) # type: ignore + cache_config: CacheConfig = field(default=None, init=True) # type: ignore + parallel_config: ParallelConfig = field(default=None, + init=True) # type: ignore + scheduler_config: SchedulerConfig = field(default=None, + init=True) # type: ignore + device_config: DeviceConfig = field(default=None, + init=True) # type: ignore + load_config: LoadConfig = field(default=None, init=True) # type: ignore lora_config: Optional[LoRAConfig] = None speculative_config: Optional[SpeculativeConfig] = None decoding_config: Optional[DecodingConfig] = None @@ -2104,11 +2107,14 @@ def with_hf_config(self, hf_config: PretrainedConfig) -> "VllmConfig": def __post_init__(self): """Verify configs are valid & consistent with each other. """ - self.model_config.verify_async_output_proc(self.parallel_config, - self.speculative_config, - self.device_config) - self.model_config.verify_with_parallel_config(self.parallel_config) - self.cache_config.verify_with_parallel_config(self.parallel_config) + if self.model_config is not None: + self.model_config.verify_async_output_proc(self.parallel_config, + self.speculative_config, + self.device_config) + self.model_config.verify_with_parallel_config(self.parallel_config) + + if self.cache_config is not None: + self.cache_config.verify_with_parallel_config(self.parallel_config) if self.lora_config: self.lora_config.verify_with_model_config(self.model_config) @@ -2162,4 +2168,4 @@ def __str__(self): self.scheduler_config.num_scheduler_steps, self.cache_config.enable_prefix_caching, self.model_config.use_async_output_proc, - self.model_config.mm_processor_kwargs) \ No newline at end of file + self.model_config.mm_processor_kwargs)