Skip to content

Commit

Permalink
Adds pooling config as engine CLI arg
Browse files Browse the repository at this point in the history
Signed-off-by: Flavia Beo <[email protected]>
  • Loading branch information
flaviabeo committed Oct 28, 2024
1 parent 32ee574 commit 0b948a4
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 2 deletions.
11 changes: 11 additions & 0 deletions tests/engine/test_arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from vllm.engine.arg_utils import EngineArgs, nullable_kvs
from vllm.utils import FlexibleArgumentParser

from vllm.model_executor.layers.pooler import PoolingConfig


@pytest.mark.parametrize(("arg", "expected"), [
(None, None),
Expand All @@ -30,6 +32,15 @@ def test_limit_mm_per_prompt_parser(arg, expected):
assert args.limit_mm_per_prompt == expected


def test_valid_pooling_config():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
args = parser.parse_args(["--pooling-type=MEAN", "--normalize=True"])
expected_pooling_config = PoolingConfig(pooling_type='MEAN',
normalize=True)
engine_args = EngineArgs.from_cli_args(args=args)
assert engine_args.pooling_config == expected_pooling_config


@pytest.mark.parametrize(
("arg"),
[
Expand Down
3 changes: 2 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def __init__(self,
override_neuron_config: Optional[Dict[str, Any]] = None,
config_format: ConfigFormat = ConfigFormat.AUTO,
chat_template_text_format: str = "string",
pooling_config: Optional[PoolingConfig] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None) -> None:
self.model = model
self.tokenizer = tokenizer
Expand Down Expand Up @@ -179,7 +180,7 @@ def __init__(self,
code_revision, rope_scaling, rope_theta,
config_format)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.pooling_config = self.get_pooling_config()
self.pooling_config = pooling_config or self.get_pooling_config()
self.bert_config = self._get_bert_config()
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, revision)
Expand Down
19 changes: 19 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.layers.pooler import PoolingConfig
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.transformers_utils.utils import check_gguf_file
Expand Down Expand Up @@ -183,6 +184,7 @@ class EngineArgs:
override_neuron_config: Optional[Dict[str, Any]] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None
scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
pooling_config: Optional[PoolingConfig] = None

def __post_init__(self):
if not self.tokenizer:
Expand Down Expand Up @@ -850,13 +852,29 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'priority (lower value means earlier handling) and time of '
'arrival deciding any ties).')

parser.add_argument('--pooling-type',
type=str,
default='CLS',
choices=['LAST', 'ALL', 'CLS', 'MEAN'],
help='Configures the pooling operation which '
'only applies to sentence-transformers models. ')

parser.add_argument('--normalize',
type=bool,
default=None,
help='Wheter to normalize the pooled data.')

return parser

@classmethod
def from_cli_args(cls, args: argparse.Namespace):
# Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)]
# Set the attributes from the parsed arguments.
pooling_config = PoolingConfig(pooling_type=args.pooling_type,
normalize=args.normalize)
d = vars(args)
d['pooling_config'] = pooling_config
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
return engine_args

Expand Down Expand Up @@ -890,6 +908,7 @@ def create_model_config(self) -> ModelConfig:
use_async_output_proc=not self.disable_async_output_proc,
override_neuron_config=self.override_neuron_config,
config_format=self.config_format,
pooling_config=self.pooling_config,
mm_processor_kwargs=self.mm_processor_kwargs,
)

Expand Down
3 changes: 3 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.pooler import PoolingConfig
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest, LLMGuidedOptions)
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
Expand Down Expand Up @@ -156,6 +157,7 @@ def __init__(
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = False,
pooling_config: Optional[PoolingConfig] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
# After positional args are removed, move this right below `model`
task: TaskOption = "auto",
Expand Down Expand Up @@ -192,6 +194,7 @@ def __init__(
max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
disable_async_output_proc=disable_async_output_proc,
pooling_config=pooling_config,
mm_processor_kwargs=mm_processor_kwargs,
**kwargs,
)
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/layers/pooler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from dataclasses import dataclass
from enum import IntEnum

from typing import Optional

import torch
import torch.nn as nn

Expand Down Expand Up @@ -32,7 +34,7 @@ class PoolingConfig:
type enum value corresponding to the given string.
"""

def __init__(self, pooling_type: str, normalize: bool):
def __init__(self, pooling_type: str, normalize: Optional[bool] = None):
self.pooling_type = self.get_pooling_type(pooling_type)
self.normalize = normalize

Expand Down

0 comments on commit 0b948a4

Please sign in to comment.