From 571841b7fcc67f8b1d171522f6249ed4224033e1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 24 Nov 2024 21:24:33 -0800 Subject: [PATCH] [torch.compile] support encoder based models (#10613) Signed-off-by: youkaichao --- tests/compile/test_basic_correctness.py | 10 ++++++++++ vllm/model_executor/models/bert.py | 17 +++++++---------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index b7170886d2556..99781c55b672e 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -62,6 +62,16 @@ class TestSetting: method="encode", fullgraph=True, ), + # encoder-based embedding model (BERT) + TestSetting( + model="BAAI/bge-base-en-v1.5", + model_args=["--task", "embedding"], + pp_size=1, + tp_size=1, + attn_backend="XFORMERS", + method="encode", + fullgraph=True, + ), # vision language model TestSetting( model="microsoft/Phi-3.5-vision-instruct", diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 1fc87bc650d92..f570d6d3c12b3 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -5,6 +5,7 @@ from transformers import BertConfig from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, PoolerConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn @@ -92,14 +93,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return pooled_output +@support_torch_compile class BertEncoder(nn.Module): - def __init__(self, - config: BertConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config self.layer = nn.ModuleList([ BertLayer(config=config, cache_config=cache_config, @@ -336,12 +337,8 @@ def __init__(self, add_pooling_layer: bool = False): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config self.embeddings = embedding_class(config) - self.encoder = BertEncoder(config, - cache_config, - quant_config, + self.encoder = BertEncoder(vllm_config=vllm_config, prefix=f"{prefix}.encoder") self.pooler = BertPooler(config) if add_pooling_layer else None