Skip to content

Commit

Permalink
[1/N] torch.compile user interface design (vllm-project#10237)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: OmerD <[email protected]>
  • Loading branch information
youkaichao authored and omer-dayan committed Nov 14, 2024
1 parent 925ed8c commit 4e8d0fc
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 37 deletions.
14 changes: 9 additions & 5 deletions tests/compile/piecewise/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.levels import CompilationLevel
from vllm.config import VllmConfig
from vllm.utils import direct_register_custom_op

os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)

global_counter = 0

# create a library to hold the custom op
Expand Down Expand Up @@ -48,7 +47,11 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
@support_torch_compile
class SillyModel(nn.Module):

def __init__(self) -> None:
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = '',
**kwargs) -> None:
super().__init__()

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -74,11 +77,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

def test_simple_piecewise_compile():

model = SillyModel()

directory = os.path.dirname(__file__)
config = os.path.join(directory, "piecewise_compilation_config.json")
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)

model = SillyModel(vllm_config=VllmConfig(), prefix='')

inputs = torch.randn(100).cuda()

Expand Down
21 changes: 14 additions & 7 deletions tests/compile/piecewise/test_toy_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.compilation.levels import CompilationLevel
from vllm.config import VllmConfig
from vllm.plugins import set_compilation_config
from vllm.utils import direct_register_custom_op

Expand Down Expand Up @@ -195,9 +196,15 @@ def forward(
return hidden_states, residual


@support_torch_compile
class LlamaModel(nn.Module):

def __init__(self, config: LlamaConfig) -> None:
def __init__(self,
*,
vllm_config: VllmConfig,
config: LlamaConfig,
prefix: str = '',
**kwargs) -> None:
super().__init__()
self.embedding_tokens = nn.Embedding(
num_embeddings=config.vocab_size,
Expand Down Expand Up @@ -265,10 +272,9 @@ def run_model(llama_config,
CompilationLevel.NO_COMPILATION)
set_compilation_config(None)

cls = LlamaModel
if use_compile:
cls = support_torch_compile(LlamaModel)
model = cls(llama_config).eval().cuda()
model = LlamaModel(config=llama_config,
vllm_config=VllmConfig(),
prefix="").eval().cuda()

B = 16 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
Expand Down Expand Up @@ -357,7 +363,6 @@ def test_toy_llama():
def benchmark():
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
from triton.testing import do_bench
cls = support_torch_compile(LlamaModel)

# similar to llama 3.1-8B
llama_config = LlamaConfig(hidden_size=4096,
Expand Down Expand Up @@ -390,7 +395,9 @@ def benchmark():
else:
set_compilation_config(None)

model = cls(llama_config).eval().cuda().to(torch.bfloat16)
model = LlamaModel(config=llama_config,
vllm_config=VllmConfig(),
prefix="").eval().cuda().to(torch.bfloat16)

B = 256 # max batch size
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
Expand Down
27 changes: 14 additions & 13 deletions vllm/compilation/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
from vllm.utils import supports_dynamo
Expand Down Expand Up @@ -110,26 +111,26 @@ def _support_torch_compile(cls: type,
"""
A decorator to add support for compiling the forward method of a class.
"""

# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
# will handle the compilation, so we don't need to do anything here.
if envs.VLLM_TORCH_COMPILE_LEVEL in [
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
] or not supports_dynamo():
if TorchCompileWrapperWithCustomDispatcher in cls.__bases__:
# support decorating multiple times
return cls

# take care of method resolution order
# make sure super().__init__ is called on the base class
# other than TorchCompileWrapperWithCustomDispatcher
if TorchCompileWrapperWithCustomDispatcher not in cls.__bases__:
# support decorating multiple times
cls.__bases__ = cls.__bases__ + (
TorchCompileWrapperWithCustomDispatcher, )
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )

old_init = cls.__init__ # type: ignore

def __init__(self, *args, **kwargs):
old_init(self, *args, **kwargs)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs):
old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs)
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
# will handle the compilation, so we don't need to do anything here.
self.do_not_compile = envs.VLLM_TORCH_COMPILE_LEVEL in [
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
] or not supports_dynamo()
if self.do_not_compile:
return
TorchCompileWrapperWithCustomDispatcher.__init__(self)

cls.__init__ = __init__ # type: ignore
Expand All @@ -138,7 +139,7 @@ def __call__(self, *args, **kwargs):
# torch.compiler.is_compiling() means we are inside the compilation
# e.g. TPU has the compilation logic in model runner, so we don't
# need to compile the model inside.
if torch.compiler.is_compiling():
if self.do_not_compile or torch.compiler.is_compiling():
return self.forward(*args, **kwargs)

# the first compilation needs to have dynamic shapes marked
Expand Down
30 changes: 18 additions & 12 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2054,12 +2054,15 @@ class VllmConfig:
simplifies passing around the distinct configurations in the codebase.
"""

model_config: ModelConfig
cache_config: CacheConfig
parallel_config: ParallelConfig
scheduler_config: SchedulerConfig
device_config: DeviceConfig
load_config: LoadConfig
model_config: ModelConfig = field(default=None, init=True) # type: ignore
cache_config: CacheConfig = field(default=None, init=True) # type: ignore
parallel_config: ParallelConfig = field(default=None,
init=True) # type: ignore
scheduler_config: SchedulerConfig = field(default=None,
init=True) # type: ignore
device_config: DeviceConfig = field(default=None,
init=True) # type: ignore
load_config: LoadConfig = field(default=None, init=True) # type: ignore
lora_config: Optional[LoRAConfig] = None
speculative_config: Optional[SpeculativeConfig] = None
decoding_config: Optional[DecodingConfig] = None
Expand Down Expand Up @@ -2104,11 +2107,14 @@ def with_hf_config(self, hf_config: PretrainedConfig) -> "VllmConfig":
def __post_init__(self):
"""Verify configs are valid & consistent with each other.
"""
self.model_config.verify_async_output_proc(self.parallel_config,
self.speculative_config,
self.device_config)
self.model_config.verify_with_parallel_config(self.parallel_config)
self.cache_config.verify_with_parallel_config(self.parallel_config)
if self.model_config is not None:
self.model_config.verify_async_output_proc(self.parallel_config,
self.speculative_config,
self.device_config)
self.model_config.verify_with_parallel_config(self.parallel_config)

if self.cache_config is not None:
self.cache_config.verify_with_parallel_config(self.parallel_config)

if self.lora_config:
self.lora_config.verify_with_model_config(self.model_config)
Expand Down Expand Up @@ -2162,4 +2168,4 @@ def __str__(self):
self.scheduler_config.num_scheduler_steps,
self.cache_config.enable_prefix_caching,
self.model_config.use_async_output_proc,
self.model_config.mm_processor_kwargs)
self.model_config.mm_processor_kwargs)

0 comments on commit 4e8d0fc

Please sign in to comment.