diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 73c6ecb561e24..e92e2588d01cb 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -32,10 +32,9 @@ def test_limit_mm_per_prompt_parser(arg, expected): def test_valid_pooling_config(): parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) - args = parser.parse_args(["--pooling-type=MEAN", "--normalize=True"]) + args = parser.parse_args(["--pooling-type=MEAN"]) engine_args = EngineArgs.from_cli_args(args=args) assert engine_args.pooling_type == 'MEAN' - assert engine_args.normalize @pytest.mark.parametrize( diff --git a/tests/model_executor/test_model_load_with_params.py b/tests/model_executor/test_model_load_with_params.py index 7eab521848ad6..a3e5e04d48789 100644 --- a/tests/model_executor/test_model_load_with_params.py +++ b/tests/model_executor/test_model_load_with_params.py @@ -28,8 +28,8 @@ def test_model_loading_with_params(vllm_runner): assert model_config.bert_config["do_lower_case"] # asserts on the pooling config files - assert model_config.pooling_config.pooling_type == PoolingType.CLS - assert model_config.pooling_config.normalize + assert model_config.pooler_config.pooling_type == PoolingType.CLS.name + assert model_config.pooler_config.pooling_norm # asserts on the tokenizer loaded assert model_tokenizer.tokenizer_id == "BAAI/bge-base-en-v1.5" diff --git a/tests/test_config.py b/tests/test_config.py index a29250aa6027c..0ffa9172d2f53 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,7 +1,7 @@ import pytest from vllm.config import ModelConfig -from vllm.model_executor.layers.pooler import PoolingConfig, PoolingType +from vllm.model_executor.layers.pooler import PoolingType @pytest.mark.parametrize(("model_id", "expected_task"), [ @@ -116,12 +116,15 @@ def test_get_pooling_config(): revision=None, ) - minilm_pooling_config = minilm_model_config.get_pooling_config(None, None) + minilm_pooling_config = minilm_model_config._init_pooler_config( + pooling_type=None, + pooling_norm=None, + pooling_returned_token_ids=None, + pooling_softmax=None, + pooling_step_tag_id=None) - assert isinstance(minilm_model_config.pooling_config, PoolingConfig) - assert minilm_pooling_config.normalize - assert isinstance(minilm_pooling_config.pooling_type, PoolingType) - assert minilm_pooling_config.pooling_type == PoolingType.MEAN + assert minilm_pooling_config.pooling_norm + assert minilm_pooling_config.pooling_type == PoolingType.MEAN.name def test_get_pooling_config_from_args(): @@ -135,12 +138,15 @@ def test_get_pooling_config_from_args(): dtype="float16", revision=None) - minilm_pooling_config = minilm_model_config.get_pooling_config('CLS', True) + minilm_pooling_config = minilm_model_config._init_pooler_config( + pooling_type='CLS', + pooling_norm=True, + pooling_returned_token_ids=None, + pooling_softmax=None, + pooling_step_tag_id=None) - assert isinstance(minilm_model_config.pooling_config, PoolingConfig) - assert minilm_pooling_config.normalize - assert isinstance(minilm_pooling_config.pooling_type, PoolingType) - assert minilm_pooling_config.pooling_type == PoolingType.CLS + assert minilm_pooling_config.pooling_norm + assert minilm_pooling_config.pooling_type == PoolingType.CLS.name def test_get_bert_tokenization_sentence_transformer_config(): diff --git a/vllm/config.py b/vllm/config.py index c6c75ccdad746..0b2deb0cbb07a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -108,8 +108,6 @@ class ModelConfig: can not be gathered from the vllm arguments. config_format: The config format which shall be loaded. Defaults to 'auto' which defaults to 'hf'. - pooling_config: pooling and normalize config from the model - - only applies to sentence-transformers models. bert_config: tokenizationconfiguration dictionary for a given Sentence Transformer BERT model. mm_processor_kwargs: Arguments to be forwarded to the model's processor @@ -282,12 +280,28 @@ def _init_pooler_config( pooling_returned_token_ids: Optional[List[int]] = None ) -> Optional["PoolerConfig"]: if self.task == "embedding": - return PoolerConfig( - pooling_type=pooling_type, - pooling_norm=pooling_norm, - pooling_softmax=pooling_softmax, - pooling_step_tag_id=pooling_step_tag_id, - pooling_returned_token_ids=pooling_returned_token_ids) + pooling_config = get_pooling_config(self.model, self.revision) + if pooling_config is not None: + pooling_type_from_file = pooling_config["pooling_type"] + normalize_from_file = pooling_config["normalize"] + pooling_config_from_file = PoolerConfig( + pooling_type=pooling_type_from_file, + pooling_norm=normalize_from_file, + pooling_softmax=pooling_softmax, + pooling_step_tag_id=pooling_step_tag_id, + pooling_returned_token_ids=pooling_returned_token_ids) + if pooling_type is not None: + pooling_config_from_file.pooling_type = pooling_type + if pooling_norm is not None: + pooling_config_from_file.pooling_norm = pooling_norm + return pooling_config_from_file + else: + return PoolerConfig( + pooling_type=pooling_type, + pooling_norm=pooling_norm, + pooling_softmax=pooling_softmax, + pooling_step_tag_id=pooling_step_tag_id, + pooling_returned_token_ids=pooling_returned_token_ids) return None def _init_attention_free(self) -> bool: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 07216f5bb8fc4..43ab68a5de4f9 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -15,7 +15,6 @@ SpeculativeConfig, TaskOption, TokenizerPoolConfig) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger -from vllm.model_executor.layers.pooler import PoolingType from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) @@ -183,8 +182,6 @@ class EngineArgs: override_neuron_config: Optional[Dict[str, Any]] = None mm_processor_kwargs: Optional[Dict[str, Any]] = None scheduling_policy: Literal["fcfs", "priority"] = "fcfs" - pooling_type: Optional[str] = None - normalize: Optional[bool] = None # Pooling configuration. pooling_type: Optional[str] = None @@ -853,7 +850,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( '--pooling-type', - choices=['LAST', 'ALL', 'CLS', 'STEP'], + choices=['LAST', 'ALL', 'CLS', 'STEP', 'MEAN'], default=None, help='Used to configure the pooling method in the embedding model.' ) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 9138eae24c334..3fd34fadee1ca 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -278,15 +278,21 @@ def __init__( parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size, parallel_config.disable_custom_all_reduce, - model_config.quantization, model_config.enforce_eager, - cache_config.cache_dtype, model_config.quantization_param_path, - device_config.device, decoding_config, observability_config, - model_config.seed, model_config.served_model_name, + model_config.quantization, + model_config.enforce_eager, + cache_config.cache_dtype, + model_config.quantization_param_path, + device_config.device, + decoding_config, + observability_config, + model_config.seed, + model_config.served_model_name, scheduler_config.num_scheduler_steps, scheduler_config.chunked_prefill_enabled, scheduler_config.multi_step_stream_outputs, cache_config.enable_prefix_caching, - model_config.use_async_output_proc, use_cached_outputs, + model_config.use_async_output_proc, + use_cached_outputs, model_config.chat_template_text_format, model_config.mm_processor_kwargs, model_config.pooler_config, diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 9634536eedb56..c21f0961d02cc 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from enum import IntEnum from typing import List, Optional diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 96dd0702a2b41..693f32160a289 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -31,8 +31,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.pooler import (Pooler, PoolingConfig, - PoolingType) +from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 3f66f73ca3580..8a9e5203972be 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -38,8 +38,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.pooler import (Pooler, PoolingConfig, - PoolingType) +from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( get_compressed_tensors_cache_scale) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 2ee24471dc0a9..253b345c3f6ea 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -18,7 +18,6 @@ from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger -from vllm.model_executor.layers.pooler import PoolingType # yapf conflicts with isort for this block # yapf: disable from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, @@ -337,7 +336,7 @@ def get_pooling_config_name(pooling_name): if "lasttoken" in pooling_name: pooling_name = "last" - supported_pooling_types = [i.name for i in PoolingType] + supported_pooling_types = ['LAST', 'ALL', 'CLS', 'STEP', 'MEAN'] pooling_type_name = pooling_name.upper() try: