Skip to content

Commit 5672485

Browse files
wangxiyuanyoukaichao
authored andcommitted
[platform] support custom torch.compile backend key (vllm-project#11318)
Signed-off-by: wangxiyuan <[email protected]> Signed-off-by: youkaichao <[email protected]> Co-authored-by: youkaichao <[email protected]>
1 parent 04f14ac commit 5672485

File tree

5 files changed

+14
-5
lines changed

5 files changed

+14
-5
lines changed

vllm/model_executor/layers/rejection_sampler.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from vllm.logger import init_logger
1010
from vllm.model_executor.layers.spec_decode_base_sampler import (
1111
SpecDecodeStochasticBaseSampler)
12+
from vllm.platforms import current_platform
1213

1314
logger = init_logger(__name__)
1415

@@ -368,7 +369,7 @@ def _smallest_positive_value(self) -> float:
368369
# Note that we always sample with replacement.
369370
# probs will be modified in place, but this is fine, as we pass
370371
# in a copy already.
371-
@torch.compile(dynamic=True)
372+
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
372373
def _multinomial(
373374
probs: torch.Tensor,
374375
num_samples: int,

vllm/model_executor/layers/vocab_parallel_embedding.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def __post_init__(self):
133133
assert self.num_added_elements <= self.num_added_elements_padded
134134

135135

136-
@torch.compile(dynamic=True)
136+
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
137137
def get_masked_input_and_mask(
138138
input_: torch.Tensor, org_vocab_start_index: int,
139139
org_vocab_end_index: int, num_org_vocab_padding: int,

vllm/model_executor/models/commandr.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
row_parallel_weight_loader)
4646
from vllm.model_executor.sampling_metadata import SamplingMetadata
4747
from vllm.model_executor.utils import set_weight_attrs
48+
from vllm.platforms import current_platform
4849
from vllm.sequence import IntermediateTensors
4950

5051
from .interfaces import SupportsLoRA, SupportsPP
@@ -53,7 +54,7 @@
5354
maybe_prefix)
5455

5556

56-
@torch.compile
57+
@torch.compile(backend=current_platform.simple_compile_backend)
5758
def layer_norm_func(hidden_states, weight, variance_epsilon):
5859
input_dtype = hidden_states.dtype
5960
hidden_states = hidden_states.to(torch.float32)

vllm/model_executor/models/phi3_small.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
2121
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
2222
from vllm.model_executor.sampling_metadata import SamplingMetadata
23+
from vllm.platforms import current_platform
2324
from vllm.sequence import IntermediateTensors
2425

2526
from .interfaces import SupportsPP
@@ -54,12 +55,12 @@ def weight_loader(self, param: torch.nn.Parameter,
5455
return load_column_parallel_weight(param, loaded_weight)
5556

5657

57-
@torch.compile(dynamic=True)
58+
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
5859
def quick_gelu(x):
5960
return x * torch.sigmoid(1.702 * x)
6061

6162

62-
@torch.compile(dynamic=True)
63+
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
6364
def gegelu(input, limit: Optional[float] = None):
6465
a_gelu, a_linear = input[..., ::2], input[..., 1::2]
6566
if limit is not None:

vllm/platforms/interface.py

+6
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,12 @@ class Platform:
8282
# check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
8383
# use "CPU" as a fallback for platforms not registered in PyTorch
8484
dispatch_key: str = "CPU"
85+
# The torch.compile backend for compiling simple and
86+
# standalone functions. The default value is "inductor" to keep
87+
# the same behavior as PyTorch.
88+
# NOTE: for the forward part of the model, vLLM has another separate
89+
# compilation strategy.
90+
simple_compile_backend: str = "inductor"
8591
supported_quantization: list[str] = []
8692

8793
def is_cuda(self) -> bool:

0 commit comments

Comments
 (0)