diff --git a/tests/model_executor/test_model_load_with_params.py b/tests/model_executor/test_model_load_with_params.py new file mode 100644 index 0000000000000..7eab521848ad6 --- /dev/null +++ b/tests/model_executor/test_model_load_with_params.py @@ -0,0 +1,45 @@ +import os + +from vllm.model_executor.layers.pooler import PoolingType +from vllm.model_executor.models.bert import BertEmbeddingModel + +MAX_MODEL_LEN = 128 +MODEL_NAME = os.environ.get("MODEL_NAME", "BAAI/bge-base-en-v1.5") +REVISION = os.environ.get("REVISION", "main") + + +def test_model_loading_with_params(vllm_runner): + """ + Test parameter weight loading with tp>1. + """ + with vllm_runner(model_name=MODEL_NAME, + revision=REVISION, + dtype="float16", + max_model_len=MAX_MODEL_LEN) as model: + output = model.encode("Write a short story about a robot that" + " dreams for the first time.\n") + + model_config = model.model.llm_engine.model_config + + model_tokenizer = model.model.llm_engine.tokenizer + + # asserts on the bert model config file + assert model_config.bert_config["max_seq_length"] == 512 + assert model_config.bert_config["do_lower_case"] + + # asserts on the pooling config files + assert model_config.pooling_config.pooling_type == PoolingType.CLS + assert model_config.pooling_config.normalize + + # asserts on the tokenizer loaded + assert model_tokenizer.tokenizer_id == "BAAI/bge-base-en-v1.5" + assert model_tokenizer.tokenizer_config["do_lower_case"] + assert model_tokenizer.tokenizer.model_max_length == 512 + + model = model.model.llm_engine.model_executor\ + .driver_worker.model_runner.model + assert isinstance(model, BertEmbeddingModel) + assert model._pooler.pooling_type == PoolingType.CLS + assert model._pooler.normalize + # assert output + assert output diff --git a/tests/test_config.py b/tests/test_config.py index 69918b67607d9..9c484dd4f4266 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,6 +1,7 @@ import pytest from vllm.config import ModelConfig +from vllm.model_executor.layers.pooler import PoolingConfig, PoolingType @pytest.mark.parametrize(("model_id", "expected_task"), [ @@ -102,6 +103,45 @@ def test_get_sliding_window(): assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW +def test_get_pooling_config(): + model_id = "sentence-transformers/all-MiniLM-L12-v2" + minilm_model_config = ModelConfig( + model_id, + task="auto", + tokenizer=model_id, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + revision=None, + ) + + minilm_pooling_config = minilm_model_config.get_pooling_config() + + assert isinstance(minilm_model_config.pooling_config, PoolingConfig) + assert minilm_pooling_config.normalize + assert isinstance(minilm_pooling_config.pooling_type, PoolingType) + assert minilm_pooling_config.pooling_type == PoolingType.MEAN + + +def test_get_bert_tokenization_sentence_transformer_config(): + bge_model_config = ModelConfig( + model="BAAI/bge-base-en-v1.5", + task="auto", + tokenizer="BAAI/bge-base-en-v1.5", + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + revision=None, + ) + + bert_bge_model_config = bge_model_config._get_bert_config() + + assert bert_bge_model_config["max_seq_length"] == 512 + assert bert_bge_model_config["do_lower_case"] + + def test_rope_customization(): TEST_ROPE_SCALING = {"rope_type": "dynamic", "factor": 2.0} TEST_ROPE_THETA = 16_000_000.0 diff --git a/vllm/config.py b/vllm/config.py index 25f841231dedd..adb084ca88aba 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -9,13 +9,15 @@ import vllm.envs as envs from vllm.logger import init_logger +from vllm.model_executor.layers.pooler import PoolingConfig from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.models import ModelRegistry from vllm.platforms import current_platform from vllm.tracing import is_otel_available, otel_import_error_traceback -from vllm.transformers_utils.config import (ConfigFormat, get_config, - get_hf_image_processor_config, - get_hf_text_config) +from vllm.transformers_utils.config import ( + ConfigFormat, get_config, get_hf_image_processor_config, + get_hf_text_config, get_pooling_config, + get_sentence_transformer_tokenizer_config) from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, is_hip, is_openvino, print_warning_once) @@ -110,6 +112,10 @@ class ModelConfig: can not be gathered from the vllm arguments. config_format: The config format which shall be loaded. Defaults to 'auto' which defaults to 'hf'. + pooling_config: pooling and normalize config from the model - + only applies to sentence-transformers models. + bert_config: tokenizationconfiguration dictionary for a given + Sentence Transformer BERT model. mm_processor_kwargs: Arguments to be forwarded to the model's processor for multi-modal data, e.g., image processor. """ @@ -173,6 +179,8 @@ def __init__(self, code_revision, rope_scaling, rope_theta, config_format) self.hf_text_config = get_hf_text_config(self.hf_config) + self.pooling_config = self.get_pooling_config() + self.bert_config = self._get_bert_config() self.hf_image_processor_config = get_hf_image_processor_config( self.model, revision) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) @@ -205,7 +213,8 @@ def __init__(self, max_model_len=max_model_len, disable_sliding_window=self.disable_sliding_window, sliding_window_len=self.get_hf_config_sliding_window(), - spec_target_max_model_len=spec_target_max_model_len) + spec_target_max_model_len=spec_target_max_model_len, + bert_config=self.bert_config) self.served_model_name = get_served_model_name(model, served_model_name) self.multimodal_config = self._init_multimodal_config( @@ -242,6 +251,10 @@ def _init_multimodal_config( return None + def _get_bert_config(self): + return get_sentence_transformer_tokenizer_config( + self.model, self.revision) + def _init_attention_free(self) -> bool: architectures = getattr(self.hf_config, "architectures", []) return ModelRegistry.is_attention_free_model(architectures) @@ -405,6 +418,13 @@ def _verify_bnb_config(self) -> None: "fallback to the eager mode.") self.enforce_eager = True + def get_pooling_config(self) -> Optional[PoolingConfig]: + pooling_config = get_pooling_config(self.model, self.revision) + if pooling_config is not None: + return PoolingConfig(pooling_config["pooling_type"], + pooling_config["normalize"]) + return None + def verify_async_output_proc(self, parallel_config, speculative_config, device_config) -> None: if not self.use_async_output_proc: @@ -1715,6 +1735,7 @@ def _get_and_verify_max_len( disable_sliding_window: bool, sliding_window_len: Optional[Union[int, List[Optional[int]]]], spec_target_max_model_len: Optional[int] = None, + bert_config: Optional[Any] = None, ) -> int: """Get and verify the model's maximum length.""" derived_max_model_len = float("inf") @@ -1797,6 +1818,9 @@ def _get_and_verify_max_len( "original_max_position_embeddings"] derived_max_model_len *= scaling_factor + if bert_config and "max_seq_length" in bert_config: + derived_max_model_len = bert_config["max_seq_length"] + # If the user specified a max length, make sure it is smaller than the # derived length from the HF model config. if max_model_len is None: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1dd0f097c74ff..06e3cdad2d0cf 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -254,44 +254,33 @@ def __init__( "num_scheduler_steps=%d, chunked_prefill_enabled=%s " "multi_step_stream_outputs=%s, enable_prefix_caching=%s, " "use_async_output_proc=%s, use_cached_outputs=%s, " + "pooling_config_type=%s, normalize=%s, " "chat_template_text_format=%s, mm_processor_kwargs=%s)", - VLLM_VERSION, - model_config.model, - speculative_config, - model_config.tokenizer, - model_config.skip_tokenizer_init, - model_config.tokenizer_mode, - model_config.revision, - model_config.override_neuron_config, - model_config.rope_scaling, - model_config.rope_theta, - model_config.tokenizer_revision, - model_config.trust_remote_code, - model_config.dtype, - model_config.max_model_len, - load_config.download_dir, - load_config.load_format, - parallel_config.tensor_parallel_size, + VLLM_VERSION, model_config.model, speculative_config, + model_config.tokenizer, model_config.skip_tokenizer_init, + model_config.tokenizer_mode, model_config.revision, + model_config.override_neuron_config, model_config.rope_scaling, + model_config.rope_theta, model_config.tokenizer_revision, + model_config.trust_remote_code, model_config.dtype, + model_config.max_model_len, load_config.download_dir, + load_config.load_format, parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size, parallel_config.disable_custom_all_reduce, - model_config.quantization, - model_config.enforce_eager, - cache_config.cache_dtype, - model_config.quantization_param_path, - device_config.device, - decoding_config, - observability_config, - model_config.seed, - model_config.served_model_name, + model_config.quantization, model_config.enforce_eager, + cache_config.cache_dtype, model_config.quantization_param_path, + device_config.device, decoding_config, observability_config, + model_config.seed, model_config.served_model_name, scheduler_config.num_scheduler_steps, scheduler_config.chunked_prefill_enabled, scheduler_config.multi_step_stream_outputs, cache_config.enable_prefix_caching, - model_config.use_async_output_proc, - use_cached_outputs, + model_config.use_async_output_proc, use_cached_outputs, + model_config.pooling_config.pooling_type + if model_config.pooling_config is not None else None, + model_config.pooling_config.normalize + if model_config.pooling_config is not None else None, model_config.chat_template_text_format, - model_config.mm_processor_kwargs, - ) + model_config.mm_processor_kwargs) # TODO(woosuk): Print more configs in debug mode. self.model_config = model_config self.cache_config = cache_config diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 3455a4ccf282f..221bb77434868 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from enum import IntEnum import torch @@ -13,6 +14,33 @@ class PoolingType(IntEnum): LAST = 0 ALL = 1 CLS = 2 + MEAN = 3 + + +@dataclass +class PoolingConfig: + """A class that configures the pooling operation which + only applies to sentence-transformers models. + More at: https://www.sbert.net/ + + Attributes: + pooling_type (str): The type of pooling to use. + normalize (bool): Whether to normalize the pooled data. + + Methods: + get_pooling_type(pooling_type_name): Returns the pooling + type enum value corresponding to the given string. + """ + + def __init__(self, pooling_type: str, normalize: bool): + self.pooling_type = self.get_pooling_type(pooling_type) + self.normalize = normalize + + def get_pooling_type(self, pooling_type_name: str) -> PoolingType: + pooling_types = PoolingType.__dict__.items() + return PoolingType( + next((value for key, value in pooling_types + if key.lower() in pooling_type_name), PoolingType.CLS)) class Pooler(nn.Module): @@ -24,7 +52,7 @@ class Pooler(nn.Module): 3. Returns structured results as `PoolerOutput`. Attributes: - pooling_type: The type of pooling to use (LAST, ALL, CLS). + pooling_type: The type of pooling to use (LAST, ALL, CLS, MEAN). normalize: Whether to normalize the pooled data. """ @@ -58,6 +86,17 @@ def forward( for prompt_len in prompt_lens: pooled_data.append(hidden_states[offset:offset + prompt_len]) offset += prompt_len + elif self.pooling_type == PoolingType.MEAN: + # Calculate mean pooling + cumsum = torch.cumsum(hidden_states, dim=0) + start_indices = torch.cat([ + torch.tensor([0], device=hidden_states.device), + torch.cumsum(prompt_lens[:-1], dim=0) + ]) + end_indices = torch.cumsum(prompt_lens, dim=0) + pooled_data = ( + cumsum[end_indices - 1] - cumsum[start_indices] + + hidden_states[start_indices]) / prompt_lens.unsqueeze(1) else: raise ValueError(f"Invalid pooling type: {self.pooling_type}") diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 813f58339da37..91c7b870cf671 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -28,6 +28,7 @@ get_tensor_model_parallel_world_size) from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger +from vllm.model_executor.layers.pooler import PoolingConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.model_loader.tensorizer import ( @@ -122,7 +123,8 @@ def _get_model_initialization_kwargs( model_class: Type[nn.Module], lora_config: Optional[LoRAConfig], multimodal_config: Optional[MultiModalConfig], - scheduler_config: Optional[SchedulerConfig] = None) -> Dict[str, Any]: + scheduler_config: Optional[SchedulerConfig] = None, + pooling_config: Optional[PoolingConfig] = None) -> Dict[str, Any]: """Get extra kwargs for model initialization.""" extra_kwargs: Dict[str, Any] = {} @@ -144,18 +146,25 @@ def _get_model_initialization_kwargs( if has_inner_state(model_class) and scheduler_config: extra_kwargs["scheduler_config"] = scheduler_config + if pooling_config is not None: + extra_kwargs["pooling_config"] = pooling_config + return extra_kwargs -def build_model(model_class: Type[nn.Module], hf_config: PretrainedConfig, +def build_model(model_class: Type[nn.Module], + hf_config: PretrainedConfig, cache_config: Optional[CacheConfig], - quant_config: Optional[QuantizationConfig], *, + quant_config: Optional[QuantizationConfig], + *, lora_config: Optional[LoRAConfig], multimodal_config: Optional[MultiModalConfig], - scheduler_config: Optional[SchedulerConfig]) -> nn.Module: + scheduler_config: Optional[SchedulerConfig], + pooling_config: Optional[PoolingConfig] = None) -> nn.Module: extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config, multimodal_config, - scheduler_config) + scheduler_config, + pooling_config) return model_class(config=hf_config, cache_config=cache_config, @@ -172,15 +181,15 @@ def _initialize_model( """Initialize a model with the given configurations.""" model_class, _ = get_model_architecture(model_config) - return build_model( - model_class, - model_config.hf_config, - cache_config=cache_config, - quant_config=_get_quantization_config(model_config, load_config), - lora_config=lora_config, - multimodal_config=model_config.multimodal_config, - scheduler_config=scheduler_config, - ) + return build_model(model_class, + model_config.hf_config, + cache_config=cache_config, + quant_config=_get_quantization_config( + model_config, load_config), + lora_config=lora_config, + multimodal_config=model_config.multimodal_config, + scheduler_config=scheduler_config, + pooling_config=model_config.pooling_config) class BaseModelLoader(ABC): diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 4c0a0e303e655..54b85b05287fb 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -12,7 +12,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.pooler import Pooler, PoolingConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -382,15 +382,15 @@ class BertEmbeddingModel(nn.Module): _pooler: An instance of Pooler used for pooling operations. """ - def __init__( - self, - config: BertConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: + def __init__(self, + config: BertConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + pooling_config: Optional[PoolingConfig] = None) -> None: super().__init__() self.model = BertModel(config, cache_config, quant_config) - self._pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True) + self._pooler = Pooler(pooling_config.pooling_type, + pooling_config.normalize) def forward( self, diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index d79248f93f5ae..6e62ef28926fd 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -31,7 +31,8 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.pooler import (Pooler, PoolingConfig, + PoolingType) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput @@ -473,12 +474,18 @@ class Gemma2EmbeddingModel(nn.Module, SupportsPP): def __init__( self, + pooling_config: Optional[PoolingConfig] = None, **kwargs, ) -> None: super().__init__() self.model = Gemma2Model(**kwargs) - self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + if pooling_config is not None: + self._pooler = Pooler(pooling_config.pooling_type, + pooling_config.normalize) + else: + self._pooler = Pooler(pooling_type=PoolingType.LAST, + normalize=True) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index c346e3e808e3f..c3c992cf2e17f 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -38,7 +38,8 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.pooler import (Pooler, PoolingConfig, + PoolingType) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( get_compressed_tensors_cache_scale) @@ -627,12 +628,18 @@ class LlamaEmbeddingModel(nn.Module, SupportsPP): def __init__( self, + pooling_config: Optional[PoolingConfig] = None, **kwargs, ) -> None: super().__init__() self.model = LlamaModel(**kwargs) - self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + if pooling_config is not None: + self._pooler = Pooler(pooling_config.pooling_type, + pooling_config.normalize) + else: + self._pooler = Pooler(pooling_type=PoolingType.LAST, + normalize=True) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 2a582deeaa2c9..7cb719f5c57aa 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -13,7 +13,8 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext -from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.pooler import (Pooler, PoolingConfig, + PoolingType) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.pooling_metadata import PoolingMetadata @@ -285,7 +286,8 @@ def __init__(self, config: LlavaNextConfig, multimodal_config: MultiModalConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: + quant_config: Optional[QuantizationConfig] = None, + pooling_config: Optional[PoolingConfig] = None) -> None: super().__init__() self.config = config @@ -306,7 +308,12 @@ def __init__(self, # The same model class supports both language generation and embedding # because the architecture name is the same - self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + if pooling_config is not None: + self._pooler = Pooler(pooling_config.pooling_type, + pooling_config.normalize) + else: + self._pooler = Pooler(pooling_type=PoolingType.LAST, + normalize=True) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 855a9b17585a4..6e8b323e89fb4 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -30,7 +30,8 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext, token_inputs) from vllm.logger import init_logger -from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.pooler import (Pooler, PoolingConfig, + PoolingType) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -525,7 +526,8 @@ def __init__(self, config: PretrainedConfig, multimodal_config: MultiModalConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: + quant_config: Optional[QuantizationConfig] = None, + pooling_config: Optional[PoolingConfig] = None) -> None: super().__init__() self.config = config @@ -547,7 +549,12 @@ def __init__(self, # The same model class supports both language generation and embedding # because the architecture name is the same - self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + if pooling_config is not None: + self._pooler = Pooler(pooling_config.pooling_type, + pooling_config.normalize) + else: + self._pooler = Pooler(pooling_type=PoolingType.LAST, + normalize=True) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index ee0eeb9db3808..3493aeffd2062 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -14,7 +14,8 @@ from vllm.config import CacheConfig, LoRAConfig from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.pooler import (Pooler, PoolingConfig, + PoolingType) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors, PoolerOutput @@ -64,6 +65,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, + pooling_config: Optional[PoolingConfig] = None, ) -> None: # TODO (@robertgshaw2): see if this can be moved out if (cache_config.sliding_window is not None @@ -93,7 +95,12 @@ def __init__( RowParallelLinear(config.hidden_size, 1, quant_config=quant_config), ) - self._pooler = Pooler(pooling_type=PoolingType.ALL, normalize=False) + if pooling_config is not None: + self._pooler = Pooler(pooling_config.pooling_type, + pooling_config.normalize) + else: + self._pooler = Pooler(pooling_type=PoolingType.ALL, + normalize=False) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 9bd2531d7a15c..65a7930824fae 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -6,6 +6,9 @@ import huggingface_hub from huggingface_hub import (file_exists, hf_hub_download, try_to_load_from_cache) +from huggingface_hub.utils import (EntryNotFoundError, LocalEntryNotFoundError, + RepositoryNotFoundError, + RevisionNotFoundError) from transformers import GenerationConfig, PretrainedConfig from transformers.models.auto.image_processing_auto import ( get_image_processor_config) @@ -202,7 +205,7 @@ def get_config( raise e elif config_format == ConfigFormat.MISTRAL: - config = load_params_config(model, revision) + config = load_params_config(model, revision, token=kwargs.get("token")) else: raise ValueError(f"Unsupported config format: {config_format}") @@ -232,6 +235,133 @@ def get_config( return config +def get_hf_file_to_dict(file_name, + model, + revision, + token: Optional[str] = None): + """ + Downloads a file from the Hugging Face Hub and returns + its contents as a dictionary. + + Parameters: + - file_name (str): The name of the file to download. + - model (str): The name of the model on the Hugging Face Hub. + - revision (str): The specific version of the model. + - token (str): The Hugging Face authentication token. + + Returns: + - config_dict (dict): A dictionary containing + the contents of the downloaded file. + """ + file_path = Path(model) / file_name + + if file_or_path_exists(model=model, + config_name=file_name, + revision=revision, + token=token): + + if not file_path.is_file(): + try: + hf_hub_file = hf_hub_download(model, + file_name, + revision=revision) + except (RepositoryNotFoundError, RevisionNotFoundError, + EntryNotFoundError, LocalEntryNotFoundError) as e: + logger.debug("File or repository not found in hf_hub_download", + e) + return None + file_path = Path(hf_hub_file) + + with open(file_path, "r") as file: + config_dict = json.load(file) + + return config_dict + return None + + +def get_pooling_config(model, revision='main', token: Optional[str] = None): + """ + This function gets the pooling and normalize + config from the model - only applies to + sentence-transformers models. + + Args: + model (str): The name of the Hugging Face model. + revision (str, optional): The specific version + of the model to use. Defaults to 'main'. + + Returns: + dict: A dictionary containing the pooling + type and whether normalization is used. + """ + + modules_file_name = "modules.json" + modules_dict = get_hf_file_to_dict(modules_file_name, model, revision, + token) + + if modules_dict is None: + return None + + pooling = next((item for item in modules_dict + if item["type"] == "sentence_transformers.models.Pooling"), + None) + normalize = bool( + next((item for item in modules_dict + if item["type"] == "sentence_transformers.models.Normalize"), + False)) + + if pooling: + + pooling_file_name = "{}/config.json".format(pooling["path"]) + pooling_dict = get_hf_file_to_dict(pooling_file_name, model, revision, + token) + pooling_type_name = next( + (item for item, val in pooling_dict.items() if val is True), None) + + return {"pooling_type": pooling_type_name, "normalize": normalize} + + return None + + +def get_sentence_transformer_tokenizer_config(model, + revision='main', + token: Optional[str] = None): + """ + Returns the tokenization configuration dictionary for a + given Sentence Transformer BERT model. + + Parameters: + - model (str): The name of the Sentence Transformer + BERT model. + - revision (str, optional): The revision of the m + odel to use. Defaults to 'main'. + - token (str): A Hugging Face access token. + + Returns: + - dict: A dictionary containing the configuration parameters + for the Sentence Transformer BERT model. + """ + for config_name in [ + "sentence_bert_config.json", + "sentence_roberta_config.json", + "sentence_distilbert_config.json", + "sentence_camembert_config.json", + "sentence_albert_config.json", + "sentence_xlm-roberta_config.json", + "sentence_xlnet_config.json", + ]: + bert_dict = get_hf_file_to_dict(config_name, model, revision, token) + if bert_dict: + break + + if not bert_dict: + return None + + if all(k in bert_dict for k in ("max_seq_length", "do_lower_case")): + return bert_dict + return None + + def maybe_register_config_serialize_by_value(trust_remote_code: bool) -> None: """Try to register HF model configuration class to serialize by value @@ -294,20 +424,15 @@ def _reduce_modelconfig(mc: ModelConfig): exc_info=e) -def load_params_config(model, revision) -> PretrainedConfig: +def load_params_config(model, + revision, + token: Optional[str] = None) -> PretrainedConfig: # This function loads a params.json config which # should be used when loading models in mistral format config_file_name = "params.json" - config_path = Path(model) / config_file_name - - if not config_path.is_file(): - config_path = Path( - hf_hub_download(model, config_file_name, revision=revision)) - - with open(config_path, "r") as file: - config_dict = json.load(file) + config_dict = get_hf_file_to_dict(config_file_name, model, revision, token) config_mapping = { "dim": "hidden_size", diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py index 9a4149251d747..95ac1d4e6baf7 100644 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -25,6 +25,11 @@ def init_tokenizer_from_configs(model_config: ModelConfig, trust_remote_code=model_config.trust_remote_code, revision=model_config.tokenizer_revision) + if (model_config.bert_config is not None + and "do_lower_case" in model_config.bert_config): + init_kwargs["do_lower_case"] = model_config.bert_config[ + "do_lower_case"] + return get_tokenizer_group(parallel_config.tokenizer_pool_config, **init_kwargs)