Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix crash with llama 3.2 vision models and guided decoding #9631

Merged
merged 4 commits into from
Oct 25, 2024

Conversation

tjohnson31415
Copy link
Contributor

@tjohnson31415 tjohnson31415 commented Oct 23, 2024

Sending a request that includes guided decoding to a Llama 3.2 Vision model crashes the engine with a CUDA device-side error:

ERROR 10-23 20:28:40 engine.py:165]   File "/home/vllm/vllm/lib64/python3.12/site-packages/vllm/model_executor/layers/logits_processor.py", line 144, in _apply_logits_processors
ERROR 10-23 20:28:40 engine.py:165]     logits_row = logits_processor(past_tokens_ids,
ERROR 10-23 20:28:40 engine.py:165]                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 10-23 20:28:40 engine.py:165]   File "/home/vllm/vllm/lib64/python3.12/site-packages/lmformatenforcer/integrations/vllm.py", line 29, in __call__
ERROR 10-23 20:28:40 engine.py:165]     self.mask[allowed_tokens] = 0
ERROR 10-23 20:28:40 engine.py:165]     ~~~~~~~~~^^^^^^^^^^^^^^^^
ERROR 10-23 20:28:40 engine.py:165] RuntimeError: CUDA error: device-side assert triggered
ERROR 10-23 20:28:40 engine.py:165] Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
ERROR 10-23 20:28:40 engine.py:165] 

All credit to @pavlo-ruban who described the root cause problem and identified this solution; I'm just creating a PR out of it.

FIX #8952

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Adding or changing kernels

Each custom kernel needs a schema and one or more implementations to be registered with PyTorch.

  • Make sure custom ops are registered following PyTorch guidelines: Custom C++ and CUDA Operators and The Custom Operators Manual
  • Custom operations that return Tensors require meta-functions. Meta-functions should be implemented and registered in python so that dynamic dims can be handled automatically. See above documents for a description of meta-functions.
  • Use torch.libary.opcheck() to test the function registration and meta-function for any registered ops. See tests/kernels for examples.
  • When changing the C++ signature of an existing op, the schema must be updated to reflect the changes.
  • If a new custom type is needed, see the following document: Custom Class Support in PT2.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Signed-off-by: Travis Johnson <[email protected]>
Co-authored-by: pavlo-ruban <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Signed-off-by: Travis Johnson <[email protected]>
@tjohnson31415
Copy link
Contributor Author

The fix in this PR is just for outlines. I looked into lm-format-enforcer as well and it would need a different approach: there was a changed made to it to support passing in a vocab_size separate from the tokenizer's vocab_size
noamgat/lm-format-enforcer@c61f00c

So we'd need to detect that the model's vocab size differs from the tokenizer and pass the vocab_size through. I think that could be a separate PR though so we can get the default of using outlines working.

Comment on lines 86 to 87
allowed_tokens = np.array(allowed_tokens)
allowed_tokens = allowed_tokens[allowed_tokens < scores.shape[-1]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjohnson31415 when you tried with pytorch did you copy to gpu first?

allowed_tokens = torch.tensor(allowed_tokens, device=scores.device)
allowed_tokens = allowed_tokens.masked_select(allowed_tokens < scores.shape[-1])
mask.index_fill_(allowed_tokens, 0)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. Creating a torch.Tensor from a large list is oddly slow, whether on CPU side or GPU side it was about the same in my testing: 15ms for ~120k element list. Creating an np.array is less than 5ms. The rest of the operations are fast in either case 🤔

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @tjohnson31415! I was intrigued and tried this myself. I found the same thing as you, that creating a tensor from a python list was the most expensive part. But actually it's about 30% faster overall if you combine the two!

allowed_tokens = np.array(allowed_tokens, dtype=np.int64)
allowed_tokens = torch.tensor(allowed_tokens, device=scores.device)
allowed_tokens = allowed_tokens.masked_select(allowed_tokens < scores.shape[-1])
mask.index_fill_(0, allowed_tokens, 0)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed an update with this change because @tjohnson31415 is OOO for a few days.

@njhill njhill added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 25, 2024
@njhill njhill merged commit 6567e13 into vllm-project:main Oct 25, 2024
55 checks passed
MErkinSag pushed a commit to MErkinSag/vllm that referenced this pull request Oct 26, 2024
…llm-project#9631)

Signed-off-by: Travis Johnson <[email protected]>
Co-authored-by: pavlo-ruban <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
Signed-off-by: Erkin Sagiroglu <[email protected]>
MErkinSag pushed a commit to MErkinSag/vllm that referenced this pull request Oct 26, 2024
…llm-project#9631)

Signed-off-by: Travis Johnson <[email protected]>
Co-authored-by: pavlo-ruban <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
Signed-off-by: Erkin Sagiroglu <[email protected]>
cooleel pushed a commit to cooleel/vllm that referenced this pull request Oct 28, 2024
…llm-project#9631)

Signed-off-by: Travis Johnson <[email protected]>
Co-authored-by: pavlo-ruban <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
Signed-off-by: Shanshan Wang <[email protected]>
cooleel pushed a commit to cooleel/vllm that referenced this pull request Oct 28, 2024
…llm-project#9631)

Signed-off-by: Travis Johnson <[email protected]>
Co-authored-by: pavlo-ruban <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
Signed-off-by: Shanshan Wang <[email protected]>
FerdinandZhong pushed a commit to FerdinandZhong/vllm that referenced this pull request Oct 29, 2024
…llm-project#9631)

Signed-off-by: Travis Johnson <[email protected]>
Co-authored-by: pavlo-ruban <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
Signed-off-by: qishuai <[email protected]>
rasmith pushed a commit to rasmith/vllm that referenced this pull request Oct 30, 2024
…llm-project#9631)

Signed-off-by: Travis Johnson <[email protected]>
Co-authored-by: pavlo-ruban <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Oct 31, 2024
…llm-project#9631)

Signed-off-by: Travis Johnson <[email protected]>
Co-authored-by: pavlo-ruban <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
Signed-off-by: NickLucche <[email protected]>
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Oct 31, 2024
…llm-project#9631)

Signed-off-by: Travis Johnson <[email protected]>
Co-authored-by: pavlo-ruban <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
Signed-off-by: NickLucche <[email protected]>
@tjohnson31415 tjohnson31415 deleted the fix-llama-vision-guide-crash branch November 2, 2024 05:21
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Nov 4, 2024
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Nov 4, 2024
…llm-project#9631)

Signed-off-by: Travis Johnson <[email protected]>
Co-authored-by: pavlo-ruban <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
Signed-off-by: Linkun Chen <[email protected]>
sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
…llm-project#9631)

Signed-off-by: Travis Johnson <[email protected]>
Co-authored-by: pavlo-ruban <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
Signed-off-by: Sumit Dubey <[email protected]>
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
mfournioux pushed a commit to mfournioux/vllm that referenced this pull request Nov 20, 2024
…llm-project#9631)

Signed-off-by: Travis Johnson <[email protected]>
Co-authored-by: pavlo-ruban <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
Signed-off-by: Maxime Fournioux <[email protected]>
tlrmchlsmth pushed a commit to neuralmagic/vllm that referenced this pull request Nov 23, 2024
…llm-project#9631)

Signed-off-by: Travis Johnson <[email protected]>
Co-authored-by: pavlo-ruban <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
sleepwalker2017 pushed a commit to sleepwalker2017/vllm that referenced this pull request Dec 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: Llama-3.2-11B-Vision-Instruct server crashes when asked guided generation
3 participants