Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Enable Online Multi-image Support for MLlama #9393

Merged

Conversation

alex-jw-brooks
Copy link
Contributor

@alex-jw-brooks alex-jw-brooks commented Oct 15, 2024

Currently, there is some special handling for mllama in the chat utils that result in only one image placeholder being used when multiple images are added. Now that #9095 added multi-image / interleaved data support, I think this can be removed to handle image placeholders normally for mllama in the frontend.

Example usage:

python -m vllm.entrypoints.openai.api_server \
    --device cuda \
    --model meta-llama/Llama-3.2-11B-Vision-Instruct \
    --api-key token-abc123 \
    --tokenizer meta-llama/Llama-3.2-11B-Vision-Instruct \
    --limit-mm-per-prompt image=2 \
    --max-model-len 32000 \
    --dtype=half  \
    --enforce-eager \
    --max-num-seqs 2
from openai import OpenAI

client = OpenAI(base_url="http://localhost:8000/v1", api_key="token-abc123")

completion = client.chat.completions.create(
  model="meta-llama/Llama-3.2-11B-Vision-Instruct",
  messages = [{
        "role": "user", "content": [
          {"type": "image_url", "image_url": {"url": "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg"}},
          {"type": "image_url", "image_url": {"url": "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg"}},
          {"type": "text", "text": "Can you compare these images?"},
        ]
  }]
)

print(completion.choices[0].message)

In the logs, it seems like we do get both placeholders:

INFO 10-15 18:01:33 logger.py:37] Received request chat-ca50cc85dea6435790ff6ff07624dade: prompt: '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 15 Oct 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n<|image|>\n<|image|>\nCan you compare these images?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n', params: SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.7, top_p=1.0, top_k=-1, min_p=0.0, seed=None, stop=[], stop_token_ids=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=31955, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None), guided_decoding=GuidedDecodingParams(json=None, regex=None, choice=None, grammar=None, json_object=None, backend=None, whitespace_pattern=None), prompt_token_ids: [128000, 128006, 9125, 128007, 271, 38766, 1303, 33025, 2696, 25, 6790, 220, 2366, 18, 198, 15724, 2696, 25, 220, 868, 5020, 220, 2366, 19, 271, 128009, 128006, 882, 128007, 271, 128256, 198, 128256, 198, 6854, 499, 9616, 1521, 5448, 30, 128009, 128006, 78191, 128007, 271], lora_request: None, prompt_adapter_request: None.
and a reasonable response:

ChatCompletionMessage(content='The image of a lion is larger than the image of a duck.', refusal=None, role='assistant', function_call=None, tool_calls=[])

This does pass what's expected to the tokenizer when it's applying its chat template, although the format of the prompt is a bit different. It seems to match the example tool chat template for llama 3.2 though

from transformers import AutoTokenizer
model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
conversation=[{'role': 'user', 'content': '<|image|>\nCan you compare these images?'}]

out = tokenizer.apply_chat_template(
    conversation=conversation,
    add_generation_prompt=True,
    tokenize=False,
)
print(f"----Output with image placeholder----\n{out}")

conversation=[{'role': 'user', 'content': [{'type': 'image'}, {'type': 'text', 'text': 'Can you compare these images?'}]}]
out = tokenizer.apply_chat_template(
    conversation=conversation,
    add_generation_prompt=True,
    tokenize=False,
)
print(f"----Output with image type----\n{out}")

Result:

----Output with image placeholder----
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 15 Oct 2024

<|eot_id|><|start_header_id|>user<|end_header_id|>

<|image|>
Can you compare these images?<|eot_id|><|start_header_id|>assistant<|end_header_id|>


----Output with image type----
<|begin_of_text|><|start_header_id|>user<|end_header_id|>

<|image|>Can you compare these images?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Adding or changing kernels

Each custom kernel needs a schema and one or more implementations to be registered with PyTorch.

  • Make sure custom ops are registered following PyTorch guidelines: Custom C++ and CUDA Operators and The Custom Operators Manual
  • Custom operations that return Tensors require meta-functions. Meta-functions should be implemented and registered in python so that dynamic dims can be handled automatically. See above documents for a description of meta-functions.
  • Use torch.libary.opcheck() to test the function registration and meta-function for any registered ops. See tests/kernels for examples.
  • When changing the C++ signature of an existing op, the schema must be updated to reflect the changes.
  • If a new custom type is needed, see the following document: Custom Class Support in PT2.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@heheda12345
Copy link
Collaborator

This change will affect how the image tokens and text tokens are mixed in almost all multi-modal models. And I think we need to figure out how to test its correctness.

@alex-jw-brooks
Copy link
Contributor Author

alex-jw-brooks commented Oct 16, 2024

Hi @heheda12345, thank you for the quick response! Can you please elaborate on how this would affect other models? As far as I understand, keep_multimodal_content is only True if the loaded model is mllama, and the rest of the multimodal models should only be hitting the other codepath (which was added in the PR for mllama support).

If there is a reason this can't be changed, maybe another path forward could be to collect the individual {"type": "text", "text":...} pieces with the interleaved {"type": "image"} so that we get the right number of placeholders? Then the cross attention token mask should have the right spans also, I think

