Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve vlm support (add idefics3 support) #2437

Merged
merged 17 commits into from
Jan 9, 2025
Merged

Improve vlm support (add idefics3 support) #2437

merged 17 commits into from
Jan 9, 2025

Conversation

drbh
Copy link
Collaborator

@drbh drbh commented Aug 20, 2024

This PR is a work in progress and add support for Idefics3 in TGI. opening for transparency and feedback.

This implementation uses the AutoProcessor/Idefics3Processor that will be added when this PR is merged: huggingface/transformers#32473

todos

  • add more comprehensive tests
  • ensure rust image token logic is correct
  • ensure correct config is loaded (related to processor_kwargs)
  • refactors/cleanup typos etc..

@ErikKaum ErikKaum mentioned this pull request Sep 9, 2024
2 tasks
@drbh drbh force-pushed the improve-vlm-support branch from c93fd85 to 35c64b2 Compare October 3, 2024 12:57
@drbh drbh force-pushed the improve-vlm-support branch from 35c64b2 to ebef284 Compare December 17, 2024 18:25
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@drbh drbh marked this pull request as ready for review December 19, 2024 02:40
@@ -632,13 +630,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()

if config.model_type == "mllama_text_model":
prefix = f"{prefix}.model"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. The correct line was whatever line was before.

This class cannot know about the model_type (shouldn't).
Especially since you're removing everything a few lines below.

No shenanigans here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

100% agreed that the class should not know about the model type, I've updated the logic to handle this case by avoiding appending .model if the prefix ends in text_model

base_model = "" if prefix.endswith("text_model") else ".model"

The reason for this complexity is the naming convention used by idefics3. The model has weights with names like model.text_model.embed_tokens.weight and the current logic always expects models to contain model.embed_tokens or X.model.embed_tokens.

The latest changes handle this case by conditionally appending ".model" before constructing the prefixes. Please let me know if theres a better way to handle this 🙏

@@ -679,6 +679,215 @@ def forward(self, image_hidden_states, attention_mask):
return image_hidden_states


class Idefics3Connector(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This belongs in idefic3 file.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed, moved into a new file in latest commit

Comment on lines 861 to 865
diff = mask_size - unrolled_image_size
if diff > 0:
print(
f"Mask size {mask_size} is greater than the number of images {unrolled_image_size}."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
diff = mask_size - unrolled_image_size
if diff > 0:
print(
f"Mask size {mask_size} is greater than the number of images {unrolled_image_size}."
)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed in latest changes, thanks!

Comment on lines 867 to 870
if mask_size == unrolled_image_size:
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_hidden_states
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if mask_size == unrolled_image_size:
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_hidden_states
)
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_hidden_states
)

Let it crash if something is wrong here. We should NEVER do silent errors.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes agreed! removed in the latest commits

Comment on lines 26 to 93
IDEFICS3_IMAGE_TOKEN = "<image>"
IDEFICS3_FAKE_IMAGE_TOKEN = "<fake_token_around_image>"
IDEFICS3_GLOBAL_IMG_TOKEN = "<global-img>"


def _prompt_split_image(
image_seq_len,
image_rows,
image_cols,
fake_token_around_image,
image_token,
global_img_token,
):
"""Prompt with expanded image tokens for when the image is split into patches."""
text_split_images = ""
for n_h in range(image_rows):
for n_w in range(image_cols):
text_split_images += (
f"{fake_token_around_image}"
+ f"<row_{n_h + 1}_col_{n_w + 1}>"
+ f"{image_token}" * image_seq_len
)
text_split_images += "\n"

text_split_images += (
f"\n{fake_token_around_image}"
+ f"{global_img_token}"
+ f"{image_token}" * image_seq_len
+ f"{fake_token_around_image}"
)
return text_split_images


def _prompt_single_image(
image_seq_len, fake_token_around_image, image_token, global_img_token
):
"""Prompt with expanded image tokens for a single image."""
return (
f"{fake_token_around_image}"
+ f"{global_img_token}"
+ f"{image_token}" * image_seq_len
+ f"{fake_token_around_image}"
)


def get_image_prompt_string(
image_rows,
image_cols,
image_seq_len,
fake_token_around_image,
image_token,
global_img_token,
):
if image_rows == 0 and image_cols == 0:
return _prompt_single_image(
image_seq_len,
fake_token_around_image=fake_token_around_image,
image_token=image_token,
global_img_token=global_img_token,
)
return _prompt_split_image(
image_seq_len,
image_rows,
image_cols,
fake_token_around_image,
image_token,
global_img_token,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put everything in some idefics3 file.

Can't those 4 functions be trivially merged into one using joins instead of forloops and ifs ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point. I've moved the idefics code into a new file and reduce this logic into a much more simple function

def get_image_prompt_string(
    rows=0,
    cols=0,
    seq_len=1,
    fake_token=IDEFICS3_FAKE_IMAGE_TOKEN,
    img_token=IDEFICS3_IMAGE_TOKEN,
    global_token=IDEFICS3_GLOBAL_IMG_TOKEN,
):
    tokens = img_token * seq_len
    end_token = f"{fake_token}{global_token}{tokens}{fake_token}"

    if rows == 0 or cols == 0:
        return end_token

    grid = "\n".join(
        "".join(f"{fake_token}<row_{i+1}_col_{j+1}>{tokens}" for j in range(cols))
        for i in range(rows)
    )

    return f"{grid}\n\n{end_token}"

@drbh drbh requested a review from Narsil December 23, 2024 16:48
@@ -507,6 +507,7 @@ def __init__(self, prefix, config, weights):
process_group = weights.process_group
self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size()
base_model = "" if prefix.endswith("text_model") else ".model"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is always appended to the prefix, maybe it's cleaner to extend prefix when prefix is not None and the condition holds?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point! I've updated the FlashLlamaForCausalLM to include a name similar to the implementation in FlashMistralForCausalLM which avoids this complexity in favor of explicitly passing the name from the VLM.

ie:

class FlashLlamaForCausalLM(torch.nn.Module):
    def __init__(self, prefix: str, config, weights, name=None):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is still a bug in the prefix generation, CI fails because model.model.layers.0.self_attn.q_proj.qweight is used as a prefix.

server/text_generation_server/models/vlm_causal_lm.py Outdated Show resolved Hide resolved
server/text_generation_server/models/vlm_causal_lm.py Outdated Show resolved Hide resolved
server/text_generation_server/models/vlm_causal_lm.py Outdated Show resolved Hide resolved
@drbh drbh merged commit da5ab46 into main Jan 9, 2025
13 of 15 checks passed
@drbh drbh deleted the improve-vlm-support branch January 9, 2025 15:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants