Skip to content

Commit

Permalink
Patch pixtral processor
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 committed Dec 6, 2024
1 parent 1d5a4d4 commit 000736b
Showing 1 changed file with 54 additions and 42 deletions.
96 changes: 54 additions & 42 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from functools import cached_property
from types import MethodType
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set,
Tuple, TypedDict, Union)

Expand All @@ -7,7 +8,8 @@
from PIL.Image import Image
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
PixtralVisionConfig, PretrainedConfig,
SiglipVisionConfig)
ProcessorMixin, SiglipVisionConfig)
from transformers.models.pixtral import PixtralProcessor

from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
Expand All @@ -23,7 +25,7 @@
from vllm.multimodal.processing import (InputProcessingContext,
ModalityProcessingMetadata,
MultiModalProcessingMetadata,
PromptReplacement)
MultiModalProcessor, PromptReplacement)
from vllm.sequence import IntermediateTensors

from .clip import (CLIPVisionModel, dummy_image_for_clip,
Expand Down Expand Up @@ -128,11 +130,15 @@ def dummy_mm_kwargs_for_llava(ctx: InputProcessingContext,
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)

image_processor = ctx.get_hf_processor().image_processor # type: ignore
hf_inputs = image_processor \
.preprocess(data['image'], return_tensors="pt") \
.data
return MultiModalKwargs(hf_inputs)
hf_processor = ctx.get_hf_processor()
image_processor = hf_processor.image_processor # type: ignore
hf_inputs = image_processor.preprocess(data['image'], return_tensors="pt")
is_pixtral = isinstance(hf_processor, PixtralProcessor)

return MultiModalKwargs(
**hf_inputs,
is_pixtral=torch.tensor(is_pixtral),
)


def create_metadata_for_llava(
Expand All @@ -157,6 +163,39 @@ def get_repl_count(
}


class LlavaProcessor(MultiModalProcessor):

def _patch_pixtral_processor(self, hf_processor: PixtralProcessor):
if getattr(hf_processor, "__is_patched__", False):
return # Already patched

image_processor = hf_processor.image_processor # type: ignore
orig_preprocess = image_processor.preprocess

def preprocess(__self, *args, **kwargs):
hf_inputs = orig_preprocess(*args, **kwargs)
hf_inputs["is_pixtral"] = torch.tensor(True)
return hf_inputs

image_processor.preprocess = MethodType(preprocess, image_processor)

hf_processor.__is_patched__ = True # type: ignore

def _get_hf_processor(self) -> ProcessorMixin:
hf_processor = self.ctx.get_hf_processor()

if isinstance(hf_processor, PixtralProcessor):
self._patch_pixtral_processor(hf_processor)

return hf_processor

def _get_dummy_mm_kwargs(
self,
mm_counts: Mapping[str, int],
) -> MultiModalKwargs:
return dummy_mm_kwargs_for_llava(self.ctx, mm_counts)


class LlavaLikeConfig(Protocol):
vision_config: PretrainedConfig
vision_feature_layer: Union[int, List[int]]
Expand Down Expand Up @@ -238,8 +277,10 @@ def init_vision_tower_for_llava(


@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@MULTIMODAL_REGISTRY.register_processor_by_metadata(create_metadata_for_llava,
dummy_mm_kwargs_for_llava)
@MULTIMODAL_REGISTRY.register_processor(lambda ctx: LlavaProcessor(
ctx=ctx,
metadata=create_metadata_for_llava(ctx),
))
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
Expand Down Expand Up @@ -312,38 +353,10 @@ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:

return data

def _validate_image_sizes(self, images: List[torch.Tensor],
sizes: List[torch.Tensor]) -> List[torch.Tensor]:
if not isinstance(sizes, list):
sizes = [sizes]

total_images = sum(size.numel() // 2 for size in sizes)
if total_images != len(images):
raise ValueError("Mismatch in number of images. "
f"Expected {total_images}, got {len(images)}")
img_idx = 0
for size in sizes:
# Flatten the size tensor to a list of (height, width) pairs
size = size.view(-1, 2).tolist()
for expected_h, expected_w in size:
if img_idx >= len(images):
raise ValueError("Ran out of images before sizes. "
f"{img_idx} >= {len(images)}")
img = images[img_idx]
if img.shape[-2:] != (expected_h, expected_w):
raise ValueError(
"Image size mismatch. Expected "
f"{(expected_h, expected_w)}, got {img.shape[-2:]}")
if img.shape[-3] != 3:
raise ValueError("Image channel mismatch. Expected 3, "
f"got {img.shape[-3]}")
img_idx += 1
return images

def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
is_pixtral = kwargs.pop("is_pixtral", torch.tensor([False]))
image_embeds = kwargs.pop("image_embeds", None)

if pixel_values is None and image_embeds is None:
Expand All @@ -354,9 +367,8 @@ def _parse_and_validate_image_input(
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")

# Case for models like PixtralHF that have dynamic image sizes
# so we need to produce a list of tensors
if image_sizes is not None:
assert isinstance(is_pixtral, torch.Tensor)
if is_pixtral.any():
images = pixel_values

def flatten_to_3d_tensors(item):
Expand All @@ -379,7 +391,7 @@ def flatten_to_3d_tensors(item):

return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_image_sizes(images, image_sizes),
data=images,
)

return LlavaImagePixelInputs(
Expand Down

0 comments on commit 000736b

Please sign in to comment.