Skip to content

Commit

Permalink
Fix merge conflicts
Browse files Browse the repository at this point in the history
Signed-off-by: Flavia Beo <[email protected]>
  • Loading branch information
flaviabeo committed Oct 31, 2024
1 parent 4531c33 commit 8df9d63
Show file tree
Hide file tree
Showing 10 changed files with 57 additions and 39 deletions.
3 changes: 1 addition & 2 deletions tests/engine/test_arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,9 @@ def test_limit_mm_per_prompt_parser(arg, expected):

def test_valid_pooling_config():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
args = parser.parse_args(["--pooling-type=MEAN", "--normalize=True"])
args = parser.parse_args(["--pooling-type=MEAN"])
engine_args = EngineArgs.from_cli_args(args=args)
assert engine_args.pooling_type == 'MEAN'
assert engine_args.normalize


@pytest.mark.parametrize(
Expand Down
4 changes: 2 additions & 2 deletions tests/model_executor/test_model_load_with_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def test_model_loading_with_params(vllm_runner):
assert model_config.bert_config["do_lower_case"]

# asserts on the pooling config files
assert model_config.pooling_config.pooling_type == PoolingType.CLS
assert model_config.pooling_config.normalize
assert model_config.pooler_config.pooling_type == PoolingType.CLS.name
assert model_config.pooler_config.pooling_norm

# asserts on the tokenizer loaded
assert model_tokenizer.tokenizer_id == "BAAI/bge-base-en-v1.5"
Expand Down
28 changes: 17 additions & 11 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from vllm.config import ModelConfig
from vllm.model_executor.layers.pooler import PoolingConfig, PoolingType
from vllm.model_executor.layers.pooler import PoolingType


@pytest.mark.parametrize(("model_id", "expected_task"), [
Expand Down Expand Up @@ -116,12 +116,15 @@ def test_get_pooling_config():
revision=None,
)

minilm_pooling_config = minilm_model_config.get_pooling_config(None, None)
minilm_pooling_config = minilm_model_config._init_pooler_config(
pooling_type=None,
pooling_norm=None,
pooling_returned_token_ids=None,
pooling_softmax=None,
pooling_step_tag_id=None)

assert isinstance(minilm_model_config.pooling_config, PoolingConfig)
assert minilm_pooling_config.normalize
assert isinstance(minilm_pooling_config.pooling_type, PoolingType)
assert minilm_pooling_config.pooling_type == PoolingType.MEAN
assert minilm_pooling_config.pooling_norm
assert minilm_pooling_config.pooling_type == PoolingType.MEAN.name


def test_get_pooling_config_from_args():
Expand All @@ -135,12 +138,15 @@ def test_get_pooling_config_from_args():
dtype="float16",
revision=None)

minilm_pooling_config = minilm_model_config.get_pooling_config('CLS', True)
minilm_pooling_config = minilm_model_config._init_pooler_config(
pooling_type='CLS',
pooling_norm=True,
pooling_returned_token_ids=None,
pooling_softmax=None,
pooling_step_tag_id=None)

assert isinstance(minilm_model_config.pooling_config, PoolingConfig)
assert minilm_pooling_config.normalize
assert isinstance(minilm_pooling_config.pooling_type, PoolingType)
assert minilm_pooling_config.pooling_type == PoolingType.CLS
assert minilm_pooling_config.pooling_norm
assert minilm_pooling_config.pooling_type == PoolingType.CLS.name


def test_get_bert_tokenization_sentence_transformer_config():
Expand Down
30 changes: 22 additions & 8 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,6 @@ class ModelConfig:
can not be gathered from the vllm arguments.
config_format: The config format which shall be loaded.
Defaults to 'auto' which defaults to 'hf'.
pooling_config: pooling and normalize config from the model -
only applies to sentence-transformers models.
bert_config: tokenizationconfiguration dictionary for a given
Sentence Transformer BERT model.
mm_processor_kwargs: Arguments to be forwarded to the model's processor
Expand Down Expand Up @@ -282,12 +280,28 @@ def _init_pooler_config(
pooling_returned_token_ids: Optional[List[int]] = None
) -> Optional["PoolerConfig"]:
if self.task == "embedding":
return PoolerConfig(
pooling_type=pooling_type,
pooling_norm=pooling_norm,
pooling_softmax=pooling_softmax,
pooling_step_tag_id=pooling_step_tag_id,
pooling_returned_token_ids=pooling_returned_token_ids)
pooling_config = get_pooling_config(self.model, self.revision)
if pooling_config is not None:
pooling_type_from_file = pooling_config["pooling_type"]
normalize_from_file = pooling_config["normalize"]
pooling_config_from_file = PoolerConfig(
pooling_type=pooling_type_from_file,
pooling_norm=normalize_from_file,
pooling_softmax=pooling_softmax,
pooling_step_tag_id=pooling_step_tag_id,
pooling_returned_token_ids=pooling_returned_token_ids)
if pooling_type is not None:
pooling_config_from_file.pooling_type = pooling_type
if pooling_norm is not None:
pooling_config_from_file.pooling_norm = pooling_norm
return pooling_config_from_file
else:
return PoolerConfig(
pooling_type=pooling_type,
pooling_norm=pooling_norm,
pooling_softmax=pooling_softmax,
pooling_step_tag_id=pooling_step_tag_id,
pooling_returned_token_ids=pooling_returned_token_ids)
return None

def _init_attention_free(self) -> bool:
Expand Down
5 changes: 1 addition & 4 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
SpeculativeConfig, TaskOption, TokenizerPoolConfig)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import PoolingType
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
Expand Down Expand Up @@ -183,8 +182,6 @@ class EngineArgs:
override_neuron_config: Optional[Dict[str, Any]] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None
scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
pooling_type: Optional[str] = None
normalize: Optional[bool] = None

