Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Explicit interface for vLLM models and support OOT embedding models #9108

Merged
merged 8 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,7 @@ def num_gpus_available():
temp_dir = tempfile.gettempdir()
_dummy_opt_path = os.path.join(temp_dir, "dummy_opt")
_dummy_llava_path = os.path.join(temp_dir, "dummy_llava")
_dummy_gemma2_embedding_path = os.path.join(temp_dir, "dummy_gemma2_embedding")


@pytest.fixture
Expand Down Expand Up @@ -923,3 +924,22 @@ def dummy_llava_path():
with open(json_path, "w") as f:
json.dump(config, f)
return _dummy_llava_path


@pytest.fixture
def dummy_gemma2_embedding_path():
json_path = os.path.join(_dummy_gemma2_embedding_path, "config.json")
if not os.path.exists(_dummy_gemma2_embedding_path):
snapshot_download(repo_id="BAAI/bge-multilingual-gemma2",
local_dir=_dummy_gemma2_embedding_path,
ignore_patterns=[
"*.bin", "*.bin.index.json", "*.pt", "*.h5",
"*.msgpack"
])
assert os.path.exists(json_path)
with open(json_path, "r") as f:
config = json.load(f)
config["architectures"] = ["MyGemma2Embedding"]
with open(json_path, "w") as f:
json.dump(config, f)
return _dummy_gemma2_embedding_path
18 changes: 15 additions & 3 deletions tests/models/test_oot_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from vllm import LLM, SamplingParams
from vllm import LLM, PoolingParams, SamplingParams
from vllm.assets.image import ImageAsset

from ..utils import fork_new_process_for_each_test
Expand All @@ -17,7 +17,7 @@ def test_plugin(dummy_opt_path):


@fork_new_process_for_each_test
def test_oot_registration(dummy_opt_path):
def test_oot_registration_text_generation(dummy_opt_path):
os.environ["VLLM_PLUGINS"] = "register_dummy_model"
prompts = ["Hello, my name is", "The text does not matter"]
sampling_params = SamplingParams(temperature=0)
Expand All @@ -32,11 +32,23 @@ def test_oot_registration(dummy_opt_path):
assert rest == ""


@fork_new_process_for_each_test
def test_oot_registration_embedding(dummy_gemma2_embedding_path):
os.environ["VLLM_PLUGINS"] = "register_dummy_model"
prompts = ["Hello, my name is", "The text does not matter"]
sampling_params = PoolingParams()
llm = LLM(model=dummy_gemma2_embedding_path, load_format="dummy")
outputs = llm.encode(prompts, sampling_params)

for output in outputs:
assert all(v == 0 for v in output.outputs.embedding)


image = ImageAsset("cherry_blossom").pil_image.convert("RGB")


@fork_new_process_for_each_test
def test_oot_multimodal_registration(dummy_llava_path):
def test_oot_registration_multimodal(dummy_llava_path):
os.environ["VLLM_PLUGINS"] = "register_dummy_model"
prompts = [{
"prompt": "What's in the image?<image>",
Expand Down
24 changes: 22 additions & 2 deletions tests/models/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@
import pytest
import torch.cuda

from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models import (is_embedding_model,
is_text_generation_model,
supports_multimodal)
from vllm.model_executor.models.registry import (_EMBEDDING_MODELS,
_MULTIMODAL_MODELS,
_SPECULATIVE_DECODING_MODELS,
_TEXT_GENERATION_MODELS,
ModelRegistry)
from vllm.platforms import current_platform

from ..utils import fork_new_process_for_each_test
Expand All @@ -12,7 +19,20 @@
@pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs())
def test_registry_imports(model_arch):
# Ensure all model classes can be imported successfully
ModelRegistry.resolve_model_cls(model_arch)
model_cls, _ = ModelRegistry.resolve_model_cls(model_arch)

if model_arch in _SPECULATIVE_DECODING_MODELS:
pass # Ignore these models which do not have a unified format
else:
assert is_text_generation_model(model_cls) is (
model_arch in _TEXT_GENERATION_MODELS
or model_arch in _MULTIMODAL_MODELS)

assert is_embedding_model(model_cls) is (model_arch
in _EMBEDDING_MODELS)

assert supports_multimodal(model_cls) is (model_arch
in _MULTIMODAL_MODELS)


@fork_new_process_for_each_test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ def register():
ModelRegistry.register_model("MyOPTForCausalLM", MyOPTForCausalLM)

# Test passing lazy model
if "MyGemma2Embedding" not in ModelRegistry.get_supported_archs():
ModelRegistry.register_model(
"MyGemma2Embedding",
"vllm_add_dummy_model.my_gemma_embedding:MyGemma2Embedding",
)

if "MyLlava" not in ModelRegistry.get_supported_archs():
ModelRegistry.register_model("MyLlava",
"vllm_add_dummy_model.my_llava:MyLlava")
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch

from vllm.model_executor.models.gemma2_embedding import Gemma2EmbeddingModel


class MyGemma2Embedding(Gemma2EmbeddingModel):

def forward(self, *args, **kwargs) -> torch.Tensor:
hidden_states = super().forward(*args, **kwargs)

# We assume PP isn't used in the test
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(hidden_states, torch.Tensor)

# Return all-zero embeddings
return torch.zeros_like(hidden_states)
7 changes: 7 additions & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal,
SupportsPP, has_inner_state, supports_lora,
supports_multimodal, supports_pp)
from .interfaces_base import (VllmModelForEmbedding,
VllmModelForTextGeneration, is_embedding_model,
is_text_generation_model)
from .registry import ModelRegistry

__all__ = [
"ModelRegistry",
"VllmModelForEmbedding",
"is_embedding_model",
"VllmModelForTextGeneration",
"is_text_generation_model",
"HasInnerState",
"has_inner_state",
"SupportsLoRA",
Expand Down
28 changes: 7 additions & 21 deletions vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import inspect
from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional,
Protocol, Type, Union, overload, runtime_checkable)

import torch
from typing_extensions import TypeIs

from vllm.logger import init_logger
from vllm.utils import supports_kw

if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
from vllm.sequence import IntermediateTensors

Expand Down Expand Up @@ -142,9 +141,7 @@ def supports_lora(
return result


def _supports_lora(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
def _supports_lora(model: Union[Type[object], object]) -> bool:
if isinstance(model, type):
return isinstance(model, _SupportsLoRAType)

Expand Down Expand Up @@ -175,10 +172,7 @@ def make_empty_intermediate_tensors(

def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: "AttentionMetadata",
*,
intermediate_tensors: Optional["IntermediateTensors"],
) -> Union[torch.Tensor, "IntermediateTensors"]:
"""
Expand All @@ -205,10 +199,7 @@ def make_empty_intermediate_tensors(

def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: "AttentionMetadata",
*,
intermediate_tensors: Optional["IntermediateTensors"],
) -> Union[torch.Tensor, "IntermediateTensors"]:
...
Expand Down Expand Up @@ -257,24 +248,19 @@ def supports_pp(
return supports_attributes and supports_inspect


def _supports_pp_attributes(
model: Union[Type[object], object],
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
def _supports_pp_attributes(model: Union[Type[object], object]) -> bool:
if isinstance(model, type):
return isinstance(model, _SupportsPPType)

return isinstance(model, SupportsPP)


def _supports_pp_inspect(
model: Union[Type[object], object],
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
def _supports_pp_inspect(model: Union[Type[object], object]) -> bool:
model_forward = getattr(model, "forward", None)
if not callable(model_forward):
return False

forward_params = inspect.signature(model_forward).parameters
return "intermediate_tensors" in forward_params
return supports_kw(model_forward, "intermediate_tensors")


@runtime_checkable
Expand Down
Loading
Loading