Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Refactor model executable interface for multimodal models #10570

Merged
merged 33 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 37 additions & 24 deletions vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.utils import consecutive_placeholder_ranges
from vllm.sequence import IntermediateTensors, SequenceData

Expand Down Expand Up @@ -609,13 +610,33 @@ def _process_image_input(self,

return self.language_projection(query_output)

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings

def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
BLIP2_IMAGE_TOKEN_ID)
return inputs_embeds

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[SamplerOutput, IntermediateTensors]:
"""Run forward pass for BLIP-2.
Expand Down Expand Up @@ -648,32 +669,24 @@ def forward(
See also:
:class:`Blip2ImageInputs`
"""

if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)

if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)

inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
BLIP2_IMAGE_TOKEN_ID)

input_ids = None
else:
inputs_embeds = None

hidden_states = self.language_model.model(
input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds)

# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None

hidden_states = self.language_model.model(input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds)

return hidden_states

Expand Down
58 changes: 41 additions & 17 deletions vllm/model_executor/models/chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens)
Expand All @@ -38,7 +39,7 @@
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
maybe_prefix, merge_multimodal_embeddings)

# These configs are not part of the model config but the preprocessor
# and processor files, so we hardcode them in the model file for now.
Expand Down Expand Up @@ -987,34 +988,57 @@ def _parse_and_validate_image_input(
data=self._validate_pixel_values(pixel_values),
)

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
assert self.model.vqmodel is not None
image_tokens = self.model.get_image_tokens(image_input["data"].to(
self.config.torch_dtype))
vision_embeddings = self.model.get_input_embeddings(image_tokens)
return vision_embeddings

def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:

inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.model.vocabulary_mapping.image_token_id)
return inputs_embeds

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:

if intermediate_tensors is not None:
inputs_embeds = None

# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)

if image_input is not None:
assert self.model.vqmodel is not None
image_tokens = self.model.get_image_tokens(
image_input["data"].to(self.config.torch_dtype))
image_token_id = self.model.vocabulary_mapping.image_token_id
special_image_mask = input_ids == image_token_id
image_tokens = image_tokens.to(input_ids.device,
input_ids.dtype)
input_ids = input_ids.masked_scatter(special_image_mask,
image_tokens)

hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)

hidden_states = self.model(input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states

def compute_logits(
Expand Down
54 changes: 35 additions & 19 deletions vllm/model_executor/models/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalData, MultiModalKwargs
from vllm.multimodal.inputs import (MultiModalData, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
Expand Down Expand Up @@ -545,33 +546,48 @@ def _parse_and_validate_image_input(
""")
return GLMImagePixelInputs(pixel_values=pixel_values)

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input["pixel_values"] is None:
return None
pixel_values = image_input["pixel_values"].to(
dtype=self.config.torch_dtype)
vision_embeddings = self.vision(pixel_values)
return vision_embeddings

def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.embedding(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_glm_vision_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
vision_embeddings=multimodal_embeddings,
boi_token_id=self.config.boi_token_id,
eoi_token_id=self.config.eoi_token_id)
return inputs_embeds

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> torch.Tensor:
if intermediate_tensors is None:
inputs_embeds = self.embedding(input_ids)
image_input = self._parse_and_validate_image_input(**kwargs)

if image_input["pixel_values"] is not None:
pixel_values = image_input["pixel_values"].to(
dtype=inputs_embeds.dtype)
image_embeds = self.vision(pixel_values)

boi_token_id = self.config.boi_token_id
eoi_token_id = self.config.eoi_token_id

inputs_embeds = merge_glm_vision_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
vision_embeddings=image_embeds,
boi_token_id=boi_token_id,
eoi_token_id=eoi_token_id)

# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
if intermediate_tensors is None and inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
else:
inputs_embeds = intermediate_tensors["hidden_states"]

Expand Down
43 changes: 29 additions & 14 deletions vllm/model_executor/models/fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges)
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
Expand Down Expand Up @@ -302,31 +303,45 @@ def _process_image_input(
vision_embeddings, _ = self.vision_embed_tokens(image_input["data"])
return vision_embeddings

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings

def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
_IMAGE_TOKEN_ID)
return inputs_embeds

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
):
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)

if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.embed_tokens(
input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.image_token_id)

else:
inputs_embeds = None

# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None

hidden_states = self.language_model(
input_ids=input_ids,
Expand Down
36 changes: 35 additions & 1 deletion vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,22 @@
Protocol, Type, Union, overload, runtime_checkable)

import torch
from typing_extensions import TypeIs
from typing_extensions import TypeIs, TypeVar

from vllm.logger import init_logger
from vllm.utils import supports_kw

from .interfaces_base import is_embedding_model

if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
from vllm.multimodal.inputs import NestedTensors # noqa: F401
from vllm.sequence import IntermediateTensors

logger = init_logger(__name__)

T = TypeVar("T", default="NestedTensors")


@runtime_checkable
class SupportsMultiModal(Protocol):
Expand All @@ -28,6 +32,36 @@ class SupportsMultiModal(Protocol):
MRO of your model class.
"""

def get_multimodal_embeddings(self, **kwargs) -> Optional[T]:
"""
Returns multimodal embeddings generated from multimodal kwargs
to be merged with text embeddings.
"""
...

# Only for models that support v0 chunked prefill
# TODO(ywang96): Remove this overload once v0 is deprecated
@overload
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[T] = None,
attn_metadata: Optional["AttentionMetadata"] = None,
) -> torch.Tensor:
...

def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[T] = None,
) -> torch.Tensor:
"""
Returns the input embeddings merged from the text embeddings from
input_ids and the multimodal embeddings generated from multimodal
kwargs.
"""
...


# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
Expand Down
Loading