From 00506abb9cc298bd7427244372544cd5ee8c9a47 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Thu, 24 Oct 2024 19:00:23 -0300 Subject: [PATCH] fix loading of non-bert models and fix tests Signed-off-by: Max de Bayser --- .../test_model_load_with_params.py | 6 +++ tests/test_config.py | 6 ++- vllm/config.py | 14 +++--- vllm/engine/llm_engine.py | 47 +++++++------------ vllm/model_executor/model_loader/loader.py | 10 ++-- vllm/transformers_utils/config.py | 7 ++- 6 files changed, 45 insertions(+), 45 deletions(-) diff --git a/tests/model_executor/test_model_load_with_params.py b/tests/model_executor/test_model_load_with_params.py index e3480f0c8f6db..7eab521848ad6 100644 --- a/tests/model_executor/test_model_load_with_params.py +++ b/tests/model_executor/test_model_load_with_params.py @@ -1,6 +1,7 @@ import os from vllm.model_executor.layers.pooler import PoolingType +from vllm.model_executor.models.bert import BertEmbeddingModel MAX_MODEL_LEN = 128 MODEL_NAME = os.environ.get("MODEL_NAME", "BAAI/bge-base-en-v1.5") @@ -35,5 +36,10 @@ def test_model_loading_with_params(vllm_runner): assert model_tokenizer.tokenizer_config["do_lower_case"] assert model_tokenizer.tokenizer.model_max_length == 512 + model = model.model.llm_engine.model_executor\ + .driver_worker.model_runner.model + assert isinstance(model, BertEmbeddingModel) + assert model._pooler.pooling_type == PoolingType.CLS + assert model._pooler.normalize # assert output assert output diff --git a/tests/test_config.py b/tests/test_config.py index 85202a136f478..9c484dd4f4266 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -104,9 +104,11 @@ def test_get_sliding_window(): def test_get_pooling_config(): + model_id = "sentence-transformers/all-MiniLM-L12-v2" minilm_model_config = ModelConfig( - "sentence-transformers/all-MiniLM-L12-v2", - "sentence-transformers/all-MiniLM-L12-v2", + model_id, + task="auto", + tokenizer=model_id, tokenizer_mode="auto", trust_remote_code=False, seed=0, diff --git a/vllm/config.py b/vllm/config.py index 88bde618a65d3..b7be41517bbb7 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -181,7 +181,6 @@ def __init__(self, self.hf_text_config = get_hf_text_config(self.hf_config) self.pooling_config = self.get_pooling_config() self.bert_config = self._get_bert_config() - self.do_lower_case = self._get_bert_config() self.hf_image_processor_config = get_hf_image_processor_config( self.model, revision) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) @@ -253,11 +252,8 @@ def _init_multimodal_config( return None def _get_bert_config(self): - bert_config = get_sentence_transformer_tokenizer_config( + return get_sentence_transformer_tokenizer_config( self.model, self.revision) - if bert_config is not None: - return bert_config - return None def _init_attention_free(self) -> bool: architectures = getattr(self.hf_config, "architectures", []) @@ -422,10 +418,12 @@ def _verify_bnb_config(self) -> None: "fallback to the eager mode.") self.enforce_eager = True - def get_pooling_config(self) -> PoolingConfig: + def get_pooling_config(self) -> Optional[PoolingConfig]: pooling_config = get_pooling_config(self.model, self.revision) - return PoolingConfig(pooling_config["pooling_type"], - pooling_config["normalize"]) + if pooling_config is not None: + return PoolingConfig(pooling_config["pooling_type"], + pooling_config["normalize"]) + return None def verify_async_output_proc(self, parallel_config, speculative_config, device_config) -> None: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index e96ad40224233..06e3cdad2d0cf 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -256,42 +256,29 @@ def __init__( "use_async_output_proc=%s, use_cached_outputs=%s, " "pooling_config_type=%s, normalize=%s, " "chat_template_text_format=%s, mm_processor_kwargs=%s)", - VLLM_VERSION, - model_config.model, - speculative_config, - model_config.tokenizer, - model_config.skip_tokenizer_init, - model_config.tokenizer_mode, - model_config.revision, - model_config.override_neuron_config, - model_config.rope_scaling, - model_config.rope_theta, - model_config.tokenizer_revision, - model_config.trust_remote_code, - model_config.dtype, - model_config.max_model_len, - load_config.download_dir, - load_config.load_format, - parallel_config.tensor_parallel_size, + VLLM_VERSION, model_config.model, speculative_config, + model_config.tokenizer, model_config.skip_tokenizer_init, + model_config.tokenizer_mode, model_config.revision, + model_config.override_neuron_config, model_config.rope_scaling, + model_config.rope_theta, model_config.tokenizer_revision, + model_config.trust_remote_code, model_config.dtype, + model_config.max_model_len, load_config.download_dir, + load_config.load_format, parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size, parallel_config.disable_custom_all_reduce, - model_config.quantization, - model_config.enforce_eager, - cache_config.cache_dtype, - model_config.quantization_param_path, - device_config.device, - decoding_config, - observability_config, - model_config.seed, - model_config.served_model_name, + model_config.quantization, model_config.enforce_eager, + cache_config.cache_dtype, model_config.quantization_param_path, + device_config.device, decoding_config, observability_config, + model_config.seed, model_config.served_model_name, scheduler_config.num_scheduler_steps, scheduler_config.chunked_prefill_enabled, scheduler_config.multi_step_stream_outputs, cache_config.enable_prefix_caching, - model_config.use_async_output_proc, - use_cached_outputs, - model_config.pooling_config.pooling_type, - model_config.pooling_config.normalize, + model_config.use_async_output_proc, use_cached_outputs, + model_config.pooling_config.pooling_type + if model_config.pooling_config is not None else None, + model_config.pooling_config.normalize + if model_config.pooling_config is not None else None, model_config.chat_template_text_format, model_config.mm_processor_kwargs) # TODO(woosuk): Print more configs in debug mode. diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 133654e53c5d4..91c7b870cf671 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -123,7 +123,8 @@ def _get_model_initialization_kwargs( model_class: Type[nn.Module], lora_config: Optional[LoRAConfig], multimodal_config: Optional[MultiModalConfig], - scheduler_config: Optional[SchedulerConfig] = None) -> Dict[str, Any]: + scheduler_config: Optional[SchedulerConfig] = None, + pooling_config: Optional[PoolingConfig] = None) -> Dict[str, Any]: """Get extra kwargs for model initialization.""" extra_kwargs: Dict[str, Any] = {} @@ -145,6 +146,9 @@ def _get_model_initialization_kwargs( if has_inner_state(model_class) and scheduler_config: extra_kwargs["scheduler_config"] = scheduler_config + if pooling_config is not None: + extra_kwargs["pooling_config"] = pooling_config + return extra_kwargs @@ -159,12 +163,12 @@ def build_model(model_class: Type[nn.Module], pooling_config: Optional[PoolingConfig] = None) -> nn.Module: extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config, multimodal_config, - scheduler_config) + scheduler_config, + pooling_config) return model_class(config=hf_config, cache_config=cache_config, quant_config=quant_config, - pooling_config=pooling_config, **extra_kwargs) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index a92146eafffd9..513541acaeeff 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -9,13 +9,13 @@ from huggingface_hub.utils import (EntryNotFoundError, LocalEntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError) +from transformers import GenerationConfig, PretrainedConfig from transformers.models.auto.image_processing_auto import ( get_image_processor_config) from transformers.models.auto.modeling_auto import ( MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME -from transformers import GenerationConfig, PretrainedConfig from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger # yapf conflicts with isort for this block @@ -256,7 +256,7 @@ def get_hf_file_to_dict(file_name, model, revision): hf_hub_file = hf_hub_download(model, file_name, revision=revision) except (RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError, LocalEntryNotFoundError) as e: - logger.info("File or repository not found in hf_hub_download", e) + logger.debug("File or repository not found in hf_hub_download", e) return None file_path = Path(hf_hub_file) @@ -285,6 +285,9 @@ def get_pooling_config(model, revision='main'): modules_file_name = "modules.json" modules_dict = get_hf_file_to_dict(modules_file_name, model, revision) + if modules_dict is None: + return None + pooling = next((item for item in modules_dict if item["type"] == "sentence_transformers.models.Pooling"), None)