diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index f838e7ad74285..0645790e24f31 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -305,17 +305,17 @@ def forward(self, hidden_states: torch.Tensor, class BertModel(nn.Module): - def __init__(self, *, - vllm_config: VllmConfig, + def __init__(self, + *, + config: BertConfig, + vllm_config: VllmConfig, prefix: str = "", embedding_class: type = BertEmbedding): super().__init__() - self.embeddings = embedding_class(config) config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - - self.embeddings = BertEmbedding(config) + self.embeddings = embedding_class(config) self.encoder = BertEncoder(config, cache_config, quant_config,