Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VLM] Merged multi-modal processor for Pixtral #12211

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
92 changes: 2 additions & 90 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
PixtralVisionConfig, PretrainedConfig,
SiglipVisionConfig)
from transformers.models.llava import LlavaProcessor
from transformers.models.pixtral import PixtralProcessor

from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
Expand All @@ -34,8 +33,8 @@

from .clip import CLIPVisionModel
from .interfaces import SupportsMultiModal, SupportsPP
from .pixtral import (PixtralHFVisionModel,
get_pixtral_hf_image_feature_grid_size)
from .pixtral import (PixtralHFMultiModalProcessor, PixtralHFProcessingInfo,
PixtralHFVisionModel)
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
Expand Down Expand Up @@ -261,93 +260,6 @@ def _get_mm_fields_config(
)


class PixtralHFProcessingInfo(BaseLlavaProcessingInfo):

def get_hf_processor(self):
return self.ctx.get_hf_processor(PixtralProcessor)


class PixtralHFMultiModalProcessor(
BaseMultiModalProcessor[PixtralHFProcessingInfo]):

def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)

pixel_values = processed_outputs.get("pixel_values")
if pixel_values is not None:
images = mm_data["images"]
assert isinstance(images, list)

# Original output: (1, num_images, C, H, W)
# New output: (num_images, C, H, W)
assert (isinstance(pixel_values, list) and len(pixel_values) == 1)
assert (isinstance(pixel_values[0], list)
and len(pixel_values[0]) == len(images))

processed_outputs["pixel_values"] = pixel_values[0]

return processed_outputs

def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)

def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index

processor = self.info.get_hf_processor()
image_token = processor.image_token
image_break_token = processor.image_break_token
image_end_token = processor.image_end_token

vision_config = hf_config.vision_config
assert isinstance(vision_config, PixtralVisionConfig)

def get_replacement(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)

ncols, nrows = get_pixtral_hf_image_feature_grid_size(
vision_config,
image_width=image_size.width,
image_height=image_size.height,
)

tokens = ([image_token] * ncols + [image_break_token]) * nrows
tokens[-1] = image_end_token

return "".join(tokens)

return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=get_replacement,
),
]


def _build_llava_or_pixtral_hf_info(
ctx: InputProcessingContext, ) -> BaseLlavaProcessingInfo:
hf_config = ctx.get_hf_config(LlavaConfig)
Expand Down
Loading
Loading