# Pooling configuration.
pooling_type: Optional[str] = None
Expand Down Expand Up @@ -853,7 +850,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:

parser.add_argument(
'--pooling-type',
choices=['LAST', 'ALL', 'CLS', 'STEP'],
choices=['LAST', 'ALL', 'CLS', 'STEP', 'MEAN'],
default=None,
help='Used to configure the pooling method in the embedding model.'
)
Expand Down
16 changes: 11 additions & 5 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,15 +278,21 @@ def __init__(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
parallel_config.disable_custom_all_reduce,
model_config.quantization, model_config.enforce_eager,
cache_config.cache_dtype, model_config.quantization_param_path,
device_config.device, decoding_config, observability_config,
model_config.seed, model_config.served_model_name,
model_config.quantization,
model_config.enforce_eager,
cache_config.cache_dtype,
model_config.quantization_param_path,
device_config.device,
decoding_config,
observability_config,
model_config.seed,
model_config.served_model_name,
scheduler_config.num_scheduler_steps,
scheduler_config.chunked_prefill_enabled,
scheduler_config.multi_step_stream_outputs,
cache_config.enable_prefix_caching,
model_config.use_async_output_proc, use_cached_outputs,
model_config.use_async_output_proc,
use_cached_outputs,
model_config.chat_template_text_format,
model_config.mm_processor_kwargs,
model_config.pooler_config,
Expand Down
1 change: 0 additions & 1 deletion vllm/model_executor/layers/pooler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from dataclasses import dataclass
from enum import IntEnum
from typing import List, Optional

Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import (Pooler, PoolingConfig,
PoolingType)
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import (Pooler, PoolingConfig,
PoolingType)
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale)
Expand Down
3 changes: 1 addition & 2 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import PoolingType
# yapf conflicts with isort for this block
# yapf: disable
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
Expand Down Expand Up @@ -337,7 +336,7 @@ def get_pooling_config_name(pooling_name):
if "lasttoken" in pooling_name:
pooling_name = "last"

supported_pooling_types = [i.name for i in PoolingType]
supported_pooling_types = ['LAST', 'ALL', 'CLS', 'STEP', 'MEAN']
pooling_type_name = pooling_name.upper()

try:
Expand Down

0 comments on commit 8df9d63

Please sign in to comment.