From 3de7d49ae4a46ecd3599f0fce4079138fb5238ba Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Thu, 24 Oct 2024 10:19:14 -0300 Subject: [PATCH] More assertions and reverts part of the linting as requested Signed-off-by: Flavia Beo --- tests/test_config.py | 18 ++++++++++++++++++ vllm/config.py | 2 +- vllm/model_executor/model_loader/loader.py | 19 ++++++++++--------- 3 files changed, 29 insertions(+), 10 deletions(-) diff --git a/tests/test_config.py b/tests/test_config.py index d573b4ca64698..65c532be1156d 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -125,6 +125,7 @@ def test_get_pooling_config(): def test_get_bert_sentence_transformer_config(): bge_model_config = ModelConfig( model="BAAI/bge-base-en-v1.5", + task="auto", tokenizer="BAAI/bge-base-en-v1.5", tokenizer_mode="auto", trust_remote_code=False, @@ -139,6 +140,23 @@ def test_get_bert_sentence_transformer_config(): assert bert_bge_model_config["do_lower_case"] +def test_get_tokenization_sentence_transformer_config(): + bge_model_config = ModelConfig( + model="BAAI/bge-base-en-v1.5", + task="auto", + tokenizer="BAAI/bge-base-en-v1.5", + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + revision=None, + ) + + bert_config = bge_model_config._get_bert_tokenization_config() + + assert bert_config + + def test_rope_customization(): TEST_ROPE_SCALING = {"rope_type": "dynamic", "factor": 2.0} TEST_ROPE_THETA = 16_000_000.0 diff --git a/vllm/config.py b/vllm/config.py index cdda385dfaf27..b018eed27f042 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1820,7 +1820,7 @@ def _get_and_verify_max_len( if bert_config and "max_seq_lenght" in bert_config: derived_max_model_len = bert_config["max_seq_length"] - + # If the user specified a max length, make sure it is smaller than the # derived length from the HF model config. if max_model_len is None: diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 133654e53c5d4..d322838a9367c 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -177,15 +177,16 @@ def _initialize_model( """Initialize a model with the given configurations.""" model_class, _ = get_model_architecture(model_config) - return build_model(model_class, - model_config.hf_config, - cache_config=cache_config, - quant_config=_get_quantization_config( - model_config, load_config), - lora_config=lora_config, - multimodal_config=model_config.multimodal_config, - scheduler_config=scheduler_config, - pooling_config=model_config.pooling_config) + return build_model( + model_class, + model_config.hf_config, + cache_config=cache_config, + quant_config=_get_quantization_config(model_config, load_config), + lora_config=lora_config, + multimodal_config=model_config.multimodal_config, + scheduler_config=scheduler_config, + pooling_config=model_config.pooling_config + ) class BaseModelLoader(ABC):