Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Linter test #9598

Closed
wants to merge 10 commits into from
45 changes: 45 additions & 0 deletions tests/model_executor/test_model_load_with_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import os

from vllm.model_executor.layers.pooler import PoolingType
from vllm.model_executor.models.bert import BertEmbeddingModel

MAX_MODEL_LEN = 128
MODEL_NAME = os.environ.get("MODEL_NAME", "BAAI/bge-base-en-v1.5")
REVISION = os.environ.get("REVISION", "main")


def test_model_loading_with_params(vllm_runner):
"""
Test parameter weight loading with tp>1.
"""
with vllm_runner(model_name=MODEL_NAME,
revision=REVISION,
dtype="float16",
max_model_len=MAX_MODEL_LEN) as model:
output = model.encode("Write a short story about a robot that"
" dreams for the first time.\n")

model_config = model.model.llm_engine.model_config

model_tokenizer = model.model.llm_engine.tokenizer

# asserts on the bert model config file
assert model_config.bert_config["max_seq_length"] == 512
assert model_config.bert_config["do_lower_case"]

# asserts on the pooling config files
assert model_config.pooling_config.pooling_type == PoolingType.CLS
assert model_config.pooling_config.normalize

# asserts on the tokenizer loaded
assert model_tokenizer.tokenizer_id == "BAAI/bge-base-en-v1.5"
assert model_tokenizer.tokenizer_config["do_lower_case"]
assert model_tokenizer.tokenizer.model_max_length == 512

model = model.model.llm_engine.model_executor\
.driver_worker.model_runner.model
assert isinstance(model, BertEmbeddingModel)
assert model._pooler.pooling_type == PoolingType.CLS
assert model._pooler.normalize
# assert output
assert output
40 changes: 40 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from vllm.config import ModelConfig
from vllm.model_executor.layers.pooler import PoolingConfig, PoolingType


