Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Support Pixtral models in the HF Transformers format #9036

Merged
merged 15 commits into from
Oct 18, 2024
17 changes: 17 additions & 0 deletions examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,22 @@ def run_qwen2_vl(question: str, modality: str):
return llm, prompt, stop_token_ids


# Pixtral HF-format
def run_pixtral_hf(question: str, modality: str):
assert modality == "image"

model_name = "mistral-community/pixtral-12b"

llm = LLM(
model=model_name,
max_model_len=8192,
)

prompt = f"<s>[INST]{question}\n[IMG][/INST]"
stop_token_ids = None
return llm, prompt, stop_token_ids


# LLama 3.2
def run_mllama(question: str, modality: str):
assert modality == "image"
Expand Down Expand Up @@ -347,6 +363,7 @@ def run_glm4v(question: str, modality: str):
"NVLM_D": run_nvlm_d,
"qwen_vl": run_qwen_vl,
"qwen2_vl": run_qwen2_vl,
"pixtral_hf": run_pixtral_hf,
"mllama": run_mllama,
"molmo": run_molmo,
"glm4v": run_glm4v,
Expand Down
74 changes: 70 additions & 4 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import torch
import torch.nn as nn
from PIL import Image
from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig,
SiglipVisionConfig)

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
Expand All @@ -22,6 +23,10 @@
dummy_seq_data_for_clip, get_max_clip_image_tokens,
input_processor_for_clip)
from .interfaces import SupportsMultiModal, SupportsPP
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
dummy_seq_data_for_pixtral_hf,
get_max_pixtral_hf_image_tokens,
input_processor_for_pixtral_hf)
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
input_processor_for_siglip)
Expand All @@ -31,8 +36,13 @@

class LlavaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
data: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape: `(batch_size * num_images, num_channels, height, width)`

Note that `height` or `width` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
"""


class LlavaImageEmbeddingInputs(TypedDict):
Expand Down Expand Up @@ -77,6 +87,8 @@ def get_max_llava_image_tokens(ctx: InputContext):
num_image_tokens = get_max_clip_image_tokens(vision_config)
elif isinstance(vision_config, SiglipVisionConfig):
num_image_tokens = get_max_siglip_image_tokens(vision_config)
elif isinstance(vision_config, PixtralVisionConfig):
num_image_tokens = get_max_pixtral_hf_image_tokens(vision_config)
else:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
Expand Down Expand Up @@ -120,6 +132,17 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,

mm_data = dummy_image_for_siglip(vision_config, num_images)
return seq_data, mm_data
elif isinstance(vision_config, PixtralVisionConfig):
seq_data = dummy_seq_data_for_pixtral_hf(
vision_config,
seq_len,
num_images,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)

mm_data = dummy_image_for_pixtral_hf(vision_config, num_images)
return seq_data, mm_data

msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
Expand Down Expand Up @@ -163,6 +186,15 @@ def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs):
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
elif isinstance(vision_config, PixtralVisionConfig):
# We ignore image_feature_size_override since we have non-uniform
# image sizes for Pixtral
return input_processor_for_pixtral_hf(
model_config,
vision_config,
inputs,
image_token_id=hf_config.image_token_index,
)

msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
Expand All @@ -189,6 +221,9 @@ def _init_vision_tower(hf_config: LlavaConfig):
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
elif isinstance(vision_config, PixtralVisionConfig):
# TODO: allow layer override?
return PixtralHFVisionModel(vision_config)

msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
Expand All @@ -210,6 +245,15 @@ def __init__(self,
self.config = config
self.multimodal_config = multimodal_config

# NOTE: These are special cases for Pixtral-12B in the HF-format
# https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json # noqa
if (config.text_config.architectures is None
and config.text_config.model_type == "mistral"):
config.text_config.architectures = ["MistralForCausalLM"]
if (config.projector_hidden_act is None
and config.vision_config.hidden_act == "gelu"):
config.projector_hidden_act = "gelu"

# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = _init_vision_tower(config)
self.multi_modal_projector = LlavaMultiModalProjector(
Expand Down Expand Up @@ -246,6 +290,7 @@ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
image_embeds = kwargs.pop("image_embeds", None)

if pixel_values is None and image_embeds is None:
Expand All @@ -256,6 +301,26 @@ def _parse_and_validate_image_input(
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")

# Case for models like PixtralHF that have dynamic image sizes
# so we need to produce a list of tensors
if image_sizes is not None:
images = pixel_values
if isinstance(images, torch.Tensor):
# if passed as batch take all images
NN, N, B, C, W, H = images.shape
images = images.reshape(NN * N * B, C, W, H)
images = [images[i] for i in range(images.size(0))]
elif isinstance(images, list):
# if passed as list flatten lists of tensors
while isinstance(images, list) and len(images) == 1:
images = images[0]

# TODO: Add validation based on image_sizes
return LlavaImagePixelInputs(
type="pixel_values",
data=images,
)

return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
Expand Down Expand Up @@ -286,7 +351,8 @@ def _select_image_features(self, image_features: torch.Tensor, *,

def _image_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
PixtralHFVisionModel],
pixel_values: torch.Tensor,
) -> torch.Tensor:

Expand Down
Loading