From ed5a4ab837bffc8cb3f21bc30d470303530c230a Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 15 Nov 2024 14:59:00 +0800 Subject: [PATCH] [Misc] Consolidate pooler config overrides (#10351) Signed-off-by: DarkLight1337 --- docs/source/models/supported_models.rst | 10 ++- tests/engine/test_arg_utils.py | 9 +- tests/test_config.py | 50 +++++------ vllm/config.py | 112 ++++++++++++------------ vllm/engine/arg_utils.py | 85 ++++-------------- vllm/entrypoints/llm.py | 15 +--- vllm/model_executor/layers/pooler.py | 54 +++++++----- 7 files changed, 143 insertions(+), 192 deletions(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index a76bb775c6ee6..96a513d42753b 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -345,6 +345,9 @@ Text Embedding Some model architectures support both generation and embedding tasks. In this case, you have to pass :code:`--task embedding` to run the model in embedding mode. +.. tip:: + You can override the model's pooling method by passing :code:`--override-pooler-config`. + Reward Modeling --------------- @@ -364,7 +367,7 @@ Reward Modeling - ✅︎ .. note:: - As an interim measure, these models are supported via Embeddings API. See `this RFC `_ for upcoming changes. + As an interim measure, these models are supported in both offline and online inference via Embeddings API. Classification --------------- @@ -385,7 +388,7 @@ Classification - ✅︎ .. note:: - As an interim measure, these models are supported via Embeddings API. It will be supported via Classification API in the future (no reference APIs exist now). + As an interim measure, these models are supported in both offline and online inference via Embeddings API. Multimodal Language Models @@ -600,6 +603,9 @@ Multimodal Embedding Some model architectures support both generation and embedding tasks. In this case, you have to pass :code:`--task embedding` to run the model in embedding mode. +.. tip:: + You can override the model's pooling method by passing :code:`--override-pooler-config`. + Model Support Policy ===================== diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index e92e2588d01cb..7b1be5a9802fd 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -2,6 +2,7 @@ import pytest +from vllm.config import PoolerConfig from vllm.engine.arg_utils import EngineArgs, nullable_kvs from vllm.utils import FlexibleArgumentParser @@ -32,9 +33,13 @@ def test_limit_mm_per_prompt_parser(arg, expected): def test_valid_pooling_config(): parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) - args = parser.parse_args(["--pooling-type=MEAN"]) + args = parser.parse_args([ + '--override-pooler-config', + '{"pooling_type": "MEAN"}', + ]) engine_args = EngineArgs.from_cli_args(args=args) - assert engine_args.pooling_type == 'MEAN' + assert engine_args.override_pooler_config == PoolerConfig( + pooling_type="MEAN", ) @pytest.mark.parametrize( diff --git a/tests/test_config.py b/tests/test_config.py index df382d22d83ec..3cf90297ce177 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,6 +1,8 @@ +from dataclasses import asdict + import pytest -from vllm.config import ModelConfig +from vllm.config import ModelConfig, PoolerConfig from vllm.model_executor.layers.pooler import PoolingType from vllm.platforms import current_platform @@ -108,7 +110,7 @@ def test_get_sliding_window(): reason="Xformers backend is not supported on ROCm.") def test_get_pooling_config(): model_id = "sentence-transformers/all-MiniLM-L12-v2" - minilm_model_config = ModelConfig( + model_config = ModelConfig( model_id, task="auto", tokenizer=model_id, @@ -119,39 +121,31 @@ def test_get_pooling_config(): revision=None, ) - minilm_pooling_config = minilm_model_config._init_pooler_config( - pooling_type=None, - pooling_norm=None, - pooling_returned_token_ids=None, - pooling_softmax=None, - pooling_step_tag_id=None) + pooling_config = model_config._init_pooler_config(None) + assert pooling_config is not None - assert minilm_pooling_config.pooling_norm - assert minilm_pooling_config.pooling_type == PoolingType.MEAN.name + assert pooling_config.normalize + assert pooling_config.pooling_type == PoolingType.MEAN.name @pytest.mark.skipif(current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm.") def test_get_pooling_config_from_args(): model_id = "sentence-transformers/all-MiniLM-L12-v2" - minilm_model_config = ModelConfig(model_id, - task="auto", - tokenizer=model_id, - tokenizer_mode="auto", - trust_remote_code=False, - seed=0, - dtype="float16", - revision=None) - - minilm_pooling_config = minilm_model_config._init_pooler_config( - pooling_type='CLS', - pooling_norm=True, - pooling_returned_token_ids=None, - pooling_softmax=None, - pooling_step_tag_id=None) - - assert minilm_pooling_config.pooling_norm - assert minilm_pooling_config.pooling_type == PoolingType.CLS.name + model_config = ModelConfig(model_id, + task="auto", + tokenizer=model_id, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + revision=None) + + override_config = PoolerConfig(pooling_type='CLS', normalize=True) + + pooling_config = model_config._init_pooler_config(override_config) + assert pooling_config is not None + assert asdict(pooling_config) == asdict(override_config) @pytest.mark.skipif(current_platform.is_rocm(), diff --git a/vllm/config.py b/vllm/config.py index 83b1483eb99e0..1c190da1d327e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -112,10 +112,6 @@ class ModelConfig: the model name will be the same as `model`. limit_mm_per_prompt: Maximum number of data items per modality per prompt. Only applicable for multimodal models. - override_neuron_config: Initialize non default neuron config or - override default neuron config that are specific to Neuron devices, - this argument will be used to configure the neuron config that - can not be gathered from the vllm arguments. config_format: The config format which shall be loaded. Defaults to 'auto' which defaults to 'hf'. hf_overrides: If a dictionary, contains arguments to be forwarded to the @@ -123,20 +119,12 @@ class ModelConfig: HuggingFace config. mm_processor_kwargs: Arguments to be forwarded to the model's processor for multi-modal data, e.g., image processor. - pooling_type: Used to configure the pooling method in the embedding - model. - pooling_norm: Used to determine whether to normalize the pooled - data in the embedding model. - pooling_softmax: Used to determine whether to softmax the pooled - data in the embedding model. - pooling_step_tag_id: When pooling_step_tag_id is not -1, it indicates - that the score corresponding to the pooling_step_tag_id in the - generated sentence should be returned. Otherwise, it returns - the scores for all tokens. - pooling_returned_token_ids: pooling_returned_token_ids represents a - list of indices for the vocabulary dimensions to be extracted, - such as the token IDs of good_token and bad_token in the - math-shepherd-mistral-7b-prm model. + override_neuron_config: Initialize non default neuron config or + override default neuron config that are specific to Neuron devices, + this argument will be used to configure the neuron config that + can not be gathered from the vllm arguments. + override_pooling_config: Initialize non default pooling config or + override default pooling config for the embedding model. """ def __init__( @@ -166,16 +154,12 @@ def __init__( served_model_name: Optional[Union[str, List[str]]] = None, limit_mm_per_prompt: Optional[Mapping[str, int]] = None, use_async_output_proc: bool = True, - override_neuron_config: Optional[Dict[str, Any]] = None, config_format: ConfigFormat = ConfigFormat.AUTO, chat_template_text_format: str = "string", hf_overrides: Optional[HfOverrides] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None, - pooling_type: Optional[str] = None, - pooling_norm: Optional[bool] = None, - pooling_softmax: Optional[bool] = None, - pooling_step_tag_id: Optional[int] = None, - pooling_returned_token_ids: Optional[List[int]] = None) -> None: + override_neuron_config: Optional[Dict[str, Any]] = None, + override_pooler_config: Optional["PoolerConfig"] = None) -> None: self.model = model self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode @@ -280,13 +264,7 @@ def __init__( supported_tasks, task = self._resolve_task(task, self.hf_config) self.supported_tasks = supported_tasks self.task: Final = task - self.pooler_config = self._init_pooler_config( - pooling_type, - pooling_norm, - pooling_softmax, - pooling_step_tag_id, - pooling_returned_token_ids, - ) + self.pooler_config = self._init_pooler_config(override_pooler_config) self._verify_quantization() self._verify_cuda_graph() @@ -311,27 +289,21 @@ def _get_encoder_config(self): def _init_pooler_config( self, - pooling_type: Optional[str] = None, - pooling_norm: Optional[bool] = None, - pooling_softmax: Optional[bool] = None, - pooling_step_tag_id: Optional[int] = None, - pooling_returned_token_ids: Optional[List[int]] = None + override_pooler_config: Optional["PoolerConfig"], ) -> Optional["PoolerConfig"]: + if self.task == "embedding": - pooling_config = get_pooling_config(self.model, self.revision) - if pooling_config is not None: - # override if user does not - # specifies pooling_type and/or pooling_norm - if pooling_type is None: - pooling_type = pooling_config["pooling_type"] - if pooling_norm is None: - pooling_norm = pooling_config["normalize"] - return PoolerConfig( - pooling_type=pooling_type, - pooling_norm=pooling_norm, - pooling_softmax=pooling_softmax, - pooling_step_tag_id=pooling_step_tag_id, - pooling_returned_token_ids=pooling_returned_token_ids) + user_config = override_pooler_config or PoolerConfig() + + base_config = get_pooling_config(self.model, self.revision) + if base_config is not None: + # Only set values that are not overridden by the user + for k, v in base_config.items(): + if getattr(user_config, k) is None: + setattr(user_config, k, v) + + return user_config + return None def _init_attention_free(self) -> bool: @@ -1786,13 +1758,43 @@ class MultiModalConfig: @dataclass class PoolerConfig: - """Controls the behavior of pooler in embedding model""" + """Controls the behavior of output pooling in embedding models.""" pooling_type: Optional[str] = None - pooling_norm: Optional[bool] = None - pooling_softmax: Optional[bool] = None - pooling_step_tag_id: Optional[int] = None - pooling_returned_token_ids: Optional[List[int]] = None + """ + The pooling method of the embedding model. This should be a key in + :class:`vllm.model_executor.layers.pooler.PoolingType`. + """ + + normalize: Optional[bool] = None + """ + Whether to normalize the pooled outputs. Usually, this should be set to + ``True`` for embedding outputs. + """ + + softmax: Optional[bool] = None + """ + Whether to apply softmax to the pooled outputs. Usually, this should be set + to ``True`` for classification outputs. + """ + + step_tag_id: Optional[int] = None + """ + If set, only the score corresponding to the ``step_tag_id`` in the + generated sentence should be returned. Otherwise, the scores for all tokens + are returned. + """ + + returned_token_ids: Optional[List[int]] = None + """ + A list of indices for the vocabulary dimensions to be extracted, + such as the token IDs of ``good_token`` and ``bad_token`` in the + ``math-shepherd-mistral-7b-prm`` model. + """ + + @staticmethod + def from_json(json_str: str) -> "PoolerConfig": + return PoolerConfig(**json.loads(json_str)) _STR_DTYPE_TO_TORCH_DTYPE = { diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 244aa09e12552..4afc61c8d0c4c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -11,12 +11,11 @@ from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig, DeviceConfig, HfOverrides, LoadConfig, LoadFormat, LoRAConfig, ModelConfig, ObservabilityConfig, - ParallelConfig, PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig, TaskOption, TokenizerPoolConfig, - VllmConfig) + ParallelConfig, PoolerConfig, PromptAdapterConfig, + SchedulerConfig, SpeculativeConfig, TaskOption, + TokenizerPoolConfig, VllmConfig) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger -from vllm.model_executor.layers.pooler import PoolingType from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.platforms import current_platform from vllm.transformers_utils.utils import check_gguf_file @@ -187,15 +186,10 @@ class EngineArgs: otlp_traces_endpoint: Optional[str] = None collect_detailed_traces: Optional[str] = None disable_async_output_proc: bool = False - override_neuron_config: Optional[Dict[str, Any]] = None scheduling_policy: Literal["fcfs", "priority"] = "fcfs" - # Pooling configuration. - pooling_type: Optional[str] = None - pooling_norm: Optional[bool] = None - pooling_softmax: Optional[bool] = None - pooling_step_tag_id: Optional[int] = None - pooling_returned_token_ids: Optional[List[int]] = None + override_neuron_config: Optional[Dict[str, Any]] = None + override_pooler_config: Optional[PoolerConfig] = None def __post_init__(self): if not self.tokenizer: @@ -859,12 +853,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.disable_async_output_proc, help="Disable async output processing. This may result in " "lower performance.") - parser.add_argument( - '--override-neuron-config', - type=json.loads, - default=None, - help="Override or set neuron device configuration. " - "e.g. {\"cast_logits_dtype\": \"bloat16\"}.'") parser.add_argument( '--scheduling-policy', @@ -877,56 +865,17 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'arrival deciding any ties).') parser.add_argument( - '--pooling-type', - choices=[pt.name for pt in PoolingType], - default=None, - help='Used to configure the pooling method in the embedding model.' - ) - - parser.add_argument('--pooling-norm', - default=None, - action='store_true', - help="Used to determine whether to normalize " - "the pooled data in the embedding model.") - - parser.add_argument('--no-pooling-norm', - default=None, - action='store_false', - dest='pooling_norm', - help="Used to determine whether to normalize " - "the pooled data in the embedding model.") - - parser.add_argument('--pooling-softmax', - default=None, - action='store_true', - help="Used to determine whether to softmax " - "the pooled data in the embedding model.") - - parser.add_argument('--no-pooling-softmax', - default=None, - action='store_false', - dest='pooling_softmax', - help="Used to determine whether to softmax " - "the pooled data in the embedding model.") - - parser.add_argument( - '--pooling-step-tag-id', - type=int, + '--override-neuron-config', + type=json.loads, default=None, - help="When pooling-step-tag-id is not -1, it indicates " - "that the score corresponding to the step-tag-ids in the " - "generated sentence should be returned. Otherwise, it " - "returns the scores for all tokens.") - + help="Override or set neuron device configuration. " + "e.g. {\"cast_logits_dtype\": \"bloat16\"}.'") parser.add_argument( - '--pooling-returned-token-ids', - nargs='+', - type=int, + '--override-pooler-config', + type=PoolerConfig.from_json, default=None, - help="pooling-returned-token-ids represents a list of " - "indices for the vocabulary dimensions to be extracted, " - "such as the token IDs of good_token and bad_token in " - "the math-shepherd-mistral-7b-prm model.") + help="Override or set the pooling method in the embedding model. " + "e.g. {\"pooling_type\": \"mean\", \"normalize\": false}.'") return parser @@ -967,14 +916,10 @@ def create_model_config(self) -> ModelConfig: served_model_name=self.served_model_name, limit_mm_per_prompt=self.limit_mm_per_prompt, use_async_output_proc=not self.disable_async_output_proc, - override_neuron_config=self.override_neuron_config, config_format=self.config_format, mm_processor_kwargs=self.mm_processor_kwargs, - pooling_type=self.pooling_type, - pooling_norm=self.pooling_norm, - pooling_softmax=self.pooling_softmax, - pooling_step_tag_id=self.pooling_step_tag_id, - pooling_returned_token_ids=self.pooling_returned_token_ids, + override_neuron_config=self.override_neuron_config, + override_pooler_config=self.override_pooler_config, ) def create_load_config(self) -> LoadConfig: diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 63c2bb6097079..3ab467e649b57 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -9,7 +9,8 @@ from vllm import envs from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, BeamSearchSequence, get_beam_search_score) -from vllm.engine.arg_utils import EngineArgs, HfOverrides, TaskOption +from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig, + TaskOption) from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, apply_hf_chat_template, @@ -162,11 +163,7 @@ def __init__( mm_processor_kwargs: Optional[Dict[str, Any]] = None, # After positional args are removed, move this right below `model` task: TaskOption = "auto", - pooling_type: Optional[str] = None, - pooling_norm: Optional[bool] = None, - pooling_softmax: Optional[bool] = None, - pooling_step_tag_id: Optional[int] = None, - pooling_returned_token_ids: Optional[List[int]] = None, + override_pooler_config: Optional[PoolerConfig] = None, **kwargs, ) -> None: ''' @@ -202,11 +199,7 @@ def __init__( disable_async_output_proc=disable_async_output_proc, hf_overrides=hf_overrides, mm_processor_kwargs=mm_processor_kwargs, - pooling_type=pooling_type, - pooling_norm=pooling_norm, - pooling_softmax=pooling_softmax, - pooling_step_tag_id=pooling_step_tag_id, - pooling_returned_token_ids=pooling_returned_token_ids, + override_pooler_config=override_pooler_config, **kwargs, ) # Logic to switch between engines is done at runtime instead of import diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 024badbc17b96..6fee57a0a03eb 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -63,14 +63,14 @@ def from_config_with_defaults( return cls( pooling_type=PoolingType[pooler_config.pooling_type] if pooler_config.pooling_type is not None else pooling_type, - normalize=pooler_config.pooling_norm - if pooler_config.pooling_norm is not None else normalize, - softmax=pooler_config.pooling_softmax - if pooler_config.pooling_softmax is not None else softmax, - step_tag_id=pooler_config.pooling_step_tag_id - if pooler_config.pooling_step_tag_id is not None else step_tag_id, - returned_token_ids=pooler_config.pooling_returned_token_ids - if pooler_config.pooling_returned_token_ids is not None else + normalize=pooler_config.normalize + if pooler_config.normalize is not None else normalize, + softmax=pooler_config.softmax + if pooler_config.softmax is not None else softmax, + step_tag_id=pooler_config.step_tag_id + if pooler_config.step_tag_id is not None else step_tag_id, + returned_token_ids=pooler_config.returned_token_ids + if pooler_config.returned_token_ids is not None else returned_token_ids, ) @@ -94,10 +94,14 @@ def forward( pooled_data = hidden_states[last_token_flat_indices] elif self.pooling_type == PoolingType.ALL: offset = 0 - pooled_data = [] + pooled_data_lst = [] for prompt_len in prompt_lens: - pooled_data.append(hidden_states[offset:offset + prompt_len]) + pooled_data_i = hidden_states[offset:offset + prompt_len] + + pooled_data_lst.append(pooled_data_i) offset += prompt_len + + pooled_data = torch.stack(pooled_data_lst) elif self.pooling_type == PoolingType.MEAN: # Calculate mean pooling cumsum = torch.cumsum(hidden_states, dim=0) @@ -110,24 +114,26 @@ def forward( cumsum[end_indices - 1] - cumsum[start_indices] + hidden_states[start_indices]) / prompt_lens.unsqueeze(1) elif self.pooling_type == PoolingType.STEP: - if self.returned_token_ids is not None and len( - self.returned_token_ids) > 0: - logits = hidden_states[:, - self.returned_token_ids].softmax(dim=-1) - else: - logits = hidden_states.softmax(dim=-1) + returned_token_ids = self.returned_token_ids + if returned_token_ids is not None and len(returned_token_ids) > 0: + hidden_states = hidden_states[:, returned_token_ids] + + logits = hidden_states.softmax(dim=-1) + step_tag_id = self.step_tag_id + offset = 0 - pooled_data = [] + pooled_data_lst = [] for prompt_len, seq_data_i in zip( prompt_lens, pooling_metadata.seq_data.values()): - if self.step_tag_id is None: - pooled_data.append(logits[offset:offset + prompt_len]) - else: - step_idxs = torch.tensor( - seq_data_i.prompt_token_ids) == self.step_tag_id - pooled_data.append(logits[offset:offset + - prompt_len][step_idxs]) + pooled_data_i = logits[offset:offset + prompt_len] + if step_tag_id is not None: + token_ids = torch.tensor(seq_data_i.prompt_token_ids) + pooled_data_i = pooled_data_i[token_ids == step_tag_id] + offset += prompt_len + pooled_data_lst.append(pooled_data_i) + + pooled_data = torch.stack(pooled_data_lst) else: raise ValueError(f"Invalid pooling type: {self.pooling_type}")