diff --git a/tests/model_executor/test_model_load_with_params.py b/tests/model_executor/test_model_load_with_params.py index 3b2c3a19a83f0..6d4fb5e81194b 100644 --- a/tests/model_executor/test_model_load_with_params.py +++ b/tests/model_executor/test_model_load_with_params.py @@ -1,10 +1,11 @@ import os +from vllm.model_executor.layers.pooler import PoolingType + MAX_MODEL_LEN = 128 MODEL_NAME = os.environ.get("MODEL_NAME", "BAAI/bge-base-en-v1.5") REVISION = os.environ.get("REVISION", "main") - def test_model_loading_with_params(vllm_runner): """ Test parameter weight loading with tp>1. @@ -18,10 +19,20 @@ def test_model_loading_with_params(vllm_runner): model_config = model.model.llm_engine.model_config - print(model.model.llm_engine.__dict__.items()) + model_tokenizer = model.model.llm_engine.tokenizer + # asserts on the bert model config file assert model_config.bert_config["max_seq_length"] == 512 assert model_config.bert_config["do_lower_case"] - assert model_config.pooling_config.pooling_type == 2 + + # asserts on the pooling config files + assert model_config.pooling_config.pooling_type == PoolingType.CLS assert model_config.pooling_config.normalize + + # asserts on the tokenizer loaded + assert model_tokenizer.tokenizer_id == "BAAI/bge-base-en-v1.5" + assert model_tokenizer.tokenizer_config["do_lower_case"] + assert model_tokenizer.tokenizer.model_max_length == 512 + + # assert output assert output diff --git a/vllm/config.py b/vllm/config.py index c93d21a1e6951..67572d3edceb1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -180,7 +180,8 @@ def __init__(self, config_format) self.hf_text_config = get_hf_text_config(self.hf_config) self.pooling_config = self.get_pooling_config() - self.bert_config, self.do_lower_case = self._get_bert_config() + self.bert_config = self._get_bert_config() + self.do_lower_case = self._get_bert_config() self.hf_image_processor_config = get_hf_image_processor_config( self.model, revision) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) @@ -1756,6 +1757,7 @@ def _get_and_verify_max_len( "max_seq_length", "seq_len", ] + # Choose the smallest "max_length" from the possible keys. max_len_key = None for key in possible_keys: diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py index d12d5246d9222..0b5b704d0a7c7 100644 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -25,8 +25,7 @@ def init_tokenizer_from_configs(model_config: ModelConfig, trust_remote_code=model_config.trust_remote_code, revision=model_config.tokenizer_revision) - if model_config.bert_config is not None and "do_lower_case" - in model_config.bert_config: + if model_config.bert_config is not None and "do_lower_case" in model_config.bert_config: init_kwargs["do_lower_case"] = model_config.bert_config["do_lower_case"] return get_tokenizer_group(parallel_config.tokenizer_pool_config,