diff --git a/vllm/config.py b/vllm/config.py index 4533fb017188c..885f28bd39d01 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -10,11 +10,13 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.model_executor.layers.pooler import PoolingConfig from vllm.model_executor.models import ModelRegistry from vllm.platforms import current_platform from vllm.tracing import is_otel_available, otel_import_error_traceback from vllm.transformers_utils.config import (ConfigFormat, get_config, get_hf_image_processor_config, + get_pooling_config, get_hf_text_config) from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, is_hip, is_neuron, is_openvino, is_xpu, @@ -163,6 +165,7 @@ def __init__(self, code_revision, rope_scaling, rope_theta, config_format) self.hf_text_config = get_hf_text_config(self.hf_config) + self.pooling_config = self.get_pooling_config() self.hf_image_processor_config = get_hf_image_processor_config( self.model, revision) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) @@ -362,6 +365,12 @@ def _verify_bnb_config(self) -> None: "fallback to the eager mode.") self.enforce_eager = True + def get_pooling_config(self) -> PoolingConfig: + pooling_config = get_pooling_config(self.model, + self.revision) + return PoolingConfig(pooling_config["pooling_type"], + pooling_config["normalize"]) + def verify_async_output_proc(self, parallel_config, speculative_config, device_config) -> None: if not self.use_async_output_proc: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 61c21887e6816..dad445ee29c40 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -251,7 +251,7 @@ def __init__( "num_scheduler_steps=%d, chunked_prefill_enabled=%s " "multi_step_stream_outputs=%s, enable_prefix_caching=%s, " "use_async_output_proc=%s, use_cached_outputs=%s, " - "mm_processor_kwargs=%s)", + "mm_processor_kwargs=%s, pooling_config=%s)", VLLM_VERSION, model_config.model, speculative_config, @@ -287,6 +287,7 @@ def __init__( model_config.use_async_output_proc, use_cached_outputs, model_config.mm_processor_kwargs, + model_config.pooling_config ) # TODO(woosuk): Print more configs in debug mode. self.model_config = model_config diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 3455a4ccf282f..941293f84f855 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -13,6 +13,31 @@ class PoolingType(IntEnum): LAST = 0 ALL = 1 CLS = 2 + MEAN = 3 + MAX = 4 + + +class PoolingConfig(): + """A class that configures the pooling operation. + + Attributes: + pooling_type (str): The type of pooling to use. + normalize (bool): Whether to normalize the pooled data. + + Methods: + get_pooling_type(pooling_type_name): Returns the pooling + type enum value corresponding to the given string. + """ + def __init__(self, pooling_type: str, normalize: bool): + self.pooling_type = self.get_pooling_type(pooling_type) + self.normalize = normalize + + def get_pooling_type(self, pooling_type_name: str) -> PoolingType: + pooling_types = PoolingType.__dict__.items() + return PoolingType(next((value for key, + value in pooling_types if key.lower() + in pooling_type_name), + None)) class Pooler(nn.Module): diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 813f58339da37..bc319ea92354d 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -30,6 +30,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.pooler import PoolingConfig from vllm.model_executor.model_loader.tensorizer import ( TensorizerConfig, is_vllm_tensorized, load_with_tensorizer, serialize_vllm_model, tensorizer_weights_iterator) @@ -122,7 +123,8 @@ def _get_model_initialization_kwargs( model_class: Type[nn.Module], lora_config: Optional[LoRAConfig], multimodal_config: Optional[MultiModalConfig], - scheduler_config: Optional[SchedulerConfig] = None) -> Dict[str, Any]: + scheduler_config: Optional[SchedulerConfig] = None + ) -> Dict[str, Any]: """Get extra kwargs for model initialization.""" extra_kwargs: Dict[str, Any] = {} @@ -152,14 +154,17 @@ def build_model(model_class: Type[nn.Module], hf_config: PretrainedConfig, quant_config: Optional[QuantizationConfig], *, lora_config: Optional[LoRAConfig], multimodal_config: Optional[MultiModalConfig], - scheduler_config: Optional[SchedulerConfig]) -> nn.Module: + scheduler_config: Optional[SchedulerConfig], + pooling_config: Optional[PoolingConfig] = None) -> nn.Module: extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config, multimodal_config, - scheduler_config) + scheduler_config + ) return model_class(config=hf_config, cache_config=cache_config, quant_config=quant_config, + pooling_config=pooling_config, **extra_kwargs) @@ -180,6 +185,7 @@ def _initialize_model( lora_config=lora_config, multimodal_config=model_config.multimodal_config, scheduler_config=scheduler_config, + pooling_config=model_config.pooling_config ) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 4c0a0e303e655..e3e7b4a3ed2e5 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -12,7 +12,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.pooler import Pooler, PoolingConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -387,10 +387,13 @@ def __init__( config: BertConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + pooling_config: Optional[PoolingConfig] = None ) -> None: super().__init__() self.model = BertModel(config, cache_config, quant_config) - self._pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True) + print(pooling_config) + self._pooler = Pooler(pooling_config.pooling_type, + pooling_config.normalize) def forward( self, diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 46405f3529215..f0f4074556cd7 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -231,6 +231,72 @@ def get_config( return config +def get_hf_file_to_dict(file_name, model, revision): + """ + Downloads a file from the Hugging Face Hub and returns + its contents as a dictionary. + + Parameters: + - file_name (str): The name of the file to download. + - model (str): The name of the model on the Hugging Face Hub. + - revision (str): The specific version of the model. + + Returns: + - config_dict (dict): A dictionary containing + the contents of the downloaded file. + """ + file_path = Path(model) / file_name + + if not file_path.is_file(): + file_path = Path( + hf_hub_download(model, file_name, revision=revision)) + + with open(file_path, "r") as file: + config_dict = json.load(file) + + return config_dict + +def get_pooling_config(model, revision='main'): + """ + This function gets the pooling and normalize + config from the model. + + Args: + model (str): The name of the Hugging Face model. + revision (str, optional): The specific version + of the model to use. Defaults to 'main'. + + Returns: + dict: A dictionary containing the pooling + type and whether normalization is used. + """ + + modules_file_name = "modules.json" + modules_dict = get_hf_file_to_dict(modules_file_name, model, revision) + + pooling = next((item for item in modules_dict if + item["type"] == "sentence_transformers.models.Pooling"), + None) + normalize = next((item for item in modules_dict if + item["type"] == + "sentence_transformers.models.Normalize"), + False) + + if pooling: + + pooling_file_name = "{}/config.json".format(pooling["path"]) + pooling_dict = get_hf_file_to_dict(pooling_file_name, model, revision) + pooling_type_name = next((item for item, + val in pooling_dict.items() if val is True), + None) + + return { + "pooling_type": pooling_type_name, + "normalize": normalize + } + + return None + def load_params_config(model, revision) -> PretrainedConfig: # This function loads a params.json config which @@ -238,14 +304,7 @@ def load_params_config(model, revision) -> PretrainedConfig: config_file_name = "params.json" - config_path = Path(model) / config_file_name - - if not config_path.is_file(): - config_path = Path( - hf_hub_download(model, config_file_name, revision=revision)) - - with open(config_path, "r") as file: - config_dict = json.load(file) + config_dict = get_hf_file_to_dict(config_file_name, model, revision) config_mapping = { "dim": "hidden_size",