Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] SiglipVisionModel ported from transformers #6942

Merged
merged 24 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
9222552
feat: initial siglip implementation
ChristopherCho Jul 30, 2024
5e09410
fix: typo fixed
ChristopherCho Jul 30, 2024
8af6456
fix: change paligemma to use ported siglip
ChristopherCho Jul 30, 2024
c00edeb
fix: style fixed
ChristopherCho Jul 30, 2024
db99a08
feat: modify paligemma to fully utilize siglip
ChristopherCho Jul 30, 2024
3e3b032
feat: sync model methods for paligemma
ChristopherCho Jul 30, 2024
f04da2b
fix: style fix
ChristopherCho Jul 30, 2024
b3ccec5
fix: sync with transformers siglip
ChristopherCho Jul 31, 2024
106e193
fix: style fix
ChristopherCho Jul 31, 2024
5b9242f
fix: faulty weight loading logic for vision model
ChristopherCho Jul 31, 2024
3dc8ea0
feat: add various attention mechanisms
ChristopherCho Jul 31, 2024
5afa010
fix: style update
ChristopherCho Jul 31, 2024
7fdb13d
fix: remove unnecessary comments
ChristopherCho Jul 31, 2024
c47e54a
fix: remove unrequired docstring
ChristopherCho Aug 5, 2024
cac1933
fix: remove unrequired docstring
ChristopherCho Aug 5, 2024
2d1aeec
fix: detach vllm attention
ChristopherCho Aug 5, 2024
bb570c3
fix: remove vllm attention
ChristopherCho Aug 5, 2024
dee55d0
fix: revert vision tower weight loading
ChristopherCho Aug 5, 2024
bffc385
fix: use basic SiglipAttention for now
ChristopherCho Aug 5, 2024
681b36d
fix: remove unnecessary weight loading logic
ChristopherCho Aug 5, 2024
fb4972d
cleanup
ywang96 Aug 5, 2024
d15a299
typing
ywang96 Aug 5, 2024
9ef79b9
update
ywang96 Aug 5, 2024
54faf5d
Merge remote-tracking branch 'upstream/main' into siglip-support
ywang96 Aug 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/offline_inference_vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def run_phi3v(question):
# PaliGemma
def run_paligemma(question):

prompt = question
# PaliGemma has special prompt format for VQA
prompt = "caption en"
llm = LLM(model="google/paligemma-3b-mix-224")

return llm, prompt
Expand Down
79 changes: 27 additions & 52 deletions vllm/model_executor/models/paligemma.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict

import torch
from PIL import Image
from torch import nn
from transformers import PaliGemmaConfig, SiglipVisionConfig, SiglipVisionModel
from transformers import PaliGemmaConfig

from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
Expand All @@ -18,9 +17,11 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
from vllm.sequence import IntermediateTensors, SamplerOutput

from .interfaces import SupportsVision
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
from .utils import merge_vision_embeddings

logger = init_logger(__name__)
Expand All @@ -32,55 +33,22 @@

def get_max_paligemma_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(PaliGemmaConfig)
text_config = hf_config.text_config

return text_config.num_image_tokens


def dummy_seq_data_for_paligemma(
hf_config: PaliGemmaConfig,
seq_len: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
):
if image_feature_size_override is None:
image_feature_size = hf_config.text_config.num_image_tokens
else:
image_feature_size = image_feature_size_override

token_ids = [image_token_id] * image_feature_size
token_ids += [0] * (seq_len - image_feature_size)
return SequenceData(token_ids)


def dummy_image_for_paligemma(
hf_config: SiglipVisionConfig,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
width = height = hf_config.image_size
if image_width_override is not None:
width = image_width_override
if image_height_override is not None:
height = image_height_override
vision_config = hf_config.vision_config

image = Image.new("RGB", (width, height), color=0)
return {"image": image}
return get_max_siglip_image_tokens(vision_config)


def dummy_data_for_paligemma(ctx: InputContext, seq_len: int):
hf_config = ctx.get_hf_config(PaliGemmaConfig)
vision_config = hf_config.vision_config

seq_data = dummy_seq_data_for_paligemma(
hf_config,
seq_data = dummy_seq_data_for_siglip(
vision_config,
seq_len,
image_token_id=hf_config.image_token_index,
)

mm_data = dummy_image_for_paligemma(vision_config)
mm_data = dummy_image_for_siglip(vision_config)
return seq_data, mm_data


Expand Down Expand Up @@ -208,30 +176,37 @@ def _parse_and_validate_image_input(
data=self._validate_pixel_values(pixel_values),
)

def _image_pixels_to_features(self, vision_tower: SiglipVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor:
def _image_pixels_to_features(
self,
vision_tower: SiglipVisionModel,
pixel_values: torch.Tensor,
) -> torch.Tensor:

target_dtype = vision_tower.get_input_embeddings().weight.dtype
image_outputs = vision_tower(pixel_values.to(dtype=target_dtype),
output_hidden_states=True)

selected_image_features = image_outputs.last_hidden_state
image_features = vision_tower(pixel_values.to(dtype=target_dtype))

return selected_image_features
return image_features

def _process_image_pixels(
self, inputs: PaliGemmaImagePixelInputs) -> torch.Tensor:
self,
inputs: PaliGemmaImagePixelInputs,
) -> torch.Tensor:
assert self.vision_tower is not None

pixel_values = inputs["data"]

return self._image_pixels_to_features(self.vision_tower, pixel_values)
return self._image_pixels_to_features(
self.vision_tower,
pixel_values,
)

def _process_image_input(
self, image_input: PaliGemmaImageInputs) -> torch.Tensor:
self,
image_input: PaliGemmaImageInputs,
) -> torch.Tensor:

assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
image_features = self._process_image_pixels(image_input, )

return self.multi_modal_projector(image_features)

Expand Down
Loading
Loading