@pytest.mark.parametrize(("model_id", "expected_task"), [
Expand Down Expand Up @@ -102,6 +103,45 @@ def test_get_sliding_window():
assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW


def test_get_pooling_config():
model_id = "sentence-transformers/all-MiniLM-L12-v2"
minilm_model_config = ModelConfig(
model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
revision=None,
)

minilm_pooling_config = minilm_model_config.get_pooling_config()

assert isinstance(minilm_model_config.pooling_config, PoolingConfig)
assert minilm_pooling_config.normalize
assert isinstance(minilm_pooling_config.pooling_type, PoolingType)
assert minilm_pooling_config.pooling_type == PoolingType.MEAN


def test_get_bert_tokenization_sentence_transformer_config():
bge_model_config = ModelConfig(
model="BAAI/bge-base-en-v1.5",
task="auto",
tokenizer="BAAI/bge-base-en-v1.5",
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
revision=None,
)

bert_bge_model_config = bge_model_config._get_bert_config()

assert bert_bge_model_config["max_seq_length"] == 512
assert bert_bge_model_config["do_lower_case"]


def test_rope_customization():
TEST_ROPE_SCALING = {"rope_type": "dynamic", "factor": 2.0}
TEST_ROPE_THETA = 16_000_000.0
Expand Down
32 changes: 28 additions & 4 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@

import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import PoolingConfig
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.models import ModelRegistry
from vllm.platforms import current_platform
from vllm.tracing import is_otel_available, otel_import_error_traceback
from vllm.transformers_utils.config import (ConfigFormat, get_config,
get_hf_image_processor_config,
get_hf_text_config)
from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config,
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
is_hip, is_openvino, print_warning_once)

Expand Down Expand Up @@ -110,6 +112,10 @@ class ModelConfig:
can not be gathered from the vllm arguments.
config_format: The config format which shall be loaded.
Defaults to 'auto' which defaults to 'hf'.
pooling_config: pooling and normalize config from the model -
only applies to sentence-transformers models.
bert_config: tokenizationconfiguration dictionary for a given
Sentence Transformer BERT model.
mm_processor_kwargs: Arguments to be forwarded to the model's processor
for multi-modal data, e.g., image processor.
"""
Expand Down Expand Up @@ -173,6 +179,8 @@ def __init__(self,
code_revision, rope_scaling, rope_theta,
config_format)
self.hf_text_config = get_hf_text_config(self.hf_config)
self.pooling_config = self.get_pooling_config()
self.bert_config = self._get_bert_config()
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, revision)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
Expand Down Expand Up @@ -205,7 +213,8 @@ def __init__(self,
max_model_len=max_model_len,
disable_sliding_window=self.disable_sliding_window,
sliding_window_len=self.get_hf_config_sliding_window(),
spec_target_max_model_len=spec_target_max_model_len)
spec_target_max_model_len=spec_target_max_model_len,
bert_config=self.bert_config)
self.served_model_name = get_served_model_name(model,
served_model_name)
self.multimodal_config = self._init_multimodal_config(
Expand Down Expand Up @@ -242,6 +251,10 @@ def _init_multimodal_config(

return None

def _get_bert_config(self):
return get_sentence_transformer_tokenizer_config(
self.model, self.revision)

def _init_attention_free(self) -> bool:
architectures = getattr(self.hf_config, "architectures", [])
return ModelRegistry.is_attention_free_model(architectures)
Expand Down Expand Up @@ -405,6 +418,13 @@ def _verify_bnb_config(self) -> None:
"fallback to the eager mode.")
self.enforce_eager = True

def get_pooling_config(self) -> Optional[PoolingConfig]:
pooling_config = get_pooling_config(self.model, self.revision)
if pooling_config is not None:
return PoolingConfig(pooling_config["pooling_type"],
pooling_config["normalize"])
return None

def verify_async_output_proc(self, parallel_config, speculative_config,
device_config) -> None:
if not self.use_async_output_proc:
Expand Down Expand Up @@ -1715,6 +1735,7 @@ def _get_and_verify_max_len(
disable_sliding_window: bool,
sliding_window_len: Optional[Union[int, List[Optional[int]]]],
spec_target_max_model_len: Optional[int] = None,
bert_config: Optional[Any] = None,
) -> int:
"""Get and verify the model's maximum length."""
derived_max_model_len = float("inf")
Expand Down Expand Up @@ -1797,6 +1818,9 @@ def _get_and_verify_max_len(
"original_max_position_embeddings"]
derived_max_model_len *= scaling_factor

if bert_config and "max_seq_length" in bert_config:
derived_max_model_len = bert_config["max_seq_length"]

# If the user specified a max length, make sure it is smaller than the
# derived length from the HF model config.
if max_model_len is None:
Expand Down
49 changes: 19 additions & 30 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,44 +254,33 @@ def __init__(
"num_scheduler_steps=%d, chunked_prefill_enabled=%s "
"multi_step_stream_outputs=%s, enable_prefix_caching=%s, "
"use_async_output_proc=%s, use_cached_outputs=%s, "
"pooling_config_type=%s, normalize=%s, "
"chat_template_text_format=%s, mm_processor_kwargs=%s)",
VLLM_VERSION,
model_config.model,
speculative_config,
model_config.tokenizer,
model_config.skip_tokenizer_init,
model_config.tokenizer_mode,
model_config.revision,
model_config.override_neuron_config,
model_config.rope_scaling,
model_config.rope_theta,
model_config.tokenizer_revision,
model_config.trust_remote_code,
model_config.dtype,
model_config.max_model_len,
load_config.download_dir,
load_config.load_format,
parallel_config.tensor_parallel_size,
VLLM_VERSION, model_config.model, speculative_config,
model_config.tokenizer, model_config.skip_tokenizer_init,
model_config.tokenizer_mode, model_config.revision,
model_config.override_neuron_config, model_config.rope_scaling,
model_config.rope_theta, model_config.tokenizer_revision,
model_config.trust_remote_code, model_config.dtype,
model_config.max_model_len, load_config.download_dir,
load_config.load_format, parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
parallel_config.disable_custom_all_reduce,
model_config.quantization,
model_config.enforce_eager,
cache_config.cache_dtype,
model_config.quantization_param_path,
device_config.device,
decoding_config,
observability_config,
model_config.seed,
model_config.served_model_name,
model_config.quantization, model_config.enforce_eager,
cache_config.cache_dtype, model_config.quantization_param_path,
device_config.device, decoding_config, observability_config,
model_config.seed, model_config.served_model_name,
scheduler_config.num_scheduler_steps,
scheduler_config.chunked_prefill_enabled,
scheduler_config.multi_step_stream_outputs,
cache_config.enable_prefix_caching,
model_config.use_async_output_proc,
use_cached_outputs,
model_config.use_async_output_proc, use_cached_outputs,
model_config.pooling_config.pooling_type
if model_config.pooling_config is not None else None,
model_config.pooling_config.normalize
if model_config.pooling_config is not None else None,
model_config.chat_template_text_format,
model_config.mm_processor_kwargs,
)
model_config.mm_processor_kwargs)
# TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config
self.cache_config = cache_config
Expand Down
41 changes: 40 additions & 1 deletion vllm/model_executor/layers/pooler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import dataclass
from enum import IntEnum

import torch
Expand All @@ -13,6 +14,33 @@ class PoolingType(IntEnum):
LAST = 0
ALL = 1
CLS = 2
MEAN = 3


@dataclass
class PoolingConfig:
"""A class that configures the pooling operation which
only applies to sentence-transformers models.
More at: https://www.sbert.net/

Attributes:
pooling_type (str): The type of pooling to use.
normalize (bool): Whether to normalize the pooled data.

Methods:
get_pooling_type(pooling_type_name): Returns the pooling
type enum value corresponding to the given string.
"""

def __init__(self, pooling_type: str, normalize: bool):
self.pooling_type = self.get_pooling_type(pooling_type)
self.normalize = normalize

def get_pooling_type(self, pooling_type_name: str) -> PoolingType:
pooling_types = PoolingType.__dict__.items()
return PoolingType(
next((value for key, value in pooling_types
if key.lower() in pooling_type_name), PoolingType.CLS))


class Pooler(nn.Module):
Expand All @@ -24,7 +52,7 @@ class Pooler(nn.Module):
3. Returns structured results as `PoolerOutput`.

Attributes:
pooling_type: The type of pooling to use (LAST, ALL, CLS).
pooling_type: The type of pooling to use (LAST, ALL, CLS, MEAN).
normalize: Whether to normalize the pooled data.
"""

Expand Down Expand Up @@ -58,6 +86,17 @@ def forward(
for prompt_len in prompt_lens:
pooled_data.append(hidden_states[offset:offset + prompt_len])
offset += prompt_len
elif self.pooling_type == PoolingType.MEAN:
# Calculate mean pooling
cumsum = torch.cumsum(hidden_states, dim=0)
start_indices = torch.cat([
torch.tensor([0], device=hidden_states.device),
torch.cumsum(prompt_lens[:-1], dim=0)
])
end_indices = torch.cumsum(prompt_lens, dim=0)
pooled_data = (
cumsum[end_indices - 1] - cumsum[start_indices] +
hidden_states[start_indices]) / prompt_lens.unsqueeze(1)
else:
raise ValueError(f"Invalid pooling type: {self.pooling_type}")

Expand Down
37 changes: 23 additions & 14 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
get_tensor_model_parallel_world_size)
from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import PoolingConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.model_loader.tensorizer import (
Expand Down Expand Up @@ -122,7 +123,8 @@ def _get_model_initialization_kwargs(
model_class: Type[nn.Module],
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
scheduler_config: Optional[SchedulerConfig] = None) -> Dict[str, Any]:
scheduler_config: Optional[SchedulerConfig] = None,
pooling_config: Optional[PoolingConfig] = None) -> Dict[str, Any]:
"""Get extra kwargs for model initialization."""
extra_kwargs: Dict[str, Any] = {}

Expand All @@ -144,18 +146,25 @@ def _get_model_initialization_kwargs(
if has_inner_state(model_class) and scheduler_config:
extra_kwargs["scheduler_config"] = scheduler_config

if pooling_config is not None:
extra_kwargs["pooling_config"] = pooling_config

return extra_kwargs


def build_model(model_class: Type[nn.Module], hf_config: PretrainedConfig,
def build_model(model_class: Type[nn.Module],
hf_config: PretrainedConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig], *,
quant_config: Optional[QuantizationConfig],
*,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
scheduler_config: Optional[SchedulerConfig]) -> nn.Module:
scheduler_config: Optional[SchedulerConfig],
pooling_config: Optional[PoolingConfig] = None) -> nn.Module:
extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config,
multimodal_config,
scheduler_config)
scheduler_config,
pooling_config)

return model_class(config=hf_config,
cache_config=cache_config,
Expand All @@ -172,15 +181,15 @@ def _initialize_model(
"""Initialize a model with the given configurations."""
model_class, _ = get_model_architecture(model_config)

return build_model(
model_class,
model_config.hf_config,
cache_config=cache_config,
quant_config=_get_quantization_config(model_config, load_config),
lora_config=lora_config,
multimodal_config=model_config.multimodal_config,
scheduler_config=scheduler_config,
)
return build_model(model_class,
model_config.hf_config,
cache_config=cache_config,
quant_config=_get_quantization_config(
model_config, load_config),
lora_config=lora_config,
multimodal_config=model_config.multimodal_config,
scheduler_config=scheduler_config,
pooling_config=model_config.pooling_config)


class BaseModelLoader(ABC):
Expand Down
Loading