Skip to content

Commit

Permalink
[torch.compile] support encoder based models (vllm-project#10613)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Nov 25, 2024
1 parent 7ea3cd7 commit 571841b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
10 changes: 10 additions & 0 deletions tests/compile/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@ class TestSetting:
method="encode",
fullgraph=True,
),
# encoder-based embedding model (BERT)
TestSetting(
model="BAAI/bge-base-en-v1.5",
model_args=["--task", "embedding"],
pp_size=1,
tp_size=1,
attn_backend="XFORMERS",
method="encode",
fullgraph=True,
),
# vision language model
TestSetting(
model="microsoft/Phi-3.5-vision-instruct",
Expand Down
17 changes: 7 additions & 10 deletions vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from transformers import BertConfig

from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
Expand Down Expand Up @@ -92,14 +93,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return pooled_output


@support_torch_compile
class BertEncoder(nn.Module):

def __init__(self,
config: BertConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.layer = nn.ModuleList([
BertLayer(config=config,
cache_config=cache_config,
Expand Down Expand Up @@ -336,12 +337,8 @@ def __init__(self,
add_pooling_layer: bool = False):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.embeddings = embedding_class(config)
self.encoder = BertEncoder(config,
cache_config,
quant_config,
self.encoder = BertEncoder(vllm_config=vllm_config,
prefix=f"{prefix}.encoder")
self.pooler = BertPooler(config) if add_pooling_layer else None

Expand Down

0 comments on commit 571841b

Please sign in to comment.