CC @joerunde since we were chatting about this earlier today 🙂

@heheda12345
Copy link
Collaborator

The reason I add this special branch for mllama is that the other branch will add an additional \n after <|image|>. So I think simply removing this branch will not work. Maybe we should start with verifying whether this pr can generate exact the same tokens as hf.

@alex-jw-brooks
Copy link
Contributor Author

alex-jw-brooks commented Oct 22, 2024

Thanks @heheda12345 - I investigated a bit, and you're right, it didn't match exactly when the branch was removed.

For now, I added the separate branch for mllama back rewrote it to build a list of {'type': 'text', 'text': text} / {'type': 'image'} dicts for mllama. This will change the current single image behavior a bit because the image will no longer be preprended if the text comes first, but should match the result of applying the chat template with the HF tokenizer directly now - there is a test that explicitly checks this mllama on interleaved images here

@heheda12345
Copy link
Collaborator

Thanks for your fix!
@ywang96 @DarkLight1337 I think this pr is correct, i.e., it uses two branches, so that it can generate correct result for mllama and does not affect the behavior of other models. However, as many other models, e.g., llava, should also use this new code path for mllama instead of doing the previous hack. How should we safely changing these models to this new code path?

@DarkLight1337 DarkLight1337 added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 23, 2024
vllm/entrypoints/chat_utils.py Outdated Show resolved Hide resolved
vllm/entrypoints/chat_utils.py Outdated Show resolved Hide resolved
@DarkLight1337
Copy link
Member

DarkLight1337 commented Oct 23, 2024

However, as many other models, e.g., llava, should also use this new code path for mllama instead of doing the previous hack. How should we safely changing these models to this new code path?

Let's add a test that compares the output of applying parse_chat_messages then apply_hf_chat_template to some example inputs, for the cases keep_multimodal_content is set to True/False.

alex-jw-brooks and others added 6 commits October 23, 2024 08:13
@alex-jw-brooks alex-jw-brooks force-pushed the mllama_server_multiimage branch from bf65f96 to fc3f377 Compare October 23, 2024 12:13
@alex-jw-brooks
Copy link
Contributor Author

Let's add a test that compares the output of applying parse_chat_messages then apply_hf_chat_template to some example inputs, for the cases keep_multimodal_content is set to True/False.

I think that sounds like a good plan to me too! I had already added a test for doing this for mllama as part of this PR; I went ahead and moved the creation of the tokenizer group / model config into the test so that we can parametrize other models for this test as new ones are moved over to this code path

@DarkLight1337
Copy link
Member

I think that sounds like a good plan to me too! I had already added a test for doing this for mllama as part of this PR; I went ahead and moved the creation of the tokenizer group / model config into the test so that we can parametrize other models for this test as new ones are moved over to this code path

Sounds good. Let's focus on mllama for this PR and move over the rest in another PR.

alex-jw-brooks and others added 2 commits October 23, 2024 08:26
Co-authored-by: Cyrus Leung <[email protected]>
Signed-off-by: Alex-Brooks <[email protected]>
Signed-off-by: Alex-Brooks <[email protected]>
@alex-jw-brooks alex-jw-brooks force-pushed the mllama_server_multiimage branch from 6faff0b to 62785ed Compare October 23, 2024 12:27
@DarkLight1337 DarkLight1337 enabled auto-merge (squash) October 23, 2024 15:49
@DarkLight1337 DarkLight1337 merged commit 150b779 into vllm-project:main Oct 23, 2024
57 checks passed
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
MErkinSag pushed a commit to MErkinSag/vllm that referenced this pull request Oct 26, 2024
…#9393)

Signed-off-by: Alex-Brooks <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Signed-off-by: Erkin Sagiroglu <[email protected]>
cooleel pushed a commit to cooleel/vllm that referenced this pull request Oct 28, 2024
…#9393)

Signed-off-by: Alex-Brooks <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Signed-off-by: Shanshan Wang <[email protected]>
cooleel pushed a commit to cooleel/vllm that referenced this pull request Oct 28, 2024
…#9393)

Signed-off-by: Alex-Brooks <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Signed-off-by: Shanshan Wang <[email protected]>
FerdinandZhong pushed a commit to FerdinandZhong/vllm that referenced this pull request Oct 29, 2024
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Oct 31, 2024
…#9393)

Signed-off-by: Alex-Brooks <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Signed-off-by: NickLucche <[email protected]>
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Oct 31, 2024
…#9393)

Signed-off-by: Alex-Brooks <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Signed-off-by: NickLucche <[email protected]>
sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
…#9393)

Signed-off-by: Alex-Brooks <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Signed-off-by: Sumit Dubey <[email protected]>
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
mfournioux pushed a commit to mfournioux/vllm that referenced this pull request Nov 20, 2024
…#9393)

Signed-off-by: Alex-Brooks <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Signed-off-by: Maxime Fournioux <[email protected]>
tlrmchlsmth pushed a commit to neuralmagic/vllm that referenced this pull request Nov 23, 2024
…#9393)

Signed-off-by: Alex-Brooks <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants