diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index f7dc167fea6e4..31a2443d1f94e 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -5,6 +5,8 @@ from vllm.engine.arg_utils import EngineArgs, nullable_kvs from vllm.utils import FlexibleArgumentParser +from vllm.model_executor.layers.pooler import PoolingConfig + @pytest.mark.parametrize(("arg", "expected"), [ (None, None), @@ -30,6 +32,15 @@ def test_limit_mm_per_prompt_parser(arg, expected): assert args.limit_mm_per_prompt == expected +def test_valid_pooling_config(): + parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) + args = parser.parse_args(["--pooling-type=MEAN", "--normalize=True"]) + expected_pooling_config = PoolingConfig(pooling_type='MEAN', + normalize=True) + engine_args = EngineArgs.from_cli_args(args=args) + assert engine_args.pooling_config == expected_pooling_config + + @pytest.mark.parametrize( ("arg"), [ diff --git a/vllm/config.py b/vllm/config.py index adb084ca88aba..212a8089dcdb8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -149,6 +149,7 @@ def __init__(self, override_neuron_config: Optional[Dict[str, Any]] = None, config_format: ConfigFormat = ConfigFormat.AUTO, chat_template_text_format: str = "string", + pooling_config: Optional[PoolingConfig] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None) -> None: self.model = model self.tokenizer = tokenizer @@ -179,7 +180,7 @@ def __init__(self, code_revision, rope_scaling, rope_theta, config_format) self.hf_text_config = get_hf_text_config(self.hf_config) - self.pooling_config = self.get_pooling_config() + self.pooling_config = pooling_config or self.get_pooling_config() self.bert_config = self._get_bert_config() self.hf_image_processor_config = get_hf_image_processor_config( self.model, revision) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c49f475b9ee61..94d9ecfe7c978 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -16,6 +16,7 @@ from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.model_executor.layers.pooler import PoolingConfig from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) from vllm.transformers_utils.utils import check_gguf_file @@ -183,6 +184,7 @@ class EngineArgs: override_neuron_config: Optional[Dict[str, Any]] = None mm_processor_kwargs: Optional[Dict[str, Any]] = None scheduling_policy: Literal["fcfs", "priority"] = "fcfs" + pooling_config: Optional[PoolingConfig] = None def __post_init__(self): if not self.tokenizer: @@ -850,6 +852,18 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'priority (lower value means earlier handling) and time of ' 'arrival deciding any ties).') + parser.add_argument('--pooling-type', + type=str, + default='CLS', + choices=['LAST', 'ALL', 'CLS', 'MEAN'], + help='Configures the pooling operation which ' + 'only applies to sentence-transformers models. ') + + parser.add_argument('--normalize', + type=bool, + default=None, + help='Wheter to normalize the pooled data.') + return parser @classmethod @@ -857,6 +871,10 @@ def from_cli_args(cls, args: argparse.Namespace): # Get the list of attributes of this dataclass. attrs = [attr.name for attr in dataclasses.fields(cls)] # Set the attributes from the parsed arguments. + pooling_config = PoolingConfig(pooling_type=args.pooling_type, + normalize=args.normalize) + d = vars(args) + d['pooling_config'] = pooling_config engine_args = cls(**{attr: getattr(args, attr) for attr in attrs}) return engine_args @@ -890,6 +908,7 @@ def create_model_config(self) -> ModelConfig: use_async_output_proc=not self.disable_async_output_proc, override_neuron_config=self.override_neuron_config, config_format=self.config_format, + pooling_config=self.pooling_config, mm_processor_kwargs=self.mm_processor_kwargs, ) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index db97fe0a0285b..c143a2542cd86 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -18,6 +18,7 @@ from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.pooler import PoolingConfig from vllm.model_executor.guided_decoding.guided_fields import ( GuidedDecodingRequest, LLMGuidedOptions) from vllm.outputs import EmbeddingRequestOutput, RequestOutput @@ -156,6 +157,7 @@ def __init__( max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, disable_async_output_proc: bool = False, + pooling_config: Optional[PoolingConfig] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None, # After positional args are removed, move this right below `model` task: TaskOption = "auto", @@ -192,6 +194,7 @@ def __init__( max_seq_len_to_capture=max_seq_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, disable_async_output_proc=disable_async_output_proc, + pooling_config=pooling_config, mm_processor_kwargs=mm_processor_kwargs, **kwargs, ) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 221bb77434868..e97792d8ad908 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -1,6 +1,8 @@ from dataclasses import dataclass from enum import IntEnum +from typing import Optional + import torch import torch.nn as nn @@ -32,7 +34,7 @@ class PoolingConfig: type enum value corresponding to the given string. """ - def __init__(self, pooling_type: str, normalize: bool): + def __init__(self, pooling_type: str, normalize: Optional[bool] = None): self.pooling_type = self.get_pooling_type(pooling_type) self.normalize = normalize