Skip to content

Commit

Permalink
[Misc] Consolidate pooler config overrides (vllm-project#10351)
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 authored Nov 15, 2024
1 parent 8041f94 commit c1900dc
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 192 deletions.
10 changes: 8 additions & 2 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,9 @@ Text Embedding
Some model architectures support both generation and embedding tasks.
In this case, you have to pass :code:`--task embedding` to run the model in embedding mode.

.. tip::
You can override the model's pooling method by passing :code:`--override-pooler-config`.

Reward Modeling
---------------

Expand All @@ -364,7 +367,7 @@ Reward Modeling
- ✅︎

.. note::
As an interim measure, these models are supported via Embeddings API. See `this RFC <https://github.com/vllm-project/vllm/issues/8967>`_ for upcoming changes.
As an interim measure, these models are supported in both offline and online inference via Embeddings API.

Classification
---------------
Expand All @@ -385,7 +388,7 @@ Classification
- ✅︎

.. note::
As an interim measure, these models are supported via Embeddings API. It will be supported via Classification API in the future (no reference APIs exist now).
As an interim measure, these models are supported in both offline and online inference via Embeddings API.


Multimodal Language Models
Expand Down Expand Up @@ -600,6 +603,9 @@ Multimodal Embedding
Some model architectures support both generation and embedding tasks.
In this case, you have to pass :code:`--task embedding` to run the model in embedding mode.

.. tip::
You can override the model's pooling method by passing :code:`--override-pooler-config`.

Model Support Policy
=====================

Expand Down
9 changes: 7 additions & 2 deletions tests/engine/test_arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest

from vllm.config import PoolerConfig
from vllm.engine.arg_utils import EngineArgs, nullable_kvs
from vllm.utils import FlexibleArgumentParser

Expand Down Expand Up @@ -32,9 +33,13 @@ def test_limit_mm_per_prompt_parser(arg, expected):

def test_valid_pooling_config():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
args = parser.parse_args(["--pooling-type=MEAN"])
args = parser.parse_args([
'--override-pooler-config',
'{"pooling_type": "MEAN"}',
])
engine_args = EngineArgs.from_cli_args(args=args)
assert engine_args.pooling_type == 'MEAN'
assert engine_args.override_pooler_config == PoolerConfig(
pooling_type="MEAN", )


@pytest.mark.parametrize(
Expand Down
50 changes: 22 additions & 28 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from dataclasses import asdict

import pytest

from vllm.config import ModelConfig
from vllm.config import ModelConfig, PoolerConfig
from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform

Expand Down Expand Up @@ -108,7 +110,7 @@ def test_get_sliding_window():
reason="Xformers backend is not supported on ROCm.")
def test_get_pooling_config():
model_id = "sentence-transformers/all-MiniLM-L12-v2"
minilm_model_config = ModelConfig(
model_config = ModelConfig(
model_id,
task="auto",
tokenizer=model_id,
Expand All @@ -119,39 +121,31 @@ def test_get_pooling_config():
revision=None,
)

minilm_pooling_config = minilm_model_config._init_pooler_config(
pooling_type=None,
pooling_norm=None,
pooling_returned_token_ids=None,
pooling_softmax=None,
pooling_step_tag_id=None)
pooling_config = model_config._init_pooler_config(None)
assert pooling_config is not None

assert minilm_pooling_config.pooling_norm
assert minilm_pooling_config.pooling_type == PoolingType.MEAN.name
assert pooling_config.normalize
assert pooling_config.pooling_type == PoolingType.MEAN.name


@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
def test_get_pooling_config_from_args():
model_id = "sentence-transformers/all-MiniLM-L12-v2"
minilm_model_config = ModelConfig(model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
revision=None)

minilm_pooling_config = minilm_model_config._init_pooler_config(
pooling_type='CLS',
pooling_norm=True,
pooling_returned_token_ids=None,
pooling_softmax=None,
pooling_step_tag_id=None)

assert minilm_pooling_config.pooling_norm
assert minilm_pooling_config.pooling_type == PoolingType.CLS.name
model_config = ModelConfig(model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
revision=None)

override_config = PoolerConfig(pooling_type='CLS', normalize=True)

pooling_config = model_config._init_pooler_config(override_config)
assert pooling_config is not None
assert asdict(pooling_config) == asdict(override_config)


@pytest.mark.skipif(current_platform.is_rocm(),
Expand Down
112 changes: 57 additions & 55 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,31 +112,19 @@ class ModelConfig:
the model name will be the same as `model`.
limit_mm_per_prompt: Maximum number of data items per modality
per prompt. Only applicable for multimodal models.
override_neuron_config: Initialize non default neuron config or
override default neuron config that are specific to Neuron devices,
this argument will be used to configure the neuron config that
can not be gathered from the vllm arguments.
config_format: The config format which shall be loaded.
Defaults to 'auto' which defaults to 'hf'.
hf_overrides: If a dictionary, contains arguments to be forwarded to the
HuggingFace config. If a callable, it is called to update the
HuggingFace config.
mm_processor_kwargs: Arguments to be forwarded to the model's processor
for multi-modal data, e.g., image processor.
pooling_type: Used to configure the pooling method in the embedding
model.
pooling_norm: Used to determine whether to normalize the pooled
data in the embedding model.
pooling_softmax: Used to determine whether to softmax the pooled
data in the embedding model.
pooling_step_tag_id: When pooling_step_tag_id is not -1, it indicates
that the score corresponding to the pooling_step_tag_id in the
generated sentence should be returned. Otherwise, it returns
the scores for all tokens.
pooling_returned_token_ids: pooling_returned_token_ids represents a
list of indices for the vocabulary dimensions to be extracted,
such as the token IDs of good_token and bad_token in the
math-shepherd-mistral-7b-prm model.
override_neuron_config: Initialize non default neuron config or
override default neuron config that are specific to Neuron devices,
this argument will be used to configure the neuron config that
can not be gathered from the vllm arguments.
override_pooling_config: Initialize non default pooling config or
override default pooling config for the embedding model.
"""

def __init__(
Expand Down Expand Up @@ -166,16 +154,12 @@ def __init__(
served_model_name: Optional[Union[str, List[str]]] = None,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
use_async_output_proc: bool = True,
override_neuron_config: Optional[Dict[str, Any]] = None,
config_format: ConfigFormat = ConfigFormat.AUTO,
chat_template_text_format: str = "string",
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
pooling_type: Optional[str] = None,
pooling_norm: Optional[bool] = None,
pooling_softmax: Optional[bool] = None,
pooling_step_tag_id: Optional[int] = None,
pooling_returned_token_ids: Optional[List[int]] = None) -> None:
override_neuron_config: Optional[Dict[str, Any]] = None,
override_pooler_config: Optional["PoolerConfig"] = None) -> None:
self.model = model
self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode
Expand Down Expand Up @@ -280,13 +264,7 @@ def __init__(
supported_tasks, task = self._resolve_task(task, self.hf_config)
self.supported_tasks = supported_tasks
self.task: Final = task
self.pooler_config = self._init_pooler_config(
pooling_type,
pooling_norm,
pooling_softmax,
pooling_step_tag_id,
pooling_returned_token_ids,
)
self.pooler_config = self._init_pooler_config(override_pooler_config)

self._verify_quantization()
self._verify_cuda_graph()
Expand All @@ -311,27 +289,21 @@ def _get_encoder_config(self):

def _init_pooler_config(
self,
pooling_type: Optional[str] = None,
pooling_norm: Optional[bool] = None,
pooling_softmax: Optional[bool] = None,
pooling_step_tag_id: Optional[int] = None,
pooling_returned_token_ids: Optional[List[int]] = None
override_pooler_config: Optional["PoolerConfig"],
) -> Optional["PoolerConfig"]:

if self.task == "embedding":
pooling_config = get_pooling_config(self.model, self.revision)
if pooling_config is not None:
# override if user does not
# specifies pooling_type and/or pooling_norm
if pooling_type is None:
pooling_type = pooling_config["pooling_type"]
if pooling_norm is None:
pooling_norm = pooling_config["normalize"]
return PoolerConfig(
pooling_type=pooling_type,
pooling_norm=pooling_norm,
pooling_softmax=pooling_softmax,
pooling_step_tag_id=pooling_step_tag_id,
pooling_returned_token_ids=pooling_returned_token_ids)
user_config = override_pooler_config or PoolerConfig()

base_config = get_pooling_config(self.model, self.revision)
if base_config is not None:
# Only set values that are not overridden by the user
for k, v in base_config.items():
if getattr(user_config, k) is None:
setattr(user_config, k, v)

return user_config

return None

def _init_attention_free(self) -> bool:
Expand Down Expand Up @@ -1786,13 +1758,43 @@ class MultiModalConfig:

@dataclass
class PoolerConfig:
"""Controls the behavior of pooler in embedding model"""
"""Controls the behavior of output pooling in embedding models."""

pooling_type: Optional[str] = None
pooling_norm: Optional[bool] = None
pooling_softmax: Optional[bool] = None
pooling_step_tag_id: Optional[int] = None
pooling_returned_token_ids: Optional[List[int]] = None
"""
The pooling method of the embedding model. This should be a key in
:class:`vllm.model_executor.layers.pooler.PoolingType`.
"""

normalize: Optional[bool] = None
"""
Whether to normalize the pooled outputs. Usually, this should be set to
``True`` for embedding outputs.
"""

softmax: Optional[bool] = None
"""
Whether to apply softmax to the pooled outputs. Usually, this should be set
to ``True`` for classification outputs.
"""

step_tag_id: Optional[int] = None
"""
If set, only the score corresponding to the ``step_tag_id`` in the
generated sentence should be returned. Otherwise, the scores for all tokens
are returned.
"""

returned_token_ids: Optional[List[int]] = None
"""
A list of indices for the vocabulary dimensions to be extracted,
such as the token IDs of ``good_token`` and ``bad_token`` in the
``math-shepherd-mistral-7b-prm`` model.
"""

@staticmethod
def from_json(json_str: str) -> "PoolerConfig":
return PoolerConfig(**json.loads(json_str))


_STR_DTYPE_TO_TORCH_DTYPE = {
Expand Down
Loading

0 comments on commit c1900dc

Please sign in to comment.